Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit d4f1ab8

Browse files
Merge pull request #149 from stacklok/normalize-vllm-output
Respond with JSON if the request is non-stream
2 parents 366bd6e + 9b3e488 commit d4f1ab8

File tree

9 files changed

+50
-13
lines changed

9 files changed

+50
-13
lines changed

src/codegate/providers/anthropic/provider.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,4 @@ async def create_message(
4848

4949
is_fim_request = self._is_fim_request(request, data)
5050
stream = await self.complete(data, x_api_key, is_fim_request)
51-
return self._completion_handler.create_streaming_response(stream)
51+
return self._completion_handler.create_response(stream)

src/codegate/providers/completion/base.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from abc import ABC, abstractmethod
2+
from collections.abc import Iterator
23
from typing import Any, AsyncIterator, Optional, Union
34

4-
from fastapi.responses import StreamingResponse
5+
from fastapi.responses import JSONResponse, StreamingResponse
56
from litellm import ChatCompletionRequest, ModelResponse
67

78

@@ -23,5 +24,17 @@ async def execute_completion(
2324
pass
2425

2526
@abstractmethod
26-
def create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResponse:
27+
def _create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResponse:
2728
pass
29+
30+
@abstractmethod
31+
def _create_json_response(self, response: Any) -> JSONResponse:
32+
pass
33+
34+
def create_response(self, response: Any) -> Union[JSONResponse, StreamingResponse]:
35+
"""
36+
Create a FastAPI response from the completion response.
37+
"""
38+
if isinstance(response, Iterator):
39+
return self._create_streaming_response(response)
40+
return self._create_json_response(response)

src/codegate/providers/litellmshim/litellmshim.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
from typing import Any, AsyncIterator, Callable, Optional, Union
22

3-
from fastapi.responses import StreamingResponse
4-
from litellm import ChatCompletionRequest, ModelResponse, acompletion
3+
import structlog
4+
from fastapi.responses import JSONResponse, StreamingResponse
5+
from litellm import (
6+
ChatCompletionRequest,
7+
ModelResponse,
8+
acompletion,
9+
)
510

611
from codegate.providers.base import BaseCompletionHandler, StreamGenerator
712

13+
logger = structlog.get_logger("codegate")
14+
815

916
class LiteLLmShim(BaseCompletionHandler):
1017
"""
@@ -42,7 +49,7 @@ async def execute_completion(
4249
return await self._fim_completion_func(**request)
4350
return await self._completion_func(**request)
4451

45-
def create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResponse:
52+
def _create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResponse:
4653
"""
4754
Create a streaming response from a stream generator. The StreamingResponse
4855
is the format that FastAPI expects for streaming responses.
@@ -56,3 +63,14 @@ def create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResp
5663
},
5764
status_code=200,
5865
)
66+
67+
def _create_json_response(self, response: ModelResponse) -> JSONResponse:
68+
"""
69+
Create a JSON FastAPI response from a ModelResponse object.
70+
ModelResponse is obtained when the request is not streaming.
71+
"""
72+
# ModelResponse is not a Pydantic object but has a json method we can use to serialize
73+
if isinstance(response, ModelResponse):
74+
return JSONResponse(status_code=200, content=response.json())
75+
# Most of others objects in LiteLLM are Pydantic, we can use the model_dump method
76+
return JSONResponse(status_code=200, content=response.model_dump())

src/codegate/providers/llamacpp/completion_handler.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import json
33
from typing import Any, AsyncIterator, Iterator, Optional, Union
44

5-
from fastapi.responses import StreamingResponse
5+
from fastapi.responses import JSONResponse, StreamingResponse
66
from litellm import ChatCompletionRequest, ModelResponse
77
from llama_cpp.llama_types import (
88
CreateChatCompletionStreamResponse,
@@ -75,7 +75,7 @@ async def execute_completion(
7575

7676
return convert_to_async_iterator(response) if stream else response
7777

78-
def create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResponse:
78+
def _create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResponse:
7979
"""
8080
Create a streaming response from a stream generator. The StreamingResponse
8181
is the format that FastAPI expects for streaming responses.
@@ -89,3 +89,6 @@ def create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResp
8989
},
9090
status_code=200,
9191
)
92+
93+
def _create_json_response(self, response: Any) -> JSONResponse:
94+
raise NotImplementedError("JSON Reponse in LlamaCPP not implemented yet.")

src/codegate/providers/llamacpp/provider.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,4 @@ async def create_completion(
4343

4444
is_fim_request = self._is_fim_request(request, data)
4545
stream = await self.complete(data, None, is_fim_request=is_fim_request)
46-
return self._completion_handler.create_streaming_response(stream)
46+
return self._completion_handler.create_response(stream)

src/codegate/providers/openai/provider.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,4 @@ async def create_completion(
4949

5050
is_fim_request = self._is_fim_request(request, data)
5151
stream = await self.complete(data, api_key, is_fim_request=is_fim_request)
52-
return self._completion_handler.create_streaming_response(stream)
52+
return self._completion_handler.create_response(stream)

src/codegate/providers/vllm/provider.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,4 @@ async def create_completion(
5757

5858
is_fim_request = self._is_fim_request(request, data)
5959
stream = await self.complete(data, api_key, is_fim_request=is_fim_request)
60-
return self._completion_handler.create_streaming_response(stream)
60+
return self._completion_handler.create_response(stream)

tests/providers/litellmshim/test_litellmshim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ async def mock_stream_gen():
117117
generator = mock_stream_gen()
118118

119119
litellm_shim = LiteLLmShim(stream_generator=sse_stream_generator)
120-
response = litellm_shim.create_streaming_response(generator)
120+
response = litellm_shim._create_streaming_response(generator)
121121

122122
# Verify response metadata
123123
assert isinstance(response, StreamingResponse)

tests/providers/test_registry.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,15 @@ def execute_completion(
4343
) -> Any:
4444
pass
4545

46-
def create_streaming_response(
46+
def _create_streaming_response(
4747
self,
4848
stream: AsyncIterator[Any],
4949
) -> StreamingResponse:
5050
return StreamingResponse(stream)
5151

52+
def _create_json_response(self, response: Any) -> Any:
53+
raise NotImplementedError
54+
5255

5356
class MockInputNormalizer(ModelInputNormalizer):
5457
def normalize(self, data: Dict) -> Dict:

0 commit comments

Comments
 (0)