"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license. """ from __future__ import unicode_literals import calendar import contextlib import copy import importlib import logging import os import random import re import time import traceback as tb import uuid from collections import defaultdict from copy import deepcopy from functools import wraps from typing import Any, Callable, Generator, Iterable, Mapping, Optional import pytz import requests import ujson as json from colorama import Fore from core.utils.params import get_env from django.conf import settings from django.contrib.postgres.operations import BtreeGinExtension, TrigramExtension from django.core.exceptions import ValidationError from django.core.paginator import EmptyPage, Paginator from django.core.validators import URLValidator from django.db import models, transaction from django.db.models.signals import ( post_delete, post_init, post_migrate, post_save, pre_delete, pre_init, pre_migrate, pre_save, ) from django.db.utils import OperationalError from django.utils import timezone from django.utils.crypto import get_random_string from django.utils.module_loading import import_string from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import OpenApiParameter, OpenApiResponse from label_studio_sdk._extensions.label_studio_tools.core.utils.exceptions import ( LabelStudioXMLSyntaxErrorSentryIgnored, ) from packaging.version import parse as parse_version from pyboxen import boxen from rest_framework import status from rest_framework.exceptions import APIException, ErrorDetail from rest_framework.views import Response, exception_handler import label_studio try: from sentry_sdk import capture_exception, set_tag sentry_sdk_loaded = True except (ModuleNotFoundError, ImportError): sentry_sdk_loaded = False from core import version from core.utils.exceptions import LabelStudioDatabaseLockedException # these functions will be included to another modules, don't remove them from core.utils.params import int_from_request logger = logging.getLogger(__name__) url_validator = URLValidator() def _override_exceptions(exc): if isinstance(exc, OperationalError) and 'database is locked' in str(exc): return LabelStudioDatabaseLockedException() return exc def custom_exception_handler(exc, context): """Make custom exception treatment in RestFramework :param exc: Exception - you can check specific exception :param context: context :return: response with error desc """ exception_id = uuid.uuid4() sentry_skip = False if isinstance(exc, APIException) and exc.status_code < 500: # Skipping Sentry for non-500 unhandled exceptions sentry_skip = True logger.error( '{} {}'.format(exception_id, exc), exc_info=True, extra={'sentry_skip': sentry_skip, 'exception_id': exception_id}, ) exc = _override_exceptions(exc) # error body structure response_data = { 'id': exception_id, 'status_code': status.HTTP_500_INTERNAL_SERVER_ERROR, # default value 'version': label_studio.__version__, 'detail': 'Unknown error', # default value 'exc_info': None, } if hasattr(exc, 'display_context'): response_data['display_context'] = deepcopy(exc.display_context) # try rest framework handler response = exception_handler(exc, context) if response is not None: response_data['status_code'] = response.status_code if 'detail' in response.data and isinstance(response.data['detail'], ErrorDetail): response_data['detail'] = response.data['detail'] response.data = response_data # move validation errors to separate namespace else: response_data['detail'] = 'Validation error' response_data['validation_errors'] = ( response.data if isinstance(response.data, dict) else {'non_field_errors': response.data} ) response.data = response_data # non-standard exception else: if sentry_sdk_loaded: # pass exception to sentry set_tag('exception_id', exception_id) capture_exception(exc) exc_tb = tb.format_exc() logger.debug(exc_tb) response_data['detail'] = str(exc) if not settings.DEBUG_MODAL_EXCEPTIONS: exc_tb = None response_data['exc_info'] = exc_tb # Thrown by sdk when label config is invalid if isinstance(exc, LabelStudioXMLSyntaxErrorSentryIgnored): response_data['status_code'] = status.HTTP_400_BAD_REQUEST response = Response(status=status.HTTP_400_BAD_REQUEST, data=response_data) else: response = Response(status=status.HTTP_500_INTERNAL_SERVER_ERROR, data=response_data) return response def create_hash() -> str: """This function creates a secure token for the organization""" return get_random_string(length=40) def paginator(objects, request, default_page=1, default_size=50): """DEPRECATED TODO: change to standard drf pagination class Get from request page and page_size and return paginated objects :param objects: all queryset :param request: view request object :param default_page: start page if there is no page in GET :param default_size: page size if there is no page in GET :return: paginated objects """ page_size = request.GET.get('page_size', request.GET.get('length', default_size)) if settings.TASK_API_PAGE_SIZE_MAX and (int(page_size) > settings.TASK_API_PAGE_SIZE_MAX or page_size == '-1'): page_size = settings.TASK_API_PAGE_SIZE_MAX if 'start' in request.GET: page = int_from_request(request.GET, 'start', default_page) if page and int(page) > int(page_size) > 0: page = int(page / int(page_size)) + 1 else: page += 1 else: page = int_from_request(request.GET, 'page', default_page) if page_size == '-1': return objects try: return Paginator(objects, page_size).page(page).object_list except ZeroDivisionError: return [] except EmptyPage: return [] def paginator_help(objects_name, tag): """API help for paginator, use it with drf_spectacular :return: dict """ if settings.TASK_API_PAGE_SIZE_MAX: page_size_description = f'[or "length"] {objects_name} per page. Max value {settings.TASK_API_PAGE_SIZE_MAX}' else: page_size_description = ( f'[or "length"] {objects_name} per page, use -1 to obtain all {objects_name} ' '(in this case "page" has no effect and this operation might be slow)' ) return dict( tags=[tag], parameters=[ OpenApiParameter( name='page', type=OpenApiTypes.INT, location='query', description='[or "start"] current page' ), OpenApiParameter( name='page_size', type=OpenApiTypes.INT, location='query', description=page_size_description ), ], responses={ 200: OpenApiResponse(description='OK') # 404: OpenApiResponse(description=f'No more {objects_name} found') }, ) def string_is_url(url): try: url_validator(url) except ValidationError: return False else: return True def safe_float(v, default=0): if v != v: return default return v def sample_query(q, sample_size): n = q.count() if n == 0: raise ValueError("Can't sample from empty query") ids = q.values_list('id', flat=True) random_ids = random.sample(list(ids), sample_size) return q.filter(id__in=random_ids) def get_client_ip(request): """Get IP address from django request :param request: django request :return: str with ip """ x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR') if x_forwarded_for: ip = x_forwarded_for.split(',')[0] else: ip = request.META.get('REMOTE_ADDR') return ip def get_attr_or_item(obj, key): if hasattr(obj, key): return getattr(obj, key) elif isinstance(obj, dict) and key in obj: return obj[key] else: raise KeyError(f"Can't get attribute or dict key '{key}' from {obj}") def datetime_to_timestamp(dt): if dt.tzinfo: dt = dt.astimezone(pytz.UTC) return calendar.timegm(dt.timetuple()) def timestamp_now(): return datetime_to_timestamp(timezone.now()) def find_first_one_to_one_related_field_by_prefix(instance, prefix): if hasattr(instance, '_find_first_one_to_one_related_field_by_prefix_cache'): return getattr(instance, '_find_first_one_to_one_related_field_by_prefix_cache') result = None for field in instance._meta.get_fields(): if issubclass(type(field), models.fields.related.OneToOneRel): attr_name = field.get_accessor_name() if re.match(prefix, attr_name) and hasattr(instance, attr_name): result = getattr(instance, attr_name) break instance._find_first_one_to_one_related_field_by_prefix_cache = result return result def start_browser(ls_url, no_browser): import threading import webbrowser if no_browser: return browser_url = ls_url threading.Timer(2.5, lambda: webbrowser.open(browser_url)).start() logger.info('Start browser at URL: ' + browser_url) def db_is_not_sqlite() -> bool: """ A common predicate for use with conditional_atomic. Checks if the DB is NOT sqlite, because sqlite dbs are locked during any write. """ return settings.DJANGO_DB != settings.DJANGO_DB_SQLITE @contextlib.contextmanager def conditional_atomic( predicate: Callable[..., bool], predicate_args: Optional[Iterable[Any]] = None, predicate_kwargs: Optional[Mapping[str, Any]] = None, ) -> Generator[None, None, None]: """Use transaction if and only if the passed predicate function returns true Params: predicate: function taking any combination of args and kwargs predicate_args: optional array of positional args for the predicate predicate_kwargs: optional map of keyword args for the predicate """ should_use_transaction = predicate(*(predicate_args or []), **(predicate_kwargs or {})) if should_use_transaction: with transaction.atomic(): yield else: yield def retry_database_locked(): back_off = 2 def deco_retry(f): @wraps(f) def f_retry(*args, **kwargs): mtries, mdelay = 10, 3 while mtries > 0: try: return f(*args, **kwargs) except OperationalError as e: if 'database is locked' in str(e): time.sleep(mdelay) mtries -= 1 mdelay *= back_off else: raise return f(*args, **kwargs) return f_retry return deco_retry def get_app_version(): return importlib.metadata.version('label-studio') def get_latest_version(): """Get version from pypi https://pypi.org/pypi/%s/json""" pypi_url = 'https://pypi.tuna.tsinghua.edu.cn/pypi/%s/json' % label_studio.package_name try: response = requests.get(pypi_url, timeout=10).text data = json.loads(response) latest_version = data['info']['version'] upload_time = data.get('releases', {}).get(latest_version, [{}])[-1].get('upload_time', None) except Exception: logger.warning("Can't get latest version", exc_info=True) else: return {'latest_version': latest_version, 'upload_time': upload_time} def current_version_is_outdated(latest_version): latest_version = parse_version(latest_version) current_version = parse_version(label_studio.__version__) return current_version < latest_version def check_for_the_latest_version(print_message): """Check latest pypi version""" if not settings.LATEST_VERSION_CHECK: return import label_studio # prevent excess checks by time intervals current_time = time.time() if label_studio.__latest_version_check_time__ and current_time - label_studio.__latest_version_check_time__ < 60: return label_studio.__latest_version_check_time__ = current_time data = get_latest_version() if not data: return latest_version = data['latest_version'] outdated = latest_version and current_version_is_outdated(latest_version) def update_package_message(): update_command = 'pip install -U ' + label_studio.package_name return boxen( 'Update available {curr_version} → {latest_version}\nRun {command}'.format( curr_version=label_studio.__version__, latest_version=latest_version, command=update_command ), style='double', ).replace(update_command, Fore.CYAN + update_command + Fore.RESET) if outdated and print_message: print(update_package_message()) label_studio.__latest_version__ = latest_version label_studio.__latest_version_upload_time__ = data['upload_time'] label_studio.__current_version_is_outdated__ = outdated # check version ASAP while package loading # skip notification for uwsgi, as we're running in production ready mode if settings.APP_WEBSERVER != 'uwsgi': check_for_the_latest_version(print_message=True) def collect_versions(force=False): """Collect versions for all modules :return: dict with sub-dicts of version descriptions """ import label_studio # prevent excess checks by time intervals current_time = time.time() need_check = current_time - settings.VERSIONS_CHECK_TIME > 300 settings.VERSIONS_CHECK_TIME = current_time if settings.VERSIONS and not force and not need_check: return settings.VERSIONS # main pypi package result = { 'release': label_studio.__version__, 'label-studio-os-package': { 'version': label_studio.__version__, 'short_version': '.'.join(label_studio.__version__.split('.')[:2]), 'latest_version_from_pypi': label_studio.__latest_version__, 'latest_version_upload_time': label_studio.__latest_version_upload_time__, 'current_version_is_outdated': label_studio.__current_version_is_outdated__, }, # backend full git info 'label-studio-os-backend': version.get_git_commit_info(ls=True), } # label studio frontend try: with open(os.path.join(settings.EDITOR_ROOT, 'version.json')) as f: lsf = json.load(f) result['label-studio-frontend'] = lsf except: # noqa: E722 pass # data manager try: with open(os.path.join(settings.DM_ROOT, 'version.json')) as f: dm = json.load(f) result['dm2'] = dm except: # noqa: E722 pass # converter from label-studio-sdk try: import label_studio_sdk.converter result['label-studio-converter'] = {'version': label_studio_sdk.__version__} except Exception: pass # ml try: import label_studio_ml result['label-studio-ml'] = {'version': label_studio_ml.__version__} except Exception: pass result.update(settings.COLLECT_VERSIONS(result=result)) for key in result: if 'message' in result[key] and len(result[key]['message']) > 70: result[key]['message'] = result[key]['message'][0:70] + ' ...' if settings.SENTRY_DSN: import sentry_sdk sentry_sdk.set_context('versions', copy.deepcopy(result)) for package in result: if 'version' in result[package]: sentry_sdk.set_tag('version-' + package, result[package]['version']) if 'commit' in result[package]: sentry_sdk.set_tag('commit-' + package, result[package]['commit']) # edition type result['edition'] = settings.VERSION_EDITION settings.VERSIONS = result return result def get_organization_from_request(request): """Helper for backward compatibility with org_pk in session""" # TODO remove session logic in next release user = request.user if user and user.is_authenticated: if user.active_organization is None: organization_pk = request.session.get('organization_pk') if organization_pk: user.active_organization_id = organization_pk user.save() request.session.pop('organization_pk', None) request.session.modified = True return user.active_organization_id def load_func(func_string): """ If the given setting is a string import notation, then perform the necessary import or imports. """ if func_string is None: return None elif isinstance(func_string, str): return import_from_string(func_string) return func_string def import_from_string(func_string): """ Attempt to import a class from a string representation. """ try: return import_string(func_string) except ImportError as e: msg = f'Could not import {func_string} from settings: {e}' raise ImportError(msg) class temporary_disconnect_signal: """Temporarily disconnect a model from a signal Example: with temporary_disconnect_all_signals( signals.post_delete, update_is_labeled_after_removing_annotation, Annotation): do_something() """ def __init__(self, signal, receiver, sender, dispatch_uid=None): self.signal = signal self.receiver = receiver self.sender = sender self.dispatch_uid = dispatch_uid def __enter__(self): self.signal.disconnect(receiver=self.receiver, sender=self.sender, dispatch_uid=self.dispatch_uid) def __exit__(self, type_, value, traceback): self.signal.connect(receiver=self.receiver, sender=self.sender, dispatch_uid=self.dispatch_uid) class temporary_disconnect_all_signals(object): def __init__(self, disabled_signals=None): self.stashed_signals = defaultdict(list) self.disabled_signals = disabled_signals or [ pre_init, post_init, pre_save, post_save, pre_delete, post_delete, pre_migrate, post_migrate, ] def __enter__(self): for signal in self.disabled_signals: self.disconnect(signal) def __exit__(self, exc_type, exc_val, exc_tb): for signal in list(self.stashed_signals): self.reconnect(signal) def disconnect(self, signal): self.stashed_signals[signal] = signal.receivers signal.receivers = [] def reconnect(self, signal): signal.receivers = self.stashed_signals.get(signal, []) del self.stashed_signals[signal] def batch(iterable, n=1): l = len(iterable) # noqa: E741 for ndx in range(0, l, n): yield iterable[ndx : min(ndx + n, l)] def batched_iterator(iterable, n): """ TODO: replace with itertools.batched when we drop support for Python < 3.12 """ iterator = iter(iterable) while True: batch = [] for _ in range(n): try: batch.append(next(iterator)) except StopIteration: if batch: yield batch return yield batch def round_floats(o): if isinstance(o, float): return round(o, 2) if isinstance(o, dict): return {k: round_floats(v) for k, v in o.items()} if isinstance(o, (list, tuple)): return [round_floats(x) for x in o] return o class temporary_disconnect_list_signal: """Temporarily disconnect a list of signals Each signal tuple: (signal_type, signal_method, object) Example: with temporary_disconnect_list_signal( [(signals.post_delete, update_is_labeled_after_removing_annotation, Annotation)] ): do_something() """ def __init__(self, signals): self.signals = signals def __enter__(self): for signal in self.signals: sig = signal[0] receiver = signal[1] sender = signal[2] dispatch_uid = signal[3] if len(signal) > 3 else None sig.disconnect(receiver=receiver, sender=sender, dispatch_uid=dispatch_uid) def __exit__(self, type_, value, traceback): for signal in self.signals: sig = signal[0] receiver = signal[1] sender = signal[2] dispatch_uid = signal[3] if len(signal) > 3 else None sig.connect(receiver=receiver, sender=sender, dispatch_uid=dispatch_uid) def trigram_migration_operations(next_step): ops = [ TrigramExtension(), next_step, ] SKIP_TRIGRAM_EXTENSION = get_env('SKIP_TRIGRAM_EXTENSION', None) if SKIP_TRIGRAM_EXTENSION == '1' or SKIP_TRIGRAM_EXTENSION == 'yes' or SKIP_TRIGRAM_EXTENSION == 'true': ops = [next_step] if SKIP_TRIGRAM_EXTENSION == 'full': ops = [] return ops def btree_gin_migration_operations(next_step): ops = [ BtreeGinExtension(), next_step, ] SKIP_BTREE_GIN_EXTENSION = get_env('SKIP_BTREE_GIN_EXTENSION', None) if SKIP_BTREE_GIN_EXTENSION == '1' or SKIP_BTREE_GIN_EXTENSION == 'yes' or SKIP_BTREE_GIN_EXTENSION == 'true': ops = [next_step] if SKIP_BTREE_GIN_EXTENSION == 'full': ops = [] return ops def merge_labels_counters(dict1, dict2): """ Merge two dictionaries with nested dictionary values into a single dictionary. Args: dict1 (dict): The first dictionary to merge. dict2 (dict): The second dictionary to merge. Returns: dict: A new dictionary with the merged nested dictionaries. Example: dict1 = {'sentiment': {'Negative': 1, 'Positive': 1}} dict2 = {'sentiment': {'Positive': 2, 'Neutral': 1}} result_dict = merge_nested_dicts(dict1, dict2) # {'sentiment': {'Negative': 1, 'Positive': 3, 'Neutral': 1}} """ result_dict = {} # iterate over keys in both dictionaries for key in set(dict1.keys()) | set(dict2.keys()): # add the corresponding values if they exist in both dictionaries value = {} if key in dict1: value.update(dict1[key]) if key in dict2: for subkey in dict2[key]: value[subkey] = value.get(subkey, 0) + dict2[key][subkey] # add the key-value pair to the result dictionary result_dict[key] = value return result_dict def timeit(func): def wrapper(*args, **kwargs): start = time.time() result = func(*args, **kwargs) end = time.time() logging.debug(f'{func.__name__} execution time: {end-start} seconds') return result return wrapper def empty(*args, **kwargs): pass def get_ttl_hash(seconds: int = 60) -> int: """Return the same value within `seconds` time period""" return round(time.time() / seconds) def is_community(): """Determine if the current Label Studio instance is the community edition (aka LSO). Returns ------- bool True if running open-source Label Studio, False otherwise. """ try: import label_studio_enterprise # noqa: F401 return False except ImportError: return True