From 316c4af696f2d5bc594737125e26ba640bf27bae Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Fri, 22 Nov 2024 11:22:26 +0100 Subject: [PATCH] Implement provider interface and OpenAI and Anthropic providers Implements an interface for building CodeGate providers. For now we use LiteLLM, but we might switch in the future, so the LiteLLM calls are abstracted away in the `litellmshim` module in two classes: - `BaseAdapter` which provides means for reusing LiteLLM adapters with the same interface. - `LiteLLmShim` that actually calls into liteLLM's completion and calls the adapter before completion to convert into liteLLM's format and then back after completion Using those interfaces, implements an OpenAI and an Anthropic provider. With this patch, codegate allows to pass through requests towards OpenAI and Anthropic. Next, we'll build a pipeline interface to modify the inputs and outputs. --- pyproject.toml | 1 + src/codegate/__init__.py | 2 + src/codegate/cli.py | 23 ++++++--- src/codegate/providers/__init__.py | 11 ++++ src/codegate/providers/anthropic/__init__.py | 0 src/codegate/providers/anthropic/adapter.py | 46 +++++++++++++++++ src/codegate/providers/anthropic/provider.py | 34 +++++++++++++ src/codegate/providers/base.py | 46 +++++++++++++++++ .../providers/litellmshim/__init__.py | 10 ++++ src/codegate/providers/litellmshim/adapter.py | 44 ++++++++++++++++ .../providers/litellmshim/generators.py | 34 +++++++++++++ .../providers/litellmshim/litellmshim.py | 51 +++++++++++++++++++ src/codegate/providers/openai/__init__.py | 0 src/codegate/providers/openai/adapter.py | 33 ++++++++++++ src/codegate/providers/openai/provider.py | 37 ++++++++++++++ src/codegate/providers/registry.py | 25 +++++++++ src/codegate/server.py | 33 ++++++++++++ 17 files changed, 422 insertions(+), 8 deletions(-) create mode 100644 src/codegate/providers/__init__.py create mode 100644 src/codegate/providers/anthropic/__init__.py create mode 100644 src/codegate/providers/anthropic/adapter.py create mode 100644 src/codegate/providers/anthropic/provider.py create mode 100644 src/codegate/providers/base.py create mode 100644 src/codegate/providers/litellmshim/__init__.py create mode 100644 src/codegate/providers/litellmshim/adapter.py create mode 100644 src/codegate/providers/litellmshim/generators.py create mode 100644 src/codegate/providers/litellmshim/litellmshim.py create mode 100644 src/codegate/providers/openai/__init__.py create mode 100644 src/codegate/providers/openai/adapter.py create mode 100644 src/codegate/providers/openai/provider.py create mode 100644 src/codegate/providers/registry.py create mode 100644 src/codegate/server.py diff --git a/pyproject.toml b/pyproject.toml index f3e794f2..1ed942bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dev = [ "bandit>=1.7.10", "build>=1.0.0", "wheel>=0.40.0", + "litellm>=1.52.11", ] [build-system] diff --git a/src/codegate/__init__.py b/src/codegate/__init__.py index 8635637a..529a2354 100644 --- a/src/codegate/__init__.py +++ b/src/codegate/__init__.py @@ -4,8 +4,10 @@ try: __version__ = metadata.version("codegate") + __description__ = metadata.metadata("codegate")["Summary"] except metadata.PackageNotFoundError: # pragma: no cover __version__ = "unknown" + __description__ = "codegate" from .config import Config, ConfigurationError from .logging import setup_logging diff --git a/src/codegate/cli.py b/src/codegate/cli.py index d189832f..bc1c4d13 100644 --- a/src/codegate/cli.py +++ b/src/codegate/cli.py @@ -9,6 +9,7 @@ from .config import Config, ConfigurationError, LogFormat, LogLevel from .logging import setup_logging +from .server import init_app def validate_port(ctx: click.Context, param: click.Parameter, value: int) -> int: @@ -65,7 +66,6 @@ def serve( config: Optional[Path], ) -> None: """Start the codegate server.""" - try: # Load configuration with priority resolution cfg = Config.load( @@ -79,11 +79,6 @@ def serve( setup_logging(cfg.log_level, cfg.log_format) logger = logging.getLogger(__name__) - logger.info("This is an info message") - logger.debug("This is a debug message") - logger.error("This is an error message") - logger.warning("This is a warning message") - logger.info( "Starting server", extra={ @@ -94,13 +89,25 @@ def serve( }, ) - # TODO: Jakub Implement actual server logic here - logger.info("Server started successfully") + app = init_app() + + import uvicorn + + uvicorn.run( + app, + host=cfg.host, + port=cfg.port, + log_level=cfg.log_level.value.lower(), + log_config=None, # Default logging configuration + ) + except KeyboardInterrupt: + logger.info("Shutting down server") except ConfigurationError as e: click.echo(f"Configuration error: {e}", err=True) sys.exit(1) except Exception as e: + logger.exception("Unexpected error occurred") click.echo(f"Error: {e}", err=True) sys.exit(1) diff --git a/src/codegate/providers/__init__.py b/src/codegate/providers/__init__.py new file mode 100644 index 00000000..304eba7d --- /dev/null +++ b/src/codegate/providers/__init__.py @@ -0,0 +1,11 @@ +from .anthropic.provider import AnthropicProvider +from .base import BaseProvider +from .openai.provider import OpenAIProvider +from .registry import ProviderRegistry + +__all__ = [ + "BaseProvider", + "ProviderRegistry", + "OpenAIProvider", + "AnthropicProvider", +] diff --git a/src/codegate/providers/anthropic/__init__.py b/src/codegate/providers/anthropic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/codegate/providers/anthropic/adapter.py b/src/codegate/providers/anthropic/adapter.py new file mode 100644 index 00000000..18a5d208 --- /dev/null +++ b/src/codegate/providers/anthropic/adapter.py @@ -0,0 +1,46 @@ +from typing import Any, Dict, Optional + +from litellm import AdapterCompletionStreamWrapper, ChatCompletionRequest, ModelResponse +from litellm.adapters.anthropic_adapter import ( + AnthropicAdapter as LitellmAnthropicAdapter, +) +from litellm.types.llms.anthropic import AnthropicResponse + +from ..base import StreamGenerator +from ..litellmshim import anthropic_stream_generator +from ..litellmshim.litellmshim import BaseAdapter + + +class AnthropicAdapter(BaseAdapter): + """ + LiteLLM's adapter class interface is used to translate between the Anthropic data + format and the underlying model. The AnthropicAdapter class contains the actual + implementation of the interface methods, we just forward the calls to it. + """ + def __init__(self, stream_generator: StreamGenerator = anthropic_stream_generator): + self.litellm_anthropic_adapter = LitellmAnthropicAdapter() + super().__init__(stream_generator) + + def translate_completion_input_params( + self, + completion_request: Dict, + ) -> Optional[ChatCompletionRequest]: + return self.litellm_anthropic_adapter.translate_completion_input_params( + completion_request + ) + + def translate_completion_output_params( + self, response: ModelResponse + ) -> Optional[AnthropicResponse]: + return self.litellm_anthropic_adapter.translate_completion_output_params( + response + ) + + def translate_completion_output_params_streaming( + self, completion_stream: Any + ) -> AdapterCompletionStreamWrapper | None: + return ( + self.litellm_anthropic_adapter.translate_completion_output_params_streaming( + completion_stream + ) + ) diff --git a/src/codegate/providers/anthropic/provider.py b/src/codegate/providers/anthropic/provider.py new file mode 100644 index 00000000..22e6775b --- /dev/null +++ b/src/codegate/providers/anthropic/provider.py @@ -0,0 +1,34 @@ +import json + +from fastapi import Header, HTTPException, Request + +from ..base import BaseProvider +from ..litellmshim.litellmshim import LiteLLmShim +from .adapter import AnthropicAdapter + + +class AnthropicProvider(BaseProvider): + def __init__(self): + adapter = AnthropicAdapter() + completion_handler = LiteLLmShim(adapter) + super().__init__(completion_handler) + + def _setup_routes(self): + """ + Sets up the /messages route for the provider as expected by the Anthropic + API. Extracts the API key from the "x-api-key" header and passes it to the + completion handler. + """ + @self.router.post("/messages") + async def create_message( + request: Request, + x_api_key: str = Header(None), + ): + if x_api_key == "": + raise HTTPException(status_code=401, detail="No API key provided") + + body = await request.body() + data = json.loads(body) + + stream = await self.complete(data, x_api_key) + return self._completion_handler.create_streaming_response(stream) diff --git a/src/codegate/providers/base.py b/src/codegate/providers/base.py new file mode 100644 index 00000000..bd937873 --- /dev/null +++ b/src/codegate/providers/base.py @@ -0,0 +1,46 @@ +from abc import ABC, abstractmethod +from typing import Any, AsyncIterator, Callable, Dict + +from fastapi import APIRouter +from fastapi.responses import StreamingResponse + +StreamGenerator = Callable[[AsyncIterator[Any]], AsyncIterator[str]] + + +class BaseCompletionHandler(ABC): + """ + The completion handler is responsible for executing the completion request + and creating the streaming response. + """ + + @abstractmethod + async def complete(self, data: Dict, api_key: str) -> AsyncIterator[Any]: + pass + + @abstractmethod + def create_streaming_response( + self, stream: AsyncIterator[Any] + ) -> StreamingResponse: + pass + + +class BaseProvider(ABC): + """ + The provider class is responsible for defining the API routes and + calling the completion method using the completion handler. + """ + + def __init__(self, completion_handler: BaseCompletionHandler): + self.router = APIRouter() + self._completion_handler = completion_handler + self._setup_routes() + + @abstractmethod + def _setup_routes(self) -> None: + pass + + async def complete(self, data: Dict, api_key: str) -> AsyncIterator[Any]: + return await self._completion_handler.complete(data, api_key) + + def get_routes(self) -> APIRouter: + return self.router diff --git a/src/codegate/providers/litellmshim/__init__.py b/src/codegate/providers/litellmshim/__init__.py new file mode 100644 index 00000000..ec191270 --- /dev/null +++ b/src/codegate/providers/litellmshim/__init__.py @@ -0,0 +1,10 @@ +from .adapter import BaseAdapter +from .generators import anthropic_stream_generator, sse_stream_generator +from .litellmshim import LiteLLmShim + +__all__ = [ + "sse_stream_generator", + "anthropic_stream_generator", + "LiteLLmShim", + "BaseAdapter", +] diff --git a/src/codegate/providers/litellmshim/adapter.py b/src/codegate/providers/litellmshim/adapter.py new file mode 100644 index 00000000..b1c349f0 --- /dev/null +++ b/src/codegate/providers/litellmshim/adapter.py @@ -0,0 +1,44 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional + +from litellm import ChatCompletionRequest, ModelResponse + +from codegate.providers.base import StreamGenerator + + +class BaseAdapter(ABC): + """ + The adapter class is responsible for translating input and output + parameters between the provider-specific on-the-wire API and the + underlying model. We use LiteLLM's ChatCompletionRequest and ModelResponse + is our data model. + + The methods in this class implement LiteLLM's Adapter interface and are + not our own. This is to allow us to use LiteLLM's adapter classes as a + drop-in replacement for our own adapters. + """ + + def __init__(self, stream_generator: StreamGenerator): + self.stream_generator = stream_generator + + @abstractmethod + def translate_completion_input_params( + self, kwargs: Dict + ) -> Optional[ChatCompletionRequest]: + """Convert input parameters to LiteLLM's ChatCompletionRequest format""" + pass + + @abstractmethod + def translate_completion_output_params(self, response: ModelResponse) -> Any: + """Convert non-streaming response from LiteLLM ModelResponse format""" + pass + + @abstractmethod + def translate_completion_output_params_streaming( + self, completion_stream: Any + ) -> Any: + """ + Convert streaming response from LiteLLM format to a format that + can be passed to a stream generator and to the client. + """ + pass diff --git a/src/codegate/providers/litellmshim/generators.py b/src/codegate/providers/litellmshim/generators.py new file mode 100644 index 00000000..6ea57a8e --- /dev/null +++ b/src/codegate/providers/litellmshim/generators.py @@ -0,0 +1,34 @@ +import json +from typing import Any, AsyncIterator + +# Since different providers typically use one of these formats for streaming +# responses, we have a single stream generator for each format that is then plugged +# into the adapter. + +async def sse_stream_generator(stream: AsyncIterator[Any]) -> AsyncIterator[str]: + """OpenAI-style SSE format""" + try: + async for chunk in stream: + if hasattr(chunk, "model_dump_json"): + chunk = chunk.model_dump_json(exclude_none=True, exclude_unset=True) + try: + yield f"data:{chunk}\n\n" + except Exception as e: + yield f"data:{str(e)}\n\n" + except Exception as e: + yield f"data: {str(e)}\n\n" + finally: + yield "data: [DONE]\n\n" + + +async def anthropic_stream_generator(stream: AsyncIterator[Any]) -> AsyncIterator[str]: + """Anthropic-style SSE format""" + try: + async for chunk in stream: + event_type = chunk.get("type") + try: + yield f"event: {event_type}\ndata:{json.dumps(chunk)}\n\n" + except Exception as e: + yield f"event: {event_type}\ndata:{str(e)}\n\n" + except Exception as e: + yield f"data: {str(e)}\n\n" diff --git a/src/codegate/providers/litellmshim/litellmshim.py b/src/codegate/providers/litellmshim/litellmshim.py new file mode 100644 index 00000000..364757f0 --- /dev/null +++ b/src/codegate/providers/litellmshim/litellmshim.py @@ -0,0 +1,51 @@ +from typing import Any, AsyncIterator, Dict + +from fastapi.responses import StreamingResponse +from litellm import ModelResponse, acompletion + +from ..base import BaseCompletionHandler +from .adapter import BaseAdapter + + +class LiteLLmShim(BaseCompletionHandler): + """ + LiteLLM Shim is a wrapper around LiteLLM's API that allows us to use it with + our own completion handler interface without exposing the underlying + LiteLLM API. + """ + def __init__(self, adapter: BaseAdapter): + self._adapter = adapter + + async def complete(self, data: Dict, api_key: str) -> AsyncIterator[Any]: + """ + Translate the input parameters to LiteLLM's format using the adapter and + call the LiteLLM API. Then translate the response back to our format using + the adapter. + """ + data["api_key"] = api_key + completion_request = self._adapter.translate_completion_input_params(data) + if completion_request is None: + raise Exception("Couldn't translate the request") + + response = await acompletion(**completion_request) + + if isinstance(response, ModelResponse): + return self._adapter.translate_completion_output_params(response) + return self._adapter.translate_completion_output_params_streaming(response) + + def create_streaming_response( + self, stream: AsyncIterator[Any] + ) -> StreamingResponse: + """ + Create a streaming response from a stream generator. The StreamingResponse + is the format that FastAPI expects for streaming responses. + """ + return StreamingResponse( + self._adapter.stream_generator(stream), + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Transfer-Encoding": "chunked", + }, + status_code=200, + ) diff --git a/src/codegate/providers/openai/__init__.py b/src/codegate/providers/openai/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/codegate/providers/openai/adapter.py b/src/codegate/providers/openai/adapter.py new file mode 100644 index 00000000..4661adee --- /dev/null +++ b/src/codegate/providers/openai/adapter.py @@ -0,0 +1,33 @@ +from typing import Any, AsyncIterator, Dict, Optional + +from litellm import ChatCompletionRequest, ModelResponse + +from ..base import StreamGenerator +from ..litellmshim import sse_stream_generator +from ..litellmshim.litellmshim import BaseAdapter + + +class OpenAIAdapter(BaseAdapter): + """ + This is just a wrapper around LiteLLM's adapter class interface that passes + through the input and output as-is - LiteLLM's API expects OpenAI's API + format. + """ + def __init__(self, stream_generator: StreamGenerator = sse_stream_generator): + super().__init__(stream_generator) + + def translate_completion_input_params( + self, kwargs: Dict + ) -> Optional[ChatCompletionRequest]: + try: + return ChatCompletionRequest(**kwargs) + except Exception as e: + raise ValueError(f"Invalid completion parameters: {str(e)}") + + def translate_completion_output_params(self, response: ModelResponse) -> Any: + return response + + def translate_completion_output_params_streaming( + self, completion_stream: AsyncIterator[ModelResponse] + ) -> AsyncIterator[ModelResponse]: + return completion_stream diff --git a/src/codegate/providers/openai/provider.py b/src/codegate/providers/openai/provider.py new file mode 100644 index 00000000..210ec40d --- /dev/null +++ b/src/codegate/providers/openai/provider.py @@ -0,0 +1,37 @@ +import json + +from fastapi import Header, HTTPException, Request + +from ..base import BaseProvider +from ..litellmshim.litellmshim import LiteLLmShim +from .adapter import OpenAIAdapter + + +class OpenAIProvider(BaseProvider): + def __init__(self): + adapter = OpenAIAdapter() + completion_handler = LiteLLmShim(adapter) + super().__init__(completion_handler) + + def _setup_routes(self): + """ + Sets up the /chat/completions route for the provider as expected by the + OpenAI API. Extracts the API key from the "Authorization" header and + passes it to the completion handler. + """ + @self.router.post("/chat/completions") + async def create_completion( + request: Request, + authorization: str = Header(..., description="Bearer token"), + ): + if not authorization.startswith("Bearer "): + raise HTTPException( + status_code=401, detail="Invalid authorization header" + ) + + api_key = authorization.split(" ")[1] + body = await request.body() + data = json.loads(body) + + stream = await self.complete(data, api_key) + return self._completion_handler.create_streaming_response(stream) diff --git a/src/codegate/providers/registry.py b/src/codegate/providers/registry.py new file mode 100644 index 00000000..b42a6f51 --- /dev/null +++ b/src/codegate/providers/registry.py @@ -0,0 +1,25 @@ +from typing import Dict, Optional + +from fastapi import FastAPI + +from .base import BaseProvider + + +class ProviderRegistry: + def __init__(self, app: FastAPI): + self.app = app + self.providers: Dict[str, BaseProvider] = {} + + def add_provider(self, name: str, provider: BaseProvider): + """ + Adds a provider to the registry. This will also add the provider's routes + to the FastAPI app. + """ + self.providers[name] = provider + self.app.include_router(provider.get_routes()) + + def get_provider(self, name: str) -> Optional[BaseProvider]: + """ + Retrieves a provider from the registry by name. + """ + return self.providers.get(name) diff --git a/src/codegate/server.py b/src/codegate/server.py new file mode 100644 index 00000000..a9c203e2 --- /dev/null +++ b/src/codegate/server.py @@ -0,0 +1,33 @@ +from fastapi import APIRouter, FastAPI + +from . import __description__, __version__ +from .providers.anthropic.provider import AnthropicProvider +from .providers.openai.provider import OpenAIProvider +from .providers.registry import ProviderRegistry + + +def init_app() -> FastAPI: + app = FastAPI( + title="CodeGate", + description=__description__, + version=__version__, + ) + + # Create provider registry + registry = ProviderRegistry(app) + + # Register all known providers + registry.add_provider("openai", OpenAIProvider()) + registry.add_provider("anthropic", AnthropicProvider()) + + # Create and add system routes + system_router = APIRouter(tags=["System"]) # Tags group endpoints in the docs + + @system_router.get("/health") + async def health_check(): + return {"status": "healthy"} + + # Include the router in the app - this exposes the health check endpoint + app.include_router(system_router) + + return app