diff --git a/config.yaml.example b/config.yaml.example
index dcf3db30..33e83f09 100644
--- a/config.yaml.example
+++ b/config.yaml.example
@@ -1,19 +1,42 @@
-# Example configuration file
-# Copy this file to config.yaml and modify as needed
+# Codegate Example Configuration
-# Server configuration
-port: 8989
-host: "localhost"
+# Network settings
+port: 8989 # Port to listen on (1-65535)
+host: "localhost" # Host to bind to (use localhost for all interfaces)
# Logging configuration
-log_level: "INFO" # ERROR, WARNING, INFO, DEBUG
-log_format: "JSON" # JSON, TEXT
+log_level: "INFO" # One of: ERROR, WARNING, INFO, DEBUG
-# Prompts configuration
-# Option 1: Define prompts directly in the config file
-prompts:
- my_system_prompt: "Custom system prompt defined in config"
- another_prompt: "Another custom prompt"
+# Note: This configuration can be overridden by:
+# 1. CLI arguments (--port, --host, --log-level)
+# 2. Environment variables (CODEGATE_APP_PORT, CODEGATE_APP_HOST, CODEGATE_APP_LOG_LEVEL)
-# Option 2: Reference a separate prompts file
-# prompts: "prompts.yaml" # Path to prompts file (relative to config file or absolute)
+# Provider URLs
+provider_urls:
+ openai: "https://api.openai.com/v1"
+ anthropic: "https://api.anthropic.com/v1"
+ vllm: "http://localhost:8000" # Base URL without /v1 path, it will be added automatically
+
+# Note: Provider URLs can be overridden by environment variables:
+# CODEGATE_PROVIDER_OPENAI_URL
+# CODEGATE_PROVIDER_ANTHROPIC_URL
+# CODEGATE_PROVIDER_VLLM_URL
+# Or by CLI flags:
+# --vllm-url
+# --openai-url
+# --anthropic-url
+
+# Embedding model configuration
+
+####
+# Inference model configuration
+##
+
+# Model to use for chatting
+chat_model_path: "./models/qwen2.5-coder-1.5b-instruct-q5_k_m.gguf"
+
+# Context length of the model
+chat_model_n_ctx: 32768
+
+# Number of layers to offload to GPU. If -1, all layers are offloaded.
+chat_model_n_gpu_layers: -1
diff --git a/docs/cli.md b/docs/cli.md
index 74684e31..9142d405 100644
--- a/docs/cli.md
+++ b/docs/cli.md
@@ -46,6 +46,21 @@ codegate serve [OPTIONS]
- Must be a valid YAML file
- Overrides default prompts and configuration file prompts
+- `--vllm-url TEXT`: vLLM provider URL (default: http://localhost:8000)
+ - Optional
+ - Base URL for vLLM provider (/v1 path is added automatically)
+ - Overrides configuration file and environment variables
+
+- `--openai-url TEXT`: OpenAI provider URL (default: https://api.openai.com/v1)
+ - Optional
+ - Base URL for OpenAI provider
+ - Overrides configuration file and environment variables
+
+- `--anthropic-url TEXT`: Anthropic provider URL (default: https://api.anthropic.com/v1)
+ - Optional
+ - Base URL for Anthropic provider
+ - Overrides configuration file and environment variables
+
### show-prompts
Display the loaded system prompts:
@@ -100,6 +115,11 @@ Start server with custom prompts:
codegate serve --prompts my-prompts.yaml
```
+Start server with custom vLLM endpoint:
+```bash
+codegate serve --vllm-url https://vllm.example.com
+```
+
Show default system prompts:
```bash
codegate show-prompts
diff --git a/docs/configuration.md b/docs/configuration.md
index a54796ad..6a015d28 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -16,6 +16,10 @@ The configuration system in Codegate is managed through the `Config` class in `c
- Log Level: "INFO"
- Log Format: "JSON"
- Prompts: Default prompts from prompts/default.yaml
+- Provider URLs:
+ - vLLM: "http://localhost:8000"
+ - OpenAI: "https://api.openai.com/v1"
+ - Anthropic: "https://api.anthropic.com/v1"
## Configuration Methods
@@ -27,6 +31,18 @@ Load configuration from a YAML file:
config = Config.from_file("config.yaml")
```
+Example config.yaml:
+```yaml
+port: 8989
+host: localhost
+log_level: INFO
+log_format: JSON
+provider_urls:
+ vllm: "https://vllm.example.com"
+ openai: "https://api.openai.com/v1"
+ anthropic: "https://api.anthropic.com/v1"
+```
+
### From Environment Variables
Environment variables are automatically loaded with these mappings:
@@ -36,6 +52,9 @@ Environment variables are automatically loaded with these mappings:
- `CODEGATE_APP_LOG_LEVEL`: Logging level
- `CODEGATE_LOG_FORMAT`: Log format
- `CODEGATE_PROMPTS_FILE`: Path to prompts YAML file
+- `CODEGATE_PROVIDER_VLLM_URL`: vLLM provider URL
+- `CODEGATE_PROVIDER_OPENAI_URL`: OpenAI provider URL
+- `CODEGATE_PROVIDER_ANTHROPIC_URL`: Anthropic provider URL
```python
config = Config.from_env()
@@ -43,6 +62,32 @@ config = Config.from_env()
## Configuration Options
+### Provider URLs
+
+Provider URLs can be configured in several ways:
+
+1. In Configuration File:
+ ```yaml
+ provider_urls:
+ vllm: "https://vllm.example.com" # /v1 path is added automatically
+ openai: "https://api.openai.com/v1"
+ anthropic: "https://api.anthropic.com/v1"
+ ```
+
+2. Via Environment Variables:
+ ```bash
+ export CODEGATE_PROVIDER_VLLM_URL=https://vllm.example.com
+ export CODEGATE_PROVIDER_OPENAI_URL=https://api.openai.com/v1
+ export CODEGATE_PROVIDER_ANTHROPIC_URL=https://api.anthropic.com/v1
+ ```
+
+3. Via CLI Flags:
+ ```bash
+ codegate serve --vllm-url https://vllm.example.com
+ ```
+
+Note: For the vLLM provider, the /v1 path is automatically appended to the base URL if not present.
+
### Log Levels
Available log levels (case-insensitive):
diff --git a/docs/development.md b/docs/development.md
index 1f954939..ef66219b 100644
--- a/docs/development.md
+++ b/docs/development.md
@@ -9,6 +9,7 @@ Codegate is a configurable Generative AI gateway designed to protect developers
- Secure coding recommendations
- Prevention of AI recommending deprecated/malicious libraries
- Modular system prompts configuration
+- Multiple AI provider support with configurable endpoints
## Development Setup
@@ -53,7 +54,11 @@ codegate/
│ ├── logging.py # Logging setup
│ ├── prompts.py # Prompts management
│ ├── server.py # Main server implementation
-│ └── providers/* # External service providers (anthropic, openai, etc.)
+│ └── providers/ # External service providers
+│ ├── anthropic/ # Anthropic provider implementation
+│ ├── openai/ # OpenAI provider implementation
+│ ├── vllm/ # vLLM provider implementation
+│ └── base.py # Base provider interface
├── tests/ # Test files
└── docs/ # Documentation
```
@@ -128,9 +133,87 @@ Codegate uses a hierarchical configuration system with the following priority (h
- Log Level: Logging level (ERROR|WARNING|INFO|DEBUG)
- Log Format: Log format (JSON|TEXT)
- Prompts: System prompts configuration
+- Provider URLs: AI provider endpoint configuration
See [Configuration Documentation](configuration.md) for detailed information.
+## Working with Providers
+
+Codegate supports multiple AI providers through a modular provider system.
+
+### Available Providers
+
+1. **vLLM Provider**
+ - Default URL: http://localhost:8000
+ - Supports OpenAI-compatible API
+ - Automatically adds /v1 path to base URL
+ - Model names are prefixed with "hosted_vllm/"
+
+2. **OpenAI Provider**
+ - Default URL: https://api.openai.com/v1
+ - Standard OpenAI API implementation
+
+3. **Anthropic Provider**
+ - Default URL: https://api.anthropic.com/v1
+ - Anthropic Claude API implementation
+
+### Configuring Providers
+
+Provider URLs can be configured through:
+
+1. Config file (config.yaml):
+ ```yaml
+ provider_urls:
+ vllm: "https://vllm.example.com"
+ openai: "https://api.openai.com/v1"
+ anthropic: "https://api.anthropic.com/v1"
+ ```
+
+2. Environment variables:
+ ```bash
+ export CODEGATE_PROVIDER_VLLM_URL=https://vllm.example.com
+ export CODEGATE_PROVIDER_OPENAI_URL=https://api.openai.com/v1
+ export CODEGATE_PROVIDER_ANTHROPIC_URL=https://api.anthropic.com/v1
+ ```
+
+3. CLI flags:
+ ```bash
+ codegate serve --vllm-url https://vllm.example.com
+ ```
+
+### Implementing New Providers
+
+To add a new provider:
+
+1. Create a new directory in `src/codegate/providers/`
+2. Implement required components:
+ - `provider.py`: Main provider class extending BaseProvider
+ - `adapter.py`: Input/output normalizers
+ - `__init__.py`: Export provider class
+
+Example structure:
+```python
+from codegate.providers.base import BaseProvider
+
+class NewProvider(BaseProvider):
+ def __init__(self, ...):
+ super().__init__(
+ InputNormalizer(),
+ OutputNormalizer(),
+ completion_handler,
+ pipeline_processor,
+ fim_pipeline_processor
+ )
+
+ @property
+ def provider_route_name(self) -> str:
+ return "provider_name"
+
+ def _setup_routes(self):
+ # Implement route setup
+ pass
+```
+
## Working with Prompts
### Default Prompts
@@ -188,8 +271,9 @@ codegate serve --port 8989 --host localhost --log-level DEBUG
# Start with custom prompts
codegate serve --prompts my-prompts.yaml
+
+# Start with custom provider URL
+codegate serve --vllm-url https://vllm.example.com
```
See [CLI Documentation](cli.md) for detailed command information.
-
-[Rest of development.md content remains unchanged...]
diff --git a/src/codegate/cli.py b/src/codegate/cli.py
index 8688947b..b6911bfe 100644
--- a/src/codegate/cli.py
+++ b/src/codegate/cli.py
@@ -2,7 +2,7 @@
import sys
from pathlib import Path
-from typing import Optional
+from typing import Dict, Optional
import click
@@ -88,6 +88,24 @@ def show_prompts(prompts: Optional[Path]) -> None:
default=None,
help="Path to YAML prompts file",
)
+@click.option(
+ "--vllm-url",
+ type=str,
+ default=None,
+ help="vLLM provider URL (default: http://localhost:8000/v1)",
+)
+@click.option(
+ "--openai-url",
+ type=str,
+ default=None,
+ help="OpenAI provider URL (default: https://api.openai.com/v1)",
+)
+@click.option(
+ "--anthropic-url",
+ type=str,
+ default=None,
+ help="Anthropic provider URL (default: https://api.anthropic.com/v1)",
+)
def serve(
port: Optional[int],
host: Optional[str],
@@ -95,10 +113,22 @@ def serve(
log_format: Optional[str],
config: Optional[Path],
prompts: Optional[Path],
+ vllm_url: Optional[str],
+ openai_url: Optional[str],
+ anthropic_url: Optional[str],
) -> None:
"""Start the codegate server."""
logger = None
try:
+ # Create provider URLs dict from CLI options
+ cli_provider_urls: Dict[str, str] = {}
+ if vllm_url:
+ cli_provider_urls["vllm"] = vllm_url
+ if openai_url:
+ cli_provider_urls["openai"] = openai_url
+ if anthropic_url:
+ cli_provider_urls["anthropic"] = anthropic_url
+
# Load configuration with priority resolution
cfg = Config.load(
config_path=config,
@@ -107,6 +137,7 @@ def serve(
cli_host=host,
cli_log_level=log_level,
cli_log_format=log_format,
+ cli_provider_urls=cli_provider_urls,
)
logger = setup_logging(cfg.log_level, cfg.log_format)
@@ -118,6 +149,7 @@ def serve(
"log_level": cfg.log_level.value,
"log_format": cfg.log_format.value,
"prompts_loaded": len(cfg.prompts.prompts),
+ "provider_urls": cfg.provider_urls,
},
)
diff --git a/src/codegate/config.py b/src/codegate/config.py
index e63e5fc1..1f304227 100644
--- a/src/codegate/config.py
+++ b/src/codegate/config.py
@@ -3,7 +3,7 @@
import os
from dataclasses import dataclass, field
from pathlib import Path
-from typing import Optional, Union
+from typing import Dict, Optional, Union
import yaml
@@ -13,6 +13,13 @@
logger = setup_logging()
+# Default provider URLs
+DEFAULT_PROVIDER_URLS = {
+ "openai": "https://api.openai.com/v1",
+ "anthropic": "https://api.anthropic.com/v1",
+ "vllm": "http://localhost:8000", # Base URL without /v1 path
+}
+
@dataclass
class Config:
@@ -32,6 +39,9 @@ class Config:
chat_model_n_ctx: int = 32768
chat_model_n_gpu_layers: int = -1
+ # Provider URLs with defaults
+ provider_urls: Dict[str, str] = field(default_factory=lambda: DEFAULT_PROVIDER_URLS.copy())
+
def __post_init__(self) -> None:
"""Validate configuration after initialization."""
if not isinstance(self.port, int) or not (1 <= self.port <= 65535):
@@ -95,19 +105,23 @@ def from_file(cls, config_path: Union[str, Path]) -> "Config":
prompts_path = Path(config_path).parent / prompts_path
prompts_config = PromptConfig.from_file(prompts_path)
+ # Get provider URLs from config
+ provider_urls = DEFAULT_PROVIDER_URLS.copy()
+ if "provider_urls" in config_data:
+ provider_urls.update(config_data.pop("provider_urls"))
+
return cls(
port=config_data.get("port", cls.port),
host=config_data.get("host", cls.host),
log_level=config_data.get("log_level", cls.log_level.value),
log_format=config_data.get("log_format", cls.log_format.value),
model_base_path=config_data.get("chat_model_path", cls.model_base_path),
- chat_model_n_ctx=config_data.get(
- "chat_model_n_ctx", cls.chat_model_n_ctx
- ),
+ chat_model_n_ctx=config_data.get("chat_model_n_ctx", cls.chat_model_n_ctx),
chat_model_n_gpu_layers=config_data.get(
"chat_model_n_gpu_layers", cls.chat_model_n_gpu_layers
),
prompts=prompts_config,
+ provider_urls=provider_urls,
)
except yaml.YAMLError as e:
raise ConfigurationError(f"Failed to parse config file: {e}")
@@ -138,6 +152,12 @@ def from_env(cls) -> "Config":
os.environ["CODEGATE_PROMPTS_FILE"]
) # noqa: E501
+ # Load provider URLs from environment variables
+ for provider in DEFAULT_PROVIDER_URLS.keys():
+ env_var = f"CODEGATE_PROVIDER_{provider.upper()}_URL"
+ if env_var in os.environ:
+ config.provider_urls[provider] = os.environ[env_var]
+
return config
except ValueError as e:
raise ConfigurationError(f"Invalid environment variable value: {e}")
@@ -151,6 +171,7 @@ def load(
cli_host: Optional[str] = None,
cli_log_level: Optional[str] = None,
cli_log_format: Optional[str] = None,
+ cli_provider_urls: Optional[Dict[str, str]] = None,
) -> "Config":
"""Load configuration with priority resolution.
@@ -167,6 +188,7 @@ def load(
cli_host: Optional CLI host override
cli_log_level: Optional CLI log level override
cli_log_format: Optional CLI log format override
+ cli_provider_urls: Optional dict of provider URLs from CLI
Returns:
Config: Resolved configuration
@@ -198,6 +220,10 @@ def load(
if "CODEGATE_PROMPTS_FILE" in os.environ:
config.prompts = env_config.prompts
+ # Override provider URLs from environment
+ for provider, url in env_config.provider_urls.items():
+ config.provider_urls[provider] = url
+
# Override with CLI arguments
if cli_port is not None:
config.port = cli_port
@@ -209,6 +235,8 @@ def load(
config.log_format = LogFormat(cli_log_format)
if prompts_path is not None:
config.prompts = PromptConfig.from_file(prompts_path)
+ if cli_provider_urls is not None:
+ config.provider_urls.update(cli_provider_urls)
# Set the __config class attribute
Config.__config = config
diff --git a/src/codegate/pipeline/fim/secret_analyzer.py b/src/codegate/pipeline/fim/secret_analyzer.py
index 83f89a11..68f38351 100644
--- a/src/codegate/pipeline/fim/secret_analyzer.py
+++ b/src/codegate/pipeline/fim/secret_analyzer.py
@@ -28,9 +28,7 @@ def name(self) -> str:
return "fim-secret-analyzer"
async def process(
- self,
- request: ChatCompletionRequest,
- context: PipelineContext
+ self, request: ChatCompletionRequest, context: PipelineContext
) -> PipelineResult:
# We should call here Secrets Blocking module to see if the request messages contain secrets
# messages_contain_secrets = [analyze_msg_secrets(msg) for msg in request.messages]
@@ -39,7 +37,7 @@ async def process(
# For the moment to test shortcutting just treat all messages as if they contain secrets
message_with_secrets = False
if message_with_secrets:
- logger.info('Blocking message with secrets.')
+ logger.info("Blocking message with secrets.")
return PipelineResult(
response=PipelineResponse(
step_name=self.name,
diff --git a/src/codegate/providers/__init__.py b/src/codegate/providers/__init__.py
index d83fa21a..1ec69183 100644
--- a/src/codegate/providers/__init__.py
+++ b/src/codegate/providers/__init__.py
@@ -2,10 +2,12 @@
from codegate.providers.base import BaseProvider
from codegate.providers.openai.provider import OpenAIProvider
from codegate.providers.registry import ProviderRegistry
+from codegate.providers.vllm.provider import VLLMProvider
__all__ = [
"BaseProvider",
"ProviderRegistry",
"OpenAIProvider",
"AnthropicProvider",
+ "VLLMProvider",
]
diff --git a/src/codegate/providers/anthropic/completion_handler.py b/src/codegate/providers/anthropic/completion_handler.py
index ed45eb40..253e2970 100644
--- a/src/codegate/providers/anthropic/completion_handler.py
+++ b/src/codegate/providers/anthropic/completion_handler.py
@@ -27,7 +27,7 @@ async def execute_completion(
For more details, refer to the
[LiteLLM Documentation](https://docs.litellm.ai/docs/providers/anthropic).
"""
- model_in_request = request['model']
- if not model_in_request.startswith('anthropic/'):
- request['model'] = f'anthropic/{model_in_request}'
+ model_in_request = request["model"]
+ if not model_in_request.startswith("anthropic/"):
+ request["model"] = f"anthropic/{model_in_request}"
return await super().execute_completion(request, api_key, stream)
diff --git a/src/codegate/providers/anthropic/provider.py b/src/codegate/providers/anthropic/provider.py
index e0042ace..a035f7aa 100644
--- a/src/codegate/providers/anthropic/provider.py
+++ b/src/codegate/providers/anthropic/provider.py
@@ -11,17 +11,17 @@
class AnthropicProvider(BaseProvider):
def __init__(
- self,
- pipeline_processor: Optional[SequentialPipelineProcessor] = None,
- fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None
- ):
+ self,
+ pipeline_processor: Optional[SequentialPipelineProcessor] = None,
+ fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None,
+ ):
completion_handler = AnthropicCompletion(stream_generator=anthropic_stream_generator)
super().__init__(
AnthropicInputNormalizer(),
AnthropicOutputNormalizer(),
completion_handler,
pipeline_processor,
- fim_pipeline_processor
+ fim_pipeline_processor,
)
@property
diff --git a/src/codegate/providers/base.py b/src/codegate/providers/base.py
index 4f2d65e0..a350d0a9 100644
--- a/src/codegate/providers/base.py
+++ b/src/codegate/providers/base.py
@@ -50,17 +50,15 @@ def provider_route_name(self) -> str:
pass
async def _run_input_pipeline(
- self,
- normalized_request: ChatCompletionRequest,
- is_fim_request: bool
+ self, normalized_request: ChatCompletionRequest, is_fim_request: bool
) -> PipelineResult:
# Decide which pipeline processor to use
if is_fim_request:
pipeline_processor = self._fim_pipelin_processor
- logger.info('FIM pipeline selected for execution.')
+ logger.info("FIM pipeline selected for execution.")
else:
pipeline_processor = self._pipeline_processor
- logger.info('Chat completion pipeline selected for execution.')
+ logger.info("Chat completion pipeline selected for execution.")
if pipeline_processor is None:
return PipelineResult(request=normalized_request)
@@ -92,21 +90,21 @@ def _is_fim_request_body(self, data: Dict) -> bool:
Determine from the raw incoming data if it's a FIM request.
Used by: OpenAI and Anthropic
"""
- messages = data.get('messages', [])
+ messages = data.get("messages", [])
if not messages:
return False
- first_message_content = messages[0].get('content')
+ first_message_content = messages[0].get("content")
if first_message_content is None:
return False
- fim_stop_sequences = ['', '', '', '']
+ fim_stop_sequences = ["", "", "", ""]
if isinstance(first_message_content, str):
msg_prompt = first_message_content
elif isinstance(first_message_content, list):
- msg_prompt = first_message_content[0].get('text', '')
+ msg_prompt = first_message_content[0].get("text", "")
else:
- logger.warning(f'Could not determine if message was FIM from data: {data}')
+ logger.warning(f"Could not determine if message was FIM from data: {data}")
return False
return all([stop_sequence in msg_prompt for stop_sequence in fim_stop_sequences])
@@ -121,7 +119,7 @@ def _is_fim_request(self, request: Request, data: Dict) -> bool:
return self._is_fim_request_body(data)
async def complete(
- self, data: Dict, api_key: Optional[str], is_fim_request: bool
+ self, data: Dict, api_key: Optional[str], is_fim_request: bool
) -> Union[ModelResponse, AsyncIterator[ModelResponse]]:
"""
Main completion flow with pipeline integration
diff --git a/src/codegate/providers/llamacpp/completion_handler.py b/src/codegate/providers/llamacpp/completion_handler.py
index 40046975..0c074985 100644
--- a/src/codegate/providers/llamacpp/completion_handler.py
+++ b/src/codegate/providers/llamacpp/completion_handler.py
@@ -1,5 +1,5 @@
-import json
import asyncio
+import json
from typing import Any, AsyncIterator, Iterator, Optional, Union
from fastapi.responses import StreamingResponse
@@ -39,16 +39,20 @@ async def execute_completion(
"""
model_path = f"{Config.get_config().model_base_path}/{request['model']}.gguf"
- if 'prompt' in request:
- response = await self.inference_engine.complete(model_path,
- Config.get_config().chat_model_n_ctx,
- Config.get_config().chat_model_n_gpu_layers,
- **request)
+ if "prompt" in request:
+ response = await self.inference_engine.complete(
+ model_path,
+ Config.get_config().chat_model_n_ctx,
+ Config.get_config().chat_model_n_gpu_layers,
+ **request,
+ )
else:
- response = await self.inference_engine.chat(model_path,
- Config.get_config().chat_model_n_ctx,
- Config.get_config().chat_model_n_gpu_layers,
- **request)
+ response = await self.inference_engine.chat(
+ model_path,
+ Config.get_config().chat_model_n_ctx,
+ Config.get_config().chat_model_n_gpu_layers,
+ **request,
+ )
return response
def create_streaming_response(self, stream: Iterator[Any]) -> StreamingResponse:
diff --git a/src/codegate/providers/llamacpp/provider.py b/src/codegate/providers/llamacpp/provider.py
index eddfa901..befc169e 100644
--- a/src/codegate/providers/llamacpp/provider.py
+++ b/src/codegate/providers/llamacpp/provider.py
@@ -10,17 +10,17 @@
class LlamaCppProvider(BaseProvider):
def __init__(
- self,
- pipeline_processor: Optional[SequentialPipelineProcessor] = None,
- fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None
- ):
+ self,
+ pipeline_processor: Optional[SequentialPipelineProcessor] = None,
+ fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None,
+ ):
completion_handler = LlamaCppCompletionHandler()
super().__init__(
LLamaCppInputNormalizer(),
LLamaCppOutputNormalizer(),
completion_handler,
pipeline_processor,
- fim_pipeline_processor
+ fim_pipeline_processor,
)
@property
@@ -32,6 +32,7 @@ def _setup_routes(self):
Sets up the /completions and /chat/completions routes for the
provider as expected by the Llama API.
"""
+
@self.router.post(f"/{self.provider_route_name}/completions")
@self.router.post(f"/{self.provider_route_name}/chat/completions")
async def create_completion(
diff --git a/src/codegate/providers/openai/provider.py b/src/codegate/providers/openai/provider.py
index 209118b5..60c36a4b 100644
--- a/src/codegate/providers/openai/provider.py
+++ b/src/codegate/providers/openai/provider.py
@@ -10,17 +10,17 @@
class OpenAIProvider(BaseProvider):
def __init__(
- self,
- pipeline_processor: Optional[SequentialPipelineProcessor] = None,
- fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None
- ):
+ self,
+ pipeline_processor: Optional[SequentialPipelineProcessor] = None,
+ fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None,
+ ):
completion_handler = LiteLLmShim(stream_generator=sse_stream_generator)
super().__init__(
OpenAIInputNormalizer(),
OpenAIOutputNormalizer(),
completion_handler,
pipeline_processor,
- fim_pipeline_processor
+ fim_pipeline_processor,
)
@property
diff --git a/src/codegate/providers/vllm/__init__.py b/src/codegate/providers/vllm/__init__.py
new file mode 100644
index 00000000..a6bf6555
--- /dev/null
+++ b/src/codegate/providers/vllm/__init__.py
@@ -0,0 +1,3 @@
+from codegate.providers.vllm.provider import VLLMProvider
+
+__all__ = ["VLLMProvider"]
diff --git a/src/codegate/providers/vllm/adapter.py b/src/codegate/providers/vllm/adapter.py
new file mode 100644
index 00000000..38d78740
--- /dev/null
+++ b/src/codegate/providers/vllm/adapter.py
@@ -0,0 +1,73 @@
+from typing import Any, Dict
+
+from litellm import ChatCompletionRequest
+
+from codegate.providers.normalizer.base import ModelInputNormalizer, ModelOutputNormalizer
+
+
+class VLLMInputNormalizer(ModelInputNormalizer):
+ def __init__(self):
+ super().__init__()
+
+ def normalize(self, data: Dict) -> ChatCompletionRequest:
+ """
+ Normalize the input data to the format expected by LiteLLM.
+ Ensures the model name has the hosted_vllm prefix and base_url has /v1.
+ """
+ # Make a copy of the data to avoid modifying the original
+ normalized_data = data.copy()
+
+ # Format the model name to include the provider
+ if "model" in normalized_data:
+ model_name = normalized_data["model"]
+ if not model_name.startswith("hosted_vllm/"):
+ normalized_data["model"] = f"hosted_vllm/{model_name}"
+
+ # Ensure the base_url ends with /v1 if provided
+ if "base_url" in normalized_data:
+ base_url = normalized_data["base_url"].rstrip("/")
+ if not base_url.endswith("/v1"):
+ normalized_data["base_url"] = f"{base_url}/v1"
+
+ return ChatCompletionRequest(**normalized_data)
+
+ def denormalize(self, data: ChatCompletionRequest) -> Dict:
+ """
+ Convert back to raw format for the API request
+ """
+ return data
+
+
+class VLLMOutputNormalizer(ModelOutputNormalizer):
+ def __init__(self):
+ super().__init__()
+
+ def normalize_streaming(
+ self,
+ model_reply: Any,
+ ) -> Any:
+ """
+ No normalizing needed for streaming responses
+ """
+ return model_reply
+
+ def normalize(self, model_reply: Any) -> Any:
+ """
+ No normalizing needed for responses
+ """
+ return model_reply
+
+ def denormalize(self, normalized_reply: Any) -> Any:
+ """
+ No denormalizing needed for responses
+ """
+ return normalized_reply
+
+ def denormalize_streaming(
+ self,
+ normalized_reply: Any,
+ ) -> Any:
+ """
+ No denormalizing needed for streaming responses
+ """
+ return normalized_reply
diff --git a/src/codegate/providers/vllm/provider.py b/src/codegate/providers/vllm/provider.py
new file mode 100644
index 00000000..05c8f720
--- /dev/null
+++ b/src/codegate/providers/vllm/provider.py
@@ -0,0 +1,57 @@
+import json
+from typing import Optional
+
+from fastapi import Header, HTTPException, Request
+
+from codegate.config import Config
+from codegate.providers.base import BaseProvider, SequentialPipelineProcessor
+from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator
+from codegate.providers.vllm.adapter import VLLMInputNormalizer, VLLMOutputNormalizer
+
+
+class VLLMProvider(BaseProvider):
+ def __init__(
+ self,
+ pipeline_processor: Optional[SequentialPipelineProcessor] = None,
+ fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None,
+ ):
+ completion_handler = LiteLLmShim(stream_generator=sse_stream_generator)
+ super().__init__(
+ VLLMInputNormalizer(),
+ VLLMOutputNormalizer(),
+ completion_handler,
+ pipeline_processor,
+ fim_pipeline_processor,
+ )
+
+ @property
+ def provider_route_name(self) -> str:
+ return "vllm"
+
+ def _setup_routes(self):
+ """
+ Sets up the /chat/completions route for the provider as expected by the
+ OpenAI API. Extracts the API key from the "Authorization" header and
+ passes it to the completion handler.
+ """
+
+ @self.router.post(f"/{self.provider_route_name}/chat/completions")
+ @self.router.post(f"/{self.provider_route_name}/completions")
+ async def create_completion(
+ request: Request,
+ authorization: str = Header(..., description="Bearer token"),
+ ):
+ if not authorization.startswith("Bearer "):
+ raise HTTPException(status_code=401, detail="Invalid authorization header")
+
+ api_key = authorization.split(" ")[1]
+ body = await request.body()
+ data = json.loads(body)
+
+ # Add the vLLM base URL to the request
+ config = Config.get_config()
+ data["base_url"] = config.provider_urls.get("vllm")
+
+ is_fim_request = self._is_fim_request(request, data)
+ stream = await self.complete(data, api_key, is_fim_request=is_fim_request)
+ return self._completion_handler.create_streaming_response(stream)
diff --git a/src/codegate/server.py b/src/codegate/server.py
index 14ccb60a..046ed60e 100644
--- a/src/codegate/server.py
+++ b/src/codegate/server.py
@@ -9,6 +9,7 @@
from codegate.providers.llamacpp.provider import LlamaCppProvider
from codegate.providers.openai.provider import OpenAIProvider
from codegate.providers.registry import ProviderRegistry
+from codegate.providers.vllm.provider import VLLMProvider
def init_app() -> FastAPI:
@@ -22,8 +23,7 @@ def init_app() -> FastAPI:
CodegateVersion(),
]
# Leaving the pipeline empty for now
- fim_steps: List[PipelineStep] = [
- ]
+ fim_steps: List[PipelineStep] = []
pipeline = SequentialPipelineProcessor(steps)
fim_pipeline = SequentialPipelineProcessor(fim_steps)
@@ -31,18 +31,20 @@ def init_app() -> FastAPI:
registry = ProviderRegistry(app)
# Register all known providers
- registry.add_provider("openai", OpenAIProvider(
- pipeline_processor=pipeline,
- fim_pipeline_processor=fim_pipeline
- ))
- registry.add_provider("anthropic", AnthropicProvider(
- pipeline_processor=pipeline,
- fim_pipeline_processor=fim_pipeline
- ))
- registry.add_provider("llamacpp", LlamaCppProvider(
- pipeline_processor=pipeline,
- fim_pipeline_processor=fim_pipeline
- ))
+ registry.add_provider(
+ "openai", OpenAIProvider(pipeline_processor=pipeline, fim_pipeline_processor=fim_pipeline)
+ )
+ registry.add_provider(
+ "anthropic",
+ AnthropicProvider(pipeline_processor=pipeline, fim_pipeline_processor=fim_pipeline),
+ )
+ registry.add_provider(
+ "llamacpp",
+ LlamaCppProvider(pipeline_processor=pipeline, fim_pipeline_processor=fim_pipeline),
+ )
+ registry.add_provider(
+ "vllm", VLLMProvider(pipeline_processor=pipeline, fim_pipeline_processor=fim_pipeline)
+ )
# Create and add system routes
system_router = APIRouter(tags=["System"]) # Tags group endpoints in the docs
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 2329e788..64f8a53e 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -9,6 +9,7 @@
from codegate.cli import cli
from codegate.codegate_logging import LogFormat, LogLevel
+from codegate.config import DEFAULT_PROVIDER_URLS
@pytest.fixture
@@ -63,6 +64,7 @@ def test_serve_default_options(cli_runner: CliRunner, mock_logging: Any) -> None
"log_level": "INFO",
"log_format": "JSON",
"prompts_loaded": 7, # Default prompts are loaded
+ "provider_urls": DEFAULT_PROVIDER_URLS,
},
)
mock_run.assert_called_once()
@@ -98,6 +100,7 @@ def test_serve_custom_options(cli_runner: CliRunner, mock_logging: Any) -> None:
"log_level": "DEBUG",
"log_format": "TEXT",
"prompts_loaded": 7, # Default prompts are loaded
+ "provider_urls": DEFAULT_PROVIDER_URLS,
},
)
mock_run.assert_called_once()
@@ -136,6 +139,7 @@ def test_serve_with_config_file(
"log_level": "DEBUG",
"log_format": "JSON",
"prompts_loaded": 7, # Default prompts are loaded
+ "provider_urls": DEFAULT_PROVIDER_URLS,
},
)
mock_run.assert_called_once()
@@ -182,6 +186,7 @@ def test_serve_priority_resolution(
"log_level": "ERROR",
"log_format": "TEXT",
"prompts_loaded": 7, # Default prompts are loaded
+ "provider_urls": DEFAULT_PROVIDER_URLS,
},
)
mock_run.assert_called_once()
diff --git a/tests/test_config.py b/tests/test_config.py
index 68e387fa..3f6ab585 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -6,7 +6,7 @@
import pytest
import yaml
-from codegate.config import Config, ConfigurationError, LogFormat, LogLevel
+from codegate.config import DEFAULT_PROVIDER_URLS, Config, ConfigurationError, LogFormat, LogLevel
def test_default_config(default_config: Config) -> None:
@@ -15,6 +15,7 @@ def test_default_config(default_config: Config) -> None:
assert default_config.host == "localhost"
assert default_config.log_level == LogLevel.INFO
assert default_config.log_format == LogFormat.JSON
+ assert default_config.provider_urls == DEFAULT_PROVIDER_URLS
def test_config_from_file(temp_config_file: Path) -> None:
@@ -24,6 +25,7 @@ def test_config_from_file(temp_config_file: Path) -> None:
assert config.host == "localhost"
assert config.log_level == LogLevel.DEBUG
assert config.log_format == LogFormat.JSON
+ assert config.provider_urls == DEFAULT_PROVIDER_URLS
def test_config_from_invalid_file(tmp_path: Path) -> None:
@@ -49,6 +51,7 @@ def test_config_from_env(env_vars: None) -> None:
assert config.host == "localhost"
assert config.log_level == LogLevel.WARNING
assert config.log_format == LogFormat.TEXT
+ assert config.provider_urls == DEFAULT_PROVIDER_URLS
def test_config_priority_resolution(temp_config_file: Path, env_vars: None) -> None:
@@ -60,11 +63,13 @@ def test_config_priority_resolution(temp_config_file: Path, env_vars: None) -> N
cli_host="example.com",
cli_log_level="WARNING",
cli_log_format="TEXT",
+ cli_provider_urls={"vllm": "https://custom.vllm.server"},
)
assert config.port == 8080
assert config.host == "example.com"
assert config.log_level == LogLevel.WARNING
assert config.log_format == LogFormat.TEXT
+ assert config.provider_urls["vllm"] == "https://custom.vllm.server"
# Env vars should override config file
config = Config.load(config_path=temp_config_file)
@@ -72,6 +77,7 @@ def test_config_priority_resolution(temp_config_file: Path, env_vars: None) -> N
assert config.host == "localhost" # from env
assert config.log_level == LogLevel.WARNING # from env
assert config.log_format == LogFormat.TEXT # from env
+ assert config.provider_urls == DEFAULT_PROVIDER_URLS # no env override
# Config file should override defaults
os.environ.clear() # Remove env vars
@@ -80,6 +86,34 @@ def test_config_priority_resolution(temp_config_file: Path, env_vars: None) -> N
assert config.host == "localhost" # from file
assert config.log_level == LogLevel.DEBUG # from file
assert config.log_format == LogFormat.JSON # from file
+ assert config.provider_urls == DEFAULT_PROVIDER_URLS # default values
+
+
+def test_provider_urls_from_config(tmp_path: Path) -> None:
+ """Test loading provider URLs from config file."""
+ config_file = tmp_path / "config.yaml"
+ custom_urls = {
+ "vllm": "https://custom.vllm.server",
+ "openai": "https://custom.openai.server",
+ }
+ with open(config_file, "w") as f:
+ yaml.dump({"provider_urls": custom_urls}, f)
+
+ config = Config.from_file(config_file)
+ assert config.provider_urls["vllm"] == custom_urls["vllm"]
+ assert config.provider_urls["openai"] == custom_urls["openai"]
+ assert config.provider_urls["anthropic"] == DEFAULT_PROVIDER_URLS["anthropic"]
+
+
+def test_provider_urls_from_env() -> None:
+ """Test loading provider URLs from environment variables."""
+ os.environ["CODEGATE_PROVIDER_VLLM_URL"] = "https://custom.vllm.server"
+ try:
+ config = Config.from_env()
+ assert config.provider_urls["vllm"] == "https://custom.vllm.server"
+ assert config.provider_urls["openai"] == DEFAULT_PROVIDER_URLS["openai"]
+ finally:
+ del os.environ["CODEGATE_PROVIDER_VLLM_URL"]
def test_invalid_log_level() -> None:
diff --git a/tests/test_inference.py b/tests/test_inference.py
index 86c2eeea..9cfe5d14 100644
--- a/tests/test_inference.py
+++ b/tests/test_inference.py
@@ -51,7 +51,7 @@ async def test_chat(inference_engine) -> None:
response = await inference_engine.chat(model_path, **chat_request)
for chunk in response:
- assert 'delta' in chunk["choices"][0]
+ assert "delta" in chunk["choices"][0]
@pytest.mark.asyncio
diff --git a/tests/test_provider.py b/tests/test_provider.py
index 0957d618..f2c4011f 100644
--- a/tests/test_provider.py
+++ b/tests/test_provider.py
@@ -14,26 +14,29 @@ def __init__(self):
mocked_pipepeline = MagicMock()
mocked_fim_pipeline = MagicMock()
super().__init__(
- mocked_input_normalizer,
- mocked_output_normalizer,
- mocked_completion_handler,
- mocked_pipepeline,
- mocked_fim_pipeline
- )
+ mocked_input_normalizer,
+ mocked_output_normalizer,
+ mocked_completion_handler,
+ mocked_pipepeline,
+ mocked_fim_pipeline,
+ )
def _setup_routes(self) -> None:
pass
@property
def provider_route_name(self) -> str:
- return 'mock-provider'
+ return "mock-provider"
-@pytest.mark.parametrize("url, expected_bool", [
- ("http://example.com", False),
- ("http://test.com/chat/completions", False),
- ("http://example.com/completions", True),
-])
+@pytest.mark.parametrize(
+ "url, expected_bool",
+ [
+ ("http://example.com", False),
+ ("http://test.com/chat/completions", False),
+ ("http://example.com/completions", True),
+ ],
+)
def test_is_fim_request_url(url, expected_bool):
mock_provider = MockProvider()
request = MagicMock()
@@ -45,7 +48,7 @@ def test_is_fim_request_url(url, expected_bool):
"messages": [
{
"role": "user",
- "content": ' ',
+ "content": " ",
}
]
}
@@ -53,12 +56,7 @@ def test_is_fim_request_url(url, expected_bool):
"messages": [
{
"role": "user",
- "content": [
- {
- "type": "text",
- "text": " "
- }
- ],
+ "content": [{"type": "text", "text": " "}],
}
]
}
@@ -72,11 +70,14 @@ def test_is_fim_request_url(url, expected_bool):
}
-@pytest.mark.parametrize("data, expected_bool", [
- (DATA_CONTENT_STR, True),
- (DATA_CONTENT_LIST, True),
- (INVALID_DATA_CONTET, False),
-])
+@pytest.mark.parametrize(
+ "data, expected_bool",
+ [
+ (DATA_CONTENT_STR, True),
+ (DATA_CONTENT_LIST, True),
+ (INVALID_DATA_CONTET, False),
+ ],
+)
def test_is_fim_request_body(data, expected_bool):
mock_provider = MockProvider()
assert mock_provider._is_fim_request_body(data) == expected_bool