Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Extract and process code snippets in the user query #493

Merged
merged 1 commit into from
Jan 6, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions src/codegate/pipeline/codegate_context_retriever/codegate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import re

import structlog
from litellm import ChatCompletionRequest
Expand All @@ -9,7 +10,9 @@
PipelineResult,
PipelineStep,
)
from codegate.pipeline.extract_snippets.extract_snippets import extract_snippets
from codegate.storage.storage_engine import StorageEngine
from codegate.utils.package_extractor import PackageExtractor
from codegate.utils.utils import generate_vector_string

logger = structlog.get_logger("codegate")
Expand Down Expand Up @@ -64,26 +67,45 @@ async def process(
if len(user_messages) == 0:
return PipelineResult(request=request)

context_str = "CodeGate did not find any malicious or archived packages."
# Create storage engine object
storage_engine = StorageEngine()

# Extract any code snippets
snippets = extract_snippets(user_messages)

# Collect all packages referenced in the snippets
snippet_packages = []
for snippet in snippets:
snippet_packages.extend(
PackageExtractor.extract_packages(snippet.code, snippet.language)
)
logger.info(f"Found {len(snippet_packages)} packages in code snippets.")

# Find bad packages in the snippets
bad_snippet_packages = await storage_engine.search_by_property("name", snippet_packages)
logger.info(f"Found {len(bad_snippet_packages)} bad packages in code snippets.")

# Remove code snippets from the user messages and search for bad packages
# in the rest of the user query/messsages
user_messages = re.sub(r"```.*?```", "", user_messages, flags=re.DOTALL)

# Vector search to find bad packages
storage_engine = StorageEngine()
searched_objects = await storage_engine.search(query=user_messages, distance=0.8, limit=100)
bad_packages = await storage_engine.search(query=user_messages, distance=0.8, limit=100)

logger.info(
f"Found {len(searched_objects)} matches in the database",
searched_objects=searched_objects,
)
# All bad packages
all_bad_packages = bad_snippet_packages + bad_packages

logger.info(f"Adding {len(all_bad_packages)} bad packages to the context.")

# Generate context string using the searched objects
logger.info(f"Adding {len(searched_objects)} packages to the context")
context_str = "CodeGate did not find any malicious or archived packages."

# Nothing to do if no bad packages are found
if len(searched_objects) == 0:
if len(all_bad_packages) == 0:
return PipelineResult(request=request, context=context)
else:
# Add context for bad packages
context_str = self.generate_context_str(searched_objects, context)
context_str = self.generate_context_str(all_bad_packages, context)
context.bad_packages_found = True

last_user_idx = self.get_last_user_message_idx(request)
Expand Down
Loading