"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license. """ import json import pytest import requests_mock from core.redis import redis_healthcheck from ml.models import MLBackend from projects.models import Project from tasks.models import Annotation, AnnotationDraft, Prediction, Task from users.models import User from .utils import make_project _project_for_text_choices_onto_A_B_classes = dict( title='Test', label_config=""" """, ) _2_tasks_with_textA_and_textB = [ {'meta_info': 'meta info A', 'text': 'text A'}, {'meta_info': 'meta info B', 'text': 'text B'}, ] _2_prediction_results_for_textA_textB = [ { 'result': [ { 'from_name': 'text_class', 'to_name': 'text', 'type': 'labels', 'value': {'labels': ['class_A'], 'start': 0, 'end': 1}, } ], 'score': 0.95, }, { 'result': [ { 'from_name': 'text_class', 'to_name': 'text', 'type': 'labels', 'value': {'labels': ['class_B'], 'start': 0, 'end': 1}, } ], 'score': 0.59, }, ] def run_task_predictions(client, project, mocker): class TestJob: def __init__(self, job_id): self.id = job_id m = MLBackend.objects.filter(project=project.id).filter(url='http://localhost:8999').first() return client.post(f'/api/ml/{m.id}/predict') @pytest.mark.skipif(not redis_healthcheck(), reason='Starting predictions requires Redis server enabled') @pytest.mark.parametrize( 'project_config, tasks, annotations, prediction_results, log_messages, model_version_in_request, use_ground_truth', [ ( # project config _project_for_text_choices_onto_A_B_classes, # tasks _2_tasks_with_textA_and_textB, # annotations [ dict( result=[ { 'from_name': 'text_class', 'to_name': 'text', 'type': 'labels', 'value': {'labels': ['class_A'], 'start': 0, 'end': 1}, } ], ground_truth=True, ), dict( result=[ { 'from_name': 'text_class', 'to_name': 'text', 'type': 'labels', 'value': {'labels': ['class_B'], 'start': 0, 'end': 1}, } ], ground_truth=True, ), ], # prediction results _2_prediction_results_for_textA_textB, # log messages None, # model version in request '12345', False, ), ( # project config _project_for_text_choices_onto_A_B_classes, # tasks _2_tasks_with_textA_and_textB, # annotations [ dict( result=[ { 'from_name': 'text_class', 'to_name': 'text', 'type': 'labels', 'value': {'labels': ['class_A'], 'start': 0, 'end': 1}, } ], ground_truth=True, ), dict( result=[ { 'from_name': 'text_class', 'to_name': 'text', 'type': 'labels', 'value': {'labels': ['class_B'], 'start': 0, 'end': 1}, } ], ground_truth=True, ), ], # prediction results _2_prediction_results_for_textA_textB, # log messages None, # model version in request '12345', True, ), ], ) @pytest.mark.django_db def test_predictions( business_client, project_config, tasks, annotations, prediction_results, log_messages, model_version_in_request, use_ground_truth, mocker, ): # create project with predefined task set project = make_project(project_config, business_client.user) for task, annotation in zip(tasks, annotations): t = Task.objects.create(data=task, project=project) if use_ground_truth: Annotation.objects.create(task=t, **annotation) # run prediction with requests_mock.Mocker() as m: m.post('http://localhost:8999/setup', text=json.dumps({'model_version': model_version_in_request})) m.post( 'http://localhost:8999/predict', text=json.dumps({'results': prediction_results[:1], 'model_version': model_version_in_request}), ) r = run_task_predictions(business_client, project, mocker) assert r.status_code == 200 assert m.called # check whether stats are created predictions = Prediction.objects.all() project = Project.objects.get(id=project.id) ml_backend = MLBackend.objects.get(url='http://localhost:8999') assert predictions.count() == len(tasks) for actual_prediction, expected_prediction_result in zip(predictions, prediction_results): assert actual_prediction.result == prediction_results[0]['result'] assert actual_prediction.score == prediction_results[0]['score'] assert ml_backend.model_version == actual_prediction.model_version @pytest.mark.skipif(not redis_healthcheck(), reason='Starting predictions requires Redis server enabled') @pytest.mark.parametrize( 'test_name, project_config, setup_returns_model_version, tasks, annotations, ' 'input_predictions, prediction_call_count, num_project_stats, num_ground_truth_in_stats, ' 'num_ground_truth_fit_predictions', [ ( # test name just for reference 'All predictions are outdated, project.model_version is outdated too', # project config: contains old model version dict( title='Test', model_version='12345_old', label_config=""" """, ), # setup API returns this model version '12345', # task data [{'text': 'text A'}, {'text': 'text B'}], # annotations: there is no any annotations [None, None], # predictions: 2 predictions are from old model version [ { 'result': [ {'from_name': 'cls', 'to_name': 'txt', 'type': 'choices', 'value': {'choices': ['class_A']}} ], 'score': 0.95, 'model_version': '12345_old', }, { 'result': [ {'from_name': 'cls', 'to_name': 'txt', 'type': 'choices', 'value': {'choices': ['class_B']}} ], 'score': 0.59, 'model_version': '12345_old', }, ], # prediction call count is 2 for both tasks with old predictions 2, # ground_truth stats 0, 0, 0, ), ( # test name just for reference 'All predictions are up-to-date', # project config: contains actual model version dict( title='Test', model_version='12345_old', label_config=""" """, ), # setup API returns this model version '12345', # task data [{'text': 'text A'}, {'text': 'text B'}], # annotations: there is no any annotations [None, None], # predictions: 2 predictions are from old model version [ { 'result': [ {'from_name': 'cls', 'to_name': 'txt', 'type': 'choices', 'value': {'choices': ['class_A']}} ], 'score': 0.95, 'model_version': '12345', }, { 'result': [ {'from_name': 'cls', 'to_name': 'txt', 'type': 'choices', 'value': {'choices': ['class_B']}} ], 'score': 0.59, 'model_version': '12345', }, ], # prediction call count is 0 since predictions are up to date 0, # ground_truth stats 0, 0, 0, ), ( # test name just for reference 'Some predictions are outdated, other are up-to-date. project.model_version is up-to-date', # project config: contains actual model version dict( title='Test', model_version='12345', label_config=""" """, ), # setup API returns this model version '12345', # task data [{'text': 'text A'}, {'text': 'text B'}], # annotations: there is no any annotations [None, None], # predictions: 2 predictions, one from the new model version, second from old [ { 'result': [ {'from_name': 'cls', 'to_name': 'txt', 'type': 'choices', 'value': {'choices': ['class_A']}} ], 'score': 0.95, 'model_version': '12345', }, { 'result': [ {'from_name': 'cls', 'to_name': 'txt', 'type': 'choices', 'value': {'choices': ['class_B']}} ], 'score': 0.59, 'model_version': '12345_old', }, ], # prediction call count is 1 only for the task with old predictions 1, # ground_truth stats 0, 0, 0, ), ( # test name just for reference 'Some predictions are outdated, other are up-to-date. project.model_version is outdated', # project config: contains actual model version dict( title='Test', model_version='12345_old', label_config=""" """, ), # setup API returns this model version '12345', # task data [{'text': 'text A'}, {'text': 'text B'}], # annotations: there is no any annotations [None, None], # predictions: 2 predictions, one from the new model version, second from old [ { 'result': [ {'from_name': 'cls', 'to_name': 'txt', 'type': 'choices', 'value': {'choices': ['class_A']}} ], 'score': 0.95, 'model_version': '12345', }, { 'result': [ {'from_name': 'cls', 'to_name': 'txt', 'type': 'choices', 'value': {'choices': ['class_B']}} ], 'score': 0.59, 'model_version': '12345_old', }, ], # prediction call count is 1 only for the task with old predictions 1, # ground_truth stats 0, 0, 0, ), ( # test name just for reference 'All tasks has no predictions', # project config: contains actual model version dict( title='Test', model_version='12345', label_config=""" """, ), # setup API returns this model version '12345', # task data [{'text': 'text A'}, {'text': 'text B'}], # annotations: there is no any annotations [None, None], # there is no any predictions yet [None, None], # prediction call count for all tasks without predictions 2, # ground_truth stats 0, 0, 0, ), ( # test name just for reference 'Some tasks has no predictions, others are up-to-date', # project config: contains actual model version dict( title='Test', model_version='12345', label_config=""" """, ), # setup API returns this model version '12345', # task data [{'text': 'text A'}, {'text': 'text B'}], # annotations: there is no any annotations [None, None], # there is only one prediction (since job has finished before processing all tasks) [ { 'result': [ {'from_name': 'cls', 'to_name': 'txt', 'type': 'choices', 'value': {'choices': ['class_A']}} ], 'score': 0.95, 'model_version': '12345', }, None, ], # prediction call count for all tasks without predictions 1, # ground_truth stats 0, 0, 0, ), ( # test name just for reference 'Some tasks has no predictions, others are up-to-date, labeled task contains ground_truth', # project config: contains actual model version dict( title='Test', model_version='12345', label_config=""" """, ), # setup API returns this model version '12345', # task data [{'text': 'text A'}, {'text': 'text B'}], # annotations: first task has fitted ground_truth [ None, { 'result': [ {'from_name': 'cls', 'to_name': 'txt', 'type': 'choices', 'value': {'choices': ['class_A']}} ], 'ground_truth': True, }, ], # there is only one prediction (since job has finished before processing all tasks) [ { 'result': [ {'from_name': 'cls', 'to_name': 'txt', 'type': 'choices', 'value': {'choices': ['class_A']}} ], 'score': 0.95, 'model_version': '12345', }, None, ], # prediction call count for all tasks without predictions 1, # ground_truth stats 1, 1, 1, ), ( # test name just for reference 'Some tasks has no predictions, others are outdated', # project config: contains actual model version dict( title='Test', model_version='12345', label_config=""" """, ), # setup API returns this model version '12345', # task data [{'text': 'text A'}, {'text': 'text B'}], # annotations: there is no any annotations [None, None], # there is only one prediction (since job has finished before processing all tasks) [ { 'result': [ {'from_name': 'cls', 'to_name': 'txt', 'type': 'choices', 'value': {'choices': ['class_A']}} ], 'score': 0.95, 'model_version': '12345_old', }, None, ], # prediction call count for all tasks without up-to-date predictions 2, # ground_truth stats 0, 0, 0, ), ( # test name just for reference 'Some tasks has no predictions, others are outdated, project.model_version is outdated', # project config: contains actual model version dict( title='Test', model_version='12345_old', label_config=""" """, ), # setup API returns this model version '12345', # task data [{'text': 'text A'}, {'text': 'text B'}], # annotations: there is no any annotations [None, None], # there is only one prediction (since job has finished before processing all tasks) [ { 'result': [ {'from_name': 'cls', 'to_name': 'txt', 'type': 'choices', 'value': {'choices': ['class_A']}} ], 'score': 0.95, 'model_version': '12345_old', }, None, ], # prediction call count for all tasks without up-to-date predictions 2, # ground_truth stats 0, 0, 0, ), ( # test name just for reference 'Some tasks has no predictions, others are outdated, others are up-to-date', # project config: contains actual model version dict( title='Test', model_version='12345_old', label_config=""" """, ), # setup API returns this model version '12345', # task data [{'text': 'text A'}, {'text': 'text A'}, {'text': 'text B'}], # annotations: there is no any annotations [None, None, None], # there is only one prediction (since job has finished before processing all tasks) [ { 'result': [ {'from_name': 'cls', 'to_name': 'txt', 'type': 'choices', 'value': {'choices': ['class_A']}} ], 'score': 0.95, 'model_version': '12345_old', }, { 'result': [ {'from_name': 'cls', 'to_name': 'txt', 'type': 'choices', 'value': {'choices': ['class_A']}} ], 'score': 0.95, 'model_version': '12345', }, None, ], # prediction call count for all tasks without up-to-date predictions 2, # ground_truth stats 0, 0, 0, ), ], ) @pytest.mark.django_db def test_predictions_with_partially_predicted_tasks( business_client, test_name, setup_returns_model_version, project_config, tasks, annotations, input_predictions, prediction_call_count, num_project_stats, num_ground_truth_in_stats, num_ground_truth_fit_predictions, mocker, ): project = make_project(project_config, business_client.user) ml_backend = MLBackend.objects.get(url='http://localhost:8999') ml_backend.model_version = project_config['model_version'] ml_backend.save() for task, annotation, prediction in zip(tasks, annotations, input_predictions): task_obj = Task.objects.create(project=project, data=task) if annotation is not None: Annotation.objects.create(task=task_obj, **annotation) if prediction is not None: Prediction.objects.create(task=task_obj, project=task_obj.project, **prediction) # run prediction with requests_mock.Mocker() as m: m.register_uri( 'POST', 'http://localhost:8999/setup', text=json.dumps({'model_version': setup_returns_model_version}) ) m.register_uri( 'POST', 'http://localhost:8999/predict', text=json.dumps( { 'results': [ { 'result': [ { 'from_name': 'cls', 'to_name': 'txt', 'type': 'choices', 'value': {'choices': ['class_A']}, } ], 'score': 1, } ], 'model_version': setup_returns_model_version, } ), ) r = run_task_predictions(business_client, project, mocker) assert r.status_code == 200 assert len(list(filter(lambda h: h.url.endswith('predict'), m.request_history))) == prediction_call_count assert Prediction.objects.filter(project=project.id, model_version=setup_returns_model_version).count() == len( tasks ) assert MLBackend.objects.get(url='http://localhost:8999').model_version == setup_returns_model_version @pytest.mark.django_db def test_interactive_annotating(business_client, configured_project): # create project with predefined task set ml_backend = configured_project.ml_backends.first() ml_backend.is_interactive = True ml_backend.save() task = configured_project.tasks.first() # run prediction with requests_mock.Mocker(real_http=True) as m: m.register_uri('POST', f'{ml_backend.url}/predict', json={'results': [{'x': 'x'}]}, status_code=200) r = business_client.post( f'/api/ml/{ml_backend.pk}/interactive-annotating', data=json.dumps( { 'task': task.id, 'context': {'y': 'y'}, } ), content_type='application/json', ) r.status_code = 200 result = r.json() assert 'data' in result assert 'x' in result['data'] assert result['data']['x'] == 'x' @pytest.mark.django_db def test_interactive_annotating_failing(business_client, configured_project): # create project with predefined task set ml_backend = configured_project.ml_backends.first() ml_backend.is_interactive = True ml_backend.save() task = configured_project.tasks.first() # run prediction r = business_client.post( f'/api/ml/{ml_backend.pk}/interactive-annotating', data=json.dumps( { 'task': task.id, 'context': {'y': 'y'}, } ), content_type='application/json', ) r.status_code = 200 result = r.json() assert 'errors' in result # BAD ML RESPONSE with requests_mock.Mocker(real_http=True) as m: m.register_uri('POST', f'{ml_backend.url}/predict', json={'kebab': [[['eat']]]}, status_code=200) r = business_client.post( f'/api/ml/{ml_backend.pk}/interactive-annotating', data=json.dumps( { 'task': task.id, 'context': {'y': 'y'}, } ), content_type='application/json', ) r.status_code = 200 result = r.json() assert 'errors' in result @pytest.mark.django_db def test_interactive_annotating_with_drafts(business_client, configured_project): """ Test interactive annotating with drafts :param business_client: :param configured_project: :return: """ # create project with predefined task set ml_backend = configured_project.ml_backends.first() ml_backend.is_interactive = True ml_backend.save() users = list(User.objects.all()) task = configured_project.tasks.first() AnnotationDraft.objects.create(task=task, user=users[0], result={}, lead_time=1) AnnotationDraft.objects.create(task=task, user=users[1], result={}, lead_time=2) # run prediction with requests_mock.Mocker(real_http=True) as m: m.register_uri('POST', f'{ml_backend.url}/predict', json={'results': [{'x': 'x'}]}, status_code=200) r = business_client.post( f'/api/ml/{ml_backend.pk}/interactive-annotating', data=json.dumps( { 'task': task.id, 'context': {'y': 'y'}, } ), content_type='application/json', ) r.status_code = 200 result = r.json() assert 'data' in result assert 'x' in result['data'] assert result['data']['x'] == 'x' history = [req for req in m.request_history if 'predict' in req.path][0] assert history.text js = json.loads(history.text) assert len(js['tasks'][0]['drafts']) == 1 @pytest.mark.django_db def test_predictions_meta(business_client, configured_project): from tasks.models import FailedPrediction, Prediction, PredictionMeta task = configured_project.tasks.first() # create Prediction prediction = Prediction.objects.create( task=task, project=task.project, result={ 'result': [ {'from_name': 'text_class', 'to_name': 'text', 'type': 'choices', 'value': {'choices': ['class_A']}} ] }, score=0.95, model_version='12345', ) # create FailedPrediction failed_prediction = FailedPrediction.objects.create( task=task, project=task.project, message='error', model_version='12345', ) # assert we can create PredictionMeta with Prediction p = PredictionMeta.objects.create(prediction=prediction) meta = PredictionMeta.objects.get(id=p.id) # assert default values like meta.inference_time == 0 and meta.failed_prediction == null assert meta.inference_time is None assert meta.failed_prediction is None # assert we can create PredictionMeta with FailedPrediction p = PredictionMeta.objects.create(failed_prediction=failed_prediction) meta = PredictionMeta.objects.get(id=p.id) assert meta.total_cost is None assert meta.prediction is None # assert it raise an exception if we create PredictionMeta with both Prediction and FailedPrediction with pytest.raises(Exception): PredictionMeta.objects.create(prediction=prediction, failed_prediction=failed_prediction) # assert it raises if no Prediction or FailedPrediction is provided with pytest.raises(Exception): PredictionMeta.objects.create()