Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Implement provider interface and OpenAI and Anthropic providers #66

Merged
merged 1 commit into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dev = [
"bandit>=1.7.10",
"build>=1.0.0",
"wheel>=0.40.0",
"litellm>=1.52.11",
]

[build-system]
Expand Down
2 changes: 2 additions & 0 deletions src/codegate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

try:
__version__ = metadata.version("codegate")
__description__ = metadata.metadata("codegate")["Summary"]
except metadata.PackageNotFoundError: # pragma: no cover
__version__ = "unknown"
__description__ = "codegate"

from .config import Config, ConfigurationError
from .logging import setup_logging
Expand Down
23 changes: 15 additions & 8 deletions src/codegate/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from .config import Config, ConfigurationError, LogFormat, LogLevel
from .logging import setup_logging
from .server import init_app


def validate_port(ctx: click.Context, param: click.Parameter, value: int) -> int:
Expand Down Expand Up @@ -65,7 +66,6 @@ def serve(
config: Optional[Path],
) -> None:
"""Start the codegate server."""

try:
# Load configuration with priority resolution
cfg = Config.load(
Expand All @@ -79,11 +79,6 @@ def serve(
setup_logging(cfg.log_level, cfg.log_format)
logger = logging.getLogger(__name__)

logger.info("This is an info message")
logger.debug("This is a debug message")
logger.error("This is an error message")
logger.warning("This is a warning message")

logger.info(
"Starting server",
extra={
Expand All @@ -94,13 +89,25 @@ def serve(
},
)

# TODO: Jakub Implement actual server logic here
logger.info("Server started successfully")
app = init_app()

import uvicorn

uvicorn.run(
app,
host=cfg.host,
port=cfg.port,
log_level=cfg.log_level.value.lower(),
log_config=None, # Default logging configuration
)

except KeyboardInterrupt:
logger.info("Shutting down server")
except ConfigurationError as e:
click.echo(f"Configuration error: {e}", err=True)
sys.exit(1)
except Exception as e:
logger.exception("Unexpected error occurred")
click.echo(f"Error: {e}", err=True)
sys.exit(1)

Expand Down
11 changes: 11 additions & 0 deletions src/codegate/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from .anthropic.provider import AnthropicProvider
from .base import BaseProvider
from .openai.provider import OpenAIProvider
from .registry import ProviderRegistry

__all__ = [
"BaseProvider",
"ProviderRegistry",
"OpenAIProvider",
"AnthropicProvider",
]
Empty file.
46 changes: 46 additions & 0 deletions src/codegate/providers/anthropic/adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Any, Dict, Optional

from litellm import AdapterCompletionStreamWrapper, ChatCompletionRequest, ModelResponse
from litellm.adapters.anthropic_adapter import (
AnthropicAdapter as LitellmAnthropicAdapter,
)
from litellm.types.llms.anthropic import AnthropicResponse

from ..base import StreamGenerator
from ..litellmshim import anthropic_stream_generator
from ..litellmshim.litellmshim import BaseAdapter


class AnthropicAdapter(BaseAdapter):
"""
LiteLLM's adapter class interface is used to translate between the Anthropic data
format and the underlying model. The AnthropicAdapter class contains the actual
implementation of the interface methods, we just forward the calls to it.
"""
def __init__(self, stream_generator: StreamGenerator = anthropic_stream_generator):
self.litellm_anthropic_adapter = LitellmAnthropicAdapter()
super().__init__(stream_generator)

def translate_completion_input_params(
self,
completion_request: Dict,
) -> Optional[ChatCompletionRequest]:
return self.litellm_anthropic_adapter.translate_completion_input_params(
completion_request
)

def translate_completion_output_params(
self, response: ModelResponse
) -> Optional[AnthropicResponse]:
return self.litellm_anthropic_adapter.translate_completion_output_params(
response
)

def translate_completion_output_params_streaming(
self, completion_stream: Any
) -> AdapterCompletionStreamWrapper | None:
return (
self.litellm_anthropic_adapter.translate_completion_output_params_streaming(
completion_stream
)
)
34 changes: 34 additions & 0 deletions src/codegate/providers/anthropic/provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import json

from fastapi import Header, HTTPException, Request

from ..base import BaseProvider
from ..litellmshim.litellmshim import LiteLLmShim
from .adapter import AnthropicAdapter


class AnthropicProvider(BaseProvider):
def __init__(self):
adapter = AnthropicAdapter()
completion_handler = LiteLLmShim(adapter)
super().__init__(completion_handler)

def _setup_routes(self):
"""
Sets up the /messages route for the provider as expected by the Anthropic
API. Extracts the API key from the "x-api-key" header and passes it to the
completion handler.
"""
@self.router.post("/messages")
async def create_message(
request: Request,
x_api_key: str = Header(None),
):
if x_api_key == "":
raise HTTPException(status_code=401, detail="No API key provided")

body = await request.body()
data = json.loads(body)

stream = await self.complete(data, x_api_key)
return self._completion_handler.create_streaming_response(stream)
46 changes: 46 additions & 0 deletions src/codegate/providers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from abc import ABC, abstractmethod
from typing import Any, AsyncIterator, Callable, Dict

from fastapi import APIRouter
from fastapi.responses import StreamingResponse

StreamGenerator = Callable[[AsyncIterator[Any]], AsyncIterator[str]]


class BaseCompletionHandler(ABC):
"""
The completion handler is responsible for executing the completion request
and creating the streaming response.
"""

@abstractmethod
async def complete(self, data: Dict, api_key: str) -> AsyncIterator[Any]:
pass

@abstractmethod
def create_streaming_response(
self, stream: AsyncIterator[Any]
) -> StreamingResponse:
pass


class BaseProvider(ABC):
"""
The provider class is responsible for defining the API routes and
calling the completion method using the completion handler.
"""

def __init__(self, completion_handler: BaseCompletionHandler):
self.router = APIRouter()
self._completion_handler = completion_handler
self._setup_routes()

@abstractmethod
def _setup_routes(self) -> None:
pass

async def complete(self, data: Dict, api_key: str) -> AsyncIterator[Any]:
return await self._completion_handler.complete(data, api_key)

def get_routes(self) -> APIRouter:
return self.router
10 changes: 10 additions & 0 deletions src/codegate/providers/litellmshim/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from .adapter import BaseAdapter
from .generators import anthropic_stream_generator, sse_stream_generator
from .litellmshim import LiteLLmShim

__all__ = [
"sse_stream_generator",
"anthropic_stream_generator",
"LiteLLmShim",
"BaseAdapter",
]
44 changes: 44 additions & 0 deletions src/codegate/providers/litellmshim/adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional

from litellm import ChatCompletionRequest, ModelResponse

from codegate.providers.base import StreamGenerator


class BaseAdapter(ABC):
"""
The adapter class is responsible for translating input and output
parameters between the provider-specific on-the-wire API and the
underlying model. We use LiteLLM's ChatCompletionRequest and ModelResponse
is our data model.

The methods in this class implement LiteLLM's Adapter interface and are
not our own. This is to allow us to use LiteLLM's adapter classes as a
drop-in replacement for our own adapters.
"""

def __init__(self, stream_generator: StreamGenerator):
self.stream_generator = stream_generator

@abstractmethod
def translate_completion_input_params(
self, kwargs: Dict
) -> Optional[ChatCompletionRequest]:
"""Convert input parameters to LiteLLM's ChatCompletionRequest format"""
pass

@abstractmethod
def translate_completion_output_params(self, response: ModelResponse) -> Any:
"""Convert non-streaming response from LiteLLM ModelResponse format"""
pass

@abstractmethod
def translate_completion_output_params_streaming(
self, completion_stream: Any
) -> Any:
"""
Convert streaming response from LiteLLM format to a format that
can be passed to a stream generator and to the client.
"""
pass
34 changes: 34 additions & 0 deletions src/codegate/providers/litellmshim/generators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import json
from typing import Any, AsyncIterator

# Since different providers typically use one of these formats for streaming
# responses, we have a single stream generator for each format that is then plugged
# into the adapter.

async def sse_stream_generator(stream: AsyncIterator[Any]) -> AsyncIterator[str]:
"""OpenAI-style SSE format"""
try:
async for chunk in stream:
if hasattr(chunk, "model_dump_json"):
chunk = chunk.model_dump_json(exclude_none=True, exclude_unset=True)
try:
yield f"data:{chunk}\n\n"
except Exception as e:
yield f"data:{str(e)}\n\n"
except Exception as e:
yield f"data: {str(e)}\n\n"
finally:
yield "data: [DONE]\n\n"


async def anthropic_stream_generator(stream: AsyncIterator[Any]) -> AsyncIterator[str]:
"""Anthropic-style SSE format"""
try:
async for chunk in stream:
event_type = chunk.get("type")
try:
yield f"event: {event_type}\ndata:{json.dumps(chunk)}\n\n"
except Exception as e:
yield f"event: {event_type}\ndata:{str(e)}\n\n"
except Exception as e:
yield f"data: {str(e)}\n\n"
51 changes: 51 additions & 0 deletions src/codegate/providers/litellmshim/litellmshim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Any, AsyncIterator, Dict

from fastapi.responses import StreamingResponse
from litellm import ModelResponse, acompletion

from ..base import BaseCompletionHandler
from .adapter import BaseAdapter


class LiteLLmShim(BaseCompletionHandler):
"""
LiteLLM Shim is a wrapper around LiteLLM's API that allows us to use it with
our own completion handler interface without exposing the underlying
LiteLLM API.
"""
def __init__(self, adapter: BaseAdapter):
self._adapter = adapter

async def complete(self, data: Dict, api_key: str) -> AsyncIterator[Any]:
"""
Translate the input parameters to LiteLLM's format using the adapter and
call the LiteLLM API. Then translate the response back to our format using
the adapter.
"""
data["api_key"] = api_key
completion_request = self._adapter.translate_completion_input_params(data)
if completion_request is None:
raise Exception("Couldn't translate the request")

response = await acompletion(**completion_request)

if isinstance(response, ModelResponse):
return self._adapter.translate_completion_output_params(response)
return self._adapter.translate_completion_output_params_streaming(response)

def create_streaming_response(
self, stream: AsyncIterator[Any]
) -> StreamingResponse:
"""
Create a streaming response from a stream generator. The StreamingResponse
is the format that FastAPI expects for streaming responses.
"""
return StreamingResponse(
self._adapter.stream_generator(stream),
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
},
status_code=200,
)
Empty file.
33 changes: 33 additions & 0 deletions src/codegate/providers/openai/adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Any, AsyncIterator, Dict, Optional

from litellm import ChatCompletionRequest, ModelResponse

from ..base import StreamGenerator
from ..litellmshim import sse_stream_generator
from ..litellmshim.litellmshim import BaseAdapter


class OpenAIAdapter(BaseAdapter):
"""
This is just a wrapper around LiteLLM's adapter class interface that passes
through the input and output as-is - LiteLLM's API expects OpenAI's API
format.
"""
def __init__(self, stream_generator: StreamGenerator = sse_stream_generator):
super().__init__(stream_generator)

def translate_completion_input_params(
self, kwargs: Dict
) -> Optional[ChatCompletionRequest]:
try:
return ChatCompletionRequest(**kwargs)
except Exception as e:
raise ValueError(f"Invalid completion parameters: {str(e)}")

def translate_completion_output_params(self, response: ModelResponse) -> Any:
return response

def translate_completion_output_params_streaming(
self, completion_stream: AsyncIterator[ModelResponse]
) -> AsyncIterator[ModelResponse]:
return completion_stream
Loading