import json
|
|
import pytest
|
|
from label_studio.tests.utils import make_project, make_task
|
|
|
@pytest.mark.django_db
|
def test_get_single_prediction_on_task(business_client, ml_backend_for_test_predict):
|
project = make_project(
|
config=dict(
|
is_published=True,
|
label_config="""
|
<View>
|
<Text name="text" value="$text"></Text>
|
<Choices name="label" choice="single" toName="text">
|
<Choice value="label_A"></Choice>
|
<Choice value="label_B"></Choice>
|
</Choices>
|
</View>""",
|
title='test_get_single_prediction_on_task',
|
),
|
user=business_client.user,
|
use_ml_backend=False,
|
)
|
|
make_task({'data': {'text': 'test 1'}}, project)
|
|
# setup ML backend with single prediction per task
|
response = business_client.post(
|
'/api/ml/',
|
data={
|
'project': project.id,
|
'title': 'ModelSingle',
|
'url': 'http://test.ml.backend.for.sdk.com:9092',
|
},
|
)
|
assert response.status_code == 201
|
|
# get next task
|
response = business_client.get(f'/api/projects/{project.id}/next')
|
payload = json.loads(response.content)
|
|
# ensure task has a single prediction with the correct value
|
assert len(payload['predictions']) == 1
|
assert payload['predictions'][0]['result'][0]['value']['choices'][0] == 'label_A'
|
assert payload['predictions'][0]['model_version'] == 'ModelSingle'
|
|
|
@pytest.mark.django_db
|
def test_get_multiple_predictions_on_task(business_client, ml_backend_for_test_predict):
|
project = make_project(
|
config=dict(
|
is_published=True,
|
label_config="""
|
<View>
|
<Text name="text" value="$text"></Text>
|
<Choices name="label" choice="single" toName="text">
|
<Choice value="label_A"></Choice>
|
<Choice value="label_B"></Choice>
|
</Choices>
|
</View>""",
|
title='test_get_multiple_predictions_on_task',
|
),
|
user=business_client.user,
|
use_ml_backend=False,
|
)
|
|
make_task({'data': {'text': 'test 1'}}, project)
|
|
# setup ML backend with multiple predictions per task
|
response = business_client.post(
|
'/api/ml/',
|
data={
|
'project': project.id,
|
'title': 'ModelA',
|
'url': 'http://test.ml.backend.for.sdk.com:9093',
|
},
|
)
|
assert response.status_code == 201
|
|
# get next task
|
response = business_client.get(f'/api/projects/{project.id}/next')
|
payload = json.loads(response.content)
|
|
# ensure task has multiple predictions with the correct values
|
assert len(payload['predictions']) == 2
|
assert payload['predictions'][0]['result'][0]['value']['choices'][0] == 'label_A'
|
assert payload['predictions'][0]['model_version'] == 'ModelA'
|
assert payload['predictions'][1]['result'][0]['value']['choices'][0] == 'label_B'
|
assert payload['predictions'][1]['model_version'] == 'ModelB'
|