diff --git a/src/codegate/pipeline/codegate_context_retriever/codegate.py b/src/codegate/pipeline/codegate_context_retriever/codegate.py index 08bc1ed6..9bbd8565 100644 --- a/src/codegate/pipeline/codegate_context_retriever/codegate.py +++ b/src/codegate/pipeline/codegate_context_retriever/codegate.py @@ -1,4 +1,5 @@ import json +import re import structlog from litellm import ChatCompletionRequest @@ -9,7 +10,9 @@ PipelineResult, PipelineStep, ) +from codegate.pipeline.extract_snippets.extract_snippets import extract_snippets from codegate.storage.storage_engine import StorageEngine +from codegate.utils.package_extractor import PackageExtractor from codegate.utils.utils import generate_vector_string logger = structlog.get_logger("codegate") @@ -64,26 +67,45 @@ async def process( if len(user_messages) == 0: return PipelineResult(request=request) - context_str = "CodeGate did not find any malicious or archived packages." + # Create storage engine object + storage_engine = StorageEngine() + + # Extract any code snippets + snippets = extract_snippets(user_messages) + + # Collect all packages referenced in the snippets + snippet_packages = [] + for snippet in snippets: + snippet_packages.extend( + PackageExtractor.extract_packages(snippet.code, snippet.language) + ) + logger.info(f"Found {len(snippet_packages)} packages in code snippets.") + + # Find bad packages in the snippets + bad_snippet_packages = await storage_engine.search_by_property("name", snippet_packages) + logger.info(f"Found {len(bad_snippet_packages)} bad packages in code snippets.") + + # Remove code snippets from the user messages and search for bad packages + # in the rest of the user query/messsages + user_messages = re.sub(r"```.*?```", "", user_messages, flags=re.DOTALL) # Vector search to find bad packages - storage_engine = StorageEngine() - searched_objects = await storage_engine.search(query=user_messages, distance=0.8, limit=100) + bad_packages = await storage_engine.search(query=user_messages, distance=0.8, limit=100) - logger.info( - f"Found {len(searched_objects)} matches in the database", - searched_objects=searched_objects, - ) + # All bad packages + all_bad_packages = bad_snippet_packages + bad_packages + + logger.info(f"Adding {len(all_bad_packages)} bad packages to the context.") # Generate context string using the searched objects - logger.info(f"Adding {len(searched_objects)} packages to the context") + context_str = "CodeGate did not find any malicious or archived packages." # Nothing to do if no bad packages are found - if len(searched_objects) == 0: + if len(all_bad_packages) == 0: return PipelineResult(request=request, context=context) else: # Add context for bad packages - context_str = self.generate_context_str(searched_objects, context) + context_str = self.generate_context_str(all_bad_packages, context) context.bad_packages_found = True last_user_idx = self.get_last_user_message_idx(request)