diff --git a/src/codegate/providers/litellmshim/adapter.py b/src/codegate/providers/litellmshim/adapter.py index c0b1a6a9..fed6f1fc 100644 --- a/src/codegate/providers/litellmshim/adapter.py +++ b/src/codegate/providers/litellmshim/adapter.py @@ -50,7 +50,9 @@ def normalize(self, data: Dict) -> ChatCompletionRequest: Uses an LiteLLM adapter to translate the request data from the native LLM format to the OpenAI API format used by LiteLLM internally. """ - return self._adapter.translate_completion_input_params(data) + # Make a copy of the data to avoid modifying the original and normalize the message content + normalized_data = self._normalize_content_messages(data) + return self._adapter.translate_completion_input_params(normalized_data) def denormalize(self, data: ChatCompletionRequest) -> Dict: """ diff --git a/src/codegate/providers/llamacpp/normalizer.py b/src/codegate/providers/llamacpp/normalizer.py index 6ca08d58..4fbff365 100644 --- a/src/codegate/providers/llamacpp/normalizer.py +++ b/src/codegate/providers/llamacpp/normalizer.py @@ -17,13 +17,18 @@ def normalize(self, data: Dict) -> ChatCompletionRequest: """ Normalize the input data """ + # Make a copy of the data to avoid modifying the original and normalize the message content + normalized_data = self._normalize_content_messages(data) + # When doing FIM, we receive "prompt" instead of messages. Normalizing. - if "prompt" in data: - data["messages"] = [{"content": data.pop("prompt"), "role": "user"}] + if "prompt" in normalized_data: + normalized_data["messages"] = [ + {"content": normalized_data.pop("prompt"), "role": "user"} + ] # We can add as many parameters as we like to data. ChatCompletionRequest is not strict. - data["had_prompt_before"] = True + normalized_data["had_prompt_before"] = True try: - return ChatCompletionRequest(**data) + return ChatCompletionRequest(**normalized_data) except Exception as e: raise ValueError(f"Invalid completion parameters: {str(e)}") diff --git a/src/codegate/providers/normalizer/base.py b/src/codegate/providers/normalizer/base.py index 625842c9..e52bb859 100644 --- a/src/codegate/providers/normalizer/base.py +++ b/src/codegate/providers/normalizer/base.py @@ -11,6 +11,32 @@ class ModelInputNormalizer(ABC): to the format expected by the pipeline. """ + def _normalize_content_messages(self, data: Dict) -> Dict: + """ + If the request contains the "messages" key, make sure that it's content is a string. + """ + # Anyways copy the original data to avoid modifying it + if "messages" not in data: + return data.copy() + + normalized_data = data.copy() + messages = normalized_data["messages"] + converted_messages = [] + for msg in messages: + role = msg.get("role", "") + content = msg.get("content", "") + new_msg = {"role": role, "content": content} + if isinstance(content, list): + # Convert list format to string + content_parts = [] + for part in msg["content"]: + if isinstance(part, dict) and part.get("type") == "text": + content_parts.append(part["text"]) + new_msg["content"] = " ".join(content_parts) + converted_messages.append(new_msg) + normalized_data["messages"] = converted_messages + return normalized_data + @abstractmethod def normalize(self, data: Dict) -> ChatCompletionRequest: """Normalize the input data""" diff --git a/src/codegate/providers/ollama/adapter.py b/src/codegate/providers/ollama/adapter.py index ebc91e85..9e847774 100644 --- a/src/codegate/providers/ollama/adapter.py +++ b/src/codegate/providers/ollama/adapter.py @@ -13,8 +13,8 @@ def normalize(self, data: Dict) -> ChatCompletionRequest: """ Normalize the input data to the format expected by Ollama. """ - # Make a copy of the data to avoid modifying the original - normalized_data = data.copy() + # Make a copy of the data to avoid modifying the original and normalize the message content + normalized_data = self._normalize_content_messages(data) normalized_data["options"] = data.get("options", {}) # Add any context or system prompt if provided @@ -27,24 +27,6 @@ def normalize(self, data: Dict) -> ChatCompletionRequest: if "model" in normalized_data: normalized_data["model"] = data["model"].strip() - # Convert messages format if needed - if "messages" in data: - messages = data["messages"] - converted_messages = [] - for msg in messages: - role = msg.get("role", "") - content = msg.get("content", "") - new_msg = {"role": role, "content": content} - if isinstance(content, list): - # Convert list format to string - content_parts = [] - for part in msg["content"]: - if part.get("type") == "text": - content_parts.append(part["text"]) - new_msg["content"] = " ".join(content_parts) - converted_messages.append(new_msg) - normalized_data["messages"] = converted_messages - # Ensure the base_url ends with /api if provided if "base_url" in normalized_data: base_url = normalized_data["base_url"].rstrip("/") diff --git a/src/codegate/providers/openai/adapter.py b/src/codegate/providers/openai/adapter.py index b5f4565a..43baf88d 100644 --- a/src/codegate/providers/openai/adapter.py +++ b/src/codegate/providers/openai/adapter.py @@ -13,7 +13,8 @@ def normalize(self, data: Dict) -> ChatCompletionRequest: """ No normalizing needed, already OpenAI format """ - return ChatCompletionRequest(**data) + normalized_data = self._normalize_content_messages(data) + return ChatCompletionRequest(**normalized_data) def denormalize(self, data: ChatCompletionRequest) -> Dict: """ diff --git a/src/codegate/providers/vllm/adapter.py b/src/codegate/providers/vllm/adapter.py index a6240f1b..ebd92d22 100644 --- a/src/codegate/providers/vllm/adapter.py +++ b/src/codegate/providers/vllm/adapter.py @@ -108,8 +108,8 @@ def normalize(self, data: Dict) -> ChatCompletionRequest: Normalize the input data to the format expected by LiteLLM. Ensures the model name has the hosted_vllm prefix and base_url has /v1. """ - # Make a copy of the data to avoid modifying the original - normalized_data = data.copy() + # Make a copy of the data to avoid modifying the original and normalize the message content + normalized_data = self._normalize_content_messages(data) # Format the model name to include the provider if "model" in normalized_data: @@ -126,6 +126,8 @@ def normalize(self, data: Dict) -> ChatCompletionRequest: ret_data = normalized_data if self._has_chat_ml_format(normalized_data): ret_data = self._chat_ml_normalizer.normalize(normalized_data) + else: + ret_data = ChatCompletionRequest(**normalized_data) return ret_data def denormalize(self, data: ChatCompletionRequest) -> Dict: diff --git a/tests/providers/anthropic/test_adapter.py b/tests/providers/anthropic/test_adapter.py index 9bb81e54..69735aa3 100644 --- a/tests/providers/anthropic/test_adapter.py +++ b/tests/providers/anthropic/test_adapter.py @@ -38,7 +38,7 @@ def test_normalize_anthropic_input(input_normalizer): "max_tokens": 1024, "messages": [ {"content": "You are an expert code reviewer", "role": "system"}, - {"content": [{"text": "Review this code", "type": "text"}], "role": "user"}, + {"content": "Review this code", "role": "user"}, ], "model": "claude-3-haiku-20240307", "stream": True, diff --git a/tests/test_storage.py b/tests/test_storage.py index f34fa8bb..2b7b9c16 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -35,8 +35,10 @@ def mock_inference_engine(): @pytest.mark.asyncio async def test_search(mock_weaviate_client, mock_inference_engine): # Patch the LlamaCppInferenceEngine.embed method (not the entire class) - with patch("codegate.inference.inference_engine.LlamaCppInferenceEngine.embed", - mock_inference_engine.embed): + with patch( + "codegate.inference.inference_engine.LlamaCppInferenceEngine.embed", + mock_inference_engine.embed, + ): # Mock the WeaviateClient as before with patch("weaviate.WeaviateClient", return_value=mock_weaviate_client):