import json from unittest.mock import patch import projects.api import pytest from django.test import TestCase from django.urls import reverse from projects.tests.factories import ProjectFactory from rest_framework.test import APIClient @pytest.mark.django_db class TestProjectSampleTask(TestCase): @classmethod def setUpTestData(cls): cls.project = ProjectFactory() @property def url(self): return reverse('projects:api:project-sample-task', kwargs={'pk': self.project.id}) def test_sample_task_with_happy_path(self): """Test that ProjectSampleTask.post successfully creates a complete sample task with annotations and predictions""" client = APIClient() client.force_authenticate(user=self.project.created_by) user_id = self.project.created_by.id label_config = """ """ sample_prediction = { 'model_version': 'sample model version', 'result': [ { 'id': 'abc123', 'from_name': 'sentiment', 'to_name': 'text', 'type': 'choices', 'value': {'choices': ['Positive']}, } ], 'score': 0.95, } sample_annotation = { 'was_cancelled': False, 'ground_truth': False, 'result': [ { 'id': 'def456', 'from_name': 'sentiment', 'to_name': 'text', 'type': 'choices', 'value': {'choices': ['Positive']}, } ], 'completed_by': -1, } sample_task = { 'id': 1, 'data': {'text': 'This is a sample task for labeling.'}, 'predictions': [sample_prediction], 'annotations': [sample_annotation], } with patch.object( projects.api.LabelInterface, 'generate_complete_sample_task', return_value=sample_task, ): response = client.post( self.url, data=json.dumps({'label_config': label_config, 'include_annotation_and_prediction': True}), content_type='application/json', ) assert response.status_code == 200 response_data = response.json() assert 'sample_task' in response_data sample_task_with_annotator_id_set = sample_task.copy() sample_task_with_annotator_id_set['annotations'][0]['completed_by'] = user_id assert response_data['sample_task'] == sample_task_with_annotator_id_set def test_sample_task_fallback_when_generate_task_fails(self): """Test fallback to project.get_sample_task when LabelInterface.generate_complete_sample_task fails""" client = APIClient() client.force_authenticate(user=self.project.created_by) label_config = """ """ fallback_data = {'id': 999, 'data': {'text': 'Fallback task'}} with ( patch.object( projects.api.LabelInterface, 'generate_complete_sample_task', side_effect=ValueError('Failed to generate sample task'), ), patch('projects.api.Project.get_sample_task', return_value=fallback_data), ): response = client.post( self.url, data=json.dumps({'label_config': label_config, 'include_annotation_and_prediction': True}), content_type='application/json', ) assert response.status_code == 200 response_data = response.json() assert 'sample_task' in response_data assert response_data['sample_task'] == fallback_data def test_sample_task_fallback_when_prediction_generation_fails(self): """Test fallback to project.get_sample_task when LabelInterface.generate_sample_prediction raises an exception""" client = APIClient() client.force_authenticate(user=self.project.created_by) label_config = """ """ fallback_data = {'id': 999, 'data': {'text': 'Fallback task'}} with ( patch.object( projects.api.LabelInterface, 'generate_sample_prediction', return_value=None, ), patch('projects.api.Project.get_sample_task', return_value=fallback_data), ): response = client.post( self.url, data=json.dumps({'label_config': label_config, 'include_annotation_and_prediction': True}), content_type='application/json', ) assert response.status_code == 200 response_data = response.json() assert 'sample_task' in response_data assert response_data['sample_task'] == fallback_data def test_sample_task_with_include_annotation_and_prediction_false(self): """Test that setting include_annotation_and_prediction=False bypasses LabelInterface.generate_complete_sample_task""" client = APIClient() client.force_authenticate(user=self.project.created_by) label_config = """ """ with patch('projects.api.Project.get_sample_task', return_value=None) as mock_get_sample_task, patch.object( projects.api.LabelInterface, 'generate_complete_sample_task', return_value=None ) as mock_generate_complete: # Shouldn't be called client.post( self.url, data=json.dumps({'label_config': label_config, 'include_annotation_and_prediction': False}), content_type='application/json', ) mock_get_sample_task.assert_called_once() mock_generate_complete.assert_not_called() def test_sample_task_default_behavior(self): """Test that omitting include_annotation_and_prediction defaults to False and uses simple sample task""" client = APIClient() client.force_authenticate(user=self.project.created_by) label_config = """ """ with patch('projects.api.Project.get_sample_task', return_value=None) as mock_get_sample_task, patch.object( projects.api.LabelInterface, 'generate_complete_sample_task', return_value=None ) as mock_generate_complete: # Shouldn't be called client.post( self.url, data=json.dumps({'label_config': label_config}), content_type='application/json', ) mock_get_sample_task.assert_called_once() mock_generate_complete.assert_not_called()