"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license. """ import logging import os.path import re import tempfile from contextlib import contextmanager from copy import deepcopy from functools import wraps from pathlib import Path from types import SimpleNamespace from unittest import mock import pytest import requests import requests_mock import ujson as json from box import Box from core.feature_flags import flag_set from data_export.models import ConvertedFormat, Export from django.apps import apps from django.conf import settings from django.test import Client from ml.models import MLBackend from organizations.models import Organization from projects.models import Project from tasks.serializers import TaskWithAnnotationsSerializer from users.models import User try: from businesses.models import BillingPlan, Business except ImportError: BillingPlan = Business = None logger = logging.getLogger(__name__) @contextmanager def ml_backend_mock(**kwargs): with requests_mock.Mocker(real_http=True) as m: yield register_ml_backend_mock(m, **kwargs) def register_ml_backend_mock( m, url='http://localhost:9090', predictions=None, health_connect_timeout=False, train_job_id='123', setup_model_version='abc', ): m.post(f'{url}/setup', text=json.dumps({'status': 'ok', 'model_version': setup_model_version})) if health_connect_timeout: m.get(f'{url}/health', exc=requests.exceptions.ConnectTimeout) else: m.get(f'{url}/health', text=json.dumps({'status': 'UP'})) m.post(f'{url}/train', text=json.dumps({'status': 'ok', 'job_id': train_job_id})) m.post(f'{url}/predict', text=json.dumps(predictions or {})) m.post(f'{url}/webhook', text=json.dumps({})) m.get(f'{url}/versions', text=json.dumps({'versions': ['1', '2']})) return m @contextmanager def import_from_url_mock(**kwargs): with mock.patch('core.utils.io.validate_upload_url'): with requests_mock.Mocker(real_http=True) as m: with open('./tests/test_suites/samples/test_1.csv', 'rb') as f: matcher = re.compile('data\.heartextest\.net/test_1\.csv') m.get(matcher, body=f, headers={'Content-Length': '100'}) yield m class _TestJob(object): def __init__(self, job_id): self.id = job_id @contextmanager def email_mock(): from django.core.mail import EmailMultiAlternatives with mock.patch.object(EmailMultiAlternatives, 'send'): yield @contextmanager def gcs_client_mock(): # be careful, this is a global contextmanager (sample_blob_names) # and will affect all tests because it will be applied to all tests that use gcs_client # it may lead to flaky tests if the sample blob names are not deterministic from collections import namedtuple from google.cloud import storage as google_storage def get_sample_blob_names_for_bucket(bucket_name): # Bucket-specific logic to avoid test bleed if bucket_name in ['pytest-recursive-scan-bucket']: result = ['dataset/', 'dataset/a.json', 'dataset/sub/b.json', 'other/c.json'] logger.info(f'get_sample_blob_names_for_bucket({bucket_name}) -> {result} (recursive scan bucket)') return result elif bucket_name.startswith('multitask_'): result = ['test.json'] logger.info(f'get_sample_blob_names_for_bucket({bucket_name}) -> {result} (multitask)') return result elif bucket_name.startswith('test-gs-bucket'): # Force deterministic samples for standard GCS test buckets - never use closure variable result = ['abc', 'def', 'ghi'] logger.info(f'get_sample_blob_names_for_bucket({bucket_name}) -> {result} (test-gs-bucket prefix)') return result else: result = ['abc', 'def', 'ghi'] logger.info(f'get_sample_blob_names_for_bucket({bucket_name}) -> {result} (default)') return result class DummyGCSBlob: def __init__(self, bucket_name, key, is_json, is_multitask): self.key = key self.bucket_name = bucket_name # Align with google-cloud-storage: Blob.name is the object key within the bucket self.name = key self.is_json = is_json self.sample_json_contents = ( [ {'data': {'image_url': 'http://ggg.com/image.jpg', 'text': 'Task 1 text'}}, {'data': {'image_url': 'http://ggg.com/image2.jpg', 'text': 'Task 2 text'}}, ] if is_multitask else { 'str_field': 'test', 'int_field': 123, 'dict_field': {'one': 'wow', 'two': 456}, } ) def download_as_string(self): data = f'test_blob_{self.key}' if self.is_json: payload = json.dumps(self.sample_json_contents) logger.info( f'DummyGCSBlob.download_as_string bucket={self.bucket_name} key={self.key} json=True bytes={len(payload)}' ) return payload logger.info(f'DummyGCSBlob.download_as_string bucket={self.bucket_name} key={self.key} json=False') return data def upload_from_string(self, string): print(f'String {string} uploaded to bucket {self.bucket_name}') def generate_signed_url(self, **kwargs): url = f'https://storage.googleapis.com/{self.bucket_name}/{self.key}' logger.info(f'DummyGCSBlob.generate_signed_url url={url}') return url def download_as_bytes(self): b = self.download_as_string().encode('utf-8') logger.info(f'DummyGCSBlob.download_as_bytes bucket={self.bucket_name} key={self.key} size={len(b)}') return b class DummyGCSBucket: def __init__(self, bucket_name, is_json, is_multitask): self.name = bucket_name self.is_json = is_json self.is_multitask = is_multitask # Use bucket-specific sample names self.sample_blob_names = get_sample_blob_names_for_bucket(bucket_name) def list_blobs(self, prefix, **kwargs): File = namedtuple('File', ['name']) if 'fake' in prefix: logger.info(f'DummyGCSBucket.list_blobs bucket={self.name} prefix={prefix} -> [] (fake)') return [] # Handle delimiter for non-recursive listing (only direct children) if 'delimiter' in kwargs and kwargs['delimiter']: delimiter = kwargs['delimiter'] pref = prefix or '' if pref: search_prefix = pref if pref.endswith(delimiter) else pref + delimiter filtered_names = [] for name in self.sample_blob_names: if name.startswith(search_prefix): remaining_path = name[len(search_prefix) :] if delimiter not in remaining_path: filtered_names.append(name) else: # Root-level: only keys without delimiter are direct children filtered_names = [name for name in self.sample_blob_names if delimiter not in name] logger.info( f'DummyGCSBucket.list_blobs bucket={self.name} prefix={prefix} delimiter={delimiter} -> {filtered_names}' ) return [File(name) for name in filtered_names] result = [name for name in self.sample_blob_names if prefix is None or name.startswith(prefix)] logger.info(f'DummyGCSBucket.list_blobs bucket={self.name} prefix={prefix} -> {result}') return [File(name) for name in result] def blob(self, key): logger.info(f'DummyGCSBucket.blob bucket={self.name} key={key}') return DummyGCSBlob(self.name, key, self.is_json, self.is_multitask) class DummyGCSClient: def get_bucket(self, bucket_name): is_json = bucket_name.endswith('_JSON') is_multitask = bucket_name.startswith('multitask_') logger.info( f'DummyGCSClient.get_bucket bucket={bucket_name} is_json={is_json} is_multitask={is_multitask}' ) return DummyGCSBucket(bucket_name, is_json, is_multitask) def list_blobs(self, bucket_name, prefix, delimiter=None): is_json = bucket_name.endswith('_JSON') is_multitask = bucket_name.startswith('multitask_') sample_blob_names = get_sample_blob_names_for_bucket(bucket_name) # Handle delimiter for non-recursive listing (only direct children) if delimiter: pref = prefix or '' if pref: search_prefix = pref if pref.endswith(delimiter) else pref + delimiter filtered_names = [] for name in sample_blob_names: if name.startswith(search_prefix): remaining_path = name[len(search_prefix) :] if delimiter not in remaining_path: filtered_names.append(name) else: # Root-level: only keys without delimiter are direct children filtered_names = [name for name in sample_blob_names if delimiter not in name] logger.info( f'DummyGCSClient.list_blobs bucket={bucket_name} prefix={prefix} delimiter={delimiter} -> {filtered_names}' ) return [DummyGCSBlob(bucket_name, name, is_json, is_multitask) for name in filtered_names] result = [name for name in sample_blob_names if prefix is None or name.startswith(prefix)] logger.info(f'DummyGCSClient.list_blobs bucket={bucket_name} prefix={prefix} -> {result}') return [ DummyGCSBlob(bucket_name, name, is_json, is_multitask) for name in sample_blob_names if prefix is None or name.startswith(prefix) ] with mock.patch.object(google_storage, 'Client', return_value=DummyGCSClient()): logger.info('gcs_client_mock installed') yield google_storage @contextmanager def azure_client_mock(sample_json_contents=None, sample_blob_names=None): # be careful, this is a global contextmanager (sample_json_contents, sample_blob_names) # and will affect all tests because it will be applied to all tests that use azure_client # and it may lead to flaky tests if the sample blob names are not deterministic from collections import namedtuple from io_storages.azure_blob import models File = namedtuple('File', ['name']) sample_json_contents = sample_json_contents or { 'str_field': 'test', 'int_field': 123, 'dict_field': {'one': 'wow', 'two': 456}, } sample_blob_names = sample_blob_names or ['abc', 'def', 'ghi'] class DummyAzureBlob: def __init__(self, container_name, key): self.key = key self.container_name = container_name def download_as_string(self): return f'test_blob_{self.key}' def upload_blob(self, string, overwrite): print(f'String {string} uploaded to bucket {self.container_name}') def generate_signed_url(self, **kwargs): return f'https://storage.googleapis.com/{self.container_name}/{self.key}' def content_as_text(self): return json.dumps(sample_json_contents) def content_as_bytes(self): return json.dumps(sample_json_contents).encode('utf-8') class DummyAzureContainer: def __init__(self, container_name, **kwargs): self.name = container_name self.sample_blob_names = deepcopy(sample_blob_names) def list_blobs(self, name_starts_with): return [File(name) for name in self.sample_blob_names] def walk_blobs(self, name_starts_with, delimiter): return [File(name) for name in self.sample_blob_names] def get_blob_client(self, key): return DummyAzureBlob(self.name, key) def get_container_properties(self, **kwargs): return SimpleNamespace( name='test-container', last_modified='2022-01-01 01:01:01', etag='test-etag', lease='test-lease', public_access='public', has_immutability_policy=True, has_legal_hold=True, immutable_storage_with_versioning_enabled=True, metadata={'key': 'value'}, encryption_scope='test-scope', deleted=False, version='1.0.0', ) def download_blob(self, key): return DummyAzureBlob(self.name, key) class DummyAzureClient: def get_container_client(self, container_name): return DummyAzureContainer(container_name) # def dummy_generate_blob_sas(*args, **kwargs): # return 'token' with mock.patch.object(models.BlobServiceClient, 'from_connection_string', return_value=DummyAzureClient()): with mock.patch.object(models, 'generate_blob_sas', return_value='token'): yield @contextmanager def redis_client_mock(): from fakeredis import FakeRedis from io_storages.redis.models import RedisStorageMixin redis = FakeRedis(decode_responses=True) # TODO: add mocked redis data with mock.patch.object(RedisStorageMixin, 'get_redis_connection', return_value=redis): yield redis def upload_data(client, project, tasks): tasks = TaskWithAnnotationsSerializer(tasks, many=True).data data = [{'data': task['data'], 'annotations': task['annotations']} for task in tasks] return client.post(f'/api/projects/{project.id}/tasks/bulk', data=data, content_type='application/json') def make_project(config, user, use_ml_backend=True, team_id=None, org=None): if org is None: org = Organization.objects.filter(created_by=user).first() project = Project.objects.create(created_by=user, organization=org, **config) if use_ml_backend: MLBackend.objects.create(project=project, url='http://localhost:8999') return project @pytest.fixture @pytest.mark.django_db def project_id(business_client): payload = dict( title='test_project', label_config='', ) response = business_client.post( '/api/projects/', data=json.dumps(payload), content_type='application/json', ) return response.json()['id'] def make_task(config, project): from tasks.models import Task return Task.objects.create(project=project, overlap=project.maximum_annotations, **config) def create_business(user): return None def make_annotation(config, task_id): from tasks.models import Annotation, Task task = Task.objects.get(pk=task_id) return Annotation.objects.create(project_id=task.project_id, task_id=task_id, **config) def make_prediction(config, task_id): from tasks.models import Prediction, Task task = Task.objects.get(pk=task_id) return Prediction.objects.create(task_id=task_id, project=task.project, **config) def make_annotator(config, project, login=False, client=None): from users.models import User user = User.objects.create(**config) user.set_password('12345') user.save() create_business(user) if login: Organization.create_organization(created_by=user, title=user.first_name) if client is None: client = Client() signin_status_code = signin(client, config['email'], '12345').status_code assert signin_status_code == 302, f'Sign-in status code: {signin_status_code}' project.add_collaborator(user) if login: client.annotator = user return client return user def invite_client_to_project(client, project): if apps.is_installed('annotators'): return client.get(f'/annotator/invites/{project.token}/') else: return SimpleNamespace(status_code=200) def login(client, email, password): if User.objects.filter(email=email).exists(): r = client.post('/user/login/', data={'email': email, 'password': password}) assert r.status_code == 302, r.status_code else: r = client.post('/user/signup/', data={'email': email, 'password': password, 'title': 'Whatever'}) assert r.status_code == 302, r.status_code def signin(client, email, password): return client.post('/user/login/', data={'email': email, 'password': password}) def signout(client): return client.get('/logout') def _client_is_annotator(client): return 'annotator' in client.user.email def save_response(response): fp = os.path.join(settings.TEST_DATA_ROOT, 'tavern-output.json') with open(fp, 'w') as f: json.dump(response.json(), f) def os_independent_path(_, path, add_tempdir=False): os_independent_path = Path(path) if add_tempdir: tempdir = Path(tempfile.gettempdir()) os_independent_path = tempdir / os_independent_path os_independent_path_parent = os_independent_path.parent return Box( { 'os_independent_path': str(os_independent_path), 'os_independent_path_parent': str(os_independent_path_parent), 'os_independent_path_tmpdir': str(Path(tempfile.gettempdir())), } ) def verify_docs(response): for _, path in response.json()['paths'].items(): print(path) for _, method in path.items(): print(method) if isinstance(method, dict): assert 'api' not in method['tags'], f'Need docs for API method {method}' def empty_list(response): assert len(response.json()) == 0, f'Response should be empty, but is {response.json()}' def save_export_file_path(response): export_id = response.json().get('id') export = Export.objects.get(id=export_id) file_path = export.file.path return Box({'file_path': file_path}) def save_convert_file_path(response, export_id=None): export = response.json()[0] convert = export['converted_formats'][0] converted = ConvertedFormat.objects.get(id=convert['id']) dir_path = os.path.join(settings.MEDIA_ROOT, settings.DELAYED_EXPORT_DIR) os.listdir(dir_path) try: file_path = converted.file.path return Box({'convert_file_path': file_path}) except ValueError: return Box({'convert_file_path': None}) def file_exists_in_storage(response, exists=True, file_path=None): if not file_path: export_id = response.json().get('id') export = Export.objects.get(id=export_id) file_path = export.file.path assert os.path.isfile(file_path) == exists def mock_feature_flag(flag_name: str, value: bool, parent_module: str = 'core.feature_flags'): """Decorator to mock a feature flag state for a test function. Args: flag_name: Name of the feature flag to mock value: True or False to set the flag state parent_module: Module path containing the flag_set function to patch """ def decorator(func): @wraps(func) def wrapper(*args, **kwargs): def fake_flag_set(feature_flag, *flag_args, **flag_kwargs): if feature_flag == flag_name: return value return flag_set(feature_flag, *flag_args, **flag_kwargs) with mock.patch(f'{parent_module}.flag_set', wraps=fake_flag_set): return func(*args, **kwargs) return wrapper return decorator