import hashlib
|
import io
|
import json
|
import logging
|
import pathlib
|
import shutil
|
from datetime import datetime
|
from functools import reduce
|
|
import django_rq
|
from core.feature_flags import flag_set
|
from core.redis import redis_connected
|
from core.utils.common import batch
|
from core.utils.io import (
|
SerializableGenerator,
|
get_all_dirs_from_dir,
|
get_all_files_from_dir,
|
get_temp_dir,
|
)
|
from data_manager.models import View
|
from django.conf import settings
|
from django.core.files import File
|
from django.core.files import temp as tempfile
|
from django.db import transaction
|
from django.db.models import Prefetch
|
from django.db.models.query_utils import Q
|
from django.utils import dateformat, timezone
|
from label_studio_sdk.converter import Converter
|
from tasks.models import Annotation, AnnotationDraft, Task
|
|
ONLY = 'only'
|
EXCLUDE = 'exclude'
|
|
|
logger = logging.getLogger(__name__)
|
|
|
class ExportMixin:
|
def has_permission(self, user):
|
user.project = self.project # link for activity log
|
return self.project.has_permission(user)
|
|
def get_default_title(self):
|
return f"{self.project.title.replace(' ', '-')}-at-{dateformat.format(timezone.now(), 'Y-m-d-H-i')}"
|
|
def _get_filtered_tasks(self, tasks, task_filter_options=None):
|
"""
|
task_filter_options: None or Dict({
|
view: optional int id or View
|
skipped: optional None or str:("include|exclude")
|
finished: optional None or str:("include|exclude")
|
annotated: optional None or str:("include|exclude")
|
})
|
"""
|
if not isinstance(task_filter_options, dict):
|
return tasks
|
if 'view' in task_filter_options:
|
try:
|
value = int(task_filter_options['view'])
|
prepare_params = View.objects.get(project=self.project, id=value).get_prepare_tasks_params(
|
add_selected_items=True
|
)
|
tab_tasks = Task.prepared.only_filtered(prepare_params=prepare_params).values_list('id', flat=True)
|
tasks = tasks.filter(id__in=tab_tasks)
|
except (ValueError, View.DoesNotExist) as exc:
|
logger.warning(f'Incorrect view params {exc}')
|
if 'skipped' in task_filter_options:
|
value = task_filter_options['skipped']
|
if value == ONLY:
|
tasks = tasks.filter(annotations__was_cancelled=True)
|
elif value == EXCLUDE:
|
tasks = tasks.exclude(annotations__was_cancelled=True)
|
if 'finished' in task_filter_options:
|
value = task_filter_options['finished']
|
if value == ONLY:
|
tasks = tasks.filter(is_labeled=True)
|
elif value == EXCLUDE:
|
tasks = tasks.exclude(is_labeled=True)
|
if 'annotated' in task_filter_options:
|
value = task_filter_options['annotated']
|
# if any annotation exists and is not cancelled
|
if value == ONLY:
|
tasks = tasks.filter(annotations__was_cancelled=False)
|
elif value == EXCLUDE:
|
tasks = tasks.exclude(annotations__was_cancelled=False)
|
|
return tasks
|
|
def _get_filtered_annotations_queryset(self, annotation_filter_options=None):
|
"""
|
Filtering using disjunction of conditions
|
|
annotation_filter_options: None or Dict({
|
usual: optional None or bool:("true|false")
|
ground_truth: optional None or bool:("true|false")
|
skipped: optional None or bool:("true|false")
|
})
|
"""
|
queryset = Annotation.objects.all()
|
if isinstance(annotation_filter_options, dict):
|
q_list = []
|
if annotation_filter_options.get('usual'):
|
q_list.append(Q(was_cancelled=False, ground_truth=False))
|
if annotation_filter_options.get('ground_truth'):
|
q_list.append(Q(ground_truth=True))
|
if annotation_filter_options.get('skipped'):
|
q_list.append(Q(was_cancelled=True))
|
if q_list:
|
q = reduce(lambda x, y: x | y, q_list)
|
queryset = queryset.filter(q)
|
|
# pre-select completed_by user info
|
queryset = queryset.select_related('completed_by')
|
# prefetch reviews in LSE
|
if hasattr(queryset.model, 'reviews'):
|
from reviews.models import AnnotationReview
|
|
queryset = queryset.prefetch_related(
|
Prefetch('reviews', queryset=AnnotationReview.objects.select_related('created_by'))
|
)
|
|
return queryset
|
|
@staticmethod
|
def _get_export_serializer_option(serialization_options):
|
options = {'expand': []}
|
if isinstance(serialization_options, dict):
|
if (
|
'drafts' in serialization_options
|
and isinstance(serialization_options['drafts'], dict)
|
and not serialization_options['drafts'].get('only_id')
|
):
|
options['expand'].append('drafts')
|
if (
|
'predictions' in serialization_options
|
and isinstance(serialization_options['predictions'], dict)
|
and not serialization_options['predictions'].get('only_id')
|
):
|
options['expand'].append('predictions')
|
if 'annotations__completed_by' in serialization_options and not serialization_options[
|
'annotations__completed_by'
|
].get('only_id'):
|
options['expand'].append('annotations.completed_by')
|
options['context'] = {'interpolate_key_frames': settings.INTERPOLATE_KEY_FRAMES}
|
if 'interpolate_key_frames' in serialization_options:
|
options['context']['interpolate_key_frames'] = serialization_options['interpolate_key_frames']
|
if serialization_options.get('include_annotation_history') is False:
|
options['omit'] = ['annotations.history']
|
# download resources
|
if serialization_options.get('download_resources') is True:
|
options['download_resources'] = True
|
return options
|
|
def get_task_queryset(self, ids, annotation_filter_options):
|
from core.feature_flags import flag_set
|
|
annotations_qs = self._get_filtered_annotations_queryset(annotation_filter_options=annotation_filter_options)
|
|
# Only annotate FSM state if both feature flags are enabled
|
# This prevents unnecessary query annotations when state won't be serialized
|
user = getattr(self, 'created_by', None)
|
if (
|
flag_set('fflag_feat_fit_568_finite_state_management', user=user)
|
and flag_set('fflag_feat_fit_710_fsm_state_fields', user=user)
|
and hasattr(annotations_qs, 'with_state')
|
):
|
annotations_qs = annotations_qs.with_state()
|
|
qs = (
|
Task.objects.filter(id__in=ids)
|
.select_related('file_upload') # select_related more efficient for regular foreign-key relationship
|
.prefetch_related(
|
Prefetch('annotations', queryset=annotations_qs),
|
Prefetch('drafts', queryset=AnnotationDraft.objects.select_related('user')),
|
'comment_authors',
|
)
|
)
|
|
# Add FSM state annotation to tasks as well to avoid N+1 queries during export
|
if (
|
flag_set('fflag_feat_fit_568_finite_state_management', user=user)
|
and flag_set('fflag_feat_fit_710_fsm_state_fields', user=user)
|
and hasattr(qs, 'with_state')
|
):
|
qs = qs.with_state()
|
|
return qs
|
|
def get_export_data(self, task_filter_options=None, annotation_filter_options=None, serialization_options=None):
|
"""
|
serialization_options: None or Dict({
|
drafts: optional
|
None
|
or
|
Dict({
|
only_id: true/false
|
})
|
predictions: optional
|
None
|
or
|
Dict({
|
only_id: true/false
|
})
|
annotations__completed_by: optional
|
None
|
or
|
Dict({
|
only_id: true/false
|
})
|
})
|
"""
|
from .serializers import ExportDataSerializer
|
|
logger.debug('Run get_task_queryset')
|
|
start = datetime.now()
|
with transaction.atomic():
|
# TODO: make counters from queryset
|
# counters = Project.objects.with_counts().filter(id=self.project.id)[0].get_counters()
|
self.counters = {'task_number': 0}
|
all_tasks = self.project.tasks
|
logger.debug('Tasks filtration')
|
task_ids = list(
|
self._get_filtered_tasks(all_tasks, task_filter_options=task_filter_options)
|
.distinct()
|
.values_list('id', flat=True)
|
)
|
base_export_serializer_option = self._get_export_serializer_option(serialization_options)
|
i = 0
|
|
if flag_set('fflag_fix_back_plt_807_batch_size_26062025_short', self.project.organization.created_by):
|
BATCH_SIZE = self.project.get_task_batch_size()
|
else:
|
BATCH_SIZE = settings.BATCH_SIZE
|
|
for ids in batch(task_ids, BATCH_SIZE):
|
i += 1
|
tasks = list(self.get_task_queryset(ids, annotation_filter_options))
|
logger.debug(f'Batch: {i*BATCH_SIZE}')
|
if isinstance(task_filter_options, dict) and task_filter_options.get('only_with_annotations'):
|
tasks = [task for task in tasks if task.annotations.exists()]
|
|
if serialization_options and serialization_options.get('include_annotation_history') is True:
|
task_ids = [task.id for task in tasks]
|
annotation_ids = Annotation.objects.filter(task_id__in=task_ids).values_list('id', flat=True)
|
base_export_serializer_option = self.update_export_serializer_option(
|
base_export_serializer_option, annotation_ids
|
)
|
|
serializer = ExportDataSerializer(tasks, many=True, **base_export_serializer_option)
|
self.counters['task_number'] += len(tasks)
|
for task in serializer.data:
|
yield task
|
duration = datetime.now() - start
|
logger.info(
|
f'{self.counters["task_number"]} tasks from project {self.project_id} exported in {duration.total_seconds():.2f} seconds'
|
)
|
|
def update_export_serializer_option(self, base_export_serializer_option, annotation_ids):
|
return base_export_serializer_option
|
|
@staticmethod
|
def eval_md5(file):
|
md5_object = hashlib.md5() # nosec
|
block_size = 128 * md5_object.block_size
|
chunk = file.read(block_size)
|
while chunk:
|
md5_object.update(chunk)
|
chunk = file.read(block_size)
|
md5 = md5_object.hexdigest()
|
return md5
|
|
def save_file(self, file, md5):
|
now = datetime.now()
|
file_name = f'project-{self.project.id}-at-{now.strftime("%Y-%m-%d-%H-%M")}-{md5[0:8]}.json'
|
file_path = f'{self.project.id}/{file_name}' # finally file will be in settings.DELAYED_EXPORT_DIR/self.project.id/file_name
|
file_ = File(file, name=file_path)
|
self.file.save(file_path, file_)
|
self.md5 = md5
|
self.save(update_fields=['file', 'md5', 'counters'])
|
|
def export_to_file(self, task_filter_options=None, annotation_filter_options=None, serialization_options=None):
|
logger.debug(
|
f'Run export for {self.id} with params:\n'
|
f'task_filter_options: {task_filter_options}\n'
|
f'annotation_filter_options: {annotation_filter_options}\n'
|
f'serialization_options: {serialization_options}\n'
|
)
|
try:
|
iter_json = json.JSONEncoder(ensure_ascii=False).iterencode(
|
SerializableGenerator(
|
self.get_export_data(
|
task_filter_options=task_filter_options,
|
annotation_filter_options=annotation_filter_options,
|
serialization_options=serialization_options,
|
)
|
)
|
)
|
with tempfile.NamedTemporaryFile(suffix='.export.json', dir=settings.FILE_UPLOAD_TEMP_DIR) as file:
|
for chunk in iter_json:
|
encoded_chunk = chunk.encode('utf-8')
|
file.write(encoded_chunk)
|
file.seek(0)
|
|
md5 = self.eval_md5(file)
|
self.save_file(file, md5)
|
|
self.status = self.Status.COMPLETED
|
self.save(update_fields=['status'])
|
|
except Exception as e:
|
self.status = self.Status.FAILED
|
self.save(update_fields=['status'])
|
logger.exception('Export was failed: %s', e)
|
finally:
|
self.finished_at = datetime.now()
|
self.save(update_fields=['finished_at'])
|
|
def run_file_exporting(self, task_filter_options=None, annotation_filter_options=None, serialization_options=None):
|
if self.status == self.Status.IN_PROGRESS:
|
logger.warning('Try to export with in progress stage')
|
return
|
|
self.status = self.Status.IN_PROGRESS
|
self.save(update_fields=['status'])
|
|
if redis_connected():
|
queue = django_rq.get_queue('default')
|
queue.enqueue(
|
export_background,
|
self.id,
|
task_filter_options,
|
annotation_filter_options,
|
serialization_options,
|
on_failure=set_export_background_failure,
|
job_timeout='3h', # 3 hours
|
)
|
else:
|
self.export_to_file(
|
task_filter_options=task_filter_options,
|
annotation_filter_options=annotation_filter_options,
|
serialization_options=serialization_options,
|
)
|
|
def convert_file(self, to_format, download_resources=False, hostname=None):
|
with get_temp_dir() as tmp_dir:
|
OUT = 'out'
|
out_dir = pathlib.Path(tmp_dir) / OUT
|
out_dir.mkdir(mode=0o700, parents=True, exist_ok=True)
|
|
converter = Converter(
|
config=self.project.get_parsed_config(),
|
project_dir=None,
|
upload_dir=out_dir,
|
download_resources=download_resources,
|
# for downloading resource we need access to the API
|
access_token=self.project.organization.created_by.auth_token.key,
|
hostname=hostname,
|
)
|
input_name = pathlib.Path(self.file.name).name
|
input_file_path = pathlib.Path(tmp_dir) / input_name
|
|
with open(input_file_path, 'wb') as file_:
|
file_.write(self.file.open().read())
|
|
converter.convert(input_file_path, out_dir, to_format, is_dir=False)
|
|
files = get_all_files_from_dir(out_dir)
|
dirs = get_all_dirs_from_dir(out_dir)
|
|
if len(files) == 0 and len(dirs) == 0:
|
return None
|
elif len(files) == 1 and len(dirs) == 0:
|
output_file = files[0]
|
filename = pathlib.Path(input_name).stem + pathlib.Path(output_file).suffix
|
else:
|
shutil.make_archive(out_dir, 'zip', out_dir)
|
output_file = pathlib.Path(tmp_dir) / (str(out_dir.stem) + '.zip')
|
filename = pathlib.Path(input_name).stem + '.zip'
|
|
# TODO(jo): can we avoid the `f.read()` here?
|
with open(output_file, mode='rb') as f:
|
return File(
|
io.BytesIO(f.read()),
|
name=filename,
|
)
|
|
|
def export_background(
|
export_id, task_filter_options, annotation_filter_options, serialization_options, *args, **kwargs
|
):
|
from data_export.models import Export
|
|
Export.objects.get(id=export_id).export_to_file(
|
task_filter_options,
|
annotation_filter_options,
|
serialization_options,
|
)
|
|
|
def set_export_background_failure(job, connection, type, value, traceback):
|
from data_export.models import Export
|
|
export_id = job.args[0]
|
Export.objects.filter(id=export_id).update(status=Export.Status.FAILED)
|