diff --git a/src/codegate/pipeline/fim/secret_analyzer.py b/src/codegate/pipeline/fim/secret_analyzer.py new file mode 100644 index 00000000..83f89a11 --- /dev/null +++ b/src/codegate/pipeline/fim/secret_analyzer.py @@ -0,0 +1,52 @@ +from litellm import ChatCompletionRequest + +from codegate.codegate_logging import setup_logging +from codegate.pipeline.base import PipelineContext, PipelineResponse, PipelineResult, PipelineStep + +logger = setup_logging() + + +class SecretAnalyzer(PipelineStep): + """Pipeline step that handles analyzing secrets in FIM pipeline.""" + + message_blocked = """ + ⚠️ CodeGate Security Warning! Analysis Report ⚠️ + Potential leak of sensitive credentials blocked + + Recommendations: + - Use environment variables for secrets + """ + + @property + def name(self) -> str: + """ + Returns the name of this pipeline step. + + Returns: + str: The identifier 'fim-secret-analyzer' + """ + return "fim-secret-analyzer" + + async def process( + self, + request: ChatCompletionRequest, + context: PipelineContext + ) -> PipelineResult: + # We should call here Secrets Blocking module to see if the request messages contain secrets + # messages_contain_secrets = [analyze_msg_secrets(msg) for msg in request.messages] + # message_with_secrets = any(messages_contain_secretes) + + # For the moment to test shortcutting just treat all messages as if they contain secrets + message_with_secrets = False + if message_with_secrets: + logger.info('Blocking message with secrets.') + return PipelineResult( + response=PipelineResponse( + step_name=self.name, + content=self.message_blocked, + model=request["model"], + ), + ) + + # No messages with secrets, execute the rest of the pipeline + return PipelineResult(request=request) diff --git a/src/codegate/providers/anthropic/completion_handler.py b/src/codegate/providers/anthropic/completion_handler.py new file mode 100644 index 00000000..ed45eb40 --- /dev/null +++ b/src/codegate/providers/anthropic/completion_handler.py @@ -0,0 +1,33 @@ +from typing import AsyncIterator, Optional, Union + +from litellm import ChatCompletionRequest, ModelResponse + +from codegate.providers.litellmshim import LiteLLmShim + + +class AnthropicCompletion(LiteLLmShim): + """ + AnthropicCompletion used by the Anthropic provider to execute completions + """ + + async def execute_completion( + self, + request: ChatCompletionRequest, + api_key: Optional[str], + stream: bool = False, + ) -> Union[ModelResponse, AsyncIterator[ModelResponse]]: + """ + Ensures the model name is prefixed with 'anthropic/' to explicitly route to Anthropic's API. + + LiteLLM automatically maps most model names, but prepending 'anthropic/' forces the request + to Anthropic. This avoids issues with unrecognized names like 'claude-3-5-sonnet-latest', + which LiteLLM doesn't accept as a valid Anthropic model. This safeguard may be unnecessary + but ensures compatibility. + + For more details, refer to the + [LiteLLM Documentation](https://docs.litellm.ai/docs/providers/anthropic). + """ + model_in_request = request['model'] + if not model_in_request.startswith('anthropic/'): + request['model'] = f'anthropic/{model_in_request}' + return await super().execute_completion(request, api_key, stream) diff --git a/src/codegate/providers/anthropic/provider.py b/src/codegate/providers/anthropic/provider.py index a16c5921..e0042ace 100644 --- a/src/codegate/providers/anthropic/provider.py +++ b/src/codegate/providers/anthropic/provider.py @@ -1,20 +1,27 @@ import json +from typing import Optional from fastapi import Header, HTTPException, Request from codegate.providers.anthropic.adapter import AnthropicInputNormalizer, AnthropicOutputNormalizer -from codegate.providers.base import BaseProvider -from codegate.providers.litellmshim import LiteLLmShim, anthropic_stream_generator +from codegate.providers.anthropic.completion_handler import AnthropicCompletion +from codegate.providers.base import BaseProvider, SequentialPipelineProcessor +from codegate.providers.litellmshim import anthropic_stream_generator class AnthropicProvider(BaseProvider): - def __init__(self, pipeline_processor=None): - completion_handler = LiteLLmShim(stream_generator=anthropic_stream_generator) + def __init__( + self, + pipeline_processor: Optional[SequentialPipelineProcessor] = None, + fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None + ): + completion_handler = AnthropicCompletion(stream_generator=anthropic_stream_generator) super().__init__( AnthropicInputNormalizer(), AnthropicOutputNormalizer(), completion_handler, pipeline_processor, + fim_pipeline_processor ) @property @@ -39,5 +46,6 @@ async def create_message( body = await request.body() data = json.loads(body) - stream = await self.complete(data, x_api_key) + is_fim_request = self._is_fim_request(request, data) + stream = await self.complete(data, x_api_key, is_fim_request) return self._completion_handler.create_streaming_response(stream) diff --git a/src/codegate/providers/base.py b/src/codegate/providers/base.py index 940d93b2..4f2d65e0 100644 --- a/src/codegate/providers/base.py +++ b/src/codegate/providers/base.py @@ -1,15 +1,17 @@ from abc import ABC, abstractmethod from typing import Any, AsyncIterator, Callable, Dict, Optional, Union -from fastapi import APIRouter +from fastapi import APIRouter, Request from litellm import ModelResponse from litellm.types.llms.openai import ChatCompletionRequest +from codegate.codegate_logging import setup_logging from codegate.pipeline.base import PipelineResult, SequentialPipelineProcessor from codegate.providers.completion.base import BaseCompletionHandler from codegate.providers.formatting.input_pipeline import PipelineResponseFormatter from codegate.providers.normalizer.base import ModelInputNormalizer, ModelOutputNormalizer +logger = setup_logging() StreamGenerator = Callable[[AsyncIterator[Any]], AsyncIterator[str]] @@ -25,12 +27,14 @@ def __init__( output_normalizer: ModelOutputNormalizer, completion_handler: BaseCompletionHandler, pipeline_processor: Optional[SequentialPipelineProcessor] = None, + fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None, ): self.router = APIRouter() self._completion_handler = completion_handler self._input_normalizer = input_normalizer self._output_normalizer = output_normalizer self._pipeline_processor = pipeline_processor + self._fim_pipelin_processor = fim_pipeline_processor self._pipeline_response_formatter = PipelineResponseFormatter(output_normalizer) @@ -48,11 +52,19 @@ def provider_route_name(self) -> str: async def _run_input_pipeline( self, normalized_request: ChatCompletionRequest, + is_fim_request: bool ) -> PipelineResult: - if self._pipeline_processor is None: + # Decide which pipeline processor to use + if is_fim_request: + pipeline_processor = self._fim_pipelin_processor + logger.info('FIM pipeline selected for execution.') + else: + pipeline_processor = self._pipeline_processor + logger.info('Chat completion pipeline selected for execution.') + if pipeline_processor is None: return PipelineResult(request=normalized_request) - result = await self._pipeline_processor.process_request(normalized_request) + result = await pipeline_processor.process_request(normalized_request) # TODO(jakub): handle this by returning a message to the client if result.error_message: @@ -60,10 +72,56 @@ async def _run_input_pipeline( return result + def _is_fim_request_url(self, request: Request) -> bool: + """ + Checks the request URL to determine if a request is FIM or chat completion. + Used by: llama.cpp + """ + request_path = request.url.path + # Evaluate first a larger substring. + if request_path.endswith("/chat/completions"): + return False + + if request_path.endswith("/completions"): + return True + + return False + + def _is_fim_request_body(self, data: Dict) -> bool: + """ + Determine from the raw incoming data if it's a FIM request. + Used by: OpenAI and Anthropic + """ + messages = data.get('messages', []) + if not messages: + return False + + first_message_content = messages[0].get('content') + if first_message_content is None: + return False + + fim_stop_sequences = ['', '', '', ''] + if isinstance(first_message_content, str): + msg_prompt = first_message_content + elif isinstance(first_message_content, list): + msg_prompt = first_message_content[0].get('text', '') + else: + logger.warning(f'Could not determine if message was FIM from data: {data}') + return False + return all([stop_sequence in msg_prompt for stop_sequence in fim_stop_sequences]) + + def _is_fim_request(self, request: Request, data: Dict) -> bool: + """ + Determin if the request is FIM by the URL or the data of the request. + """ + # Avoid more expensive inspection of body by just checking the URL. + if self._is_fim_request_url(request): + return True + + return self._is_fim_request_body(data) + async def complete( - self, - data: Dict, - api_key: Optional[str], + self, data: Dict, api_key: Optional[str], is_fim_request: bool ) -> Union[ModelResponse, AsyncIterator[ModelResponse]]: """ Main completion flow with pipeline integration @@ -78,8 +136,7 @@ async def complete( """ normalized_request = self._input_normalizer.normalize(data) streaming = data.get("stream", False) - - input_pipeline_result = await self._run_input_pipeline(normalized_request) + input_pipeline_result = await self._run_input_pipeline(normalized_request, is_fim_request) if input_pipeline_result.response: return self._pipeline_response_formatter.handle_pipeline_response( input_pipeline_result.response, streaming @@ -93,7 +150,6 @@ async def complete( model_response = await self._completion_handler.execute_completion( provider_request, api_key=api_key, stream=streaming ) - if not streaming: return self._output_normalizer.denormalize(model_response) return self._output_normalizer.denormalize_streaming(model_response) diff --git a/src/codegate/providers/llamacpp/normalizer.py b/src/codegate/providers/llamacpp/normalizer.py index 38a97a79..8a0ae077 100644 --- a/src/codegate/providers/llamacpp/normalizer.py +++ b/src/codegate/providers/llamacpp/normalizer.py @@ -10,6 +10,11 @@ def normalize(self, data: Dict) -> ChatCompletionRequest: """ Normalize the input data """ + # When doing FIM, we receive "prompt" instead of messages. Normalizing. + if "prompt" in data: + data["messages"] = [{"content": data.pop("prompt"), "role": "user"}] + # We can add as many parameters as we like to data. ChatCompletionRequest is not strict. + data["had_prompt_before"] = True try: return ChatCompletionRequest(**data) except Exception as e: @@ -19,6 +24,11 @@ def denormalize(self, data: ChatCompletionRequest) -> Dict: """ Denormalize the input data """ + # If we receive "prompt" in FIM, we need convert it back. + if data.get("had_prompt_before", False): + data["prompt"] = data["messages"][0]["content"] + del data["had_prompt_before"] + del data["messages"] return data diff --git a/src/codegate/providers/llamacpp/provider.py b/src/codegate/providers/llamacpp/provider.py index 26291cdc..eddfa901 100644 --- a/src/codegate/providers/llamacpp/provider.py +++ b/src/codegate/providers/llamacpp/provider.py @@ -1,20 +1,26 @@ import json +from typing import Optional from fastapi import Request -from codegate.providers.base import BaseProvider +from codegate.providers.base import BaseProvider, SequentialPipelineProcessor from codegate.providers.llamacpp.completion_handler import LlamaCppCompletionHandler from codegate.providers.llamacpp.normalizer import LLamaCppInputNormalizer, LLamaCppOutputNormalizer class LlamaCppProvider(BaseProvider): - def __init__(self, pipeline_processor=None): + def __init__( + self, + pipeline_processor: Optional[SequentialPipelineProcessor] = None, + fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None + ): completion_handler = LlamaCppCompletionHandler() super().__init__( LLamaCppInputNormalizer(), LLamaCppOutputNormalizer(), completion_handler, pipeline_processor, + fim_pipeline_processor ) @property @@ -34,5 +40,6 @@ async def create_completion( body = await request.body() data = json.loads(body) - stream = await self.complete(data, api_key=None) + is_fim_request = self._is_fim_request(request, data) + stream = await self.complete(data, None, is_fim_request=is_fim_request) return self._completion_handler.create_streaming_response(stream) diff --git a/src/codegate/providers/openai/provider.py b/src/codegate/providers/openai/provider.py index 6d1e6c1d..209118b5 100644 --- a/src/codegate/providers/openai/provider.py +++ b/src/codegate/providers/openai/provider.py @@ -1,20 +1,26 @@ import json +from typing import Optional from fastapi import Header, HTTPException, Request -from codegate.providers.base import BaseProvider +from codegate.providers.base import BaseProvider, SequentialPipelineProcessor from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator from codegate.providers.openai.adapter import OpenAIInputNormalizer, OpenAIOutputNormalizer class OpenAIProvider(BaseProvider): - def __init__(self, pipeline_processor=None): + def __init__( + self, + pipeline_processor: Optional[SequentialPipelineProcessor] = None, + fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None + ): completion_handler = LiteLLmShim(stream_generator=sse_stream_generator) super().__init__( OpenAIInputNormalizer(), OpenAIOutputNormalizer(), completion_handler, pipeline_processor, + fim_pipeline_processor ) @property @@ -29,6 +35,7 @@ def _setup_routes(self): """ @self.router.post(f"/{self.provider_route_name}/chat/completions") + @self.router.post(f"/{self.provider_route_name}/completions") async def create_completion( request: Request, authorization: str = Header(..., description="Bearer token"), @@ -40,5 +47,6 @@ async def create_completion( body = await request.body() data = json.loads(body) - stream = await self.complete(data, api_key) + is_fim_request = self._is_fim_request(request, data) + stream = await self.complete(data, api_key, is_fim_request=is_fim_request) return self._completion_handler.create_streaming_response(stream) diff --git a/src/codegate/server.py b/src/codegate/server.py index 359425a2..14ccb60a 100644 --- a/src/codegate/server.py +++ b/src/codegate/server.py @@ -21,15 +21,28 @@ def init_app() -> FastAPI: steps: List[PipelineStep] = [ CodegateVersion(), ] - + # Leaving the pipeline empty for now + fim_steps: List[PipelineStep] = [ + ] pipeline = SequentialPipelineProcessor(steps) + fim_pipeline = SequentialPipelineProcessor(fim_steps) + # Create provider registry registry = ProviderRegistry(app) # Register all known providers - registry.add_provider("openai", OpenAIProvider(pipeline_processor=pipeline)) - registry.add_provider("anthropic", AnthropicProvider(pipeline_processor=pipeline)) - registry.add_provider("llamacpp", LlamaCppProvider(pipeline_processor=pipeline)) + registry.add_provider("openai", OpenAIProvider( + pipeline_processor=pipeline, + fim_pipeline_processor=fim_pipeline + )) + registry.add_provider("anthropic", AnthropicProvider( + pipeline_processor=pipeline, + fim_pipeline_processor=fim_pipeline + )) + registry.add_provider("llamacpp", LlamaCppProvider( + pipeline_processor=pipeline, + fim_pipeline_processor=fim_pipeline + )) # Create and add system routes system_router = APIRouter(tags=["System"]) # Tags group endpoints in the docs diff --git a/tests/test_provider.py b/tests/test_provider.py new file mode 100644 index 00000000..0957d618 --- /dev/null +++ b/tests/test_provider.py @@ -0,0 +1,82 @@ +from unittest.mock import MagicMock + +import pytest + +from codegate.providers.base import BaseProvider + + +class MockProvider(BaseProvider): + + def __init__(self): + mocked_input_normalizer = MagicMock() + mocked_output_normalizer = MagicMock() + mocked_completion_handler = MagicMock() + mocked_pipepeline = MagicMock() + mocked_fim_pipeline = MagicMock() + super().__init__( + mocked_input_normalizer, + mocked_output_normalizer, + mocked_completion_handler, + mocked_pipepeline, + mocked_fim_pipeline + ) + + def _setup_routes(self) -> None: + pass + + @property + def provider_route_name(self) -> str: + return 'mock-provider' + + +@pytest.mark.parametrize("url, expected_bool", [ + ("http://example.com", False), + ("http://test.com/chat/completions", False), + ("http://example.com/completions", True), +]) +def test_is_fim_request_url(url, expected_bool): + mock_provider = MockProvider() + request = MagicMock() + request.url.path = url + assert mock_provider._is_fim_request_url(request) == expected_bool + + +DATA_CONTENT_STR = { + "messages": [ + { + "role": "user", + "content": ' ', + } + ] +} +DATA_CONTENT_LIST = { + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": " " + } + ], + } + ] +} +INVALID_DATA_CONTET = { + "messages": [ + { + "role": "user", + "content": "http://example.com/completions", + } + ] +} + + +@pytest.mark.parametrize("data, expected_bool", [ + (DATA_CONTENT_STR, True), + (DATA_CONTENT_LIST, True), + (INVALID_DATA_CONTET, False), +]) +def test_is_fim_request_body(data, expected_bool): + mock_provider = MockProvider() + assert mock_provider._is_fim_request_body(data) == expected_bool