diff --git a/config.yaml.example b/config.yaml.example index dcf3db30..33e83f09 100644 --- a/config.yaml.example +++ b/config.yaml.example @@ -1,19 +1,42 @@ -# Example configuration file -# Copy this file to config.yaml and modify as needed +# Codegate Example Configuration -# Server configuration -port: 8989 -host: "localhost" +# Network settings +port: 8989 # Port to listen on (1-65535) +host: "localhost" # Host to bind to (use localhost for all interfaces) # Logging configuration -log_level: "INFO" # ERROR, WARNING, INFO, DEBUG -log_format: "JSON" # JSON, TEXT +log_level: "INFO" # One of: ERROR, WARNING, INFO, DEBUG -# Prompts configuration -# Option 1: Define prompts directly in the config file -prompts: - my_system_prompt: "Custom system prompt defined in config" - another_prompt: "Another custom prompt" +# Note: This configuration can be overridden by: +# 1. CLI arguments (--port, --host, --log-level) +# 2. Environment variables (CODEGATE_APP_PORT, CODEGATE_APP_HOST, CODEGATE_APP_LOG_LEVEL) -# Option 2: Reference a separate prompts file -# prompts: "prompts.yaml" # Path to prompts file (relative to config file or absolute) +# Provider URLs +provider_urls: + openai: "https://api.openai.com/v1" + anthropic: "https://api.anthropic.com/v1" + vllm: "http://localhost:8000" # Base URL without /v1 path, it will be added automatically + +# Note: Provider URLs can be overridden by environment variables: +# CODEGATE_PROVIDER_OPENAI_URL +# CODEGATE_PROVIDER_ANTHROPIC_URL +# CODEGATE_PROVIDER_VLLM_URL +# Or by CLI flags: +# --vllm-url +# --openai-url +# --anthropic-url + +# Embedding model configuration + +#### +# Inference model configuration +## + +# Model to use for chatting +chat_model_path: "./models/qwen2.5-coder-1.5b-instruct-q5_k_m.gguf" + +# Context length of the model +chat_model_n_ctx: 32768 + +# Number of layers to offload to GPU. If -1, all layers are offloaded. +chat_model_n_gpu_layers: -1 diff --git a/docs/cli.md b/docs/cli.md index 74684e31..9142d405 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -46,6 +46,21 @@ codegate serve [OPTIONS] - Must be a valid YAML file - Overrides default prompts and configuration file prompts +- `--vllm-url TEXT`: vLLM provider URL (default: http://localhost:8000) + - Optional + - Base URL for vLLM provider (/v1 path is added automatically) + - Overrides configuration file and environment variables + +- `--openai-url TEXT`: OpenAI provider URL (default: https://api.openai.com/v1) + - Optional + - Base URL for OpenAI provider + - Overrides configuration file and environment variables + +- `--anthropic-url TEXT`: Anthropic provider URL (default: https://api.anthropic.com/v1) + - Optional + - Base URL for Anthropic provider + - Overrides configuration file and environment variables + ### show-prompts Display the loaded system prompts: @@ -100,6 +115,11 @@ Start server with custom prompts: codegate serve --prompts my-prompts.yaml ``` +Start server with custom vLLM endpoint: +```bash +codegate serve --vllm-url https://vllm.example.com +``` + Show default system prompts: ```bash codegate show-prompts diff --git a/docs/configuration.md b/docs/configuration.md index a54796ad..6a015d28 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -16,6 +16,10 @@ The configuration system in Codegate is managed through the `Config` class in `c - Log Level: "INFO" - Log Format: "JSON" - Prompts: Default prompts from prompts/default.yaml +- Provider URLs: + - vLLM: "http://localhost:8000" + - OpenAI: "https://api.openai.com/v1" + - Anthropic: "https://api.anthropic.com/v1" ## Configuration Methods @@ -27,6 +31,18 @@ Load configuration from a YAML file: config = Config.from_file("config.yaml") ``` +Example config.yaml: +```yaml +port: 8989 +host: localhost +log_level: INFO +log_format: JSON +provider_urls: + vllm: "https://vllm.example.com" + openai: "https://api.openai.com/v1" + anthropic: "https://api.anthropic.com/v1" +``` + ### From Environment Variables Environment variables are automatically loaded with these mappings: @@ -36,6 +52,9 @@ Environment variables are automatically loaded with these mappings: - `CODEGATE_APP_LOG_LEVEL`: Logging level - `CODEGATE_LOG_FORMAT`: Log format - `CODEGATE_PROMPTS_FILE`: Path to prompts YAML file +- `CODEGATE_PROVIDER_VLLM_URL`: vLLM provider URL +- `CODEGATE_PROVIDER_OPENAI_URL`: OpenAI provider URL +- `CODEGATE_PROVIDER_ANTHROPIC_URL`: Anthropic provider URL ```python config = Config.from_env() @@ -43,6 +62,32 @@ config = Config.from_env() ## Configuration Options +### Provider URLs + +Provider URLs can be configured in several ways: + +1. In Configuration File: + ```yaml + provider_urls: + vllm: "https://vllm.example.com" # /v1 path is added automatically + openai: "https://api.openai.com/v1" + anthropic: "https://api.anthropic.com/v1" + ``` + +2. Via Environment Variables: + ```bash + export CODEGATE_PROVIDER_VLLM_URL=https://vllm.example.com + export CODEGATE_PROVIDER_OPENAI_URL=https://api.openai.com/v1 + export CODEGATE_PROVIDER_ANTHROPIC_URL=https://api.anthropic.com/v1 + ``` + +3. Via CLI Flags: + ```bash + codegate serve --vllm-url https://vllm.example.com + ``` + +Note: For the vLLM provider, the /v1 path is automatically appended to the base URL if not present. + ### Log Levels Available log levels (case-insensitive): diff --git a/docs/development.md b/docs/development.md index 1f954939..ef66219b 100644 --- a/docs/development.md +++ b/docs/development.md @@ -9,6 +9,7 @@ Codegate is a configurable Generative AI gateway designed to protect developers - Secure coding recommendations - Prevention of AI recommending deprecated/malicious libraries - Modular system prompts configuration +- Multiple AI provider support with configurable endpoints ## Development Setup @@ -53,7 +54,11 @@ codegate/ │ ├── logging.py # Logging setup │ ├── prompts.py # Prompts management │ ├── server.py # Main server implementation -│ └── providers/* # External service providers (anthropic, openai, etc.) +│ └── providers/ # External service providers +│ ├── anthropic/ # Anthropic provider implementation +│ ├── openai/ # OpenAI provider implementation +│ ├── vllm/ # vLLM provider implementation +│ └── base.py # Base provider interface ├── tests/ # Test files └── docs/ # Documentation ``` @@ -128,9 +133,87 @@ Codegate uses a hierarchical configuration system with the following priority (h - Log Level: Logging level (ERROR|WARNING|INFO|DEBUG) - Log Format: Log format (JSON|TEXT) - Prompts: System prompts configuration +- Provider URLs: AI provider endpoint configuration See [Configuration Documentation](configuration.md) for detailed information. +## Working with Providers + +Codegate supports multiple AI providers through a modular provider system. + +### Available Providers + +1. **vLLM Provider** + - Default URL: http://localhost:8000 + - Supports OpenAI-compatible API + - Automatically adds /v1 path to base URL + - Model names are prefixed with "hosted_vllm/" + +2. **OpenAI Provider** + - Default URL: https://api.openai.com/v1 + - Standard OpenAI API implementation + +3. **Anthropic Provider** + - Default URL: https://api.anthropic.com/v1 + - Anthropic Claude API implementation + +### Configuring Providers + +Provider URLs can be configured through: + +1. Config file (config.yaml): + ```yaml + provider_urls: + vllm: "https://vllm.example.com" + openai: "https://api.openai.com/v1" + anthropic: "https://api.anthropic.com/v1" + ``` + +2. Environment variables: + ```bash + export CODEGATE_PROVIDER_VLLM_URL=https://vllm.example.com + export CODEGATE_PROVIDER_OPENAI_URL=https://api.openai.com/v1 + export CODEGATE_PROVIDER_ANTHROPIC_URL=https://api.anthropic.com/v1 + ``` + +3. CLI flags: + ```bash + codegate serve --vllm-url https://vllm.example.com + ``` + +### Implementing New Providers + +To add a new provider: + +1. Create a new directory in `src/codegate/providers/` +2. Implement required components: + - `provider.py`: Main provider class extending BaseProvider + - `adapter.py`: Input/output normalizers + - `__init__.py`: Export provider class + +Example structure: +```python +from codegate.providers.base import BaseProvider + +class NewProvider(BaseProvider): + def __init__(self, ...): + super().__init__( + InputNormalizer(), + OutputNormalizer(), + completion_handler, + pipeline_processor, + fim_pipeline_processor + ) + + @property + def provider_route_name(self) -> str: + return "provider_name" + + def _setup_routes(self): + # Implement route setup + pass +``` + ## Working with Prompts ### Default Prompts @@ -188,8 +271,9 @@ codegate serve --port 8989 --host localhost --log-level DEBUG # Start with custom prompts codegate serve --prompts my-prompts.yaml + +# Start with custom provider URL +codegate serve --vllm-url https://vllm.example.com ``` See [CLI Documentation](cli.md) for detailed command information. - -[Rest of development.md content remains unchanged...] diff --git a/src/codegate/cli.py b/src/codegate/cli.py index 8688947b..b6911bfe 100644 --- a/src/codegate/cli.py +++ b/src/codegate/cli.py @@ -2,7 +2,7 @@ import sys from pathlib import Path -from typing import Optional +from typing import Dict, Optional import click @@ -88,6 +88,24 @@ def show_prompts(prompts: Optional[Path]) -> None: default=None, help="Path to YAML prompts file", ) +@click.option( + "--vllm-url", + type=str, + default=None, + help="vLLM provider URL (default: http://localhost:8000/v1)", +) +@click.option( + "--openai-url", + type=str, + default=None, + help="OpenAI provider URL (default: https://api.openai.com/v1)", +) +@click.option( + "--anthropic-url", + type=str, + default=None, + help="Anthropic provider URL (default: https://api.anthropic.com/v1)", +) def serve( port: Optional[int], host: Optional[str], @@ -95,10 +113,22 @@ def serve( log_format: Optional[str], config: Optional[Path], prompts: Optional[Path], + vllm_url: Optional[str], + openai_url: Optional[str], + anthropic_url: Optional[str], ) -> None: """Start the codegate server.""" logger = None try: + # Create provider URLs dict from CLI options + cli_provider_urls: Dict[str, str] = {} + if vllm_url: + cli_provider_urls["vllm"] = vllm_url + if openai_url: + cli_provider_urls["openai"] = openai_url + if anthropic_url: + cli_provider_urls["anthropic"] = anthropic_url + # Load configuration with priority resolution cfg = Config.load( config_path=config, @@ -107,6 +137,7 @@ def serve( cli_host=host, cli_log_level=log_level, cli_log_format=log_format, + cli_provider_urls=cli_provider_urls, ) logger = setup_logging(cfg.log_level, cfg.log_format) @@ -118,6 +149,7 @@ def serve( "log_level": cfg.log_level.value, "log_format": cfg.log_format.value, "prompts_loaded": len(cfg.prompts.prompts), + "provider_urls": cfg.provider_urls, }, ) diff --git a/src/codegate/config.py b/src/codegate/config.py index e63e5fc1..1f304227 100644 --- a/src/codegate/config.py +++ b/src/codegate/config.py @@ -3,7 +3,7 @@ import os from dataclasses import dataclass, field from pathlib import Path -from typing import Optional, Union +from typing import Dict, Optional, Union import yaml @@ -13,6 +13,13 @@ logger = setup_logging() +# Default provider URLs +DEFAULT_PROVIDER_URLS = { + "openai": "https://api.openai.com/v1", + "anthropic": "https://api.anthropic.com/v1", + "vllm": "http://localhost:8000", # Base URL without /v1 path +} + @dataclass class Config: @@ -32,6 +39,9 @@ class Config: chat_model_n_ctx: int = 32768 chat_model_n_gpu_layers: int = -1 + # Provider URLs with defaults + provider_urls: Dict[str, str] = field(default_factory=lambda: DEFAULT_PROVIDER_URLS.copy()) + def __post_init__(self) -> None: """Validate configuration after initialization.""" if not isinstance(self.port, int) or not (1 <= self.port <= 65535): @@ -95,19 +105,23 @@ def from_file(cls, config_path: Union[str, Path]) -> "Config": prompts_path = Path(config_path).parent / prompts_path prompts_config = PromptConfig.from_file(prompts_path) + # Get provider URLs from config + provider_urls = DEFAULT_PROVIDER_URLS.copy() + if "provider_urls" in config_data: + provider_urls.update(config_data.pop("provider_urls")) + return cls( port=config_data.get("port", cls.port), host=config_data.get("host", cls.host), log_level=config_data.get("log_level", cls.log_level.value), log_format=config_data.get("log_format", cls.log_format.value), model_base_path=config_data.get("chat_model_path", cls.model_base_path), - chat_model_n_ctx=config_data.get( - "chat_model_n_ctx", cls.chat_model_n_ctx - ), + chat_model_n_ctx=config_data.get("chat_model_n_ctx", cls.chat_model_n_ctx), chat_model_n_gpu_layers=config_data.get( "chat_model_n_gpu_layers", cls.chat_model_n_gpu_layers ), prompts=prompts_config, + provider_urls=provider_urls, ) except yaml.YAMLError as e: raise ConfigurationError(f"Failed to parse config file: {e}") @@ -138,6 +152,12 @@ def from_env(cls) -> "Config": os.environ["CODEGATE_PROMPTS_FILE"] ) # noqa: E501 + # Load provider URLs from environment variables + for provider in DEFAULT_PROVIDER_URLS.keys(): + env_var = f"CODEGATE_PROVIDER_{provider.upper()}_URL" + if env_var in os.environ: + config.provider_urls[provider] = os.environ[env_var] + return config except ValueError as e: raise ConfigurationError(f"Invalid environment variable value: {e}") @@ -151,6 +171,7 @@ def load( cli_host: Optional[str] = None, cli_log_level: Optional[str] = None, cli_log_format: Optional[str] = None, + cli_provider_urls: Optional[Dict[str, str]] = None, ) -> "Config": """Load configuration with priority resolution. @@ -167,6 +188,7 @@ def load( cli_host: Optional CLI host override cli_log_level: Optional CLI log level override cli_log_format: Optional CLI log format override + cli_provider_urls: Optional dict of provider URLs from CLI Returns: Config: Resolved configuration @@ -198,6 +220,10 @@ def load( if "CODEGATE_PROMPTS_FILE" in os.environ: config.prompts = env_config.prompts + # Override provider URLs from environment + for provider, url in env_config.provider_urls.items(): + config.provider_urls[provider] = url + # Override with CLI arguments if cli_port is not None: config.port = cli_port @@ -209,6 +235,8 @@ def load( config.log_format = LogFormat(cli_log_format) if prompts_path is not None: config.prompts = PromptConfig.from_file(prompts_path) + if cli_provider_urls is not None: + config.provider_urls.update(cli_provider_urls) # Set the __config class attribute Config.__config = config diff --git a/src/codegate/pipeline/fim/secret_analyzer.py b/src/codegate/pipeline/fim/secret_analyzer.py index 83f89a11..68f38351 100644 --- a/src/codegate/pipeline/fim/secret_analyzer.py +++ b/src/codegate/pipeline/fim/secret_analyzer.py @@ -28,9 +28,7 @@ def name(self) -> str: return "fim-secret-analyzer" async def process( - self, - request: ChatCompletionRequest, - context: PipelineContext + self, request: ChatCompletionRequest, context: PipelineContext ) -> PipelineResult: # We should call here Secrets Blocking module to see if the request messages contain secrets # messages_contain_secrets = [analyze_msg_secrets(msg) for msg in request.messages] @@ -39,7 +37,7 @@ async def process( # For the moment to test shortcutting just treat all messages as if they contain secrets message_with_secrets = False if message_with_secrets: - logger.info('Blocking message with secrets.') + logger.info("Blocking message with secrets.") return PipelineResult( response=PipelineResponse( step_name=self.name, diff --git a/src/codegate/providers/__init__.py b/src/codegate/providers/__init__.py index d83fa21a..1ec69183 100644 --- a/src/codegate/providers/__init__.py +++ b/src/codegate/providers/__init__.py @@ -2,10 +2,12 @@ from codegate.providers.base import BaseProvider from codegate.providers.openai.provider import OpenAIProvider from codegate.providers.registry import ProviderRegistry +from codegate.providers.vllm.provider import VLLMProvider __all__ = [ "BaseProvider", "ProviderRegistry", "OpenAIProvider", "AnthropicProvider", + "VLLMProvider", ] diff --git a/src/codegate/providers/anthropic/completion_handler.py b/src/codegate/providers/anthropic/completion_handler.py index ed45eb40..253e2970 100644 --- a/src/codegate/providers/anthropic/completion_handler.py +++ b/src/codegate/providers/anthropic/completion_handler.py @@ -27,7 +27,7 @@ async def execute_completion( For more details, refer to the [LiteLLM Documentation](https://docs.litellm.ai/docs/providers/anthropic). """ - model_in_request = request['model'] - if not model_in_request.startswith('anthropic/'): - request['model'] = f'anthropic/{model_in_request}' + model_in_request = request["model"] + if not model_in_request.startswith("anthropic/"): + request["model"] = f"anthropic/{model_in_request}" return await super().execute_completion(request, api_key, stream) diff --git a/src/codegate/providers/anthropic/provider.py b/src/codegate/providers/anthropic/provider.py index e0042ace..a035f7aa 100644 --- a/src/codegate/providers/anthropic/provider.py +++ b/src/codegate/providers/anthropic/provider.py @@ -11,17 +11,17 @@ class AnthropicProvider(BaseProvider): def __init__( - self, - pipeline_processor: Optional[SequentialPipelineProcessor] = None, - fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None - ): + self, + pipeline_processor: Optional[SequentialPipelineProcessor] = None, + fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None, + ): completion_handler = AnthropicCompletion(stream_generator=anthropic_stream_generator) super().__init__( AnthropicInputNormalizer(), AnthropicOutputNormalizer(), completion_handler, pipeline_processor, - fim_pipeline_processor + fim_pipeline_processor, ) @property diff --git a/src/codegate/providers/base.py b/src/codegate/providers/base.py index 4f2d65e0..a350d0a9 100644 --- a/src/codegate/providers/base.py +++ b/src/codegate/providers/base.py @@ -50,17 +50,15 @@ def provider_route_name(self) -> str: pass async def _run_input_pipeline( - self, - normalized_request: ChatCompletionRequest, - is_fim_request: bool + self, normalized_request: ChatCompletionRequest, is_fim_request: bool ) -> PipelineResult: # Decide which pipeline processor to use if is_fim_request: pipeline_processor = self._fim_pipelin_processor - logger.info('FIM pipeline selected for execution.') + logger.info("FIM pipeline selected for execution.") else: pipeline_processor = self._pipeline_processor - logger.info('Chat completion pipeline selected for execution.') + logger.info("Chat completion pipeline selected for execution.") if pipeline_processor is None: return PipelineResult(request=normalized_request) @@ -92,21 +90,21 @@ def _is_fim_request_body(self, data: Dict) -> bool: Determine from the raw incoming data if it's a FIM request. Used by: OpenAI and Anthropic """ - messages = data.get('messages', []) + messages = data.get("messages", []) if not messages: return False - first_message_content = messages[0].get('content') + first_message_content = messages[0].get("content") if first_message_content is None: return False - fim_stop_sequences = ['', '', '', ''] + fim_stop_sequences = ["", "", "", ""] if isinstance(first_message_content, str): msg_prompt = first_message_content elif isinstance(first_message_content, list): - msg_prompt = first_message_content[0].get('text', '') + msg_prompt = first_message_content[0].get("text", "") else: - logger.warning(f'Could not determine if message was FIM from data: {data}') + logger.warning(f"Could not determine if message was FIM from data: {data}") return False return all([stop_sequence in msg_prompt for stop_sequence in fim_stop_sequences]) @@ -121,7 +119,7 @@ def _is_fim_request(self, request: Request, data: Dict) -> bool: return self._is_fim_request_body(data) async def complete( - self, data: Dict, api_key: Optional[str], is_fim_request: bool + self, data: Dict, api_key: Optional[str], is_fim_request: bool ) -> Union[ModelResponse, AsyncIterator[ModelResponse]]: """ Main completion flow with pipeline integration diff --git a/src/codegate/providers/llamacpp/completion_handler.py b/src/codegate/providers/llamacpp/completion_handler.py index 40046975..0c074985 100644 --- a/src/codegate/providers/llamacpp/completion_handler.py +++ b/src/codegate/providers/llamacpp/completion_handler.py @@ -1,5 +1,5 @@ -import json import asyncio +import json from typing import Any, AsyncIterator, Iterator, Optional, Union from fastapi.responses import StreamingResponse @@ -39,16 +39,20 @@ async def execute_completion( """ model_path = f"{Config.get_config().model_base_path}/{request['model']}.gguf" - if 'prompt' in request: - response = await self.inference_engine.complete(model_path, - Config.get_config().chat_model_n_ctx, - Config.get_config().chat_model_n_gpu_layers, - **request) + if "prompt" in request: + response = await self.inference_engine.complete( + model_path, + Config.get_config().chat_model_n_ctx, + Config.get_config().chat_model_n_gpu_layers, + **request, + ) else: - response = await self.inference_engine.chat(model_path, - Config.get_config().chat_model_n_ctx, - Config.get_config().chat_model_n_gpu_layers, - **request) + response = await self.inference_engine.chat( + model_path, + Config.get_config().chat_model_n_ctx, + Config.get_config().chat_model_n_gpu_layers, + **request, + ) return response def create_streaming_response(self, stream: Iterator[Any]) -> StreamingResponse: diff --git a/src/codegate/providers/llamacpp/provider.py b/src/codegate/providers/llamacpp/provider.py index eddfa901..befc169e 100644 --- a/src/codegate/providers/llamacpp/provider.py +++ b/src/codegate/providers/llamacpp/provider.py @@ -10,17 +10,17 @@ class LlamaCppProvider(BaseProvider): def __init__( - self, - pipeline_processor: Optional[SequentialPipelineProcessor] = None, - fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None - ): + self, + pipeline_processor: Optional[SequentialPipelineProcessor] = None, + fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None, + ): completion_handler = LlamaCppCompletionHandler() super().__init__( LLamaCppInputNormalizer(), LLamaCppOutputNormalizer(), completion_handler, pipeline_processor, - fim_pipeline_processor + fim_pipeline_processor, ) @property @@ -32,6 +32,7 @@ def _setup_routes(self): Sets up the /completions and /chat/completions routes for the provider as expected by the Llama API. """ + @self.router.post(f"/{self.provider_route_name}/completions") @self.router.post(f"/{self.provider_route_name}/chat/completions") async def create_completion( diff --git a/src/codegate/providers/openai/provider.py b/src/codegate/providers/openai/provider.py index 209118b5..60c36a4b 100644 --- a/src/codegate/providers/openai/provider.py +++ b/src/codegate/providers/openai/provider.py @@ -10,17 +10,17 @@ class OpenAIProvider(BaseProvider): def __init__( - self, - pipeline_processor: Optional[SequentialPipelineProcessor] = None, - fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None - ): + self, + pipeline_processor: Optional[SequentialPipelineProcessor] = None, + fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None, + ): completion_handler = LiteLLmShim(stream_generator=sse_stream_generator) super().__init__( OpenAIInputNormalizer(), OpenAIOutputNormalizer(), completion_handler, pipeline_processor, - fim_pipeline_processor + fim_pipeline_processor, ) @property diff --git a/src/codegate/providers/vllm/__init__.py b/src/codegate/providers/vllm/__init__.py new file mode 100644 index 00000000..a6bf6555 --- /dev/null +++ b/src/codegate/providers/vllm/__init__.py @@ -0,0 +1,3 @@ +from codegate.providers.vllm.provider import VLLMProvider + +__all__ = ["VLLMProvider"] diff --git a/src/codegate/providers/vllm/adapter.py b/src/codegate/providers/vllm/adapter.py new file mode 100644 index 00000000..38d78740 --- /dev/null +++ b/src/codegate/providers/vllm/adapter.py @@ -0,0 +1,73 @@ +from typing import Any, Dict + +from litellm import ChatCompletionRequest + +from codegate.providers.normalizer.base import ModelInputNormalizer, ModelOutputNormalizer + + +class VLLMInputNormalizer(ModelInputNormalizer): + def __init__(self): + super().__init__() + + def normalize(self, data: Dict) -> ChatCompletionRequest: + """ + Normalize the input data to the format expected by LiteLLM. + Ensures the model name has the hosted_vllm prefix and base_url has /v1. + """ + # Make a copy of the data to avoid modifying the original + normalized_data = data.copy() + + # Format the model name to include the provider + if "model" in normalized_data: + model_name = normalized_data["model"] + if not model_name.startswith("hosted_vllm/"): + normalized_data["model"] = f"hosted_vllm/{model_name}" + + # Ensure the base_url ends with /v1 if provided + if "base_url" in normalized_data: + base_url = normalized_data["base_url"].rstrip("/") + if not base_url.endswith("/v1"): + normalized_data["base_url"] = f"{base_url}/v1" + + return ChatCompletionRequest(**normalized_data) + + def denormalize(self, data: ChatCompletionRequest) -> Dict: + """ + Convert back to raw format for the API request + """ + return data + + +class VLLMOutputNormalizer(ModelOutputNormalizer): + def __init__(self): + super().__init__() + + def normalize_streaming( + self, + model_reply: Any, + ) -> Any: + """ + No normalizing needed for streaming responses + """ + return model_reply + + def normalize(self, model_reply: Any) -> Any: + """ + No normalizing needed for responses + """ + return model_reply + + def denormalize(self, normalized_reply: Any) -> Any: + """ + No denormalizing needed for responses + """ + return normalized_reply + + def denormalize_streaming( + self, + normalized_reply: Any, + ) -> Any: + """ + No denormalizing needed for streaming responses + """ + return normalized_reply diff --git a/src/codegate/providers/vllm/provider.py b/src/codegate/providers/vllm/provider.py new file mode 100644 index 00000000..05c8f720 --- /dev/null +++ b/src/codegate/providers/vllm/provider.py @@ -0,0 +1,57 @@ +import json +from typing import Optional + +from fastapi import Header, HTTPException, Request + +from codegate.config import Config +from codegate.providers.base import BaseProvider, SequentialPipelineProcessor +from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator +from codegate.providers.vllm.adapter import VLLMInputNormalizer, VLLMOutputNormalizer + + +class VLLMProvider(BaseProvider): + def __init__( + self, + pipeline_processor: Optional[SequentialPipelineProcessor] = None, + fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None, + ): + completion_handler = LiteLLmShim(stream_generator=sse_stream_generator) + super().__init__( + VLLMInputNormalizer(), + VLLMOutputNormalizer(), + completion_handler, + pipeline_processor, + fim_pipeline_processor, + ) + + @property + def provider_route_name(self) -> str: + return "vllm" + + def _setup_routes(self): + """ + Sets up the /chat/completions route for the provider as expected by the + OpenAI API. Extracts the API key from the "Authorization" header and + passes it to the completion handler. + """ + + @self.router.post(f"/{self.provider_route_name}/chat/completions") + @self.router.post(f"/{self.provider_route_name}/completions") + async def create_completion( + request: Request, + authorization: str = Header(..., description="Bearer token"), + ): + if not authorization.startswith("Bearer "): + raise HTTPException(status_code=401, detail="Invalid authorization header") + + api_key = authorization.split(" ")[1] + body = await request.body() + data = json.loads(body) + + # Add the vLLM base URL to the request + config = Config.get_config() + data["base_url"] = config.provider_urls.get("vllm") + + is_fim_request = self._is_fim_request(request, data) + stream = await self.complete(data, api_key, is_fim_request=is_fim_request) + return self._completion_handler.create_streaming_response(stream) diff --git a/src/codegate/server.py b/src/codegate/server.py index 14ccb60a..046ed60e 100644 --- a/src/codegate/server.py +++ b/src/codegate/server.py @@ -9,6 +9,7 @@ from codegate.providers.llamacpp.provider import LlamaCppProvider from codegate.providers.openai.provider import OpenAIProvider from codegate.providers.registry import ProviderRegistry +from codegate.providers.vllm.provider import VLLMProvider def init_app() -> FastAPI: @@ -22,8 +23,7 @@ def init_app() -> FastAPI: CodegateVersion(), ] # Leaving the pipeline empty for now - fim_steps: List[PipelineStep] = [ - ] + fim_steps: List[PipelineStep] = [] pipeline = SequentialPipelineProcessor(steps) fim_pipeline = SequentialPipelineProcessor(fim_steps) @@ -31,18 +31,20 @@ def init_app() -> FastAPI: registry = ProviderRegistry(app) # Register all known providers - registry.add_provider("openai", OpenAIProvider( - pipeline_processor=pipeline, - fim_pipeline_processor=fim_pipeline - )) - registry.add_provider("anthropic", AnthropicProvider( - pipeline_processor=pipeline, - fim_pipeline_processor=fim_pipeline - )) - registry.add_provider("llamacpp", LlamaCppProvider( - pipeline_processor=pipeline, - fim_pipeline_processor=fim_pipeline - )) + registry.add_provider( + "openai", OpenAIProvider(pipeline_processor=pipeline, fim_pipeline_processor=fim_pipeline) + ) + registry.add_provider( + "anthropic", + AnthropicProvider(pipeline_processor=pipeline, fim_pipeline_processor=fim_pipeline), + ) + registry.add_provider( + "llamacpp", + LlamaCppProvider(pipeline_processor=pipeline, fim_pipeline_processor=fim_pipeline), + ) + registry.add_provider( + "vllm", VLLMProvider(pipeline_processor=pipeline, fim_pipeline_processor=fim_pipeline) + ) # Create and add system routes system_router = APIRouter(tags=["System"]) # Tags group endpoints in the docs diff --git a/tests/test_cli.py b/tests/test_cli.py index 2329e788..64f8a53e 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -9,6 +9,7 @@ from codegate.cli import cli from codegate.codegate_logging import LogFormat, LogLevel +from codegate.config import DEFAULT_PROVIDER_URLS @pytest.fixture @@ -63,6 +64,7 @@ def test_serve_default_options(cli_runner: CliRunner, mock_logging: Any) -> None "log_level": "INFO", "log_format": "JSON", "prompts_loaded": 7, # Default prompts are loaded + "provider_urls": DEFAULT_PROVIDER_URLS, }, ) mock_run.assert_called_once() @@ -98,6 +100,7 @@ def test_serve_custom_options(cli_runner: CliRunner, mock_logging: Any) -> None: "log_level": "DEBUG", "log_format": "TEXT", "prompts_loaded": 7, # Default prompts are loaded + "provider_urls": DEFAULT_PROVIDER_URLS, }, ) mock_run.assert_called_once() @@ -136,6 +139,7 @@ def test_serve_with_config_file( "log_level": "DEBUG", "log_format": "JSON", "prompts_loaded": 7, # Default prompts are loaded + "provider_urls": DEFAULT_PROVIDER_URLS, }, ) mock_run.assert_called_once() @@ -182,6 +186,7 @@ def test_serve_priority_resolution( "log_level": "ERROR", "log_format": "TEXT", "prompts_loaded": 7, # Default prompts are loaded + "provider_urls": DEFAULT_PROVIDER_URLS, }, ) mock_run.assert_called_once() diff --git a/tests/test_config.py b/tests/test_config.py index 68e387fa..3f6ab585 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -6,7 +6,7 @@ import pytest import yaml -from codegate.config import Config, ConfigurationError, LogFormat, LogLevel +from codegate.config import DEFAULT_PROVIDER_URLS, Config, ConfigurationError, LogFormat, LogLevel def test_default_config(default_config: Config) -> None: @@ -15,6 +15,7 @@ def test_default_config(default_config: Config) -> None: assert default_config.host == "localhost" assert default_config.log_level == LogLevel.INFO assert default_config.log_format == LogFormat.JSON + assert default_config.provider_urls == DEFAULT_PROVIDER_URLS def test_config_from_file(temp_config_file: Path) -> None: @@ -24,6 +25,7 @@ def test_config_from_file(temp_config_file: Path) -> None: assert config.host == "localhost" assert config.log_level == LogLevel.DEBUG assert config.log_format == LogFormat.JSON + assert config.provider_urls == DEFAULT_PROVIDER_URLS def test_config_from_invalid_file(tmp_path: Path) -> None: @@ -49,6 +51,7 @@ def test_config_from_env(env_vars: None) -> None: assert config.host == "localhost" assert config.log_level == LogLevel.WARNING assert config.log_format == LogFormat.TEXT + assert config.provider_urls == DEFAULT_PROVIDER_URLS def test_config_priority_resolution(temp_config_file: Path, env_vars: None) -> None: @@ -60,11 +63,13 @@ def test_config_priority_resolution(temp_config_file: Path, env_vars: None) -> N cli_host="example.com", cli_log_level="WARNING", cli_log_format="TEXT", + cli_provider_urls={"vllm": "https://custom.vllm.server"}, ) assert config.port == 8080 assert config.host == "example.com" assert config.log_level == LogLevel.WARNING assert config.log_format == LogFormat.TEXT + assert config.provider_urls["vllm"] == "https://custom.vllm.server" # Env vars should override config file config = Config.load(config_path=temp_config_file) @@ -72,6 +77,7 @@ def test_config_priority_resolution(temp_config_file: Path, env_vars: None) -> N assert config.host == "localhost" # from env assert config.log_level == LogLevel.WARNING # from env assert config.log_format == LogFormat.TEXT # from env + assert config.provider_urls == DEFAULT_PROVIDER_URLS # no env override # Config file should override defaults os.environ.clear() # Remove env vars @@ -80,6 +86,34 @@ def test_config_priority_resolution(temp_config_file: Path, env_vars: None) -> N assert config.host == "localhost" # from file assert config.log_level == LogLevel.DEBUG # from file assert config.log_format == LogFormat.JSON # from file + assert config.provider_urls == DEFAULT_PROVIDER_URLS # default values + + +def test_provider_urls_from_config(tmp_path: Path) -> None: + """Test loading provider URLs from config file.""" + config_file = tmp_path / "config.yaml" + custom_urls = { + "vllm": "https://custom.vllm.server", + "openai": "https://custom.openai.server", + } + with open(config_file, "w") as f: + yaml.dump({"provider_urls": custom_urls}, f) + + config = Config.from_file(config_file) + assert config.provider_urls["vllm"] == custom_urls["vllm"] + assert config.provider_urls["openai"] == custom_urls["openai"] + assert config.provider_urls["anthropic"] == DEFAULT_PROVIDER_URLS["anthropic"] + + +def test_provider_urls_from_env() -> None: + """Test loading provider URLs from environment variables.""" + os.environ["CODEGATE_PROVIDER_VLLM_URL"] = "https://custom.vllm.server" + try: + config = Config.from_env() + assert config.provider_urls["vllm"] == "https://custom.vllm.server" + assert config.provider_urls["openai"] == DEFAULT_PROVIDER_URLS["openai"] + finally: + del os.environ["CODEGATE_PROVIDER_VLLM_URL"] def test_invalid_log_level() -> None: diff --git a/tests/test_inference.py b/tests/test_inference.py index 86c2eeea..9cfe5d14 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -51,7 +51,7 @@ async def test_chat(inference_engine) -> None: response = await inference_engine.chat(model_path, **chat_request) for chunk in response: - assert 'delta' in chunk["choices"][0] + assert "delta" in chunk["choices"][0] @pytest.mark.asyncio diff --git a/tests/test_provider.py b/tests/test_provider.py index 0957d618..f2c4011f 100644 --- a/tests/test_provider.py +++ b/tests/test_provider.py @@ -14,26 +14,29 @@ def __init__(self): mocked_pipepeline = MagicMock() mocked_fim_pipeline = MagicMock() super().__init__( - mocked_input_normalizer, - mocked_output_normalizer, - mocked_completion_handler, - mocked_pipepeline, - mocked_fim_pipeline - ) + mocked_input_normalizer, + mocked_output_normalizer, + mocked_completion_handler, + mocked_pipepeline, + mocked_fim_pipeline, + ) def _setup_routes(self) -> None: pass @property def provider_route_name(self) -> str: - return 'mock-provider' + return "mock-provider" -@pytest.mark.parametrize("url, expected_bool", [ - ("http://example.com", False), - ("http://test.com/chat/completions", False), - ("http://example.com/completions", True), -]) +@pytest.mark.parametrize( + "url, expected_bool", + [ + ("http://example.com", False), + ("http://test.com/chat/completions", False), + ("http://example.com/completions", True), + ], +) def test_is_fim_request_url(url, expected_bool): mock_provider = MockProvider() request = MagicMock() @@ -45,7 +48,7 @@ def test_is_fim_request_url(url, expected_bool): "messages": [ { "role": "user", - "content": ' ', + "content": " ", } ] } @@ -53,12 +56,7 @@ def test_is_fim_request_url(url, expected_bool): "messages": [ { "role": "user", - "content": [ - { - "type": "text", - "text": " " - } - ], + "content": [{"type": "text", "text": " "}], } ] } @@ -72,11 +70,14 @@ def test_is_fim_request_url(url, expected_bool): } -@pytest.mark.parametrize("data, expected_bool", [ - (DATA_CONTENT_STR, True), - (DATA_CONTENT_LIST, True), - (INVALID_DATA_CONTET, False), -]) +@pytest.mark.parametrize( + "data, expected_bool", + [ + (DATA_CONTENT_STR, True), + (DATA_CONTENT_LIST, True), + (INVALID_DATA_CONTET, False), + ], +) def test_is_fim_request_body(data, expected_bool): mock_provider = MockProvider() assert mock_provider._is_fim_request_body(data) == expected_bool