Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Some fixes on Persona CRUD #1241

Merged
merged 1 commit into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions src/codegate/api/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ async def get_workspace_token_usage(workspace_name: str) -> v1_models.TokenUsage
async def list_personas() -> List[Persona]:
"""List all personas."""
try:
personas = await dbreader.get_all_personas()
personas = await persona_manager.get_all_personas()
return personas
except Exception:
logger.exception("Error while getting personas")
Expand All @@ -688,15 +688,11 @@ async def list_personas() -> List[Persona]:
async def get_persona(persona_name: str) -> Persona:
"""Get a persona by name."""
try:
persona = await dbreader.get_persona_by_name(persona_name)
if not persona:
raise HTTPException(status_code=404, detail=f"Persona {persona_name} not found")
persona = await persona_manager.get_persona(persona_name)
return persona
except Exception as e:
if isinstance(e, HTTPException):
raise e
logger.exception(f"Error while getting persona {persona_name}")
raise HTTPException(status_code=500, detail="Internal server error")
except PersonaDoesNotExistError:
logger.exception("Error while getting persona")
raise HTTPException(status_code=404, detail="Persona does not exist")


@v1.post("/personas", tags=["Personas"], generate_unique_id_function=uniq_name, status_code=201)
Expand All @@ -712,6 +708,15 @@ async def create_persona(request: v1_models.PersonaRequest) -> Persona:
except AlreadyExistsError:
logger.exception("Error while creating persona")
raise HTTPException(status_code=409, detail="Persona already exists")
except ValidationError:
logger.exception("Error while creating persona")
raise HTTPException(
status_code=400,
detail=(
"Persona has invalid name, check is alphanumeric "
"and only contains dashes and underscores"
),
)
except Exception:
logger.exception("Error while creating persona")
raise HTTPException(status_code=500, detail="Internal server error")
Expand All @@ -735,6 +740,15 @@ async def update_persona(persona_name: str, request: v1_models.PersonaUpdateRequ
except AlreadyExistsError:
logger.exception("Error while updating persona")
raise HTTPException(status_code=409, detail="Persona already exists")
except ValidationError:
logger.exception("Error while creating persona")
raise HTTPException(
status_code=400,
detail=(
"Persona has invalid name, check is alphanumeric "
"and only contains dashes and underscores"
),
)
except Exception:
logger.exception("Error while updating persona")
raise HTTPException(status_code=500, detail="Internal server error")
Expand Down
21 changes: 20 additions & 1 deletion src/codegate/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
from typing import Annotated, Any, Dict, List, Optional

import numpy as np
from pydantic import BaseModel, BeforeValidator, ConfigDict, PlainSerializer, StringConstraints
import regex as re
from pydantic import (
BaseModel,
BeforeValidator,
ConfigDict,
PlainSerializer,
StringConstraints,
field_validator,
)


class AlertSeverity(str, Enum):
Expand Down Expand Up @@ -266,6 +274,8 @@ def nd_array_custom_serializer(x):
PlainSerializer(nd_array_custom_serializer, return_type=str),
]

VALID_PERSONA_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_ -]+$")


class Persona(BaseModel):
"""
Expand All @@ -276,6 +286,15 @@ class Persona(BaseModel):
name: str
description: str

@field_validator("name", mode="after")
@classmethod
def validate_persona_name(cls, value: str) -> str:
if VALID_PERSONA_NAME_PATTERN.match(value):
return value
raise ValueError(
"Invalid persona name. It should be alphanumeric with underscores and dashes."
)


class PersonaEmbedding(Persona):
"""
Expand Down
17 changes: 16 additions & 1 deletion src/codegate/muxing/persona.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unicodedata
import uuid
from typing import Optional
from typing import List, Optional

import numpy as np
import regex as re
Expand Down Expand Up @@ -165,6 +165,21 @@ async def add_persona(self, persona_name: str, persona_desc: str) -> None:
await self._db_recorder.add_persona(new_persona)
logger.info(f"Added persona {persona_name} to the database.")

async def get_persona(self, persona_name: str) -> db_models.Persona:
"""
Get a persona from the database by name.
"""
persona = await self._db_reader.get_persona_by_name(persona_name)
if not persona:
raise PersonaDoesNotExistError(f"Persona {persona_name} does not exist.")
return persona

async def get_all_personas(self) -> List[db_models.Persona]:
"""
Get all personas from the database.
"""
return await self._db_reader.get_all_personas()

async def update_persona(
self, persona_name: str, new_persona_name: str, new_persona_desc: str
) -> None:
Expand Down
47 changes: 42 additions & 5 deletions tests/muxing/test_persona.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import List

import pytest
from pydantic import BaseModel
from pydantic import BaseModel, ValidationError

from codegate.db import connection
from codegate.muxing.persona import (
Expand Down Expand Up @@ -54,7 +54,7 @@ async def test_add_persona(semantic_router_mocked_db: PersonaManager):
persona_name = "test_persona"
persona_desc = "test_persona_desc"
await semantic_router_mocked_db.add_persona(persona_name, persona_desc)
retrieved_persona = await semantic_router_mocked_db._db_reader.get_persona_by_name(persona_name)
retrieved_persona = await semantic_router_mocked_db.get_persona(persona_name)
assert retrieved_persona.name == persona_name
assert retrieved_persona.description == persona_desc

Expand All @@ -72,6 +72,18 @@ async def test_add_duplicate_persona(semantic_router_mocked_db: PersonaManager):
await semantic_router_mocked_db.add_persona(persona_name, updated_description)


@pytest.mark.asyncio
async def test_add_persona_invalid_name(semantic_router_mocked_db: PersonaManager):
"""Test adding a persona to the database."""
persona_name = "test_persona&"
persona_desc = "test_persona_desc"
with pytest.raises(ValidationError):
await semantic_router_mocked_db.add_persona(persona_name, persona_desc)

with pytest.raises(PersonaDoesNotExistError):
await semantic_router_mocked_db.delete_persona(persona_name)


@pytest.mark.asyncio
async def test_persona_not_exist_match(semantic_router_mocked_db: PersonaManager):
"""Test checking persona match when persona does not exist"""
Expand Down Expand Up @@ -235,7 +247,7 @@ class PersonaMatchTest(BaseModel):

# DevOps/SRE Engineer Persona
devops_sre = PersonaMatchTest(
persona_name="devops/sre engineer",
persona_name="devops sre engineer",
persona_desc="""
Expert in infrastructure automation, deployment pipelines, and operational reliability.
Specializes in building and maintaining scalable, resilient, and secure infrastructure.
Expand Down Expand Up @@ -441,8 +453,8 @@ async def test_delete_persona(semantic_router_mocked_db: PersonaManager):

await semantic_router_mocked_db.delete_persona(persona_name)

persona_found = await semantic_router_mocked_db._db_reader.get_persona_by_name(persona_name)
assert persona_found is None
with pytest.raises(PersonaDoesNotExistError):
await semantic_router_mocked_db.get_persona(persona_name)


@pytest.mark.asyncio
Expand All @@ -451,3 +463,28 @@ async def test_delete_persona_not_exists(semantic_router_mocked_db: PersonaManag

with pytest.raises(PersonaDoesNotExistError):
await semantic_router_mocked_db.delete_persona(persona_name)


@pytest.mark.asyncio
async def test_get_personas(semantic_router_mocked_db: PersonaManager):
"""Test getting personas from the database."""
persona_name = "test_persona"
persona_desc = "test_persona_desc"
await semantic_router_mocked_db.add_persona(persona_name, persona_desc)

persona_name_2 = "test_persona_2"
persona_desc_2 = "foo and bar"
await semantic_router_mocked_db.add_persona(persona_name_2, persona_desc_2)

all_personas = await semantic_router_mocked_db.get_all_personas()
assert len(all_personas) == 2
assert all_personas[0].name == persona_name
assert all_personas[1].name == persona_name_2


@pytest.mark.asyncio
async def test_get_personas_empty(semantic_router_mocked_db: PersonaManager):
"""Test adding a persona to the database."""

all_personas = await semantic_router_mocked_db.get_all_personas()
assert len(all_personas) == 0