import logging import time import traceback from typing import Callable, Optional from core.feature_flags import flag_set from core.utils.common import load_func from data_import.uploader import load_tasks_for_async_import_streaming from django.conf import settings from django.db import transaction from label_studio_sdk.label_interface import LabelInterface from projects.models import ProjectImport, ProjectReimport, ProjectSummary from rest_framework.exceptions import ValidationError from tasks.models import Task from users.models import User from webhooks.models import WebhookAction from webhooks.utils import emit_webhooks_for_instance from .models import FileUpload from .serializers import ImportApiSerializer from .uploader import load_tasks_for_async_import logger = logging.getLogger(__name__) def async_import_background( import_id, user_id, recalculate_stats_func: Optional[Callable[..., None]] = None, **kwargs ): with transaction.atomic(): try: project_import = ProjectImport.objects.get(id=import_id) except ProjectImport.DoesNotExist: logger.error(f'ProjectImport with id {import_id} not found, import processing failed') return if project_import.status != ProjectImport.Status.CREATED: logger.error(f'Processing import with id {import_id} already started') return project_import.status = ProjectImport.Status.IN_PROGRESS project_import.save(update_fields=['status']) user = User.objects.get(id=user_id) if flag_set('fflag_fix_back_plt_902_async_import_background_oom_fix_22092025_short', user='auto'): logger.info(f'Using streaming import for project {project_import.project.id}') _async_import_background_streaming(project_import, user) return start = time.time() project = project_import.project tasks = None # upload files from request, and parse all tasks # TODO: Stop passing request to load_tasks function, make all validation before tasks, file_upload_ids, found_formats, data_columns = load_tasks_for_async_import(project_import, user) if project_import.preannotated_from_fields: # turn flat task JSONs {"column1": value, "column2": value} into {"data": {"column1"..}, "predictions": [{..."column2"}] raise_errors = flag_set( 'fflag_feat_utc_210_prediction_validation_15082025', user=project.organization.created_by ) logger.info(f'Reformatting predictions with raise_errors: {raise_errors}') tasks = reformat_predictions(tasks, project_import.preannotated_from_fields, project, raise_errors) # Always validate predictions regardless of commit_to_project setting if project.label_config_is_not_default and flag_set( 'fflag_feat_utc_210_prediction_validation_15082025', user=project.organization.created_by ): validation_errors = [] li = LabelInterface(project.label_config) for i, task in enumerate(tasks): if 'predictions' in task: for j, prediction in enumerate(task['predictions']): try: validation_errors_list = li.validate_prediction(prediction, return_errors=True) if validation_errors_list: for error in validation_errors_list: validation_errors.append(f'Task {i}, prediction {j}: {error}') except Exception as e: error_msg = f'Task {i}, prediction {j}: Error validating prediction - {str(e)}' validation_errors.append(error_msg) logger.error(f'Exception during validation: {error_msg}') if validation_errors: error_message = f'Prediction validation failed ({len(validation_errors)} errors):\n' for error in validation_errors: error_message += f'- {error}\n' if flag_set('fflag_feat_utc_210_prediction_validation_15082025', user=project.organization.created_by): project_import.error = error_message project_import.status = ProjectImport.Status.FAILED project_import.save(update_fields=['error', 'status']) return else: logger.error( f'Prediction validation failed, not raising error - ({len(validation_errors)} errors):\n{error_message}' ) if project_import.commit_to_project: with transaction.atomic(): # Lock summary for update to avoid race conditions summary = ProjectSummary.objects.select_for_update().get(project=project) # Immediately create project tasks and update project states and counters serializer = ImportApiSerializer(data=tasks, many=True, context={'project': project}) serializer.is_valid(raise_exception=True) try: tasks = serializer.save(project_id=project.id) emit_webhooks_for_instance(user.active_organization, project, WebhookAction.TASKS_CREATED, tasks) task_count = len(tasks) annotation_count = len(serializer.db_annotations) prediction_count = len(serializer.db_predictions) # Update counters (like total_annotations) for new tasks and after bulk update tasks stats. It should be a # single operation as counters affect bulk is_labeled update recalculate_stats_counts = { 'task_count': task_count, 'annotation_count': annotation_count, 'prediction_count': prediction_count, } project.update_tasks_counters_and_task_states( tasks_queryset=tasks, maximum_annotations_changed=False, overlap_cohort_percentage_changed=False, tasks_number_changed=True, recalculate_stats_counts=recalculate_stats_counts, ) logger.info('Tasks bulk_update finished (async import)') summary.update_data_columns(tasks) # TODO: summary.update_created_annotations_and_labels except Exception as e: # Handle any other unexpected errors during task creation error_message = f'Error creating tasks: {str(e)}' project_import.error = error_message project_import.status = ProjectImport.Status.FAILED project_import.save(update_fields=['error', 'status']) return else: # Do nothing - just output file upload ids for further use task_count = len(tasks) annotation_count = None prediction_count = None duration = time.time() - start project_import.task_count = task_count or 0 project_import.annotation_count = annotation_count or 0 project_import.prediction_count = prediction_count or 0 project_import.duration = duration project_import.file_upload_ids = file_upload_ids project_import.found_formats = found_formats project_import.data_columns = data_columns if project_import.return_task_ids: project_import.task_ids = [task.id for task in tasks] project_import.status = ProjectImport.Status.COMPLETED project_import.save() def set_import_background_failure(job, connection, type, value, _): import_id = job.args[0] ProjectImport.objects.filter(id=import_id).update( status=ProjectImport.Status.FAILED, traceback=traceback.format_exc(), error=str(value) ) def set_reimport_background_failure(job, connection, type, value, _): reimport_id = job.args[0] ProjectReimport.objects.filter(id=reimport_id).update( status=ProjectReimport.Status.FAILED, traceback=traceback.format_exc(), error=str(value), ) def reformat_predictions(tasks, preannotated_from_fields, project=None, raise_errors=False): """ Transform flat task JSON objects into proper format with separate data and predictions fields. Also validates the predictions to ensure they are properly formatted using LabelInterface. Args: tasks: List of task data preannotated_from_fields: List of field names to convert to predictions project: Optional project instance to determine correct to_name and type from label config """ new_tasks = [] validation_errors = [] # If project is provided, create LabelInterface to determine correct mappings li = None if project: try: li = LabelInterface(project.label_config) except Exception as e: logger.warning(f'Could not create LabelInterface for project {project.id}: {e}') for task_index, task in enumerate(tasks): if 'data' in task: task_data = task['data'] else: task_data = task predictions = [] for field in preannotated_from_fields: if field not in task_data: validation_errors.append(f"Task {task_index}: Preannotated field '{field}' not found in task data") continue value = task_data[field] if value is not None: # Try to determine correct to_name and type from project configuration to_name = 'text' # Default fallback prediction_type = 'choices' # Default fallback if li: # Find a control tag that matches the field name try: control_tag = li.get_control(field) # Use the control's to_name and determine type if hasattr(control_tag, 'to_name') and control_tag.to_name: to_name = ( control_tag.to_name[0] if isinstance(control_tag.to_name, list) else control_tag.to_name ) prediction_type = control_tag.tag.lower() except Exception: # Control not found, use defaults pass # Create prediction from preannotated field # Handle different types of values if isinstance(value, dict): # For complex structures like bounding boxes, use the value directly prediction_value = value else: # For simple values, use the prediction_type as the key # Handle cases where the type doesn't match the expected key value_key = prediction_type if prediction_type == 'textarea': value_key = 'text' # Most types expect lists, but some expect single values if prediction_type in ['rating', 'number', 'datetime']: prediction_value = {value_key: value} else: # Wrap in list for most types prediction_value = {value_key: [value] if not isinstance(value, list) else value} prediction = { 'result': [ { 'from_name': field, 'to_name': to_name, 'type': prediction_type, 'value': prediction_value, } ], 'score': 1.0, 'model_version': 'preannotated', } predictions.append(prediction) # Create new task structure new_task = {'data': task_data, 'predictions': predictions} new_tasks.append(new_task) # If there are validation errors, raise them if validation_errors and raise_errors: raise ValidationError({'preannotated_fields': validation_errors}) return new_tasks post_process_reimport = load_func(settings.POST_PROCESS_REIMPORT) def _async_reimport_background_streaming(reimport, project, organization_id, user): """Streaming version of reimport that processes tasks in batches to reduce memory usage""" try: # Get batch size from settings or use default batch_size = settings.REIMPORT_BATCH_SIZE # Initialize counters total_task_count = 0 total_annotation_count = 0 total_prediction_count = 0 all_found_formats = {} all_data_columns = set() all_created_task_ids = [] # Remove old tasks once before starting with transaction.atomic(): project.remove_tasks_by_file_uploads(reimport.file_upload_ids) # Process tasks in batches batch_number = 0 for batch_tasks, batch_formats, batch_columns in FileUpload.load_tasks_from_uploaded_files_streaming( project, reimport.file_upload_ids, files_as_tasks_list=reimport.files_as_tasks_list, batch_size=batch_size ): if not batch_tasks: logger.info(f'Empty batch received for reimport {reimport.id}') continue batch_number += 1 logger.info(f'Processing batch {batch_number} with {len(batch_tasks)} tasks for reimport {reimport.id}') # Process batch in transaction with transaction.atomic(): # Lock summary for update to avoid race conditions summary = ProjectSummary.objects.select_for_update().get(project=project) # Serialize and save batch serializer = ImportApiSerializer( data=batch_tasks, many=True, context={'project': project, 'user': user} ) serializer.is_valid(raise_exception=True) batch_db_tasks = serializer.save(project_id=project.id) # Collect task IDs for later use all_created_task_ids.extend([t.id for t in batch_db_tasks]) # Update batch counters batch_task_count = len(batch_db_tasks) batch_annotation_count = len(serializer.db_annotations) batch_prediction_count = len(serializer.db_predictions) total_task_count += batch_task_count total_annotation_count += batch_annotation_count total_prediction_count += batch_prediction_count # Update formats and columns all_found_formats.update(batch_formats) if batch_columns: if not all_data_columns: all_data_columns = batch_columns else: all_data_columns &= batch_columns # Update data columns in summary summary.update_data_columns(batch_db_tasks) logger.info( f'Batch {batch_number} processed successfully: {batch_task_count} tasks, ' f'{batch_annotation_count} annotations, {batch_prediction_count} predictions' ) # After all batches are processed, emit webhooks and update task states once if all_created_task_ids: logger.info( f'Finalizing reimport: emitting webhooks and updating task states for {len(all_created_task_ids)} tasks' ) # Emit webhooks for all tasks at once (passing list of IDs) emit_webhooks_for_instance(organization_id, project, WebhookAction.TASKS_CREATED, all_created_task_ids) # Update task states for all tasks at once all_tasks_queryset = Task.objects.filter(id__in=all_created_task_ids) recalculate_stats_counts = { 'task_count': total_task_count, 'annotation_count': total_annotation_count, 'prediction_count': total_prediction_count, } project.update_tasks_counters_and_task_states( tasks_queryset=all_tasks_queryset, maximum_annotations_changed=False, overlap_cohort_percentage_changed=False, tasks_number_changed=True, recalculate_stats_counts=recalculate_stats_counts, ) logger.info('Tasks bulk_update finished (async streaming reimport)') # Update reimport with final statistics reimport.task_count = total_task_count reimport.annotation_count = total_annotation_count reimport.prediction_count = total_prediction_count reimport.found_formats = all_found_formats reimport.data_columns = list(all_data_columns) reimport.status = ProjectReimport.Status.COMPLETED reimport.save() logger.info(f'Streaming reimport {reimport.id} completed: {total_task_count} tasks imported') # Run post-processing post_process_reimport(reimport) except Exception as e: logger.error(f'Error in streaming reimport {reimport.id}: {str(e)}', exc_info=True) reimport.status = ProjectReimport.Status.FAILED reimport.traceback = traceback.format_exc() reimport.error = str(e) reimport.save() raise def _async_import_background_streaming(project_import, user): try: batch_size = settings.IMPORT_BATCH_SIZE total_task_count = 0 total_annotation_count = 0 total_prediction_count = 0 all_created_task_ids = [] project = project_import.project start = time.time() batch_number = 0 streaming_generator = load_tasks_for_async_import_streaming(project_import, user, batch_size) final_file_upload_ids = [] final_found_formats = {} final_data_columns = set() for batch_tasks, file_upload_ids, found_formats, data_columns in streaming_generator: if not batch_tasks: logger.info(f'Empty batch received for import {project_import.id}') continue batch_number += 1 logger.info( f'Processing batch {batch_number} with {len(batch_tasks)} tasks for import {project_import.id}' ) if file_upload_ids and file_upload_ids not in final_file_upload_ids: final_file_upload_ids = file_upload_ids final_found_formats.update(found_formats) final_data_columns.update(data_columns) if project_import.preannotated_from_fields: raise_errors = flag_set( 'fflag_feat_utc_210_prediction_validation_15082025', user=project.organization.created_by ) logger.info(f'Reformatting predictions with raise_errors: {raise_errors}') batch_tasks = reformat_predictions( batch_tasks, project_import.preannotated_from_fields, project, raise_errors ) if project.label_config_is_not_default and flag_set( 'fflag_feat_utc_210_prediction_validation_15082025', user=project.organization.created_by ): validation_errors = [] li = LabelInterface(project.label_config) for i, task in enumerate(batch_tasks): if 'predictions' in task: for j, prediction in enumerate(task['predictions']): try: validation_errors_list = li.validate_prediction(prediction, return_errors=True) if validation_errors_list: for error in validation_errors_list: validation_errors.append( f'Task {total_task_count + i}, prediction {j}: {error}' ) except Exception as e: error_msg = f'Task {total_task_count + i}, prediction {j}: Error validating prediction - {str(e)}' validation_errors.append(error_msg) logger.error(f'Exception during validation: {error_msg}') if validation_errors: error_message = f'Prediction validation failed ({len(validation_errors)} errors):\n' for error in validation_errors: error_message += f'- {error}\n' if flag_set( 'fflag_feat_utc_210_prediction_validation_15082025', user=project.organization.created_by ): project_import.error = error_message project_import.status = ProjectImport.Status.FAILED project_import.save(update_fields=['error', 'status']) return else: logger.error( f'Prediction validation failed, not raising error - ({len(validation_errors)} errors):\n{error_message}' ) if project_import.commit_to_project: with transaction.atomic(): summary = ProjectSummary.objects.select_for_update().get(project=project) serializer = ImportApiSerializer(data=batch_tasks, many=True, context={'project': project}) serializer.is_valid(raise_exception=True) batch_db_tasks = serializer.save(project_id=project.id) all_created_task_ids.extend([t.id for t in batch_db_tasks]) batch_task_count = len(batch_db_tasks) batch_annotation_count = len(serializer.db_annotations) batch_prediction_count = len(serializer.db_predictions) total_task_count += batch_task_count total_annotation_count += batch_annotation_count total_prediction_count += batch_prediction_count summary.update_data_columns(batch_db_tasks) else: total_task_count += len(batch_tasks) logger.info(f'Batch {batch_number} processed successfully: {len(batch_tasks)} tasks') final_data_columns = list(final_data_columns) if project_import.commit_to_project and all_created_task_ids: logger.info( f'Finalizing import: emitting webhooks and updating task states for {len(all_created_task_ids)} tasks' ) emit_webhooks_for_instance( user.active_organization, project, WebhookAction.TASKS_CREATED, all_created_task_ids ) recalculate_stats_counts = { 'task_count': total_task_count, 'annotation_count': total_annotation_count, 'prediction_count': total_prediction_count, } all_tasks_queryset = Task.objects.filter(id__in=all_created_task_ids) project.update_tasks_counters_and_task_states( tasks_queryset=all_tasks_queryset, maximum_annotations_changed=False, overlap_cohort_percentage_changed=False, tasks_number_changed=True, recalculate_stats_counts=recalculate_stats_counts, ) logger.info('Tasks bulk_update finished (async streaming import)') duration = time.time() - start project_import.task_count = total_task_count or 0 project_import.annotation_count = total_annotation_count or 0 project_import.prediction_count = total_prediction_count or 0 project_import.duration = duration project_import.file_upload_ids = final_file_upload_ids project_import.found_formats = final_found_formats project_import.data_columns = final_data_columns if project_import.return_task_ids: project_import.task_ids = all_created_task_ids project_import.status = ProjectImport.Status.COMPLETED project_import.save() logger.info(f'Streaming import {project_import.id} completed: {total_task_count} tasks imported') except Exception as e: logger.error(f'Error in streaming import {project_import.id}: {str(e)}', exc_info=True) project_import.status = ProjectImport.Status.FAILED project_import.traceback = traceback.format_exc() project_import.error = str(e) project_import.save() raise def async_reimport_background(reimport_id, organization_id, user, **kwargs): with transaction.atomic(): try: reimport = ProjectReimport.objects.get(id=reimport_id) except ProjectReimport.DoesNotExist: logger.error(f'ProjectReimport with id {reimport_id} not found, import processing failed') return if reimport.status != ProjectReimport.Status.CREATED: logger.error(f'Processing reimport with id {reimport_id} already started') return reimport.status = ProjectReimport.Status.IN_PROGRESS reimport.save(update_fields=['status']) project = reimport.project # Check feature flag for memory improvement if flag_set('fflag_fix_back_plt_838_reimport_memory_improvement_05082025_short', user='auto'): logger.info(f'Using streaming reimport for project {project.id}') _async_reimport_background_streaming(reimport, project, organization_id, user) else: # Original implementation tasks, found_formats, data_columns = FileUpload.load_tasks_from_uploaded_files( reimport.project, reimport.file_upload_ids, files_as_tasks_list=reimport.files_as_tasks_list ) with transaction.atomic(): # Lock summary for update to avoid race conditions summary = ProjectSummary.objects.select_for_update().get(project=project) project.remove_tasks_by_file_uploads(reimport.file_upload_ids) serializer = ImportApiSerializer(data=tasks, many=True, context={'project': project, 'user': user}) serializer.is_valid(raise_exception=True) tasks = serializer.save(project_id=project.id) emit_webhooks_for_instance(organization_id, project, WebhookAction.TASKS_CREATED, tasks) task_count = len(tasks) annotation_count = len(serializer.db_annotations) prediction_count = len(serializer.db_predictions) recalculate_stats_counts = { 'task_count': task_count, 'annotation_count': annotation_count, 'prediction_count': prediction_count, } # Update counters (like total_annotations) for new tasks and after bulk update tasks stats. It should be a # single operation as counters affect bulk is_labeled update project.update_tasks_counters_and_task_states( tasks_queryset=tasks, maximum_annotations_changed=False, overlap_cohort_percentage_changed=False, tasks_number_changed=True, recalculate_stats_counts=recalculate_stats_counts, ) logger.info('Tasks bulk_update finished (async reimport)') summary.update_data_columns(tasks) # TODO: summary.update_created_annotations_and_labels reimport.task_count = task_count reimport.annotation_count = annotation_count reimport.prediction_count = prediction_count reimport.found_formats = found_formats reimport.data_columns = list(data_columns) reimport.status = ProjectReimport.Status.COMPLETED reimport.save() post_process_reimport(reimport)