"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license. """ import logging from core.feature_flags import flag_set from core.permissions import ViewClassPermission, all_permissions from django.conf import settings from django.utils.decorators import method_decorator from django_filters.rest_framework import DjangoFilterBackend from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema from ml.models import MLBackend from ml.serializers import MLBackendSerializer, MLInteractiveAnnotatingRequest from projects.models import Project, Task from rest_framework import generics, status from rest_framework.parsers import FormParser, JSONParser, MultiPartParser from rest_framework.response import Response from rest_framework.views import APIView logger = logging.getLogger(__name__) _ml_backend_schema = { 'type': 'object', 'properties': { 'url': { 'type': 'string', 'description': 'ML backend URL', }, 'project': { 'type': 'integer', 'description': 'Project ID', }, 'is_interactive': { 'type': 'boolean', 'description': 'Is interactive', }, 'title': { 'type': 'string', 'description': 'Title', }, 'description': { 'type': 'string', 'description': 'Description', }, 'auth_method': { 'type': 'string', 'description': 'Auth method', 'enum': ['NONE', 'BASIC_AUTH'], }, 'basic_auth_user': { 'type': 'string', 'description': 'Basic auth user', }, 'basic_auth_pass': { 'type': 'string', 'description': 'Basic auth password', }, 'extra_params': { 'type': 'object', 'description': 'Extra parameters', }, 'timeout': { 'type': 'integer', 'description': 'Response model timeout', }, }, 'required': [], } @method_decorator( name='post', decorator=extend_schema( tags=['Machine Learning'], summary='Add ML Backend', description=""" Add an ML backend to a project using the Label Studio UI or by sending a POST request using the following cURL command: ```bash curl -X POST -H 'Content-type: application/json' {host}/api/ml -H 'Authorization: Token abc123'\\ --data '{{"url": "http://localhost:9090", "project": {{project_id}}}}' """.format( host=(settings.HOSTNAME or 'https://localhost:8080') ), request={ 'application/json': _ml_backend_schema, }, extensions={ 'x-fern-sdk-group-name': 'ml', 'x-fern-sdk-method-name': 'create', 'x-fern-audiences': ['public'], }, ), ) @method_decorator( name='get', decorator=extend_schema( tags=['Machine Learning'], summary='List ML backends', description=""" List all configured ML backends for a specific project by ID. Use the following cURL command: ```bash curl {host}/api/ml?project={{project_id}} -H 'Authorization: Token abc123' """.format( host=(settings.HOSTNAME or 'https://localhost:8080') ), parameters=[ OpenApiParameter(name='project', type=OpenApiTypes.INT, location='query', description='Project ID'), ], extensions={ 'x-fern-sdk-group-name': 'ml', 'x-fern-sdk-method-name': 'list', 'x-fern-audiences': ['public'], }, ), ) class MLBackendListAPI(generics.ListCreateAPIView): parser_classes = (JSONParser, FormParser, MultiPartParser) permission_required = ViewClassPermission( GET=all_permissions.projects_view, POST=all_permissions.projects_change, ) serializer_class = MLBackendSerializer filter_backends = [DjangoFilterBackend] filterset_fields = ['is_interactive'] def get_queryset(self): project_pk = self.request.query_params.get('project') project = generics.get_object_or_404(Project, pk=project_pk) self.check_object_permissions(self.request, project) ml_backends = project.update_ml_backends_state() return ml_backends def perform_create(self, serializer): ml_backend = serializer.save() ml_backend.update_state() project = ml_backend.project # In case we are adding the model, let's set it as the default # to obtain predictions. This approach is consistent with uploading # offline predictions, which would be set automatically. if project.show_collab_predictions and not project.model_version: project.model_version = ml_backend.title project.save(update_fields=['model_version']) @method_decorator( name='patch', decorator=extend_schema( tags=['Machine Learning'], summary='Update ML Backend', description=""" Update ML backend parameters using the Label Studio UI or by sending a PATCH request using the following cURL command: ```bash curl -X PATCH -H 'Content-type: application/json' {host}/api/ml/{{ml_backend_ID}} -H 'Authorization: Token abc123'\\ --data '{{"url": "http://localhost:9091"}}' """.format( host=(settings.HOSTNAME or 'https://localhost:8080') ), request={ 'application/json': _ml_backend_schema, }, extensions={ 'x-fern-sdk-group-name': 'ml', 'x-fern-sdk-method-name': 'update', 'x-fern-audiences': ['public'], }, ), ) @method_decorator( name='get', decorator=extend_schema( tags=['Machine Learning'], summary='Get ML Backend', description=""" Get details about a specific ML backend connection by ID. For example, make a GET request using the following cURL command: ```bash curl {host}/api/ml/{{ml_backend_ID}} -H 'Authorization: Token abc123' """.format( host=(settings.HOSTNAME or 'https://localhost:8080') ), request=None, extensions={ 'x-fern-sdk-group-name': 'ml', 'x-fern-sdk-method-name': 'get', 'x-fern-audiences': ['public'], }, ), ) @method_decorator( name='delete', decorator=extend_schema( tags=['Machine Learning'], summary='Remove ML Backend', description=""" Remove an existing ML backend connection by ID. For example, use the following cURL command: ```bash curl -X DELETE {host}/api/ml/{{ml_backend_ID}} -H 'Authorization: Token abc123' """.format( host=(settings.HOSTNAME or 'https://localhost:8080') ), request=None, extensions={ 'x-fern-sdk-group-name': 'ml', 'x-fern-sdk-method-name': 'delete', 'x-fern-audiences': ['public'], }, ), ) @method_decorator(name='put', decorator=extend_schema(exclude=True)) class MLBackendDetailAPI(generics.RetrieveUpdateDestroyAPIView): parser_classes = (JSONParser, FormParser, MultiPartParser) serializer_class = MLBackendSerializer permission_required = all_permissions.projects_change queryset = MLBackend.objects.all() def get_object(self): ml_backend = super(MLBackendDetailAPI, self).get_object() ml_backend.update_state() return ml_backend def perform_update(self, serializer): ml_backend = serializer.save() ml_backend.update_state() @method_decorator( name='post', decorator=extend_schema( tags=['Machine Learning'], summary='Train', description=""" After you add an ML backend, call this API with the ML backend ID to start training with already-labeled tasks. Get the ML backend ID by [listing the ML backends for a project](https://labelstud.io/api/#operation/api_ml_list). """, parameters=[ OpenApiParameter( name='id', type=OpenApiTypes.INT, location='path', description='A unique integer value identifying this ML backend.', ), ], request={ 'application/json': { 'type': 'object', 'properties': { 'use_ground_truth': { 'type': 'boolean', 'description': 'Whether to include ground truth annotations in training', }, }, }, }, responses={ 200: OpenApiResponse(description='Training has successfully started.'), 500: OpenApiResponse( description='Training error', response={ 'description': 'Error message', 'type': 'string', 'example': 'Server responded with an error.', }, ), }, extensions={ 'x-fern-sdk-group-name': 'ml', 'x-fern-sdk-method-name': 'train', 'x-fern-audiences': ['public'], }, ), ) class MLBackendTrainAPI(APIView): permission_required = all_permissions.projects_change def post(self, request, *args, **kwargs): ml_backend = generics.get_object_or_404(MLBackend, pk=self.kwargs['pk']) self.check_object_permissions(self.request, ml_backend) ml_backend.train() return Response(status=status.HTTP_200_OK) @method_decorator( name='post', decorator=extend_schema( tags=['Machine Learning'], summary='Test prediction', description=""" After you add an ML backend, call this API with the ML backend ID to run a test prediction on specific task data """, parameters=[ OpenApiParameter( name='id', type=OpenApiTypes.INT, location='path', description='A unique integer value identifying this ML backend.', ), ], responses={ 200: OpenApiResponse(description='Predicting has successfully started.'), 500: OpenApiResponse( description='Predicting error', response={ 'description': 'Error message', 'type': 'string', 'example': 'Server responded with an error.', }, ), }, extensions={ 'x-fern-sdk-group-name': 'ml', 'x-fern-sdk-method-name': 'test_predict', 'x-fern-audiences': ['internal'], }, ), ) class MLBackendPredictTestAPI(APIView): serializer_class = MLBackendSerializer permission_required = all_permissions.projects_change def post(self, request, *args, **kwargs): ml_backend = generics.get_object_or_404(MLBackend, pk=self.kwargs['pk']) self.check_object_permissions(self.request, ml_backend) random = request.query_params.get('random', False) if random: task = Task.get_random(project=ml_backend.project) if not task: return Response( status=status.HTTP_500_INTERNAL_SERVER_ERROR, data={ 'detail': 'Project has no tasks to run prediction on, import at least 1 task to run prediction' }, ) kwargs = ml_backend._predict(task) if not kwargs: return Response( status=status.HTTP_500_INTERNAL_SERVER_ERROR, data={ 'detail': 'ML backend did not return any predictions, check ML backend logs for more details' }, ) return Response(**kwargs) else: return Response( status=status.HTTP_501_NOT_IMPLEMENTED, data={'error': 'Not implemented - you must provide random=true query parameter'}, ) @method_decorator( name='post', decorator=extend_schema( tags=['Machine Learning'], summary='Request Interactive Annotation', description=""" Send a request to the machine learning backend set up to be used for interactive preannotations to retrieve a predicted region based on annotator input. See [set up machine learning](https://labelstud.io/guide/ml.html#Get-interactive-preannotations) for more. """, parameters=[ OpenApiParameter( name='id', type=OpenApiTypes.INT, location='path', description='A unique integer value identifying this ML backend.', ), ], request=MLInteractiveAnnotatingRequest, responses={ 200: OpenApiResponse(description='Interactive annotation has succeeded.'), }, extensions={ 'x-fern-sdk-group-name': 'ml', 'x-fern-sdk-method-name': 'predict_interactive', 'x-fern-audiences': ['public'], }, ), ) class MLBackendInteractiveAnnotating(APIView): """ Send a request to the machine learning backend set up to be used for interactive preannotations to retrieve a predicted region based on annotator input. """ permission_required = all_permissions.tasks_view def _error_response(self, message, log_function=logger.info): log_function(message) return Response({'errors': [message]}, status=status.HTTP_200_OK) def _get_task(self, ml_backend, validated_data): return generics.get_object_or_404(Task, pk=validated_data['task'], project=ml_backend.project) def _get_credentials(self, request, context, project): if flag_set('ff_back_dev_2362_project_credentials_060722_short', request.user): context.update( project_credentials_login=project.task_data_login, project_credentials_password=project.task_data_password, ) return context def post(self, request, *args, **kwargs): """ Send a request to the machine learning backend set up to be used for interactive preannotations to retrieve a predicted region based on annotator input. """ ml_backend = generics.get_object_or_404(MLBackend, pk=self.kwargs['pk']) self.check_object_permissions(self.request, ml_backend) serializer = MLInteractiveAnnotatingRequest(data=request.data) serializer.is_valid(raise_exception=True) task = self._get_task(ml_backend, serializer.validated_data) context = self._get_credentials(request, serializer.validated_data.get('context', {}), task.project) result = ml_backend.interactive_annotating(task, context, user=request.user) return Response( result, status=status.HTTP_200_OK, ) @method_decorator( name='get', decorator=extend_schema( tags=['Machine Learning'], summary='Get model versions', description='Get available versions of the model.', responses={ 200: OpenApiResponse( description='List of available versions.', response={ 'type': 'object', 'properties': { 'versions': { 'type': 'array', 'items': { 'type': 'string', }, }, 'message': { 'type': 'string', }, }, }, ), }, extensions={ 'x-fern-sdk-group-name': 'ml', 'x-fern-sdk-method-name': 'list_model_versions', 'x-fern-audiences': ['public'], }, ), ) class MLBackendVersionsAPI(generics.RetrieveAPIView): # TODO(jo): implement this view with a serializer and replace the handwritten schema above with it permission_required = all_permissions.projects_change def get(self, request, *args, **kwargs): ml_backend = generics.get_object_or_404(MLBackend, pk=self.kwargs['pk']) self.check_object_permissions(self.request, ml_backend) versions_response = ml_backend.get_versions() if versions_response.status_code == 200: result = {'versions': versions_response.response.get('versions', [])} return Response(data=result, status=200) elif versions_response.status_code == 404: result = {'versions': [ml_backend.model_version], 'message': 'Upgrade your ML backend version to latest.'} return Response(data=result, status=200) else: result = {'error': str(versions_response.error_message)} status_code = versions_response.status_code if versions_response.status_code > 0 else 500 return Response(data=result, status=status_code)