"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license. """ import logging from core.permissions import AllPermissions from core.redis import start_job_async_or_sync from data_manager.actions import DataManagerAction from label_studio_sdk.label_interface import LabelInterface from tasks.models import Annotation, Prediction, Task logger = logging.getLogger(__name__) all_permissions = AllPermissions() def cache_labels_job(project, queryset, **kwargs): request_data = kwargs['request_data'] source = request_data.get('source', 'annotations').lower() assert source in ['annotations', 'predictions'], 'Source must be annotations or predictions' source_class = Annotation if source == 'annotations' else Prediction control_tag = request_data.get('custom_control_tag') or request_data.get('control_tag') with_counters = request_data.get('with_counters', 'Yes').lower() == 'yes' label_interface = LabelInterface(project.label_config) label_interface_tags = {tag.name: tag for tag in label_interface.find_tags('control')} if source == 'annotations': column_name = 'cache' else: column_name = 'cache_predictions' # ALL is a special case, we will cache all labels from all control tags into one column if control_tag == 'ALL' or control_tag is None: control_tag = None column_name = f'{column_name}_all' else: column_name = f'{column_name}_{control_tag}' tasks = list(queryset.only('data')) logger.info(f'Cache labels for {len(tasks)} tasks and control tag {control_tag}') for task in tasks: task_labels = [] annotations = source_class.objects.filter(task=task).only('result') for annotation in annotations: labels = extract_labels(annotation, control_tag, label_interface_tags) task_labels.extend(labels) # cache labels in separate data column # with counters if with_counters: task.data[column_name] = ', '.join( sorted([f'{label}: {task_labels.count(label)}' for label in set(task_labels)]) ) # no counters else: task.data[column_name] = ', '.join(sorted(list(set(task_labels)))) Task.objects.bulk_update(tasks, fields=['data'], batch_size=1000) first_task = Task.objects.get(id=queryset.first().id) project.summary.update_data_columns([first_task]) return {'response_code': 200, 'detail': f'Updated {len(tasks)} tasks'} def extract_labels(annotation, control_tag, label_interface_tags=None): labels = [] for region in annotation.result: # find regions with specific control tag name or just all regions if control tag is None if (control_tag is None or region['from_name'] == control_tag) and 'value' in region: # scan value for a field with list of strings (eg choices, textareas) # or taxonomy (list of string-lists) for key in region['value']: if region['value'][key] and isinstance(region['value'][key], list): if key == 'taxonomy': showFullPath = 'true' pathSeparator = '/' if label_interface_tags is not None and region['from_name'] in label_interface_tags: # if from_name is not a custom_control tag, then we can try to fetch taxonomy formatting params label_interface_tag = label_interface_tags[region['from_name']] showFullPath = label_interface_tag.attr.get('showFullPath', 'false') pathSeparator = label_interface_tag.attr.get('pathSeparator', '/') if showFullPath == 'false': for elems in region['value'][key]: labels.append(elems[-1]) # just the leaf node of a taxonomy selection else: for elems in region['value'][key]: labels.append(pathSeparator.join(elems)) # the full delimited taxonomy path # other control tag types like Choices & TextAreas elif isinstance(region['value'][key][0], str): labels.extend(region['value'][key]) break return labels def cache_labels(project, queryset, request, **kwargs): """Cache labels from annotations to a new column in tasks""" start_job_async_or_sync( cache_labels_job, project, queryset, organization_id=project.organization_id, request_data=request.data, job_timeout=60 * 60 * 5, # max allowed duration is 5 hours ) return {'response_code': 200} def cache_labels_form(user, project): labels = project.get_parsed_config() control_tags = ['ALL'] for key, _ in labels.items(): control_tags.append(key) return [ { 'columnCount': 1, 'fields': [ { 'type': 'select', 'name': 'control_tag', 'label': 'Choose a control tag', 'options': control_tags, }, { 'type': 'input', 'name': 'custom_control_tag', 'label': "Custom control tag if it's not in label config", }, { 'type': 'select', 'name': 'with_counters', 'label': 'With counters', 'options': ['Yes', 'No'], }, { 'type': 'select', 'name': 'source', 'label': 'Source', 'options': ['Annotations', 'Predictions'], }, ], } ] actions: list[DataManagerAction] = [ { 'entry_point': cache_labels, 'permission': all_permissions.projects_change, 'title': 'Cache Labels', 'order': 1, 'experimental': True, 'dialog': { 'text': 'Confirm that you want to add a new task.data field with cached labels from annotations. ' 'This field will help you to quickly filter or order tasks by labels. ' 'After this operation you must refresh the Data Manager page fully to see the new column!', 'type': 'confirm', 'form': cache_labels_form, }, }, ]