"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license. """ import logging from typing import List from django.conf import settings from django.db import models from django.utils.translation import gettext_lazy as _ from tasks.models import PredictionMeta logger = logging.getLogger(__name__) class ModelProviders(models.TextChoices): OPENAI = 'OpenAI', _('OpenAI') AZURE_OPENAI = 'AzureOpenAI', _('AzureOpenAI') AZURE_AI_FOUNDRY = 'AzureAIFoundry', _('AzureAIFoundry') VERTEX_AI = 'VertexAI', _('VertexAI') GEMINI = 'Gemini', _('Gemini') ANTHROPIC = 'Anthropic', _('Anthropic') CUSTOM = 'Custom', _('Custom') class ModelProviderConnectionScopes(models.TextChoices): ORG = 'Organization', _('Organization') USER = 'User', _('User') MODEL = 'Model', _('Model') class ModelProviderConnection(models.Model): provider = models.CharField(max_length=255, choices=ModelProviders.choices, default=ModelProviders.OPENAI) api_key = models.TextField(_('api_key'), null=True, blank=True, help_text='Model provider API key') auth_token = models.TextField(_('auth_token'), null=True, blank=True, help_text='Model provider Auth token') deployment_name = models.CharField(max_length=512, null=True, blank=True, help_text='Azure OpenAI deployment name') endpoint = models.CharField(max_length=512, null=True, blank=True, help_text='Azure OpenAI endpoint') google_application_credentials = models.TextField( _('google application credentials'), null=True, blank=True, help_text='The content of GOOGLE_APPLICATION_CREDENTIALS json file', ) google_project_id = models.CharField( _('google project id'), max_length=255, null=True, blank=True, help_text='Google project ID' ) google_location = models.CharField( _('google location'), max_length=255, null=True, blank=True, help_text='Google project location' ) cached_available_models = models.CharField( max_length=4096, null=True, blank=True, help_text='List of available models from the provider' ) scope = models.CharField( max_length=255, choices=ModelProviderConnectionScopes.choices, default=ModelProviderConnectionScopes.ORG ) organization = models.ForeignKey( 'organizations.Organization', on_delete=models.CASCADE, related_name='model_provider_connections', null=True ) created_by = models.ForeignKey( settings.AUTH_USER_MODEL, related_name='created_model_provider_connections', on_delete=models.SET_NULL, null=True, ) # Future work - add foreign key for modelinterface / modelinstance created_at = models.DateTimeField(_('created at'), auto_now_add=True) updated_at = models.DateTimeField(_('updated at'), auto_now=True) is_internal = models.BooleanField( _('is_internal'), default=False, help_text='Whether the model provider connection is internal, not visible to the user', null=True, blank=True, ) budget_limit = models.FloatField( _('budget_limit'), null=True, blank=True, default=None, help_text='Budget limit for the model provider connection (null if unlimited)', ) budget_last_reset_date = models.DateTimeField( _('budget_last_reset_date'), null=True, blank=True, default=None, help_text='Date and time the budget was last reset', ) budget_reset_period = models.CharField( _('budget_reset_period'), max_length=20, choices=[ ('Monthly', 'Monthly'), ('Yearly', 'Yearly'), ], null=True, blank=True, default=None, help_text='Budget reset period for the model provider connection (null if not reset)', ) budget_total_spent = models.FloatField( _('budget_total_spent'), null=True, blank=True, default=None, help_text='Tracked total budget spent for the given provider connection within the current budget period', ) budget_alert_threshold = models.FloatField( _('budget_alert_threshold'), null=True, blank=True, default=None, help_text='Budget alert threshold for the given provider connection', ) # Check if user is Admin or Owner # This will need to be updated if we ever use this model in LSO as `is_owner` and # `is_administrator` only exist in LSE def has_permission(self, user): return ( user.is_administrator or user.is_owner or user.is_manager ) and user.active_organization_id == self.organization_id def update_budget_total_spent_from_predictions_meta(self, predictions_meta: List[PredictionMeta]): total_cost = sum(meta.total_cost or 0 for meta in predictions_meta) # opting for the goofy "self.budget_total_spent or 0" to avoid a db migration self.budget_total_spent = (self.budget_total_spent or 0) + total_cost self.save(update_fields=['budget_total_spent']) def has_reached_budget_limit(self): if ( self.is_internal and self.budget_total_spent and self.budget_limit and self.budget_total_spent > self.budget_limit ): logger.info( f'Model connection {self.id} has reached the budget limit: ' f'{self.budget_total_spent} > {self.budget_limit}' ) return True return False