diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index 0c756e2e..33efea33 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -677,7 +677,7 @@ async def get_workspace_token_usage(workspace_name: str) -> v1_models.TokenUsage async def list_personas() -> List[Persona]: """List all personas.""" try: - personas = await dbreader.get_all_personas() + personas = await persona_manager.get_all_personas() return personas except Exception: logger.exception("Error while getting personas") @@ -688,15 +688,11 @@ async def list_personas() -> List[Persona]: async def get_persona(persona_name: str) -> Persona: """Get a persona by name.""" try: - persona = await dbreader.get_persona_by_name(persona_name) - if not persona: - raise HTTPException(status_code=404, detail=f"Persona {persona_name} not found") + persona = await persona_manager.get_persona(persona_name) return persona - except Exception as e: - if isinstance(e, HTTPException): - raise e - logger.exception(f"Error while getting persona {persona_name}") - raise HTTPException(status_code=500, detail="Internal server error") + except PersonaDoesNotExistError: + logger.exception("Error while getting persona") + raise HTTPException(status_code=404, detail="Persona does not exist") @v1.post("/personas", tags=["Personas"], generate_unique_id_function=uniq_name, status_code=201) @@ -712,6 +708,15 @@ async def create_persona(request: v1_models.PersonaRequest) -> Persona: except AlreadyExistsError: logger.exception("Error while creating persona") raise HTTPException(status_code=409, detail="Persona already exists") + except ValidationError: + logger.exception("Error while creating persona") + raise HTTPException( + status_code=400, + detail=( + "Persona has invalid name, check is alphanumeric " + "and only contains dashes and underscores" + ), + ) except Exception: logger.exception("Error while creating persona") raise HTTPException(status_code=500, detail="Internal server error") @@ -735,6 +740,15 @@ async def update_persona(persona_name: str, request: v1_models.PersonaUpdateRequ except AlreadyExistsError: logger.exception("Error while updating persona") raise HTTPException(status_code=409, detail="Persona already exists") + except ValidationError: + logger.exception("Error while creating persona") + raise HTTPException( + status_code=400, + detail=( + "Persona has invalid name, check is alphanumeric " + "and only contains dashes and underscores" + ), + ) except Exception: logger.exception("Error while updating persona") raise HTTPException(status_code=500, detail="Internal server error") diff --git a/src/codegate/db/models.py b/src/codegate/db/models.py index 6f146b34..f9f63614 100644 --- a/src/codegate/db/models.py +++ b/src/codegate/db/models.py @@ -3,7 +3,15 @@ from typing import Annotated, Any, Dict, List, Optional import numpy as np -from pydantic import BaseModel, BeforeValidator, ConfigDict, PlainSerializer, StringConstraints +import regex as re +from pydantic import ( + BaseModel, + BeforeValidator, + ConfigDict, + PlainSerializer, + StringConstraints, + field_validator, +) class AlertSeverity(str, Enum): @@ -266,6 +274,8 @@ def nd_array_custom_serializer(x): PlainSerializer(nd_array_custom_serializer, return_type=str), ] +VALID_PERSONA_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_ -]+$") + class Persona(BaseModel): """ @@ -276,6 +286,15 @@ class Persona(BaseModel): name: str description: str + @field_validator("name", mode="after") + @classmethod + def validate_persona_name(cls, value: str) -> str: + if VALID_PERSONA_NAME_PATTERN.match(value): + return value + raise ValueError( + "Invalid persona name. It should be alphanumeric with underscores and dashes." + ) + class PersonaEmbedding(Persona): """ diff --git a/src/codegate/muxing/persona.py b/src/codegate/muxing/persona.py index 615b3256..ac21205c 100644 --- a/src/codegate/muxing/persona.py +++ b/src/codegate/muxing/persona.py @@ -1,6 +1,6 @@ import unicodedata import uuid -from typing import Optional +from typing import List, Optional import numpy as np import regex as re @@ -165,6 +165,21 @@ async def add_persona(self, persona_name: str, persona_desc: str) -> None: await self._db_recorder.add_persona(new_persona) logger.info(f"Added persona {persona_name} to the database.") + async def get_persona(self, persona_name: str) -> db_models.Persona: + """ + Get a persona from the database by name. + """ + persona = await self._db_reader.get_persona_by_name(persona_name) + if not persona: + raise PersonaDoesNotExistError(f"Persona {persona_name} does not exist.") + return persona + + async def get_all_personas(self) -> List[db_models.Persona]: + """ + Get all personas from the database. + """ + return await self._db_reader.get_all_personas() + async def update_persona( self, persona_name: str, new_persona_name: str, new_persona_desc: str ) -> None: diff --git a/tests/muxing/test_persona.py b/tests/muxing/test_persona.py index 4e221d8a..fd0003c9 100644 --- a/tests/muxing/test_persona.py +++ b/tests/muxing/test_persona.py @@ -3,7 +3,7 @@ from typing import List import pytest -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from codegate.db import connection from codegate.muxing.persona import ( @@ -54,7 +54,7 @@ async def test_add_persona(semantic_router_mocked_db: PersonaManager): persona_name = "test_persona" persona_desc = "test_persona_desc" await semantic_router_mocked_db.add_persona(persona_name, persona_desc) - retrieved_persona = await semantic_router_mocked_db._db_reader.get_persona_by_name(persona_name) + retrieved_persona = await semantic_router_mocked_db.get_persona(persona_name) assert retrieved_persona.name == persona_name assert retrieved_persona.description == persona_desc @@ -72,6 +72,18 @@ async def test_add_duplicate_persona(semantic_router_mocked_db: PersonaManager): await semantic_router_mocked_db.add_persona(persona_name, updated_description) +@pytest.mark.asyncio +async def test_add_persona_invalid_name(semantic_router_mocked_db: PersonaManager): + """Test adding a persona to the database.""" + persona_name = "test_persona&" + persona_desc = "test_persona_desc" + with pytest.raises(ValidationError): + await semantic_router_mocked_db.add_persona(persona_name, persona_desc) + + with pytest.raises(PersonaDoesNotExistError): + await semantic_router_mocked_db.delete_persona(persona_name) + + @pytest.mark.asyncio async def test_persona_not_exist_match(semantic_router_mocked_db: PersonaManager): """Test checking persona match when persona does not exist""" @@ -235,7 +247,7 @@ class PersonaMatchTest(BaseModel): # DevOps/SRE Engineer Persona devops_sre = PersonaMatchTest( - persona_name="devops/sre engineer", + persona_name="devops sre engineer", persona_desc=""" Expert in infrastructure automation, deployment pipelines, and operational reliability. Specializes in building and maintaining scalable, resilient, and secure infrastructure. @@ -441,8 +453,8 @@ async def test_delete_persona(semantic_router_mocked_db: PersonaManager): await semantic_router_mocked_db.delete_persona(persona_name) - persona_found = await semantic_router_mocked_db._db_reader.get_persona_by_name(persona_name) - assert persona_found is None + with pytest.raises(PersonaDoesNotExistError): + await semantic_router_mocked_db.get_persona(persona_name) @pytest.mark.asyncio @@ -451,3 +463,28 @@ async def test_delete_persona_not_exists(semantic_router_mocked_db: PersonaManag with pytest.raises(PersonaDoesNotExistError): await semantic_router_mocked_db.delete_persona(persona_name) + + +@pytest.mark.asyncio +async def test_get_personas(semantic_router_mocked_db: PersonaManager): + """Test getting personas from the database.""" + persona_name = "test_persona" + persona_desc = "test_persona_desc" + await semantic_router_mocked_db.add_persona(persona_name, persona_desc) + + persona_name_2 = "test_persona_2" + persona_desc_2 = "foo and bar" + await semantic_router_mocked_db.add_persona(persona_name_2, persona_desc_2) + + all_personas = await semantic_router_mocked_db.get_all_personas() + assert len(all_personas) == 2 + assert all_personas[0].name == persona_name + assert all_personas[1].name == persona_name_2 + + +@pytest.mark.asyncio +async def test_get_personas_empty(semantic_router_mocked_db: PersonaManager): + """Test adding a persona to the database.""" + + all_personas = await semantic_router_mocked_db.get_all_personas() + assert len(all_personas) == 0