Source code for vs_common.archive

"""
archive.py
============
This module contains functions for handling archives.
"""

import pathlib
import re
import tarfile
import zipfile
from os.path import join
from urllib.parse import urlparse
from abc import ABC, abstractmethod
from typing import Optional, List, Type, Union

import fsspec

from .model import FilesystemType, VSI_MAPPING


[docs] class ArchiveError(Exception): pass
InfoFile = Union[tarfile.TarInfo, zipfile.ZipInfo] ExFile = Union[tarfile.ExFileObject, zipfile.ZipExtFile]
[docs] class Archive(ABC): vsi_prefix: str = "" def __init__(self, path: Union[str, fsspec.spec.AbstractBufferedFile]): self.path = path @abstractmethod def __enter__(self): """Context manager for archives""" @abstractmethod def __exit__(self, exc_type, exc_value, traceback): """Exits the context"""
[docs] @abstractmethod def get_members(self) -> List[InfoFile]: """Return a list of archive info files Returns: List[InfoFile]: List of InfoFiles """
[docs] @abstractmethod def extract_file(self, f: InfoFile) -> ExFile: """Extract a file in memory Args: f (InfoFile): Valid info file Returns: ExFile: Extracted file """
[docs] class TarArchive(Archive): """Handles .tar and .tar.gz archives""" vsi_prefix = "/vsitar/" def __enter__(self): try: self.archive = tarfile.open(self.path) except TypeError: self.archive = tarfile.open(fileobj=self.path) return self def __exit__(self, exc_type, exc_value, traceback): self.archive.close()
[docs] def get_members(self): return self.archive.getmembers()
[docs] def extract_file(self, f): return self.archive.extractfile(f)
[docs] class ZipArchive(Archive): """Handles .zip archives""" vsi_prefix = "/vsizip/" def __enter__(self): self.archive = zipfile.ZipFile(self.path) return self def __exit__(self, exc_type, exc_value, traceback): self.archive.close()
[docs] def get_members(self): return self.archive.infolist()
[docs] def extract_file(self, f): return self.archive.open(f)
[docs] class GzArchive(Archive): """Handles .gz archives""" pass
def _recursive_archive_search( archive: Archive, regex: str, matches: List[str], prefixes: List[str], path: Optional[str] = None, ) -> None: if path is None: path = _prepare_path(archive) for member in archive.get_members(): f_name = _get_member_name(member) file = pathlib.Path(f_name) if re.search(regex, file.name): matches.append(join(path, str(file))) elif file.suffix.lower() in ARCHIVE_MAPPING: suffix = _get_archive_suffix(file) ArchiveClass = _get_archive(suffix) new_archive = archive.extract_file(member) new_path = join(path, f_name) with ArchiveClass(new_archive) as arch: prefixes.append(ArchiveClass.vsi_prefix) _recursive_archive_search(arch, regex, matches, prefixes, new_path) def _recursive_archive_extract( archive: Archive, regex: str, to_path: str, ) -> str: archives = [] for member in archive.get_members(): f_name = _get_member_name(member) file = pathlib.Path(f_name) if re.match(regex, file.name): extract_path = join(to_path, str(file)) with open(extract_path, "wb") as f: data = archive.extract_file(member) f.write(data.read()) return extract_path elif file.suffix.lower() in ARCHIVE_MAPPING: archives.append((file, member)) for file, member in archives: suffix = _get_archive_suffix(file) ArchiveClass = _get_archive(suffix) new_archive = archive.extract_file(member) with ArchiveClass(new_archive) as t: return _recursive_archive_extract(t, regex, to_path) raise ArchiveError(f"Cannot find {regex} in {archive}")
[docs] def archive_extract(path: str, regex: str, to_path: str) -> str: """Extracts the first file matching regex from archive Args: path (str): Path to archive regex (str): Regex to extract to_path (str): Destination of extract Returns: str: Path to extracted file """ suffix = _get_archive_suffix(path) ArchiveClass = _get_archive(suffix) with fsspec.open(path) as p, ArchiveClass(p) as archive: return _recursive_archive_extract(archive, regex, to_path)
def _prepare_path(archive: Archive) -> str: path = archive.path if isinstance(path, str): return path else: _path = pathlib.Path(path.path) parsed_url = urlparse(path.path) if parsed_url.scheme == FilesystemType.swift: parts = _path.parts[3:] return join(*parts) else: return join(str(_path)) ARCHIVE_MAPPING = { ".tar": TarArchive, ".tgz": TarArchive, ".tar.gz": TarArchive, ".zip": ZipArchive, ".gz": GzArchive, } def _get_archive_suffix(path: Union[str, pathlib.Path]) -> str: if isinstance(path, str): path = pathlib.Path(path) suffixes = path.suffixes return "".join( suffix.lower() for suffix in suffixes if suffix.lower() in ARCHIVE_MAPPING ) def _get_archive(suffix: str) -> Type[Archive]: return ARCHIVE_MAPPING[suffix] def _get_member_name(member: InfoFile) -> str: if isinstance(member, tarfile.TarInfo): f_name = member.name elif isinstance(member, zipfile.ZipInfo): f_name = member.filename else: raise ArchiveError(f"Member {member} not recognized as valid InfoFile") return f_name def _get_vsi_network_prefix(path: str) -> str: scheme = urlparse(path).scheme or "file" storage_type = FilesystemType(scheme) return VSI_MAPPING[storage_type] def _create_vsi_path(path: str, prefixes: List[str]) -> str: prepared_prefixes = "".join(reversed(prefixes)) path = f"{prepared_prefixes}{path}" return path