"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license.
|
"""
|
import logging
|
import os
|
import urllib
|
|
import requests
|
from core.feature_flags import flag_set
|
from core.utils.common import load_func
|
from core.version import get_git_version
|
from data_export.serializers import ExportDataSerializer
|
from django.conf import settings
|
from django.contrib.auth.models import AnonymousUser
|
from django.db.models import Count
|
from requests.adapters import HTTPAdapter
|
from requests.auth import HTTPBasicAuth
|
|
from label_studio.core.utils.params import get_env
|
|
version = get_git_version()
|
logger = logging.getLogger(__name__)
|
|
CONNECTION_TIMEOUT = float(get_env('ML_CONNECTION_TIMEOUT', 1)) # seconds
|
TIMEOUT_DEFAULT = float(get_env('ML_TIMEOUT_DEFAULT', 100)) # seconds
|
|
TIMEOUT_TRAIN = float(get_env('ML_TIMEOUT_TRAIN', 30))
|
TIMEOUT_PREDICT = float(get_env('ML_TIMEOUT_PREDICT', 100))
|
TIMEOUT_HEALTH = float(get_env('ML_TIMEOUT_HEALTH', 1))
|
TIMEOUT_SETUP = float(get_env('ML_TIMEOUT_SETUP', 3))
|
TIMEOUT_DUPLICATE_MODEL = float(get_env('ML_TIMEOUT_DUPLICATE_MODEL', 1))
|
TIMEOUT_DELETE = float(get_env('ML_TIMEOUT_DELETE', 1))
|
TIMEOUT_TRAIN_JOB_STATUS = float(get_env('ML_TIMEOUT_TRAIN_JOB_STATUS', 1))
|
|
# TODO
|
# we would need to make it configurable on the ML backend side too
|
PREDICT_URL = 'predict'
|
HEALTH_URL = 'health'
|
VALIDATE_URL = 'validate'
|
SETUP_URL = 'setup'
|
DUPLICATE_URL = 'duplicate_model'
|
DELETE_URL = 'delete'
|
JOB_STATUS_URL = 'job_status'
|
VERSIONS_URL = 'versions'
|
|
|
class BaseHTTPAPI(object):
|
MAX_RETRIES = 2
|
HEADERS = {
|
'User-Agent': 'heartex/' + (version or ''),
|
}
|
|
def __init__(
|
self, url, timeout=None, connection_timeout=None, max_retries=None, headers=None, auth_method=None, **kwargs
|
):
|
self._url = url
|
self._timeout = timeout or TIMEOUT_DEFAULT
|
self._connection_timeout = connection_timeout or CONNECTION_TIMEOUT
|
self._headers = headers or {}
|
self._auth_method = auth_method
|
|
# TODO basic auth parameters must be required for auth_method == 'basic'
|
self._basic_auth = (kwargs.get('basic_auth_user'), kwargs.get('basic_auth_pass'))
|
|
self._max_retries = max_retries or self.MAX_RETRIES
|
self._sessions = {self._session_key(): self.create_session()}
|
|
def create_session(self):
|
session = requests.Session()
|
session.headers.update(self.HEADERS)
|
session.headers.update(self._headers)
|
session.mount('http://', HTTPAdapter(max_retries=self._max_retries))
|
session.mount('https://', HTTPAdapter(max_retries=self._max_retries))
|
return session
|
|
def _session_key(self):
|
return os.getpid()
|
|
@property
|
def http(self):
|
key = self._session_key()
|
if key in self._sessions:
|
return self._sessions[key]
|
else:
|
session = self.create_session()
|
self._sessions[key] = session
|
return session
|
|
def _prepare_kwargs(self, kwargs):
|
# add timeout if it's not presented
|
if 'timeout' not in kwargs:
|
kwargs['timeout'] = self._connection_timeout, self._timeout
|
|
if self._basic_auth[0] and self._basic_auth[1]:
|
kwargs['auth'] = HTTPBasicAuth(*self._basic_auth)
|
|
# add connection timeout if it's not presented
|
elif isinstance(kwargs['timeout'], float) or isinstance(kwargs['timeout'], int):
|
kwargs['timeout'] = (self._connection_timeout, kwargs['timeout'])
|
|
def request(self, method, *args, **kwargs):
|
self._prepare_kwargs(kwargs)
|
return self.http.request(method, *args, **kwargs)
|
|
def get(self, *args, **kwargs):
|
return self.request('GET', *args, **kwargs)
|
|
def post(self, *args, **kwargs):
|
return self.request('POST', *args, **kwargs)
|
|
|
class MLApiResult:
|
"""
|
Class for storing the result of ML API request
|
"""
|
|
def __init__(self, url='', request='', response=None, headers=None, type='ok', status_code=200):
|
self.url = url
|
self.request = request
|
self.response = {} if response is None else response
|
self.headers = {} if headers is None else headers
|
self.type = type
|
self.status_code = status_code
|
|
@property
|
def is_error(self):
|
return self.type == 'error'
|
|
@property
|
def error_message(self):
|
return self.response.get('error')
|
|
|
class MLApi(BaseHTTPAPI):
|
"""
|
Class for ML API connector
|
"""
|
|
def __init__(self, **kwargs):
|
super(MLApi, self).__init__(**kwargs)
|
self._validate_request_timeout = 10
|
|
def _get_url(self, url_suffix):
|
url = self._url
|
if url[-1] != '/':
|
url += '/'
|
return urllib.parse.urljoin(url, url_suffix)
|
|
def _request(self, url_suffix, request=None, verbose=True, method='POST', *args, **kwargs):
|
assert method in ('POST', 'GET')
|
url = self._get_url(url_suffix)
|
request = request or {}
|
headers = dict(self.http.headers)
|
|
response = None
|
try:
|
if method == 'POST':
|
response = self.post(url=url, json=request, *args, **kwargs)
|
else:
|
response = self.get(url=url, *args, **kwargs)
|
response.raise_for_status()
|
except requests.exceptions.RequestException as e:
|
error_string = str(e)
|
status_code = response.status_code if response is not None else 0
|
return MLApiResult(url, request, {'error': error_string}, headers, 'error', status_code=status_code)
|
status_code = response.status_code
|
try:
|
response = response.json()
|
except ValueError as e:
|
return MLApiResult(
|
url=url,
|
request=request,
|
response={'error': str(e), 'response': response.content},
|
headers=headers,
|
type='error',
|
status_code=status_code,
|
)
|
|
return MLApiResult(url=url, request=request, response=response, headers=headers, status_code=status_code)
|
|
def _create_project_uid(self, project):
|
time_id = int(project.created_at.timestamp())
|
return f'{project.id}.{time_id}'
|
|
def train(self, project, use_ground_truth=False):
|
# TODO Replace AnonymousUser with real user from request
|
user = AnonymousUser()
|
# Identify if feature flag is turned on
|
if flag_set('ff_back_dev_1417_start_training_mlbackend_webhooks_250122_long', user):
|
request = {
|
'action': 'START_TRAINING',
|
'project': load_func(settings.WEBHOOK_SERIALIZERS['project'])(instance=project).data,
|
}
|
return self._request('webhook', request, verbose=False, timeout=TIMEOUT_PREDICT)
|
else:
|
# get only tasks with annotations
|
tasks = project.tasks.annotate(num_annotations=Count('annotations')).filter(num_annotations__gt=0)
|
# create serialized tasks with annotations: {"data": {...}, "annotations": [{...}], "predictions": [{...}]}
|
tasks_ser = ExportDataSerializer(tasks, many=True).data
|
logger.debug(f'{len(tasks_ser)} tasks with annotations are sent to ML backend for training.')
|
request = {
|
'annotations': tasks_ser,
|
'project': self._create_project_uid(project),
|
'label_config': project.label_config,
|
'params': {'login': project.task_data_login, 'password': project.task_data_password},
|
}
|
return self._request('train', request, verbose=False, timeout=TIMEOUT_PREDICT)
|
|
def _prep_prediction_req(self, tasks, project, context=None):
|
request = {
|
'tasks': tasks,
|
'project': self._create_project_uid(project),
|
'label_config': project.label_config,
|
'params': {
|
'login': project.task_data_login,
|
'password': project.task_data_password,
|
'context': context,
|
},
|
}
|
|
return request
|
|
def make_predictions(self, tasks, project, context=None):
|
request = self._prep_prediction_req(tasks, project, context=context)
|
return self._request(PREDICT_URL, request, verbose=False, timeout=TIMEOUT_PREDICT)
|
|
def health(self):
|
return self._request(HEALTH_URL, method='GET', timeout=TIMEOUT_HEALTH)
|
|
def validate(self, config):
|
return self._request(VALIDATE_URL, request={'config': config}, timeout=self._validate_request_timeout)
|
|
def setup(self, project, extra_params=None, **kwargs):
|
return self._request(
|
SETUP_URL,
|
request={
|
'project': self._create_project_uid(project),
|
'schema': project.label_config,
|
'hostname': settings.HOSTNAME if settings.HOSTNAME else ('http://localhost:' + settings.INTERNAL_PORT),
|
'access_token': project.created_by.auth_token.key,
|
'extra_params': extra_params,
|
},
|
timeout=TIMEOUT_SETUP,
|
)
|
|
def duplicate_model(self, project_src, project_dst):
|
return self._request(
|
DUPLICATE_URL,
|
request={
|
'project_src': self._create_project_uid(project_src),
|
'project_dst': self._create_project_uid(project_dst),
|
},
|
timeout=TIMEOUT_DUPLICATE_MODEL,
|
)
|
|
def delete(self, project):
|
return self._request(
|
DELETE_URL, request={'project': self._create_project_uid(project)}, timeout=TIMEOUT_DELETE
|
)
|
|
def get_train_job_status(self, train_job):
|
return self._request(JOB_STATUS_URL, request={'job': train_job.job_id}, timeout=TIMEOUT_TRAIN_JOB_STATUS)
|
|
def get_versions(self, project):
|
return self._request(
|
VERSIONS_URL, request={'project': self._create_project_uid(project)}, timeout=TIMEOUT_SETUP, method='GET'
|
)
|
|
|
def get_ml_api(project):
|
if project.ml_backend_active_connection is None:
|
return None
|
if project.ml_backend_active_connection.ml_backend is None:
|
return None
|
return MLApi(
|
url=project.ml_backend_active_connection.ml_backend.url,
|
timeout=project.ml_backend_active_connection.ml_backend.timeout,
|
)
|