Bin
2025-12-17 1d710f844b65d9bfdf986a71a3b924cd70598a41
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""Tests for the cache_labels action."""
 
import pytest
from data_manager.actions.cache_labels import cache_labels_job
from django.contrib.auth import get_user_model
from projects.models import Project
from tasks.models import Annotation, Prediction, Task
 
 
@pytest.mark.django_db
@pytest.mark.parametrize(
    'source, control_tag, with_counters, expected_cache_column, use_predictions',
    [
        # Test case 1: Annotations, control tag 'ALL', with counters
        ('annotations', 'ALL', 'Yes', 'cache_all', False),
        # Test case 2: Annotations, specific control tag, with counters
        ('annotations', 'label', 'Yes', 'cache_label', False),
        # Test case 3: Annotations, control tag 'ALL', without counters
        ('annotations', 'ALL', 'No', 'cache_all', False),
        # Test case 4: Predictions, control tag 'ALL', with counters
        ('predictions', 'ALL', 'Yes', 'cache_predictions_all', True),
    ],
)
def test_cache_labels_job(source, control_tag, with_counters, expected_cache_column, use_predictions):
    # Initialize a test user and project
    User = get_user_model()
    test_user = User.objects.create(username='test_user')
    project = Project.objects.create(title='Test Project', created_by=test_user)
 
    # Create a few tasks
    tasks = []
    for i in range(3):
        task = Task.objects.create(project=project, data={'text': f'This is task {i}'})
        tasks.append(task)
 
    # Add a few annotations or predictions to these tasks
    for i, task in enumerate(tasks):
        result = [
            {
                'from_name': 'label',  # Control tag used in the result
                'to_name': 'text',
                'type': 'labels',
                'value': {'labels': [f'Label_{i%2+1}']},
            }
        ]
        if use_predictions:
            Prediction.objects.create(task=task, project=project, result=result, model_version='v1')
        else:
            Annotation.objects.create(task=task, project=project, completed_by=test_user, result=result)
 
    # Prepare the request data
    request_data = {'source': source, 'control_tag': control_tag, 'with_counters': with_counters}
 
    # Get the queryset of tasks to process
    queryset = Task.objects.filter(project=project)
 
    # Run cache_labels_job
    cache_labels_job(project, queryset, request_data=request_data)
 
    # Check that the expected cache column is added to task['data']
    for task in tasks:
        task.refresh_from_db()
        cache_column = expected_cache_column
        assert cache_column in task.data
        cached_labels = task.data[cache_column]
        assert cached_labels is not None
 
        # Verify the contents of the cached labels
        if use_predictions:
            source_objects = Prediction.objects.filter(task=task)
        else:
            source_objects = Annotation.objects.filter(task=task)
 
        all_labels = []
        for source_obj in source_objects:
            for result in source_obj.result:
                # Apply similar logic as in extract_labels
                from_name = result.get('from_name')
                if control_tag == 'ALL' or control_tag == from_name:
                    value = result.get('value', {})
                    for key in value:
                        if isinstance(value[key], list) and value[key] and isinstance(value[key][0], str):
                            all_labels.extend(value[key])
                            break
 
        if with_counters.lower() == 'yes':
            expected_cache = ', '.join(sorted([f'{label}: {all_labels.count(label)}' for label in set(all_labels)]))
        else:
            expected_cache = ', '.join(sorted(list(set(all_labels))))
 
        assert cached_labels == expected_cache