"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license. """ import base64 import fnmatch import logging import re from urllib.parse import urlparse import boto3 from botocore.exceptions import ClientError from core.utils.params import get_env from django.conf import settings from tldextract import TLDExtract logger = logging.getLogger(__name__) def get_client_and_resource( aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None, region_name=None, s3_endpoint=None ): aws_access_key_id = aws_access_key_id or get_env('AWS_ACCESS_KEY_ID') aws_secret_access_key = aws_secret_access_key or get_env('AWS_SECRET_ACCESS_KEY') aws_session_token = aws_session_token or get_env('AWS_SESSION_TOKEN') logger.debug( f'Create boto3 session with ' f'access key id={aws_access_key_id}, ' f'secret key={aws_secret_access_key[:4] + "..." if aws_secret_access_key else None}, ' f'session token={aws_session_token}' ) session = boto3.Session( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token, ) settings = {'region_name': region_name or get_env('S3_region') or 'us-east-1'} s3_endpoint = s3_endpoint or get_env('S3_ENDPOINT') if s3_endpoint: settings['endpoint_url'] = s3_endpoint client = session.client('s3', config=boto3.session.Config(signature_version='s3v4'), **settings) resource = session.resource('s3', config=boto3.session.Config(signature_version='s3v4'), **settings) return client, resource def resolve_s3_url(url, client, presign=True, expires_in=3600): r = urlparse(url, allow_fragments=False) bucket_name = r.netloc key = r.path.lstrip('/') # Return blob as base64 encoded string if presigned urls are disabled if not presign: object = client.get_object(Bucket=bucket_name, Key=key) content_type = object['ResponseMetadata']['HTTPHeaders']['content-type'] object_b64 = 'data:' + content_type + ';base64,' + base64.b64encode(object['Body'].read()).decode('utf-8') return object_b64 # Otherwise try to generate presigned url try: presigned_url = client.generate_presigned_url( ClientMethod='get_object', Params={'Bucket': bucket_name, 'Key': key}, ExpiresIn=expires_in ) except ClientError as exc: logger.warning(f"Can't generate presigned URL. Reason: {exc}") return url else: logger.debug('Presigned URL {presigned_url} generated for {url}'.format(presigned_url=presigned_url, url=url)) return presigned_url class AWS(object): @classmethod def get_blob_metadata( cls, url: str, bucket_name: str, client=None, aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None, region_name=None, s3_endpoint=None, ): """ Get blob metadata by url :param url: Object key :param bucket_name: AWS bucket name :param client: AWS client for batch processing :param account_key: Azure account key :return: Object metadata dict("name": "value") """ if client is None: client, _ = get_client_and_resource( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token, region_name=region_name, s3_endpoint=s3_endpoint, ) object = client.get_object(Bucket=bucket_name, Key=url) metadata = dict(object) # remove unused fields metadata.pop('Body', None) metadata.pop('ResponseMetadata', None) return metadata @classmethod def validate_pattern(cls, storage, pattern, glob_pattern=True): """ Validate pattern against S3 Storage :param storage: S3 Storage instance :param pattern: Pattern to validate :param glob_pattern: If True, pattern is a glob pattern, otherwise it is a regex pattern :return: Message if pattern is not valid, empty string otherwise """ client, bucket = storage.get_client_and_bucket() if glob_pattern: pattern = fnmatch.translate(pattern) regex = re.compile(pattern) if storage.prefix: list_kwargs = {'Prefix': storage.prefix.rstrip('/') + '/'} if not storage.recursive_scan: list_kwargs['Delimiter'] = '/' bucket_iter = bucket.objects.filter(**list_kwargs) else: bucket_iter = bucket.objects bucket_iter = bucket_iter.page_size(settings.CLOUD_STORAGE_CHECK_FOR_RECORDS_PAGE_SIZE).all() for index, obj in enumerate(bucket_iter): key = obj.key # skip directories if key.endswith('/'): logger.debug(key + ' is skipped because it is a folder') continue if regex and regex.match(key): logger.debug(key + ' matches file pattern') return '' return 'No objects found matching the provided glob pattern' class S3StorageError(Exception): pass # see https://github.com/john-kurkowski/tldextract?tab=readme-ov-file#note-about-caching # prevents network call on first use extractor = TLDExtract(suffix_list_urls=()) def catch_and_reraise_from_none(func): """ For S3 storages - if s3_endpoint is not on a known domain, catch exception and raise a new one with the previous context suppressed. See also: https://peps.python.org/pep-0409/ """ def wrapper(self, *args, **kwargs): try: return func(self, *args, **kwargs) except Exception as e: if self.s3_endpoint and ( domain := extractor.extract_urllib(urlparse(self.s3_endpoint)).registered_domain.lower() ) not in [trusted_domain.lower() for trusted_domain in settings.S3_TRUSTED_STORAGE_DOMAINS]: logger.error(f'Exception from unrecognized S3 domain: {e}', exc_info=True) raise S3StorageError( f'Debugging info is not available for s3 endpoints on domain: {domain}. ' 'Please contact your Label Studio devops team if you require detailed error reporting for this domain.' ) from None else: raise e return wrapper