"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license. """ import logging from typing import Dict, List from core.utils.common import conditional_atomic, db_is_not_sqlite, load_func from django.conf import settings from django.db import models, transaction from django.db.models import Count, JSONField, Q from django.db.models.signals import post_save, pre_delete from django.dispatch import receiver from django.utils.translation import gettext_lazy as _ from ml.api_connector import PREDICT_URL, TIMEOUT_PREDICT, MLApi from projects.models import Project from tasks.serializers import PredictionSerializer, TaskSimpleSerializer from webhooks.serializers import Webhook, WebhookSerializer logger = logging.getLogger(__name__) MAX_JOBS_PER_PROJECT = 1 InteractiveAnnotatingDataSerializer = load_func(settings.INTERACTIVE_DATA_SERIALIZER) class MLBackendState(models.TextChoices): CONNECTED = 'CO', _('Connected') DISCONNECTED = 'DI', _('Disconnected') ERROR = 'ER', _('Error') TRAINING = 'TR', _('Training') PREDICTING = 'PR', _('Predicting') class MLBackendAuth(models.TextChoices): NONE = 'NONE', _('None') BASIC_AUTH = 'BASIC_AUTH', _('Basic Auth') class MLBackend(models.Model): """ """ state = models.CharField( max_length=2, choices=MLBackendState.choices, default=MLBackendState.DISCONNECTED, ) is_interactive = models.BooleanField( _('is_interactive'), default=False, help_text=('Used to interactively annotate tasks. ' 'If true, model returns one list with results'), ) url = models.TextField( _('url'), help_text='URL for the machine learning model server', ) error_message = models.TextField( _('error_message'), blank=True, null=True, help_text='Error message in error state', ) title = models.TextField( _('title'), blank=True, null=True, default='default', help_text='Name of the machine learning backend', ) auth_method = models.CharField( max_length=255, choices=MLBackendAuth.choices, default=MLBackendAuth.NONE, ) basic_auth_user = models.TextField( _('basic auth user'), blank=True, null=True, default='', help_text='HTTP Basic Auth user', ) basic_auth_pass = models.TextField( _('basic auth password'), blank=True, null=True, default='', help_text='HTTP Basic Auth password', ) description = models.TextField( _('description'), blank=True, null=True, default='', help_text='Description for the machine learning backend', ) extra_params = JSONField( _('extra params'), null=True, help_text='Any extra parameters passed to the ML Backend during the setup', ) model_version = models.TextField( _('model version'), blank=True, null=True, default='', help_text='Current model version associated with this machine learning backend', ) timeout = models.FloatField( _('timeout'), blank=True, default=100.0, help_text='Response model timeout', ) project = models.ForeignKey( Project, on_delete=models.CASCADE, related_name='ml_backends', ) created_at = models.DateTimeField(_('created at'), auto_now_add=True) updated_at = models.DateTimeField(_('updated at'), auto_now=True) auto_update = models.BooleanField( _('auto_update'), default=True, help_text='If false, model version is set by the user, if true - getting latest version from backend.', ) def __str__(self): return f'{self.title} (id={self.id}, url={self.url})' def __init__(self, *args, **kwargs): super(MLBackend, self).__init__(*args, **kwargs) self.__original_title = self.title def save(self, *args, **kwargs): """ Overrides the save() method to update the associated project's model_version field. If the title of the model instance is changed and the model_version of the related project is currently the same as the original title, the project's model_version is updated to the new title. """ p = self.project if self.title != self.__original_title and p.model_version == self.__original_title: with transaction.atomic(): p.model_version = self.title p.save(update_fields=['model_version']) super().save(*args, **kwargs) # reset original field to current field after save self.__original_title = self.title else: super().save(*args, **kwargs) @staticmethod def healthcheck_(url, auth_method=None, **kwargs): return MLApi(url=url, auth_method=auth_method, **kwargs).health() def has_permission(self, user): user.project = self.project # link for activity log return self.project.has_permission(user) @staticmethod def setup_(url, project, auth_method=None, **kwargs): api = MLApi(url=url, auth_method=auth_method, **kwargs) if not isinstance(project, Project): project = Project.objects.get(pk=project) return api.setup(project, **kwargs) def healthcheck(self): return self.healthcheck_( self.url, self.auth_method, basic_auth_user=self.basic_auth_user, basic_auth_pass=self.basic_auth_pass ) def setup(self): return self.setup_( self.url, self.project, self.auth_method, extra_params=self.extra_params, basic_auth_user=self.basic_auth_user, basic_auth_pass=self.basic_auth_pass, ) @property def api(self): return MLApi( url=self.url, timeout=self.timeout, auth_method=self.auth_method, basic_auth_user=self.basic_auth_user, basic_auth_pass=self.basic_auth_pass, ) @property def not_ready(self): return self.state in (MLBackendState.DISCONNECTED, MLBackendState.ERROR) def update_state(self): model_version = None if self.healthcheck().is_error: self.state = MLBackendState.DISCONNECTED else: setup_response = self.setup() if setup_response.is_error: logger.info(f'ML backend responds with error: {setup_response.error_message}') self.state = MLBackendState.ERROR self.error_message = setup_response.error_message else: self.state = MLBackendState.CONNECTED model_version = setup_response.response.get('model_version') logger.info(f'ML backend responds with success: {setup_response.response}') if self.auto_update: logger.debug(f'Changing model version: {self.model_version} -> {model_version}') self.model_version = model_version self.error_message = None self.save() return model_version def train(self): train_response = self.api.train(self.project) if train_response.is_error: self.state = MLBackendState.ERROR self.error_message = train_response.error_message else: self.state = MLBackendState.TRAINING current_train_job = train_response.response.get('job') if current_train_job: MLBackendTrainJob.objects.create(job_id=current_train_job, ml_backend=self) self.save() def _predict(self, task): """This is low level prediction method that is used for debugging""" ml_api = self.api task_ser = TaskSimpleSerializer(task).data request_params = ml_api._prep_prediction_req([task_ser], self.project) ml_api_result = ml_api._request(PREDICT_URL, request_params, verbose=False, timeout=TIMEOUT_PREDICT) if ml_api_result.is_error: logger.info(f'Prediction not created for project {self}: {ml_api_result.error_message}') return results = ml_api_result.response.get('results', None) return { 'status': 200, 'data': { 'status': ml_api_result.status_code, 'error_message': ml_api_result.error_message, 'url': ml_api._get_url(PREDICT_URL), 'task': task_ser, 'request': request_params, 'response': results, }, } def _get_predictions_from_ml_backend_one_by_one( self, serialized_tasks: List[Dict], current_responses: List[Dict] ) -> List[Dict]: """ This is helper method to get predictions from ML backend one by one in case when tasks length doesn't match responses length Note: don't use this function outside of this class """ if len(current_responses) == 1: # In case ML backend doesn't support batch of tasks, do it one by one # TODO: remove this block after all ML backends will support batch processing logger.warning( f"'ML backend '{self.title}' doesn't support batch processing of tasks, " f'switched to one-by-one task retrieval' ) predictions = [] for serialized_task in serialized_tasks: # get predictions per task predictions.extend(self._get_predictions_from_ml_backend([serialized_task])) return predictions else: # complete failure - likely ML backend skipped some tasks, we can't match them logger.error( f'Number of tasks and responses are not equal: ' f'{len(serialized_tasks)} tasks != {len(current_responses)} responses. ' f'Returning empty predictions.' ) return [] def _get_predictions_from_ml_backend(self, serialized_tasks: List[Dict]) -> List[Dict]: result = self.api.make_predictions(serialized_tasks, self.project) # response validation if result.is_error: logger.error(f'Error occurred: {result.error_message}') return [] elif not isinstance(result.response, dict) or 'results' not in result.response: logger.error(f'ML backend returns an incorrect response, it must be a dict: {result.response}') return [] elif not isinstance(result.response['results'], list) or len(result.response['results']) == 0: logger.error( 'ML backend returns an incorrect response, results field must be a list with at least one item' ) return [] responses = result.response['results'] predictions = [] if len(serialized_tasks) != len(responses): # Number of tasks and responses are not equal # It can happen if ML backend doesn't support batch processing but only process one task at a time # In the future versions, we may better consider this as an error and deprecate this code branch return self._get_predictions_from_ml_backend_one_by_one(serialized_tasks, responses) # ML backend supports batch processing for task, response in zip(serialized_tasks, responses): if isinstance(response, dict): # ML backend can return single prediction per task or multiple predictions response = [response] # get all predictions per task for r in response: if 'result' not in r: logger.error( f"ML backend returns an incorrect prediction, it should be a dict with the 'result' field:" f' {r}' ) continue predictions.append( { 'task': task['id'], 'result': r['result'], 'score': r.get('score'), 'model_version': r.get('model_version', self.model_version), 'project': task['project'], } ) return predictions def predict_tasks(self, tasks): model_version = self.update_state() if self.not_ready: logger.debug(f'ML backend {self} is not ready') return if isinstance(tasks, list): from tasks.models import Task tasks = Task.objects.filter(id__in=[task.id for task in tasks]) # Filter tasks that already contain the current model version in predictions tasks = tasks.annotate(predictions_count=Count('predictions')).exclude( Q(predictions_count__gt=0) & Q(predictions__model_version=model_version) ) if not tasks.exists(): logger.debug(f'All tasks already have prediction from model version={self.model_version}') return model_version tasks_ser = TaskSimpleSerializer(tasks, many=True).data predictions = self._get_predictions_from_ml_backend(tasks_ser) with conditional_atomic(predicate=db_is_not_sqlite): prediction_ser = PredictionSerializer(data=predictions, many=True) prediction_ser.is_valid(raise_exception=True) instances = prediction_ser.save() return instances def interactive_annotating(self, task, context=None, user=None): result = {} options = {} if user: options = {'user': user} if not self.is_interactive: result['errors'] = ['Model is not set to be used for interactive preannotations'] return result tasks_ser = InteractiveAnnotatingDataSerializer( [task], many=True, expand=['drafts', 'predictions', 'annotations'], context=options ).data ml_api_result = self.api.make_predictions( tasks=tasks_ser, project=self.project, context=context, ) if ml_api_result.is_error: logger.info(f'Prediction not created for project {self}: {ml_api_result.error_message}') result['errors'] = [ml_api_result.error_message] return result if not (isinstance(ml_api_result.response, dict) and 'results' in ml_api_result.response): logger.info(f'ML backend returns an incorrect response, it must be a dict: {ml_api_result.response}') result['errors'] = [ 'Incorrect response from ML service: ' 'ML backend returns an incorrect response, it must be a dict.' ] return result ml_results = ml_api_result.response.get( 'results', [ None, ], ) if not isinstance(ml_results, list) or len(ml_results) < 1: logger.warning(f'ML backend has to return list with 1 annotation but it returned: {type(ml_results)}') result['errors'] = [ 'Incorrect response from ML service: ' 'ML backend has to return list with more than 1 result.' ] return result result['data'] = ml_results[0] return result @staticmethod def get_versions_(url, project, auth_method, **kwargs): api = MLApi(url=url, auth_method=auth_method, **kwargs) if not isinstance(project, Project): project = Project.objects.get(pk=project) return api.get_versions(project) def get_versions(self): return self.get_versions_( self.url, self.project, self.auth_method, basic_auth_user=self.basic_auth_user, basic_auth_pass=self.basic_auth_pass, ) class MLBackendPredictionJob(models.Model): job_id = models.CharField(max_length=128) ml_backend = models.ForeignKey(MLBackend, related_name='prediction_jobs', on_delete=models.CASCADE) model_version = models.TextField( _('model version'), blank=True, null=True, help_text='Model version this job is associated with' ) batch_size = models.PositiveSmallIntegerField( _('batch size'), default=100, help_text='Number of tasks processed per batch' ) created_at = models.DateTimeField(_('created at'), auto_now_add=True) updated_at = models.DateTimeField(_('updated at'), auto_now=True) class MLBackendTrainJob(models.Model): job_id = models.CharField(max_length=128) ml_backend = models.ForeignKey(MLBackend, related_name='train_jobs', on_delete=models.CASCADE) model_version = models.TextField( _('model version'), blank=True, null=True, help_text='Model version this job is associated with', ) created_at = models.DateTimeField(_('created at'), auto_now_add=True) updated_at = models.DateTimeField(_('updated at'), auto_now=True) def get_status(self): project = self.ml_backend.project ml_api = project.get_ml_api() if not ml_api: logger.error( f"Training job {self.id}: Can't collect training jobs for project {project.id}: ML API is null" ) return None ml_api_result = ml_api.get_train_job_status(self) if ml_api_result.is_error: if ml_api_result.status_code == 410: return {'job_status': 'removed'} logger.info( f"Training job {self.id}: Can't collect training jobs for project {project}: " f'ML API returns error {ml_api_result.error_message}' ) return None return ml_api_result.response @property def is_running(self): status = self.get_status() return status['job_status'] in ('queued', 'started') def _validate_ml_api_result(ml_api_result, tasks, curr_logger): if ml_api_result.is_error: curr_logger.info(ml_api_result.error_message) return False results = ml_api_result.response['results'] if not isinstance(results, list) or len(results) != len(tasks): curr_logger.warning('Num input tasks is %d but ML API returns %d results', len(tasks), len(results)) return False return True @receiver(pre_delete, sender=MLBackend) def modify_project_model_version(sender, instance, **kwargs): project = instance.project if project.model_version == instance.title: project.model_version = '' project.save(update_fields=['model_version']) @receiver(post_save, sender=MLBackend) def create_ml_webhook(sender, instance, created, **kwargs): if not created: return ml_backend = instance webhook_url = ml_backend.url.rstrip('/') + '/webhook' project = ml_backend.project if Webhook.objects.filter(project=project, url=webhook_url).exists(): logger.info(f'Webhook {webhook_url} already exists for project {project}: skip creating new one.') return logger.info(f'Create ML backend webhook {webhook_url}') ser = WebhookSerializer( data=dict(project=project.id, url=webhook_url, send_payload=True, send_for_all_actions=True) ) if ser.is_valid(): ser.save(organization=project.organization)