from unittest.mock import patch import pytest from django.db import OperationalError from projects.tests.factories import ProjectFactory from tasks.models import Task from tasks.tests.factories import TaskFactory @pytest.mark.django_db class TestBatchUpdateWithRetry: def test_batch_update_success_without_deadlock(self): project = ProjectFactory(overlap_cohort_percentage=50, maximum_annotations=3) tasks = TaskFactory.create_batch(10, project=project) task_ids = [task.id for task in tasks] tasks_qs = Task.objects.filter(id__in=task_ids) project._batch_update_with_retry(tasks_qs, batch_size=3, overlap=2) updated_tasks = Task.objects.filter(id__in=task_ids, overlap=2) assert updated_tasks.count() == 10 def test_batch_update_with_deadlock_retry(self): project = ProjectFactory(overlap_cohort_percentage=50, maximum_annotations=3) tasks = TaskFactory.create_batch(5, project=project) task_ids = [task.id for task in tasks] tasks_qs = Task.objects.filter(id__in=task_ids) call_count = 0 original_batch_update = project._batch_update_with_retry def mock_batch_update(queryset, *args, **kwargs): nonlocal call_count call_count += 1 if call_count == 1: # Simulate deadlock on first call with patch('django.db.transaction.atomic') as mock_atomic: mock_atomic.side_effect = OperationalError('deadlock detected') try: original_batch_update(queryset, *args, **kwargs) except OperationalError: pass return original_batch_update(queryset, *args, **kwargs) with patch.object(project, '_batch_update_with_retry', side_effect=mock_batch_update): project._batch_update_with_retry(tasks_qs, batch_size=5, overlap=2) assert call_count >= 1 updated_tasks = Task.objects.filter(id__in=task_ids, overlap=2) assert updated_tasks.count() == 5 def test_batch_update_with_multiple_batches(self): project = ProjectFactory(overlap_cohort_percentage=50, maximum_annotations=3) tasks = TaskFactory.create_batch(15, project=project) task_ids = [task.id for task in tasks] tasks_qs = Task.objects.filter(id__in=task_ids) project._batch_update_with_retry(tasks_qs, batch_size=5, overlap=3, is_labeled=True) updated_tasks = Task.objects.filter(id__in=task_ids, overlap=3, is_labeled=True) assert updated_tasks.count() == 15 def test_batch_update_exceeds_max_retries(self): project = ProjectFactory(overlap_cohort_percentage=50, maximum_annotations=3) tasks = TaskFactory.create_batch(5, project=project) task_ids = [task.id for task in tasks] tasks_qs = Task.objects.filter(id__in=task_ids) def mock_batch_update_always_deadlock(queryset, *args, **kwargs): # Always simulate deadlock to test max retries exceeded with patch('django.db.transaction.atomic') as mock_atomic: mock_atomic.side_effect = OperationalError('deadlock detected') from core.utils.db import batch_update_with_retry batch_update_with_retry(queryset, max_retries=2, **kwargs) with patch.object(project, '_batch_update_with_retry', side_effect=mock_batch_update_always_deadlock): with pytest.raises(OperationalError, match='deadlock detected'): project._batch_update_with_retry(tasks_qs, batch_size=5, overlap=2)