"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license."""
|
|
import logging
|
|
from django.conf import settings
|
from django.db import models
|
from django.db.models import Q
|
from django.utils.translation import gettext_lazy as _
|
from ml_model_providers.models import ModelProviderConnection, ModelProviders
|
from projects.models import Project
|
from rest_framework.exceptions import ValidationError
|
from tasks.models import Annotation, FailedPrediction, Prediction, PredictionMeta
|
|
logger = logging.getLogger(__name__)
|
|
|
# skills are partitions of projects (label config + input columns + output columns) into categories of labeling tasks
|
class SkillNames(models.TextChoices):
|
TEXT_CLASSIFICATION = 'TextClassification', _('TextClassification')
|
NAMED_ENTITY_RECOGNITION = 'NamedEntityRecognition', _('NamedEntityRecognition')
|
|
|
def validate_string_list(value):
|
if not value:
|
raise ValidationError('list should not be empty')
|
if not isinstance(value, list):
|
raise ValidationError('Value must be a list')
|
if not all(isinstance(item, str) for item in value):
|
raise ValidationError('All items in the list must be strings')
|
|
|
class ModelInterface(models.Model):
|
title = models.CharField(_('title'), max_length=500, null=False, blank=False, help_text='Model name')
|
|
description = models.TextField(_('description'), null=True, blank=True, help_text='Model description')
|
|
created_by = models.ForeignKey(
|
settings.AUTH_USER_MODEL, related_name='created_models', on_delete=models.SET_NULL, null=True
|
)
|
|
created_at = models.DateTimeField(_('created at'), auto_now_add=True)
|
|
updated_at = models.DateTimeField(_('updated at'), auto_now=True)
|
|
organization = models.ForeignKey(
|
'organizations.Organization', on_delete=models.CASCADE, related_name='model_interfaces', null=True
|
)
|
|
skill_name = models.CharField(max_length=255, choices=SkillNames.choices, null=True)
|
|
input_fields = models.JSONField(default=list, validators=[validate_string_list])
|
|
output_classes = models.JSONField(default=list, validators=[validate_string_list])
|
|
associated_projects = models.ManyToManyField('projects.Project', blank=True)
|
|
def has_permission(self, user):
|
return user.active_organization == self.organization
|
|
|
class ModelVersion(models.Model):
|
class Meta:
|
abstract = True
|
|
title = models.CharField(_('title'), max_length=500, null=False, blank=False, help_text='Model name')
|
|
parent_model = models.ForeignKey(ModelInterface, related_name='model_versions', on_delete=models.CASCADE)
|
|
prompt = models.TextField(_('prompt'), null=False, blank=False, help_text='Prompt to execute')
|
|
model_provider_connection = models.ForeignKey(
|
ModelProviderConnection, related_name='model_versions', on_delete=models.SET_NULL, null=True
|
)
|
|
@property
|
def full_title(self):
|
return f'{self.parent_model.title}__{self.title}'
|
|
def delete(self, *args, **kwargs):
|
"""
|
Deletes Predictions associated with ModelVersion
|
"""
|
model_runs = ModelRun.objects.filter(model_version=self.id)
|
for model_run in model_runs:
|
model_run.delete_predictions()
|
super().delete(*args, **kwargs)
|
|
|
class ThirdPartyModelVersion(ModelVersion):
|
provider = models.CharField(
|
max_length=255,
|
choices=ModelProviders.choices,
|
default=ModelProviders.OPENAI,
|
help_text='The model provider to use e.g. OpenAI',
|
)
|
|
provider_model_id = models.CharField(
|
max_length=255,
|
blank=False,
|
null=False,
|
help_text='The model ID to use within the given provider, e.g. gpt-3.5',
|
)
|
|
created_by = models.ForeignKey(
|
settings.AUTH_USER_MODEL,
|
related_name='created_third_party_model_versions',
|
on_delete=models.SET_NULL,
|
null=True,
|
)
|
|
created_at = models.DateTimeField(_('created at'), auto_now_add=True)
|
|
updated_at = models.DateTimeField(_('updated at'), auto_now=True)
|
|
organization = models.ForeignKey(
|
'organizations.Organization', on_delete=models.CASCADE, related_name='third_party_model_versions', null=True
|
)
|
|
@property
|
def project(self):
|
# TODO: can it be just a property of the model version?
|
if self.parent_model and self.parent_model.associated_projects.exists():
|
return self.parent_model.associated_projects.first()
|
return None
|
|
def has_permission(self, user):
|
return user.active_organization == self.organization
|
|
|
class ModelRun(models.Model):
|
class ProjectSubset(models.TextChoices):
|
ALL = 'All', _('All')
|
HASGT = 'HasGT', _('HasGT')
|
SAMPLE = 'Sample', _('Sample')
|
|
class FileType(models.TextChoices):
|
INPUT = 'Input', _('Input')
|
OUTPUT = 'Output', _('Output')
|
|
class ModelRunStatus(models.TextChoices):
|
PENDING = 'Pending', _('Pending')
|
INPROGRESS = 'InProgress', _('InProgress')
|
COMPLETED = 'Completed', ('Completed')
|
FAILED = 'Failed', ('Failed')
|
CANCELED = 'Canceled', ('Canceled')
|
|
organization = models.ForeignKey(
|
'organizations.Organization', on_delete=models.CASCADE, related_name='model_runs', null=True
|
)
|
|
project = models.ForeignKey(Project, on_delete=models.CASCADE, related_name='model_runs')
|
|
model_version = models.ForeignKey(ThirdPartyModelVersion, on_delete=models.CASCADE, related_name='model_runs')
|
|
created_by = models.ForeignKey(
|
settings.AUTH_USER_MODEL,
|
related_name='model_runs',
|
on_delete=models.SET_NULL,
|
null=True,
|
)
|
|
project_subset = models.CharField(max_length=255, choices=ProjectSubset.choices, default=ProjectSubset.HASGT)
|
|
status = models.CharField(max_length=255, choices=ModelRunStatus.choices, default=ModelRunStatus.PENDING)
|
|
job_id = models.CharField(
|
max_length=255,
|
null=True,
|
blank=True,
|
default=None,
|
help_text='Job ID for inference job for a ModelRun e.g. Adala job ID',
|
)
|
|
total_predictions = models.IntegerField(_('total predictions'), default=0)
|
|
total_correct_predictions = models.IntegerField(_('total correct predictions'), default=0)
|
|
total_tasks = models.IntegerField(_('total tasks'), default=0)
|
|
created_at = models.DateTimeField(_('created at'), auto_now_add=True)
|
|
triggered_at = models.DateTimeField(_('triggered at'), null=True, default=None)
|
|
predictions_updated_at = models.DateTimeField(_('predictions updated at'), null=True, default=None)
|
|
completed_at = models.DateTimeField(_('completed at'), null=True, default=None)
|
|
def has_permission(self, user):
|
return user.active_organization == self.organization
|
|
def delete_predictions(self):
|
"""
|
Deletes any predictions that have originated from a ModelRun
|
|
Executing a raw SQL query here for speed. This ignores any foreign key relationships
|
so if another model has a Prediction fk and set to on_delete=CASCADE for example,
|
it will not take affect. The only relationship like this that currently exists
|
is in Annotation.parent_prediction, which we are handling here.
|
"""
|
predictions = Prediction.objects.filter(model_run=self.id)
|
prediction_ids = [p.id for p in predictions]
|
# to delete all dependencies where predictions are foreign keys.
|
Annotation.objects.filter(parent_prediction__in=prediction_ids).update(parent_prediction=None)
|
try:
|
from stats.models import PredictionPairStats, PredictionStats
|
|
prediction_stats_to_be_deleted = PredictionStats.objects.filter(prediction_to__in=prediction_ids)
|
prediction_stats_to_be_deleted.delete()
|
prediction_pair_stats_to_be_deleted = PredictionPairStats.objects.filter(
|
Q(prediction_to_id__in=prediction_ids) | Q(prediction_from_id__in=prediction_ids)
|
)
|
prediction_pair_stats_to_be_deleted.delete()
|
except Exception as e:
|
logger.info(f'PredictionStats or PredictionPairStats model does not exist , exception:{e}')
|
|
# Delete failed predictions. Currently no other model references this, no fk relationships to remove
|
failed_predictions = FailedPrediction.objects.filter(model_run=self.id)
|
failed_predictions_ids = [p.id for p in failed_predictions]
|
|
# delete predictions meta
|
PredictionMeta.objects.filter(prediction__in=prediction_ids).delete()
|
PredictionMeta.objects.filter(failed_prediction__in=failed_predictions_ids).delete()
|
|
# remove predictions from db
|
predictions._raw_delete(predictions.db)
|
failed_predictions._raw_delete(failed_predictions.db)
|
|
def delete(self, *args, **kwargs):
|
"""
|
Deletes Predictions associated with ModelRun
|
"""
|
self.delete_predictions()
|
super().delete(*args, **kwargs)
|