Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 466d495

Browse files
authored
Merge pull request #128 from stacklok/issue-63
feat: enable weaviate usage in codegate
2 parents d4f1ab8 + 529ef2d commit 466d495

31 files changed

+286
-29
lines changed

poetry.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

scripts/import_packages.py

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from weaviate.util import generate_uuid5
88

99
from codegate.inference.inference_engine import LlamaCppInferenceEngine
10+
from src.codegate.utils.utils import generate_vector_string
1011

1112

1213
class PackageImporter:
@@ -37,33 +38,8 @@ def setup_schema(self):
3738
],
3839
)
3940

40-
def generate_vector_string(self, package):
41-
vector_str = f"{package['name']}"
42-
package_url = ""
43-
type_map = {
44-
"pypi": "Python package available on PyPI",
45-
"npm": "JavaScript package available on NPM",
46-
"go": "Go package",
47-
"crates": "Rust package available on Crates",
48-
"java": "Java package",
49-
}
50-
status_messages = {
51-
"archived": "However, this package is found to be archived and no longer maintained.",
52-
"deprecated": "However, this package is found to be deprecated and no longer "
53-
"recommended for use.",
54-
"malicious": "However, this package is found to be malicious.",
55-
}
56-
vector_str += f" is a {type_map.get(package['type'], 'unknown type')} "
57-
package_url = f"https://trustypkg.dev/{package['type']}/{package['name']}"
58-
59-
# Add extra status
60-
status_suffix = status_messages.get(package["status"], "")
61-
if status_suffix:
62-
vector_str += f"{status_suffix} For additional information refer to {package_url}"
63-
return vector_str
64-
6541
async def process_package(self, batch, package):
66-
vector_str = self.generate_vector_string(package)
42+
vector_str = generate_vector_string(package)
6743
vector = await self.inference_engine.embed(self.model_path, [vector_str])
6844
# This is where the synchronous call is made
6945
batch.add_object(properties=package, vector=vector[0])
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from codegate.pipeline.codegate_context_retriever.codegate import CodegateContextRetriever
2+
3+
__all__ = ["CodegateContextRetriever"]
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from typing import Optional
2+
3+
from litellm import ChatCompletionRequest, ChatCompletionSystemMessage
4+
5+
from codegate.pipeline.base import (
6+
PipelineContext,
7+
PipelineResult,
8+
PipelineStep,
9+
)
10+
from src.codegate.storage.storage_engine import StorageEngine
11+
from src.codegate.utils.utils import generate_vector_string
12+
13+
14+
class CodegateContextRetriever(PipelineStep):
15+
"""
16+
Pipeline step that adds a context message to the completion request when it detects
17+
the word "codegate" in the user message.
18+
"""
19+
20+
def __init__(self, system_prompt_message: Optional[str] = None):
21+
self._system_message = ChatCompletionSystemMessage(
22+
content=system_prompt_message, role="system"
23+
)
24+
self.storage_engine = StorageEngine()
25+
26+
@property
27+
def name(self) -> str:
28+
"""
29+
Returns the name of this pipeline step.
30+
"""
31+
return "codegate-context-retriever"
32+
33+
async def get_objects_from_search(self, search: str) -> list[object]:
34+
objects = await self.storage_engine.search(search)
35+
return objects
36+
37+
def generate_context_str(self, objects: list[object]) -> str:
38+
context_str = "Please use the information about related packages "
39+
"to influence your answer:\n"
40+
for obj in objects:
41+
# generate dictionary from object
42+
package_obj = {
43+
"name": obj.properties["name"],
44+
"type": obj.properties["type"],
45+
"status": obj.properties["status"],
46+
"description": obj.properties["description"],
47+
}
48+
package_str = generate_vector_string(package_obj)
49+
context_str += package_str + "\n"
50+
return context_str
51+
52+
async def process(
53+
self, request: ChatCompletionRequest, context: PipelineContext
54+
) -> PipelineResult:
55+
"""
56+
Process the completion request and add a system prompt if the user message contains
57+
the word "codegate".
58+
"""
59+
# no prompt configured
60+
if not self._system_message["content"]:
61+
return PipelineResult(request=request)
62+
63+
last_user_message = self.get_last_user_message(request)
64+
65+
if last_user_message is not None:
66+
last_user_message_str, last_user_idx = last_user_message
67+
if "codegate" in last_user_message_str.lower():
68+
# strip codegate from prompt and trim it
69+
last_user_message_str = (
70+
last_user_message_str.lower().replace("codegate", "").strip()
71+
)
72+
searched_objects = await self.get_objects_from_search(last_user_message_str)
73+
context_str = self.generate_context_str(searched_objects)
74+
# Add a system prompt to the completion request
75+
new_request = request.copy()
76+
new_request["messages"].insert(last_user_idx, context_str)
77+
return PipelineResult(
78+
request=new_request,
79+
)
80+
81+
# Fall through
82+
return PipelineResult(request=request)

src/codegate/server.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from codegate import __description__, __version__
66
from codegate.config import Config
77
from codegate.pipeline.base import PipelineStep, SequentialPipelineProcessor
8+
from codegate.pipeline.codegate_context_retriever.codegate import CodegateContextRetriever
89
from codegate.pipeline.codegate_system_prompt.codegate import CodegateSystemPrompt
910
from codegate.pipeline.extract_snippets.extract_snippets import CodeSnippetExtractor
1011
from codegate.pipeline.version.version import CodegateVersion
@@ -27,6 +28,7 @@ def init_app() -> FastAPI:
2728
CodegateVersion(),
2829
CodeSnippetExtractor(),
2930
CodegateSystemPrompt(Config.get_config().prompts.codegate_chat),
31+
CodegateContextRetriever(Config.get_config().prompts.codegate_chat),
3032
# CodegateSecrets(),
3133
]
3234
# Leaving the pipeline empty for now

src/codegate/storage/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from codegate.storage.storage_engine import StorageEngine
2+
3+
__all__ = [StorageEngine]
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import structlog
2+
import weaviate
3+
from weaviate.classes.config import DataType
4+
from weaviate.classes.query import MetadataQuery
5+
6+
from codegate.inference.inference_engine import LlamaCppInferenceEngine
7+
8+
logger = structlog.get_logger("codegate")
9+
10+
schema_config = [
11+
{
12+
"name": "Package",
13+
"properties": [
14+
{"name": "name", "data_type": DataType.TEXT},
15+
{"name": "type", "data_type": DataType.TEXT},
16+
{"name": "status", "data_type": DataType.TEXT},
17+
{"name": "description", "data_type": DataType.TEXT},
18+
],
19+
},
20+
]
21+
22+
23+
class StorageEngine:
24+
def get_client(self, data_path):
25+
try:
26+
client = weaviate.WeaviateClient(
27+
embedded_options=weaviate.EmbeddedOptions(persistence_data_path=data_path),
28+
)
29+
return client
30+
except Exception as e:
31+
logger.error(f"Error during client creation: {str(e)}")
32+
return None
33+
34+
def __init__(self, data_path="./weaviate_data"):
35+
self.data_path = data_path
36+
self.inference_engine = LlamaCppInferenceEngine()
37+
self.model_path = "./models/all-minilm-L6-v2-q5_k_m.gguf"
38+
self.schema_config = schema_config
39+
40+
# setup schema for weaviate
41+
weaviate_client = self.get_client(self.data_path)
42+
if weaviate_client is not None:
43+
try:
44+
weaviate_client.connect()
45+
self.setup_schema(weaviate_client)
46+
except Exception as e:
47+
logger.error(f"Failed to connect or setup schema: {str(e)}")
48+
finally:
49+
try:
50+
weaviate_client.close()
51+
except Exception as e:
52+
logger.info(f"Failed to close client: {str(e)}")
53+
else:
54+
logger.error("Could not find client, skipping schema setup.")
55+
56+
def setup_schema(self, client):
57+
for class_config in self.schema_config:
58+
if not client.collections.exists(class_config["name"]):
59+
client.collections.create(
60+
class_config["name"], properties=class_config["properties"]
61+
)
62+
logger.info(f"Weaviate schema for class {class_config['name']} setup complete.")
63+
64+
async def search(self, query: str, limit=5, distance=0.3) -> list[object]:
65+
"""
66+
Search the 'Package' collection based on a query string.
67+
68+
Args:
69+
query (str): The text query for which to search.
70+
limit (int): The number of results to return.
71+
72+
Returns:
73+
list: A list of matching results with their properties and distances.
74+
"""
75+
# Generate the vector for the query
76+
query_vector = await self.inference_engine.embed(self.model_path, [query])
77+
78+
# Perform the vector search
79+
weaviate_client = self.get_client(self.data_path)
80+
if weaviate_client is None:
81+
logger.error("Could not find client, not returning results.")
82+
return []
83+
84+
try:
85+
weaviate_client.connect()
86+
collection = weaviate_client.collections.get("Package")
87+
response = collection.query.near_vector(
88+
query_vector[0],
89+
limit=limit,
90+
distance=distance,
91+
return_metadata=MetadataQuery(distance=True),
92+
)
93+
94+
weaviate_client.close()
95+
if not response:
96+
return []
97+
return response.objects
98+
99+
except Exception as e:
100+
logger.error(f"Error during search: {str(e)}")
101+
return []
102+
finally:
103+
try:
104+
weaviate_client.close()
105+
except Exception as e:
106+
logger.info(f"Failed to close client: {str(e)}")

src/codegate/utils/utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
def generate_vector_string(package) -> str:
2+
vector_str = f"{package['name']}"
3+
package_url = ""
4+
type_map = {
5+
"pypi": "Python package available on PyPI",
6+
"npm": "JavaScript package available on NPM",
7+
"go": "Go package",
8+
"crates": "Rust package available on Crates",
9+
"java": "Java package",
10+
}
11+
status_messages = {
12+
"archived": "However, this package is found to be archived and no longer maintained.",
13+
"deprecated": "However, this package is found to be deprecated and no longer "
14+
"recommended for use.",
15+
"malicious": "However, this package is found to be malicious.",
16+
}
17+
vector_str += f" is a {type_map.get(package['type'], 'unknown type')} "
18+
package_url = f"https://trustypkg.dev/{package['type']}/{package['name']}"
19+
20+
# Add extra status
21+
status_suffix = status_messages.get(package["status"], "")
22+
if status_suffix:
23+
vector_str += f"{status_suffix} For additional information refer to {package_url}"
24+
25+
# add description
26+
vector_str += f" - Package offers this functionality: {package['description']}"
27+
return vector_str

src/weaviate_data/classifications.db

128 KB
Binary file not shown.

src/weaviate_data/migration1.19.filter2search.skip.flag

Whitespace-only changes.

0 commit comments

Comments
 (0)