diff --git a/poetry.lock b/poetry.lock index 66a75b80..79240af8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1778,13 +1778,13 @@ files = [ [[package]] name = "openai" -version = "1.55.2" +version = "1.55.3" description = "The official Python library for the openai API" optional = false python-versions = ">=3.8" files = [ - {file = "openai-1.55.2-py3-none-any.whl", hash = "sha256:3027c7fa4a33ed759f4a3d076093fcfa1c55658660c889bec33f651e2dc77922"}, - {file = "openai-1.55.2.tar.gz", hash = "sha256:5cc0b1162b65dcdf670b4b41448f18dd470d2724ca04821ab1e86b6b4e88650b"}, + {file = "openai-1.55.3-py3-none-any.whl", hash = "sha256:2a235d0e1e312cd982f561b18c27692e253852f4e5fb6ccf08cb13540a9bdaa1"}, + {file = "openai-1.55.3.tar.gz", hash = "sha256:547e85b94535469f137a779d8770c8c5adebd507c2cc6340ca401a7c4d5d16f0"}, ] [package.dependencies] diff --git a/scripts/import_packages.py b/scripts/import_packages.py index f980d17d..3ab75c00 100644 --- a/scripts/import_packages.py +++ b/scripts/import_packages.py @@ -7,6 +7,7 @@ from weaviate.util import generate_uuid5 from codegate.inference.inference_engine import LlamaCppInferenceEngine +from src.codegate.utils.utils import generate_vector_string class PackageImporter: @@ -37,33 +38,8 @@ def setup_schema(self): ], ) - def generate_vector_string(self, package): - vector_str = f"{package['name']}" - package_url = "" - type_map = { - "pypi": "Python package available on PyPI", - "npm": "JavaScript package available on NPM", - "go": "Go package", - "crates": "Rust package available on Crates", - "java": "Java package", - } - status_messages = { - "archived": "However, this package is found to be archived and no longer maintained.", - "deprecated": "However, this package is found to be deprecated and no longer " - "recommended for use.", - "malicious": "However, this package is found to be malicious.", - } - vector_str += f" is a {type_map.get(package['type'], 'unknown type')} " - package_url = f"https://trustypkg.dev/{package['type']}/{package['name']}" - - # Add extra status - status_suffix = status_messages.get(package["status"], "") - if status_suffix: - vector_str += f"{status_suffix} For additional information refer to {package_url}" - return vector_str - async def process_package(self, batch, package): - vector_str = self.generate_vector_string(package) + vector_str = generate_vector_string(package) vector = await self.inference_engine.embed(self.model_path, [vector_str]) # This is where the synchronous call is made batch.add_object(properties=package, vector=vector[0]) diff --git a/src/codegate/pipeline/codegate_context_retriever/__init__.py b/src/codegate/pipeline/codegate_context_retriever/__init__.py new file mode 100644 index 00000000..2bb4ce76 --- /dev/null +++ b/src/codegate/pipeline/codegate_context_retriever/__init__.py @@ -0,0 +1,3 @@ +from codegate.pipeline.codegate_context_retriever.codegate import CodegateContextRetriever + +__all__ = ["CodegateContextRetriever"] diff --git a/src/codegate/pipeline/codegate_context_retriever/codegate.py b/src/codegate/pipeline/codegate_context_retriever/codegate.py new file mode 100644 index 00000000..b87a33a6 --- /dev/null +++ b/src/codegate/pipeline/codegate_context_retriever/codegate.py @@ -0,0 +1,82 @@ +from typing import Optional + +from litellm import ChatCompletionRequest, ChatCompletionSystemMessage + +from codegate.pipeline.base import ( + PipelineContext, + PipelineResult, + PipelineStep, +) +from src.codegate.storage.storage_engine import StorageEngine +from src.codegate.utils.utils import generate_vector_string + + +class CodegateContextRetriever(PipelineStep): + """ + Pipeline step that adds a context message to the completion request when it detects + the word "codegate" in the user message. + """ + + def __init__(self, system_prompt_message: Optional[str] = None): + self._system_message = ChatCompletionSystemMessage( + content=system_prompt_message, role="system" + ) + self.storage_engine = StorageEngine() + + @property + def name(self) -> str: + """ + Returns the name of this pipeline step. + """ + return "codegate-context-retriever" + + async def get_objects_from_search(self, search: str) -> list[object]: + objects = await self.storage_engine.search(search) + return objects + + def generate_context_str(self, objects: list[object]) -> str: + context_str = "Please use the information about related packages " + "to influence your answer:\n" + for obj in objects: + # generate dictionary from object + package_obj = { + "name": obj.properties["name"], + "type": obj.properties["type"], + "status": obj.properties["status"], + "description": obj.properties["description"], + } + package_str = generate_vector_string(package_obj) + context_str += package_str + "\n" + return context_str + + async def process( + self, request: ChatCompletionRequest, context: PipelineContext + ) -> PipelineResult: + """ + Process the completion request and add a system prompt if the user message contains + the word "codegate". + """ + # no prompt configured + if not self._system_message["content"]: + return PipelineResult(request=request) + + last_user_message = self.get_last_user_message(request) + + if last_user_message is not None: + last_user_message_str, last_user_idx = last_user_message + if "codegate" in last_user_message_str.lower(): + # strip codegate from prompt and trim it + last_user_message_str = ( + last_user_message_str.lower().replace("codegate", "").strip() + ) + searched_objects = await self.get_objects_from_search(last_user_message_str) + context_str = self.generate_context_str(searched_objects) + # Add a system prompt to the completion request + new_request = request.copy() + new_request["messages"].insert(last_user_idx, context_str) + return PipelineResult( + request=new_request, + ) + + # Fall through + return PipelineResult(request=request) diff --git a/src/codegate/pipeline/extract_snippets/extract_snippets.py b/src/codegate/pipeline/extract_snippets/extract_snippets.py index a50460ee..bee756e0 100644 --- a/src/codegate/pipeline/extract_snippets/extract_snippets.py +++ b/src/codegate/pipeline/extract_snippets/extract_snippets.py @@ -13,6 +13,7 @@ logger = structlog.get_logger("codegate") + def ecosystem_from_filepath(filepath: str) -> Optional[str]: """ Determine language from filepath. diff --git a/src/codegate/server.py b/src/codegate/server.py index 631824bf..94ecac56 100644 --- a/src/codegate/server.py +++ b/src/codegate/server.py @@ -5,6 +5,7 @@ from codegate import __description__, __version__ from codegate.config import Config from codegate.pipeline.base import PipelineStep, SequentialPipelineProcessor +from codegate.pipeline.codegate_context_retriever.codegate import CodegateContextRetriever from codegate.pipeline.codegate_system_prompt.codegate import CodegateSystemPrompt from codegate.pipeline.extract_snippets.extract_snippets import CodeSnippetExtractor from codegate.pipeline.version.version import CodegateVersion @@ -26,6 +27,7 @@ def init_app() -> FastAPI: CodegateVersion(), CodeSnippetExtractor(), CodegateSystemPrompt(Config.get_config().prompts.codegate_chat), + CodegateContextRetriever(Config.get_config().prompts.codegate_chat), # CodegateSecrets(), ] # Leaving the pipeline empty for now diff --git a/src/codegate/storage/__init__.py b/src/codegate/storage/__init__.py new file mode 100644 index 00000000..032dc8fd --- /dev/null +++ b/src/codegate/storage/__init__.py @@ -0,0 +1,3 @@ +from codegate.storage.storage_engine import StorageEngine + +__all__ = [StorageEngine] diff --git a/src/codegate/storage/storage_engine.py b/src/codegate/storage/storage_engine.py new file mode 100644 index 00000000..b27cac7a --- /dev/null +++ b/src/codegate/storage/storage_engine.py @@ -0,0 +1,106 @@ +import structlog +import weaviate +from weaviate.classes.config import DataType +from weaviate.classes.query import MetadataQuery + +from codegate.inference.inference_engine import LlamaCppInferenceEngine + +logger = structlog.get_logger("codegate") + +schema_config = [ + { + "name": "Package", + "properties": [ + {"name": "name", "data_type": DataType.TEXT}, + {"name": "type", "data_type": DataType.TEXT}, + {"name": "status", "data_type": DataType.TEXT}, + {"name": "description", "data_type": DataType.TEXT}, + ], + }, +] + + +class StorageEngine: + def get_client(self, data_path): + try: + client = weaviate.WeaviateClient( + embedded_options=weaviate.EmbeddedOptions(persistence_data_path=data_path), + ) + return client + except Exception as e: + logger.error(f"Error during client creation: {str(e)}") + return None + + def __init__(self, data_path="./weaviate_data"): + self.data_path = data_path + self.inference_engine = LlamaCppInferenceEngine() + self.model_path = "./models/all-minilm-L6-v2-q5_k_m.gguf" + self.schema_config = schema_config + + # setup schema for weaviate + weaviate_client = self.get_client(self.data_path) + if weaviate_client is not None: + try: + weaviate_client.connect() + self.setup_schema(weaviate_client) + except Exception as e: + logger.error(f"Failed to connect or setup schema: {str(e)}") + finally: + try: + weaviate_client.close() + except Exception as e: + logger.info(f"Failed to close client: {str(e)}") + else: + logger.error("Could not find client, skipping schema setup.") + + def setup_schema(self, client): + for class_config in self.schema_config: + if not client.collections.exists(class_config["name"]): + client.collections.create( + class_config["name"], properties=class_config["properties"] + ) + logger.info(f"Weaviate schema for class {class_config['name']} setup complete.") + + async def search(self, query: str, limit=5, distance=0.3) -> list[object]: + """ + Search the 'Package' collection based on a query string. + + Args: + query (str): The text query for which to search. + limit (int): The number of results to return. + + Returns: + list: A list of matching results with their properties and distances. + """ + # Generate the vector for the query + query_vector = await self.inference_engine.embed(self.model_path, [query]) + + # Perform the vector search + weaviate_client = self.get_client(self.data_path) + if weaviate_client is None: + logger.error("Could not find client, not returning results.") + return [] + + try: + weaviate_client.connect() + collection = weaviate_client.collections.get("Package") + response = collection.query.near_vector( + query_vector[0], + limit=limit, + distance=distance, + return_metadata=MetadataQuery(distance=True), + ) + + weaviate_client.close() + if not response: + return [] + return response.objects + + except Exception as e: + logger.error(f"Error during search: {str(e)}") + return [] + finally: + try: + weaviate_client.close() + except Exception as e: + logger.info(f"Failed to close client: {str(e)}") diff --git a/src/codegate/utils/utils.py b/src/codegate/utils/utils.py new file mode 100644 index 00000000..e38de0a6 --- /dev/null +++ b/src/codegate/utils/utils.py @@ -0,0 +1,27 @@ +def generate_vector_string(package) -> str: + vector_str = f"{package['name']}" + package_url = "" + type_map = { + "pypi": "Python package available on PyPI", + "npm": "JavaScript package available on NPM", + "go": "Go package", + "crates": "Rust package available on Crates", + "java": "Java package", + } + status_messages = { + "archived": "However, this package is found to be archived and no longer maintained.", + "deprecated": "However, this package is found to be deprecated and no longer " + "recommended for use.", + "malicious": "However, this package is found to be malicious.", + } + vector_str += f" is a {type_map.get(package['type'], 'unknown type')} " + package_url = f"https://trustypkg.dev/{package['type']}/{package['name']}" + + # Add extra status + status_suffix = status_messages.get(package["status"], "") + if status_suffix: + vector_str += f"{status_suffix} For additional information refer to {package_url}" + + # add description + vector_str += f" - Package offers this functionality: {package['description']}" + return vector_str diff --git a/src/weaviate_data/classifications.db b/src/weaviate_data/classifications.db new file mode 100644 index 00000000..db8fa1b9 Binary files /dev/null and b/src/weaviate_data/classifications.db differ diff --git a/src/weaviate_data/migration1.19.filter2search.skip.flag b/src/weaviate_data/migration1.19.filter2search.skip.flag new file mode 100644 index 00000000..e69de29b diff --git a/src/weaviate_data/migration1.19.filter2search.state b/src/weaviate_data/migration1.19.filter2search.state new file mode 100644 index 00000000..e69de29b diff --git a/src/weaviate_data/migration1.22.fs.hierarchy b/src/weaviate_data/migration1.22.fs.hierarchy new file mode 100644 index 00000000..e69de29b diff --git a/src/weaviate_data/modules.db b/src/weaviate_data/modules.db new file mode 100644 index 00000000..0b745ffb Binary files /dev/null and b/src/weaviate_data/modules.db differ diff --git a/src/weaviate_data/package/9e4pu9kSqOe2/indexcount b/src/weaviate_data/package/9e4pu9kSqOe2/indexcount new file mode 100644 index 00000000..e69de29b diff --git a/src/weaviate_data/package/9e4pu9kSqOe2/main.hnsw.commitlog.d/1732804018 b/src/weaviate_data/package/9e4pu9kSqOe2/main.hnsw.commitlog.d/1732804018 new file mode 100644 index 00000000..e69de29b diff --git a/src/weaviate_data/package/9e4pu9kSqOe2/proplengths b/src/weaviate_data/package/9e4pu9kSqOe2/proplengths new file mode 100644 index 00000000..ea48a43d --- /dev/null +++ b/src/weaviate_data/package/9e4pu9kSqOe2/proplengths @@ -0,0 +1 @@ +{"BucketedData":{},"SumData":{},"CountData":{},"ObjectCount":0} \ No newline at end of file diff --git a/src/weaviate_data/package/9e4pu9kSqOe2/version b/src/weaviate_data/package/9e4pu9kSqOe2/version new file mode 100644 index 00000000..5407bf3d Binary files /dev/null and b/src/weaviate_data/package/9e4pu9kSqOe2/version differ diff --git a/src/weaviate_data/raft/raft.db b/src/weaviate_data/raft/raft.db new file mode 100644 index 00000000..5292bf64 Binary files /dev/null and b/src/weaviate_data/raft/raft.db differ diff --git a/src/weaviate_data/schema.db b/src/weaviate_data/schema.db new file mode 100644 index 00000000..8af6f9f6 Binary files /dev/null and b/src/weaviate_data/schema.db differ diff --git a/tests/test_storage.py b/tests/test_storage.py new file mode 100644 index 00000000..6e6ee6be --- /dev/null +++ b/tests/test_storage.py @@ -0,0 +1,56 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from codegate.storage.storage_engine import ( + StorageEngine, +) # Adjust the import based on your actual path + + +@pytest.fixture +def mock_weaviate_client(): + client = MagicMock() + response = MagicMock() + response.objects = [ + { + "properties": { + "name": "test", + "type": "library", + "status": "active", + "description": "test description", + } + } + ] + client.collections.get.return_value.query.near_vector.return_value = response + return client + + +@pytest.fixture +def mock_inference_engine(): + engine = AsyncMock() + engine.embed.return_value = [0.1, 0.2, 0.3] # Adjust based on expected vector dimensions + return engine + + +@pytest.mark.asyncio +async def test_search(mock_weaviate_client, mock_inference_engine): + # Patch the WeaviateClient and LlamaCppInferenceEngine inside the test function + with ( + patch("weaviate.WeaviateClient", return_value=mock_weaviate_client), + patch( + "codegate.inference.inference_engine.LlamaCppInferenceEngine", + return_value=mock_inference_engine, + ), + ): + + # Initialize StorageEngine + storage_engine = StorageEngine(data_path="./weaviate_data") + + # Invoke the search method + results = await storage_engine.search("test query", 5, 0.3) + + # Assertions to validate the expected behavior + assert len(results) == 1 # Assert that one result is returned + assert results[0]["properties"]["name"] == "test" + mock_weaviate_client.connect.assert_called() + mock_weaviate_client.close.assert_called() diff --git a/weaviate_data/classifications.db b/weaviate_data/classifications.db new file mode 100644 index 00000000..6f0b2331 Binary files /dev/null and b/weaviate_data/classifications.db differ diff --git a/weaviate_data/migration1.19.filter2search.skip.flag b/weaviate_data/migration1.19.filter2search.skip.flag new file mode 100644 index 00000000..e69de29b diff --git a/weaviate_data/migration1.19.filter2search.state b/weaviate_data/migration1.19.filter2search.state new file mode 100644 index 00000000..e69de29b diff --git a/weaviate_data/migration1.22.fs.hierarchy b/weaviate_data/migration1.22.fs.hierarchy new file mode 100644 index 00000000..e69de29b diff --git a/weaviate_data/modules.db b/weaviate_data/modules.db new file mode 100644 index 00000000..0b745ffb Binary files /dev/null and b/weaviate_data/modules.db differ diff --git a/weaviate_data/package/yhcabdxdWUhw/indexcount b/weaviate_data/package/yhcabdxdWUhw/indexcount new file mode 100644 index 00000000..e69de29b diff --git a/weaviate_data/package/yhcabdxdWUhw/main.hnsw.commitlog.d/1732894628 b/weaviate_data/package/yhcabdxdWUhw/main.hnsw.commitlog.d/1732894628 new file mode 100644 index 00000000..e69de29b diff --git a/weaviate_data/package/yhcabdxdWUhw/proplengths b/weaviate_data/package/yhcabdxdWUhw/proplengths new file mode 100644 index 00000000..ea48a43d --- /dev/null +++ b/weaviate_data/package/yhcabdxdWUhw/proplengths @@ -0,0 +1 @@ +{"BucketedData":{},"SumData":{},"CountData":{},"ObjectCount":0} \ No newline at end of file diff --git a/weaviate_data/package/yhcabdxdWUhw/version b/weaviate_data/package/yhcabdxdWUhw/version new file mode 100644 index 00000000..5407bf3d Binary files /dev/null and b/weaviate_data/package/yhcabdxdWUhw/version differ diff --git a/weaviate_data/raft/raft.db b/weaviate_data/raft/raft.db new file mode 100644 index 00000000..1105373f Binary files /dev/null and b/weaviate_data/raft/raft.db differ diff --git a/weaviate_data/schema.db b/weaviate_data/schema.db new file mode 100644 index 00000000..f7f96e10 Binary files /dev/null and b/weaviate_data/schema.db differ