From 87cd807eae2e7f165912204ce115393468be2196 Mon Sep 17 00:00:00 2001 From: Michelangelo Mori <328978+blkt@users.noreply.github.com> Date: Sat, 22 Mar 2025 11:43:51 +0100 Subject: [PATCH] Fix Anthropic FIM with muxing. In the context of muxing, the code determining which mapper to use when receiving requests to be routed towards Anthropic was relying in `is_fim_request` only, and was not taking into account if the actual endpoint receiving the request was the legacy one (i.e. `/completions`) or the current one (i.e. `/chat/completions`). This caused the use of the wrong mapper, which led to an empty text content for the FIM request. A better way to determine which mapper to use is looking at the effective type, since that's the real source of truth for the translation. --- src/codegate/muxing/router.py | 10 ++- src/codegate/providers/anthropic/provider.py | 37 ++++++++--- .../providers/ollama/completion_handler.py | 2 +- src/codegate/types/anthropic/__init__.py | 4 ++ src/codegate/types/anthropic/_generators.py | 61 ++++++++++++++++++- .../types/anthropic/_request_models.py | 1 + src/codegate/types/generators.py | 46 ++++++-------- src/codegate/types/ollama/_generators.py | 2 +- src/codegate/types/openai/_generators.py | 13 ---- 9 files changed, 120 insertions(+), 56 deletions(-) diff --git a/src/codegate/muxing/router.py b/src/codegate/muxing/router.py index 8e1c2045..04086791 100644 --- a/src/codegate/muxing/router.py +++ b/src/codegate/muxing/router.py @@ -138,7 +138,15 @@ async def route_to_dest_provider( # TODO this should be improved match model_route.endpoint.provider_type: case ProviderType.anthropic: - if is_fim_request: + # Note: despite `is_fim_request` being true, our + # integration tests query the `/chat/completions` + # endpoint, which causes the + # `anthropic_from_legacy_openai` to incorrectly + # populate the struct. + # + # Checking for the actual type is a much more + # reliable way of determining the right mapper. + if isinstance(parsed, openai.LegacyCompletionRequest): completion_function = anthropic.acompletion from_openai = anthropic_from_legacy_openai to_openai = anthropic_to_legacy_openai diff --git a/src/codegate/providers/anthropic/provider.py b/src/codegate/providers/anthropic/provider.py index 3b23fe39..13741b85 100644 --- a/src/codegate/providers/anthropic/provider.py +++ b/src/codegate/providers/anthropic/provider.py @@ -11,7 +11,15 @@ from codegate.providers.anthropic.completion_handler import AnthropicCompletion from codegate.providers.base import BaseProvider, ModelFetchError from codegate.providers.fim_analyzer import FIMAnalyzer -from codegate.types.anthropic import ChatCompletionRequest, stream_generator +from codegate.types.anthropic import ( + ChatCompletionRequest, + single_message, + single_response, + stream_generator, +) +from codegate.types.generators import ( + completion_handler_replacement, +) logger = structlog.get_logger("codegate") @@ -118,18 +126,29 @@ async def create_message( body = await request.body() if os.getenv("CODEGATE_DEBUG_ANTHROPIC") is not None: - print(f"{create_message.__name__}: {body}") + print(f"{body.decode('utf-8')}") req = ChatCompletionRequest.model_validate_json(body) is_fim_request = FIMAnalyzer.is_fim_request(request.url.path, req) - return await self.process_request( - req, - x_api_key, - self.base_url, - is_fim_request, - request.state.detected_client, - ) + if req.stream: + return await self.process_request( + req, + x_api_key, + self.base_url, + is_fim_request, + request.state.detected_client, + ) + else: + return await self.process_request( + req, + x_api_key, + self.base_url, + is_fim_request, + request.state.detected_client, + completion_handler=completion_handler_replacement(single_message), + stream_generator=single_response, + ) async def dumper(stream): diff --git a/src/codegate/providers/ollama/completion_handler.py b/src/codegate/providers/ollama/completion_handler.py index b1782a9a..d134fd66 100644 --- a/src/codegate/providers/ollama/completion_handler.py +++ b/src/codegate/providers/ollama/completion_handler.py @@ -73,7 +73,7 @@ async def _ollama_dispatcher( # noqa: C901 stream = openai_stream_generator(prepend(first, stream)) if isinstance(first, OpenAIChatCompletion): - stream = openai_single_response_generator(first, stream) + stream = openai_single_response_generator(first) async for item in stream: yield item diff --git a/src/codegate/types/anthropic/__init__.py b/src/codegate/types/anthropic/__init__.py index 10d225a8..f037cc5c 100644 --- a/src/codegate/types/anthropic/__init__.py +++ b/src/codegate/types/anthropic/__init__.py @@ -1,6 +1,8 @@ from ._generators import ( acompletion, message_wrapper, + single_message, + single_response, stream_generator, ) from ._request_models import ( @@ -49,6 +51,8 @@ __all__ = [ "acompletion", "message_wrapper", + "single_message", + "single_response", "stream_generator", "AssistantMessage", "CacheControl", diff --git a/src/codegate/types/anthropic/_generators.py b/src/codegate/types/anthropic/_generators.py index 4c7449d7..64c99229 100644 --- a/src/codegate/types/anthropic/_generators.py +++ b/src/codegate/types/anthropic/_generators.py @@ -12,6 +12,7 @@ ContentBlockDelta, ContentBlockStart, ContentBlockStop, + Message, MessageDelta, MessageError, MessagePing, @@ -27,7 +28,7 @@ async def stream_generator(stream: AsyncIterator[Any]) -> AsyncIterator[str]: try: async for chunk in stream: try: - body = chunk.json(exclude_defaults=True, exclude_unset=True) + body = chunk.json(exclude_unset=True) except Exception as e: logger.error("failed serializing payload", exc_info=e) err = MessageError( @@ -37,7 +38,7 @@ async def stream_generator(stream: AsyncIterator[Any]) -> AsyncIterator[str]: message=str(e), ), ) - body = err.json(exclude_defaults=True, exclude_unset=True) + body = err.json(exclude_unset=True) yield f"event: error\ndata: {body}\n\n" data = f"event: {chunk.type}\ndata: {body}\n\n" @@ -55,10 +56,60 @@ async def stream_generator(stream: AsyncIterator[Any]) -> AsyncIterator[str]: message=str(e), ), ) - body = err.json(exclude_defaults=True, exclude_unset=True) + body = err.json(exclude_unset=True) yield f"event: error\ndata: {body}\n\n" +async def single_response(stream: AsyncIterator[Any]) -> AsyncIterator[str]: + """Wraps a single response object in an AsyncIterator. This is + meant to be used for non-streaming responses. + + """ + resp = await anext(stream) + yield resp.model_dump_json(exclude_unset=True) + + +async def single_message(request, api_key, base_url, stream=None, is_fim_request=None): + headers = { + "anthropic-version": "2023-06-01", + "x-api-key": api_key, + "accept": "application/json", + "content-type": "application/json", + } + payload = request.model_dump_json(exclude_unset=True) + + if os.getenv("CODEGATE_DEBUG_ANTHROPIC") is not None: + print(payload) + + client = httpx.AsyncClient() + async with client.stream( + "POST", + f"{base_url}/v1/messages", + headers=headers, + content=payload, + timeout=60, # TODO this should not be hardcoded + ) as resp: + match resp.status_code: + case 200: + text = await resp.aread() + if os.getenv("CODEGATE_DEBUG_ANTHROPIC") is not None: + print(text.decode("utf-8")) + yield Message.model_validate_json(text) + case 400 | 401 | 403 | 404 | 413 | 429: + text = await resp.aread() + if os.getenv("CODEGATE_DEBUG_ANTHROPIC") is not None: + print(text.decode("utf-8")) + yield MessageError.model_validate_json(text) + case 500 | 529: + text = await resp.aread() + if os.getenv("CODEGATE_DEBUG_ANTHROPIC") is not None: + print(text.decode("utf-8")) + yield MessageError.model_validate_json(text) + case _: + logger.error(f"unexpected status code {resp.status_code}", provider="anthropic") + raise ValueError(f"unexpected status code {resp.status_code}", provider="anthropic") + + async def acompletion(request, api_key, base_url): headers = { "anthropic-version": "2023-06-01", @@ -86,9 +137,13 @@ async def acompletion(request, api_key, base_url): yield event case 400 | 401 | 403 | 404 | 413 | 429: text = await resp.aread() + if os.getenv("CODEGATE_DEBUG_ANTHROPIC") is not None: + print(text.decode("utf-8")) yield MessageError.model_validate_json(text) case 500 | 529: text = await resp.aread() + if os.getenv("CODEGATE_DEBUG_ANTHROPIC") is not None: + print(text.decode("utf-8")) yield MessageError.model_validate_json(text) case _: logger.error(f"unexpected status code {resp.status_code}", provider="anthropic") diff --git a/src/codegate/types/anthropic/_request_models.py b/src/codegate/types/anthropic/_request_models.py index 592b9712..fb2c22b4 100644 --- a/src/codegate/types/anthropic/_request_models.py +++ b/src/codegate/types/anthropic/_request_models.py @@ -155,6 +155,7 @@ class ToolDef(pydantic.BaseModel): Literal["auto"], Literal["any"], Literal["tool"], + Literal["none"], ] diff --git a/src/codegate/types/generators.py b/src/codegate/types/generators.py index affca5ba..6ab0ee97 100644 --- a/src/codegate/types/generators.py +++ b/src/codegate/types/generators.py @@ -1,37 +1,27 @@ -import os from typing import ( - Any, - AsyncIterator, + Callable, ) -import pydantic import structlog logger = structlog.get_logger("codegate") -# Since different providers typically use one of these formats for streaming -# responses, we have a single stream generator for each format that is then plugged -# into the adapter. +def completion_handler_replacement( + completion_handler: Callable, +): + async def _inner( + request, + base_url, + api_key, + stream=None, + is_fim_request=None, + ): + # Execute e.g. acompletion from Anthropic types + return completion_handler( + request, + api_key, + base_url, + ) - -async def sse_stream_generator(stream: AsyncIterator[Any]) -> AsyncIterator[str]: - """OpenAI-style SSE format""" - try: - async for chunk in stream: - if isinstance(chunk, pydantic.BaseModel): - # alternatively we might want to just dump the whole object - # this might even allow us to tighten the typing of the stream - chunk = chunk.model_dump_json(exclude_none=True, exclude_unset=True) - try: - if os.getenv("CODEGATE_DEBUG_OPENAI") is not None: - print(chunk) - yield f"data: {chunk}\n\n" - except Exception as e: - logger.error("failed generating output payloads", exc_info=e) - yield f"data: {str(e)}\n\n" - except Exception as e: - logger.error("failed generating output payloads", exc_info=e) - yield f"data: {str(e)}\n\n" - finally: - yield "data: [DONE]\n\n" + return _inner diff --git a/src/codegate/types/ollama/_generators.py b/src/codegate/types/ollama/_generators.py index 2c141158..896cc7fe 100644 --- a/src/codegate/types/ollama/_generators.py +++ b/src/codegate/types/ollama/_generators.py @@ -23,7 +23,7 @@ async def stream_generator( try: async for chunk in stream: try: - body = chunk.model_dump_json(exclude_none=True, exclude_unset=True) + body = chunk.model_dump_json(exclude_unset=True) data = f"{body}\n" if os.getenv("CODEGATE_DEBUG_OLLAMA") is not None: diff --git a/src/codegate/types/openai/_generators.py b/src/codegate/types/openai/_generators.py index 2a36229c..1d0f215c 100644 --- a/src/codegate/types/openai/_generators.py +++ b/src/codegate/types/openai/_generators.py @@ -50,7 +50,6 @@ async def stream_generator(stream: AsyncIterator[StreamingChatCompletion]) -> As async def single_response_generator( first: ChatCompletion, - stream: AsyncIterator[ChatCompletion], ) -> AsyncIterator[ChatCompletion]: """Wraps a single response object in an AsyncIterator. This is meant to be used for non-streaming responses. @@ -58,18 +57,6 @@ async def single_response_generator( """ yield first.model_dump_json(exclude_none=True, exclude_unset=True) - # Note: this async for loop is necessary to force Python to return - # an AsyncIterator. This is necessary because of the wiring at the - # Provider level expecting an AsyncIterator rather than a single - # response payload. - # - # Refactoring this means adding a code path specific for when we - # expect single response payloads rather than an SSE stream. - async for item in stream: - if item: - logger.error("no further items were expected", item=item) - yield item.model_dump_json(exclude_none=True, exclude_unset=True) - async def completions_streaming(request, api_key, base_url): if base_url is None: