Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Normalize key messages for all providers #170

Merged
merged 2 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/codegate/providers/litellmshim/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def normalize(self, data: Dict) -> ChatCompletionRequest:
Uses an LiteLLM adapter to translate the request data from the native
LLM format to the OpenAI API format used by LiteLLM internally.
"""
return self._adapter.translate_completion_input_params(data)
# Make a copy of the data to avoid modifying the original and normalize the message content
normalized_data = self._normalize_content_messages(data)
return self._adapter.translate_completion_input_params(normalized_data)

def denormalize(self, data: ChatCompletionRequest) -> Dict:
"""
Expand Down
13 changes: 9 additions & 4 deletions src/codegate/providers/llamacpp/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,18 @@ def normalize(self, data: Dict) -> ChatCompletionRequest:
"""
Normalize the input data
"""
# Make a copy of the data to avoid modifying the original and normalize the message content
normalized_data = self._normalize_content_messages(data)

# When doing FIM, we receive "prompt" instead of messages. Normalizing.
if "prompt" in data:
data["messages"] = [{"content": data.pop("prompt"), "role": "user"}]
if "prompt" in normalized_data:
normalized_data["messages"] = [
{"content": normalized_data.pop("prompt"), "role": "user"}
]
# We can add as many parameters as we like to data. ChatCompletionRequest is not strict.
data["had_prompt_before"] = True
normalized_data["had_prompt_before"] = True
try:
return ChatCompletionRequest(**data)
return ChatCompletionRequest(**normalized_data)
except Exception as e:
raise ValueError(f"Invalid completion parameters: {str(e)}")

Expand Down
26 changes: 26 additions & 0 deletions src/codegate/providers/normalizer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,32 @@ class ModelInputNormalizer(ABC):
to the format expected by the pipeline.
"""

def _normalize_content_messages(self, data: Dict) -> Dict:
"""
If the request contains the "messages" key, make sure that it's content is a string.
"""
# Anyways copy the original data to avoid modifying it
if "messages" not in data:
return data.copy()

normalized_data = data.copy()
messages = normalized_data["messages"]
converted_messages = []
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", "")
new_msg = {"role": role, "content": content}
if isinstance(content, list):
# Convert list format to string
content_parts = []
for part in msg["content"]:
if isinstance(part, dict) and part.get("type") == "text":
content_parts.append(part["text"])
new_msg["content"] = " ".join(content_parts)
converted_messages.append(new_msg)
normalized_data["messages"] = converted_messages
return normalized_data

@abstractmethod
def normalize(self, data: Dict) -> ChatCompletionRequest:
"""Normalize the input data"""
Expand Down
22 changes: 2 additions & 20 deletions src/codegate/providers/ollama/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ def normalize(self, data: Dict) -> ChatCompletionRequest:
"""
Normalize the input data to the format expected by Ollama.
"""
# Make a copy of the data to avoid modifying the original
normalized_data = data.copy()
# Make a copy of the data to avoid modifying the original and normalize the message content
normalized_data = self._normalize_content_messages(data)
normalized_data["options"] = data.get("options", {})

# Add any context or system prompt if provided
Expand All @@ -27,24 +27,6 @@ def normalize(self, data: Dict) -> ChatCompletionRequest:
if "model" in normalized_data:
normalized_data["model"] = data["model"].strip()

# Convert messages format if needed
if "messages" in data:
messages = data["messages"]
converted_messages = []
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", "")
new_msg = {"role": role, "content": content}
if isinstance(content, list):
# Convert list format to string
content_parts = []
for part in msg["content"]:
if part.get("type") == "text":
content_parts.append(part["text"])
new_msg["content"] = " ".join(content_parts)
converted_messages.append(new_msg)
normalized_data["messages"] = converted_messages

# Ensure the base_url ends with /api if provided
if "base_url" in normalized_data:
base_url = normalized_data["base_url"].rstrip("/")
Expand Down
3 changes: 2 additions & 1 deletion src/codegate/providers/openai/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ def normalize(self, data: Dict) -> ChatCompletionRequest:
"""
No normalizing needed, already OpenAI format
"""
return ChatCompletionRequest(**data)
normalized_data = self._normalize_content_messages(data)
return ChatCompletionRequest(**normalized_data)

def denormalize(self, data: ChatCompletionRequest) -> Dict:
"""
Expand Down
6 changes: 4 additions & 2 deletions src/codegate/providers/vllm/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ def normalize(self, data: Dict) -> ChatCompletionRequest:
Normalize the input data to the format expected by LiteLLM.
Ensures the model name has the hosted_vllm prefix and base_url has /v1.
"""
# Make a copy of the data to avoid modifying the original
normalized_data = data.copy()
# Make a copy of the data to avoid modifying the original and normalize the message content
normalized_data = self._normalize_content_messages(data)

# Format the model name to include the provider
if "model" in normalized_data:
Expand All @@ -126,6 +126,8 @@ def normalize(self, data: Dict) -> ChatCompletionRequest:
ret_data = normalized_data
if self._has_chat_ml_format(normalized_data):
ret_data = self._chat_ml_normalizer.normalize(normalized_data)
else:
ret_data = ChatCompletionRequest(**normalized_data)
return ret_data

def denormalize(self, data: ChatCompletionRequest) -> Dict:
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/anthropic/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_normalize_anthropic_input(input_normalizer):
"max_tokens": 1024,
"messages": [
{"content": "You are an expert code reviewer", "role": "system"},
{"content": [{"text": "Review this code", "type": "text"}], "role": "user"},
{"content": "Review this code", "role": "user"},
],
"model": "claude-3-haiku-20240307",
"stream": True,
Expand Down
6 changes: 4 additions & 2 deletions tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ def mock_inference_engine():
@pytest.mark.asyncio
async def test_search(mock_weaviate_client, mock_inference_engine):
# Patch the LlamaCppInferenceEngine.embed method (not the entire class)
with patch("codegate.inference.inference_engine.LlamaCppInferenceEngine.embed",
mock_inference_engine.embed):
with patch(
"codegate.inference.inference_engine.LlamaCppInferenceEngine.embed",
mock_inference_engine.embed,
):

# Mock the WeaviateClient as before
with patch("weaviate.WeaviateClient", return_value=mock_weaviate_client):
Expand Down
Loading