Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

don't delete models on provider update #836

Merged
merged 1 commit into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,14 +469,15 @@ async def add_provider_model(self, model: ProviderModel) -> ProviderModel:
added_model = await self._execute_update_pydantic_model(model, sql, should_raise=True)
return added_model

async def delete_provider_models(self, provider_id: str):
async def delete_provider_model(self, provider_id: str, model: str) -> Optional[ProviderModel]:
sql = text(
"""
DELETE FROM provider_models
WHERE provider_endpoint_id = :provider_endpoint_id
WHERE provider_endpoint_id = :provider_endpoint_id AND name = :name
"""
)
conditions = {"provider_endpoint_id": provider_id}

conditions = {"provider_endpoint_id": provider_id, "name": model}
await self._execute_with_no_return(sql, conditions)

async def delete_muxes_by_workspace(self, workspace_id: str):
Expand Down
32 changes: 23 additions & 9 deletions src/codegate/providers/crud/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,26 +141,40 @@ async def update_endpoint(
except Exception as err:
raise ValueError("Unable to get models from provider: {}".format(str(err)))

# Reset all provider models.
await self._db_writer.delete_provider_models(str(endpoint.id))
models_set = set(models)

for model in models:
# Get the models from the provider
models_in_db = await self._db_reader.get_provider_models_by_provider_id(str(endpoint.id))

models_in_db_set = set(model.name for model in models_in_db)

# Add the models that are in the provider but not in the DB
for model in models_set - models_in_db_set:
await self._db_writer.add_provider_model(
dbmodels.ProviderModel(
provider_endpoint_id=founddbe.id,
name=model,
)
)

# Remove the models that are in the DB but not in the provider
for model in models_in_db_set - models_set:
await self._db_writer.delete_provider_model(
founddbe.id,
model,
)

dbendpoint = await self._db_writer.update_provider_endpoint(endpoint.to_db_model())

await self._db_writer.push_provider_auth_material(
dbmodels.ProviderAuthMaterial(
provider_endpoint_id=dbendpoint.id,
auth_type=endpoint.auth_type,
auth_blob=endpoint.api_key if endpoint.api_key else "",
# If an API key was provided or we've changed the auth type, we update the auth material
if endpoint.auth_type != founddbe.auth_type or endpoint.api_key:
await self._db_writer.push_provider_auth_material(
dbmodels.ProviderAuthMaterial(
provider_endpoint_id=dbendpoint.id,
auth_type=endpoint.auth_type,
auth_blob=endpoint.api_key if endpoint.api_key else "",
)
)
)

return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint)

Expand Down
Loading