From 3f620a2fda3d5f887ad61aa69f2ed4023a27cd5c Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Mon, 16 Dec 2024 15:02:04 +0100 Subject: [PATCH] Cache FIM entries in memory to avoid repeated writes to DB --- src/codegate/config.py | 2 + src/codegate/db/connection.py | 85 ++++++++++++++++++++-- src/codegate/pipeline/output.py | 4 +- src/codegate/providers/base.py | 4 +- src/codegate/providers/copilot/provider.py | 4 +- 5 files changed, 85 insertions(+), 14 deletions(-) diff --git a/src/codegate/config.py b/src/codegate/config.py index be7b0143..76d947f2 100644 --- a/src/codegate/config.py +++ b/src/codegate/config.py @@ -52,6 +52,8 @@ class Config: server_key: str = "server.key" force_certs: bool = False + max_fim_hash_lifetime: int = 60 * 5 # Time in seconds. Default is 5 minutes. + # Provider URLs with defaults provider_urls: Dict[str, str] = field(default_factory=lambda: DEFAULT_PROVIDER_URLS.copy()) diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 5dbe14fa..521cef7d 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -1,5 +1,7 @@ import asyncio +import hashlib import json +import re from pathlib import Path from typing import List, Optional @@ -8,6 +10,7 @@ from sqlalchemy import text from sqlalchemy.ext.asyncio import create_async_engine +from codegate.config import Config from codegate.db.models import Alert, Output, Prompt from codegate.db.queries import ( AsyncQuerier, @@ -18,6 +21,7 @@ logger = structlog.get_logger("codegate") alert_queue = asyncio.Queue() +fim_entries = {} class DbCodeGate: @@ -178,14 +182,85 @@ async def record_alerts(self, alerts: List[Alert]) -> List[Alert]: logger.debug(f"Recorded alerts: {recorded_alerts}") return recorded_alerts - async def record_context(self, context: PipelineContext) -> None: - logger.info( - f"Recording context in DB. Output chunks: {len(context.output_responses)}. " - f"Alerts: {len(context.alerts_raised)}." - ) + def _extract_request_message(self, request: str) -> Optional[dict]: + """Extract the user message from the FIM request""" + try: + parsed_request = json.loads(request) + except Exception as e: + logger.exception(f"Failed to extract request message: {request}", error=str(e)) + return None + + messages = [message for message in parsed_request["messages"] if message["role"] == "user"] + if len(messages) != 1: + logger.warning(f"Expected one user message, found {len(messages)}.") + return None + + content_message = messages[0].get("content") + return content_message + + def _create_hash_key(self, message: str, provider: str) -> str: + """Creates a hash key from the message and includes the provider""" + # Try to extract the path from the message. Most of the times is at the top of the message. + # The pattern was generated using ChatGPT. Should match common occurrences like: + # folder/testing_file.py + # Path: file3.py + pattern = r"(?:[a-zA-Z]:\\|\/)?(?:[^\s\/]+\/)*[^\s\/]+\.[^\s\/]+" + match = re.search(pattern, message) + # Copilot it's the only provider that has an easy path to extract. + # Other providers are harder to extact. This part needs to be revisited for the moment + # hasing the entire request message. + if match is None or provider != "copilot": + logger.warning("No path found in message or not copilot. Creating hash from message.") + message_to_hash = f"{message}-{provider}" + else: + message_to_hash = f"{match.group(0)}-{provider}" + + logger.debug(f"Message to hash: {message_to_hash}") + hashed_content = hashlib.sha256(message_to_hash.encode("utf-8")).hexdigest() + logger.debug(f"Hashed contnet: {hashed_content}") + return hashed_content + + def _should_record_context(self, context: Optional[PipelineContext]) -> bool: + """Check if the context should be recorded in DB""" + if context is None or context.metadata.get("stored_in_db", False): + return False + + if not context.input_request: + logger.warning("No input request found. Skipping recording context.") + return False + + # If it's not a FIM prompt, we don't need to check anything else. + if context.input_request.type != "fim": + return True + + # Couldn't process the user message. Skip creating a mapping entry. + message = self._extract_request_message(context.input_request.request) + if message is None: + logger.warning(f"Couldn't read FIM message: {message}. Will not record to DB.") + return False + + hash_key = self._create_hash_key(message, context.input_request.provider) + old_timestamp = fim_entries.get(hash_key, None) + if old_timestamp is None: + fim_entries[hash_key] = context.input_request.timestamp + return True + + elapsed_seconds = (context.input_request.timestamp - old_timestamp).total_seconds() + if elapsed_seconds < Config.get_config().max_fim_hash_lifetime: + logger.info(f"Skipping context recording. Elapsed time: {elapsed_seconds} seconds.") + return False + + async def record_context(self, context: Optional[PipelineContext]) -> None: + if not self._should_record_context(context): + return await self.record_request(context.input_request) await self.record_outputs(context.output_responses) await self.record_alerts(context.alerts_raised) + context.metadata["stored_in_db"] = True + logger.info( + f"Recorded context in DB. Output chunks: {len(context.output_responses)}. " + f"Alerts: {len(context.alerts_raised)}." + ) class DbReader(DbCodeGate): diff --git a/src/codegate/pipeline/output.py b/src/codegate/pipeline/output.py index ad4f14b9..43751786 100644 --- a/src/codegate/pipeline/output.py +++ b/src/codegate/pipeline/output.py @@ -112,9 +112,7 @@ def _store_chunk_content(self, chunk: ModelResponse) -> None: self._context.processed_content.append(choice.delta.content) async def _record_to_db(self): - if self._input_context and not self._input_context.metadata.get("stored_in_db", False): - await self._db_recorder.record_context(self._input_context) - self._input_context.metadata["stored_in_db"] = True + await self._db_recorder.record_context(self._input_context) async def process_stream( self, stream: AsyncIterator[ModelResponse] diff --git a/src/codegate/providers/base.py b/src/codegate/providers/base.py index a0350737..f6e696be 100644 --- a/src/codegate/providers/base.py +++ b/src/codegate/providers/base.py @@ -193,9 +193,7 @@ async def _cleanup_after_streaming( finally: if context: # Record to DB the objects captured during the stream - if not context.metadata.get("stored_in_db", False): - await self._db_recorder.record_context(context) - context.metadata["stored_in_db"] = True + await self._db_recorder.record_context(context) # Ensure sensitive data is cleaned up after the stream is consumed if context.sensitive: context.sensitive.secure_cleanup() diff --git a/src/codegate/providers/copilot/provider.py b/src/codegate/providers/copilot/provider.py index b514f3bc..6580bf92 100644 --- a/src/codegate/providers/copilot/provider.py +++ b/src/codegate/providers/copilot/provider.py @@ -609,9 +609,7 @@ async def stream_iterator(): StreamingChoices( finish_reason=choice.get("finish_reason", None), index=0, - delta=Delta( - content=content, role="assistant" - ), + delta=Delta(content=content, role="assistant"), logprobs=None, ) )