Bin
2025-12-17 1442f92732d7c5311a627a7ba3aaa0bb8ffc539f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license.
"""
import logging
from typing import List
 
from django.conf import settings
from django.db import models
from django.utils.translation import gettext_lazy as _
from tasks.models import PredictionMeta
 
logger = logging.getLogger(__name__)
 
 
class ModelProviders(models.TextChoices):
    OPENAI = 'OpenAI', _('OpenAI')
    AZURE_OPENAI = 'AzureOpenAI', _('AzureOpenAI')
    AZURE_AI_FOUNDRY = 'AzureAIFoundry', _('AzureAIFoundry')
    VERTEX_AI = 'VertexAI', _('VertexAI')
    GEMINI = 'Gemini', _('Gemini')
    ANTHROPIC = 'Anthropic', _('Anthropic')
    CUSTOM = 'Custom', _('Custom')
 
 
class ModelProviderConnectionScopes(models.TextChoices):
    ORG = 'Organization', _('Organization')
    USER = 'User', _('User')
    MODEL = 'Model', _('Model')
 
 
class ModelProviderConnection(models.Model):
 
    provider = models.CharField(max_length=255, choices=ModelProviders.choices, default=ModelProviders.OPENAI)
 
    api_key = models.TextField(_('api_key'), null=True, blank=True, help_text='Model provider API key')
 
    auth_token = models.TextField(_('auth_token'), null=True, blank=True, help_text='Model provider Auth token')
 
    deployment_name = models.CharField(max_length=512, null=True, blank=True, help_text='Azure OpenAI deployment name')
 
    endpoint = models.CharField(max_length=512, null=True, blank=True, help_text='Azure OpenAI endpoint')
 
    google_application_credentials = models.TextField(
        _('google application credentials'),
        null=True,
        blank=True,
        help_text='The content of GOOGLE_APPLICATION_CREDENTIALS json file',
    )
 
    google_project_id = models.CharField(
        _('google project id'), max_length=255, null=True, blank=True, help_text='Google project ID'
    )
 
    google_location = models.CharField(
        _('google location'), max_length=255, null=True, blank=True, help_text='Google project location'
    )
 
    cached_available_models = models.CharField(
        max_length=4096, null=True, blank=True, help_text='List of available models from the provider'
    )
 
    scope = models.CharField(
        max_length=255, choices=ModelProviderConnectionScopes.choices, default=ModelProviderConnectionScopes.ORG
    )
 
    organization = models.ForeignKey(
        'organizations.Organization', on_delete=models.CASCADE, related_name='model_provider_connections', null=True
    )
 
    created_by = models.ForeignKey(
        settings.AUTH_USER_MODEL,
        related_name='created_model_provider_connections',
        on_delete=models.SET_NULL,
        null=True,
    )
 
    # Future work - add foreign key for modelinterface / modelinstance
 
    created_at = models.DateTimeField(_('created at'), auto_now_add=True)
 
    updated_at = models.DateTimeField(_('updated at'), auto_now=True)
 
    is_internal = models.BooleanField(
        _('is_internal'),
        default=False,
        help_text='Whether the model provider connection is internal, not visible to the user',
        null=True,
        blank=True,
    )
 
    budget_limit = models.FloatField(
        _('budget_limit'),
        null=True,
        blank=True,
        default=None,
        help_text='Budget limit for the model provider connection (null if unlimited)',
    )
 
    budget_last_reset_date = models.DateTimeField(
        _('budget_last_reset_date'),
        null=True,
        blank=True,
        default=None,
        help_text='Date and time the budget was last reset',
    )
 
    budget_reset_period = models.CharField(
        _('budget_reset_period'),
        max_length=20,
        choices=[
            ('Monthly', 'Monthly'),
            ('Yearly', 'Yearly'),
        ],
        null=True,
        blank=True,
        default=None,
        help_text='Budget reset period for the model provider connection (null if not reset)',
    )
 
    budget_total_spent = models.FloatField(
        _('budget_total_spent'),
        null=True,
        blank=True,
        default=None,
        help_text='Tracked total budget spent for the given provider connection within the current budget period',
    )
 
    budget_alert_threshold = models.FloatField(
        _('budget_alert_threshold'),
        null=True,
        blank=True,
        default=None,
        help_text='Budget alert threshold for the given provider connection',
    )
 
    # Check if user is Admin or Owner
    # This will need to be updated if we ever use this model in LSO as `is_owner` and
    # `is_administrator` only exist in LSE
    def has_permission(self, user):
        return (
            user.is_administrator or user.is_owner or user.is_manager
        ) and user.active_organization_id == self.organization_id
 
    def update_budget_total_spent_from_predictions_meta(self, predictions_meta: List[PredictionMeta]):
        total_cost = sum(meta.total_cost or 0 for meta in predictions_meta)
        # opting for the goofy "self.budget_total_spent or 0" to avoid a db migration
        self.budget_total_spent = (self.budget_total_spent or 0) + total_cost
        self.save(update_fields=['budget_total_spent'])
 
    def has_reached_budget_limit(self):
        if (
            self.is_internal
            and self.budget_total_spent
            and self.budget_limit
            and self.budget_total_spent > self.budget_limit
        ):
            logger.info(
                f'Model connection {self.id} has reached the budget limit: '
                f'{self.budget_total_spent} > {self.budget_limit}'
            )
            return True
        return False