From 09060f1db89e122ff0f7b2f2c1bf0c64af67af99 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Tue, 26 Nov 2024 13:09:38 +0100 Subject: [PATCH 1/9] Add input processing pipeline + codegate-version pipeline step This adds a pipeline processing before the completion is ran where the request is either change or can be shortcut. This pipeline consists of steps, for now we implement a single step `CodegateVersion` that responds with the codegate version if the verbatim `codegate-version` string is found in the input. The pipeline also passes along a context, for now that is unused but I thought this would be where we store extracted code snippets etc. To avoid import loops, we also move the `BaseCompletionHandler` class to a new `completion` package. Since the shortcut replies are more or less simple strings, we add yet another package `providers/formatting` whose responsibility is to convert the string returned by the shortcut response to the format expected by the client, meaning either a reply or a stream of replies in the LLM-specific format. We use the `BaseCompletionHandler` as a way to convert to the LLM-specific format. --- src/codegate/providers/litellmshim/generators.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/codegate/providers/litellmshim/generators.py b/src/codegate/providers/litellmshim/generators.py index 306f1900..e36354ac 100644 --- a/src/codegate/providers/litellmshim/generators.py +++ b/src/codegate/providers/litellmshim/generators.py @@ -12,6 +12,7 @@ async def sse_stream_generator(stream: AsyncIterator[Any]) -> AsyncIterator[str] """OpenAI-style SSE format""" try: async for chunk in stream: + print(chunk) if isinstance(chunk, BaseModel): # alternatively we might want to just dump the whole object # this might even allow us to tighten the typing of the stream From c36a53e78ea9417e9940964fe542d75ef3752b06 Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Wed, 27 Nov 2024 12:19:53 +0200 Subject: [PATCH 2/9] Add a FIM pipeline to providers Related: #87, #43 The PR adds a FIM pipeline independent from chat completion pipeline. It could still be faulty since we need: - Message normalizer. We now expect all messages to have the key `messages`. However, there are incoming messages with `prompt`. - Secreets detector. There's the skeleton of a class called SecretAnalyzer that is meant to analyze the messages and return a warning if it detected a secret. --- src/codegate/pipeline/fim/secret_analyzer.py | 48 +++++++++++++++++++ src/codegate/providers/anthropic/provider.py | 1 + .../providers/litellmshim/litellmshim.py | 8 ++++ .../providers/llamacpp/completion_handler.py | 8 ++++ src/codegate/providers/llamacpp/provider.py | 2 + src/codegate/providers/openai/provider.py | 17 +++++++ src/codegate/server.py | 12 +++-- 7 files changed, 93 insertions(+), 3 deletions(-) create mode 100644 src/codegate/pipeline/fim/secret_analyzer.py diff --git a/src/codegate/pipeline/fim/secret_analyzer.py b/src/codegate/pipeline/fim/secret_analyzer.py new file mode 100644 index 00000000..83d05403 --- /dev/null +++ b/src/codegate/pipeline/fim/secret_analyzer.py @@ -0,0 +1,48 @@ +from litellm import ChatCompletionRequest + +from codegate.pipeline.base import PipelineContext, PipelineResult, PipelineStep, PipelineResponse + + +class SecretAnalyzer(PipelineStep): + """Pipeline step that handles version information requests.""" + + message_blocked = """ + ⚠️ CodeGate Security Warning! Analysis Report ⚠️ + Potential leak of sensitive credentials blocked + + Recommendations: + - Use environment variables for secrets + """ + + @property + def name(self) -> str: + """ + Returns the name of this pipeline step. + + Returns: + str: The identifier 'fim-secret-analyzer' + """ + return "fim-secret-analyzer" + + async def process( + self, + request: ChatCompletionRequest, + context: PipelineContext + ) -> PipelineResult: + # We should call here Secrets Blocking module to see if the request messages contain secrets + # messages_contain_secrets = [analyze_msg_secrets(msg) for msg in request.messages] + # message_with_secrets = any(messages_contain_secretes) + + # For the moment to test shortcutting just treat all messages as if they contain secrets + message_with_secrets = True + if message_with_secrets: + return PipelineResult( + response=PipelineResponse( + step_name=self.name, + content=self.message_blocked, + model=request["model"], + ), + ) + + # No messages with secrets, execute the rest of the pipeline + return PipelineResult(request=request) diff --git a/src/codegate/providers/anthropic/provider.py b/src/codegate/providers/anthropic/provider.py index a16c5921..851bdd0a 100644 --- a/src/codegate/providers/anthropic/provider.py +++ b/src/codegate/providers/anthropic/provider.py @@ -1,4 +1,5 @@ import json +from typing import Optional from fastapi import Header, HTTPException, Request diff --git a/src/codegate/providers/litellmshim/litellmshim.py b/src/codegate/providers/litellmshim/litellmshim.py index 1b4dcdf5..a2be0431 100644 --- a/src/codegate/providers/litellmshim/litellmshim.py +++ b/src/codegate/providers/litellmshim/litellmshim.py @@ -43,3 +43,11 @@ def create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResp }, status_code=200, ) + + def is_fim_request(self, data: Dict) -> bool: + """ + Determine from the raw incoming data if it's a FIM request. + This is needed here since completion_handler is used by provider and provider + doesn't know about the adapter. + """ + return self._adapter.is_fim_request(data) diff --git a/src/codegate/providers/llamacpp/completion_handler.py b/src/codegate/providers/llamacpp/completion_handler.py index 40046975..e99deaf3 100644 --- a/src/codegate/providers/llamacpp/completion_handler.py +++ b/src/codegate/providers/llamacpp/completion_handler.py @@ -65,3 +65,11 @@ def create_streaming_response(self, stream: Iterator[Any]) -> StreamingResponse: }, status_code=200, ) + + def is_fim_request(self, data: Dict) -> bool: + """ + Determine from the raw incoming data if it's a FIM request. + This is needed here since completion_handler is used by provider and provider + doesn't know about the adapter. + """ + return self._adapter.is_fim_request(data) diff --git a/src/codegate/providers/llamacpp/provider.py b/src/codegate/providers/llamacpp/provider.py index 26291cdc..8f792846 100644 --- a/src/codegate/providers/llamacpp/provider.py +++ b/src/codegate/providers/llamacpp/provider.py @@ -1,7 +1,9 @@ import json +from typing import Optional from fastapi import Request +from codegate.pipeline.base import SequentialPipelineProcessor from codegate.providers.base import BaseProvider from codegate.providers.llamacpp.completion_handler import LlamaCppCompletionHandler from codegate.providers.llamacpp.normalizer import LLamaCppInputNormalizer, LLamaCppOutputNormalizer diff --git a/src/codegate/providers/openai/provider.py b/src/codegate/providers/openai/provider.py index 6d1e6c1d..e221ea92 100644 --- a/src/codegate/providers/openai/provider.py +++ b/src/codegate/providers/openai/provider.py @@ -1,7 +1,9 @@ import json +from typing import Optional from fastapi import Header, HTTPException, Request +from codegate.pipeline.base import SequentialPipelineProcessor from codegate.providers.base import BaseProvider from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator from codegate.providers.openai.adapter import OpenAIInputNormalizer, OpenAIOutputNormalizer @@ -42,3 +44,18 @@ async def create_completion( stream = await self.complete(data, api_key) return self._completion_handler.create_streaming_response(stream) + + @self.router.post(f"/{self.provider_route_name}/completions") + async def create_fim( + request: Request, + authorization: str = Header(..., description="Bearer token"), + ): + if not authorization.startswith("Bearer "): + raise HTTPException(status_code=401, detail="Invalid authorization header") + + api_key = authorization.split(" ")[1] + body = await request.body() + data = json.loads(body) + + stream = await self.complete(data, api_key) + return self._completion_handler.create_streaming_response(stream) diff --git a/src/codegate/server.py b/src/codegate/server.py index 359425a2..7c28dc42 100644 --- a/src/codegate/server.py +++ b/src/codegate/server.py @@ -5,6 +5,7 @@ from codegate import __description__, __version__ from codegate.pipeline.base import PipelineStep, SequentialPipelineProcessor from codegate.pipeline.version.version import CodegateVersion +from codegate.pipeline.fim.secret_analyzer import SecretAnalyzer from codegate.providers.anthropic.provider import AnthropicProvider from codegate.providers.llamacpp.provider import LlamaCppProvider from codegate.providers.openai.provider import OpenAIProvider @@ -21,15 +22,20 @@ def init_app() -> FastAPI: steps: List[PipelineStep] = [ CodegateVersion(), ] + fim_steps: List[PipelineStep] = [ + SecretAnalyzer(), + ] pipeline = SequentialPipelineProcessor(steps) + fim_pipeline = SequentialPipelineProcessor(fim_steps) + # Create provider registry registry = ProviderRegistry(app) # Register all known providers - registry.add_provider("openai", OpenAIProvider(pipeline_processor=pipeline)) - registry.add_provider("anthropic", AnthropicProvider(pipeline_processor=pipeline)) - registry.add_provider("llamacpp", LlamaCppProvider(pipeline_processor=pipeline)) + registry.add_provider("openai", OpenAIProvider(pipeline_processor=pipeline, fim_pipeline_processor=fim_pipeline)) + registry.add_provider("anthropic", AnthropicProvider(pipeline_processor=pipeline, fim_pipeline_processor=fim_pipeline)) + registry.add_provider("llamacpp", LlamaCppProvider(pipeline_processor=pipeline, fim_pipeline_processor=fim_pipeline)) # Create and add system routes system_router = APIRouter(tags=["System"]) # Tags group endpoints in the docs From b7d54891e9383ec29ae5fbe2067eb5c8253fb3a1 Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Wed, 27 Nov 2024 15:41:03 +0200 Subject: [PATCH 3/9] Adding output from make format and fixing unit tests --- src/codegate/pipeline/fim/secret_analyzer.py | 2 +- src/codegate/providers/llamacpp/provider.py | 1 + src/codegate/server.py | 16 +++++++++++++--- tests/providers/anthropic/test_adapter.py | 2 +- tests/providers/litellmshim/test_litellmshim.py | 3 +++ tests/providers/test_registry.py | 3 +++ 6 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/codegate/pipeline/fim/secret_analyzer.py b/src/codegate/pipeline/fim/secret_analyzer.py index 83d05403..eb25ad4f 100644 --- a/src/codegate/pipeline/fim/secret_analyzer.py +++ b/src/codegate/pipeline/fim/secret_analyzer.py @@ -1,6 +1,6 @@ from litellm import ChatCompletionRequest -from codegate.pipeline.base import PipelineContext, PipelineResult, PipelineStep, PipelineResponse +from codegate.pipeline.base import PipelineContext, PipelineResponse, PipelineResult, PipelineStep class SecretAnalyzer(PipelineStep): diff --git a/src/codegate/providers/llamacpp/provider.py b/src/codegate/providers/llamacpp/provider.py index 8f792846..f7e8c981 100644 --- a/src/codegate/providers/llamacpp/provider.py +++ b/src/codegate/providers/llamacpp/provider.py @@ -5,6 +5,7 @@ from codegate.pipeline.base import SequentialPipelineProcessor from codegate.providers.base import BaseProvider +from codegate.providers.llamacpp.adapter import LlamaCppAdapter from codegate.providers.llamacpp.completion_handler import LlamaCppCompletionHandler from codegate.providers.llamacpp.normalizer import LLamaCppInputNormalizer, LLamaCppOutputNormalizer diff --git a/src/codegate/server.py b/src/codegate/server.py index 7c28dc42..db6b7244 100644 --- a/src/codegate/server.py +++ b/src/codegate/server.py @@ -6,6 +6,7 @@ from codegate.pipeline.base import PipelineStep, SequentialPipelineProcessor from codegate.pipeline.version.version import CodegateVersion from codegate.pipeline.fim.secret_analyzer import SecretAnalyzer +from codegate.pipeline.version.version import CodegateVersion from codegate.providers.anthropic.provider import AnthropicProvider from codegate.providers.llamacpp.provider import LlamaCppProvider from codegate.providers.openai.provider import OpenAIProvider @@ -33,9 +34,18 @@ def init_app() -> FastAPI: registry = ProviderRegistry(app) # Register all known providers - registry.add_provider("openai", OpenAIProvider(pipeline_processor=pipeline, fim_pipeline_processor=fim_pipeline)) - registry.add_provider("anthropic", AnthropicProvider(pipeline_processor=pipeline, fim_pipeline_processor=fim_pipeline)) - registry.add_provider("llamacpp", LlamaCppProvider(pipeline_processor=pipeline, fim_pipeline_processor=fim_pipeline)) + registry.add_provider("openai", OpenAIProvider( + pipeline_processor=pipeline, + fim_pipeline_processor=fim_pipeline + )) + registry.add_provider("anthropic", AnthropicProvider( + pipeline_processor=pipeline, + fim_pipeline_processor=fim_pipeline + )) + registry.add_provider("llamacpp", LlamaCppProvider( + pipeline_processor=pipeline, + fim_pipeline_processor=fim_pipeline + )) # Create and add system routes system_router = APIRouter(tags=["System"]) # Tags group endpoints in the docs diff --git a/tests/providers/anthropic/test_adapter.py b/tests/providers/anthropic/test_adapter.py index 9bb81e54..34452912 100644 --- a/tests/providers/anthropic/test_adapter.py +++ b/tests/providers/anthropic/test_adapter.py @@ -40,7 +40,7 @@ def test_normalize_anthropic_input(input_normalizer): {"content": "You are an expert code reviewer", "role": "system"}, {"content": [{"text": "Review this code", "type": "text"}], "role": "user"}, ], - "model": "claude-3-haiku-20240307", + "model": "anthropic/claude-3-haiku-20240307", "stream": True, } diff --git a/tests/providers/litellmshim/test_litellmshim.py b/tests/providers/litellmshim/test_litellmshim.py index 73889a34..edf87996 100644 --- a/tests/providers/litellmshim/test_litellmshim.py +++ b/tests/providers/litellmshim/test_litellmshim.py @@ -37,6 +37,9 @@ async def modified_stream(): return modified_stream() + def is_fim_request(self, data: Dict) -> bool: + return False + @pytest.mark.asyncio async def test_complete_non_streaming(): diff --git a/tests/providers/test_registry.py b/tests/providers/test_registry.py index 8c957f13..64234f32 100644 --- a/tests/providers/test_registry.py +++ b/tests/providers/test_registry.py @@ -49,6 +49,9 @@ def create_streaming_response( ) -> StreamingResponse: return StreamingResponse(stream) + def is_fim_request(self, data: Dict) -> bool: + return False + class MockInputNormalizer(ModelInputNormalizer): def normalize(self, data: Dict) -> Dict: From 0172668f1978e9bb1f13309408f6e37b8b0796b9 Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Wed, 27 Nov 2024 18:06:51 +0200 Subject: [PATCH 4/9] Channging check of is_fim to chec for stop_sequences --- src/codegate/providers/litellmshim/generators.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/codegate/providers/litellmshim/generators.py b/src/codegate/providers/litellmshim/generators.py index e36354ac..306f1900 100644 --- a/src/codegate/providers/litellmshim/generators.py +++ b/src/codegate/providers/litellmshim/generators.py @@ -12,7 +12,6 @@ async def sse_stream_generator(stream: AsyncIterator[Any]) -> AsyncIterator[str] """OpenAI-style SSE format""" try: async for chunk in stream: - print(chunk) if isinstance(chunk, BaseModel): # alternatively we might want to just dump the whole object # this might even allow us to tighten the typing of the stream From 4ee938ffb9cdc6f1c1c700983e44e4386dff8c78 Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Thu, 28 Nov 2024 11:49:24 +0200 Subject: [PATCH 5/9] Reformatted checkf for is_fim_request --- src/codegate/pipeline/fim/secret_analyzer.py | 6 +- .../providers/anthropic/completion_handler.py | 27 +++++++ src/codegate/providers/anthropic/provider.py | 17 +++-- src/codegate/providers/base.py | 72 +++++++++++++++++-- .../providers/litellmshim/litellmshim.py | 8 --- .../providers/llamacpp/completion_handler.py | 8 --- src/codegate/providers/llamacpp/provider.py | 14 ++-- src/codegate/providers/openai/provider.py | 29 +++----- src/codegate/server.py | 2 - tests/providers/anthropic/test_adapter.py | 2 +- .../providers/litellmshim/test_litellmshim.py | 3 - 11 files changed, 129 insertions(+), 59 deletions(-) create mode 100644 src/codegate/providers/anthropic/completion_handler.py diff --git a/src/codegate/pipeline/fim/secret_analyzer.py b/src/codegate/pipeline/fim/secret_analyzer.py index eb25ad4f..ff0eeb4b 100644 --- a/src/codegate/pipeline/fim/secret_analyzer.py +++ b/src/codegate/pipeline/fim/secret_analyzer.py @@ -1,7 +1,10 @@ from litellm import ChatCompletionRequest +from codegate.codegate_logging import setup_logging from codegate.pipeline.base import PipelineContext, PipelineResponse, PipelineResult, PipelineStep +logger = setup_logging() + class SecretAnalyzer(PipelineStep): """Pipeline step that handles version information requests.""" @@ -34,8 +37,9 @@ async def process( # message_with_secrets = any(messages_contain_secretes) # For the moment to test shortcutting just treat all messages as if they contain secrets - message_with_secrets = True + message_with_secrets = False if message_with_secrets: + logger.info('Blocking message with secrets.') return PipelineResult( response=PipelineResponse( step_name=self.name, diff --git a/src/codegate/providers/anthropic/completion_handler.py b/src/codegate/providers/anthropic/completion_handler.py new file mode 100644 index 00000000..e5b77819 --- /dev/null +++ b/src/codegate/providers/anthropic/completion_handler.py @@ -0,0 +1,27 @@ +from typing import AsyncIterator, Optional, Union + +from litellm import ChatCompletionRequest, ModelResponse + +from codegate.providers.litellmshim import LiteLLmShim + + +class AnthropicCompletion(LiteLLmShim): + """ + LiteLLM Shim is a wrapper around LiteLLM's API that allows us to use it with + our own completion handler interface without exposing the underlying + LiteLLM API. + """ + + async def execute_completion( + self, + request: ChatCompletionRequest, + api_key: Optional[str], + stream: bool = False, + ) -> Union[ModelResponse, AsyncIterator[ModelResponse]]: + """ + Execute the completion request with LiteLLM's API + """ + model_in_request = request['model'] + if not model_in_request.startswith('anthropic/'): + request['model'] = f'anthropic/{model_in_request}' + return await super().execute_completion(request, api_key, stream) diff --git a/src/codegate/providers/anthropic/provider.py b/src/codegate/providers/anthropic/provider.py index 851bdd0a..e0042ace 100644 --- a/src/codegate/providers/anthropic/provider.py +++ b/src/codegate/providers/anthropic/provider.py @@ -4,18 +4,24 @@ from fastapi import Header, HTTPException, Request from codegate.providers.anthropic.adapter import AnthropicInputNormalizer, AnthropicOutputNormalizer -from codegate.providers.base import BaseProvider -from codegate.providers.litellmshim import LiteLLmShim, anthropic_stream_generator +from codegate.providers.anthropic.completion_handler import AnthropicCompletion +from codegate.providers.base import BaseProvider, SequentialPipelineProcessor +from codegate.providers.litellmshim import anthropic_stream_generator class AnthropicProvider(BaseProvider): - def __init__(self, pipeline_processor=None): - completion_handler = LiteLLmShim(stream_generator=anthropic_stream_generator) + def __init__( + self, + pipeline_processor: Optional[SequentialPipelineProcessor] = None, + fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None + ): + completion_handler = AnthropicCompletion(stream_generator=anthropic_stream_generator) super().__init__( AnthropicInputNormalizer(), AnthropicOutputNormalizer(), completion_handler, pipeline_processor, + fim_pipeline_processor ) @property @@ -40,5 +46,6 @@ async def create_message( body = await request.body() data = json.loads(body) - stream = await self.complete(data, x_api_key) + is_fim_request = self._is_fim_request(request, data) + stream = await self.complete(data, x_api_key, is_fim_request) return self._completion_handler.create_streaming_response(stream) diff --git a/src/codegate/providers/base.py b/src/codegate/providers/base.py index 940d93b2..65a922d6 100644 --- a/src/codegate/providers/base.py +++ b/src/codegate/providers/base.py @@ -1,15 +1,17 @@ from abc import ABC, abstractmethod from typing import Any, AsyncIterator, Callable, Dict, Optional, Union -from fastapi import APIRouter +from fastapi import APIRouter, Request from litellm import ModelResponse from litellm.types.llms.openai import ChatCompletionRequest +from codegate.codegate_logging import setup_logging from codegate.pipeline.base import PipelineResult, SequentialPipelineProcessor from codegate.providers.completion.base import BaseCompletionHandler from codegate.providers.formatting.input_pipeline import PipelineResponseFormatter from codegate.providers.normalizer.base import ModelInputNormalizer, ModelOutputNormalizer +logger = setup_logging() StreamGenerator = Callable[[AsyncIterator[Any]], AsyncIterator[str]] @@ -25,12 +27,14 @@ def __init__( output_normalizer: ModelOutputNormalizer, completion_handler: BaseCompletionHandler, pipeline_processor: Optional[SequentialPipelineProcessor] = None, + fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None, ): self.router = APIRouter() self._completion_handler = completion_handler self._input_normalizer = input_normalizer self._output_normalizer = output_normalizer self._pipeline_processor = pipeline_processor + self._fim_pipelin_processor = fim_pipeline_processor self._pipeline_response_formatter = PipelineResponseFormatter(output_normalizer) @@ -48,11 +52,19 @@ def provider_route_name(self) -> str: async def _run_input_pipeline( self, normalized_request: ChatCompletionRequest, + is_fim_request: bool ) -> PipelineResult: - if self._pipeline_processor is None: + # Decide which pipeline processor to use + if is_fim_request: + pipeline_processor = self._fim_pipelin_processor + logger.info('FIM pipeline selected for execution.') + else: + pipeline_processor = self._pipeline_processor + logger.info('Chat completion pipeline selected for execution.') + if pipeline_processor is None: return PipelineResult(request=normalized_request) - result = await self._pipeline_processor.process_request(normalized_request) + result = await pipeline_processor.process_request(normalized_request) # TODO(jakub): handle this by returning a message to the client if result.error_message: @@ -60,10 +72,56 @@ async def _run_input_pipeline( return result + def _is_fim_request_url(self, request: Request) -> bool: + """ + Checks the request URL to determine if a request is FIM or chat completion. + Used by: llama.cpp + """ + request_path = request.url.path + # Evaluate first a larger substring. + if request_path.endswith("/chat/completions"): + return False + + if request_path.endswith("/completions"): + return True + + return False + + def _is_fim_request_body(self, data: Dict) -> bool: + """ + Determine from the raw incoming data if it's a FIM request. + Used by: OpenAI and Anthropic + """ + messages = data.get('messages', []) + if not messages: + return False + + first_message_content = messages[0].get('content') + if first_message_content is None: + return False + + fim_stop_sequences = ['', '', '', ''] + if isinstance(first_message_content, str): + msg_prompt = first_message_content + elif isinstance(first_message_content, list): + msg_prompt = first_message_content[0].get('text', '') + else: + logger.warning(f'Could not determine if message was FIM from data: {data}') + return False + return all([stop_sequence in msg_prompt for stop_sequence in fim_stop_sequences]) + + def _is_fim_request(self, request: Request, data: Dict) -> bool: + """ + Determin if the request is FIM by the URL or the data of the request. + """ + # Avoid more expensive inspection of body by just checking the URL. + if self._is_fim_request_url(request): + return True + + return self._is_fim_request_body(data) + async def complete( - self, - data: Dict, - api_key: Optional[str], + self, data: Dict, api_key: Optional[str], is_fim_request: bool ) -> Union[ModelResponse, AsyncIterator[ModelResponse]]: """ Main completion flow with pipeline integration @@ -79,7 +137,7 @@ async def complete( normalized_request = self._input_normalizer.normalize(data) streaming = data.get("stream", False) - input_pipeline_result = await self._run_input_pipeline(normalized_request) + input_pipeline_result = await self._run_input_pipeline(normalized_request, is_fim_request) if input_pipeline_result.response: return self._pipeline_response_formatter.handle_pipeline_response( input_pipeline_result.response, streaming diff --git a/src/codegate/providers/litellmshim/litellmshim.py b/src/codegate/providers/litellmshim/litellmshim.py index a2be0431..1b4dcdf5 100644 --- a/src/codegate/providers/litellmshim/litellmshim.py +++ b/src/codegate/providers/litellmshim/litellmshim.py @@ -43,11 +43,3 @@ def create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResp }, status_code=200, ) - - def is_fim_request(self, data: Dict) -> bool: - """ - Determine from the raw incoming data if it's a FIM request. - This is needed here since completion_handler is used by provider and provider - doesn't know about the adapter. - """ - return self._adapter.is_fim_request(data) diff --git a/src/codegate/providers/llamacpp/completion_handler.py b/src/codegate/providers/llamacpp/completion_handler.py index e99deaf3..40046975 100644 --- a/src/codegate/providers/llamacpp/completion_handler.py +++ b/src/codegate/providers/llamacpp/completion_handler.py @@ -65,11 +65,3 @@ def create_streaming_response(self, stream: Iterator[Any]) -> StreamingResponse: }, status_code=200, ) - - def is_fim_request(self, data: Dict) -> bool: - """ - Determine from the raw incoming data if it's a FIM request. - This is needed here since completion_handler is used by provider and provider - doesn't know about the adapter. - """ - return self._adapter.is_fim_request(data) diff --git a/src/codegate/providers/llamacpp/provider.py b/src/codegate/providers/llamacpp/provider.py index f7e8c981..eddfa901 100644 --- a/src/codegate/providers/llamacpp/provider.py +++ b/src/codegate/providers/llamacpp/provider.py @@ -3,21 +3,24 @@ from fastapi import Request -from codegate.pipeline.base import SequentialPipelineProcessor -from codegate.providers.base import BaseProvider -from codegate.providers.llamacpp.adapter import LlamaCppAdapter +from codegate.providers.base import BaseProvider, SequentialPipelineProcessor from codegate.providers.llamacpp.completion_handler import LlamaCppCompletionHandler from codegate.providers.llamacpp.normalizer import LLamaCppInputNormalizer, LLamaCppOutputNormalizer class LlamaCppProvider(BaseProvider): - def __init__(self, pipeline_processor=None): + def __init__( + self, + pipeline_processor: Optional[SequentialPipelineProcessor] = None, + fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None + ): completion_handler = LlamaCppCompletionHandler() super().__init__( LLamaCppInputNormalizer(), LLamaCppOutputNormalizer(), completion_handler, pipeline_processor, + fim_pipeline_processor ) @property @@ -37,5 +40,6 @@ async def create_completion( body = await request.body() data = json.loads(body) - stream = await self.complete(data, api_key=None) + is_fim_request = self._is_fim_request(request, data) + stream = await self.complete(data, None, is_fim_request=is_fim_request) return self._completion_handler.create_streaming_response(stream) diff --git a/src/codegate/providers/openai/provider.py b/src/codegate/providers/openai/provider.py index e221ea92..209118b5 100644 --- a/src/codegate/providers/openai/provider.py +++ b/src/codegate/providers/openai/provider.py @@ -3,20 +3,24 @@ from fastapi import Header, HTTPException, Request -from codegate.pipeline.base import SequentialPipelineProcessor -from codegate.providers.base import BaseProvider +from codegate.providers.base import BaseProvider, SequentialPipelineProcessor from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator from codegate.providers.openai.adapter import OpenAIInputNormalizer, OpenAIOutputNormalizer class OpenAIProvider(BaseProvider): - def __init__(self, pipeline_processor=None): + def __init__( + self, + pipeline_processor: Optional[SequentialPipelineProcessor] = None, + fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None + ): completion_handler = LiteLLmShim(stream_generator=sse_stream_generator) super().__init__( OpenAIInputNormalizer(), OpenAIOutputNormalizer(), completion_handler, pipeline_processor, + fim_pipeline_processor ) @property @@ -31,22 +35,8 @@ def _setup_routes(self): """ @self.router.post(f"/{self.provider_route_name}/chat/completions") - async def create_completion( - request: Request, - authorization: str = Header(..., description="Bearer token"), - ): - if not authorization.startswith("Bearer "): - raise HTTPException(status_code=401, detail="Invalid authorization header") - - api_key = authorization.split(" ")[1] - body = await request.body() - data = json.loads(body) - - stream = await self.complete(data, api_key) - return self._completion_handler.create_streaming_response(stream) - @self.router.post(f"/{self.provider_route_name}/completions") - async def create_fim( + async def create_completion( request: Request, authorization: str = Header(..., description="Bearer token"), ): @@ -57,5 +47,6 @@ async def create_fim( body = await request.body() data = json.loads(body) - stream = await self.complete(data, api_key) + is_fim_request = self._is_fim_request(request, data) + stream = await self.complete(data, api_key, is_fim_request=is_fim_request) return self._completion_handler.create_streaming_response(stream) diff --git a/src/codegate/server.py b/src/codegate/server.py index db6b7244..e506faa8 100644 --- a/src/codegate/server.py +++ b/src/codegate/server.py @@ -4,7 +4,6 @@ from codegate import __description__, __version__ from codegate.pipeline.base import PipelineStep, SequentialPipelineProcessor -from codegate.pipeline.version.version import CodegateVersion from codegate.pipeline.fim.secret_analyzer import SecretAnalyzer from codegate.pipeline.version.version import CodegateVersion from codegate.providers.anthropic.provider import AnthropicProvider @@ -26,7 +25,6 @@ def init_app() -> FastAPI: fim_steps: List[PipelineStep] = [ SecretAnalyzer(), ] - pipeline = SequentialPipelineProcessor(steps) fim_pipeline = SequentialPipelineProcessor(fim_steps) diff --git a/tests/providers/anthropic/test_adapter.py b/tests/providers/anthropic/test_adapter.py index 34452912..9bb81e54 100644 --- a/tests/providers/anthropic/test_adapter.py +++ b/tests/providers/anthropic/test_adapter.py @@ -40,7 +40,7 @@ def test_normalize_anthropic_input(input_normalizer): {"content": "You are an expert code reviewer", "role": "system"}, {"content": [{"text": "Review this code", "type": "text"}], "role": "user"}, ], - "model": "anthropic/claude-3-haiku-20240307", + "model": "claude-3-haiku-20240307", "stream": True, } diff --git a/tests/providers/litellmshim/test_litellmshim.py b/tests/providers/litellmshim/test_litellmshim.py index edf87996..73889a34 100644 --- a/tests/providers/litellmshim/test_litellmshim.py +++ b/tests/providers/litellmshim/test_litellmshim.py @@ -37,9 +37,6 @@ async def modified_stream(): return modified_stream() - def is_fim_request(self, data: Dict) -> bool: - return False - @pytest.mark.asyncio async def test_complete_non_streaming(): From e2971a2a0552e45037984d21aa0b616307807d28 Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Thu, 28 Nov 2024 12:46:46 +0200 Subject: [PATCH 6/9] Added basic unit test --- tests/providers/test_registry.py | 3 -- tests/test_provider.py | 82 ++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 3 deletions(-) create mode 100644 tests/test_provider.py diff --git a/tests/providers/test_registry.py b/tests/providers/test_registry.py index 64234f32..8c957f13 100644 --- a/tests/providers/test_registry.py +++ b/tests/providers/test_registry.py @@ -49,9 +49,6 @@ def create_streaming_response( ) -> StreamingResponse: return StreamingResponse(stream) - def is_fim_request(self, data: Dict) -> bool: - return False - class MockInputNormalizer(ModelInputNormalizer): def normalize(self, data: Dict) -> Dict: diff --git a/tests/test_provider.py b/tests/test_provider.py new file mode 100644 index 00000000..0957d618 --- /dev/null +++ b/tests/test_provider.py @@ -0,0 +1,82 @@ +from unittest.mock import MagicMock + +import pytest + +from codegate.providers.base import BaseProvider + + +class MockProvider(BaseProvider): + + def __init__(self): + mocked_input_normalizer = MagicMock() + mocked_output_normalizer = MagicMock() + mocked_completion_handler = MagicMock() + mocked_pipepeline = MagicMock() + mocked_fim_pipeline = MagicMock() + super().__init__( + mocked_input_normalizer, + mocked_output_normalizer, + mocked_completion_handler, + mocked_pipepeline, + mocked_fim_pipeline + ) + + def _setup_routes(self) -> None: + pass + + @property + def provider_route_name(self) -> str: + return 'mock-provider' + + +@pytest.mark.parametrize("url, expected_bool", [ + ("http://example.com", False), + ("http://test.com/chat/completions", False), + ("http://example.com/completions", True), +]) +def test_is_fim_request_url(url, expected_bool): + mock_provider = MockProvider() + request = MagicMock() + request.url.path = url + assert mock_provider._is_fim_request_url(request) == expected_bool + + +DATA_CONTENT_STR = { + "messages": [ + { + "role": "user", + "content": ' ', + } + ] +} +DATA_CONTENT_LIST = { + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": " " + } + ], + } + ] +} +INVALID_DATA_CONTET = { + "messages": [ + { + "role": "user", + "content": "http://example.com/completions", + } + ] +} + + +@pytest.mark.parametrize("data, expected_bool", [ + (DATA_CONTENT_STR, True), + (DATA_CONTENT_LIST, True), + (INVALID_DATA_CONTET, False), +]) +def test_is_fim_request_body(data, expected_bool): + mock_provider = MockProvider() + assert mock_provider._is_fim_request_body(data) == expected_bool From e45e15d792e4be4c9986322106d4104605a49218 Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Thu, 28 Nov 2024 13:09:54 +0200 Subject: [PATCH 7/9] Minor modifications to docstrings --- src/codegate/pipeline/fim/secret_analyzer.py | 2 +- .../providers/anthropic/completion_handler.py | 14 ++++++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/codegate/pipeline/fim/secret_analyzer.py b/src/codegate/pipeline/fim/secret_analyzer.py index ff0eeb4b..83f89a11 100644 --- a/src/codegate/pipeline/fim/secret_analyzer.py +++ b/src/codegate/pipeline/fim/secret_analyzer.py @@ -7,7 +7,7 @@ class SecretAnalyzer(PipelineStep): - """Pipeline step that handles version information requests.""" + """Pipeline step that handles analyzing secrets in FIM pipeline.""" message_blocked = """ ⚠️ CodeGate Security Warning! Analysis Report ⚠️ diff --git a/src/codegate/providers/anthropic/completion_handler.py b/src/codegate/providers/anthropic/completion_handler.py index e5b77819..ed45eb40 100644 --- a/src/codegate/providers/anthropic/completion_handler.py +++ b/src/codegate/providers/anthropic/completion_handler.py @@ -7,9 +7,7 @@ class AnthropicCompletion(LiteLLmShim): """ - LiteLLM Shim is a wrapper around LiteLLM's API that allows us to use it with - our own completion handler interface without exposing the underlying - LiteLLM API. + AnthropicCompletion used by the Anthropic provider to execute completions """ async def execute_completion( @@ -19,7 +17,15 @@ async def execute_completion( stream: bool = False, ) -> Union[ModelResponse, AsyncIterator[ModelResponse]]: """ - Execute the completion request with LiteLLM's API + Ensures the model name is prefixed with 'anthropic/' to explicitly route to Anthropic's API. + + LiteLLM automatically maps most model names, but prepending 'anthropic/' forces the request + to Anthropic. This avoids issues with unrecognized names like 'claude-3-5-sonnet-latest', + which LiteLLM doesn't accept as a valid Anthropic model. This safeguard may be unnecessary + but ensures compatibility. + + For more details, refer to the + [LiteLLM Documentation](https://docs.litellm.ai/docs/providers/anthropic). """ model_in_request = request['model'] if not model_in_request.startswith('anthropic/'): From 438714278e6ec094c380a319171afb853e2a5b49 Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Thu, 28 Nov 2024 16:40:18 +0200 Subject: [PATCH 8/9] Leaving the FIM pipeline empty for now --- src/codegate/providers/base.py | 3 +-- src/codegate/providers/llamacpp/normalizer.py | 10 ++++++++++ src/codegate/server.py | 3 +-- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/codegate/providers/base.py b/src/codegate/providers/base.py index 65a922d6..311de2d0 100644 --- a/src/codegate/providers/base.py +++ b/src/codegate/providers/base.py @@ -136,7 +136,6 @@ async def complete( """ normalized_request = self._input_normalizer.normalize(data) streaming = data.get("stream", False) - input_pipeline_result = await self._run_input_pipeline(normalized_request, is_fim_request) if input_pipeline_result.response: return self._pipeline_response_formatter.handle_pipeline_response( @@ -151,7 +150,7 @@ async def complete( model_response = await self._completion_handler.execute_completion( provider_request, api_key=api_key, stream=streaming ) - + print(f'Model response: {model_response}') if not streaming: return self._output_normalizer.denormalize(model_response) return self._output_normalizer.denormalize_streaming(model_response) diff --git a/src/codegate/providers/llamacpp/normalizer.py b/src/codegate/providers/llamacpp/normalizer.py index 38a97a79..8a0ae077 100644 --- a/src/codegate/providers/llamacpp/normalizer.py +++ b/src/codegate/providers/llamacpp/normalizer.py @@ -10,6 +10,11 @@ def normalize(self, data: Dict) -> ChatCompletionRequest: """ Normalize the input data """ + # When doing FIM, we receive "prompt" instead of messages. Normalizing. + if "prompt" in data: + data["messages"] = [{"content": data.pop("prompt"), "role": "user"}] + # We can add as many parameters as we like to data. ChatCompletionRequest is not strict. + data["had_prompt_before"] = True try: return ChatCompletionRequest(**data) except Exception as e: @@ -19,6 +24,11 @@ def denormalize(self, data: ChatCompletionRequest) -> Dict: """ Denormalize the input data """ + # If we receive "prompt" in FIM, we need convert it back. + if data.get("had_prompt_before", False): + data["prompt"] = data["messages"][0]["content"] + del data["had_prompt_before"] + del data["messages"] return data diff --git a/src/codegate/server.py b/src/codegate/server.py index e506faa8..14ccb60a 100644 --- a/src/codegate/server.py +++ b/src/codegate/server.py @@ -4,7 +4,6 @@ from codegate import __description__, __version__ from codegate.pipeline.base import PipelineStep, SequentialPipelineProcessor -from codegate.pipeline.fim.secret_analyzer import SecretAnalyzer from codegate.pipeline.version.version import CodegateVersion from codegate.providers.anthropic.provider import AnthropicProvider from codegate.providers.llamacpp.provider import LlamaCppProvider @@ -22,8 +21,8 @@ def init_app() -> FastAPI: steps: List[PipelineStep] = [ CodegateVersion(), ] + # Leaving the pipeline empty for now fim_steps: List[PipelineStep] = [ - SecretAnalyzer(), ] pipeline = SequentialPipelineProcessor(steps) fim_pipeline = SequentialPipelineProcessor(fim_steps) From 8bb074c93e9bc7e0d20158bcc4fd97006d6f3426 Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Thu, 28 Nov 2024 16:56:10 +0200 Subject: [PATCH 9/9] Remove print statement --- src/codegate/providers/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/codegate/providers/base.py b/src/codegate/providers/base.py index 311de2d0..4f2d65e0 100644 --- a/src/codegate/providers/base.py +++ b/src/codegate/providers/base.py @@ -150,7 +150,6 @@ async def complete( model_response = await self._completion_handler.execute_completion( provider_request, api_key=api_key, stream=streaming ) - print(f'Model response: {model_response}') if not streaming: return self._output_normalizer.denormalize(model_response) return self._output_normalizer.denormalize_streaming(model_response)