"""Tests for the cache_labels action.""" import pytest from data_manager.actions.cache_labels import cache_labels_job from django.contrib.auth import get_user_model from projects.models import Project from tasks.models import Annotation, Prediction, Task @pytest.mark.django_db @pytest.mark.parametrize( 'source, control_tag, with_counters, expected_cache_column, use_predictions', [ # Test case 1: Annotations, control tag 'ALL', with counters ('annotations', 'ALL', 'Yes', 'cache_all', False), # Test case 2: Annotations, specific control tag, with counters ('annotations', 'label', 'Yes', 'cache_label', False), # Test case 3: Annotations, control tag 'ALL', without counters ('annotations', 'ALL', 'No', 'cache_all', False), # Test case 4: Predictions, control tag 'ALL', with counters ('predictions', 'ALL', 'Yes', 'cache_predictions_all', True), ], ) def test_cache_labels_job(source, control_tag, with_counters, expected_cache_column, use_predictions): # Initialize a test user and project User = get_user_model() test_user = User.objects.create(username='test_user') project = Project.objects.create(title='Test Project', created_by=test_user) # Create a few tasks tasks = [] for i in range(3): task = Task.objects.create(project=project, data={'text': f'This is task {i}'}) tasks.append(task) # Add a few annotations or predictions to these tasks for i, task in enumerate(tasks): result = [ { 'from_name': 'label', # Control tag used in the result 'to_name': 'text', 'type': 'labels', 'value': {'labels': [f'Label_{i%2+1}']}, } ] if use_predictions: Prediction.objects.create(task=task, project=project, result=result, model_version='v1') else: Annotation.objects.create(task=task, project=project, completed_by=test_user, result=result) # Prepare the request data request_data = {'source': source, 'control_tag': control_tag, 'with_counters': with_counters} # Get the queryset of tasks to process queryset = Task.objects.filter(project=project) # Run cache_labels_job cache_labels_job(project, queryset, request_data=request_data) # Check that the expected cache column is added to task['data'] for task in tasks: task.refresh_from_db() cache_column = expected_cache_column assert cache_column in task.data cached_labels = task.data[cache_column] assert cached_labels is not None # Verify the contents of the cached labels if use_predictions: source_objects = Prediction.objects.filter(task=task) else: source_objects = Annotation.objects.filter(task=task) all_labels = [] for source_obj in source_objects: for result in source_obj.result: # Apply similar logic as in extract_labels from_name = result.get('from_name') if control_tag == 'ALL' or control_tag == from_name: value = result.get('value', {}) for key in value: if isinstance(value[key], list) and value[key] and isinstance(value[key][0], str): all_labels.extend(value[key]) break if with_counters.lower() == 'yes': expected_cache = ', '.join(sorted([f'{label}: {all_labels.count(label)}' for label in set(all_labels)])) else: expected_cache = ', '.join(sorted(list(set(all_labels)))) assert cached_labels == expected_cache