diff --git a/src/codegate/pipeline/fim/secret_analyzer.py b/src/codegate/pipeline/fim/secret_analyzer.py
new file mode 100644
index 00000000..83f89a11
--- /dev/null
+++ b/src/codegate/pipeline/fim/secret_analyzer.py
@@ -0,0 +1,52 @@
+from litellm import ChatCompletionRequest
+
+from codegate.codegate_logging import setup_logging
+from codegate.pipeline.base import PipelineContext, PipelineResponse, PipelineResult, PipelineStep
+
+logger = setup_logging()
+
+
+class SecretAnalyzer(PipelineStep):
+ """Pipeline step that handles analyzing secrets in FIM pipeline."""
+
+ message_blocked = """
+ ⚠️ CodeGate Security Warning! Analysis Report ⚠️
+ Potential leak of sensitive credentials blocked
+
+ Recommendations:
+ - Use environment variables for secrets
+ """
+
+ @property
+ def name(self) -> str:
+ """
+ Returns the name of this pipeline step.
+
+ Returns:
+ str: The identifier 'fim-secret-analyzer'
+ """
+ return "fim-secret-analyzer"
+
+ async def process(
+ self,
+ request: ChatCompletionRequest,
+ context: PipelineContext
+ ) -> PipelineResult:
+ # We should call here Secrets Blocking module to see if the request messages contain secrets
+ # messages_contain_secrets = [analyze_msg_secrets(msg) for msg in request.messages]
+ # message_with_secrets = any(messages_contain_secretes)
+
+ # For the moment to test shortcutting just treat all messages as if they contain secrets
+ message_with_secrets = False
+ if message_with_secrets:
+ logger.info('Blocking message with secrets.')
+ return PipelineResult(
+ response=PipelineResponse(
+ step_name=self.name,
+ content=self.message_blocked,
+ model=request["model"],
+ ),
+ )
+
+ # No messages with secrets, execute the rest of the pipeline
+ return PipelineResult(request=request)
diff --git a/src/codegate/providers/anthropic/completion_handler.py b/src/codegate/providers/anthropic/completion_handler.py
new file mode 100644
index 00000000..ed45eb40
--- /dev/null
+++ b/src/codegate/providers/anthropic/completion_handler.py
@@ -0,0 +1,33 @@
+from typing import AsyncIterator, Optional, Union
+
+from litellm import ChatCompletionRequest, ModelResponse
+
+from codegate.providers.litellmshim import LiteLLmShim
+
+
+class AnthropicCompletion(LiteLLmShim):
+ """
+ AnthropicCompletion used by the Anthropic provider to execute completions
+ """
+
+ async def execute_completion(
+ self,
+ request: ChatCompletionRequest,
+ api_key: Optional[str],
+ stream: bool = False,
+ ) -> Union[ModelResponse, AsyncIterator[ModelResponse]]:
+ """
+ Ensures the model name is prefixed with 'anthropic/' to explicitly route to Anthropic's API.
+
+ LiteLLM automatically maps most model names, but prepending 'anthropic/' forces the request
+ to Anthropic. This avoids issues with unrecognized names like 'claude-3-5-sonnet-latest',
+ which LiteLLM doesn't accept as a valid Anthropic model. This safeguard may be unnecessary
+ but ensures compatibility.
+
+ For more details, refer to the
+ [LiteLLM Documentation](https://docs.litellm.ai/docs/providers/anthropic).
+ """
+ model_in_request = request['model']
+ if not model_in_request.startswith('anthropic/'):
+ request['model'] = f'anthropic/{model_in_request}'
+ return await super().execute_completion(request, api_key, stream)
diff --git a/src/codegate/providers/anthropic/provider.py b/src/codegate/providers/anthropic/provider.py
index a16c5921..e0042ace 100644
--- a/src/codegate/providers/anthropic/provider.py
+++ b/src/codegate/providers/anthropic/provider.py
@@ -1,20 +1,27 @@
import json
+from typing import Optional
from fastapi import Header, HTTPException, Request
from codegate.providers.anthropic.adapter import AnthropicInputNormalizer, AnthropicOutputNormalizer
-from codegate.providers.base import BaseProvider
-from codegate.providers.litellmshim import LiteLLmShim, anthropic_stream_generator
+from codegate.providers.anthropic.completion_handler import AnthropicCompletion
+from codegate.providers.base import BaseProvider, SequentialPipelineProcessor
+from codegate.providers.litellmshim import anthropic_stream_generator
class AnthropicProvider(BaseProvider):
- def __init__(self, pipeline_processor=None):
- completion_handler = LiteLLmShim(stream_generator=anthropic_stream_generator)
+ def __init__(
+ self,
+ pipeline_processor: Optional[SequentialPipelineProcessor] = None,
+ fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None
+ ):
+ completion_handler = AnthropicCompletion(stream_generator=anthropic_stream_generator)
super().__init__(
AnthropicInputNormalizer(),
AnthropicOutputNormalizer(),
completion_handler,
pipeline_processor,
+ fim_pipeline_processor
)
@property
@@ -39,5 +46,6 @@ async def create_message(
body = await request.body()
data = json.loads(body)
- stream = await self.complete(data, x_api_key)
+ is_fim_request = self._is_fim_request(request, data)
+ stream = await self.complete(data, x_api_key, is_fim_request)
return self._completion_handler.create_streaming_response(stream)
diff --git a/src/codegate/providers/base.py b/src/codegate/providers/base.py
index 940d93b2..4f2d65e0 100644
--- a/src/codegate/providers/base.py
+++ b/src/codegate/providers/base.py
@@ -1,15 +1,17 @@
from abc import ABC, abstractmethod
from typing import Any, AsyncIterator, Callable, Dict, Optional, Union
-from fastapi import APIRouter
+from fastapi import APIRouter, Request
from litellm import ModelResponse
from litellm.types.llms.openai import ChatCompletionRequest
+from codegate.codegate_logging import setup_logging
from codegate.pipeline.base import PipelineResult, SequentialPipelineProcessor
from codegate.providers.completion.base import BaseCompletionHandler
from codegate.providers.formatting.input_pipeline import PipelineResponseFormatter
from codegate.providers.normalizer.base import ModelInputNormalizer, ModelOutputNormalizer
+logger = setup_logging()
StreamGenerator = Callable[[AsyncIterator[Any]], AsyncIterator[str]]
@@ -25,12 +27,14 @@ def __init__(
output_normalizer: ModelOutputNormalizer,
completion_handler: BaseCompletionHandler,
pipeline_processor: Optional[SequentialPipelineProcessor] = None,
+ fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None,
):
self.router = APIRouter()
self._completion_handler = completion_handler
self._input_normalizer = input_normalizer
self._output_normalizer = output_normalizer
self._pipeline_processor = pipeline_processor
+ self._fim_pipelin_processor = fim_pipeline_processor
self._pipeline_response_formatter = PipelineResponseFormatter(output_normalizer)
@@ -48,11 +52,19 @@ def provider_route_name(self) -> str:
async def _run_input_pipeline(
self,
normalized_request: ChatCompletionRequest,
+ is_fim_request: bool
) -> PipelineResult:
- if self._pipeline_processor is None:
+ # Decide which pipeline processor to use
+ if is_fim_request:
+ pipeline_processor = self._fim_pipelin_processor
+ logger.info('FIM pipeline selected for execution.')
+ else:
+ pipeline_processor = self._pipeline_processor
+ logger.info('Chat completion pipeline selected for execution.')
+ if pipeline_processor is None:
return PipelineResult(request=normalized_request)
- result = await self._pipeline_processor.process_request(normalized_request)
+ result = await pipeline_processor.process_request(normalized_request)
# TODO(jakub): handle this by returning a message to the client
if result.error_message:
@@ -60,10 +72,56 @@ async def _run_input_pipeline(
return result
+ def _is_fim_request_url(self, request: Request) -> bool:
+ """
+ Checks the request URL to determine if a request is FIM or chat completion.
+ Used by: llama.cpp
+ """
+ request_path = request.url.path
+ # Evaluate first a larger substring.
+ if request_path.endswith("/chat/completions"):
+ return False
+
+ if request_path.endswith("/completions"):
+ return True
+
+ return False
+
+ def _is_fim_request_body(self, data: Dict) -> bool:
+ """
+ Determine from the raw incoming data if it's a FIM request.
+ Used by: OpenAI and Anthropic
+ """
+ messages = data.get('messages', [])
+ if not messages:
+ return False
+
+ first_message_content = messages[0].get('content')
+ if first_message_content is None:
+ return False
+
+ fim_stop_sequences = ['', '', '', '']
+ if isinstance(first_message_content, str):
+ msg_prompt = first_message_content
+ elif isinstance(first_message_content, list):
+ msg_prompt = first_message_content[0].get('text', '')
+ else:
+ logger.warning(f'Could not determine if message was FIM from data: {data}')
+ return False
+ return all([stop_sequence in msg_prompt for stop_sequence in fim_stop_sequences])
+
+ def _is_fim_request(self, request: Request, data: Dict) -> bool:
+ """
+ Determin if the request is FIM by the URL or the data of the request.
+ """
+ # Avoid more expensive inspection of body by just checking the URL.
+ if self._is_fim_request_url(request):
+ return True
+
+ return self._is_fim_request_body(data)
+
async def complete(
- self,
- data: Dict,
- api_key: Optional[str],
+ self, data: Dict, api_key: Optional[str], is_fim_request: bool
) -> Union[ModelResponse, AsyncIterator[ModelResponse]]:
"""
Main completion flow with pipeline integration
@@ -78,8 +136,7 @@ async def complete(
"""
normalized_request = self._input_normalizer.normalize(data)
streaming = data.get("stream", False)
-
- input_pipeline_result = await self._run_input_pipeline(normalized_request)
+ input_pipeline_result = await self._run_input_pipeline(normalized_request, is_fim_request)
if input_pipeline_result.response:
return self._pipeline_response_formatter.handle_pipeline_response(
input_pipeline_result.response, streaming
@@ -93,7 +150,6 @@ async def complete(
model_response = await self._completion_handler.execute_completion(
provider_request, api_key=api_key, stream=streaming
)
-
if not streaming:
return self._output_normalizer.denormalize(model_response)
return self._output_normalizer.denormalize_streaming(model_response)
diff --git a/src/codegate/providers/llamacpp/normalizer.py b/src/codegate/providers/llamacpp/normalizer.py
index 38a97a79..8a0ae077 100644
--- a/src/codegate/providers/llamacpp/normalizer.py
+++ b/src/codegate/providers/llamacpp/normalizer.py
@@ -10,6 +10,11 @@ def normalize(self, data: Dict) -> ChatCompletionRequest:
"""
Normalize the input data
"""
+ # When doing FIM, we receive "prompt" instead of messages. Normalizing.
+ if "prompt" in data:
+ data["messages"] = [{"content": data.pop("prompt"), "role": "user"}]
+ # We can add as many parameters as we like to data. ChatCompletionRequest is not strict.
+ data["had_prompt_before"] = True
try:
return ChatCompletionRequest(**data)
except Exception as e:
@@ -19,6 +24,11 @@ def denormalize(self, data: ChatCompletionRequest) -> Dict:
"""
Denormalize the input data
"""
+ # If we receive "prompt" in FIM, we need convert it back.
+ if data.get("had_prompt_before", False):
+ data["prompt"] = data["messages"][0]["content"]
+ del data["had_prompt_before"]
+ del data["messages"]
return data
diff --git a/src/codegate/providers/llamacpp/provider.py b/src/codegate/providers/llamacpp/provider.py
index 26291cdc..eddfa901 100644
--- a/src/codegate/providers/llamacpp/provider.py
+++ b/src/codegate/providers/llamacpp/provider.py
@@ -1,20 +1,26 @@
import json
+from typing import Optional
from fastapi import Request
-from codegate.providers.base import BaseProvider
+from codegate.providers.base import BaseProvider, SequentialPipelineProcessor
from codegate.providers.llamacpp.completion_handler import LlamaCppCompletionHandler
from codegate.providers.llamacpp.normalizer import LLamaCppInputNormalizer, LLamaCppOutputNormalizer
class LlamaCppProvider(BaseProvider):
- def __init__(self, pipeline_processor=None):
+ def __init__(
+ self,
+ pipeline_processor: Optional[SequentialPipelineProcessor] = None,
+ fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None
+ ):
completion_handler = LlamaCppCompletionHandler()
super().__init__(
LLamaCppInputNormalizer(),
LLamaCppOutputNormalizer(),
completion_handler,
pipeline_processor,
+ fim_pipeline_processor
)
@property
@@ -34,5 +40,6 @@ async def create_completion(
body = await request.body()
data = json.loads(body)
- stream = await self.complete(data, api_key=None)
+ is_fim_request = self._is_fim_request(request, data)
+ stream = await self.complete(data, None, is_fim_request=is_fim_request)
return self._completion_handler.create_streaming_response(stream)
diff --git a/src/codegate/providers/openai/provider.py b/src/codegate/providers/openai/provider.py
index 6d1e6c1d..209118b5 100644
--- a/src/codegate/providers/openai/provider.py
+++ b/src/codegate/providers/openai/provider.py
@@ -1,20 +1,26 @@
import json
+from typing import Optional
from fastapi import Header, HTTPException, Request
-from codegate.providers.base import BaseProvider
+from codegate.providers.base import BaseProvider, SequentialPipelineProcessor
from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator
from codegate.providers.openai.adapter import OpenAIInputNormalizer, OpenAIOutputNormalizer
class OpenAIProvider(BaseProvider):
- def __init__(self, pipeline_processor=None):
+ def __init__(
+ self,
+ pipeline_processor: Optional[SequentialPipelineProcessor] = None,
+ fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None
+ ):
completion_handler = LiteLLmShim(stream_generator=sse_stream_generator)
super().__init__(
OpenAIInputNormalizer(),
OpenAIOutputNormalizer(),
completion_handler,
pipeline_processor,
+ fim_pipeline_processor
)
@property
@@ -29,6 +35,7 @@ def _setup_routes(self):
"""
@self.router.post(f"/{self.provider_route_name}/chat/completions")
+ @self.router.post(f"/{self.provider_route_name}/completions")
async def create_completion(
request: Request,
authorization: str = Header(..., description="Bearer token"),
@@ -40,5 +47,6 @@ async def create_completion(
body = await request.body()
data = json.loads(body)
- stream = await self.complete(data, api_key)
+ is_fim_request = self._is_fim_request(request, data)
+ stream = await self.complete(data, api_key, is_fim_request=is_fim_request)
return self._completion_handler.create_streaming_response(stream)
diff --git a/src/codegate/server.py b/src/codegate/server.py
index 359425a2..14ccb60a 100644
--- a/src/codegate/server.py
+++ b/src/codegate/server.py
@@ -21,15 +21,28 @@ def init_app() -> FastAPI:
steps: List[PipelineStep] = [
CodegateVersion(),
]
-
+ # Leaving the pipeline empty for now
+ fim_steps: List[PipelineStep] = [
+ ]
pipeline = SequentialPipelineProcessor(steps)
+ fim_pipeline = SequentialPipelineProcessor(fim_steps)
+
# Create provider registry
registry = ProviderRegistry(app)
# Register all known providers
- registry.add_provider("openai", OpenAIProvider(pipeline_processor=pipeline))
- registry.add_provider("anthropic", AnthropicProvider(pipeline_processor=pipeline))
- registry.add_provider("llamacpp", LlamaCppProvider(pipeline_processor=pipeline))
+ registry.add_provider("openai", OpenAIProvider(
+ pipeline_processor=pipeline,
+ fim_pipeline_processor=fim_pipeline
+ ))
+ registry.add_provider("anthropic", AnthropicProvider(
+ pipeline_processor=pipeline,
+ fim_pipeline_processor=fim_pipeline
+ ))
+ registry.add_provider("llamacpp", LlamaCppProvider(
+ pipeline_processor=pipeline,
+ fim_pipeline_processor=fim_pipeline
+ ))
# Create and add system routes
system_router = APIRouter(tags=["System"]) # Tags group endpoints in the docs
diff --git a/tests/test_provider.py b/tests/test_provider.py
new file mode 100644
index 00000000..0957d618
--- /dev/null
+++ b/tests/test_provider.py
@@ -0,0 +1,82 @@
+from unittest.mock import MagicMock
+
+import pytest
+
+from codegate.providers.base import BaseProvider
+
+
+class MockProvider(BaseProvider):
+
+ def __init__(self):
+ mocked_input_normalizer = MagicMock()
+ mocked_output_normalizer = MagicMock()
+ mocked_completion_handler = MagicMock()
+ mocked_pipepeline = MagicMock()
+ mocked_fim_pipeline = MagicMock()
+ super().__init__(
+ mocked_input_normalizer,
+ mocked_output_normalizer,
+ mocked_completion_handler,
+ mocked_pipepeline,
+ mocked_fim_pipeline
+ )
+
+ def _setup_routes(self) -> None:
+ pass
+
+ @property
+ def provider_route_name(self) -> str:
+ return 'mock-provider'
+
+
+@pytest.mark.parametrize("url, expected_bool", [
+ ("http://example.com", False),
+ ("http://test.com/chat/completions", False),
+ ("http://example.com/completions", True),
+])
+def test_is_fim_request_url(url, expected_bool):
+ mock_provider = MockProvider()
+ request = MagicMock()
+ request.url.path = url
+ assert mock_provider._is_fim_request_url(request) == expected_bool
+
+
+DATA_CONTENT_STR = {
+ "messages": [
+ {
+ "role": "user",
+ "content": ' ',
+ }
+ ]
+}
+DATA_CONTENT_LIST = {
+ "messages": [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": " "
+ }
+ ],
+ }
+ ]
+}
+INVALID_DATA_CONTET = {
+ "messages": [
+ {
+ "role": "user",
+ "content": "http://example.com/completions",
+ }
+ ]
+}
+
+
+@pytest.mark.parametrize("data, expected_bool", [
+ (DATA_CONTENT_STR, True),
+ (DATA_CONTENT_LIST, True),
+ (INVALID_DATA_CONTET, False),
+])
+def test_is_fim_request_body(data, expected_bool):
+ mock_provider = MockProvider()
+ assert mock_provider._is_fim_request_body(data) == expected_bool