Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Keep the code coverage high #80

Merged
merged 1 commit into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ bandit = ">=1.7.10"
build = ">=1.0.0"
wheel = ">=0.40.0"
litellm = ">=1.52.11"
pytest-asyncio = "0.24.0"

[build-system]
requires = ["poetry-core"]
Expand Down
6 changes: 5 additions & 1 deletion src/codegate/providers/litellmshim/generators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
from typing import Any, AsyncIterator

from pydantic import BaseModel

# Since different providers typically use one of these formats for streaming
# responses, we have a single stream generator for each format that is then plugged
# into the adapter.
Expand All @@ -10,7 +12,9 @@ async def sse_stream_generator(stream: AsyncIterator[Any]) -> AsyncIterator[str]
"""OpenAI-style SSE format"""
try:
async for chunk in stream:
if hasattr(chunk, "model_dump_json"):
if isinstance(chunk, BaseModel):
# alternatively we might want to just dump the whole object
# this might even allow us to tighten the typing of the stream
chunk = chunk.model_dump_json(exclude_none=True, exclude_unset=True)
try:
yield f"data:{chunk}\n\n"
Expand Down
5 changes: 3 additions & 2 deletions src/codegate/providers/litellmshim/litellmshim.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ class LiteLLmShim(BaseCompletionHandler):
LiteLLM API.
"""

def __init__(self, adapter: BaseAdapter):
def __init__(self, adapter: BaseAdapter, completion_func=acompletion):
self._adapter = adapter
self._completion_func = completion_func

async def complete(self, data: Dict, api_key: str) -> AsyncIterator[Any]:
"""
Expand All @@ -28,7 +29,7 @@ async def complete(self, data: Dict, api_key: str) -> AsyncIterator[Any]:
if completion_request is None:
raise Exception("Couldn't translate the request")

response = await acompletion(**completion_request)
response = await self._completion_func(**completion_request)

if isinstance(response, ModelResponse):
return self._adapter.translate_completion_output_params(response)
Expand Down
154 changes: 154 additions & 0 deletions tests/providers/anthropic/test_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
from typing import AsyncIterator, Dict, List, Union

import pytest
from litellm import ModelResponse
from litellm.adapters.anthropic_adapter import AnthropicStreamWrapper
from litellm.types.llms.anthropic import (
ContentBlockDelta,
ContentBlockStart,
ContentTextBlockDelta,
MessageChunk,
MessageStartBlock,
)
from litellm.types.utils import Delta, StreamingChoices

from codegate.providers.anthropic.adapter import AnthropicAdapter


@pytest.fixture
def adapter():
return AnthropicAdapter()

def test_translate_completion_input_params(adapter):
# Test input data
completion_request = {
"model": "claude-3-haiku-20240307",
"max_tokens": 1024,
"stream": True,
"messages": [
{
"role": "user",
"system": "You are an expert code reviewer",
"content": [
{
"type": "text",
"text": "Review this code"
}
]
}
]
}
expected = {
'max_tokens': 1024,
'messages': [
{'content': [{'text': 'Review this code', 'type': 'text'}], 'role': 'user'}
],
'model': 'claude-3-haiku-20240307',
'stream': True
}

# Get translation
result = adapter.translate_completion_input_params(completion_request)
assert result == expected

@pytest.mark.asyncio
async def test_translate_completion_output_params_streaming(adapter):
# Test stream data
async def mock_stream():
messages = [
ModelResponse(
id="test_id_1",
choices=[
StreamingChoices(
finish_reason=None,
index=0,
delta=Delta(content="Hello", role="assistant")),
],
model="claude-3-haiku-20240307",
),
ModelResponse(
id="test_id_2",
choices=[
StreamingChoices(finish_reason=None,
index=0,
delta=Delta(content="world", role="assistant")),
],
model="claude-3-haiku-20240307",
),
ModelResponse(
id="test_id_2",
choices=[
StreamingChoices(finish_reason=None,
index=0,
delta=Delta(content="!", role="assistant")),
],
model="claude-3-haiku-20240307",
),
]
for msg in messages:
yield msg

expected: List[Union[MessageStartBlock,ContentBlockStart,ContentBlockDelta]] = [
MessageStartBlock(
type="message_start",
message=MessageChunk(
id="msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY",
type="message",
role="assistant",
content=[],
# litellm makes up a message start block with hardcoded values
model="claude-3-5-sonnet-20240620",
stop_reason=None,
stop_sequence=None,
usage={"input_tokens": 25, "output_tokens": 1},
),
),
ContentBlockStart(
type="content_block_start",
index=0,
content_block={"type": "text", "text": ""},
),
ContentBlockDelta(
type="content_block_delta",
index=0,
delta=ContentTextBlockDelta(type="text_delta", text="Hello"),
),
ContentBlockDelta(
type="content_block_delta",
index=0,
delta=ContentTextBlockDelta(type="text_delta", text="world"),
),
ContentBlockDelta(
type="content_block_delta",
index=0,
delta=ContentTextBlockDelta(type="text_delta", text="!"),
),
# litellm doesn't seem to have a type for message stop
dict(type="message_stop"),
]

stream = adapter.translate_completion_output_params_streaming(mock_stream())
assert isinstance(stream, AnthropicStreamWrapper)

# just so that we can zip over the expected chunks
stream_list = [chunk async for chunk in stream]
# Verify we got all chunks
assert len(stream_list) == 6

for chunk, expected_chunk in zip(stream_list, expected):
assert chunk == expected_chunk


def test_stream_generator_initialization(adapter):
# Verify the default stream generator is set
from codegate.providers.litellmshim import anthropic_stream_generator
assert adapter.stream_generator == anthropic_stream_generator

def test_custom_stream_generator():
# Test that we can inject a custom stream generator
async def custom_generator(stream: AsyncIterator[Dict]) -> AsyncIterator[str]:
async for chunk in stream:
yield "custom: " + str(chunk)

adapter = AnthropicAdapter(stream_generator=custom_generator)
assert adapter.stream_generator == custom_generator
80 changes: 80 additions & 0 deletions tests/providers/litellmshim/test_generators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import AsyncIterator

import pytest
from litellm import ModelResponse

from codegate.providers.litellmshim import (
anthropic_stream_generator,
sse_stream_generator,
)


@pytest.mark.asyncio
async def test_sse_stream_generator():
# Mock stream data
mock_chunks = [
ModelResponse(id="1", choices=[{"text": "Hello"}]),
ModelResponse(id="2", choices=[{"text": "World"}])
]

async def mock_stream():
for chunk in mock_chunks:
yield chunk

# Collect generated SSE messages
messages = []
async for message in sse_stream_generator(mock_stream()):
messages.append(message)

# Verify format and content
assert len(messages) == len(mock_chunks) + 1 # +1 for the [DONE] message
assert all(msg.startswith("data:") for msg in messages)
assert "Hello" in messages[0]
assert "World" in messages[1]
assert messages[-1] == "data: [DONE]\n\n"

@pytest.mark.asyncio
async def test_anthropic_stream_generator():
# Mock Anthropic-style chunks
mock_chunks = [
{"type": "message_start", "message": {"id": "1"}},
{"type": "content_block_start", "content_block": {"text": "Hello"}},
{"type": "content_block_stop", "content_block": {"text": "World"}}
]

async def mock_stream():
for chunk in mock_chunks:
yield chunk

# Collect generated SSE messages
messages = []
async for message in anthropic_stream_generator(mock_stream()):
messages.append(message)

# Verify format and content
assert len(messages) == 3
for msg, chunk in zip(messages, mock_chunks):
assert msg.startswith(f"event: {chunk['type']}\ndata:")
assert "Hello" in messages[1] # content_block_start message
assert "World" in messages[2] # content_block_stop message

@pytest.mark.asyncio
async def test_generators_error_handling():
async def error_stream() -> AsyncIterator[str]:
raise Exception("Test error")
yield # This will never be reached, but is needed for AsyncIterator typing

# Test SSE generator error handling
messages = []
async for message in sse_stream_generator(error_stream()):
messages.append(message)
assert len(messages) == 2
assert "Test error" in messages[0]
assert messages[1] == "data: [DONE]\n\n"

# Test Anthropic generator error handling
messages = []
async for message in anthropic_stream_generator(error_stream()):
messages.append(message)
assert len(messages) == 1
assert "Test error" in messages[0]
Loading