from django.test import TestCase
|
from django.urls import reverse
|
from django.utils import timezone
|
from django.utils.http import urlencode
|
from projects.tests.factories import ProjectFactory
|
from rest_framework.test import APIClient, APITestCase
|
from tasks.models import Task
|
from tasks.tests.factories import PredictionFactory, TaskFactory
|
|
|
class TestProjectCountsListAPI(TestCase):
|
@classmethod
|
def setUpTestData(cls):
|
cls.project_1 = ProjectFactory()
|
cls.project_2 = ProjectFactory(organization=cls.project_1.organization)
|
Task.objects.create(project=cls.project_1, data={'text': 'Task 1'})
|
Task.objects.create(project=cls.project_1, data={'text': 'Task 2'})
|
Task.objects.create(project=cls.project_2, data={'text': 'Task 3'})
|
|
def get_url(self, **params):
|
return f'{reverse("projects:api:project-counts-list")}?{urlencode(params)}'
|
|
def test_get_counts(self):
|
client = APIClient()
|
client.force_authenticate(user=self.project_1.created_by)
|
response = client.get(self.get_url(include='id,task_number,finished_task_number,total_predictions_number'))
|
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.json()['count'], 2)
|
expected = [
|
{
|
'id': self.project_1.id,
|
'task_number': 2,
|
'finished_task_number': 0,
|
'total_predictions_number': 0,
|
},
|
{
|
'id': self.project_2.id,
|
'task_number': 1,
|
'finished_task_number': 0,
|
'total_predictions_number': 0,
|
},
|
]
|
actual = sorted(response.json()['results'], key=lambda d: d['id'])
|
self.assertEqual(actual, expected)
|
|
|
class TestProjectModelVersionsAPI(APITestCase):
|
@classmethod
|
def setUpTestData(cls):
|
cls.project = ProjectFactory()
|
cls.user = cls.project.created_by
|
|
cls.task = TaskFactory(project=cls.project)
|
cls.prediction_m1 = PredictionFactory(task=cls.task, model_version='model_1')
|
cls.prediction_m1_2 = PredictionFactory(task=cls.task, model_version='model_1')
|
cls.prediction_m2 = PredictionFactory(task=cls.task, model_version='model_2')
|
cls.prediction_m3 = PredictionFactory(task=cls.task, model_version='model_3')
|
|
# To test ordering by last used
|
cls.prediction_m2.created_at = timezone.now()
|
cls.prediction_m2.save()
|
|
def test_no_params(self):
|
self.client.force_authenticate(user=self.user)
|
response = self.client.get(f'/api/projects/{self.project.id}/model-versions')
|
assert response.status_code == 200
|
assert response.json() == {
|
'model_2': 1,
|
'model_3': 1,
|
'model_1': 2,
|
}
|
|
def test_limit(self):
|
self.client.force_authenticate(user=self.user)
|
response = self.client.get(f'/api/projects/{self.project.id}/model-versions?limit=2')
|
assert response.status_code == 200
|
assert response.json() == {
|
'model_2': 1,
|
'model_3': 1,
|
}
|
|
def test_extended(self):
|
self.client.force_authenticate(user=self.user)
|
response = self.client.get(f'/api/projects/{self.project.id}/model-versions?extended=true')
|
assert response.status_code == 200
|
assert response.json()['live'] is None
|
assert response.json()['static'][0]['model_version'] == 'model_2'
|
assert response.json()['static'][0]['count'] == 1
|
assert response.json()['static'][1]['model_version'] == 'model_3'
|
assert response.json()['static'][1]['count'] == 1
|
assert response.json()['static'][2]['model_version'] == 'model_1'
|
assert response.json()['static'][2]['count'] == 2
|