import json import pytest from projects.models import Task from rest_framework import status from label_studio.tests.utils import make_project, register_ml_backend_mock ORIG_MODEL_NAME = 'basic_ml_backend' PROJECT_CONFIG = """""" @pytest.fixture def ml_backend_for_test_api(ml_backend): register_ml_backend_mock( ml_backend, url='https://ml_backend_for_test_api', setup_model_version='1.0.0', ) yield ml_backend @pytest.fixture def mock_gethostbyname(mocker): mocker.patch('socket.gethostbyname', return_value='321.21.21.21') @pytest.mark.django_db def test_ml_backend_set_for_prelabeling(business_client, ml_backend_for_test_api, mock_gethostbyname): project = make_project( config=dict( is_published=True, label_config=PROJECT_CONFIG, title='test_ml_backend_creation', ), user=business_client.user, ) assert project.model_version == '' # create ML backend response = business_client.post( '/api/ml/', data={ 'project': project.id, 'title': 'ml_backend_title', 'url': 'https://ml_backend_for_test_api', }, ) assert response.status_code == 201 project.refresh_from_db() assert project.model_version == 'ml_backend_title' @pytest.mark.django_db def test_ml_backend_not_set_for_prelabeling(business_client, ml_backend_for_test_api, mock_gethostbyname): """We are not setting it when its already set for another name, for example when predictions were uploaded before""" project = make_project( config=dict( is_published=True, label_config=PROJECT_CONFIG, title='test_ml_backend_creation', ), user=business_client.user, ) project.model_version = ORIG_MODEL_NAME project.save() # create ML backend response = business_client.post( '/api/ml/', data={ 'project': project.id, 'title': 'ml_backend_title', 'url': 'https://ml_backend_for_test_api', }, ) assert response.status_code == 201 project.refresh_from_db() assert project.model_version == ORIG_MODEL_NAME @pytest.mark.django_db def test_model_version_on_save(business_client, ml_backend_for_test_api, mock_gethostbyname): project = make_project( config=dict( is_published=True, label_config=PROJECT_CONFIG, title='test_ml_backend_creation', ), user=business_client.user, ) assert project.model_version == '' # create ML backend response = business_client.post( '/api/ml/', data={ 'project': project.id, 'title': 'test_ml_backend_creation_ML_backend', 'url': 'https://ml_backend_for_test_api', }, ) assert response.status_code == 201 r = response.json() ml_backend_id = r['id'] response = business_client.get(f'/api/ml/{ml_backend_id}') assert response.status_code == 200 assert response.json()['state'] == 'CO' # select model version in project assert ( business_client.patch( f'/api/projects/{project.id}', data=json.dumps({'model_version': 'test_ml_backend_creation_ML_backend'}), content_type='application/json', ).status_code == 200 ) # change ML backend title --> model version should be updated assert ( business_client.patch( f'/api/ml/{ml_backend_id}', data=json.dumps( { 'project': project.id, 'title': 'new_title', 'url': 'https://ml_backend_for_test_api', } ), content_type='application/json', ).status_code == 200 ) project.refresh_from_db() assert project.model_version == 'new_title' @pytest.mark.django_db def test_model_version_on_delete(business_client, ml_backend_for_test_api, mock_gethostbyname): project = make_project( config=dict( is_published=True, label_config=PROJECT_CONFIG, title='test_ml_backend_creation', ), user=business_client.user, ) assert project.model_version == '' # create ML backend response = business_client.post( '/api/ml/', data={ 'project': project.id, 'title': 'test_ml_backend_creation_ML_backend', 'url': 'https://ml_backend_for_test_api', }, ) assert response.status_code == 201 r = response.json() ml_backend_id = r['id'] response = business_client.get(f'/api/ml/{ml_backend_id}') assert response.status_code == 200 assert response.json()['state'] == 'CO' # select model version in project assert ( business_client.patch( f'/api/projects/{project.id}', data=json.dumps({'model_version': 'test_ml_backend_creation_ML_backend'}), content_type='application/json', ).status_code == 200 ) project.refresh_from_db() assert project.model_version == 'test_ml_backend_creation_ML_backend' # delete ML backend --> project's model version should be reset assert business_client.delete(f'/api/ml/{ml_backend_id}').status_code == 204 project.refresh_from_db() assert project.model_version == '' @pytest.mark.django_db def test_security_write_only_payload(business_client, ml_backend_for_test_api, mock_gethostbyname): project = make_project( config=dict( is_published=True, label_config=PROJECT_CONFIG, title='test_ml_backend_creation', ), user=business_client.user, ) # create ML backend - fails without password response = business_client.post( '/api/ml/', data={ 'project': project.id, 'title': 'test_ml_backend_creation_ML_backend', 'url': 'https://ml_backend_for_test_api', 'auth_method': 'BASIC_AUTH', # 'basic_auth_user': 'user', # 'basic_auth_pass': '', }, ) assert response.status_code == 400 r = response.json() assert ( r['validation_errors']['non_field_errors'][0] == 'Authentication username and password is required for Basic Authentication.' ) # create ML backend with username and password response = business_client.post( '/api/ml/', data={ 'project': project.id, 'title': 'test_ml_backend_creation_ML_backend', 'url': 'https://ml_backend_for_test_api', 'auth_method': 'BASIC_AUTH', 'basic_auth_user': 'user', 'basic_auth_pass': '', }, ) assert response.status_code == 201 r = response.json() # security check that password is not returned in POST response assert 'basic_auth_pass' not in r ml_backend_id = r['id'] response = business_client.get(f'/api/ml/{ml_backend_id}') assert response.status_code == 200 # check that password is not returned in GET response assert 'basic_auth_pass' not in response.json() # patch ML backend without password - must pass since it uses write_only field for previous password response = business_client.patch( f'/api/ml/{ml_backend_id}', data=json.dumps( { 'project': project.id, 'title': 'new_title_1', 'url': 'https://ml_backend_for_test_api', } ), content_type='application/json', ) assert response.status_code == 200 # check that password is not returned in PATCH response assert 'basic_auth_pass' not in response.json() # patch ML backend with password response = business_client.patch( f'/api/ml/{ml_backend_id}', data=json.dumps( { 'project': project.id, 'title': 'new_title', 'url': 'https://ml_backend_for_test_api', 'basic_auth_pass': '', } ), content_type='application/json', ) # check that password is not returned in PATCH response assert 'basic_auth_pass' not in response.json() from ml.models import MLBackend ml_backend = MLBackend.objects.get(id=ml_backend_id) assert ml_backend.basic_auth_pass == '' @pytest.mark.django_db def test_ml_backend_predict_test_api_post_random_true(business_client): project = make_project( config=dict( is_published=True, label_config=PROJECT_CONFIG, title='test_ml_backend_creation', ), user=business_client.user, use_ml_backend=True, ) Task.objects.create(project=project, data={'image': 'http://example.com/image.jpg'}) # get ML backend id from project project.refresh_from_db() ml_backend = project.get_ml_backends().first() response = business_client.post(f'/api/ml/{ml_backend.id}/predict/test?random=true') assert response.status_code == status.HTTP_200_OK r = response.json() assert r['url'] == 'http://localhost:8999/predict' assert r['status'] == 200