Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit abafcc0

Browse files
committed
inject data from weaviate in context
1 parent ef28395 commit abafcc0

File tree

6 files changed

+117
-29
lines changed

6 files changed

+117
-29
lines changed

scripts/import_packages.py

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import json
33

4+
from src.codegate.utils.utils import generate_vector_string
45
import weaviate
56
from weaviate.classes.config import DataType, Property
67
from weaviate.embedded import EmbeddedOptions
@@ -37,33 +38,8 @@ def setup_schema(self):
3738
],
3839
)
3940

40-
def generate_vector_string(self, package):
41-
vector_str = f"{package['name']}"
42-
package_url = ""
43-
type_map = {
44-
"pypi": "Python package available on PyPI",
45-
"npm": "JavaScript package available on NPM",
46-
"go": "Go package",
47-
"crates": "Rust package available on Crates",
48-
"java": "Java package",
49-
}
50-
status_messages = {
51-
"archived": "However, this package is found to be archived and no longer maintained.",
52-
"deprecated": "However, this package is found to be deprecated and no longer "
53-
"recommended for use.",
54-
"malicious": "However, this package is found to be malicious.",
55-
}
56-
vector_str += f" is a {type_map.get(package['type'], 'unknown type')} "
57-
package_url = f"https://trustypkg.dev/{package['type']}/{package['name']}"
58-
59-
# Add extra status
60-
status_suffix = status_messages.get(package["status"], "")
61-
if status_suffix:
62-
vector_str += f"{status_suffix} For additional information refer to {package_url}"
63-
return vector_str
64-
6541
async def process_package(self, batch, package):
66-
vector_str = self.generate_vector_string(package)
42+
vector_str = generate_vector_string(package)
6743
vector = await self.inference_engine.embed(self.model_path, [vector_str])
6844
# This is where the synchronous call is made
6945
batch.add_object(properties=package, vector=vector[0])
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from codegate.pipeline.codegate_context_retriever.codegate import CodegateContextRetriever
2+
3+
__all__ = ["CodegateContextRetriever"]
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from typing import Optional
2+
3+
from litellm import ChatCompletionRequest, ChatCompletionSystemMessage
4+
5+
from codegate.pipeline.base import (
6+
PipelineContext,
7+
PipelineResult,
8+
PipelineStep,
9+
)
10+
from src.codegate.storage.storage_engine import StorageEngine
11+
from src.codegate.utils.utils import generate_vector_string
12+
13+
14+
class CodegateContextRetriever(PipelineStep):
15+
"""
16+
Pipeline step that adds a context message to the completion request when it detects
17+
the word "codegate" in the user message.
18+
"""
19+
20+
def __init__(self, system_prompt_message: Optional[str] = None):
21+
self._system_message = ChatCompletionSystemMessage(
22+
content=system_prompt_message,
23+
role="system"
24+
)
25+
self.storage_engine = StorageEngine()
26+
27+
@property
28+
def name(self) -> str:
29+
"""
30+
Returns the name of this pipeline step.
31+
"""
32+
return "codegate-context-retriever"
33+
34+
async def get_objects_from_search(self, search: str) -> list[object]:
35+
objects = await self.storage_engine.search(search)
36+
return objects
37+
38+
def generate_context_str(self, objects: list[object]) -> str:
39+
context_str = "Please use the information about related packages to influence your answer:\n"
40+
for obj in objects:
41+
# generate dictionary from object
42+
package_obj = {
43+
"name": obj.properties["name"],
44+
"type": obj.properties["type"],
45+
"status": obj.properties["status"],
46+
"description": obj.properties["description"],
47+
}
48+
package_str = generate_vector_string(package_obj)
49+
context_str += package_str + "\n"
50+
return context_str
51+
52+
async def process(
53+
self, request: ChatCompletionRequest, context: PipelineContext
54+
) -> PipelineResult:
55+
"""
56+
Process the completion request and add a system prompt if the user message contains
57+
the word "codegate".
58+
"""
59+
# no prompt configured
60+
if not self._system_message["content"]:
61+
return PipelineResult(request=request)
62+
63+
last_user_message = self.get_last_user_message(request)
64+
65+
if last_user_message is not None:
66+
last_user_message_str, last_user_idx = last_user_message
67+
if "codegate" in last_user_message_str.lower():
68+
# strip codegate from prompt and trim it
69+
last_user_message_str = last_user_message_str.lower().replace("codegate", "").strip()
70+
searched_objects = await self.get_objects_from_search(last_user_message_str)
71+
context_str = self.generate_context_str(searched_objects)
72+
# Add a system prompt to the completion request
73+
new_request = request.copy()
74+
new_request["messages"].insert(last_user_idx, context_str)
75+
return PipelineResult(
76+
request=new_request,
77+
)
78+
79+
# Fall through
80+
return PipelineResult(request=request)

src/codegate/server.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from codegate.config import Config
77
from codegate.pipeline.base import PipelineStep, SequentialPipelineProcessor
88
from codegate.pipeline.codegate_system_prompt.codegate import CodegateSystemPrompt
9+
from codegate.pipeline.codegate_context_retriever.codegate import CodegateContextRetriever
910
from codegate.pipeline.version.version import CodegateVersion
1011
from codegate.providers.anthropic.provider import AnthropicProvider
1112
from codegate.providers.llamacpp.provider import LlamaCppProvider
@@ -24,6 +25,7 @@ def init_app() -> FastAPI:
2425
steps: List[PipelineStep] = [
2526
CodegateVersion(),
2627
CodegateSystemPrompt(Config.get_config().prompts.codegate_chat),
28+
CodegateContextRetriever(Config.get_config().prompts.codegate_chat),
2729
# CodegateSecrets(),
2830
]
2931
# Leaving the pipeline empty for now

src/codegate/storage/storage_engine.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from codegate.codegate_logging import setup_logging
22
from codegate.inference.inference_engine import LlamaCppInferenceEngine
3-
from weaviate.classes.config import DataType, Property
3+
from weaviate.classes.config import DataType
44
from weaviate.classes.query import MetadataQuery
55
import weaviate
66

@@ -45,7 +45,7 @@ def setup_schema(self):
4545
self.client.collections.create(class_config['name'], properties=class_config['properties'])
4646
self.__logger.info(f"Weaviate schema for class {class_config['name']} setup complete.")
4747

48-
async def search(self, query, limit=5, distance=0.1):
48+
async def search(self, query: str, limit=5, distance=0.3) -> list[object]:
4949
"""
5050
Search the 'Package' collection based on a query string.
5151
@@ -62,7 +62,7 @@ async def search(self, query, limit=5, distance=0.1):
6262
# Perform the vector search
6363
try:
6464
collection = self.client.collections.get("Package")
65-
response = collection.query.near_vector(query_vector, limit=limit, distance=distance, return_metadata=MetadataQuery(distance=True))
65+
response = collection.query.near_vector(query_vector[0], limit=limit, distance=distance, return_metadata=MetadataQuery(distance=True))
6666
if not response:
6767
return []
6868
return response.objects

src/codegate/utils/utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
def generate_vector_string(package) -> str:
2+
vector_str = f"{package['name']}"
3+
package_url = ""
4+
type_map = {
5+
"pypi": "Python package available on PyPI",
6+
"npm": "JavaScript package available on NPM",
7+
"go": "Go package",
8+
"crates": "Rust package available on Crates",
9+
"java": "Java package",
10+
}
11+
status_messages = {
12+
"archived": "However, this package is found to be archived and no longer maintained.",
13+
"deprecated": "However, this package is found to be deprecated and no longer "
14+
"recommended for use.",
15+
"malicious": "However, this package is found to be malicious.",
16+
}
17+
vector_str += f" is a {type_map.get(package['type'], 'unknown type')} "
18+
package_url = f"https://trustypkg.dev/{package['type']}/{package['name']}"
19+
20+
# Add extra status
21+
status_suffix = status_messages.get(package["status"], "")
22+
if status_suffix:
23+
vector_str += f"{status_suffix} For additional information refer to {package_url}"
24+
25+
# add description
26+
vector_str += f" - Package offers this functionality: {package['description']}"
27+
return vector_str

0 commit comments

Comments
 (0)