diff --git a/prompts/default.yaml b/prompts/default.yaml index 0ec7ccd0..92f029ee 100644 --- a/prompts/default.yaml +++ b/prompts/default.yaml @@ -19,14 +19,20 @@ default_chat: | if no malicious, deprecated or archived package is detected. If no malicious, deprecated or archived packages are detected, respond to the user request using your prior - knowledge. You always begin your responses with: - "**CodeGate Security Analysis**" - Structure your responses to emphasize security considerations, focusing on: + knowledge. + + If the user requests security review, perform the review and structure your responses to emphasize + security considerations, focusing on: 1. Security vulnerabilities and potential risks 2. Best practices for secure implementation 3. Recommendations for security improvements 4. References to relevant security standards or guidelines + You always begin your responses for security review with: + "**CodeGate Security Analysis**" + + If you see a string that begins with REDACTED word, DO NOT MODIFY THAT STRING while responding back. + lookup_packages: | You are a software expert with knowledge of packages from various ecosystems. Your job is to extract any software packages from user's request. diff --git a/src/codegate/pipeline/base.py b/src/codegate/pipeline/base.py index d39da34f..b96ed07b 100644 --- a/src/codegate/pipeline/base.py +++ b/src/codegate/pipeline/base.py @@ -178,29 +178,30 @@ def get_last_user_message( for i in reversed(range(len(request["messages"]))): if request["messages"][i]["role"] == "user": content = request["messages"][i]["content"] + return content, i - # This is really another LiteLLM weirdness. Depending on the - # provider inside the ChatCompletionRequest you might either - # have a string or a list of Union, one of which is a - # ChatCompletionTextObject. We'll handle this better by - # either dumping litellm completely or converting to a more sane - # format # in our own adapter + return None - # Handle string content - if isinstance(content, str): - return content, i + @staticmethod + def get_last_user_message_idx(request: ChatCompletionRequest) -> int: + if request.get("messages") is None: + return -1 - # Handle iterable of ChatCompletionTextObject - if isinstance(content, (list, tuple)): - # Find first text content - for item in content: - if isinstance(item, dict) and item.get("type") == "text": - return item["text"], i + for idx, message in reversed(list(enumerate(request['messages']))): + if message.get("role", "") == "user": + return idx - # If no text content found, return None - return None + return -1 - return None + @staticmethod + def get_all_user_messages(request: ChatCompletionRequest) -> str: + all_user_messages = "" + + for message in request.get("messages", []): + if message["role"] == "user": + all_user_messages += "\n" + message["content"] + + return all_user_messages @abstractmethod async def process( diff --git a/src/codegate/pipeline/codegate_context_retriever/codegate.py b/src/codegate/pipeline/codegate_context_retriever/codegate.py index 6da1da93..fb333e42 100644 --- a/src/codegate/pipeline/codegate_context_retriever/codegate.py +++ b/src/codegate/pipeline/codegate_context_retriever/codegate.py @@ -93,17 +93,16 @@ async def process( Use RAG DB to add context to the user request """ - # Get the last user message - last_user_message = self.get_last_user_message(request) + # Get all user messages + user_messages = self.get_all_user_messages(request) - # Nothing to do if the last user message is none - if last_user_message is None: + # Nothing to do if the user_messages string is empty + if len(user_messages) == 0: return PipelineResult(request=request) # Extract packages from the user message - last_user_message_str, last_user_idx = last_user_message - ecosystem = await self.__lookup_ecosystem(last_user_message_str, context) - packages = await self.__lookup_packages(last_user_message_str, context) + ecosystem = await self.__lookup_ecosystem(user_messages, context) + packages = await self.__lookup_packages(user_messages, context) packages = [pkg.lower() for pkg in packages] # If user message does not reference any packages, then just return @@ -112,7 +111,7 @@ async def process( # Look for matches in vector DB using list of packages as filter searched_objects = await self.get_objects_from_search( - last_user_message_str, ecosystem, packages + user_messages, ecosystem, packages ) logger.info( @@ -136,24 +135,18 @@ async def process( else: context_str = "Codegate did not find any malicious or archived packages." + last_user_idx = self.get_last_user_message_idx(request) + if last_user_idx == -1: + return PipelineResult(request=request, context=context) + # Make a copy of the request new_request = request.copy() # Add the context to the last user message - # Format: "Context: {context_str} \n Query: {last user message conent}" - # Handle the two cases: (a) message content is str, (b)message content - # is list + # Format: "Context: {context_str} \n Query: {last user message content}" message = new_request["messages"][last_user_idx] - if isinstance(message["content"], str): - context_msg = f'Context: {context_str} \n\n Query: {message["content"]}' - message["content"] = context_msg - elif isinstance(message["content"], (list, tuple)): - for item in message["content"]: - if isinstance(item, dict) and item.get("type") == "text": - context_msg = f'Context: {context_str} \n\n Query: {item["text"]}' - item["text"] = context_msg - - return PipelineResult(request=new_request, context=context) - - # Fall through - return PipelineResult(request=request, context=context) + context_msg = f'Context: {context_str} \n\n Query: {message["content"]}' + message["content"] = context_msg + + return PipelineResult(request=new_request, context=context) + diff --git a/src/codegate/providers/copilot/pipeline.py b/src/codegate/providers/copilot/pipeline.py index f41d7531..d20dee86 100644 --- a/src/codegate/providers/copilot/pipeline.py +++ b/src/codegate/providers/copilot/pipeline.py @@ -21,7 +21,7 @@ class CopilotPipeline(ABC): def __init__(self, pipeline_factory): self.pipeline_factory = pipeline_factory self.normalizer = self._create_normalizer() - self.provider_name = "copilot" + self.provider_name = "openai" @abstractmethod def _create_normalizer(self):