-
Notifications
You must be signed in to change notification settings - Fork 83
Tune context and prompts #200
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,10 @@ | ||
from typing import Optional | ||
import json | ||
|
||
from litellm import ChatCompletionRequest, ChatCompletionSystemMessage | ||
import structlog | ||
from litellm import ChatCompletionRequest | ||
|
||
from codegate.config import Config | ||
from codegate.inference.inference_engine import LlamaCppInferenceEngine | ||
from codegate.pipeline.base import ( | ||
PipelineContext, | ||
PipelineResult, | ||
|
@@ -10,18 +13,18 @@ | |
from src.codegate.storage.storage_engine import StorageEngine | ||
from src.codegate.utils.utils import generate_vector_string | ||
|
||
logger = structlog.get_logger("codegate") | ||
|
||
|
||
class CodegateContextRetriever(PipelineStep): | ||
""" | ||
Pipeline step that adds a context message to the completion request when it detects | ||
the word "codegate" in the user message. | ||
""" | ||
|
||
def __init__(self, system_prompt_message: Optional[str] = None): | ||
self._system_message = ChatCompletionSystemMessage( | ||
content=system_prompt_message, role="system" | ||
) | ||
def __init__(self): | ||
self.storage_engine = StorageEngine() | ||
self.inference_engine = LlamaCppInferenceEngine() | ||
|
||
@property | ||
def name(self) -> str: | ||
|
@@ -30,8 +33,10 @@ def name(self) -> str: | |
""" | ||
return "codegate-context-retriever" | ||
|
||
async def get_objects_from_search(self, search: str) -> list[object]: | ||
objects = await self.storage_engine.search(search, distance=0.5) | ||
async def get_objects_from_search( | ||
self, search: str, packages: list[str] = None | ||
) -> list[object]: | ||
objects = await self.storage_engine.search(search, distance=0.8, packages=packages) | ||
return objects | ||
|
||
def generate_context_str(self, objects: list[object]) -> str: | ||
|
@@ -48,49 +53,87 @@ def generate_context_str(self, objects: list[object]) -> str: | |
context_str += package_str + "\n" | ||
return context_str | ||
|
||
async def __lookup_packages(self, user_query: str): | ||
## Check which packages are referenced in the user query | ||
request = { | ||
"messages": [ | ||
{"role": "system", "content": Config.get_config().prompts.lookup_packages}, | ||
{"role": "user", "content": user_query}, | ||
], | ||
"model": "qwen2-1_5b-instruct-q5_k_m", | ||
"stream": False, | ||
"response_format": {"type": "json_object"}, | ||
"temperature": 0, | ||
} | ||
|
||
result = await self.inference_engine.chat( | ||
f"{Config.get_config().model_base_path}/{request['model']}.gguf", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what do you think about wrapping this in a try/except and returning an empty list and logging an error? In general our exception handling is not great (unrelated to this PR of course) and I was wondering if it would make sense to mark pipeline steps as critical or nice-to-have and handle exceptions in the pipeline processor rather than having to handle them in the steps themselves. That would be outside the scope of this patch, for this one I just wonder about wrapping the chat in try/except and returning [] in case of an exception. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh and since we were talking on slack about performance of the local vs remote model, just noting here that the local LLM takes anywhere between 1.5 - 4 seconds on my laptop. I will also measure the hosted LLMs for the same task. |
||
n_ctx=Config.get_config().chat_model_n_ctx, | ||
n_gpu_layers=Config.get_config().chat_model_n_gpu_layers, | ||
**request, | ||
) | ||
|
||
result = json.loads(result["choices"][0]["message"]["content"]) | ||
logger.info(f"Packages in user query: {result['packages']}") | ||
return result["packages"] | ||
|
||
async def process( | ||
self, request: ChatCompletionRequest, context: PipelineContext | ||
) -> PipelineResult: | ||
""" | ||
Process the completion request and add a system prompt if the user message contains | ||
the word "codegate". | ||
Use RAG DB to add context to the user request | ||
""" | ||
# no prompt configured | ||
if not self._system_message["content"]: | ||
return PipelineResult(request=request) | ||
|
||
# Get the last user message | ||
last_user_message = self.get_last_user_message(request) | ||
|
||
if last_user_message is not None: | ||
last_user_message_str, last_user_idx = last_user_message | ||
if last_user_message_str.lower(): | ||
# Look for matches in vector DB | ||
searched_objects = await self.get_objects_from_search(last_user_message_str) | ||
|
||
# If matches are found, add the matched content to context | ||
if len(searched_objects) > 0: | ||
context_str = self.generate_context_str(searched_objects) | ||
|
||
# Make a copy of the request | ||
new_request = request.copy() | ||
|
||
# Add the context to the last user message | ||
# Format: "Context: {context_str} \n Query: {last user message conent}" | ||
# Handle the two cases: (a) message content is str, (b)message content | ||
# is list | ||
message = new_request["messages"][last_user_idx] | ||
if isinstance(message["content"], str): | ||
message["content"] = ( | ||
f'Context: {context_str} \n\n Query: {message["content"]}' | ||
) | ||
elif isinstance(message["content"], (list, tuple)): | ||
for item in message["content"]: | ||
if isinstance(item, dict) and item.get("type") == "text": | ||
item["text"] = f'Context: {context_str} \n\n Query: {item["text"]}' | ||
|
||
return PipelineResult( | ||
request=new_request, | ||
) | ||
# Nothing to do if the last user message is none | ||
if last_user_message is None: | ||
return PipelineResult(request=request) | ||
|
||
# Extract packages from the user message | ||
last_user_message_str, last_user_idx = last_user_message | ||
packages = await self.__lookup_packages(last_user_message_str) | ||
|
||
# If user message does not reference any packages, then just return | ||
if len(packages) == 0: | ||
return PipelineResult(request=request) | ||
|
||
# Look for matches in vector DB using list of packages as filter | ||
searched_objects = await self.get_objects_from_search(last_user_message_str, packages) | ||
|
||
# If matches are found, add the matched content to context | ||
if len(searched_objects) > 0: | ||
# Remove searched objects that are not in packages. This is needed | ||
# since Weaviate performs substring match in the filter. | ||
updated_searched_objects = [] | ||
for searched_object in searched_objects: | ||
if searched_object.properties["name"] in packages: | ||
updated_searched_objects.append(searched_object) | ||
searched_objects = updated_searched_objects | ||
|
||
# Generate context string using the searched objects | ||
logger.info(f"Adding {len(searched_objects)} packages to the context") | ||
context_str = self.generate_context_str(searched_objects) | ||
|
||
# Make a copy of the request | ||
new_request = request.copy() | ||
|
||
# Add the context to the last user message | ||
# Format: "Context: {context_str} \n Query: {last user message conent}" | ||
# Handle the two cases: (a) message content is str, (b)message content | ||
# is list | ||
message = new_request["messages"][last_user_idx] | ||
if isinstance(message["content"], str): | ||
ptelang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
message["content"] = f'Context: {context_str} \n\n Query: {message["content"]}' | ||
elif isinstance(message["content"], (list, tuple)): | ||
for item in message["content"]: | ||
if isinstance(item, dict) and item.get("type") == "text": | ||
item["text"] = f'Context: {context_str} \n\n Query: {item["text"]}' | ||
|
||
return PipelineResult( | ||
request=new_request, | ||
) | ||
|
||
# Fall through | ||
return PipelineResult(request=request) |
This file was deleted.
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from codegate.pipeline.system_prompt.codegate import SystemPrompt | ||
|
||
__all__ = ["SystemPrompt"] |
Uh oh!
There was an error while loading. Please reload this page.