chenzhaoyang
2025-12-17 d3e5a4b7658ece4f845bbc0c4f95acf3fbdf8a61
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license.
"""
from core.utils.io import validate_upload_url
from django.conf import settings
from ml.models import MLBackend, MLBackendAuth
from rest_framework import serializers
 
 
class MLBackendSerializer(serializers.ModelSerializer):
    """
    Serializer for MLBackend model.
    """
 
    readable_state = serializers.SerializerMethodField()
    basic_auth_pass = serializers.CharField(write_only=True, required=False, allow_null=True, allow_blank=True)
    basic_auth_pass_is_set = serializers.SerializerMethodField()
 
    def get_basic_auth_pass_is_set(self, obj):
        return bool(obj.basic_auth_pass)
 
    def get_readable_state(self, obj):
        return obj.get_state_display()
 
    def validate_basic_auth_pass(self, value):
        # Checks if the new password and old password are non-existent.
        if not value:
            if not self.instance.basic_auth_pass:
                raise serializers.ValidationError('Authentication password is required for Basic Authentication.')
            else:
                # If user is not changing the password, return the old password.
                return self.instance.basic_auth_pass
        return value
 
    def validate_url(self, value):
        validate_upload_url(value, block_local_urls=settings.ML_BLOCK_LOCAL_IP)
 
        return value
 
    def _validate_authentication(self, attrs):
        if attrs.get('auth_method') == MLBackendAuth.BASIC_AUTH:
            required_fields = ['basic_auth_user', 'basic_auth_pass']
 
            if any(field not in attrs for field in required_fields):
                raise serializers.ValidationError(
                    'Authentication username and password is required for Basic Authentication.'
                )
 
    def _validate_healthcheck(self, attrs):
        healthcheck_response = MLBackend.healthcheck_(**attrs)
 
        if healthcheck_response.is_error:
            if healthcheck_response.status_code == 401:
                message = (
                    'Able to connect to ML Server, but authentication parameters were '
                    'either not provided or are incorrect.'
                )
            else:
                message = (
                    f"Can't connect to ML backend {attrs['url']}, health check failed. "
                    'Make sure it is up and your firewall is properly configured. '
                    f'<a href="https://labelstud.io/guide/ml.html">Learn more</a> '
                    f'about how to set up an ML backend. Additional info: {healthcheck_response.error_message}'
                )
 
            raise serializers.ValidationError(message)
 
    def _validate_setup(self, attrs):
        setup_response = MLBackend.setup_(**attrs)
 
        if setup_response.is_error:
            message = (
                f"Successfully connected to {attrs['url']} but it doesn't look like a valid ML backend. "
                f'Reason: {setup_response.error_message}.\n'
                'Check the ML backend server console logs to check the status.'
                'There might be something wrong with your model or it might be incompatible with the current labeling configuration.'
            )
 
            raise serializers.ValidationError(message)
 
    def validate(self, attrs):
        attrs = super().validate(attrs)
 
        self._validate_authentication(attrs)
        self._validate_healthcheck(attrs)
        self._validate_setup(attrs)
 
        return attrs
 
    class Meta:
        model = MLBackend
        fields = [
            'id',
            'state',
            'readable_state',
            'is_interactive',
            'url',
            'error_message',
            'title',
            'auth_method',
            'basic_auth_user',
            'basic_auth_pass',
            'basic_auth_pass_is_set',
            'description',
            'extra_params',
            'model_version',
            'timeout',
            'created_at',
            'updated_at',
            'auto_update',
            'project',
        ]
 
 
class MLInteractiveAnnotatingRequest(serializers.Serializer):
    """
    Serializer for ML interactive annotating request.
    """
 
    task = serializers.IntegerField(help_text='ID of task to annotate', required=True)
    context = serializers.JSONField(help_text='Context for ML model', allow_null=True, default=None)