Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Inspect all user messages for malicious packages #318

Merged
merged 1 commit into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions prompts/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,20 @@ default_chat: |
if no malicious, deprecated or archived package is detected.

If no malicious, deprecated or archived packages are detected, respond to the user request using your prior
knowledge. You always begin your responses with:
"**CodeGate Security Analysis**"
Structure your responses to emphasize security considerations, focusing on:
knowledge.

If the user requests security review, perform the review and structure your responses to emphasize
security considerations, focusing on:
1. Security vulnerabilities and potential risks
2. Best practices for secure implementation
3. Recommendations for security improvements
4. References to relevant security standards or guidelines

You always begin your responses for security review with:
"**CodeGate Security Analysis**"

If you see a string that begins with REDACTED word, DO NOT MODIFY THAT STRING while responding back.

lookup_packages: |
You are a software expert with knowledge of packages from various ecosystems.
Your job is to extract any software packages from user's request.
Expand Down
37 changes: 19 additions & 18 deletions src/codegate/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,29 +178,30 @@ def get_last_user_message(
for i in reversed(range(len(request["messages"]))):
if request["messages"][i]["role"] == "user":
content = request["messages"][i]["content"]
return content, i

# This is really another LiteLLM weirdness. Depending on the
# provider inside the ChatCompletionRequest you might either
# have a string or a list of Union, one of which is a
# ChatCompletionTextObject. We'll handle this better by
# either dumping litellm completely or converting to a more sane
# format # in our own adapter
return None

# Handle string content
if isinstance(content, str):
return content, i
@staticmethod
def get_last_user_message_idx(request: ChatCompletionRequest) -> int:
if request.get("messages") is None:
return -1

# Handle iterable of ChatCompletionTextObject
if isinstance(content, (list, tuple)):
# Find first text content
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
return item["text"], i
for idx, message in reversed(list(enumerate(request['messages']))):
if message.get("role", "") == "user":
return idx

# If no text content found, return None
return None
return -1

return None
@staticmethod
def get_all_user_messages(request: ChatCompletionRequest) -> str:
all_user_messages = ""

for message in request.get("messages", []):
if message["role"] == "user":
all_user_messages += "\n" + message["content"]

return all_user_messages

@abstractmethod
async def process(
Expand Down
41 changes: 17 additions & 24 deletions src/codegate/pipeline/codegate_context_retriever/codegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,16 @@ async def process(
Use RAG DB to add context to the user request
"""

# Get the last user message
last_user_message = self.get_last_user_message(request)
# Get all user messages
user_messages = self.get_all_user_messages(request)

# Nothing to do if the last user message is none
if last_user_message is None:
# Nothing to do if the user_messages string is empty
if len(user_messages) == 0:
return PipelineResult(request=request)

# Extract packages from the user message
last_user_message_str, last_user_idx = last_user_message
ecosystem = await self.__lookup_ecosystem(last_user_message_str, context)
packages = await self.__lookup_packages(last_user_message_str, context)
ecosystem = await self.__lookup_ecosystem(user_messages, context)
packages = await self.__lookup_packages(user_messages, context)
packages = [pkg.lower() for pkg in packages]

# If user message does not reference any packages, then just return
Expand All @@ -112,7 +111,7 @@ async def process(

# Look for matches in vector DB using list of packages as filter
searched_objects = await self.get_objects_from_search(
last_user_message_str, ecosystem, packages
user_messages, ecosystem, packages
)

logger.info(
Expand All @@ -136,24 +135,18 @@ async def process(
else:
context_str = "Codegate did not find any malicious or archived packages."

last_user_idx = self.get_last_user_message_idx(request)
if last_user_idx == -1:
return PipelineResult(request=request, context=context)

# Make a copy of the request
new_request = request.copy()

# Add the context to the last user message
# Format: "Context: {context_str} \n Query: {last user message conent}"
# Handle the two cases: (a) message content is str, (b)message content
# is list
# Format: "Context: {context_str} \n Query: {last user message content}"
message = new_request["messages"][last_user_idx]
if isinstance(message["content"], str):
context_msg = f'Context: {context_str} \n\n Query: {message["content"]}'
message["content"] = context_msg
elif isinstance(message["content"], (list, tuple)):
for item in message["content"]:
if isinstance(item, dict) and item.get("type") == "text":
context_msg = f'Context: {context_str} \n\n Query: {item["text"]}'
item["text"] = context_msg

return PipelineResult(request=new_request, context=context)

# Fall through
return PipelineResult(request=request, context=context)
context_msg = f'Context: {context_str} \n\n Query: {message["content"]}'
message["content"] = context_msg

return PipelineResult(request=new_request, context=context)

2 changes: 1 addition & 1 deletion src/codegate/providers/copilot/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class CopilotPipeline(ABC):
def __init__(self, pipeline_factory):
self.pipeline_factory = pipeline_factory
self.normalizer = self._create_normalizer()
self.provider_name = "copilot"
self.provider_name = "openai"

@abstractmethod
def _create_normalizer(self):
Expand Down
Loading