From d22abcc2dee1ecf617f2b9bb4cae4be09478ee49 Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Mon, 3 Mar 2025 11:55:12 +0200 Subject: [PATCH 1/2] Added a class which performs semantic routing Related to: #1055 For the current implementation of muxing we only need to match a single Persona at a time. For example: 1. mux1 -> persona Architect -> openai o1 2. mux2 -> catch all -> openai gpt4o In the above case we would only need to know if the request matches the persona `Architect`. It's not needed to match any extra personas even if they exist in DB. This PR introduces what's necessary to do the above without actually wiring in muxing rules. The PR: - Creates the persona table in DB - Adds methods to write and read to the new persona table - Implements a function to check if a query matches to the specified persona To check more about the personas and the queries please check the unit tests --- ..._03_1008-02b710eda156_add_persona_table.py | 50 ++ src/codegate/config.py | 1 + src/codegate/db/connection.py | 103 ++- src/codegate/db/models.py | 39 +- src/codegate/muxing/semantic_router.py | 129 ++++ tests/muxing/test_semantic_router.py | 590 ++++++++++++++++++ 6 files changed, 910 insertions(+), 2 deletions(-) create mode 100644 migrations/versions/2025_03_03_1008-02b710eda156_add_persona_table.py create mode 100644 src/codegate/muxing/semantic_router.py create mode 100644 tests/muxing/test_semantic_router.py diff --git a/migrations/versions/2025_03_03_1008-02b710eda156_add_persona_table.py b/migrations/versions/2025_03_03_1008-02b710eda156_add_persona_table.py new file mode 100644 index 00000000..e6b90a46 --- /dev/null +++ b/migrations/versions/2025_03_03_1008-02b710eda156_add_persona_table.py @@ -0,0 +1,50 @@ +"""add persona table + +Revision ID: 02b710eda156 +Revises: 5e5cd2288147 +Create Date: 2025-03-03 10:08:16.206617+00:00 + +""" + +from typing import Sequence, Union + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "02b710eda156" +down_revision: Union[str, None] = "5e5cd2288147" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Begin transaction + op.execute("BEGIN TRANSACTION;") + + op.execute( + """ + CREATE TABLE IF NOT EXISTS personas ( + id TEXT PRIMARY KEY, -- UUID stored as TEXT + name TEXT NOT NULL UNIQUE, + description TEXT NOT NULL, + description_embedding BLOB NOT NULL + ); + """ + ) + + # Finish transaction + op.execute("COMMIT;") + + +def downgrade() -> None: + # Begin transaction + op.execute("BEGIN TRANSACTION;") + + op.execute( + """ + DROP TABLE personas; + """ + ) + + # Finish transaction + op.execute("COMMIT;") diff --git a/src/codegate/config.py b/src/codegate/config.py index 11cd96bf..8b177056 100644 --- a/src/codegate/config.py +++ b/src/codegate/config.py @@ -57,6 +57,7 @@ class Config: force_certs: bool = False max_fim_hash_lifetime: int = 60 * 5 # Time in seconds. Default is 5 minutes. + persona_threshold = 0.75 # Min value is 0 (max similarity), max value is 2 (orthogonal) # Provider URLs with defaults provider_urls: Dict[str, str] = field(default_factory=lambda: DEFAULT_PROVIDER_URLS.copy()) diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 2d56fccd..78eaa9c7 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -1,9 +1,12 @@ import asyncio import json +import sqlite3 import uuid from pathlib import Path from typing import Dict, List, Optional, Type +import numpy as np +import sqlite_vec_sl_tmp import structlog from alembic import command as alembic_command from alembic.config import Config as AlembicConfig @@ -22,6 +25,9 @@ IntermediatePromptWithOutputUsageAlerts, MuxRule, Output, + Persona, + PersonaDistance, + PersonaEmbedding, Prompt, ProviderAuthMaterial, ProviderEndpoint, @@ -65,7 +71,7 @@ def __new__(cls, *args, **kwargs): # It should only be used for testing if "_no_singleton" in kwargs and kwargs["_no_singleton"]: kwargs.pop("_no_singleton") - return super().__new__(cls, *args, **kwargs) + return super().__new__(cls) if cls._instance is None: cls._instance = super().__new__(cls) @@ -92,6 +98,22 @@ def __init__(self, sqlite_path: Optional[str] = None, **kwargs): } self._async_db_engine = create_async_engine(**engine_dict) + def _get_vec_db_connection(self): + """ + Vector database connection is a separate connection to the SQLite database. aiosqlite + does not support loading extensions, so we need to use the sqlite3 module to load the + vector extension. + """ + try: + conn = sqlite3.connect(self._db_path) + conn.enable_load_extension(True) + sqlite_vec_sl_tmp.load(conn) + conn.enable_load_extension(False) + return conn + except Exception: + logger.exception("Failed to initialize vector database connection") + raise + def does_db_exist(self): return self._db_path.is_file() @@ -523,6 +545,30 @@ async def add_mux(self, mux: MuxRule) -> MuxRule: added_mux = await self._execute_update_pydantic_model(mux, sql, should_raise=True) return added_mux + async def add_persona(self, persona: PersonaEmbedding) -> None: + """Add a new Persona to the DB. + + This handles validation and insertion of a new persona. + + It may raise a AlreadyExistsError if the persona already exists. + """ + sql = text( + """ + INSERT INTO personas (id, name, description, description_embedding) + VALUES (:id, :name, :description, :description_embedding) + """ + ) + + try: + # For Pydantic we conver the numpy array to a string when serializing. + # We need to convert it back to a numpy array before inserting it into the DB. + persona_dict = persona.model_dump() + persona_dict["description_embedding"] = persona.description_embedding + await self._execute_with_no_return(sql, persona_dict) + except IntegrityError as e: + logger.debug(f"Exception type: {type(e)}") + raise AlreadyExistsError(f"Persona '{persona.name}' already exists.") + class DbReader(DbCodeGate): def __init__(self, sqlite_path: Optional[str] = None, *args, **kwargs): @@ -569,6 +615,18 @@ async def _exec_select_conditions_to_pydantic( raise e return None + async def _exec_vec_db_query( + self, sql_command: str, conditions: dict + ) -> Optional[CursorResult]: + """ + Execute a query on the vector database. This is a separate connection to the SQLite + database that has the vector extension loaded. + """ + conn = self._get_vec_db_connection() + cursor = conn.cursor() + cursor.execute(sql_command, conditions) + return cursor + async def get_prompts_with_output(self, workpace_id: str) -> List[GetPromptWithOutputsRow]: sql = text( """ @@ -893,6 +951,49 @@ async def get_muxes_by_workspace(self, workspace_id: str) -> List[MuxRule]: ) return muxes + async def get_persona_by_name(self, persona_name: str) -> Optional[Persona]: + """ + Get a persona by name. + """ + sql = text( + """ + SELECT + id, name, description + FROM personas + WHERE name = :name + """ + ) + conditions = {"name": persona_name} + personas = await self._exec_select_conditions_to_pydantic( + Persona, sql, conditions, should_raise=True + ) + return personas[0] if personas else None + + async def get_distance_to_persona( + self, persona_id: str, query_embedding: np.ndarray + ) -> PersonaDistance: + """ + Get the distance between a persona and a query embedding. + """ + sql = """ + SELECT + id, + name, + description, + vec_distance_cosine(description_embedding, :query_embedding) as distance + FROM personas + WHERE id = :id + """ + conditions = {"id": persona_id, "query_embedding": query_embedding} + persona_distance_cursor = await self._exec_vec_db_query(sql, conditions) + persona_distance_raw = persona_distance_cursor.fetchone() + return PersonaDistance( + id=persona_distance_raw[0], + name=persona_distance_raw[1], + description=persona_distance_raw[2], + distance=persona_distance_raw[3], + ) + def init_db_sync(db_path: Optional[str] = None): """DB will be initialized in the constructor in case it doesn't exist.""" diff --git a/src/codegate/db/models.py b/src/codegate/db/models.py index 8f2365a0..1fa0d9ff 100644 --- a/src/codegate/db/models.py +++ b/src/codegate/db/models.py @@ -2,7 +2,8 @@ from enum import Enum from typing import Annotated, Any, Dict, List, Optional -from pydantic import BaseModel, StringConstraints +import numpy as np +from pydantic import BaseModel, BeforeValidator, ConfigDict, PlainSerializer, StringConstraints class AlertSeverity(str, Enum): @@ -240,3 +241,39 @@ class MuxRule(BaseModel): priority: int created_at: Optional[datetime.datetime] = None updated_at: Optional[datetime.datetime] = None + + +# Pydantic doesn't support numpy arrays out of the box. Defining a custom type +# Reference: https://github.com/pydantic/pydantic/issues/7017 +def nd_array_custom_before_validator(x): + # custome before validation logic + return x + + +def nd_array_custom_serializer(x): + # custome serialization logic + return str(x) + + +NdArray = Annotated[ + np.ndarray, + BeforeValidator(nd_array_custom_before_validator), + PlainSerializer(nd_array_custom_serializer, return_type=str), +] + + +class Persona(BaseModel): + id: str + name: str + description: str + + +class PersonaEmbedding(Persona): + description_embedding: NdArray # sqlite-vec will handle numpy arrays directly + + # Part of the workaround to allow numpy arrays in pydantic models + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class PersonaDistance(Persona): + distance: float diff --git a/src/codegate/muxing/semantic_router.py b/src/codegate/muxing/semantic_router.py new file mode 100644 index 00000000..bcb48082 --- /dev/null +++ b/src/codegate/muxing/semantic_router.py @@ -0,0 +1,129 @@ +import unicodedata +import uuid + +import numpy as np +import regex as re +import structlog + +from codegate.config import Config +from codegate.db import models as db_models +from codegate.db.connection import DbReader, DbRecorder +from codegate.inference.inference_engine import LlamaCppInferenceEngine + +logger = structlog.get_logger("codegate") + + +class PersonaDoesNotExistError(Exception): + pass + + +class SemanticRouter: + + def __init__(self): + self._inference_engine = LlamaCppInferenceEngine() + conf = Config.get_config() + self._embeddings_model = f"{conf.model_base_path}/{conf.embedding_model}" + self._n_gpu = conf.chat_model_n_gpu_layers + self._persona_threshold = conf.persona_threshold + self._db_recorder = DbRecorder() + self._db_reader = DbReader() + + def _clean_text_for_embedding(self, text: str) -> str: + """ + Clean the text for embedding. This function should be used to preprocess the text + before embedding. + + Performs the following operations: + 1. Replaces newlines and carriage returns with spaces + 2. Removes extra whitespace + 3. Converts to lowercase + 4. Removes URLs and email addresses + 5. Removes code block markers and other markdown syntax + 6. Normalizes Unicode characters + 7. Handles special characters and punctuation + 8. Normalizes numbers + """ + if not text: + return "" + + # Replace newlines and carriage returns with spaces + text = text.replace("\n", " ").replace("\r", " ") + + # Normalize Unicode characters (e.g., convert accented characters to ASCII equivalents) + text = unicodedata.normalize("NFKD", text) + text = "".join([c for c in text if not unicodedata.combining(c)]) + + # Remove URLs + text = re.sub(r"https?://\S+|www\.\S+", " ", text) + + # Remove email addresses + text = re.sub(r"\S+@\S+", " ", text) + + # Remove code block markers and other markdown/code syntax + text = re.sub(r"```[\s\S]*?```", " ", text) # Code blocks + text = re.sub(r"`[^`]*`", " ", text) # Inline code + + # Remove HTML/XML tags + text = re.sub(r"<[^>]+>", " ", text) + + # Normalize numbers (replace with placeholder) + text = re.sub(r"\b\d+\.\d+\b", " NUM ", text) # Decimal numbers + text = re.sub(r"\b\d+\b", " NUM ", text) # Integer numbers + + # Replace punctuation with spaces (keeping apostrophes for contractions) + text = re.sub(r"[^\w\s\']", " ", text) + + # Normalize whitespace (replace multiple spaces with a single space) + text = re.sub(r"\s+", " ", text) + + # Convert to lowercase and strip + text = text.strip() + + return text + + async def _embed_text(self, text: str) -> np.ndarray: + """ + Helper function to embed text using the inference engine. + """ + cleaned_text = self._clean_text_for_embedding(text) + # .embed returns a list of embeddings + embed_list = await self._inference_engine.embed( + self._embeddings_model, [cleaned_text], n_gpu_layers=self._n_gpu + ) + # Use only the first entry in the list and make sure we have the appropriate type + logger.debug("Text embedded in semantic routing", text=cleaned_text[:100]) + return np.array(embed_list[0], dtype=np.float32) + + async def add_persona(self, persona_name: str, persona_desc: str) -> None: + """ + Add a new persona to the database. The persona description is embedded + and stored in the database. + """ + emb_persona_desc = await self._embed_text(persona_desc) + new_persona = db_models.PersonaEmbedding( + id=str(uuid.uuid4()), + name=persona_name, + description=persona_desc, + description_embedding=emb_persona_desc, + ) + await self._db_recorder.add_persona(new_persona) + logger.info(f"Added persona {persona_name} to the database.") + + async def check_persona_match(self, persona_name: str, query: str) -> bool: + """ + Check if the query matches the persona description. A vector similarity + search is performed between the query and the persona description. + 0 means the vectors are identical, 2 means they are orthogonal. + See + [sqlite docs](https://alexgarcia.xyz/sqlite-vec/api-reference.html#vec_distance_cosine) + """ + persona = await self._db_reader.get_persona_by_name(persona_name) + if not persona: + raise PersonaDoesNotExistError(f"Persona {persona_name} does not exist.") + + emb_query = await self._embed_text(query) + persona_distance = await self._db_reader.get_distance_to_persona(persona.id, emb_query) + logger.info(f"Persona distance to {persona_name}", distance=persona_distance.distance) + if persona_distance.distance < self._persona_threshold: + return True + return False diff --git a/tests/muxing/test_semantic_router.py b/tests/muxing/test_semantic_router.py new file mode 100644 index 00000000..c8c7edc6 --- /dev/null +++ b/tests/muxing/test_semantic_router.py @@ -0,0 +1,590 @@ +import uuid +from pathlib import Path +from typing import List + +import pytest +from pydantic import BaseModel + +from codegate.db import connection +from codegate.muxing.semantic_router import PersonaDoesNotExistError, SemanticRouter + + +@pytest.fixture +def db_path(): + """Creates a temporary database file path.""" + current_test_dir = Path(__file__).parent + db_filepath = current_test_dir / f"codegate_test_{uuid.uuid4()}.db" + db_fullpath = db_filepath.absolute() + connection.init_db_sync(str(db_fullpath)) + yield db_fullpath + if db_fullpath.is_file(): + db_fullpath.unlink() + + +@pytest.fixture() +def db_recorder(db_path) -> connection.DbRecorder: + """Creates a DbRecorder instance with test database.""" + return connection.DbRecorder(sqlite_path=db_path, _no_singleton=True) + + +@pytest.fixture() +def db_reader(db_path) -> connection.DbReader: + """Creates a DbReader instance with test database.""" + return connection.DbReader(sqlite_path=db_path, _no_singleton=True) + + +@pytest.fixture() +def semantic_router_mocked_db( + db_recorder: connection.DbRecorder, db_reader: connection.DbReader +) -> SemanticRouter: + """Creates a SemanticRouter instance with mocked database.""" + semantic_router = SemanticRouter() + semantic_router._db_reader = db_reader + semantic_router._db_recorder = db_recorder + return semantic_router + + +@pytest.mark.asyncio +async def test_add_persona(semantic_router_mocked_db: SemanticRouter): + """Test adding a persona to the database.""" + persona_name = "test_persona" + persona_desc = "test_persona_desc" + await semantic_router_mocked_db.add_persona(persona_name, persona_desc) + retrieved_persona = await semantic_router_mocked_db._db_reader.get_persona_by_name(persona_name) + assert retrieved_persona.name == persona_name + assert retrieved_persona.description == persona_desc + + +@pytest.mark.asyncio +async def test_persona_not_exist_match(semantic_router_mocked_db: SemanticRouter): + """Test checking persona match when persona does not exist""" + persona_name = "test_persona" + query = "test_query" + with pytest.raises(PersonaDoesNotExistError): + await semantic_router_mocked_db.check_persona_match(persona_name, query) + + +class PersonaMatchTest(BaseModel): + persona_name: str + persona_desc: str + pass_queries: List[str] + fail_queries: List[str] + + +simple_persona = PersonaMatchTest( + persona_name="test_persona", + persona_desc="test_desc", + pass_queries=["test_desc", "test_desc2"], + fail_queries=["foo"], +) + +software_architect = PersonaMatchTest( + persona_name="software architect", + persona_desc=""" + Expert in designing large-scale software systems and technical infrastructure. + Specializes in distributed systems, microservices architecture, + and cloud-native applications. + Deep knowledge of architectural patterns like CQRS, event sourcing, hexagonal architecture, + and domain-driven design. + Experienced in designing scalable, resilient, and maintainable software solutions. + Proficient in evaluating technology stacks and making strategic technical decisions. + Skilled at creating architecture diagrams, technical specifications, + and system documentation. + Focuses on non-functional requirements like performance, security, and reliability. + Guides development teams on best practices for implementing complex systems. + """, + pass_queries=[ + """ + How should I design a microservices architecture that can handle high traffic loads? + """, + """ + What's the best approach for implementing event sourcing in a distributed system? + """, + """ + I need to design a system that can scale to millions of users. What architecture would you + recommend? + """, + """ + Can you explain the trade-offs between monolithic and microservices architectures for our + new project? + """, + ], + fail_queries=[ + """ + How do I create a simple landing page with HTML and CSS? + """, + """ + What's the best way to optimize my SQL query performance? + """, + """ + Can you help me debug this JavaScript function that's throwing an error? + """, + """ + How do I implement user authentication in my React application? + """, + ], +) + +# Data Scientist Persona +data_scientist = PersonaMatchTest( + persona_name="data scientist", + persona_desc=""" + Expert in analyzing and interpreting complex data to solve business problems. + Specializes in statistical analysis, machine learning algorithms, and predictive modeling. + Builds and deploys models for classification, regression, clustering, and anomaly detection. + Proficient in data preprocessing, feature engineering, and model evaluation techniques. + Uses Python with libraries like NumPy, Pandas, scikit-learn, TensorFlow, and PyTorch. + Experienced with data visualization using Matplotlib, Seaborn, and interactive dashboards. + Applies experimental design principles and A/B testing methodologies. + Works with structured and unstructured data, including time series and text. + Implements data pipelines for model training, validation, and deployment. + Communicates insights and recommendations based on data analysis to stakeholders. + + Handles class imbalance problems in classification tasks using techniques like SMOTE, + undersampling, oversampling, and class weighting. Addresses customer churn prediction + challenges by identifying key features that indicate potential churners. + + Applies feature selection methods for high-dimensional datasets, including filter methods + (correlation, chi-square), wrapper methods (recursive feature elimination), and embedded + methods (LASSO regularization). + + Prevents overfitting and high variance in tree-based models like random forests through + techniques such as pruning, setting maximum depth, adjusting minimum samples per leaf, + and cross-validation. + + Specializes in time series forecasting for sales and demand prediction, using methods like + ARIMA, SARIMA, Prophet, and exponential smoothing to handle seasonal patterns and trends. + Implements forecasting models that account for quarterly business cycles and seasonal + variations in customer behavior. + + Evaluates model performance using appropriate metrics: accuracy, precision, recall, + F1-score + for classification; RMSE, MAE, R-squared for regression; and specialized metrics for + time series forecasting like MAPE and SMAPE. + + Experienced in developing customer segmentation models, recommendation systems, + anomaly detection algorithms, and predictive maintenance solutions. + """, + pass_queries=[ + """ + How should I handle class imbalance in my customer churn prediction model? + """, + """ + What feature selection techniques would work best for my high-dimensional dataset? + """, + """ + I'm getting high variance in my random forest model. How can I prevent overfitting? + """, + """ + What's the best approach for forecasting seasonal time series data for our sales + predictions? + """, + ], + fail_queries=[ + """ + How do I structure my React components for a single-page application? + """, + """ + What's the best way to implement a CI/CD pipeline for my microservices? + """, + """ + Can you help me design a responsive layout for mobile and desktop browsers? + """, + """ + How should I configure my Kubernetes cluster for high availability? + """, + ], +) + +# UX Designer Persona +ux_designer = PersonaMatchTest( + persona_name="ux designer", + persona_desc=""" + Expert in creating intuitive, user-centered digital experiences and interfaces. + Specializes in user research, usability testing, and interaction design. + Creates wireframes, prototypes, and user flows to visualize design solutions. + Conducts user interviews, usability studies, and analyzes user feedback. + Develops user personas and journey maps to understand user needs and pain points. + Designs information architecture and navigation systems for complex applications. + Applies design thinking methodology to solve user experience problems. + Knowledgeable about accessibility standards and inclusive design principles. + Collaborates with product managers and developers to implement user-friendly features. + Uses tools like Figma, Sketch, and Adobe XD to create high-fidelity mockups. + """, + pass_queries=[ + """ + How can I improve the user onboarding experience for my mobile application? + """, + """ + What usability testing methods would you recommend for evaluating our new interface design? + """, + """ + I'm designing a complex dashboard. What information architecture would make it most + intuitive for users? + """, + """ + How should I structure user research to identify pain points in our current + checkout process? + """, + ], + fail_queries=[ + """ + How do I configure a load balancer for my web servers? + """, + """ + What's the best way to implement a caching layer in my application? + """, + """ + Can you explain how to set up a CI/CD pipeline with GitHub Actions? + """, + """ + How do I optimize my database queries for better performance? + """, + ], +) + +# DevOps Engineer Persona +devops_engineer = PersonaMatchTest( + persona_name="devops engineer", + persona_desc=""" + Expertise: Infrastructure automation, CI/CD pipelines, cloud services, containerization, + and monitoring. + Proficient with tools like Docker, Kubernetes, Terraform, Ansible, and Jenkins. + Experienced with cloud platforms including AWS, Azure, and Google Cloud. + Strong knowledge of Linux/Unix systems administration and shell scripting. + Skilled in implementing microservices architectures and service mesh technologies. + Focus on reliability, scalability, security, and operational efficiency. + Practices infrastructure as code, GitOps, and site reliability engineering principles. + Experienced with monitoring tools like Prometheus, Grafana, and ELK stack. + """, + pass_queries=[ + """ + What's the best way to set up auto-scaling for my Kubernetes cluster on AWS? + """, + """ + I need to implement a zero-downtime deployment strategy for my microservices. + What approaches would you recommend? + """, + """ + How can I improve the security of my CI/CD pipeline and prevent supply chain attacks? + """, + """ + What monitoring metrics should I track to ensure the reliability of my distributed system? + """, + ], + fail_queries=[ + """ + How do I design an effective user onboarding flow for my mobile app? + """, + """ + What's the best algorithm for sentiment analysis on customer reviews? + """, + """ + Can you help me with color theory for my website redesign? + """, + """ + I need advice on optimizing my SQL queries for a reporting dashboard. + """, + ], +) + +# Security Specialist Persona +security_specialist = PersonaMatchTest( + persona_name="security specialist", + persona_desc=""" + Expert in cybersecurity, application security, and secure system design. + Specializes in identifying and mitigating security vulnerabilities and threats. + Performs security assessments, penetration testing, and code security reviews. + Implements security controls like authentication, authorization, and encryption. + Knowledgeable about common attack vectors such as injection attacks, XSS, CSRF, and SSRF. + Experienced with security frameworks and standards like OWASP Top 10, NIST, and ISO 27001. + Designs secure architectures and implements defense-in-depth strategies. + Conducts security incident response and forensic analysis. + Implements security monitoring, logging, and alerting systems. + Stays current with emerging security threats and mitigation techniques. + """, + pass_queries=[ + """ + How can I protect my web application from SQL injection attacks? + """, + """ + What security controls should I implement for storing sensitive user data? + """, + """ + How do I conduct a thorough security assessment of our cloud infrastructure? + """, + """ + What's the best approach for implementing secure authentication in my API? + """, + ], + fail_queries=[ + """ + How do I optimize the loading speed of my website? + """, + """ + What's the best way to implement responsive design for mobile devices? + """, + """ + Can you help me design a database schema for my e-commerce application? + """, + """ + How should I structure my React components for better code organization? + """, + ], +) + +# Mobile Developer Persona +mobile_developer = PersonaMatchTest( + persona_name="mobile developer", + persona_desc=""" + Expert in building native and cross-platform mobile applications for iOS and Android. + Specializes in mobile UI development, responsive layouts, and platform-specific + design patterns. + Proficient in Swift and SwiftUI for iOS, Kotlin for Android, and React Native or + Flutter for cross-platform. + Implements mobile-specific features like push notifications, offline storage, and + location services. + Optimizes mobile applications for performance, battery efficiency, and limited + network connectivity. + Experienced with mobile app architecture patterns like MVVM, MVC, and Redux. + Integrates with device hardware features including camera, biometrics, sensors, + and Bluetooth. + Familiar with app store submission processes, app signing, and distribution workflows. + Implements secure data storage, authentication, and API communication on mobile devices. + Designs and develops responsive interfaces that work across different screen sizes + and orientations. + + Implements sophisticated offline-first data synchronization strategies + for mobile applications, + handling conflict resolution, data merging, and background syncing when connectivity + is restored. + Uses technologies like Realm, SQLite, Core Data, and Room Database to enable seamless + offline + experiences in React Native and native apps. + + Structures Swift code following the MVVM (Model-View-ViewModel) architectural pattern + to create + maintainable, testable iOS applications. Implements proper separation of concerns + with bindings + between views and view models using Combine, RxSwift, or SwiftUI's native state management. + + Specializes in deep linking implementation for both Android and iOS, enabling app-to-app + communication, marketing campaign tracking, and seamless user experiences when navigating + between web and mobile contexts. Configures Universal Links, App Links, and custom URL + schemes. + + Optimizes battery usage for location-based features by implementing intelligent location + tracking + strategies, including geofencing, significant location changes, deferred location updates, + and + region monitoring. Balances accuracy requirements with power consumption constraints. + + Develops efficient state management solutions for complex mobile applications using Redux, + MobX, Provider, or Riverpod for React Native apps, and native state management approaches + for iOS and Android. + + Creates responsive mobile interfaces that adapt to different device orientations, + screen sizes, + and pixel densities using constraint layouts, auto layout, size classes, and flexible + grid systems. + """, + pass_queries=[ + """ + What's the best approach for implementing offline-first data synchronization in my mobile + app? + """, + """ + How should I structure my Swift code to implement the MVVM pattern effectively? + """, + """ + What's the most efficient way to handle deep linking and app-to-app communication on + Android? + """, + """ + How can I optimize battery usage when implementing background location tracking? + """, + ], + fail_queries=[ + """ + How do I design a database schema with proper normalization for my web application? + """, + """ + What's the best approach for implementing a distributed caching layer in my microservices? + """, + """ + Can you help me set up a data pipeline for processing large datasets with Apache Spark? + """, + """ + How should I configure my load balancer to distribute traffic across my web servers? + """, + ], +) + +# Database Administrator Persona +database_administrator = PersonaMatchTest( + persona_name="database administrator", + persona_desc=""" + Expert in designing, implementing, and managing database systems for optimal performance and + reliability. + Specializes in database architecture, schema design, and query optimization techniques. + Proficient with relational databases like PostgreSQL, MySQL, Oracle, and SQL Server. + Implements and manages database security, access controls, and data protection measures. + Designs high-availability solutions using replication, clustering, and failover mechanisms. + Develops and executes backup strategies, disaster recovery plans, and data retention + policies. + Monitors database performance, identifies bottlenecks, and implements optimization + solutions. + Creates and maintains indexes, partitioning schemes, and other performance-enhancing + structures. + Experienced with database migration, version control, and change management processes. + Implements data integrity constraints, stored procedures, triggers, and database automation. + + Optimizes complex JOIN query performance in PostgreSQL through advanced techniques including + query rewriting, proper indexing strategies, materialized views, and query plan analysis. + Uses EXPLAIN ANALYZE to identify bottlenecks in query execution plans and implements + appropriate optimizations for specific query patterns. + + Designs and implements high-availability MySQL configurations with automatic failover using + technologies like MySQL Group Replication, Galera Cluster, Percona XtraDB Cluster, or MySQL + InnoDB Cluster with MySQL Router. Configures synchronous and asynchronous replication + strategies + to balance consistency and performance requirements. + + Develops sophisticated indexing strategies for tables with frequent write operations and + complex + read queries, balancing write performance with read optimization. Implements partial + indexes, + covering indexes, and composite indexes based on query patterns and cardinality analysis. + + Specializes in large-scale database migrations between different database engines, + particularly + Oracle to PostgreSQL transitions. Uses tools like ora2pg, AWS DMS, and custom ETL processes + to + ensure data integrity, schema compatibility, and minimal downtime during migration. + + Implements table partitioning schemes based on data access patterns, including range + partitioning + for time-series data, list partitioning for categorical data, and hash partitioning for + evenly + distributed workloads. + + Configures and manages database connection pooling, query caching, and buffer management to + optimize resource utilization and throughput under varying workloads. + + Designs and implements database sharding strategies for horizontal scaling, including + consistent hashing algorithms, shard key selection, and cross-shard query optimization. + """, + pass_queries=[ + """ + How can I optimize the performance of complex JOIN queries in my PostgreSQL database? + """, + """ + What's the best approach for implementing a high-availability MySQL setup with automatic + failover? + """, + """ + How should I design my indexing strategy for a table with frequent writes and complex read + queries? + """, + """ + What's the most efficient way to migrate a large Oracle database to PostgreSQL with minimal + downtime? + """, + ], + fail_queries=[ + """ + How do I structure my React components to implement the Redux state management pattern? + """, + """ + What's the best approach for implementing responsive design with CSS Grid and Flexbox? + """, + """ + Can you help me set up a CI/CD pipeline for my containerized microservices? + """, + ], +) + +# Natural Language Processing Specialist Persona +nlp_specialist = PersonaMatchTest( + persona_name="nlp specialist", + persona_desc=""" + Expertise: Natural language processing, computational linguistics, and text analytics. + Proficient with NLP libraries and frameworks like NLTK, spaCy, Hugging Face Transformers, + and Gensim. + Experience with language models such as BERT, GPT, T5, and their applications. + Skilled in text preprocessing, tokenization, lemmatization, and feature extraction + techniques. + Knowledge of sentiment analysis, named entity recognition, topic modeling, and text + classification. + Familiar with word embeddings, contextual embeddings, and language representation methods. + Understanding of machine translation, question answering, and text summarization systems. + Background in information retrieval, semantic search, and conversational AI development. + """, + pass_queries=[ + """ + What approach should I take to fine-tune BERT for my custom text classification task? + """, + """ + How can I improve the accuracy of my named entity recognition system for medical texts? + """, + """ + What's the best way to implement semantic search using embeddings from language models? + """, + """ + I need to build a sentiment analysis system that can handle sarcasm and idioms. + Any suggestions? + """, + ], + fail_queries=[ + """ + How do I optimize my React components to reduce rendering time? + """, + """ + What's the best approach for implementing a CI/CD pipeline with Jenkins? + """, + """ + Can you help me design a responsive UI for my web application? + """, + """ + How should I structure my microservices architecture for scalability? + """, + ], +) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "persona_match_test", + [ + simple_persona, + software_architect, + data_scientist, + ux_designer, + devops_engineer, + security_specialist, + mobile_developer, + database_administrator, + nlp_specialist, + ], +) +async def test_check_persona_match( + semantic_router_mocked_db: SemanticRouter, persona_match_test: PersonaMatchTest +): + """Test checking persona match.""" + await semantic_router_mocked_db.add_persona( + persona_match_test.persona_name, persona_match_test.persona_desc + ) + + # Check for the queries that should pass + for query in persona_match_test.pass_queries: + match = await semantic_router_mocked_db.check_persona_match( + persona_match_test.persona_name, query + ) + assert match is True + + # Check for the queries that should fail + for query in persona_match_test.fail_queries: + match = await semantic_router_mocked_db.check_persona_match( + persona_match_test.persona_name, query + ) + assert match is False From 0e37312bb823a16ba60525ac0116aced7115c738 Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Tue, 4 Mar 2025 13:17:01 +0200 Subject: [PATCH 2/2] Attended PR comments --- src/codegate/config.py | 4 +++- src/codegate/db/connection.py | 24 +++++++++----------- src/codegate/db/models.py | 25 ++++++++++++++++++--- src/codegate/muxing/semantic_router.py | 31 +++++++++++++++++--------- 4 files changed, 57 insertions(+), 27 deletions(-) diff --git a/src/codegate/config.py b/src/codegate/config.py index 8b177056..761ca09e 100644 --- a/src/codegate/config.py +++ b/src/codegate/config.py @@ -57,7 +57,9 @@ class Config: force_certs: bool = False max_fim_hash_lifetime: int = 60 * 5 # Time in seconds. Default is 5 minutes. - persona_threshold = 0.75 # Min value is 0 (max similarity), max value is 2 (orthogonal) + # Min value is 0 (max similarity), max value is 2 (orthogonal) + # The value 0.75 was found through experimentation. See /tests/muxing/test_semantic_router.py + persona_threshold = 0.75 # Provider URLs with defaults provider_urls: Dict[str, str] = field(default_factory=lambda: DEFAULT_PROVIDER_URLS.copy()) diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 78eaa9c7..803943b3 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -560,7 +560,7 @@ async def add_persona(self, persona: PersonaEmbedding) -> None: ) try: - # For Pydantic we conver the numpy array to a string when serializing. + # For Pydantic we convert the numpy array to string when serializing with .model_dumpy() # We need to convert it back to a numpy array before inserting it into the DB. persona_dict = persona.model_dump() persona_dict["description_embedding"] = persona.description_embedding @@ -615,17 +615,19 @@ async def _exec_select_conditions_to_pydantic( raise e return None - async def _exec_vec_db_query( - self, sql_command: str, conditions: dict - ) -> Optional[CursorResult]: + async def _exec_vec_db_query_to_pydantic( + self, sql_command: str, conditions: dict, model_type: Type[BaseModel] + ) -> List[BaseModel]: """ Execute a query on the vector database. This is a separate connection to the SQLite database that has the vector extension loaded. """ conn = self._get_vec_db_connection() + conn.row_factory = sqlite3.Row cursor = conn.cursor() - cursor.execute(sql_command, conditions) - return cursor + results = [model_type(**row) for row in cursor.execute(sql_command, conditions)] + conn.close() + return results async def get_prompts_with_output(self, workpace_id: str) -> List[GetPromptWithOutputsRow]: sql = text( @@ -985,14 +987,10 @@ async def get_distance_to_persona( WHERE id = :id """ conditions = {"id": persona_id, "query_embedding": query_embedding} - persona_distance_cursor = await self._exec_vec_db_query(sql, conditions) - persona_distance_raw = persona_distance_cursor.fetchone() - return PersonaDistance( - id=persona_distance_raw[0], - name=persona_distance_raw[1], - description=persona_distance_raw[2], - distance=persona_distance_raw[3], + persona_distance = await self._exec_vec_db_query_to_pydantic( + sql, conditions, PersonaDistance ) + return persona_distance[0] def init_db_sync(db_path: Optional[str] = None): diff --git a/src/codegate/db/models.py b/src/codegate/db/models.py index 1fa0d9ff..a5941e96 100644 --- a/src/codegate/db/models.py +++ b/src/codegate/db/models.py @@ -243,8 +243,6 @@ class MuxRule(BaseModel): updated_at: Optional[datetime.datetime] = None -# Pydantic doesn't support numpy arrays out of the box. Defining a custom type -# Reference: https://github.com/pydantic/pydantic/issues/7017 def nd_array_custom_before_validator(x): # custome before validation logic return x @@ -255,6 +253,11 @@ def nd_array_custom_serializer(x): return str(x) +# Pydantic doesn't support numpy arrays out of the box hence we need to construct a custom type. +# There are 2 things necessary for a Pydantic custom type: Validator and Serializer +# The lines below build our custom type +# Docs: https://docs.pydantic.dev/latest/concepts/types/#adding-validation-and-serialization +# Open Pydantic issue for npy support: https://github.com/pydantic/pydantic/issues/7017 NdArray = Annotated[ np.ndarray, BeforeValidator(nd_array_custom_before_validator), @@ -263,17 +266,33 @@ def nd_array_custom_serializer(x): class Persona(BaseModel): + """ + Represents a persona object. + """ + id: str name: str description: str class PersonaEmbedding(Persona): - description_embedding: NdArray # sqlite-vec will handle numpy arrays directly + """ + Represents a persona object with an embedding. + """ + + description_embedding: NdArray # Part of the workaround to allow numpy arrays in pydantic models model_config = ConfigDict(arbitrary_types_allowed=True) class PersonaDistance(Persona): + """ + Result of an SQL query to get the distance between the query and the persona description. + + A vector similarity search is performed to get the distance. Distance values ranges [0, 2]. + 0 means the vectors are identical, 2 means they are orthogonal. + See [sqlite docs](https://alexgarcia.xyz/sqlite-vec/api-reference.html#vec_distance_cosine) + """ + distance: float diff --git a/src/codegate/muxing/semantic_router.py b/src/codegate/muxing/semantic_router.py index bcb48082..ce240b1f 100644 --- a/src/codegate/muxing/semantic_router.py +++ b/src/codegate/muxing/semantic_router.py @@ -13,6 +13,17 @@ logger = structlog.get_logger("codegate") +REMOVE_URLS = re.compile(r"https?://\S+|www\.\S+") +REMOVE_EMAILS = re.compile(r"\S+@\S+") +REMOVE_CODE_BLOCKS = re.compile(r"```[\s\S]*?```") +REMOVE_INLINE_CODE = re.compile(r"`[^`]*`") +REMOVE_HTML_TAGS = re.compile(r"<[^>]+>") +REMOVE_PUNCTUATION = re.compile(r"[^\w\s\']") +NORMALIZE_WHITESPACE = re.compile(r"\s+") +NORMALIZE_DECIMAL_NUMBERS = re.compile(r"\b\d+\.\d+\b") +NORMALIZE_INTEGER_NUMBERS = re.compile(r"\b\d+\b") + + class PersonaDoesNotExistError(Exception): pass @@ -54,27 +65,27 @@ def _clean_text_for_embedding(self, text: str) -> str: text = "".join([c for c in text if not unicodedata.combining(c)]) # Remove URLs - text = re.sub(r"https?://\S+|www\.\S+", " ", text) + text = REMOVE_URLS.sub(" ", text) # Remove email addresses - text = re.sub(r"\S+@\S+", " ", text) + text = REMOVE_EMAILS.sub(" ", text) # Remove code block markers and other markdown/code syntax - text = re.sub(r"```[\s\S]*?```", " ", text) # Code blocks - text = re.sub(r"`[^`]*`", " ", text) # Inline code + text = REMOVE_CODE_BLOCKS.sub(" ", text) + text = REMOVE_INLINE_CODE.sub(" ", text) # Remove HTML/XML tags - text = re.sub(r"<[^>]+>", " ", text) + text = REMOVE_HTML_TAGS.sub(" ", text) # Normalize numbers (replace with placeholder) - text = re.sub(r"\b\d+\.\d+\b", " NUM ", text) # Decimal numbers - text = re.sub(r"\b\d+\b", " NUM ", text) # Integer numbers + text = NORMALIZE_DECIMAL_NUMBERS.sub(" NUM ", text) # Decimal numbers + text = NORMALIZE_INTEGER_NUMBERS.sub(" NUM ", text) # Integer numbers # Replace punctuation with spaces (keeping apostrophes for contractions) - text = re.sub(r"[^\w\s\']", " ", text) + text = REMOVE_PUNCTUATION.sub(" ", text) # Normalize whitespace (replace multiple spaces with a single space) - text = re.sub(r"\s+", " ", text) + text = NORMALIZE_WHITESPACE.sub(" ", text) # Convert to lowercase and strip text = text.strip() @@ -91,7 +102,7 @@ async def _embed_text(self, text: str) -> np.ndarray: self._embeddings_model, [cleaned_text], n_gpu_layers=self._n_gpu ) # Use only the first entry in the list and make sure we have the appropriate type - logger.debug("Text embedded in semantic routing", text=cleaned_text[:100]) + logger.debug("Text embedded in semantic routing", text=cleaned_text[:50]) return np.array(embed_list[0], dtype=np.float32) async def add_persona(self, persona_name: str, persona_desc: str) -> None: