"""
source.py
==========
Contains different file source concepts where the data may be stored such as
local storage, S3, swift
"""
import re
from os.path import normpath, join, isabs
import shutil
from glob import glob
from fnmatch import fnmatch
from typing import List, Optional
from abc import ABC, abstractmethod
from urllib.parse import urljoin, urlparse
import requests
import structlog
import boto3
import boto3.session
import botocore.session
import botocore.handlers
from botocore import UNSIGNED
from botocore.config import Config
from swiftclient.service import SwiftService
from .config import SourceConfig
logger = structlog.getLogger(__name__)
[docs]class Source(ABC):
"""Abstract base class for all sources
"""
def __init__(self, name: str = None, endpoint_url: str = ""):
self.name = name
self.endpoint_url = endpoint_url
def __repr__(self) -> str:
return f"<{self.__class__.__name__} name={self.name}>"
[docs] @abstractmethod
def get_container_and_path(self, path: str):
"""Split the input path into a container and a path part
"""
[docs] @abstractmethod
def list_files(self, path: str, glob_patterns: list = None):
"""Return a list of file references for the given base path and glob
pattern
"""
[docs] @abstractmethod
def get_file(self, path: str, target_path: str):
"""Download the given file to the target location"""
[docs] @abstractmethod
def get_vsi_env_and_path(self, path: str):
"""Get a VSI conformant path.
See https://gdal.org/user/virtual_file_systems.html"""
[docs]class SwiftSource(Source):
"""Handles data located on a openstack swift bucket
Args:
name (str, optional): Name of bucket. Defaults to None.
username (str, optional): username for authentication.
Defaults to None.
password (str, optional): password for authentication Defaults to None.
project_name (str, optional): name of swift tenant. Defaults to None.
project_id (str, optional): id of swift tenant. Defaults to None.
region_name (str, optional): name of region. Defaults to None.
project_domain_id (str, optional): project domain identifier.
Defaults to None.
project_domain_name (str, optional): name of project domain.
Defaults to None.
user_domain_id (str, optional): user domain identifier.
Defaults to None.
user_domain_name (str, optional): name of user domain.
Defaults to None.
auth_url (str, optional): url to authenticate to. Defaults to None.
auth_url_short (str, optional): short url to auth to. Defaults to None.
auth_version (str, optional): swift auth version. Defaults to None.
container (str, optional): name of swift container. Defaults to None.
streaming (bool, optional): If streaming is to be used by GDAL.
Defaults to False.
"""
def __init__(
self,
name: str = None,
username: str = None,
password: str = None,
project_name: str = None,
project_id: str = None,
region_name: str = None,
project_domain_id: str = None,
project_domain_name: str = None,
user_domain_id: str = None,
user_domain_name: str = None,
auth_url: str = None,
auth_url_short: str = None,
auth_version: str = None,
container: str = None,
streaming: bool = False,
):
super().__init__(name)
self.username = username
self.password = password
self.project_name = project_name
self.project_id = project_id
self.region_name = region_name
self.project_domain_id = project_domain_id
self.project_domain_name = project_domain_name
self.user_domain_id = user_domain_id
self.user_domain_name = user_domain_name
self.auth_url = auth_url
self.auth_url_short = auth_url_short
self.auth_version = auth_version # TODO: assume 3
self.container = container
self.streaming = streaming
[docs] def get_service(self):
"""Returns the swiftclient.SwiftService for the options.
"""
return SwiftService(
options={
"os_username": self.username,
"os_password": self.password,
"os_project_name": self.project_name,
"os_project_id": self.project_id,
"os_region_name": self.region_name,
"os_auth_url": self.auth_url,
"auth_version": self.auth_version,
"os_project_domain_id": self.project_domain_id,
"os_project_domain_name": self.project_domain_name,
"os_user_domain_id": self.user_domain_id,
"os_user_domain_name": self.user_domain_name,
}
)
[docs] def get_container_and_path(self, path: str):
container = self.container
if container is None or container == "":
if path.startswith("swift://"):
# expects a schema swift://bucket/object
fullpath = path.replace("swift://", "")
else:
# fallback if in schema /bucket/object or bucket/object
fullpath = path[1:] if path.startswith("/") else path
container, _, path = fullpath.partition("/")
return container, path
[docs] def list_files(self, path: str, glob_patterns: list = None):
container, path = self.get_container_and_path(path)
if glob_patterns and not isinstance(glob_patterns, list):
glob_patterns = [glob_patterns]
with self.get_service() as swift:
pages = swift.list(
container=container,
options={"prefix": path},
)
filenames = []
for page in pages:
if page["success"]:
# at least two files present -> pass validation
for item in page["listing"]:
if glob_patterns is None or any(
fnmatch(item["name"], join(path, glob_pattern))
for glob_pattern in glob_patterns
):
filenames.append(
item["name"]
if self.container
else join(container, item["name"])
)
else:
raise page["error"]
return filenames
[docs] def get_file(self, path: str, target_path: str):
container, path = self.get_container_and_path(path)
with self.get_service() as swift:
results = swift.download(
container, [path], options={"out_file": target_path}
)
for result in results:
if not result["success"]:
raise Exception(f"Failed to download {path}")
[docs] def get_vsi_env_and_path(self, path: str):
container, path = self.get_container_and_path(path)
protocol = "vsiswift" if not self.streaming else "vsiswift_streaming"
return {
"OS_IDENTITY_API_VERSION": self.auth_version,
"OS_AUTH_URL": self.auth_url,
"OS_USERNAME": self.username,
"OS_PASSWORD": self.password,
"OS_USER_DOMAIN_NAME": self.user_domain_name,
# 'OS_PROJECT_NAME': self.tena,
# 'OS_PROJECT_DOMAIN_NAME': ,
"OS_REGION_NAME": self.region_name,
}, f"/{protocol}/{container}/{path}"
[docs]class S3Source(Source):
"""Handles data located on an S3 bucket
Args:
name (str, optional): Name of the source. Defaults to None.
bucket_name (str, optional): Name of the bucket. Defaults to None.
secret_access_key (str, optional): secret access key. Defaults to None.
access_key_id (str, optional): access key identifier. Defaults to None.
endpoint_url (str, optional): endpoint url. Defaults to None.
strip_bucket (bool, optional): whether to strip bucket name when
constructing paths. Defaults to True.
validate_bucket_name (bool, optional): whether to validate the name of
bucket. Defaults to True.
region_name (str, optional): name of aws s3 region. Defaults to None.
public (bool, optional): whether the data is public or not. Defaults
to False.
streaming (bool, optional): If streaming is to be used by GDAL.
Defaults to False.
"""
def __init__(
self,
name: str = None,
bucket_name: str = None,
secret_access_key: str = None,
access_key_id: str = None,
endpoint_url: str = "",
strip_bucket: bool = True,
validate_bucket_name: bool = True,
region_name: str = None,
public: bool = False,
streaming: bool = False,
**client_kwargs,
):
super().__init__(name, endpoint_url)
# see
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/
# session.html#boto3.session.Session.client
# for client_kwargs
self.bucket_name = bucket_name
self.secret_access_key = secret_access_key
self.access_key_id = access_key_id
self.strip_bucket = strip_bucket
self.region_name = region_name
self.public = public
self.streaming = streaming
botocore_session = botocore.session.Session()
if not validate_bucket_name:
botocore_session.unregister(
"before-parameter-build.s3",
botocore.handlers.validate_bucket_name,
)
session = boto3.session.Session(botocore_session=botocore_session)
self.client = session.client(
"s3",
aws_secret_access_key=secret_access_key,
aws_access_key_id=access_key_id,
endpoint_url=endpoint_url,
region_name=region_name,
config=Config(signature_version=UNSIGNED) if public else None,
**client_kwargs,
)
[docs] def get_container_and_path(self, path: str):
# try to see if we have a fully qualified S3 URL, in which case we will
# return bucket/path from there
parsed = urlparse(path)
if parsed.scheme:
if parsed.scheme.lower() != "s3":
raise ValueError("invalid S3 URL {path}")
return (parsed.netloc, parsed.path[1:])
bucket = self.bucket_name
if bucket is None:
parts = (path[1:] if path.startswith("/") else path).split("/")
bucket, path = parts[0], "/".join(parts[1:])
elif self.strip_bucket:
parts = (path[1:] if path.startswith("/") else path).split("/")
if parts[0] == bucket:
parts.pop(0)
path = "/".join(parts)
return bucket, path
[docs] def list_files(self, path: str, glob_patterns: list = None):
if glob_patterns and not isinstance(glob_patterns, list):
glob_patterns = [glob_patterns]
bucket, key = self.get_container_and_path(path)
logger.info("Listing S3 files", bucket=bucket, prefix=key)
response = self.client.list_objects_v2(
Bucket=bucket,
Prefix=key,
)
return [
f"{bucket}/{item['Key']}"
for item in response["Contents"]
if glob_patterns is None
or any(
fnmatch(item["Key"], join(key, glob_pattern))
for glob_pattern in glob_patterns
)
]
[docs] def get_file(self, path: str, target_path: str):
bucket, key = self.get_container_and_path(path)
logger.info(
"Retrieving file from S3",
bucket=bucket,
key=key,
target_path=target_path,
)
self.client.download_file(bucket, key, target_path)
[docs] def get_vsi_env_and_path(self, path: str):
bucket, key = self.get_container_and_path(path)
# parsed = urlparse(self.endpoint_url)
protocol = "vsis3" if not self.streaming else "vsis3_streaming"
return (
{
"AWS_SECRET_ACCESS_KEY": self.secret_access_key,
"AWS_ACCESS_KEY_ID": self.access_key_id,
"AWS_S3_ENDPOINT": self.endpoint_url,
"AWS_NO_SIGN_REQUEST": "YES" if self.public else "NO",
},
f'/{protocol}/{bucket}/{key}',
)
[docs]class LocalSource(Source):
"""Handles data on local filesystem
Args:
name (str): Name of the filesystem
root_directory (str): path to root
"""
def __init__(self, name: str, root_directory: str):
super().__init__(name)
self.root_directory = root_directory
[docs] def get_container_and_path(self, path: str):
return (self.root_directory, path)
def _join_path(self, path):
path = normpath(path)
if isabs(path):
path = path[1:]
return join(self.root_directory, path)
[docs] def list_files(self, path: str, glob_patterns: list = None):
if glob_patterns and not isinstance(glob_patterns, list):
glob_patterns = [glob_patterns]
if glob_patterns is not None:
return glob(join(self._join_path(path), glob_patterns[0]))
else:
return glob(join(self._join_path(path), "*"))
[docs] def get_file(self, path: str, target_path: str):
shutil.copy(self._join_path(path), target_path)
[docs] def get_vsi_env_and_path(self, path: str):
return {}, self._join_path(path)
[docs]class HTTPSource(Source):
"""Source class for HTTP locations"""
def __init__(self, name: str, endpoint_url: str, streaming: bool):
super().__init__(name, endpoint_url)
self.streaming = streaming
[docs] def get_container_and_path(self, path: str):
return (self.endpoint_url, path)
[docs] def list_files(self, path: str, glob_patterns: list = None):
raise NotImplementedError()
[docs] def get_file(self, path: str, target_path: str):
url = urljoin(self.endpoint_url, path)
response = requests.get(url, allow_redirects=True)
with open(target_path, "wb") as f:
f.write(response.content)
[docs] def get_vsi_env_and_path(self, path: str):
if self.streaming:
prefix = "/vsicurl_streaming/"
else:
prefix = "/vsicurl/"
return {}, f"{prefix}{urljoin(self.endpoint_url, path)}"
SOURCE_TYPES = {
"swift": SwiftSource,
"s3": S3Source,
"directory": LocalSource,
"http": HTTPSource,
}
[docs]def get_source(
source_cfgs: List[SourceConfig], hrefs: List[str]
) -> Optional[Source]:
"""Retrieves a Source from a given list of SourceConfigs and a list of
hrefs to test against.
Arguments:
source_cfgs (List[SourceConfig]): the source configs to test
hrefs (List[str]): the hrefs to test the sources against
Returns:
Source: the constructed source from the tested source configuration
"""
for source_cfg in source_cfgs:
if source_cfg.filter:
filter_ = source_cfg.filter
if any(filter(lambda item: re.match(filter_, item), hrefs)):
break
else:
break
else:
return None
return SOURCE_TYPES[source_cfg.type](
source_cfg.name, *source_cfg.args, **source_cfg.kwargs
)