# Generated by Django 3.2.25 on 2024-09-12 21:59 from django.db import migrations, models, transaction import django.db.models.deletion import django_migration_linter as linter from core.redis import start_job_async_or_sync from ml_models.models import ThirdPartyModelVersion from ml_model_providers.models import ModelProviderConnection, ModelProviders def _fill_model_version_model_provider_connection(db_alias: str): for provider in [ModelProviders.OPENAI, ModelProviders.AZURE_OPENAI]: this_provider_model_versions = ( ThirdPartyModelVersion.objects.using(db_alias) .filter(provider=provider) .values('id', 'organization_id', 'provider_model_id') ) for provider_model_version in this_provider_model_versions: connection_ids = ModelProviderConnection.objects.using(db_alias).filter( organization_id=provider_model_version['organization_id'], provider=provider, **({'deployment_name': provider_model_version['provider_model_id']} if provider == ModelProviders.AZURE_OPENAI else {}), ).values_list('id', flat=True)[:1] connection_id = connection_ids[0] if connection_ids else None ThirdPartyModelVersion.objects.using(db_alias).filter(id=provider_model_version['id']).update(model_provider_connection_id=connection_id) def forwards(apps, schema_editor): db_alias = schema_editor.connection.alias start_job_async_or_sync(_fill_model_version_model_provider_connection, db_alias=db_alias) def backwards(apps, schema_editor): pass class Migration(migrations.Migration): atomic = False dependencies = [ ('ml_model_providers', '0003_modelproviderconnection_cached_available_models'), ('ml_models', '0010_modelinterface_skill_name'), ] operations = [ linter.IgnoreMigration(), migrations.AddField( model_name='thirdpartymodelversion', name='model_provider_connection', field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='model_versions', to='ml_model_providers.modelproviderconnection'), ), migrations.RunPython(forwards, backwards) ]