Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Fix Anthropic FIM with muxing. #1304

Merged
merged 1 commit into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/codegate/muxing/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,15 @@ async def route_to_dest_provider(
# TODO this should be improved
match model_route.endpoint.provider_type:
case ProviderType.anthropic:
if is_fim_request:
# Note: despite `is_fim_request` being true, our
# integration tests query the `/chat/completions`
# endpoint, which causes the
# `anthropic_from_legacy_openai` to incorrectly
# populate the struct.
#
# Checking for the actual type is a much more
# reliable way of determining the right mapper.
if isinstance(parsed, openai.LegacyCompletionRequest):
completion_function = anthropic.acompletion
from_openai = anthropic_from_legacy_openai
to_openai = anthropic_to_legacy_openai
Expand Down
37 changes: 28 additions & 9 deletions src/codegate/providers/anthropic/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,15 @@
from codegate.providers.anthropic.completion_handler import AnthropicCompletion
from codegate.providers.base import BaseProvider, ModelFetchError
from codegate.providers.fim_analyzer import FIMAnalyzer
from codegate.types.anthropic import ChatCompletionRequest, stream_generator
from codegate.types.anthropic import (
ChatCompletionRequest,
single_message,
single_response,
stream_generator,
)
from codegate.types.generators import (
completion_handler_replacement,
)

logger = structlog.get_logger("codegate")

Expand Down Expand Up @@ -118,18 +126,29 @@ async def create_message(
body = await request.body()

if os.getenv("CODEGATE_DEBUG_ANTHROPIC") is not None:
print(f"{create_message.__name__}: {body}")
print(f"{body.decode('utf-8')}")

req = ChatCompletionRequest.model_validate_json(body)
is_fim_request = FIMAnalyzer.is_fim_request(request.url.path, req)

return await self.process_request(
req,
x_api_key,
self.base_url,
is_fim_request,
request.state.detected_client,
)
if req.stream:
return await self.process_request(
req,
x_api_key,
self.base_url,
is_fim_request,
request.state.detected_client,
)
else:
return await self.process_request(
req,
x_api_key,
self.base_url,
is_fim_request,
request.state.detected_client,
completion_handler=completion_handler_replacement(single_message),
stream_generator=single_response,
)


async def dumper(stream):
Expand Down
2 changes: 1 addition & 1 deletion src/codegate/providers/ollama/completion_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ async def _ollama_dispatcher( # noqa: C901
stream = openai_stream_generator(prepend(first, stream))

if isinstance(first, OpenAIChatCompletion):
stream = openai_single_response_generator(first, stream)
stream = openai_single_response_generator(first)

async for item in stream:
yield item
Expand Down
4 changes: 4 additions & 0 deletions src/codegate/types/anthropic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from ._generators import (
acompletion,
message_wrapper,
single_message,
single_response,
stream_generator,
)
from ._request_models import (
Expand Down Expand Up @@ -49,6 +51,8 @@
__all__ = [
"acompletion",
"message_wrapper",
"single_message",
"single_response",
"stream_generator",
"AssistantMessage",
"CacheControl",
Expand Down
61 changes: 58 additions & 3 deletions src/codegate/types/anthropic/_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ContentBlockDelta,
ContentBlockStart,
ContentBlockStop,
Message,
MessageDelta,
MessageError,
MessagePing,
Expand All @@ -27,7 +28,7 @@ async def stream_generator(stream: AsyncIterator[Any]) -> AsyncIterator[str]:
try:
async for chunk in stream:
try:
body = chunk.json(exclude_defaults=True, exclude_unset=True)
body = chunk.json(exclude_unset=True)
except Exception as e:
logger.error("failed serializing payload", exc_info=e)
err = MessageError(
Expand All @@ -37,7 +38,7 @@ async def stream_generator(stream: AsyncIterator[Any]) -> AsyncIterator[str]:
message=str(e),
),
)
body = err.json(exclude_defaults=True, exclude_unset=True)
body = err.json(exclude_unset=True)
yield f"event: error\ndata: {body}\n\n"

data = f"event: {chunk.type}\ndata: {body}\n\n"
Expand All @@ -55,10 +56,60 @@ async def stream_generator(stream: AsyncIterator[Any]) -> AsyncIterator[str]:
message=str(e),
),
)
body = err.json(exclude_defaults=True, exclude_unset=True)
body = err.json(exclude_unset=True)
yield f"event: error\ndata: {body}\n\n"


async def single_response(stream: AsyncIterator[Any]) -> AsyncIterator[str]:
"""Wraps a single response object in an AsyncIterator. This is
meant to be used for non-streaming responses.

"""
resp = await anext(stream)
yield resp.model_dump_json(exclude_unset=True)


async def single_message(request, api_key, base_url, stream=None, is_fim_request=None):
headers = {
"anthropic-version": "2023-06-01",
"x-api-key": api_key,
"accept": "application/json",
"content-type": "application/json",
}
payload = request.model_dump_json(exclude_unset=True)

if os.getenv("CODEGATE_DEBUG_ANTHROPIC") is not None:
print(payload)

client = httpx.AsyncClient()
async with client.stream(
"POST",
f"{base_url}/v1/messages",
headers=headers,
content=payload,
timeout=60, # TODO this should not be hardcoded
) as resp:
match resp.status_code:
case 200:
text = await resp.aread()
if os.getenv("CODEGATE_DEBUG_ANTHROPIC") is not None:
print(text.decode("utf-8"))
yield Message.model_validate_json(text)
case 400 | 401 | 403 | 404 | 413 | 429:
text = await resp.aread()
if os.getenv("CODEGATE_DEBUG_ANTHROPIC") is not None:
print(text.decode("utf-8"))
yield MessageError.model_validate_json(text)
case 500 | 529:
text = await resp.aread()
if os.getenv("CODEGATE_DEBUG_ANTHROPIC") is not None:
print(text.decode("utf-8"))
yield MessageError.model_validate_json(text)
case _:
logger.error(f"unexpected status code {resp.status_code}", provider="anthropic")
raise ValueError(f"unexpected status code {resp.status_code}", provider="anthropic")


async def acompletion(request, api_key, base_url):
headers = {
"anthropic-version": "2023-06-01",
Expand Down Expand Up @@ -86,9 +137,13 @@ async def acompletion(request, api_key, base_url):
yield event
case 400 | 401 | 403 | 404 | 413 | 429:
text = await resp.aread()
if os.getenv("CODEGATE_DEBUG_ANTHROPIC") is not None:
print(text.decode("utf-8"))
yield MessageError.model_validate_json(text)
case 500 | 529:
text = await resp.aread()
if os.getenv("CODEGATE_DEBUG_ANTHROPIC") is not None:
print(text.decode("utf-8"))
yield MessageError.model_validate_json(text)
case _:
logger.error(f"unexpected status code {resp.status_code}", provider="anthropic")
Expand Down
1 change: 1 addition & 0 deletions src/codegate/types/anthropic/_request_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ class ToolDef(pydantic.BaseModel):
Literal["auto"],
Literal["any"],
Literal["tool"],
Literal["none"],
]


Expand Down
46 changes: 18 additions & 28 deletions src/codegate/types/generators.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,27 @@
import os
from typing import (
Any,
AsyncIterator,
Callable,
)

import pydantic
import structlog

logger = structlog.get_logger("codegate")


# Since different providers typically use one of these formats for streaming
# responses, we have a single stream generator for each format that is then plugged
# into the adapter.
def completion_handler_replacement(
completion_handler: Callable,
):
async def _inner(
request,
base_url,
api_key,
stream=None,
is_fim_request=None,
):
# Execute e.g. acompletion from Anthropic types
return completion_handler(
request,
api_key,
base_url,
)


async def sse_stream_generator(stream: AsyncIterator[Any]) -> AsyncIterator[str]:
"""OpenAI-style SSE format"""
try:
async for chunk in stream:
if isinstance(chunk, pydantic.BaseModel):
# alternatively we might want to just dump the whole object
# this might even allow us to tighten the typing of the stream
chunk = chunk.model_dump_json(exclude_none=True, exclude_unset=True)
try:
if os.getenv("CODEGATE_DEBUG_OPENAI") is not None:
print(chunk)
yield f"data: {chunk}\n\n"
except Exception as e:
logger.error("failed generating output payloads", exc_info=e)
yield f"data: {str(e)}\n\n"
except Exception as e:
logger.error("failed generating output payloads", exc_info=e)
yield f"data: {str(e)}\n\n"
finally:
yield "data: [DONE]\n\n"
return _inner
2 changes: 1 addition & 1 deletion src/codegate/types/ollama/_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ async def stream_generator(
try:
async for chunk in stream:
try:
body = chunk.model_dump_json(exclude_none=True, exclude_unset=True)
body = chunk.model_dump_json(exclude_unset=True)
data = f"{body}\n"

if os.getenv("CODEGATE_DEBUG_OLLAMA") is not None:
Expand Down
13 changes: 0 additions & 13 deletions src/codegate/types/openai/_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,26 +50,13 @@ async def stream_generator(stream: AsyncIterator[StreamingChatCompletion]) -> As

async def single_response_generator(
first: ChatCompletion,
stream: AsyncIterator[ChatCompletion],
) -> AsyncIterator[ChatCompletion]:
"""Wraps a single response object in an AsyncIterator. This is
meant to be used for non-streaming responses.

"""
yield first.model_dump_json(exclude_none=True, exclude_unset=True)

# Note: this async for loop is necessary to force Python to return
# an AsyncIterator. This is necessary because of the wiring at the
# Provider level expecting an AsyncIterator rather than a single
# response payload.
#
# Refactoring this means adding a code path specific for when we
# expect single response payloads rather than an SSE stream.
async for item in stream:
if item:
logger.error("no further items were expected", item=item)
yield item.model_dump_json(exclude_none=True, exclude_unset=True)


async def completions_streaming(request, api_key, base_url):
if base_url is None:
Expand Down
Loading