"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license. """ import json import logging import mimetypes import time from urllib.parse import unquote, urlparse from core.decorators import override_report_only_csp from core.feature_flags import flag_set from core.permissions import ViewClassPermission, all_permissions from core.redis import start_job_async_or_sync from core.utils.common import retry_database_locked, timeit from core.utils.params import bool_from_request, list_of_strings_from_request from csp.decorators import csp from django.conf import settings from django.db import transaction from django.http import HttpResponse from django.utils.decorators import method_decorator from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema from label_studio_sdk.label_interface import LabelInterface from projects.models import Project, ProjectImport, ProjectReimport from ranged_fileresponse import RangedFileResponse from rest_framework import generics, status from rest_framework.exceptions import ValidationError from rest_framework.parsers import FormParser, JSONParser, MultiPartParser from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response from rest_framework.settings import api_settings from rest_framework.views import APIView from tasks.functions import update_tasks_counters from tasks.models import Prediction, Task from users.models import User from webhooks.models import WebhookAction from webhooks.utils import emit_webhooks_for_instance from label_studio.core.utils.common import load_func from .functions import ( async_import_background, async_reimport_background, reformat_predictions, set_import_background_failure, set_reimport_background_failure, ) from .models import FileUpload from .serializers import FileUploadSerializer, ImportApiSerializer, PredictionSerializer from .uploader import create_file_uploads, load_tasks logger = logging.getLogger(__name__) ProjectImportPermission = load_func(settings.PROJECT_IMPORT_PERMISSION) task_create_response_scheme = { 201: OpenApiResponse( description='Tasks successfully imported', response={ 'title': 'Task creation response', 'description': 'Task creation response', 'type': 'object', 'properties': { 'task_count': { 'title': 'task_count', 'description': 'Number of tasks added', 'type': 'integer', }, 'annotation_count': { 'title': 'annotation_count', 'description': 'Number of annotations added', 'type': 'integer', }, 'predictions_count': { 'title': 'predictions_count', 'description': 'Number of predictions added', 'type': 'integer', }, 'duration': { 'title': 'duration', 'description': 'Time in seconds to create', 'type': 'number', }, 'file_upload_ids': { 'title': 'file_upload_ids', 'description': 'Database IDs of uploaded files', 'type': 'array', 'items': { 'title': 'File Upload IDs', 'type': 'integer', }, }, 'could_be_tasks_list': { 'title': 'could_be_tasks_list', 'description': 'Whether uploaded files can contain lists of tasks, like CSV/TSV files', 'type': 'boolean', }, 'found_formats': { 'title': 'found_formats', 'description': 'The list of found file formats', 'type': 'array', 'items': { 'title': 'File format', 'type': 'string', }, }, 'data_columns': { 'title': 'data_columns', 'description': 'The list of found data columns', 'type': 'array', 'items': { 'title': 'Data column name', 'type': 'string', }, }, }, }, ), 400: OpenApiResponse( description='Bad Request', response={ 'title': 'Incorrect task data', 'description': 'String with error description', 'type': 'string', }, ), } @method_decorator( name='post', decorator=extend_schema( tags=['Import'], responses=task_create_response_scheme, parameters=[ OpenApiParameter( name='id', type=OpenApiTypes.INT, location='path', description='A unique integer value identifying this project.', ), OpenApiParameter( name='commit_to_project', type=OpenApiTypes.BOOL, location='query', description='Set to "true" to immediately commit tasks to the project.', default=True, required=False, ), OpenApiParameter( name='return_task_ids', type=OpenApiTypes.BOOL, location='query', description='Set to "true" to return task IDs in the response.', default=False, required=False, ), OpenApiParameter( name='preannotated_from_fields', many=True, location='query', description='List of fields to preannotate from the task data. For example, if you provide a list of' ' `{"text": "text", "prediction": "label"}` items in the request, the system will create ' 'a task with the `text` field and a prediction with the `label` field when ' '`preannoted_from_fields=["prediction"]`.', default=None, required=False, ), ], summary='Import tasks', description=""" Import data as labeling tasks in bulk using this API endpoint. You can use this API endpoint to import multiple tasks. One POST request is limited at 250K tasks and 200 MB. **Note:** Imported data is verified against a project *label_config* and must include all variables that were used in the *label_config*. For example, if the label configuration has a *$text* variable, then each item in a data object must include a "text" field.
## POST requests
There are three possible ways to import tasks with this endpoint: ### 1. **POST with data** Send JSON tasks as POST data. Only JSON is supported for POSTing files directly. Update this example to specify your authorization token and Label Studio instance host, then run the following from the command line. ```bash curl -H 'Content-Type: application/json' -H 'Authorization: Token abc123' \\ -X POST '{host}/api/projects/1/import' --data '[{{"text": "Some text 1"}}, {{"text": "Some text 2"}}]' ``` ### 2. **POST with files** Send tasks as files. You can attach multiple files with different names. - **JSON**: text files in JavaScript object notation format - **CSV**: text files with tables in Comma Separated Values format - **TSV**: text files with tables in Tab Separated Value format - **TXT**: simple text files are similar to CSV with one column and no header, supported for projects with one source only Update this example to specify your authorization token, Label Studio instance host, and file name and path, then run the following from the command line: ```bash curl -H 'Authorization: Token abc123' \\ -X POST '{host}/api/projects/1/import' -F 'file=@path/to/my_file.csv' ``` ### 3. **POST with URL** You can also provide a URL to a file with labeling tasks. Supported file formats are the same as in option 2. ```bash curl -H 'Content-Type: application/json' -H 'Authorization: Token abc123' \\ -X POST '{host}/api/projects/1/import' \\ --data '[{{"url": "http://example.com/test1.csv"}}, {{"url": "http://example.com/test2.csv"}}]' ```
""".format( host=(settings.HOSTNAME or 'https://localhost:8080') ), request=ImportApiSerializer(many=True), extensions={ 'x-fern-sdk-group-name': 'projects', 'x-fern-sdk-method-name': 'import_tasks', 'x-fern-audiences': ['public'], }, ), ) # Import class ImportAPI(generics.CreateAPIView): permission_required = all_permissions.projects_change permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES + [ProjectImportPermission] parser_classes = (JSONParser, MultiPartParser, FormParser) serializer_class = ImportApiSerializer queryset = Task.objects.all() def get_serializer_context(self): project_id = self.kwargs.get('pk') if project_id: project = generics.get_object_or_404(Project.objects.for_user(self.request.user), pk=project_id) else: project = None return {'project': project, 'user': self.request.user} def post(self, *args, **kwargs): return super(ImportAPI, self).post(*args, **kwargs) def _save(self, tasks): serializer = self.get_serializer(data=tasks, many=True) serializer.is_valid(raise_exception=True) task_instances = serializer.save(project_id=self.kwargs['pk']) project = generics.get_object_or_404(Project.objects.for_user(self.request.user), pk=self.kwargs['pk']) emit_webhooks_for_instance( self.request.user.active_organization, project, WebhookAction.TASKS_CREATED, task_instances ) return task_instances, serializer def sync_import(self, request, project, preannotated_from_fields, commit_to_project, return_task_ids): start = time.time() tasks = None # upload files from request, and parse all tasks # TODO: Stop passing request to load_tasks function, make all validation before parsed_data, file_upload_ids, could_be_tasks_list, found_formats, data_columns = load_tasks(request, project) if preannotated_from_fields: # turn flat task JSONs {"column1": value, "column2": value} into {"data": {"column1"..}, "predictions": [{..."column2"}] raise_errors = flag_set('fflag_feat_utc_210_prediction_validation_15082025', user='auto') logger.info(f'Reformatting predictions with raise_errors: {raise_errors}') parsed_data = reformat_predictions(parsed_data, preannotated_from_fields, project, raise_errors) # Conditionally validate predictions: skip when label config is default during project creation if project.label_config_is_not_default: validation_errors = [] li = LabelInterface(project.label_config) for i, task in enumerate(parsed_data): if 'predictions' in task: for j, prediction in enumerate(task['predictions']): try: validation_errors_list = li.validate_prediction(prediction, return_errors=True) if validation_errors_list: for error in validation_errors_list: validation_errors.append(f'Task {i}, prediction {j}: {error}') except Exception as e: error_msg = f'Task {i}, prediction {j}: Error validating prediction - {str(e)}' validation_errors.append(error_msg) if validation_errors: error_message = f'Prediction validation failed ({len(validation_errors)} errors):\n' for error in validation_errors: error_message += f'- {error}\n' if flag_set('fflag_feat_utc_210_prediction_validation_15082025', user='auto'): raise ValidationError({'predictions': [error_message]}) else: logger.error( f'Prediction validation failed, not raising error - ({len(validation_errors)} errors):\n{error_message}' ) if commit_to_project: # Immediately create project tasks and update project states and counters tasks, serializer = self._save(parsed_data) task_count = len(tasks) annotation_count = len(serializer.db_annotations) prediction_count = len(serializer.db_predictions) recalculate_stats_counts = { 'task_count': task_count, 'annotation_count': annotation_count, 'prediction_count': prediction_count, } # Update counters (like total_annotations) for new tasks and after bulk update tasks stats. It should be a # single operation as counters affect bulk is_labeled update project.update_tasks_counters_and_task_states( tasks_queryset=tasks, maximum_annotations_changed=False, overlap_cohort_percentage_changed=False, tasks_number_changed=True, recalculate_stats_counts=recalculate_stats_counts, ) logger.info('Tasks bulk_update finished (sync import)') project.summary.update_data_columns(parsed_data) # TODO: project.summary.update_created_annotations_and_labels else: # Do nothing - just output file upload ids for further use task_count = len(parsed_data) annotation_count = None prediction_count = None duration = time.time() - start response = { 'task_count': task_count, 'annotation_count': annotation_count, 'prediction_count': prediction_count, 'duration': duration, 'file_upload_ids': file_upload_ids, 'could_be_tasks_list': could_be_tasks_list, 'found_formats': found_formats, 'data_columns': data_columns, } if tasks and return_task_ids: response['task_ids'] = [task.id for task in tasks] return Response(response, status=status.HTTP_201_CREATED) @timeit def async_import(self, request, project, preannotated_from_fields, commit_to_project, return_task_ids): project_import = ProjectImport.objects.create( project=project, preannotated_from_fields=preannotated_from_fields, commit_to_project=commit_to_project, return_task_ids=return_task_ids, ) if len(request.FILES): logger.debug(f'Import from files: {request.FILES}') file_upload_ids, could_be_tasks_list = create_file_uploads(request.user, project, request.FILES) project_import.file_upload_ids = file_upload_ids project_import.could_be_tasks_list = could_be_tasks_list project_import.save(update_fields=['file_upload_ids', 'could_be_tasks_list']) elif 'application/x-www-form-urlencoded' in request.content_type: logger.debug(f'Import from url: {request.data.get("url")}') # empty url url = request.data.get('url') if not url: raise ValidationError('"url" is not found in request data') project_import.url = url project_import.save(update_fields=['url']) # take one task from request DATA elif 'application/json' in request.content_type and isinstance(request.data, dict): project_import.tasks = [request.data] project_import.save(update_fields=['tasks']) # take many tasks from request DATA elif 'application/json' in request.content_type and isinstance(request.data, list): project_import.tasks = request.data project_import.save(update_fields=['tasks']) # incorrect data source else: raise ValidationError('load_tasks: No data found in DATA or in FILES') start_job_async_or_sync( async_import_background, project_import.id, request.user.id, queue_name='high', on_failure=set_import_background_failure, project_id=project.id, organization_id=request.user.active_organization.id, ) response = {'import': project_import.id} return Response(response, status=status.HTTP_201_CREATED) def create(self, request, *args, **kwargs): commit_to_project = bool_from_request(request.query_params, 'commit_to_project', True) return_task_ids = bool_from_request(request.query_params, 'return_task_ids', False) preannotated_from_fields = list_of_strings_from_request(request.query_params, 'preannotated_from_fields', None) # check project permissions project = generics.get_object_or_404(Project.objects.for_user(self.request.user), pk=self.kwargs['pk']) if settings.VERSION_EDITION != 'Community': return self.async_import(request, project, preannotated_from_fields, commit_to_project, return_task_ids) else: return self.sync_import(request, project, preannotated_from_fields, commit_to_project, return_task_ids) # Import @method_decorator( name='post', decorator=extend_schema( tags=['Import'], summary='Import predictions', description='Import model predictions for tasks in the specified project.', parameters=[ OpenApiParameter( name='id', type=OpenApiTypes.INT, location='path', description='A unique integer value identifying this project.', ), ], request=PredictionSerializer(many=True), responses={ 201: OpenApiResponse( description='Predictions successfully imported', response={ 'title': 'Predictions import response', 'description': 'Import result', 'type': 'object', 'properties': { 'created': { 'title': 'created', 'description': 'Number of predictions created', 'type': 'integer', } }, }, ), 400: OpenApiResponse( description='Bad Request', ), }, extensions={ 'x-fern-sdk-group-name': 'projects', 'x-fern-sdk-method-name': 'import_predictions', 'x-fern-audiences': ['public'], }, ), ) class ImportPredictionsAPI(generics.CreateAPIView): """ API for importing predictions to a project. Memory optimization controlled by feature flag: 'fflag_fix_back_4620_memory_efficient_predictions_import_08012025_short' When flag is enabled: Uses memory-efficient batch processing (reduces memory usage by 90-99%) When flag is disabled: Uses legacy implementation for safe fallback """ permission_required = all_permissions.projects_change parser_classes = (JSONParser, MultiPartParser, FormParser) serializer_class = PredictionSerializer queryset = Project.objects.all() def create(self, request, *args, **kwargs): # check project permissions project = self.get_object() # Use feature flag to control memory-efficient implementation rollout if flag_set('fflag_fix_back_4620_memory_efficient_predictions_import_08012025_short', user=self.request.user): return self._create_memory_efficient(project) else: return self._create_legacy(project) def _create_memory_efficient(self, project): """Memory-efficient batch processing implementation""" # Configure batch processing settings # Use smaller batch size for processing to avoid memory issues PROCESSING_BATCH_SIZE = getattr(settings, 'PREDICTION_IMPORT_BATCH_SIZE', 500) request_data = self.request.data total_predictions = len(request_data) logger.debug( f'Importing {total_predictions} predictions to project {project} using memory-efficient batch processing (batch size: {PROCESSING_BATCH_SIZE})' ) total_created = 0 all_task_ids = set() # Process predictions in smaller batches to avoid memory issues for batch_start in range(0, total_predictions, PROCESSING_BATCH_SIZE): batch_end = min(batch_start + PROCESSING_BATCH_SIZE, total_predictions) batch_items = request_data[batch_start:batch_end] # Extract task IDs for this batch batch_task_ids = [item.get('task') for item in batch_items] # Validate that all task IDs in this batch exist in the project # This is much more memory efficient than loading all project task IDs upfront existing_task_ids = set( Task.objects.filter(project=project, id__in=batch_task_ids).values_list('id', flat=True) ) # Build predictions for this batch batch_predictions = [] for item in batch_items: task_id = item.get('task') if task_id not in existing_task_ids: raise ValidationError( f'{item} contains invalid "task" field: task ID {task_id} ' f'not found in project {project}' ) batch_predictions.append( Prediction( task_id=task_id, project_id=project.id, result=Prediction.prepare_prediction_result(item.get('result'), project), score=item.get('score'), model_version=item.get('model_version', 'undefined'), ) ) all_task_ids.add(task_id) # Bulk create this batch with the configured batch size batch_created = Prediction.objects.bulk_create(batch_predictions, batch_size=settings.BATCH_SIZE) total_created += len(batch_created) logger.debug( f'Processed batch {batch_start}-{batch_end-1}: created {len(batch_created)} predictions ' f'(total so far: {total_created})' ) # Update task counters for all affected tasks # Only pass the unique task IDs that were actually processed if all_task_ids: start_job_async_or_sync(update_tasks_counters, Task.objects.filter(id__in=all_task_ids)) return Response({'created': total_created}, status=status.HTTP_201_CREATED) def _create_legacy(self, project): """Legacy implementation - kept for safe rollback""" tasks_ids = set(Task.objects.filter(project=project).values_list('id', flat=True)) logger.debug( f'Importing {len(self.request.data)} predictions to project {project} with {len(tasks_ids)} tasks (legacy mode)' ) li = LabelInterface(project.label_config) # Validate all predictions before creating any validation_errors = [] predictions = [] for i, item in enumerate(self.request.data): # Validate task ID if item.get('task') not in tasks_ids: if flag_set('fflag_feat_utc_210_prediction_validation_15082025', user='auto'): validation_errors.append( f'Prediction {i}: Invalid task ID {item.get("task")} - task not found in project' ) continue else: # Before change we raised only here raise ValidationError( f'{item} contains invalid "task" field: corresponding task ID couldn\'t be retrieved ' f'from project {project} tasks' ) # Validate prediction using LabelInterface only try: validation_errors_list = li.validate_prediction(item, return_errors=True) # If prediction is invalid, add error to validation_errors list and continue to next prediction if validation_errors_list: # Format errors for better readability for error in validation_errors_list: validation_errors.append(f'Prediction {i}: {error}') continue except Exception as e: validation_errors.append(f'Prediction {i}: Error validating prediction - {str(e)}') continue # If prediction is valid, add it to predictions list to be created try: predictions.append( Prediction( task_id=item['task'], project_id=project.id, result=Prediction.prepare_prediction_result(item.get('result'), project), score=item.get('score'), model_version=item.get('model_version', 'undefined'), ) ) except Exception as e: validation_errors.append(f'Prediction {i}: Failed to create prediction - {str(e)}') continue # If there are validation errors, raise them before creating any predictions if validation_errors: if flag_set('fflag_feat_utc_210_prediction_validation_15082025', user='auto'): raise ValidationError(validation_errors) else: logger.error(f'Prediction validation failed ({len(validation_errors)} errors):\n{validation_errors}') predictions_obj = Prediction.objects.bulk_create(predictions, batch_size=settings.BATCH_SIZE) start_job_async_or_sync(update_tasks_counters, Task.objects.filter(id__in=tasks_ids)) return Response({'created': len(predictions_obj)}, status=status.HTTP_201_CREATED) @extend_schema(exclude=True) class TasksBulkCreateAPI(ImportAPI): # just for compatibility - can be safely removed pass class ReImportAPI(ImportAPI): permission_required = all_permissions.projects_change def sync_reimport(self, project, file_upload_ids, files_as_tasks_list): start = time.time() tasks, found_formats, data_columns = FileUpload.load_tasks_from_uploaded_files( project, file_upload_ids, files_as_tasks_list=files_as_tasks_list ) with transaction.atomic(): project.remove_tasks_by_file_uploads(file_upload_ids) tasks, serializer = self._save(tasks) duration = time.time() - start task_count = len(tasks) annotation_count = len(serializer.db_annotations) prediction_count = len(serializer.db_predictions) # Update counters (like total_annotations) for new tasks and after bulk update tasks stats. It should be a # single operation as counters affect bulk is_labeled update project.update_tasks_counters_and_task_states( tasks_queryset=tasks, maximum_annotations_changed=False, overlap_cohort_percentage_changed=False, tasks_number_changed=True, recalculate_stats_counts={ 'task_count': task_count, 'annotation_count': annotation_count, 'prediction_count': prediction_count, }, ) logger.info('Tasks bulk_update finished (sync reimport)') project.summary.update_data_columns(tasks) # TODO: project.summary.update_created_annotations_and_labels return Response( { 'task_count': task_count, 'annotation_count': annotation_count, 'prediction_count': prediction_count, 'duration': duration, 'file_upload_ids': file_upload_ids, 'found_formats': found_formats, 'data_columns': data_columns, }, status=status.HTTP_201_CREATED, ) def async_reimport(self, project, file_upload_ids, files_as_tasks_list, organization_id): project_reimport = ProjectReimport.objects.create( project=project, file_upload_ids=file_upload_ids, files_as_tasks_list=files_as_tasks_list ) start_job_async_or_sync( async_reimport_background, project_reimport.id, organization_id, self.request.user, queue_name='high', on_failure=set_reimport_background_failure, project_id=project.id, ) response = {'reimport': project_reimport.id} return Response(response, status=status.HTTP_201_CREATED) @retry_database_locked() def create(self, request, *args, **kwargs): files_as_tasks_list = bool_from_request(request.data, 'files_as_tasks_list', True) file_upload_ids = self.request.data.get('file_upload_ids') # check project permissions project = generics.get_object_or_404(Project.objects.for_user(self.request.user), pk=self.kwargs['pk']) if not file_upload_ids: return Response( { 'task_count': 0, 'annotation_count': 0, 'prediction_count': 0, 'duration': 0, 'file_upload_ids': [], 'found_formats': {}, 'data_columns': [], }, status=status.HTTP_200_OK, ) if settings.VERSION_EDITION != 'Community': return self.async_reimport( project, file_upload_ids, files_as_tasks_list, request.user.active_organization_id ) else: return self.sync_reimport(project, file_upload_ids, files_as_tasks_list) @extend_schema( exclude=True, summary='Re-import tasks', description=""" Re-import tasks using the specified file upload IDs for a specific project. """, ) def post(self, *args, **kwargs): return super(ReImportAPI, self).post(*args, **kwargs) @method_decorator( name='get', decorator=extend_schema( tags=['Import'], summary='Get files list', parameters=[ OpenApiParameter( name='all', type=OpenApiTypes.BOOL, location='query', description='Set to "true" if you want to retrieve all file uploads', ), OpenApiParameter( name='ids', many=True, location='query', description='Specify the list of file upload IDs to retrieve, e.g. ids=[1,2,3]', ), ], description=""" Retrieve the list of uploaded files used to create labeling tasks for a specific project. """, extensions={ 'x-fern-sdk-group-name': ['files'], 'x-fern-sdk-method-name': 'list', 'x-fern-audiences': ['public'], }, ), ) @method_decorator( name='delete', decorator=extend_schema( tags=['Import'], summary='Delete files', description=""" Delete uploaded files for a specific project. """, extensions={ 'x-fern-sdk-group-name': ['files'], 'x-fern-sdk-method-name': 'delete_many', 'x-fern-audiences': ['public'], }, ), ) class FileUploadListAPI(generics.mixins.ListModelMixin, generics.mixins.DestroyModelMixin, generics.GenericAPIView): parser_classes = (JSONParser, MultiPartParser, FormParser) serializer_class = FileUploadSerializer permission_required = ViewClassPermission( GET=all_permissions.projects_view, DELETE=all_permissions.projects_change, ) queryset = FileUpload.objects.all() def get_queryset(self): project = generics.get_object_or_404(Project.objects.for_user(self.request.user), pk=self.kwargs.get('pk', 0)) if project.is_draft or bool_from_request(self.request.query_params, 'all', False): # If project is in draft state, we return all uploaded files, ignoring queried ids logger.debug(f'Return all uploaded files for draft project {project}') return FileUpload.objects.filter(project_id=project.id, user=self.request.user) # If requested in regular import, only queried IDs are returned to avoid showing previously imported ids = json.loads(self.request.query_params.get('ids', '[]')) logger.debug(f'File Upload IDs found: {ids}') return FileUpload.objects.filter(project_id=project.id, id__in=ids, user=self.request.user) def get(self, request, *args, **kwargs): return self.list(request, *args, **kwargs) def delete(self, request, *args, **kwargs): project = generics.get_object_or_404(Project.objects.for_user(self.request.user), pk=self.kwargs['pk']) ids = self.request.data.get('file_upload_ids') if ids is None: deleted, _ = FileUpload.objects.filter(project=project).delete() elif isinstance(ids, list): deleted, _ = FileUpload.objects.filter(project=project, id__in=ids).delete() else: raise ValueError('"file_upload_ids" parameter must be a list of integers') return Response({'deleted': deleted}, status=status.HTTP_200_OK) @method_decorator( name='get', decorator=extend_schema( tags=['Import'], summary='Get file upload', description='Retrieve details about a specific uploaded file.', extensions={ 'x-fern-sdk-group-name': ['files'], 'x-fern-sdk-method-name': 'get', 'x-fern-audiences': ['public'], }, ), ) @method_decorator( name='patch', decorator=extend_schema( tags=['Import'], summary='Update file upload', description='Update a specific uploaded file.', request=FileUploadSerializer, extensions={ 'x-fern-sdk-group-name': ['files'], 'x-fern-sdk-method-name': 'update', 'x-fern-audiences': ['public'], }, ), ) @method_decorator( name='delete', decorator=extend_schema( tags=['Import'], summary='Delete file upload', description='Delete a specific uploaded file.', extensions={ 'x-fern-sdk-group-name': ['files'], 'x-fern-sdk-method-name': 'delete', 'x-fern-audiences': ['public'], }, ), ) class FileUploadAPI(generics.RetrieveUpdateDestroyAPIView): parser_classes = (JSONParser, MultiPartParser, FormParser) permission_classes = (IsAuthenticated,) serializer_class = FileUploadSerializer queryset = FileUpload.objects.all() def get(self, *args, **kwargs): return super(FileUploadAPI, self).get(*args, **kwargs) def patch(self, *args, **kwargs): return super(FileUploadAPI, self).patch(*args, **kwargs) def delete(self, *args, **kwargs): return super(FileUploadAPI, self).delete(*args, **kwargs) @extend_schema(exclude=True) def put(self, *args, **kwargs): return super(FileUploadAPI, self).put(*args, **kwargs) @method_decorator( name='get', decorator=extend_schema( tags=['Import'], summary='Download file', description='Download a specific uploaded file.', extensions={ 'x-fern-sdk-group-name': ['files'], 'x-fern-sdk-method-name': 'download', 'x-fern-audiences': ['public'], }, responses={ 200: OpenApiResponse(description='File downloaded successfully'), }, ), ) class UploadedFileResponse(generics.RetrieveAPIView): """Serve uploaded files from local drive""" permission_classes = (IsAuthenticated,) @override_report_only_csp @csp(SANDBOX=[]) def get(self, *args, **kwargs): request = self.request filename = kwargs['filename'] # XXX needed, on windows os.path.join generates '\' which breaks FileUpload file = settings.UPLOAD_DIR + ('/' if not settings.UPLOAD_DIR.endswith('/') else '') + filename logger.debug(f'Fetch uploaded file by user {request.user} => {file}') file_upload = FileUpload.objects.filter(file=file).last() if not file_upload.has_permission(request.user): return Response(status=status.HTTP_403_FORBIDDEN) file = file_upload.file if file.storage.exists(file.name): content_type, encoding = mimetypes.guess_type(str(file.name)) content_type = content_type or 'application/octet-stream' return RangedFileResponse(request, file.open(mode='rb'), content_type=content_type) return Response(status=status.HTTP_404_NOT_FOUND) @extend_schema(exclude=True) class DownloadStorageData(APIView): """ Secure file download API for persistent storage (S3, GCS, Azure, etc.) This view provides authenticated access to uploaded files and user avatars stored in cloud storage or local filesystems. It supports two operational modes for optimal performance and flexibility (simplicity). ## Operation Modes: ### 1. NGINX Mode (Default - USE_NGINX_FOR_UPLOADS=True) - **High Performance**: Uses X-Accel-Redirect headers for efficient file serving - **How it works**: 1. Validates user permissions and file access 2. Returns HttpResponse with X-Accel-Redirect header pointing to storage URL 3. NGINX intercepts and serves the file directly from storage - **Benefits**: Reduces Django server load, better performance for large files ### 2. Direct Mode (USE_NGINX_FOR_UPLOADS=False) - **Direct Serving**: Django serves files using RangedFileResponse - **How it works**: 1. Validates user permissions and file access 2. Opens file from storage and streams it with range request support 3. Supports partial content requests (HTTP 206) - **Benefits**: Works without NGINX, supports range requests for media files ## Content-Disposition Logic: - **Inline**: PDFs, audio, video files - because media files are directly displayed in the browser """ http_method_names = ['get'] permission_classes = (IsAuthenticated,) def get(self, request, *args, **kwargs): """Get export files list""" filepath = request.GET.get('filepath') if filepath is None: return Response(status=status.HTTP_404_NOT_FOUND) filepath = unquote(request.GET['filepath']) file_obj = None if filepath.startswith(settings.UPLOAD_DIR): logger.debug(f'Fetch uploaded file by user {request.user} => {filepath}') file_upload = FileUpload.objects.filter(file=filepath).last() if file_upload is not None and file_upload.has_permission(request.user): file_obj = file_upload.file elif filepath.startswith(settings.AVATAR_PATH): user = User.objects.filter(avatar=filepath).first() if user is not None and request.user.active_organization.has_user(user): file_obj = user.avatar if file_obj is None: return Response(status=status.HTTP_403_FORBIDDEN) # NGINX handling is the default for better performance if settings.USE_NGINX_FOR_UPLOADS: url = file_obj.storage.url(file_obj.name, storage_url=True) protocol = urlparse(url).scheme response = HttpResponse() # The below header tells NGINX to catch it and serve, see deploy/default.conf redirect = '/file_download/' + protocol + '/' + url.replace(protocol + '://', '') response['X-Accel-Redirect'] = redirect response['Content-Disposition'] = f'inline; filename="{filepath}"' return response # No NGINX: standard way for direct file serving else: content_type, _ = mimetypes.guess_type(filepath) content_type = content_type or 'application/octet-stream' response = RangedFileResponse(request, file_obj.open(mode='rb'), content_type=content_type) response['Content-Disposition'] = f'inline; filename="{filepath}"' response['filename'] = filepath return response