Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Added a class which performs semantic routing #1192

Merged
merged 2 commits into from
Mar 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""add persona table

Revision ID: 02b710eda156
Revises: 5e5cd2288147
Create Date: 2025-03-03 10:08:16.206617+00:00

"""

from typing import Sequence, Union

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "02b710eda156"
down_revision: Union[str, None] = "5e5cd2288147"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# Begin transaction
op.execute("BEGIN TRANSACTION;")

op.execute(
"""
CREATE TABLE IF NOT EXISTS personas (
id TEXT PRIMARY KEY, -- UUID stored as TEXT
name TEXT NOT NULL UNIQUE,
description TEXT NOT NULL,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to make descriptions unique as well? If someone adds two similar descriptions, it would be very hard for the matcher to work properly. Perhaps enforcing uniqueness is the way to go as a first step, and in a further iteration we could check for description similarity. wdyt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a really nice suggestion! But actually making the descriptions unique won't cut it. If the difference is a single letter then we will accept the new description. What I will do is to check the cosine distance to the existing descriptions and only accept a new persona if it's sufficiently different. Will upload a commit soon

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool beans

description_embedding BLOB NOT NULL
);
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a good idea to have personas not be namespaced within a workspace since this allows us to share personas between workspaces. Do you think we should also add a namespaced persona concept? This is not a blocker and if we decide a namespaced persona makes sense, this can be left as a TODO for another PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uuhm . Probably the concept of persona namespaces makes sense. Although right now I can't think on what would be the difference wrt. our workspaces. In other words, I like the idea but lack the use cases atm. Lets introduce when we need them

)

# Finish transaction
op.execute("COMMIT;")


def downgrade() -> None:
# Begin transaction
op.execute("BEGIN TRANSACTION;")

op.execute(
"""
DROP TABLE personas;
"""
)

# Finish transaction
op.execute("COMMIT;")
3 changes: 3 additions & 0 deletions src/codegate/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ class Config:
force_certs: bool = False

max_fim_hash_lifetime: int = 60 * 5 # Time in seconds. Default is 5 minutes.
# Min value is 0 (max similarity), max value is 2 (orthogonal)
# The value 0.75 was found through experimentation. See /tests/muxing/test_semantic_router.py
persona_threshold = 0.75

# Provider URLs with defaults
provider_urls: Dict[str, str] = field(default_factory=lambda: DEFAULT_PROVIDER_URLS.copy())
Expand Down
101 changes: 100 additions & 1 deletion src/codegate/db/connection.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import asyncio
import json
import sqlite3
import uuid
from pathlib import Path
from typing import Dict, List, Optional, Type

import numpy as np
import sqlite_vec_sl_tmp
import structlog
from alembic import command as alembic_command
from alembic.config import Config as AlembicConfig
Expand All @@ -22,6 +25,9 @@
IntermediatePromptWithOutputUsageAlerts,
MuxRule,
Output,
Persona,
PersonaDistance,
PersonaEmbedding,
Prompt,
ProviderAuthMaterial,
ProviderEndpoint,
Expand Down Expand Up @@ -65,7 +71,7 @@ def __new__(cls, *args, **kwargs):
# It should only be used for testing
if "_no_singleton" in kwargs and kwargs["_no_singleton"]:
kwargs.pop("_no_singleton")
return super().__new__(cls, *args, **kwargs)
return super().__new__(cls)

if cls._instance is None:
cls._instance = super().__new__(cls)
Expand All @@ -92,6 +98,22 @@ def __init__(self, sqlite_path: Optional[str] = None, **kwargs):
}
self._async_db_engine = create_async_engine(**engine_dict)

def _get_vec_db_connection(self):
"""
Vector database connection is a separate connection to the SQLite database. aiosqlite
does not support loading extensions, so we need to use the sqlite3 module to load the
vector extension.
"""
try:
conn = sqlite3.connect(self._db_path)
conn.enable_load_extension(True)
sqlite_vec_sl_tmp.load(conn)
conn.enable_load_extension(False)
return conn
except Exception:
logger.exception("Failed to initialize vector database connection")
raise

def does_db_exist(self):
return self._db_path.is_file()

Expand Down Expand Up @@ -523,6 +545,30 @@ async def add_mux(self, mux: MuxRule) -> MuxRule:
added_mux = await self._execute_update_pydantic_model(mux, sql, should_raise=True)
return added_mux

async def add_persona(self, persona: PersonaEmbedding) -> None:
"""Add a new Persona to the DB.

This handles validation and insertion of a new persona.

It may raise a AlreadyExistsError if the persona already exists.
"""
sql = text(
"""
INSERT INTO personas (id, name, description, description_embedding)
VALUES (:id, :name, :description, :description_embedding)
"""
)

try:
# For Pydantic we convert the numpy array to string when serializing with .model_dumpy()
# We need to convert it back to a numpy array before inserting it into the DB.
persona_dict = persona.model_dump()
persona_dict["description_embedding"] = persona.description_embedding
await self._execute_with_no_return(sql, persona_dict)
except IntegrityError as e:
logger.debug(f"Exception type: {type(e)}")
raise AlreadyExistsError(f"Persona '{persona.name}' already exists.")


class DbReader(DbCodeGate):
def __init__(self, sqlite_path: Optional[str] = None, *args, **kwargs):
Expand Down Expand Up @@ -569,6 +615,20 @@ async def _exec_select_conditions_to_pydantic(
raise e
return None

async def _exec_vec_db_query_to_pydantic(
self, sql_command: str, conditions: dict, model_type: Type[BaseModel]
) -> List[BaseModel]:
"""
Execute a query on the vector database. This is a separate connection to the SQLite
database that has the vector extension loaded.
"""
conn = self._get_vec_db_connection()
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
results = [model_type(**row) for row in cursor.execute(sql_command, conditions)]
conn.close()
return results

async def get_prompts_with_output(self, workpace_id: str) -> List[GetPromptWithOutputsRow]:
sql = text(
"""
Expand Down Expand Up @@ -893,6 +953,45 @@ async def get_muxes_by_workspace(self, workspace_id: str) -> List[MuxRule]:
)
return muxes

async def get_persona_by_name(self, persona_name: str) -> Optional[Persona]:
"""
Get a persona by name.
"""
sql = text(
"""
SELECT
id, name, description
FROM personas
WHERE name = :name
"""
)
conditions = {"name": persona_name}
personas = await self._exec_select_conditions_to_pydantic(
Persona, sql, conditions, should_raise=True
)
return personas[0] if personas else None

async def get_distance_to_persona(
self, persona_id: str, query_embedding: np.ndarray
) -> PersonaDistance:
"""
Get the distance between a persona and a query embedding.
"""
sql = """
SELECT
id,
name,
description,
vec_distance_cosine(description_embedding, :query_embedding) as distance
FROM personas
WHERE id = :id
"""
conditions = {"id": persona_id, "query_embedding": query_embedding}
persona_distance = await self._exec_vec_db_query_to_pydantic(
sql, conditions, PersonaDistance
)
return persona_distance[0]


def init_db_sync(db_path: Optional[str] = None):
"""DB will be initialized in the constructor in case it doesn't exist."""
Expand Down
58 changes: 57 additions & 1 deletion src/codegate/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from enum import Enum
from typing import Annotated, Any, Dict, List, Optional

from pydantic import BaseModel, StringConstraints
import numpy as np
from pydantic import BaseModel, BeforeValidator, ConfigDict, PlainSerializer, StringConstraints


class AlertSeverity(str, Enum):
Expand Down Expand Up @@ -240,3 +241,58 @@ class MuxRule(BaseModel):
priority: int
created_at: Optional[datetime.datetime] = None
updated_at: Optional[datetime.datetime] = None


def nd_array_custom_before_validator(x):
# custome before validation logic
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you might want to reclarify this comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! Let me know if more clarification is needed

return x


def nd_array_custom_serializer(x):
# custome serialization logic
return str(x)


# Pydantic doesn't support numpy arrays out of the box hence we need to construct a custom type.
# There are 2 things necessary for a Pydantic custom type: Validator and Serializer
# The lines below build our custom type
# Docs: https://docs.pydantic.dev/latest/concepts/types/#adding-validation-and-serialization
# Open Pydantic issue for npy support: https://github.com/pydantic/pydantic/issues/7017
NdArray = Annotated[
np.ndarray,
BeforeValidator(nd_array_custom_before_validator),
PlainSerializer(nd_array_custom_serializer, return_type=str),
]


class Persona(BaseModel):
"""
Represents a persona object.
"""

id: str
name: str
description: str


class PersonaEmbedding(Persona):
"""
Represents a persona object with an embedding.
"""

description_embedding: NdArray

# Part of the workaround to allow numpy arrays in pydantic models
model_config = ConfigDict(arbitrary_types_allowed=True)


class PersonaDistance(Persona):
"""
Result of an SQL query to get the distance between the query and the persona description.

A vector similarity search is performed to get the distance. Distance values ranges [0, 2].
0 means the vectors are identical, 2 means they are orthogonal.
See [sqlite docs](https://alexgarcia.xyz/sqlite-vec/api-reference.html#vec_distance_cosine)
"""

distance: float
Loading