chenzhaoyang
2025-12-17 063da0bf961e1d35e25dc107f883f7492f4c5a7c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
import logging
from collections import Counter
from typing import List, Tuple, Union
 
from core.feature_flags import flag_set
from core.utils.common import conditional_atomic, db_is_not_sqlite, load_func
from core.utils.db import fast_first
from django.conf import settings
from django.db.models import BooleanField, Case, Count, Exists, F, Max, OuterRef, Q, QuerySet, Value, When
from django.db.models.fields import DecimalField
from projects.functions.stream_history import add_stream_history
from projects.models import Project
from tasks.models import Annotation, Task
from users.models import User
 
logger = logging.getLogger(__name__)
 
 
# Hook for GT-first gating (Enterprise can override via settings)
def _oss_should_attempt_gt_first(user: User, project: Project) -> bool:
    # Open-source default: if project enables GT-first, allow it without onboarding gates
    return bool(project.show_ground_truth_first)
 
 
get_tasks_agreement_queryset = load_func(settings.GET_TASKS_AGREEMENT_QUERYSET)
should_attempt_ground_truth_first = (
    load_func(settings.SHOULD_ATTEMPT_GROUND_TRUTH_FIRST) or _oss_should_attempt_gt_first
)
 
 
def get_next_task_logging_level(user: User) -> int:
    level = logging.DEBUG
    if flag_set('fflag_fix_back_dev_4185_next_task_additional_logging_long', user=user):
        level = logging.INFO
    return level
 
 
def _get_random_unlocked(task_query: QuerySet[Task], user: User, upper_limit=None) -> Union[Task, None]:
    for task in task_query.order_by('?').only('id')[: settings.RANDOM_NEXT_TASK_SAMPLE_SIZE]:
        try:
            task = Task.objects.select_for_update(skip_locked=True).get(pk=task.id)
            if not task.has_lock(user):
                return task
        except Task.DoesNotExist:
            logger.debug('Task with id {} locked'.format(task.id))
 
 
def _get_first_unlocked(tasks_query: QuerySet[Task], user) -> Union[Task, None]:
    # Skip tasks that are locked due to being taken by collaborators
    for task_id in tasks_query.values_list('id', flat=True):
        try:
            task = Task.objects.select_for_update(skip_locked=True).get(pk=task_id)
            if not task.has_lock(user):
                return task
 
        except Task.DoesNotExist:
            logger.debug('Task with id {} locked'.format(task_id))
 
 
def _try_ground_truth(tasks: QuerySet[Task], project: Project, user: User) -> Union[Task, None]:
    """Returns task from ground truth set"""
    ground_truth = Annotation.objects.filter(task=OuterRef('pk'), ground_truth=True)
    not_solved_tasks_with_ground_truths = tasks.annotate(has_ground_truths=Exists(ground_truth)).filter(
        has_ground_truths=True
    )
    if not_solved_tasks_with_ground_truths.exists():
        if project.sampling == project.SEQUENCE:
            return _get_first_unlocked(not_solved_tasks_with_ground_truths, user)
        return _get_random_unlocked(not_solved_tasks_with_ground_truths, user)
 
 
def _try_tasks_with_overlap(tasks: QuerySet[Task]) -> Tuple[Union[Task, None], QuerySet[Task]]:
    """Filter out tasks without overlap (doesn't return next task)"""
    tasks_with_overlap = tasks.filter(overlap__gt=1)
    if tasks_with_overlap.exists():
        return None, tasks_with_overlap
    else:
        return None, tasks.filter(overlap=1)
 
 
def _try_breadth_first(tasks: QuerySet[Task], user: User, project: Project) -> Union[Task, None]:
    """Try to find tasks with maximum amount of annotations, since we are trying to label tasks as fast as possible"""
 
    # Exclude ground truth annotations from the count when not in onboarding mode
    # to prevent GT tasks from being prioritized via breadth-first logic
    annotation_filter = ~Q(annotations__completed_by=user)
    if not project.show_ground_truth_first:
        annotation_filter &= ~Q(annotations__ground_truth=True)
 
    tasks = tasks.annotate(annotations_count=Count('annotations', filter=annotation_filter))
    max_annotations_count = tasks.aggregate(Max('annotations_count'))['annotations_count__max']
    if max_annotations_count == 0:
        # there is no any labeled tasks found
        return
 
    # find any task with maximal amount of created annotations
    not_solved_tasks_labeling_started = tasks.annotate(
        reach_max_annotations_count=Case(
            When(annotations_count=max_annotations_count, then=Value(True)),
            default=Value(False),
            output_field=BooleanField(),
        )
    )
    not_solved_tasks_labeling_with_max_annotations = not_solved_tasks_labeling_started.filter(
        reach_max_annotations_count=True
    )
    if not_solved_tasks_labeling_with_max_annotations.exists():
        # try to complete tasks that are already in progress
        return _get_random_unlocked(not_solved_tasks_labeling_with_max_annotations, user)
 
 
def _try_uncertainty_sampling(
    tasks: QuerySet[Task],
    project: Project,
    user_solved_tasks_array: List[int],
    user: User,
    prepared_tasks: QuerySet[Task],
) -> Union[Task, None]:
    task_with_current_predictions = tasks.filter(predictions__model_version=project.model_version)
    if task_with_current_predictions.exists():
        logger.debug('Use uncertainty sampling')
        # collect all clusters already solved by user, count number of solved task in them
        user_solved_clusters = (
            prepared_tasks.filter(pk__in=user_solved_tasks_array)
            .annotate(cluster=Max('predictions__cluster'))
            .values_list('cluster', flat=True)
        )
        user_solved_clusters = Counter(user_solved_clusters)
        # order each task by the count of how many tasks solved in it's cluster
        cluster_num_solved_map = [When(predictions__cluster=k, then=v) for k, v in user_solved_clusters.items()]
 
        # WARNING! this call doesn't work after consequent annotate
        num_tasks_with_current_predictions = task_with_current_predictions.count()
        if cluster_num_solved_map:
            task_with_current_predictions = task_with_current_predictions.annotate(
                cluster_num_solved=Case(*cluster_num_solved_map, default=0, output_field=DecimalField())
            )
            # next task is chosen from least solved cluster and with lowest prediction score
            possible_next_tasks = task_with_current_predictions.order_by('cluster_num_solved', 'predictions__score')
        else:
            possible_next_tasks = task_with_current_predictions.order_by('predictions__score')
 
        num_annotators = project.annotators().count()
        if num_annotators > 1 and num_tasks_with_current_predictions > 0:
            # try to randomize tasks to avoid concurrent labeling between several annotators
            next_task = _get_random_unlocked(
                possible_next_tasks, user, upper_limit=min(num_annotators + 1, num_tasks_with_current_predictions)
            )
        else:
            next_task = _get_first_unlocked(possible_next_tasks, user)
    else:
        # uncertainty sampling fallback: choose by random sampling
        logger.debug(
            f'Uncertainty sampling fallbacks to random sampling '
            f'(current project.model_version={str(project.model_version)})'
        )
        next_task = _get_random_unlocked(tasks, user)
    return next_task
 
 
def get_not_solved_tasks_qs(
    user: User,
    project: Project,
    prepared_tasks: QuerySet[Task],
    assigned_flag: Union[bool, None],
    queue_info: str,
    allow_gt_first: bool,
) -> Tuple[QuerySet[Task], List[int], str, bool]:
    user_solved_tasks_array = user.annotations.filter(project=project, task__isnull=False)
    user_solved_tasks_array = user_solved_tasks_array.distinct().values_list('task__pk', flat=True)
    not_solved_tasks = prepared_tasks.exclude(pk__in=user_solved_tasks_array)
 
    # annotation can't have postponed draft, so skip annotation__project filter
    postponed_drafts = user.drafts.filter(task__project=project, was_postponed=True)
    if postponed_drafts.exists():
        user_postponed_tasks = postponed_drafts.distinct().values_list('task__pk', flat=True)
        not_solved_tasks = not_solved_tasks.exclude(pk__in=user_postponed_tasks)
 
    prioritized_on_agreement = False
    # if annotator is assigned for tasks, he must solve it regardless of is_labeled=True
    if not assigned_flag:
        # low agreement strategy for auto-assigned annotators:
        # Include tasks that have been completed if their agreement is not at threshold if threshold setting is set
        lse_project = getattr(project, 'lse_project', None)
        if (
            lse_project
            and lse_project.agreement_threshold is not None
            and get_tasks_agreement_queryset
            and user.is_project_annotator(project)
        ):
            # Onboarding mode (GT-first) should keep GT tasks eligible regardless of is_labeled/agreement
            qs = get_tasks_agreement_queryset(not_solved_tasks)
            qs = qs.annotate(annotators=Count('annotations__completed_by', distinct=True))
 
            low_agreement_pred = Q(_agreement__lt=lse_project.agreement_threshold, is_labeled=True) | Q(
                is_labeled=False
            )
            capacity_pred = Q(annotators__lt=F('overlap') + (lse_project.max_additional_annotators_assignable or 0))
 
            if project.show_ground_truth_first:
                gt_subq = Annotation.objects.filter(task=OuterRef('pk'), ground_truth=True)
                qs = qs.annotate(has_ground_truths=Exists(gt_subq))
                # Keep all GT tasks + apply low-agreement+capacity to the rest. For sure, we can do:
                # - if user.solved_tasks_array.count < lse_project.annotator_evaluation_minimum_tasks
                # - else, apply low-agreement+capacity to the rest (maybe performance will be better)
                # but it's a question - what is better here. This version is simpler at least from the code perspective.
                not_solved_tasks = qs.filter(Q(has_ground_truths=True) | (low_agreement_pred & capacity_pred))
            else:
                not_solved_tasks = qs.filter(low_agreement_pred & capacity_pred)
 
            prioritized_on_agreement, not_solved_tasks = _prioritize_low_agreement_tasks(not_solved_tasks, lse_project)
 
        # otherwise, filtering out completed tasks is sufficient
        else:
            # ignore tasks that are already labeled when GT-first is NOT allowed
            if not allow_gt_first:
                not_solved_tasks = not_solved_tasks.filter(is_labeled=False)
 
    if not flag_set('fflag_fix_back_lsdv_4523_show_overlap_first_order_27022023_short'):
        # show tasks with overlap > 1 first (unless tasks are already prioritized on agreement)
        if project.show_overlap_first and not prioritized_on_agreement:
            # don't output anything - just filter tasks with overlap
            logger.debug(f'User={user} tries overlap first from prepared tasks')
            _, not_solved_tasks = _try_tasks_with_overlap(not_solved_tasks)
            queue_info += (' & ' if queue_info else '') + 'Show overlap first'
 
    return not_solved_tasks, user_solved_tasks_array, queue_info, prioritized_on_agreement
 
 
def _prioritize_low_agreement_tasks(tasks, lse_project):
    # if there are any tasks with agreement below the threshold which are labeled, prioritize them over the rest
    # and return all tasks to be considered for sampling in order by least agreement
    prioritized_low_agreement = tasks.filter(_agreement__lt=lse_project.agreement_threshold, is_labeled=True)
 
    if prioritized_low_agreement.exists():
        return True, tasks.order_by('-is_labeled', '_agreement')
 
    return False, tasks
 
 
def get_next_task_without_dm_queue(
    user: User,
    project: Project,
    not_solved_tasks: QuerySet,
    assigned_flag: Union[bool, None],
    prioritized_low_agreement: bool,
    allow_gt_first: bool,
) -> Tuple[Union[Task, None], bool, str]:
    next_task = None
    use_task_lock = True
    queue_info = ''
 
    # Manually assigned tasks
    if assigned_flag:
        logger.debug(f'User={user} try to get task from assigned')
        next_task = not_solved_tasks.first()
        use_task_lock = False
        queue_info += (' & ' if queue_info else '') + 'Manually assigned queue'
 
    # Task lock: if current user already has a locked task, return it (without setting the lock again)
    if not next_task:
        next_task = Task.get_locked_by(user, tasks=not_solved_tasks)
        if next_task:
            logger.debug(f'User={user} got already locked for them {next_task}')
            use_task_lock = False
            queue_info += (' & ' if queue_info else '') + 'Task lock'
 
    # Ground truth: use precomputed gating for GT-first
    if not next_task and allow_gt_first:
        logger.debug(f'User={user} tries ground truth from prepared tasks')
        next_task = _try_ground_truth(not_solved_tasks, project, user)
        if next_task:
            queue_info += (' & ' if queue_info else '') + 'Ground truth queue'
 
    # Low agreement strategy: reassign this annotator to low agreement tasks
    if not next_task and prioritized_low_agreement:
        logger.debug(f'User={user} tries low agreement from prepared tasks')
        next_task = _get_first_unlocked(not_solved_tasks, user)
        if next_task:
            queue_info += (' & ' if queue_info else '') + 'Low agreement queue'
 
    # Breadth first: label in-progress tasks first;
    if not next_task and project.maximum_annotations > 1:
        # if there are already labeled tasks, but task.overlap still < project.maximum_annotations, randomly sampling from them
        logger.debug(f'User={user} tries depth first from prepared tasks')
        next_task = _try_breadth_first(not_solved_tasks, user, project)
        if next_task:
            queue_info += (' & ' if queue_info else '') + 'Breadth first queue'
 
    return next_task, use_task_lock, queue_info
 
 
def skipped_queue(next_task, prepared_tasks, project, user, assigned_flag, queue_info):
    if not next_task and project.skip_queue == project.SkipQueue.REQUEUE_FOR_ME:
        q = Q(project=project, task__isnull=False, was_cancelled=True, task__is_labeled=False)
        skipped_tasks = user.annotations.filter(q).order_by('updated_at').values_list('task__pk', flat=True)
        if skipped_tasks.exists():
            preserved_order = Case(*[When(pk=pk, then=pos) for pos, pk in enumerate(skipped_tasks)])
            skipped_tasks = prepared_tasks.filter(pk__in=skipped_tasks).order_by(preserved_order)
 
            # for assigned annotators locks don't make sense, moreover,
            # _get_first_unlocked breaks label stream for manual mode because
            # it evaluates locks based on auto-mode logic and returns None
            # when there are no more tasks to label in auto-mode
            if assigned_flag:
                next_task = fast_first(skipped_tasks)
            else:
                next_task = _get_first_unlocked(skipped_tasks, user)
            queue_info = 'Skipped queue'
 
    return next_task, queue_info
 
 
def postponed_queue(next_task, prepared_tasks, project, user, assigned_flag, queue_info):
    if not next_task:
        q = Q(task__project=project, task__isnull=False, was_postponed=True, task__is_labeled=False)
        postponed_tasks = user.drafts.filter(q).order_by('updated_at').values_list('task__pk', flat=True)
        if postponed_tasks.exists():
            preserved_order = Case(*[When(pk=pk, then=pos) for pos, pk in enumerate(postponed_tasks)])
            postponed_tasks = prepared_tasks.filter(pk__in=postponed_tasks).order_by(preserved_order)
 
            # for assigned annotators locks don't make sense, moreover,
            # _get_first_unlocked breaks label stream for manual mode because
            # it evaluates locks based on auto-mode logic and returns None
            # when there are no more tasks to label in auto-mode
            if assigned_flag:
                next_task = fast_first(postponed_tasks)
            else:
                next_task = _get_first_unlocked(postponed_tasks, user)
            if next_task is not None:
                next_task.allow_postpone = False
            queue_info = 'Postponed draft queue'
 
    return next_task, queue_info
 
 
def get_task_from_qs_with_sampling(
    not_solved_tasks: QuerySet[Task],
    user_solved_tasks_array: List[int],
    prepared_tasks: QuerySet,
    user: User,
    project: Project,
    queue_info: str,
) -> Tuple[Union[Task, None], str]:
    next_task = None
    if project.sampling == project.SEQUENCE:
        logger.debug(f'User={user} tries sequence sampling from prepared tasks')
        next_task = _get_first_unlocked(not_solved_tasks, user)
        if next_task:
            queue_info += (' & ' if queue_info else '') + 'Sequence queue'
 
    elif project.sampling == project.UNCERTAINTY:
        logger.debug(f'User={user} tries uncertainty sampling from prepared tasks')
        next_task = _try_uncertainty_sampling(not_solved_tasks, project, user_solved_tasks_array, user, prepared_tasks)
        if next_task:
            queue_info += (' & ' if queue_info else '') + 'Active learning or random queue'
 
    elif project.sampling == project.UNIFORM:
        logger.debug(f'User={user} tries random sampling from prepared tasks')
        next_task = _get_random_unlocked(not_solved_tasks, user)
        if next_task:
            queue_info += (' & ' if queue_info else '') + 'Uniform random queue'
 
    return next_task, queue_info
 
 
def get_next_task(
    user: User,
    prepared_tasks: QuerySet,
    project: Project,
    dm_queue: Union[bool, None],
    assigned_flag: Union[bool, None] = None,
) -> Tuple[Union[Task, None], str]:
    logger.debug(f'get_next_task called. user: {user}, project: {project}, dm_queue: {dm_queue}')
 
    with conditional_atomic(predicate=db_is_not_sqlite):
        next_task = None
        use_task_lock = True
        queue_info = ''
 
        # Ground truth: label GT first only during onboarding window for user (gated by min tasks and min score)
        allow_gt_first = should_attempt_ground_truth_first(user, project)
 
        not_solved_tasks, user_solved_tasks_array, queue_info, prioritized_low_agreement = get_not_solved_tasks_qs(
            user, project, prepared_tasks, assigned_flag, queue_info, allow_gt_first
        )
 
        if not dm_queue:
            next_task, use_task_lock, queue_info = get_next_task_without_dm_queue(
                user, project, not_solved_tasks, assigned_flag, prioritized_low_agreement, allow_gt_first
            )
 
        if flag_set('fflag_fix_back_lsdv_4523_show_overlap_first_order_27022023_short'):
            # show tasks with overlap > 1 first
            if not next_task and project.show_overlap_first:
                # don't output anything - just filter tasks with overlap
                logger.debug(f'User={user} tries overlap first from prepared tasks')
                _, tasks_with_overlap = _try_tasks_with_overlap(not_solved_tasks)
                queue_info += (' & ' if queue_info else '') + 'Show overlap first'
                next_task, queue_info = get_task_from_qs_with_sampling(
                    tasks_with_overlap, user_solved_tasks_array, prepared_tasks, user, project, queue_info
                )
 
        if not next_task:
            if dm_queue:
                queue_info += (' & ' if queue_info else '') + 'Data manager queue'
                logger.debug(f'User={user} tries sequence sampling from prepared tasks')
                next_task = not_solved_tasks.first()
 
            else:
                next_task, queue_info = get_task_from_qs_with_sampling(
                    not_solved_tasks, user_solved_tasks_array, prepared_tasks, user, project, queue_info
                )
 
        next_task, queue_info = postponed_queue(next_task, prepared_tasks, project, user, assigned_flag, queue_info)
 
        next_task, queue_info = skipped_queue(next_task, prepared_tasks, project, user, assigned_flag, queue_info)
 
        if next_task and use_task_lock:
            # set lock for the task with TTL 3x time more then current average lead time (or 1 hour by default)
            next_task.set_lock(user)
 
        logger.log(
            get_next_task_logging_level(user),
            f'get_next_task finished. next_task: {next_task}, queue_info: {queue_info}',
        )
 
        # debug for critical overlap issue
        if next_task and flag_set('fflag_fix_back_dev_4185_next_task_additional_logging_long', user):
            try:
                count = next_task.annotations.filter(was_cancelled=False).count()
                task_overlap_reached = count >= next_task.overlap
                global_overlap_reached = count >= project.maximum_annotations
                locks = next_task.locks.count() > project.maximum_annotations - next_task.annotations.count()
                if next_task.is_labeled or task_overlap_reached or global_overlap_reached or locks:
                    from tasks.serializers import TaskSimpleSerializer
 
                    local = dict(locals())
                    local.pop('prepared_tasks', None)
                    local.pop('user_solved_tasks_array', None)
                    local.pop('not_solved_tasks', None)
 
                    task = TaskSimpleSerializer(next_task).data
                    task.pop('data', None)
                    task.pop('predictions', None)
                    for i, a in enumerate(task['annotations']):
                        task['annotations'][i] = dict(task['annotations'][i])
                        task['annotations'][i].pop('result', None)
 
                    project = next_task.project
                    project_data = {
                        'maximum_annotations': project.maximum_annotations,
                        'skip_queue': project.skip_queue,
                        'sampling': project.sampling,
                        'show_ground_truth_first': project.show_ground_truth_first,
                        'show_overlap_first': project.show_overlap_first,
                        'overlap_cohort_percentage': project.overlap_cohort_percentage,
                        'project_id': project.id,
                        'title': project.title,
                    }
                    logger.info(
                        f'DEBUG INFO: get_next_task is_labeled/overlap: '
                        f'LOCALS ==> {local} :: PROJECT ==> {project_data} :: '
                        f'NEXT_TASK ==> {task}'
                    )
            except Exception as e:
                logger.error(f'get_next_task is_labeled/overlap try/except: {str(e)}')
                pass
 
        add_stream_history(next_task, user, project)
        return next_task, queue_info