"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license.
|
"""
|
import json
|
|
import pytest
|
import requests_mock
|
from django.apps import apps
|
from django.urls import reverse
|
from projects.models import Project
|
from tasks.models import Annotation, Task
|
|
from .utils import _client_is_annotator, invite_client_to_project
|
|
|
@pytest.fixture
|
def configured_project_min_annotations_1(configured_project):
|
p = Project.objects.get(id=configured_project.id)
|
p.min_annotations_to_start_training = 1
|
# p.agreement_method = p.SINGLE
|
p.save()
|
return p
|
|
|
@pytest.mark.django_db
|
@pytest.mark.parametrize(
|
'result, logtext, ml_upload_called',
|
[
|
(
|
json.dumps(
|
[
|
{
|
'from_name': 'text_class',
|
'to_name': 'text',
|
'type': 'labels',
|
'value': {'labels': ['class_A'], 'start': 0, 'end': 1},
|
}
|
]
|
),
|
None,
|
True,
|
),
|
(json.dumps([]), None, True),
|
],
|
)
|
def test_create_annotation(
|
caplog, any_client, configured_project_min_annotations_1, result, logtext, ml_upload_called
|
):
|
task = Task.objects.first()
|
if _client_is_annotator(any_client):
|
assert invite_client_to_project(any_client, task.project).status_code == 200
|
with requests_mock.Mocker() as m:
|
m.post('http://localhost:8999/train')
|
m.post('http://localhost:8999/webhook')
|
annotation = {'result': result, 'task': task.id, 'lead_time': 2.54}
|
r = any_client.post(f'/api/tasks/{task.id}/annotations/', data=annotation)
|
# check that submitted VALID data for task_annotation
|
# makes task labeled
|
task.refresh_from_db()
|
assert task.is_labeled is True
|
assert r.status_code == 201
|
annotation = Annotation.objects.all()
|
assert annotation.count() == 1
|
annotation = annotation.first()
|
assert annotation.task.id == task.id
|
# annotator client
|
if hasattr(any_client, 'annotator') and any_client.annotator is not None:
|
assert annotation.completed_by.id == any_client.user.id
|
assert annotation.updated_by.id == any_client.user.id
|
# business client
|
else:
|
assert annotation.completed_by.id == any_client.business.admin.id
|
assert annotation.updated_by.id == any_client.business.admin.id
|
|
if apps.is_installed('businesses'):
|
assert annotation.task.accuracy == 1.0
|
|
if logtext:
|
assert logtext in caplog.text
|
|
|
@pytest.mark.django_db
|
def test_create_annotation_with_ground_truth(caplog, any_client, configured_project_min_annotations_1):
|
|
task = Task.objects.first()
|
client_is_annotator = _client_is_annotator(any_client)
|
if client_is_annotator:
|
assert invite_client_to_project(any_client, task.project).status_code == 200
|
|
webhook_called = not client_is_annotator
|
ground_truth = {
|
'task': task.id,
|
'result': json.dumps(
|
[{'from_name': 'text_class', 'to_name': 'text', 'value': {'labels': ['class_A'], 'start': 0, 'end': 1}}]
|
),
|
'ground_truth': True,
|
}
|
|
annotation = {
|
'task': task.id,
|
'result': json.dumps(
|
[{'from_name': 'text_class', 'to_name': 'text', 'value': {'labels': ['class_B'], 'start': 0, 'end': 1}}]
|
),
|
}
|
|
with requests_mock.Mocker() as m:
|
m.post('http://localhost:8999/webhook')
|
m.post('http://localhost:8999/train')
|
|
# ground_truth doesn't affect statistics & ML backend, webhook is called for admin accounts
|
r = any_client.post('/api/tasks/{}/annotations/'.format(task.id), data=ground_truth)
|
assert r.status_code == 201
|
assert m.called == webhook_called
|
|
# real annotation triggers uploading to ML backend and recalculating accuracy
|
r = any_client.post('/api/tasks/{}/annotations/'.format(task.id), data=annotation)
|
assert r.status_code == 201
|
assert m.called
|
task = Task.objects.get(id=task.id)
|
assert task.annotations.count() == 2
|
annotations = Annotation.objects.filter(task=task)
|
for a in annotations:
|
assert a.updated_by.id == any_client.user.id
|
|
|
@pytest.mark.django_db
|
def test_delete_annotation(business_client, configured_project):
|
task = Task.objects.first()
|
annotation = Annotation.objects.create(task=task, project=configured_project, result=[])
|
assert task.annotations.count() == 1
|
r = business_client.delete('/api/annotations/{}/'.format(annotation.id))
|
assert r.status_code == 204
|
assert task.annotations.count() == 0
|
|
|
@pytest.fixture
|
def annotations():
|
task = Task.objects.first()
|
return {
|
'class_A': {
|
'task': task.id,
|
'result': json.dumps(
|
[
|
{
|
'from_name': 'text_class',
|
'to_name': 'text',
|
'type': 'labels',
|
'value': {'labels': ['class_A'], 'start': 0, 'end': 10},
|
}
|
]
|
),
|
},
|
'class_B': {
|
'task': task.id,
|
'result': json.dumps(
|
[
|
{
|
'from_name': 'text_class',
|
'to_name': 'text',
|
'type': 'labels',
|
'value': {'labels': ['class_B'], 'start': 0, 'end': 10},
|
}
|
]
|
),
|
},
|
'empty': {'task': task.id, 'result': json.dumps([])},
|
}
|
|
|
@pytest.fixture
|
def project_with_max_annotations_2(configured_project):
|
configured_project.maximum_annotations = 2
|
# configured_project.agreement_method = Project.SINGLE
|
configured_project.save()
|
|
|
@pytest.mark.parametrize(
|
'annotations_sequence, accuracy, is_labeled',
|
[
|
([], None, False),
|
([('class_A', 'business')], 1, False),
|
([('class_A', 'annotator')], 1, False),
|
# Same user twice doesn't meet overlap requirement for distinct annotators
|
([('class_A', 'business'), ('class_A', 'business')], 1, False),
|
([('class_A', 'business'), ('class_A', 'annotator')], 1, True),
|
([('class_A', 'annotator'), ('class_A', 'business')], 1, True),
|
# Same user twice doesn't meet overlap requirement for distinct annotators
|
([('class_A', 'business'), ('class_B', 'business')], 0.5, False),
|
([('class_A', 'business'), ('class_B', 'annotator')], 0.5, True),
|
([('class_A', 'annotator'), ('class_B', 'business')], 0.5, True),
|
([('empty', 'annotator'), ('empty', 'business')], 1, True),
|
([('class_A', 'annotator'), ('empty', 'business')], 0.5, True),
|
],
|
)
|
@pytest.mark.django_db
|
def test_accuracy(
|
business_client,
|
annotator_client,
|
project_with_max_annotations_2,
|
annotations,
|
annotations_sequence,
|
accuracy,
|
is_labeled,
|
):
|
client = {'business': business_client, 'annotator': annotator_client}
|
task_id = next(iter(annotations.values()))['task']
|
task = Task.objects.get(id=task_id)
|
invite_client_to_project(annotator_client, task.project)
|
|
for annotation_key, client_key in annotations_sequence:
|
r = client[client_key].post(
|
reverse('tasks:api:task-annotations', kwargs={'pk': task_id}), data=annotations[annotation_key]
|
)
|
assert r.status_code == 201
|
task = Task.objects.get(id=task_id)
|
assert task.is_labeled == is_labeled
|
|
|
@pytest.mark.django_db
|
def test_accuracy_on_delete(business_client, annotator_client, project_with_max_annotations_2, annotations):
|
task_id = next(iter(annotations.values()))['task']
|
task = Task.objects.get(id=task_id)
|
invite_client_to_project(annotator_client, task.project)
|
|
# Create annotations from two different users to meet overlap=2 with distinct annotators
|
annotation_list = list(annotations.values())
|
business_client.post(reverse('tasks:api:task-annotations', kwargs={'pk': task_id}), data=annotation_list[0])
|
business_client.post(reverse('tasks:api:task-annotations', kwargs={'pk': task_id}), data=annotation_list[1])
|
annotator_client.post(reverse('tasks:api:task-annotations', kwargs={'pk': task_id}), data=annotation_list[2])
|
|
task = Task.objects.get(id=task_id)
|
assert task.annotations.count() == len(annotations)
|
assert task.is_labeled # 2 distinct annotators >= overlap of 2
|
|
annotation_ids = [c.id for c in task.annotations.all()]
|
# Delete one of business_client's annotations - still have 2 distinct annotators
|
r = business_client.delete(reverse('tasks:api-annotations:annotation-detail', kwargs={'pk': annotation_ids[0]}))
|
assert r.status_code == 204
|
task = Task.objects.get(id=task_id)
|
assert task.is_labeled # Still 2 distinct annotators (business + annotator)
|
|
# Delete annotator's annotation - now only 1 distinct annotator
|
r = business_client.delete(reverse('tasks:api-annotations:annotation-detail', kwargs={'pk': annotation_ids[2]}))
|
assert r.status_code == 204
|
task = Task.objects.get(id=task_id)
|
assert not task.is_labeled # Only 1 distinct annotator < overlap of 2
|
|
# Delete last annotation
|
r = business_client.delete(reverse('tasks:api-annotations:annotation-detail', kwargs={'pk': annotation_ids[1]}))
|
assert r.status_code == 204
|
task = Task.objects.get(id=task_id)
|
assert not task.is_labeled
|
|
|
# @pytest.mark.django_db
|
# def test_accuracy_on_delete(business_client, project_with_max_annotations_2, annotations):
|
# task_id = next(iter(annotations.values()))['task']
|
# for annotation in annotations.values():
|
# business_client.post(reverse('tasks:api:task-annotations', kwargs={'pk': task_id}), data=annotation)
|
#
|
# task = Task.objects.get(id=task_id)
|
# assert task.annotations.count() == len(annotations)
|
# if apps.is_installed('businesses'):
|
# assert math.fabs(task.accuracy - 1 / 3) < 0.00001
|
# assert task.is_labeled
|
# annotation_ids = [c.id for c in task.annotations.all()]
|
# r = business_client.delete(reverse('tasks:api-annotations:annotation-detail', kwargs={'pk': annotation_ids[0]}))
|
# assert r.status_code == 204
|
# task = Task.objects.get(id=task_id)
|
# if apps.is_installed('businesses'):
|
# assert task.accuracy == 0.5
|
# assert task.is_labeled # project.max_annotations = 2
|
#
|
# r = business_client.delete(reverse('tasks:api-annotations:annotation-detail', kwargs={'pk': annotation_ids[1]}))
|
# assert r.status_code == 204
|
# task = Task.objects.get(id=task_id)
|
# if apps.is_installed('businesses'):
|
# assert task.accuracy == 1.0
|
# assert not task.is_labeled
|
#
|
# r = business_client.delete(reverse('tasks:api-annotations:annotation-detail', kwargs={'pk': annotation_ids[2]}))
|
# assert r.status_code == 204
|
# task = Task.objects.get(id=task_id)
|
# if apps.is_installed('businesses'):
|
# assert task.accuracy is None
|
# assert not task.is_labeled
|