Bin
2025-12-17 262fecaa75b2909ad244f12c3b079ed3ff4ae329
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# Generated by Django 3.2.25 on 2024-09-12 21:59
 
from django.db import migrations, models, transaction
import django.db.models.deletion
import django_migration_linter as linter
from core.redis import start_job_async_or_sync
from ml_models.models import ThirdPartyModelVersion
from ml_model_providers.models import ModelProviderConnection, ModelProviders
 
 
def _fill_model_version_model_provider_connection(db_alias: str):
    for provider in [ModelProviders.OPENAI, ModelProviders.AZURE_OPENAI]:
        this_provider_model_versions = (
            ThirdPartyModelVersion.objects.using(db_alias)
            .filter(provider=provider)
            .values('id', 'organization_id', 'provider_model_id')
        )
        for provider_model_version in this_provider_model_versions:
            connection_ids = ModelProviderConnection.objects.using(db_alias).filter(
                organization_id=provider_model_version['organization_id'],
                provider=provider,
                **({'deployment_name': provider_model_version['provider_model_id']} if provider == ModelProviders.AZURE_OPENAI else {}),
            ).values_list('id', flat=True)[:1]
            connection_id = connection_ids[0] if connection_ids else None
            ThirdPartyModelVersion.objects.using(db_alias).filter(id=provider_model_version['id']).update(model_provider_connection_id=connection_id)
 
def forwards(apps, schema_editor):
    db_alias = schema_editor.connection.alias
    start_job_async_or_sync(_fill_model_version_model_provider_connection, db_alias=db_alias)
 
 
def backwards(apps, schema_editor):
    pass
 
 
class Migration(migrations.Migration):
    atomic = False
 
    dependencies = [
        ('ml_model_providers', '0003_modelproviderconnection_cached_available_models'),
        ('ml_models', '0010_modelinterface_skill_name'),
    ]
 
    operations = [
        linter.IgnoreMigration(),
        migrations.AddField(
            model_name='thirdpartymodelversion',
            name='model_provider_connection',
            field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='model_versions', to='ml_model_providers.modelproviderconnection'),
        ),
        migrations.RunPython(forwards, backwards)
    ]