"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license. """ import base64 import concurrent.futures import itertools import json import logging import os import sys import traceback as tb from concurrent.futures import ThreadPoolExecutor from dataclasses import asdict from datetime import datetime from typing import Any, Iterator, Union from urllib.parse import urljoin import django_rq import rq import rq.exceptions from core.feature_flags import flag_set from core.redis import is_job_in_queue, is_job_on_worker, redis_connected, start_job_async_or_sync from core.utils.common import load_func from core.utils.iterators import iterate_queryset from data_export.serializers import ExportDataSerializer from django.conf import settings from django.contrib.auth.models import AnonymousUser from django.db import models, transaction from django.db.models import JSONField from django.shortcuts import reverse from django.utils import timezone from django.utils.translation import gettext_lazy as _ from io_storages.utils import StorageObject, get_uri_via_regex, parse_bucket_uri from rest_framework.exceptions import ValidationError from rq.job import Job from tasks.models import Annotation, Task from tasks.serializers import AnnotationSerializer, PredictionSerializer from webhooks.models import WebhookAction from webhooks.utils import emit_webhooks_for_instance from .exceptions import UnsupportedFileFormatError logger = logging.getLogger(__name__) class StorageInfo(models.Model): """ StorageInfo helps to understand storage status and progress that happens in background jobs """ class Status(models.TextChoices): INITIALIZED = 'initialized', _('Initialized') QUEUED = 'queued', _('Queued') IN_PROGRESS = 'in_progress', _('In progress') FAILED = 'failed', _('Failed') COMPLETED = 'completed', _('Completed') COMPLETED_WITH_ERRORS = 'completed_with_errors', _('Completed with errors') class Meta: abstract = True last_sync = models.DateTimeField(_('last sync'), null=True, blank=True, help_text='Last sync finished time') last_sync_count = models.PositiveIntegerField( _('last sync count'), null=True, blank=True, help_text='Count of tasks synced last time' ) last_sync_job = models.CharField( _('last_sync_job'), null=True, blank=True, max_length=256, help_text='Last sync job ID' ) status = models.CharField( max_length=64, choices=Status.choices, default=Status.INITIALIZED, ) traceback = models.TextField(null=True, blank=True, help_text='Traceback report for the last failed sync') meta = JSONField('meta', null=True, default=dict, help_text='Meta and debug information about storage processes') def info_set_job(self, job_id): self.last_sync_job = job_id self.save(update_fields=['last_sync_job']) def _update_queued_status(self): self.last_sync = None self.last_sync_count = None self.last_sync_job = None self.status = self.Status.QUEUED # reset and init meta self.meta = {'attempts': self.meta.get('attempts', 0) + 1, 'time_queued': str(timezone.now())} self.save(update_fields=['last_sync_job', 'last_sync', 'last_sync_count', 'status', 'meta']) def info_set_queued(self): if settings.DJANGO_DB == settings.DJANGO_DB_SQLITE: self._update_queued_status() return True with transaction.atomic(): try: locked_storage = self.__class__.objects.select_for_update().get(pk=self.pk) except self.__class__.DoesNotExist: logger.error(f'Storage {self.__class__.__name__} with pk={self.pk} does not exist') return False if locked_storage.status in [self.Status.QUEUED, self.Status.IN_PROGRESS]: logger.error( f'Storage {locked_storage} (id={locked_storage.id}) is already in status ' f'"{locked_storage.status}". Cannot set to QUEUED. ' f'Last sync job: {locked_storage.last_sync_job}, ' f'Meta: {locked_storage.meta}' ) return False locked_storage._update_queued_status() self.refresh_from_db() return True def info_set_in_progress(self): # only QUEUED => IN_PROGRESS transition is possible, because in QUEUED we reset states if self.status != self.Status.QUEUED: raise ValueError(f'Storage status ({self.status}) must be QUEUED to move it IN_PROGRESS') self.status = self.Status.IN_PROGRESS dt = timezone.now() self.meta['time_in_progress'] = str(dt) # at the very beginning it's the same as in progress time self.meta['time_last_ping'] = str(dt) self.save(update_fields=['status', 'meta']) @property def time_in_progress(self): if 'time_failure' not in self.meta: return datetime.fromisoformat(self.meta['time_in_progress']) else: return datetime.fromisoformat(self.meta['time_failure']) def info_set_completed(self, last_sync_count, **kwargs): self.status = self.Status.COMPLETED self.last_sync = timezone.now() self.last_sync_count = last_sync_count time_completed = timezone.now() self.meta['time_completed'] = str(time_completed) self.meta['duration'] = (time_completed - self.time_in_progress).total_seconds() self.meta.update(kwargs) self.save(update_fields=['status', 'meta', 'last_sync', 'last_sync_count']) def info_set_completed_with_errors(self, last_sync_count, validation_errors, **kwargs): self.status = self.Status.COMPLETED_WITH_ERRORS self.last_sync = timezone.now() self.last_sync_count = last_sync_count self.traceback = '\n'.join(validation_errors) time_completed = timezone.now() self.meta['time_completed'] = str(time_completed) self.meta['duration'] = (time_completed - self.time_in_progress).total_seconds() self.meta['tasks_failed_validation'] = len(validation_errors) self.meta.update(kwargs) self.save(update_fields=['status', 'meta', 'last_sync', 'last_sync_count', 'traceback']) def info_set_failed(self): self.status = self.Status.FAILED # Get the current exception info exc_type, exc_value, exc_traceback = sys.exc_info() # Extract human-readable error messages from ValidationError if exc_type and issubclass(exc_type, ValidationError): error_messages = [] if hasattr(exc_value, 'detail'): # Handle ValidationError.detail which can be a dict or list if isinstance(exc_value.detail, dict): for field, errors in exc_value.detail.items(): if isinstance(errors, list): for error in errors: if hasattr(error, 'string'): error_messages.append(error.string) else: error_messages.append(str(error)) else: error_messages.append(str(errors)) elif isinstance(exc_value.detail, list): for error in exc_value.detail: if hasattr(error, 'string'): error_messages.append(error.string) else: error_messages.append(str(error)) else: error_messages.append(str(exc_value.detail)) # Use human-readable messages if available, otherwise fall back to full traceback if error_messages: self.traceback = '\n'.join(error_messages) else: self.traceback = str(tb.format_exc()) else: # For non-ValidationError exceptions, use the full traceback self.traceback = str(tb.format_exc()) time_failure = timezone.now() self.meta['time_failure'] = str(time_failure) self.meta['duration'] = (time_failure - self.time_in_progress).total_seconds() self.save(update_fields=['status', 'traceback', 'meta']) def info_update_progress(self, last_sync_count, **kwargs): # update db counter once per 5 seconds to avid db overloads now = timezone.now() last_ping = datetime.fromisoformat(self.meta['time_last_ping']) delta = (now - last_ping).total_seconds() if delta > settings.STORAGE_IN_PROGRESS_TIMER: self.last_sync_count = last_sync_count self.meta['time_last_ping'] = str(now) self.meta['duration'] = (now - self.time_in_progress).total_seconds() self.meta.update(kwargs) self.save(update_fields=['last_sync_count', 'meta']) @staticmethod def ensure_storage_statuses(storages): """Check failed jobs and set storage status as failed if job is failed :param storages: Import or Export storages """ # iterate over all storages storages = storages.only('id', 'last_sync_job', 'status', 'meta') for storage in storages: storage.health_check() def health_check(self): # get duration between last ping time and now now = timezone.now() last_ping = datetime.fromisoformat(self.meta.get('time_last_ping', str(now))) delta = (now - last_ping).total_seconds() # check redis connection if redis_connected(): self.job_health_check() # in progress last ping time, job is not needed here if self.status == self.Status.IN_PROGRESS and delta > settings.STORAGE_IN_PROGRESS_TIMER * 5: self.status = self.Status.FAILED self.traceback = ( 'It appears the job was failed because the last ping time is too old, ' 'and no traceback information is available.\n' 'This typically occurs if job was manually removed ' 'or workers reloaded unexpectedly.' ) self.save(update_fields=['status', 'traceback']) logger.info( f'Storage {self} status moved to `failed` ' f'because the job {self.last_sync_job} has too old ping time' ) def job_health_check(self): Status = self.Status if self.status not in [Status.IN_PROGRESS, Status.QUEUED]: return queue = django_rq.get_queue('low') try: sync_job = Job.fetch(self.last_sync_job, connection=queue.connection) job_status = sync_job.get_status() except rq.exceptions.NoSuchJobError: job_status = 'not found' # broken synchronization between storage and job # this might happen when job was stopped because of OOM and on_failure wasn't called if job_status == 'failed': self.status = Status.FAILED self.traceback = ( 'It appears the job was terminated unexpectedly, ' 'and no traceback information is available.\n' 'This typically occurs due to an out-of-memory (OOM) error.' ) self.save(update_fields=['status', 'traceback']) logger.info(f'Storage {self} status moved to `failed` ' f'because of the failed job {self.last_sync_job}') # job is not found in redis (maybe deleted while redeploy), storage status is still active elif job_status == 'not found': self.status = Status.FAILED self.traceback = ( 'It appears the job was not found in redis, ' 'and no traceback information is available.\n' 'This typically occurs if job was manually removed ' 'or workers reloaded unexpectedly.' ) self.save(update_fields=['status', 'traceback']) logger.info( f'Storage {self} status moved to `failed` ' f'because the job {self.last_sync_job} was not found' ) class Storage(StorageInfo): url_scheme = '' title = models.CharField(_('title'), null=True, blank=True, max_length=256, help_text='Cloud storage title') description = models.TextField(_('description'), null=True, blank=True, help_text='Cloud storage description') created_at = models.DateTimeField(_('created at'), auto_now_add=True, help_text='Creation time') synchronizable = models.BooleanField(_('synchronizable'), default=True, help_text='If storage can be synced') def validate_connection(self, client=None): raise NotImplementedError('validate_connection is not implemented') class Meta: abstract = True class ImportStorage(Storage): def iter_objects(self) -> Iterator[Any]: """ Returns: Iterator[Any]: An iterator for objects in the storage. """ raise NotImplementedError def iter_keys(self) -> Iterator[str]: """ Returns: Iterator[str]: An iterator of keys for each object in the storage. """ raise NotImplementedError def get_unified_metadata(self, obj: Any) -> dict: """ Args: obj: The storage object to get metadata for Returns: dict: A dictionary of metadata for the object with keys: 'key', 'last_modified', 'size'. """ raise NotImplementedError def get_data(self, key) -> list[StorageObject]: raise NotImplementedError def generate_http_url(self, url): raise NotImplementedError def get_bytes_stream(self, uri): """Get file bytes from storage as a stream and content type. Args: uri: The URI of the file to retrieve Returns: Tuple of (BytesIO stream, content_type) """ raise NotImplementedError def can_resolve_url(self, url: Union[str, None]) -> bool: return self.can_resolve_scheme(url) def can_resolve_scheme(self, url: Union[str, None]) -> bool: if not url: return False # TODO: Search for occurrences inside string, e.g. for cases like "gs://bucket/file.pdf" or "" _, prefix = get_uri_via_regex(url, prefixes=(self.url_scheme,)) bucket_uri = parse_bucket_uri(url, self) # If there is a prefix and the bucket matches the storage's bucket/container/path if prefix == self.url_scheme and bucket_uri: # bucket is used for s3 and gcs if hasattr(self, 'bucket') and bucket_uri.bucket == self.bucket: return True # container is used for azure blob if hasattr(self, 'container') and bucket_uri.bucket == self.container: return True # path is used for redis if hasattr(self, 'path') and bucket_uri.bucket == self.path: return True # if not found any occurrences - this Storage can't resolve url return False def resolve_uri(self, uri, task=None): # list of objects if isinstance(uri, list): resolved = [] for item in uri: result = self.resolve_uri(item, task) resolved.append(result if result else item) return resolved # dict of objects elif isinstance(uri, dict): resolved = {} for key in uri.keys(): result = self.resolve_uri(uri[key], task) resolved[key] = result if result else uri[key] return resolved # string: process one url elif isinstance(uri, str) and self.url_scheme in uri: try: # extract uri first from task data extracted_uri, _ = get_uri_via_regex(uri, prefixes=(self.url_scheme,)) if not self.can_resolve_url(extracted_uri): logger.debug(f'No storage info found for URI={uri}') return if flag_set('fflag_optic_all_optic_1938_storage_proxy', user=self.project.organization.created_by): if task is None: logger.error(f'Task is required to resolve URI={uri}', exc_info=True) raise ValueError(f'Task is required to resolve URI={uri}') proxy_url = urljoin( settings.HOSTNAME, reverse('storages:task-storage-data-resolve', kwargs={'task_id': task.id}) + f'?fileuri={base64.urlsafe_b64encode(extracted_uri.encode()).decode()}', ) return uri.replace(extracted_uri, proxy_url) # ff off: old logic without proxy else: if self.presign and task is not None: proxy_url = urljoin( settings.HOSTNAME, reverse('storages:task-storage-data-presign', kwargs={'task_id': task.id}) + f'?fileuri={base64.urlsafe_b64encode(extracted_uri.encode()).decode()}', ) return uri.replace(extracted_uri, proxy_url) else: # this branch is our old approach: # it generates presigned URLs if storage.presign=True; # or it inserts base64 media into task data if storage.presign=False http_url = self.generate_http_url(extracted_uri) return uri.replace(extracted_uri, http_url) except Exception: logger.info(f"Can't resolve URI={uri}", exc_info=True) def _scan_and_create_links_v2(self): # Async job execution for batch of objects: # e.g. GCS example # | "GetKey" >> --> read file content into label_studio_semantic_search.indexer.RawDataObject repr # | "AggregateBatch" >> beam.Combine --> combine read objects into a batch # | "AddObjects" >> label_studio_semantic_search.indexer.add_objects_from_bucket # --> add objects from batch to Vector DB # or for project task creation last step would be # | "AddObject" >> ImportStorage.add_task raise NotImplementedError @classmethod def add_task(cls, project, maximum_annotations, max_inner_id, storage, link_object: StorageObject, link_class): link_kwargs = asdict(link_object) data = link_kwargs.pop('task_data', None) allow_skip = data.get('allow_skip', None) # predictions predictions = data.get('predictions') or [] if predictions: if 'data' not in data: raise ValueError( 'If you use "predictions" field in the task, ' 'you must put "data" field in the task too' ) # annotations annotations = data.get('annotations') or [] cancelled_annotations = 0 if annotations: if 'data' not in data: raise ValueError( 'If you use "annotations" field in the task, ' 'you must put "data" field in the task too' ) cancelled_annotations = len([a for a in annotations if a.get('was_cancelled', False)]) if 'data' in data and isinstance(data['data'], dict): if data['data'] is not None: data = data['data'] else: data.pop('data') with transaction.atomic(): # Create task without skip_fsm (it's not a model field) task = Task( data=data, project=project, overlap=maximum_annotations, is_labeled=len(annotations) >= maximum_annotations, total_predictions=len(predictions), total_annotations=len(annotations) - cancelled_annotations, cancelled_annotations=cancelled_annotations, inner_id=max_inner_id, allow_skip=(allow_skip if allow_skip is not None else True), ) # Save with skip_fsm flag to bypass FSM during bulk import task.save(skip_fsm=True) link_class.create(task, storage=storage, **link_kwargs) logger.debug(f'Create {storage.__class__.__name__} link with {link_kwargs} for {task=}') raise_exception = not flag_set( 'ff_fix_back_dev_3342_storage_scan_with_invalid_annotations', user=AnonymousUser() ) # add predictions logger.debug(f'Create {len(predictions)} predictions for task={task}') for prediction in predictions: prediction['task'] = task.id prediction['project'] = project.id prediction_ser = PredictionSerializer(data=predictions, many=True) # Always validate predictions and raise exception if invalid raise_prediction_exception = ( flag_set('fflag_feat_utc_210_prediction_validation_15082025', user=project.organization.created_by) or raise_exception ) if prediction_ser.is_valid(raise_exception=raise_prediction_exception): prediction_ser.save() # add annotations logger.debug(f'Create {len(annotations)} annotations for task={task}') for annotation in annotations: annotation['task'] = task.id annotation['project'] = project.id annotation_ser = AnnotationSerializer(data=annotations, many=True) # Always validate annotations, but control error handling based on FF if annotation_ser.is_valid(): annotation_ser.save() else: # Log validation errors but don't save invalid annotations logger.error(f'Invalid annotations for task {task.id}: {annotation_ser.errors}') if raise_exception: raise ValidationError(annotation_ser.errors) return task # FIXME: add_annotation_history / post_process_annotations should be here def _scan_and_create_links(self, link_class): """ TODO: deprecate this function and transform it to "pipeline" version _scan_and_create_links_v2, TODO: it must be compatible with opensource, so old version is needed as well """ # set in progress status for storage info self.info_set_in_progress() tasks_existed = tasks_created = 0 maximum_annotations = self.project.maximum_annotations task = self.project.tasks.order_by('-inner_id').first() max_inner_id = (task.inner_id + 1) if task else 1 validation_errors = [] # Check feature flags once for the entire sync process check_file_extension = flag_set( 'fflag_fix_back_plt_804_check_file_extension_11072025_short', organization=self.project.organization ) existed_count_flag_set = flag_set( 'fflag_root_212_reduce_importstoragelink_counts', organization=self.project.organization ) tasks_for_webhook = [] for keys_batch in _batched( self.iter_keys(), settings.STORAGE_EXISTED_COUNT_BATCH_SIZE if existed_count_flag_set else 1 ): deduplicated_keys = list(dict.fromkeys(keys_batch)) # preserve order for key in deduplicated_keys: logger.debug(f'Scanning key {key}') # w/o Dataflow # pubsub.push(topic, key) # -> GF.pull(topic, key) + env -> add_task() # skip if key has already been synced existing_keys = link_class.exists(deduplicated_keys, self) tasks_existed += link_class.objects.filter(key__in=existing_keys, storage=self.id).count() self.info_update_progress(last_sync_count=tasks_created, tasks_existed=tasks_existed) for key in deduplicated_keys: if key in existing_keys: logger.debug(f'{self.__class__.__name__} already has tasks linked to {key=}') continue logger.debug(f'{self}: found new key {key}') # Check if file should be processed as JSON based on extension # Skip non-JSON files if use_blob_urls is False if check_file_extension and not self.use_blob_urls: _, ext = os.path.splitext(key.lower()) # Only process files with JSON/JSONL/PARQUET extensions json_extensions = {'.json', '.jsonl', '.parquet'} if ext and ext not in json_extensions: raise UnsupportedFileFormatError( f'File "{key}" is not a JSON/JSONL/Parquet file. Only .json, .jsonl, and .parquet files can be processed.\n' f"If you're trying to import non-JSON data (images, audio, text, etc.), " f'edit storage settings and enable "Tasks" import method' ) try: link_objects = self.get_data(key) except (UnicodeDecodeError, json.decoder.JSONDecodeError) as exc: logger.debug(exc, exc_info=True) raise ValueError( f'Error loading JSON from file "{key}".\nIf you\'re trying to import non-JSON data ' f'(images, audio, text, etc.), edit storage settings and enable ' f'"Tasks" import method' ) for link_object in link_objects: # TODO: batch this loop body with add_task -> add_tasks in a single bulk write. # See DIA-2062 for prerequisites try: task = self.add_task( self.project, maximum_annotations, max_inner_id, self, link_object, link_class=link_class, ) max_inner_id += 1 # update progress counters for storage info tasks_created += 1 # add task to webhook list tasks_for_webhook.append(task.id) except ValidationError as e: # Log validation errors but continue processing other tasks error_message = f'Validation error for task from {link_object.key}: {e}' logger.error(error_message) validation_errors.append(error_message) continue # settings.WEBHOOK_BATCH_SIZE # `WEBHOOK_BATCH_SIZE` sets the maximum number of tasks sent in a single webhook call, ensuring manageable payload sizes. # When `tasks_for_webhook` accumulates tasks equal to/exceeding `WEBHOOK_BATCH_SIZE`, they're sent in a webhook via # `emit_webhooks_for_instance`, and `tasks_for_webhook` is cleared for new tasks. # If tasks remain in `tasks_for_webhook` at process end (less than `WEBHOOK_BATCH_SIZE`), they're sent in a final webhook # call to ensure all tasks are processed and no task is left unreported in the webhook. if len(tasks_for_webhook) >= settings.WEBHOOK_BATCH_SIZE: emit_webhooks_for_instance( self.project.organization, self.project, WebhookAction.TASKS_CREATED, tasks_for_webhook ) tasks_for_webhook = [] self.info_update_progress(last_sync_count=tasks_created, tasks_existed=tasks_existed) if tasks_for_webhook: emit_webhooks_for_instance( self.project.organization, self.project, WebhookAction.TASKS_CREATED, tasks_for_webhook ) # Create initial FSM states for all tasks created during storage sync # CurrentContext is now available because we use start_job_async_or_sync from fsm.functions import backfill_fsm_states_for_tasks backfill_fsm_states_for_tasks(self.id, tasks_created, link_class) self.project.update_tasks_states( maximum_annotations_changed=False, overlap_cohort_percentage_changed=False, tasks_number_changed=True ) if validation_errors: # sync is finished, set completed with errors status for storage info self.info_set_completed_with_errors( last_sync_count=tasks_created, tasks_existed=tasks_existed, validation_errors=validation_errors ) else: # sync is finished, set completed status for storage info self.info_set_completed(last_sync_count=tasks_created, tasks_existed=tasks_existed) def scan_and_create_links(self): """This is proto method - you can override it, or just replace ImportStorageLink by your own model""" self._scan_and_create_links(ImportStorageLink) def sync(self): if redis_connected(): queue_name = 'low' queue = django_rq.get_queue(queue_name) meta = {'project': self.project.id, 'storage': self.id} if not is_job_in_queue(queue, 'import_sync_background', meta=meta) and not is_job_on_worker( job_id=self.last_sync_job, queue_name=queue_name ): if not self.info_set_queued(): return # Use start_job_async_or_sync to automatically capture and restore CurrentContext # This ensures user_id, organization_id, and request_id are available in the worker sync_job = start_job_async_or_sync( import_sync_background, self.__class__, self.id, queue_name=queue_name, meta=meta, project_id=self.project.id, organization_id=self.project.organization.id, on_failure=storage_background_failure, job_timeout=settings.RQ_LONG_JOB_TIMEOUT, ) self.info_set_job(sync_job.id) logger.info(f'Storage sync background job {sync_job.id} for storage {self} has been started') else: try: logger.info(f'Start syncing storage {self}') if not self.info_set_queued(): return import_sync_background(self.__class__, self.id) except Exception: # needed to facilitate debugging storage-related testcases, since otherwise no exception is logged logger.debug(f'Storage {self} failed', exc_info=True) storage_background_failure(self) class Meta: abstract = True class ProjectStorageMixin(models.Model): project = models.ForeignKey( 'projects.Project', related_name='%(app_label)s_%(class)ss', on_delete=models.CASCADE, help_text='A unique integer value identifying this project.', ) def has_permission(self, user): user.project = self.project # link for activity log if self.project.has_permission(user): return True return False class Meta: abstract = True def import_sync_background(storage_class, storage_id, timeout=settings.RQ_LONG_JOB_TIMEOUT, **kwargs): storage = storage_class.objects.get(id=storage_id) try: storage.scan_and_create_links() except UnsupportedFileFormatError: # This is an expected error when user tries to import non-JSON files without enabling blob URLs # We don't want to fail the job in this case, just mark the storage as failed with a clear message storage.info_set_failed() # Exit gracefully without raising exception to avoid job failure return def export_sync_background(storage_class, storage_id, **kwargs): storage = storage_class.objects.get(id=storage_id) storage.save_all_annotations() def export_sync_only_new_background(storage_class, storage_id, **kwargs): storage = storage_class.objects.get(id=storage_id) storage.save_only_new_annotations() def storage_background_failure(*args, **kwargs): # job is used in rqworker failure, extract storage id from job arguments if isinstance(args[0], rq.job.Job): sync_job = args[0] _class = sync_job.args[0] storage_id = sync_job.args[1] storage = _class.objects.filter(id=storage_id).first() if storage is None: logger.info(f'Storage {_class} {storage_id} not found at job {sync_job} failure') return # storage is used when redis and rqworkers are not available (e.g. in opensource) elif isinstance(args[0], Storage): # we have to load storage with the last states from DB # the current args[0] instance might be outdated storage_id = args[0].id storage = args[0].__class__.objects.filter(id=storage_id).first() else: raise ValueError(f'Unknown storage in {args}') # save info about failure for storage info storage.info_set_failed() # note: this is available in python 3.12 , #TODO to switch to builtin function when we move to it. def _batched(iterable, n): # batched('ABCDEFG', 3) --> ABC DEF G if n < 1: raise ValueError('n must be at least one') it = iter(iterable) while batch := tuple(itertools.islice(it, n)): yield batch class ExportStorage(Storage, ProjectStorageMixin): can_delete_objects = models.BooleanField( _('can_delete_objects'), null=True, blank=True, help_text='Deletion from storage enabled' ) # Use 8 threads, unless we know we only have a single core # TODO from testing, more than 8 seems to cause problems. revisit to add more parallelism. max_workers = min(8, (os.cpu_count() or 2) * 4) def _get_serialized_data(self, annotation): user = self.project.organization.created_by flag = flag_set( 'fflag_feat_optic_650_target_storage_task_format_long', user=user, override_system_default=False ) if settings.FUTURE_SAVE_TASK_TO_STORAGE or flag: # export task with annotations # TODO: we have to rewrite save_all_annotations, because this func will be called for each annotation # TODO: instead of each task, however, we have to call it only once per task expand = ['annotations.reviews', 'annotations.completed_by'] context = {'project': self.project} return ExportDataSerializer(annotation.task, context=context, expand=expand).data else: serializer_class = load_func(settings.STORAGE_ANNOTATION_SERIALIZER) # deprecated functionality - save only annotation return serializer_class(annotation, context={'project': self.project}).data def save_annotation(self, annotation): raise NotImplementedError def save_annotations(self, annotations: models.QuerySet[Annotation]): annotation_exported = 0 total_annotations = annotations.count() self.info_set_in_progress() self.cached_user = self.project.organization.created_by # Calculate optimal batch size based on project data and worker count project_batch_size = self.project.get_task_batch_size() chunk_size = max(1, project_batch_size // self.max_workers) logger.info( f'Export storage {self.id}: using chunk_size={chunk_size} ' f'(project_batch_size={project_batch_size}, max_workers={self.max_workers})' ) with ThreadPoolExecutor(max_workers=self.max_workers) as executor: # Batch annotations so that we update progress before having to submit every future. # Updating progress in thread requires coordinating on count and db writes, so just # batching to keep it simpler. for annotation_batch in _batched( iterate_queryset(Annotation.objects.filter(project=self.project), chunk_size=chunk_size), chunk_size, ): futures = [] for annotation in annotation_batch: annotation.cached_user = self.cached_user futures.append(executor.submit(self.save_annotation, annotation)) for future in concurrent.futures.as_completed(futures): annotation_exported += 1 self.info_update_progress(last_sync_count=annotation_exported, total_annotations=total_annotations) self.info_set_completed(last_sync_count=annotation_exported, total_annotations=total_annotations) def save_all_annotations(self): self.save_annotations(Annotation.objects.filter(project=self.project)) def save_only_new_annotations(self): """Do not update existing annotations, only ensure that all annotations have an ExportStorageLink""" # Get the storage-specific ExportStorageLink model storage_link_model = self.links.model new_annotations = Annotation.objects.filter(project=self.project).exclude( id__in=storage_link_model.objects.filter(storage=self, annotation__project=self.project).values( 'annotation_id' ) ) self.save_annotations(new_annotations) def sync(self, save_only_new_annotations: bool = False): if save_only_new_annotations: export_sync_fn = export_sync_only_new_background else: export_sync_fn = export_sync_background if redis_connected(): queue = django_rq.get_queue('low') if not self.info_set_queued(): return sync_job = queue.enqueue( export_sync_fn, self.__class__, self.id, job_timeout=settings.RQ_LONG_JOB_TIMEOUT, project_id=self.project.id, organization_id=self.project.organization.id, on_failure=storage_background_failure, ) self.info_set_job(sync_job.id) logger.info(f'Storage sync background job {sync_job.id} for storage {self} has been queued') else: try: logger.info(f'Start syncing storage {self}') if not self.info_set_queued(): return export_sync_fn(self.__class__, self.id) except Exception: storage_background_failure(self) class Meta: abstract = True class ImportStorageLink(models.Model): task = models.OneToOneField('tasks.Task', on_delete=models.CASCADE, related_name='%(app_label)s_%(class)s') key = models.TextField(_('key'), null=False, help_text='External link key') # This field is set to True on creation and never updated; it should not be relied upon. object_exists = models.BooleanField( _('object exists'), help_text='Whether object under external link still exists', default=True ) created_at = models.DateTimeField(_('created at'), auto_now_add=True, help_text='Creation time') row_group = models.IntegerField(null=True, blank=True, help_text='Parquet row group') row_index = models.IntegerField(null=True, blank=True, help_text='Parquet row index, or JSON[L] object index') @classmethod def exists(cls, keys, storage) -> set[str]: return set(cls.objects.filter(key__in=keys, storage=storage.id).values_list('key', flat=True).distinct()) @classmethod def create(cls, task, key, storage, row_index=None, row_group=None): link, created = cls.objects.get_or_create( task_id=task.id, key=key, row_index=row_index, row_group=row_group, storage=storage, object_exists=True ) return link class Meta: abstract = True class ExportStorageLink(models.Model): annotation = models.ForeignKey( 'tasks.Annotation', on_delete=models.CASCADE, related_name='%(app_label)s_%(class)s' ) object_exists = models.BooleanField( _('object exists'), help_text='Whether object under external link still exists', default=True ) created_at = models.DateTimeField(_('created at'), auto_now_add=True, help_text='Creation time') updated_at = models.DateTimeField(_('updated at'), auto_now=True, help_text='Update time') @staticmethod def get_key(annotation): # get user who created the organization explicitly using filter/values_list to avoid prefetching user = getattr(annotation, 'cached_user', None) # when signal for annotation save is called, user is not cached if user is None: user = annotation.project.organization.created_by flag = flag_set('fflag_feat_optic_650_target_storage_task_format_long', user=user) if settings.FUTURE_SAVE_TASK_TO_STORAGE or flag: ext = '.json' if settings.FUTURE_SAVE_TASK_TO_STORAGE_JSON_EXT or flag else '' return str(annotation.task.id) + ext else: return str(annotation.id) @property def key(self): return self.get_key(self.annotation) @classmethod def exists(cls, annotation, storage): return cls.objects.filter(annotation=annotation.id, storage=storage.id).exists() @classmethod def create(cls, annotation, storage): link, created = cls.objects.get_or_create(annotation=annotation, storage=storage, object_exists=True) if not created: # update updated_at field link.save() return link def has_permission(self, user): user.project = self.annotation.project # link for activity log if self.annotation.has_permission(user): return True return False class Meta: abstract = True