"""
archive.py
============
This module contains functions for handling archives.
"""
import pathlib
import re
import tarfile
import zipfile
from os.path import join
from urllib.parse import urlparse
from abc import ABC, abstractmethod
from typing import Optional, List, Type, Union
import fsspec
from .model import FilesystemType, VSI_MAPPING
[docs]
class ArchiveError(Exception):
pass
InfoFile = Union[tarfile.TarInfo, zipfile.ZipInfo]
ExFile = Union[tarfile.ExFileObject, zipfile.ZipExtFile]
[docs]
class Archive(ABC):
vsi_prefix: str = ""
def __init__(self, path: Union[str, fsspec.spec.AbstractBufferedFile]):
self.path = path
@abstractmethod
def __enter__(self):
"""Context manager for archives"""
@abstractmethod
def __exit__(self, exc_type, exc_value, traceback):
"""Exits the context"""
[docs]
@abstractmethod
def get_members(self) -> List[InfoFile]:
"""Return a list of archive info files
Returns:
List[InfoFile]: List of InfoFiles
"""
[docs]
class TarArchive(Archive):
"""Handles .tar and .tar.gz archives"""
vsi_prefix = "/vsitar/"
def __enter__(self):
try:
self.archive = tarfile.open(self.path)
except TypeError:
self.archive = tarfile.open(fileobj=self.path)
return self
def __exit__(self, exc_type, exc_value, traceback):
self.archive.close()
[docs]
def get_members(self):
return self.archive.getmembers()
[docs]
class ZipArchive(Archive):
"""Handles .zip archives"""
vsi_prefix = "/vsizip/"
def __enter__(self):
self.archive = zipfile.ZipFile(self.path)
return self
def __exit__(self, exc_type, exc_value, traceback):
self.archive.close()
[docs]
def get_members(self):
return self.archive.infolist()
[docs]
class GzArchive(Archive):
"""Handles .gz archives"""
pass
def _recursive_archive_search(
archive: Archive,
regex: str,
matches: List[str],
prefixes: List[str],
path: Optional[str] = None,
) -> None:
if path is None:
path = _prepare_path(archive)
for member in archive.get_members():
f_name = _get_member_name(member)
file = pathlib.Path(f_name)
if re.search(regex, file.name):
matches.append(join(path, str(file)))
elif file.suffix.lower() in ARCHIVE_MAPPING:
suffix = _get_archive_suffix(file)
ArchiveClass = _get_archive(suffix)
new_archive = archive.extract_file(member)
new_path = join(path, f_name)
with ArchiveClass(new_archive) as arch:
prefixes.append(ArchiveClass.vsi_prefix)
_recursive_archive_search(arch, regex, matches, prefixes, new_path)
def _recursive_archive_extract(
archive: Archive,
regex: str,
to_path: str,
) -> str:
archives = []
for member in archive.get_members():
f_name = _get_member_name(member)
file = pathlib.Path(f_name)
if re.match(regex, file.name):
extract_path = join(to_path, str(file))
with open(extract_path, "wb") as f:
data = archive.extract_file(member)
f.write(data.read())
return extract_path
elif file.suffix.lower() in ARCHIVE_MAPPING:
archives.append((file, member))
for file, member in archives:
suffix = _get_archive_suffix(file)
ArchiveClass = _get_archive(suffix)
new_archive = archive.extract_file(member)
with ArchiveClass(new_archive) as t:
return _recursive_archive_extract(t, regex, to_path)
raise ArchiveError(f"Cannot find {regex} in {archive}")
[docs]
def archive_search(path: str, regex: str) -> List[str]:
"""Searches the archive recursively for a file matching a regex
Args:
path (str): Path to archive
regex (str): Regular expression for searching
Raises:
ArchiveError: Raised when no matches are found
Returns:
List[str]: All matched files in archive, formatted as vsi paths
"""
suffix = _get_archive_suffix(path)
ArchiveClass = _get_archive(suffix)
matches: List[str] = []
prefixes: List[str] = []
with fsspec.open(path) as p, ArchiveClass(p) as archive:
if prefix := _get_vsi_network_prefix(path):
prefixes.append(prefix)
prefixes.append(ArchiveClass.vsi_prefix)
_recursive_archive_search(archive, regex, matches, prefixes)
if not matches:
raise ArchiveError(f"{regex} expression did not return any matches")
return [_create_vsi_path(m, prefixes) for m in matches]
def _prepare_path(archive: Archive) -> str:
path = archive.path
if isinstance(path, str):
return path
else:
_path = pathlib.Path(path.path)
parsed_url = urlparse(path.path)
if parsed_url.scheme == FilesystemType.swift:
parts = _path.parts[3:]
return join(*parts)
else:
return join(str(_path))
ARCHIVE_MAPPING = {
".tar": TarArchive,
".tgz": TarArchive,
".tar.gz": TarArchive,
".zip": ZipArchive,
".gz": GzArchive,
}
def _get_archive_suffix(path: Union[str, pathlib.Path]) -> str:
if isinstance(path, str):
path = pathlib.Path(path)
suffixes = path.suffixes
return "".join(
suffix.lower() for suffix in suffixes if suffix.lower() in ARCHIVE_MAPPING
)
def _get_archive(suffix: str) -> Type[Archive]:
return ARCHIVE_MAPPING[suffix]
def _get_member_name(member: InfoFile) -> str:
if isinstance(member, tarfile.TarInfo):
f_name = member.name
elif isinstance(member, zipfile.ZipInfo):
f_name = member.filename
else:
raise ArchiveError(f"Member {member} not recognized as valid InfoFile")
return f_name
def _get_vsi_network_prefix(path: str) -> str:
scheme = urlparse(path).scheme or "file"
storage_type = FilesystemType(scheme)
return VSI_MAPPING[storage_type]
def _create_vsi_path(path: str, prefixes: List[str]) -> str:
prepared_prefixes = "".join(reversed(prefixes))
path = f"{prepared_prefixes}{path}"
return path