Source code for registrar.source

"""
source.py
==========

Contains different file source concepts where the data may be stored such as
local storage, S3, swift
"""

import re
from os.path import normpath, join, isabs
import shutil
from glob import glob
from fnmatch import fnmatch
from typing import List, Optional
from abc import ABC, abstractmethod
from urllib.parse import urljoin, urlparse

import requests
import structlog

import boto3
import boto3.session
import botocore.session
import botocore.handlers
from botocore import UNSIGNED
from botocore.config import Config
from swiftclient.service import SwiftService

from .config import SourceConfig


logger = structlog.getLogger(__name__)


[docs]class Source(ABC): """Abstract base class for all sources """ def __init__(self, name: str = None, endpoint_url: str = ""): self.name = name self.endpoint_url = endpoint_url def __repr__(self) -> str: return f"<{self.__class__.__name__} name={self.name}>"
[docs] @abstractmethod def get_container_and_path(self, path: str): """Split the input path into a container and a path part """
[docs] @abstractmethod def list_files(self, path: str, glob_patterns: list = None): """Return a list of file references for the given base path and glob pattern """
[docs] @abstractmethod def get_file(self, path: str, target_path: str): """Download the given file to the target location"""
[docs] @abstractmethod def get_vsi_env_and_path(self, path: str): """Get a VSI conformant path. See https://gdal.org/user/virtual_file_systems.html"""
[docs]class SwiftSource(Source): """Handles data located on a openstack swift bucket Args: name (str, optional): Name of bucket. Defaults to None. username (str, optional): username for authentication. Defaults to None. password (str, optional): password for authentication Defaults to None. project_name (str, optional): name of swift tenant. Defaults to None. project_id (str, optional): id of swift tenant. Defaults to None. region_name (str, optional): name of region. Defaults to None. project_domain_id (str, optional): project domain identifier. Defaults to None. project_domain_name (str, optional): name of project domain. Defaults to None. user_domain_id (str, optional): user domain identifier. Defaults to None. user_domain_name (str, optional): name of user domain. Defaults to None. auth_url (str, optional): url to authenticate to. Defaults to None. auth_url_short (str, optional): short url to auth to. Defaults to None. auth_version (str, optional): swift auth version. Defaults to None. container (str, optional): name of swift container. Defaults to None. streaming (bool, optional): If streaming is to be used by GDAL. Defaults to False. """ def __init__( self, name: str = None, username: str = None, password: str = None, project_name: str = None, project_id: str = None, region_name: str = None, project_domain_id: str = None, project_domain_name: str = None, user_domain_id: str = None, user_domain_name: str = None, auth_url: str = None, auth_url_short: str = None, auth_version: str = None, container: str = None, streaming: bool = False, ): super().__init__(name) self.username = username self.password = password self.project_name = project_name self.project_id = project_id self.region_name = region_name self.project_domain_id = project_domain_id self.project_domain_name = project_domain_name self.user_domain_id = user_domain_id self.user_domain_name = user_domain_name self.auth_url = auth_url self.auth_url_short = auth_url_short self.auth_version = auth_version # TODO: assume 3 self.container = container self.streaming = streaming
[docs] def get_service(self): """Returns the swiftclient.SwiftService for the options. """ return SwiftService( options={ "os_username": self.username, "os_password": self.password, "os_project_name": self.project_name, "os_project_id": self.project_id, "os_region_name": self.region_name, "os_auth_url": self.auth_url, "auth_version": self.auth_version, "os_project_domain_id": self.project_domain_id, "os_project_domain_name": self.project_domain_name, "os_user_domain_id": self.user_domain_id, "os_user_domain_name": self.user_domain_name, } )
[docs] def get_container_and_path(self, path: str): container = self.container if container is None or container == "": if path.startswith("swift://"): # expects a schema swift://bucket/object fullpath = path.replace("swift://", "") else: # fallback if in schema /bucket/object or bucket/object fullpath = path[1:] if path.startswith("/") else path container, _, path = fullpath.partition("/") return container, path
[docs] def list_files(self, path: str, glob_patterns: list = None): container, path = self.get_container_and_path(path) if glob_patterns and not isinstance(glob_patterns, list): glob_patterns = [glob_patterns] with self.get_service() as swift: pages = swift.list( container=container, options={"prefix": path}, ) filenames = [] for page in pages: if page["success"]: # at least two files present -> pass validation for item in page["listing"]: if glob_patterns is None or any( fnmatch(item["name"], join(path, glob_pattern)) for glob_pattern in glob_patterns ): filenames.append( item["name"] if self.container else join(container, item["name"]) ) else: raise page["error"] return filenames
[docs] def get_file(self, path: str, target_path: str): container, path = self.get_container_and_path(path) with self.get_service() as swift: results = swift.download( container, [path], options={"out_file": target_path} ) for result in results: if not result["success"]: raise Exception(f"Failed to download {path}")
[docs] def get_vsi_env_and_path(self, path: str): container, path = self.get_container_and_path(path) protocol = "vsiswift" if not self.streaming else "vsiswift_streaming" return { "OS_IDENTITY_API_VERSION": self.auth_version, "OS_AUTH_URL": self.auth_url, "OS_USERNAME": self.username, "OS_PASSWORD": self.password, "OS_USER_DOMAIN_NAME": self.user_domain_name, # 'OS_PROJECT_NAME': self.tena, # 'OS_PROJECT_DOMAIN_NAME': , "OS_REGION_NAME": self.region_name, }, f"/{protocol}/{container}/{path}"
[docs]class S3Source(Source): """Handles data located on an S3 bucket Args: name (str, optional): Name of the source. Defaults to None. bucket_name (str, optional): Name of the bucket. Defaults to None. secret_access_key (str, optional): secret access key. Defaults to None. access_key_id (str, optional): access key identifier. Defaults to None. endpoint_url (str, optional): endpoint url. Defaults to None. strip_bucket (bool, optional): whether to strip bucket name when constructing paths. Defaults to True. validate_bucket_name (bool, optional): whether to validate the name of bucket. Defaults to True. region_name (str, optional): name of aws s3 region. Defaults to None. public (bool, optional): whether the data is public or not. Defaults to False. streaming (bool, optional): If streaming is to be used by GDAL. Defaults to False. """ def __init__( self, name: str = None, bucket_name: str = None, secret_access_key: str = None, access_key_id: str = None, endpoint_url: str = "", strip_bucket: bool = True, validate_bucket_name: bool = True, region_name: str = None, public: bool = False, streaming: bool = False, **client_kwargs, ): super().__init__(name, endpoint_url) # see # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/ # session.html#boto3.session.Session.client # for client_kwargs self.bucket_name = bucket_name self.secret_access_key = secret_access_key self.access_key_id = access_key_id self.strip_bucket = strip_bucket self.region_name = region_name self.public = public self.streaming = streaming botocore_session = botocore.session.Session() if not validate_bucket_name: botocore_session.unregister( "before-parameter-build.s3", botocore.handlers.validate_bucket_name, ) session = boto3.session.Session(botocore_session=botocore_session) self.client = session.client( "s3", aws_secret_access_key=secret_access_key, aws_access_key_id=access_key_id, endpoint_url=endpoint_url, region_name=region_name, config=Config(signature_version=UNSIGNED) if public else None, **client_kwargs, )
[docs] def get_container_and_path(self, path: str): # try to see if we have a fully qualified S3 URL, in which case we will # return bucket/path from there parsed = urlparse(path) if parsed.scheme: if parsed.scheme.lower() != "s3": raise ValueError("invalid S3 URL {path}") return (parsed.netloc, parsed.path[1:]) bucket = self.bucket_name if bucket is None: parts = (path[1:] if path.startswith("/") else path).split("/") bucket, path = parts[0], "/".join(parts[1:]) elif self.strip_bucket: parts = (path[1:] if path.startswith("/") else path).split("/") if parts[0] == bucket: parts.pop(0) path = "/".join(parts) return bucket, path
[docs] def list_files(self, path: str, glob_patterns: list = None): if glob_patterns and not isinstance(glob_patterns, list): glob_patterns = [glob_patterns] bucket, key = self.get_container_and_path(path) logger.info("Listing S3 files", bucket=bucket, prefix=key) response = self.client.list_objects_v2( Bucket=bucket, Prefix=key, ) return [ f"{bucket}/{item['Key']}" for item in response["Contents"] if glob_patterns is None or any( fnmatch(item["Key"], join(key, glob_pattern)) for glob_pattern in glob_patterns ) ]
[docs] def get_file(self, path: str, target_path: str): bucket, key = self.get_container_and_path(path) logger.info( "Retrieving file from S3", bucket=bucket, key=key, target_path=target_path, ) self.client.download_file(bucket, key, target_path)
[docs] def get_vsi_env_and_path(self, path: str): bucket, key = self.get_container_and_path(path) # parsed = urlparse(self.endpoint_url) protocol = "vsis3" if not self.streaming else "vsis3_streaming" return ( { "AWS_SECRET_ACCESS_KEY": self.secret_access_key, "AWS_ACCESS_KEY_ID": self.access_key_id, "AWS_S3_ENDPOINT": self.endpoint_url, "AWS_NO_SIGN_REQUEST": "YES" if self.public else "NO", }, f'/{protocol}/{bucket}/{key}', )
[docs]class LocalSource(Source): """Handles data on local filesystem Args: name (str): Name of the filesystem root_directory (str): path to root """ def __init__(self, name: str, root_directory: str): super().__init__(name) self.root_directory = root_directory
[docs] def get_container_and_path(self, path: str): return (self.root_directory, path)
def _join_path(self, path): path = normpath(path) if isabs(path): path = path[1:] return join(self.root_directory, path)
[docs] def list_files(self, path: str, glob_patterns: list = None): if glob_patterns and not isinstance(glob_patterns, list): glob_patterns = [glob_patterns] if glob_patterns is not None: return glob(join(self._join_path(path), glob_patterns[0])) else: return glob(join(self._join_path(path), "*"))
[docs] def get_file(self, path: str, target_path: str): shutil.copy(self._join_path(path), target_path)
[docs] def get_vsi_env_and_path(self, path: str): return {}, self._join_path(path)
[docs]class HTTPSource(Source): """Source class for HTTP locations""" def __init__(self, name: str, endpoint_url: str, streaming: bool): super().__init__(name, endpoint_url) self.streaming = streaming
[docs] def get_container_and_path(self, path: str): return (self.endpoint_url, path)
[docs] def list_files(self, path: str, glob_patterns: list = None): raise NotImplementedError()
[docs] def get_file(self, path: str, target_path: str): url = urljoin(self.endpoint_url, path) response = requests.get(url, allow_redirects=True) with open(target_path, "wb") as f: f.write(response.content)
[docs] def get_vsi_env_and_path(self, path: str): if self.streaming: prefix = "/vsicurl_streaming/" else: prefix = "/vsicurl/" return {}, f"{prefix}{urljoin(self.endpoint_url, path)}"
SOURCE_TYPES = { "swift": SwiftSource, "s3": S3Source, "directory": LocalSource, "http": HTTPSource, }
[docs]def get_source( source_cfgs: List[SourceConfig], hrefs: List[str] ) -> Optional[Source]: """Retrieves a Source from a given list of SourceConfigs and a list of hrefs to test against. Arguments: source_cfgs (List[SourceConfig]): the source configs to test hrefs (List[str]): the hrefs to test the sources against Returns: Source: the constructed source from the tested source configuration """ for source_cfg in source_cfgs: if source_cfg.filter: filter_ = source_cfg.filter if any(filter(lambda item: re.match(filter_, item), hrefs)): break else: break else: return None return SOURCE_TYPES[source_cfg.type]( source_cfg.name, *source_cfg.args, **source_cfg.kwargs )