Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 1660500

Browse files
Merge pull request #170 from stacklok/normalize-anthropic
Normalize key `messages` for all providers
2 parents 9f8ce05 + 6c3b13c commit 1660500

File tree

8 files changed

+51
-31
lines changed

8 files changed

+51
-31
lines changed

src/codegate/providers/litellmshim/adapter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ def normalize(self, data: Dict) -> ChatCompletionRequest:
5050
Uses an LiteLLM adapter to translate the request data from the native
5151
LLM format to the OpenAI API format used by LiteLLM internally.
5252
"""
53-
return self._adapter.translate_completion_input_params(data)
53+
# Make a copy of the data to avoid modifying the original and normalize the message content
54+
normalized_data = self._normalize_content_messages(data)
55+
return self._adapter.translate_completion_input_params(normalized_data)
5456

5557
def denormalize(self, data: ChatCompletionRequest) -> Dict:
5658
"""

src/codegate/providers/llamacpp/normalizer.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,18 @@ def normalize(self, data: Dict) -> ChatCompletionRequest:
1717
"""
1818
Normalize the input data
1919
"""
20+
# Make a copy of the data to avoid modifying the original and normalize the message content
21+
normalized_data = self._normalize_content_messages(data)
22+
2023
# When doing FIM, we receive "prompt" instead of messages. Normalizing.
21-
if "prompt" in data:
22-
data["messages"] = [{"content": data.pop("prompt"), "role": "user"}]
24+
if "prompt" in normalized_data:
25+
normalized_data["messages"] = [
26+
{"content": normalized_data.pop("prompt"), "role": "user"}
27+
]
2328
# We can add as many parameters as we like to data. ChatCompletionRequest is not strict.
24-
data["had_prompt_before"] = True
29+
normalized_data["had_prompt_before"] = True
2530
try:
26-
return ChatCompletionRequest(**data)
31+
return ChatCompletionRequest(**normalized_data)
2732
except Exception as e:
2833
raise ValueError(f"Invalid completion parameters: {str(e)}")
2934

src/codegate/providers/normalizer/base.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,32 @@ class ModelInputNormalizer(ABC):
1111
to the format expected by the pipeline.
1212
"""
1313

14+
def _normalize_content_messages(self, data: Dict) -> Dict:
15+
"""
16+
If the request contains the "messages" key, make sure that it's content is a string.
17+
"""
18+
# Anyways copy the original data to avoid modifying it
19+
if "messages" not in data:
20+
return data.copy()
21+
22+
normalized_data = data.copy()
23+
messages = normalized_data["messages"]
24+
converted_messages = []
25+
for msg in messages:
26+
role = msg.get("role", "")
27+
content = msg.get("content", "")
28+
new_msg = {"role": role, "content": content}
29+
if isinstance(content, list):
30+
# Convert list format to string
31+
content_parts = []
32+
for part in msg["content"]:
33+
if isinstance(part, dict) and part.get("type") == "text":
34+
content_parts.append(part["text"])
35+
new_msg["content"] = " ".join(content_parts)
36+
converted_messages.append(new_msg)
37+
normalized_data["messages"] = converted_messages
38+
return normalized_data
39+
1440
@abstractmethod
1541
def normalize(self, data: Dict) -> ChatCompletionRequest:
1642
"""Normalize the input data"""

src/codegate/providers/ollama/adapter.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ def normalize(self, data: Dict) -> ChatCompletionRequest:
1313
"""
1414
Normalize the input data to the format expected by Ollama.
1515
"""
16-
# Make a copy of the data to avoid modifying the original
17-
normalized_data = data.copy()
16+
# Make a copy of the data to avoid modifying the original and normalize the message content
17+
normalized_data = self._normalize_content_messages(data)
1818
normalized_data["options"] = data.get("options", {})
1919

2020
# Add any context or system prompt if provided
@@ -27,24 +27,6 @@ def normalize(self, data: Dict) -> ChatCompletionRequest:
2727
if "model" in normalized_data:
2828
normalized_data["model"] = data["model"].strip()
2929

30-
# Convert messages format if needed
31-
if "messages" in data:
32-
messages = data["messages"]
33-
converted_messages = []
34-
for msg in messages:
35-
role = msg.get("role", "")
36-
content = msg.get("content", "")
37-
new_msg = {"role": role, "content": content}
38-
if isinstance(content, list):
39-
# Convert list format to string
40-
content_parts = []
41-
for part in msg["content"]:
42-
if part.get("type") == "text":
43-
content_parts.append(part["text"])
44-
new_msg["content"] = " ".join(content_parts)
45-
converted_messages.append(new_msg)
46-
normalized_data["messages"] = converted_messages
47-
4830
# Ensure the base_url ends with /api if provided
4931
if "base_url" in normalized_data:
5032
base_url = normalized_data["base_url"].rstrip("/")

src/codegate/providers/openai/adapter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ def normalize(self, data: Dict) -> ChatCompletionRequest:
1313
"""
1414
No normalizing needed, already OpenAI format
1515
"""
16-
return ChatCompletionRequest(**data)
16+
normalized_data = self._normalize_content_messages(data)
17+
return ChatCompletionRequest(**normalized_data)
1718

1819
def denormalize(self, data: ChatCompletionRequest) -> Dict:
1920
"""

src/codegate/providers/vllm/adapter.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,8 @@ def normalize(self, data: Dict) -> ChatCompletionRequest:
108108
Normalize the input data to the format expected by LiteLLM.
109109
Ensures the model name has the hosted_vllm prefix and base_url has /v1.
110110
"""
111-
# Make a copy of the data to avoid modifying the original
112-
normalized_data = data.copy()
111+
# Make a copy of the data to avoid modifying the original and normalize the message content
112+
normalized_data = self._normalize_content_messages(data)
113113

114114
# Format the model name to include the provider
115115
if "model" in normalized_data:
@@ -126,6 +126,8 @@ def normalize(self, data: Dict) -> ChatCompletionRequest:
126126
ret_data = normalized_data
127127
if self._has_chat_ml_format(normalized_data):
128128
ret_data = self._chat_ml_normalizer.normalize(normalized_data)
129+
else:
130+
ret_data = ChatCompletionRequest(**normalized_data)
129131
return ret_data
130132

131133
def denormalize(self, data: ChatCompletionRequest) -> Dict:

tests/providers/anthropic/test_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_normalize_anthropic_input(input_normalizer):
3838
"max_tokens": 1024,
3939
"messages": [
4040
{"content": "You are an expert code reviewer", "role": "system"},
41-
{"content": [{"text": "Review this code", "type": "text"}], "role": "user"},
41+
{"content": "Review this code", "role": "user"},
4242
],
4343
"model": "claude-3-haiku-20240307",
4444
"stream": True,

tests/test_storage.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@ def mock_inference_engine():
3535
@pytest.mark.asyncio
3636
async def test_search(mock_weaviate_client, mock_inference_engine):
3737
# Patch the LlamaCppInferenceEngine.embed method (not the entire class)
38-
with patch("codegate.inference.inference_engine.LlamaCppInferenceEngine.embed",
39-
mock_inference_engine.embed):
38+
with patch(
39+
"codegate.inference.inference_engine.LlamaCppInferenceEngine.embed",
40+
mock_inference_engine.embed,
41+
):
4042

4143
# Mock the WeaviateClient as before
4244
with patch("weaviate.WeaviateClient", return_value=mock_weaviate_client):

0 commit comments

Comments
 (0)