diff --git a/src/codegate/llm_utils/extractor.py b/src/codegate/llm_utils/extractor.py index 4325f909..1f95270e 100644 --- a/src/codegate/llm_utils/extractor.py +++ b/src/codegate/llm_utils/extractor.py @@ -30,7 +30,7 @@ async def extract_packages( system_prompt = Config.get_config().prompts.lookup_packages result = await LLMClient.complete( - content=content, + content=content.lower(), system_prompt=system_prompt, provider=provider, model=model, @@ -41,6 +41,9 @@ async def extract_packages( # Handle both formats: {"packages": [...]} and direct list [...] packages = result if isinstance(result, list) else result.get("packages", []) + + # Filter packages based on the content + packages = [package.lower() for package in packages if package.lower() in content] logger.info(f"Extracted packages: {packages}") return packages diff --git a/src/codegate/pipeline/codegate_context_retriever/codegate.py b/src/codegate/pipeline/codegate_context_retriever/codegate.py index eab0b9bc..13853b1a 100644 --- a/src/codegate/pipeline/codegate_context_retriever/codegate.py +++ b/src/codegate/pipeline/codegate_context_retriever/codegate.py @@ -29,12 +29,12 @@ def name(self) -> str: """ return "codegate-context-retriever" - async def get_objects_from_search( - self, search: str, ecosystem, packages: list[str] = None + async def get_objects_from_db( + self, ecosystem, packages: list[str] = None ) -> list[object]: storage_engine = StorageEngine() objects = await storage_engine.search( - search, distance=0.8, ecosystem=ecosystem, packages=packages + distance=0.8, ecosystem=ecosystem, packages=packages ) return objects @@ -103,39 +103,25 @@ async def process( # Extract packages from the user message ecosystem = await self.__lookup_ecosystem(user_messages, context) packages = await self.__lookup_packages(user_messages, context) - packages = [pkg.lower() for pkg in packages] - # If user message does not reference any packages, then just return - if len(packages) == 0: - return PipelineResult(request=request) - - # Look for matches in vector DB using list of packages as filter - searched_objects = await self.get_objects_from_search(user_messages, ecosystem, packages) + context_str = "CodeGate did not find any malicious or archived packages." - logger.info( - f"Found {len(searched_objects)} matches in the database", - searched_objects=searched_objects, - ) + if len(packages) > 0: + # Look for matches in DB using packages and ecosystem + searched_objects = await self.get_objects_from_db(ecosystem, packages) - # Remove searched objects that are not in packages. This is needed - # since Weaviate performs substring match in the filter. - updated_searched_objects = [] - for searched_object in searched_objects: - if searched_object.properties["name"].lower() in packages: - updated_searched_objects.append(searched_object) - searched_objects = updated_searched_objects + logger.info( + f"Found {len(searched_objects)} matches in the database", + searched_objects=searched_objects, + ) - # Generate context string using the searched objects - logger.info(f"Adding {len(searched_objects)} packages to the context") + # Generate context string using the searched objects + logger.info(f"Adding {len(searched_objects)} packages to the context") - if len(searched_objects) > 0: - context_str = self.generate_context_str(searched_objects, context) - else: - context_str = "CodeGate did not find any malicious or archived packages." + if len(searched_objects) > 0: + context_str = self.generate_context_str(searched_objects, context) last_user_idx = self.get_last_user_message_idx(request) - if last_user_idx == -1: - return PipelineResult(request=request, context=context) # Make a copy of the request new_request = request.copy() diff --git a/src/codegate/storage/storage_engine.py b/src/codegate/storage/storage_engine.py index e58646a6..5aaf91b4 100644 --- a/src/codegate/storage/storage_engine.py +++ b/src/codegate/storage/storage_engine.py @@ -121,7 +121,6 @@ async def search_by_property(self, name: str, properties: List[str]) -> list[obj return [] try: - packages = self.weaviate_client.collections.get("Package") response = packages.query.fetch_objects( filters=Filter.by_property(name).contains_any(properties), @@ -145,10 +144,19 @@ async def search_by_property(self, name: str, properties: List[str]) -> list[obj return [] async def search( - self, query: str, limit=5, distance=0.3, ecosystem=None, packages=None + self, + query: str = None, + ecosystem: str = None, + packages: List[str] = None, + limit: int = 5, + distance: float = 0.3, ) -> list[object]: """ - Search the 'Package' collection based on a query string. + Search the 'Package' collection based on a query string, ecosystem and packages. + If packages and ecosystem are both not none, then filter the objects using them. + If packages is not none and ecosystem is none, then filter the objects using + package names. + If packages is none, then perform vector search. Args: query (str): The text query for which to search. @@ -160,26 +168,40 @@ async def search( Returns: list: A list of matching results with their properties and distances. """ - # Generate the vector for the query - query_vector = await self.inference_engine.embed(self.model_path, [query]) - # Perform the vector search try: collection = self.weaviate_client.collections.get("Package") - if packages: - # filter by packages and ecosystem if present - filters = [] - if ecosystem and ecosystem in VALID_ECOSYSTEMS: - filters.append(wvc.query.Filter.by_property("type").equal(ecosystem)) - filters.append(wvc.query.Filter.by_property("name").contains_any(packages)) - response = collection.query.near_vector( - query_vector[0], - limit=limit, - distance=distance, - filters=wvc.query.Filter.all_of(filters), - return_metadata=MetadataQuery(distance=True), + + response = None + if packages and ecosystem and ecosystem in VALID_ECOSYSTEMS: + response = collection.query.fetch_objects( + filters=wvc.query.Filter.all_of([ + wvc.query.Filter.by_property("name").contains_any(packages), + wvc.query.Filter.by_property("type").equal(ecosystem) + ]), ) - else: + response.objects = [ + obj + for obj in response.objects + if obj.properties["name"].lower() in packages + and obj.properties["type"].lower() == ecosystem.lower() + ] + elif packages and not ecosystem: + response = collection.query.fetch_objects( + filters=wvc.query.Filter.all_of([ + wvc.query.Filter.by_property("name").contains_any(packages), + ]), + ) + response.objects = [ + obj + for obj in response.objects + if obj.properties["name"].lower() in packages + ] + elif query: + # Perform the vector search + # Generate the vector for the query + query_vector = await self.inference_engine.embed(self.model_path, [query]) + response = collection.query.near_vector( query_vector[0], limit=limit, diff --git a/tests/test_storage.py b/tests/test_storage.py index 965bf071..dafb1547 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -52,7 +52,7 @@ async def test_search(mock_weaviate_client, mock_inference_engine): storage_engine = StorageEngine.recreate_instance(data_path="./weaviate_data") # Invoke the search method - results = await storage_engine.search("test query", 5, 0.3) + results = await storage_engine.search(query="test query", limit=5, distance=0.3) # Assertions to validate the expected behavior assert len(results) == 1 # Assert that one result is returned