import json
import pytest
from label_studio.tests.utils import make_project, make_task
@pytest.mark.django_db
def test_get_single_prediction_on_task(business_client, ml_backend_for_test_predict):
project = make_project(
config=dict(
is_published=True,
label_config="""
""",
title='test_get_single_prediction_on_task',
),
user=business_client.user,
use_ml_backend=False,
)
make_task({'data': {'text': 'test 1'}}, project)
# setup ML backend with single prediction per task
response = business_client.post(
'/api/ml/',
data={
'project': project.id,
'title': 'ModelSingle',
'url': 'http://test.ml.backend.for.sdk.com:9092',
},
)
assert response.status_code == 201
# get next task
response = business_client.get(f'/api/projects/{project.id}/next')
payload = json.loads(response.content)
# ensure task has a single prediction with the correct value
assert len(payload['predictions']) == 1
assert payload['predictions'][0]['result'][0]['value']['choices'][0] == 'label_A'
assert payload['predictions'][0]['model_version'] == 'ModelSingle'
@pytest.mark.django_db
def test_get_multiple_predictions_on_task(business_client, ml_backend_for_test_predict):
project = make_project(
config=dict(
is_published=True,
label_config="""
""",
title='test_get_multiple_predictions_on_task',
),
user=business_client.user,
use_ml_backend=False,
)
make_task({'data': {'text': 'test 1'}}, project)
# setup ML backend with multiple predictions per task
response = business_client.post(
'/api/ml/',
data={
'project': project.id,
'title': 'ModelA',
'url': 'http://test.ml.backend.for.sdk.com:9093',
},
)
assert response.status_code == 201
# get next task
response = business_client.get(f'/api/projects/{project.id}/next')
payload = json.loads(response.content)
# ensure task has multiple predictions with the correct values
assert len(payload['predictions']) == 2
assert payload['predictions'][0]['result'][0]['value']['choices'][0] == 'label_A'
assert payload['predictions'][0]['model_version'] == 'ModelA'
assert payload['predictions'][1]['result'][0]['value']['choices'][0] == 'label_B'
assert payload['predictions'][1]['model_version'] == 'ModelB'