"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license. """ import json import logging import re from typing import Union from urllib.parse import urlparse import boto3 from core.feature_flags import flag_set from core.redis import start_job_async_or_sync from django.conf import settings from django.db import models from django.db.models.signals import post_save, pre_delete from django.dispatch import receiver from django.utils.translation import gettext_lazy as _ from io_storages.base_models import ( ExportStorage, ExportStorageLink, ImportStorage, ImportStorageLink, ProjectStorageMixin, ) from io_storages.s3.utils import ( catch_and_reraise_from_none, get_client_and_resource, resolve_s3_url, ) from io_storages.utils import StorageObject, load_tasks_json, storage_can_resolve_bucket_url from tasks.models import Annotation from label_studio.io_storages.s3.utils import AWS logger = logging.getLogger(__name__) logging.getLogger('botocore').setLevel(logging.CRITICAL) boto3.set_stream_logger(level=logging.INFO) clients_cache = {} class S3StorageMixin(models.Model): bucket = models.TextField(_('bucket'), null=True, blank=True, help_text='S3 bucket name') prefix = models.TextField(_('prefix'), null=True, blank=True, help_text='S3 bucket prefix') regex_filter = models.TextField( _('regex_filter'), null=True, blank=True, help_text='Cloud storage regex for filtering objects', ) use_blob_urls = models.BooleanField( _('use_blob_urls'), default=False, help_text='Interpret objects as BLOBs and generate URLs', ) aws_access_key_id = models.TextField(_('aws_access_key_id'), null=True, blank=True, help_text='AWS_ACCESS_KEY_ID') aws_secret_access_key = models.TextField( _('aws_secret_access_key'), null=True, blank=True, help_text='AWS_SECRET_ACCESS_KEY', ) aws_session_token = models.TextField(_('aws_session_token'), null=True, blank=True, help_text='AWS_SESSION_TOKEN') aws_sse_kms_key_id = models.TextField( _('aws_sse_kms_key_id'), null=True, blank=True, help_text='AWS SSE KMS Key ID' ) region_name = models.TextField(_('region_name'), null=True, blank=True, help_text='AWS Region') s3_endpoint = models.TextField(_('s3_endpoint'), null=True, blank=True, help_text='S3 Endpoint') @catch_and_reraise_from_none def get_client_and_resource(self): # s3 client initialization ~ 100 ms, for 30 tasks it's a 3 seconds, so we need to cache it cache_key = f'{self.aws_access_key_id}:{self.aws_secret_access_key}:{self.aws_session_token}:{self.region_name}:{self.s3_endpoint}' if cache_key in clients_cache: return clients_cache[cache_key] result = get_client_and_resource( self.aws_access_key_id, self.aws_secret_access_key, self.aws_session_token, self.region_name, self.s3_endpoint, ) clients_cache[cache_key] = result return result def get_client(self): client, _ = self.get_client_and_resource() return client def get_client_and_bucket(self, validate_connection=True): client, s3 = self.get_client_and_resource() if validate_connection: self.validate_connection(client) return client, s3.Bucket(self.bucket) @catch_and_reraise_from_none def validate_connection(self, client=None): logger.debug('validate_connection') if client is None: client = self.get_client() # TODO(jo): add check for write access for .*Export.* classes is_export = 'Export' in self.__class__.__name__ if self.prefix: logger.debug( f'[Class {self.__class__.__name__}]: Test connection to bucket {self.bucket} with prefix {self.prefix} using ListObjectsV2 operation' ) result = client.list_objects_v2(Bucket=self.bucket, Prefix=self.prefix, MaxKeys=1) # We expect 1 key with the prefix for imports. For exports it's okay if there are 0 with the prefix. expected_keycount = 0 if is_export else 1 if (keycount := result.get('KeyCount')) is None or keycount < expected_keycount: raise KeyError(f'{self.url_scheme}://{self.bucket}/{self.prefix} not found.') else: logger.debug( f'[Class {self.__class__.__name__}]: Test connection to bucket {self.bucket} using HeadBucket operation' ) client.head_bucket(Bucket=self.bucket) @property def path_full(self): prefix = self.prefix or '' return f'{self.url_scheme}://{self.bucket}/{prefix}' @property def type_full(self): return 'Amazon AWS S3' @catch_and_reraise_from_none def get_bytes_stream(self, uri, range_header=None): """Get file directly from S3 using iter_chunks without wrapper. This method forwards Range headers directly to S3 and returns the raw stream. Note: The returned stream is NOT seekable and will break if seeking backwards. Args: uri: The S3 URI of the file to retrieve range_header: Optional HTTP Range header to forward to S3 Returns: Tuple of (stream, content_type, metadata) where metadata contains important S3 headers like ETag, ContentLength, etc. """ # Parse URI to get bucket and key parsed_uri = urlparse(uri, allow_fragments=False) bucket_name = parsed_uri.netloc key = parsed_uri.path.lstrip('/') # Get S3 client client = self.get_client() try: # Forward Range header to S3 if provided request_params = {'Bucket': bucket_name, 'Key': key} if range_header: request_params['Range'] = range_header # Get the object from S3 response = client.get_object(**request_params) # Extract metadata to return metadata = { 'ETag': response.get('ETag'), 'ContentLength': response.get('ContentLength'), 'ContentRange': response.get('ContentRange'), 'LastModified': response.get('LastModified'), 'StatusCode': response['ResponseMetadata']['HTTPStatusCode'], } # Return the streaming body directly return response['Body'], response.get('ContentType'), metadata except Exception as e: logger.error(f'Error getting direct stream from S3 for uri {uri}: {e}', exc_info=True) return None, None, {} class Meta: abstract = True class S3ImportStorageBase(S3StorageMixin, ImportStorage): url_scheme = 's3' presign = models.BooleanField(_('presign'), default=True, help_text='Generate presigned URLs') presign_ttl = models.PositiveSmallIntegerField( _('presign_ttl'), default=1, help_text='Presigned URLs TTL (in minutes)' ) recursive_scan = models.BooleanField( _('recursive scan'), default=False, help_text=_('Perform recursive scan over the bucket content'), ) @catch_and_reraise_from_none def iter_objects(self): _, bucket = self.get_client_and_bucket() list_kwargs = {} if self.prefix: list_kwargs['Prefix'] = self.prefix.rstrip('/') + '/' if not self.recursive_scan: list_kwargs['Delimiter'] = '/' bucket_iter = bucket.objects.filter(**list_kwargs).all() regex = re.compile(str(self.regex_filter)) if self.regex_filter else None for obj in bucket_iter: key = obj.key if key.endswith('/'): logger.debug(key + ' is skipped because it is a folder') continue if regex and not regex.match(key): logger.debug(key + ' is skipped by regex filter') continue logger.debug(f's3 {key} has passed the regex filter') yield obj @catch_and_reraise_from_none def iter_keys(self): for obj in self.iter_objects(): yield obj.key def get_unified_metadata(self, obj): return { 'key': obj.key, 'last_modified': obj.last_modified, 'size': obj.size, } @catch_and_reraise_from_none def scan_and_create_links(self): return self._scan_and_create_links(S3ImportStorageLink) @catch_and_reraise_from_none def get_data(self, key) -> list[StorageObject]: uri = f'{self.url_scheme}://{self.bucket}/{key}' if self.use_blob_urls: data_key = settings.DATA_UNDEFINED_NAME task = {data_key: uri} return [StorageObject(key=key, task_data=task)] # read task json from bucket and validate it _, s3 = self.get_client_and_resource() bucket = s3.Bucket(self.bucket) obj = s3.Object(bucket.name, key).get()['Body'].read() return load_tasks_json(obj, key) @catch_and_reraise_from_none def generate_http_url(self, url): return resolve_s3_url(url, self.get_client(), self.presign, expires_in=self.presign_ttl * 60) @catch_and_reraise_from_none def can_resolve_url(self, url: Union[str, None]) -> bool: return storage_can_resolve_bucket_url(self, url) @catch_and_reraise_from_none def get_blob_metadata(self, key): return AWS.get_blob_metadata( key, self.bucket, aws_access_key_id=self.aws_access_key_id, aws_secret_access_key=self.aws_secret_access_key, aws_session_token=self.aws_session_token, region_name=self.region_name, s3_endpoint=self.s3_endpoint, ) class Meta: abstract = True class S3ImportStorage(ProjectStorageMixin, S3ImportStorageBase): class Meta: abstract = False class S3ExportStorage(S3StorageMixin, ExportStorage): @catch_and_reraise_from_none def save_annotation(self, annotation): client, s3 = self.get_client_and_resource() logger.debug(f'Creating new object on {self.__class__.__name__} Storage {self} for annotation {annotation}') ser_annotation = self._get_serialized_data(annotation) # get key that identifies this object in storage key = S3ExportStorageLink.get_key(annotation) key = str(self.prefix) + '/' + key if self.prefix else key # put object into storage additional_params = {} self.cached_user = getattr(self, 'cached_user', self.project.organization.created_by) if flag_set( 'fflag_feat_back_lsdv_3958_server_side_encryption_for_target_storage_short', user=self.cached_user, ): if self.aws_sse_kms_key_id: additional_params['SSEKMSKeyId'] = self.aws_sse_kms_key_id additional_params['ServerSideEncryption'] = 'aws:kms' else: additional_params['ServerSideEncryption'] = 'AES256' s3.Object(self.bucket, key).put(Body=json.dumps(ser_annotation), **additional_params) # create link if everything ok S3ExportStorageLink.create(annotation, self) @catch_and_reraise_from_none def delete_annotation(self, annotation): client, s3 = self.get_client_and_resource() logger.debug(f'Deleting object on {self.__class__.__name__} Storage {self} for annotation {annotation}') # get key that identifies this object in storage key = S3ExportStorageLink.get_key(annotation) key = str(self.prefix) + '/' + key if self.prefix else key # delete object from storage s3.Object(self.bucket, key).delete() # delete link if everything ok S3ExportStorageLink.objects.filter(storage=self, annotation=annotation).delete() def async_export_annotation_to_s3_storages(annotation): project = annotation.project if hasattr(project, 'io_storages_s3exportstorages'): for storage in project.io_storages_s3exportstorages.all(): logger.debug(f'Export {annotation} to S3 storage {storage}') storage.save_annotation(annotation) @receiver(post_save, sender=Annotation) def export_annotation_to_s3_storages(sender, instance, **kwargs): storages = getattr(instance.project, 'io_storages_s3exportstorages', None) if storages and storages.exists(): # avoid excess jobs in rq start_job_async_or_sync(async_export_annotation_to_s3_storages, instance) @receiver(pre_delete, sender=Annotation) def delete_annotation_from_s3_storages(sender, instance, **kwargs): links = S3ExportStorageLink.objects.filter(annotation=instance) for link in links: storage = link.storage if storage.can_delete_objects: logger.debug(f'Delete {instance} from S3 storage {storage}') # nosec storage.delete_annotation(instance) class S3ImportStorageLink(ImportStorageLink): storage = models.ForeignKey(S3ImportStorage, on_delete=models.CASCADE, related_name='links') @classmethod def exists(cls, keys, storage) -> set[str]: super_exists = super(S3ImportStorageLink, cls).exists # TODO: this is a workaround to be compatible with old keys version - remove it later prefix = str(storage.prefix) or '' if prefix: return ( super_exists(keys, storage) | super_exists([prefix + key for key in keys], storage) | super_exists([prefix + '/' + key for key in keys], storage) ) else: return super_exists(keys, storage) class S3ExportStorageLink(ExportStorageLink): storage = models.ForeignKey(S3ExportStorage, on_delete=models.CASCADE, related_name='links')