Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Cache FIM entries in memory to avoid repeated writes to DB #372

Merged
merged 1 commit into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/codegate/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class Config:
server_key: str = "server.key"
force_certs: bool = False

max_fim_hash_lifetime: int = 60 * 5 # Time in seconds. Default is 5 minutes.

# Provider URLs with defaults
provider_urls: Dict[str, str] = field(default_factory=lambda: DEFAULT_PROVIDER_URLS.copy())

Expand Down
85 changes: 80 additions & 5 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import hashlib
import json
import re
from pathlib import Path
from typing import List, Optional

Expand All @@ -8,6 +10,7 @@
from sqlalchemy import text
from sqlalchemy.ext.asyncio import create_async_engine

from codegate.config import Config
from codegate.db.models import Alert, Output, Prompt
from codegate.db.queries import (
AsyncQuerier,
Expand All @@ -18,6 +21,7 @@

logger = structlog.get_logger("codegate")
alert_queue = asyncio.Queue()
fim_entries = {}


class DbCodeGate:
Expand Down Expand Up @@ -178,14 +182,85 @@ async def record_alerts(self, alerts: List[Alert]) -> List[Alert]:
logger.debug(f"Recorded alerts: {recorded_alerts}")
return recorded_alerts

async def record_context(self, context: PipelineContext) -> None:
logger.info(
f"Recording context in DB. Output chunks: {len(context.output_responses)}. "
f"Alerts: {len(context.alerts_raised)}."
)
def _extract_request_message(self, request: str) -> Optional[dict]:
"""Extract the user message from the FIM request"""
try:
parsed_request = json.loads(request)
except Exception as e:
logger.exception(f"Failed to extract request message: {request}", error=str(e))
return None

messages = [message for message in parsed_request["messages"] if message["role"] == "user"]
if len(messages) != 1:
logger.warning(f"Expected one user message, found {len(messages)}.")
return None

content_message = messages[0].get("content")
return content_message

def _create_hash_key(self, message: str, provider: str) -> str:
"""Creates a hash key from the message and includes the provider"""
# Try to extract the path from the message. Most of the times is at the top of the message.
# The pattern was generated using ChatGPT. Should match common occurrences like:
# folder/testing_file.py
# Path: file3.py
pattern = r"(?:[a-zA-Z]:\\|\/)?(?:[^\s\/]+\/)*[^\s\/]+\.[^\s\/]+"
match = re.search(pattern, message)
# Copilot it's the only provider that has an easy path to extract.
# Other providers are harder to extact. This part needs to be revisited for the moment
# hasing the entire request message.
if match is None or provider != "copilot":
logger.warning("No path found in message or not copilot. Creating hash from message.")
message_to_hash = f"{message}-{provider}"
else:
message_to_hash = f"{match.group(0)}-{provider}"

logger.debug(f"Message to hash: {message_to_hash}")
hashed_content = hashlib.sha256(message_to_hash.encode("utf-8")).hexdigest()
logger.debug(f"Hashed contnet: {hashed_content}")
return hashed_content

def _should_record_context(self, context: Optional[PipelineContext]) -> bool:
"""Check if the context should be recorded in DB"""
if context is None or context.metadata.get("stored_in_db", False):
return False

if not context.input_request:
logger.warning("No input request found. Skipping recording context.")
return False

# If it's not a FIM prompt, we don't need to check anything else.
if context.input_request.type != "fim":
return True

# Couldn't process the user message. Skip creating a mapping entry.
message = self._extract_request_message(context.input_request.request)
if message is None:
logger.warning(f"Couldn't read FIM message: {message}. Will not record to DB.")
return False

hash_key = self._create_hash_key(message, context.input_request.provider)
old_timestamp = fim_entries.get(hash_key, None)
if old_timestamp is None:
fim_entries[hash_key] = context.input_request.timestamp
return True

elapsed_seconds = (context.input_request.timestamp - old_timestamp).total_seconds()
if elapsed_seconds < Config.get_config().max_fim_hash_lifetime:
logger.info(f"Skipping context recording. Elapsed time: {elapsed_seconds} seconds.")
return False

async def record_context(self, context: Optional[PipelineContext]) -> None:
if not self._should_record_context(context):
return
await self.record_request(context.input_request)
await self.record_outputs(context.output_responses)
await self.record_alerts(context.alerts_raised)
context.metadata["stored_in_db"] = True
logger.info(
f"Recorded context in DB. Output chunks: {len(context.output_responses)}. "
f"Alerts: {len(context.alerts_raised)}."
)


class DbReader(DbCodeGate):
Expand Down
4 changes: 1 addition & 3 deletions src/codegate/pipeline/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,7 @@ def _store_chunk_content(self, chunk: ModelResponse) -> None:
self._context.processed_content.append(choice.delta.content)

async def _record_to_db(self):
if self._input_context and not self._input_context.metadata.get("stored_in_db", False):
await self._db_recorder.record_context(self._input_context)
self._input_context.metadata["stored_in_db"] = True
await self._db_recorder.record_context(self._input_context)

async def process_stream(
self, stream: AsyncIterator[ModelResponse]
Expand Down
4 changes: 1 addition & 3 deletions src/codegate/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,7 @@ async def _cleanup_after_streaming(
finally:
if context:
# Record to DB the objects captured during the stream
if not context.metadata.get("stored_in_db", False):
await self._db_recorder.record_context(context)
context.metadata["stored_in_db"] = True
await self._db_recorder.record_context(context)
# Ensure sensitive data is cleaned up after the stream is consumed
if context.sensitive:
context.sensitive.secure_cleanup()
Expand Down
4 changes: 1 addition & 3 deletions src/codegate/providers/copilot/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,9 +609,7 @@ async def stream_iterator():
StreamingChoices(
finish_reason=choice.get("finish_reason", None),
index=0,
delta=Delta(
content=content, role="assistant"
),
delta=Delta(content=content, role="assistant"),
logprobs=None,
)
)
Expand Down
Loading