From 57dbfb1a5a64de843c4789d95bab80605ccc591c Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Thu, 12 Dec 2024 12:54:55 +0000 Subject: [PATCH 1/3] Copilot chats are sent through an input pipeline --- src/codegate/providers/copilot/pipeline.py | 30 ++++- src/codegate/providers/copilot/provider.py | 130 +++++++++++++++++---- 2 files changed, 135 insertions(+), 25 deletions(-) diff --git a/src/codegate/providers/copilot/pipeline.py b/src/codegate/providers/copilot/pipeline.py index 3a0c8dee..c61bc9c1 100644 --- a/src/codegate/providers/copilot/pipeline.py +++ b/src/codegate/providers/copilot/pipeline.py @@ -101,6 +101,21 @@ def denormalize(self, request_from_pipeline: ChatCompletionRequest) -> bytes: return json.dumps(normalized_json_body).encode() +class CopilotChatNormalizer: + """ + A custom normalizer for the chat format used by Copilot + The requests are already in the OpenAI format, we just need + to unmarshall them and marshall them back. + """ + + def normalize(self, body: bytes) -> ChatCompletionRequest: + json_body = json.loads(body) + return ChatCompletionRequest(**json_body) + + def denormalize(self, request_from_pipeline: ChatCompletionRequest) -> bytes: + return json.dumps(request_from_pipeline).encode() + + class CopilotFimPipeline(CopilotPipeline): """ A pipeline for the FIM format used by Copilot. Combines the normalizer for the FIM @@ -108,7 +123,20 @@ class CopilotFimPipeline(CopilotPipeline): """ def _create_normalizer(self): - return CopilotFimNormalizer() # Uses your custom normalizer + return CopilotFimNormalizer() def create_pipeline(self): return self.pipeline_factory.create_fim_pipeline() + + +class CopilotChatPipeline(CopilotPipeline): + """ + A pipeline for the Chat format used by Copilot. Combines the normalizer for the FIM + format and the FIM pipeline used by all providers. + """ + + def _create_normalizer(self): + return CopilotChatNormalizer() + + def create_pipeline(self): + return self.pipeline_factory.create_input_pipeline() diff --git a/src/codegate/providers/copilot/provider.py b/src/codegate/providers/copilot/provider.py index ab040531..28e89409 100644 --- a/src/codegate/providers/copilot/provider.py +++ b/src/codegate/providers/copilot/provider.py @@ -13,7 +13,11 @@ from codegate.pipeline.factory import PipelineFactory from codegate.pipeline.secrets.manager import SecretsManager from codegate.providers.copilot.mapping import VALIDATED_ROUTES -from codegate.providers.copilot.pipeline import CopilotFimPipeline +from codegate.providers.copilot.pipeline import ( + CopilotChatPipeline, + CopilotFimPipeline, + CopilotPipeline, +) logger = structlog.get_logger("codegate") @@ -38,6 +42,61 @@ class HttpRequest: headers: List[str] original_path: str target: Optional[str] = None + body: Optional[bytes] = None + + def reconstruct(self) -> bytes: + """Reconstruct HTTP request from stored details""" + headers = "\r\n".join(self.headers) + request_line = f"{self.method} /{self.path} {self.version}\r\n" + header_block = f"{request_line}{headers}\r\n\r\n" + + # Convert header block to bytes and combine with body + result = header_block.encode("utf-8") + if self.body: + result += self.body + + return result + + +def extract_path(full_path: str) -> str: + """Extract clean path from full URL or path string""" + logger.debug(f"Extracting path from {full_path}") + if full_path.startswith(("http://", "https://")): + parsed = urlparse(full_path) + path = parsed.path + if parsed.query: + path = f"{path}?{parsed.query}" + return path.lstrip("/") + return full_path.lstrip("/") + + +def http_request_from_bytes(data: bytes) -> Optional[HttpRequest]: + """ + Parse HTTP request details from raw bytes data. + TODO: Make safer by checking for valid HTTP request format, check + if there is a method if there are headers, etc. + """ + if b"\r\n\r\n" not in data: + return None + + headers_end = data.index(b"\r\n\r\n") + headers = data[:headers_end].split(b"\r\n") + + request = headers[0].decode("utf-8") + method, full_path, version = request.split(" ") + + body_start = data.index(b"\r\n\r\n") + 4 + body = data[body_start:] + + return HttpRequest( + method=method, + path=extract_path(full_path), + version=version, + headers=[header.decode("utf-8") for header in headers[1:]], + original_path=full_path, + target=full_path if method == "CONNECT" else None, + body=body, + ) class CopilotProvider(asyncio.Protocol): @@ -63,20 +122,26 @@ def __init__(self, loop: asyncio.AbstractEventLoop): self.pipeline_factory = PipelineFactory(SecretsManager()) self.context_tracking: Optional[PipelineContext] = None - def _select_pipeline(self): - if ( - self.request.method == "POST" - and self.request.path == "v1/engines/copilot-codex/completions" - ): + def _select_pipeline(self, method: str, path: str) -> Optional[CopilotPipeline]: + if method == "POST" and path == "v1/engines/copilot-codex/completions": logger.debug("Selected CopilotFimStrategy") return CopilotFimPipeline(self.pipeline_factory) + if method == "POST" and path == "chat/completions": + logger.debug("Selected CopilotChatStrategy") + return CopilotChatPipeline(self.pipeline_factory) logger.debug("No pipeline strategy selected") return None - async def _body_through_pipeline(self, headers: list[str], body: bytes) -> bytes: + async def _body_through_pipeline( + self, + method: str, + path: str, + headers: list[str], + body: bytes, + ) -> bytes: logger.debug(f"Processing body through pipeline: {len(body)} bytes") - strategy = self._select_pipeline() + strategy = self._select_pipeline(method, path) if strategy is None: # if we didn't select any strategy that would change the request # let's just pass through the body as-is @@ -89,7 +154,12 @@ async def _request_to_target(self, headers: list[str], body: bytes): ).encode() logger.debug(f"Request Line: {request_line}") - body = await self._body_through_pipeline(headers, body) + body = await self._body_through_pipeline( + self.request.method, + self.request.path, + headers, + body, + ) for header in headers: if header.lower().startswith("content-length:"): @@ -113,18 +183,6 @@ def connection_made(self, transport: asyncio.Transport) -> None: self.peername = transport.get_extra_info("peername") logger.debug(f"Client connected from {self.peername}") - @staticmethod - def extract_path(full_path: str) -> str: - """Extract clean path from full URL or path string""" - logger.debug(f"Extracting path from {full_path}") - if full_path.startswith(("http://", "https://")): - parsed = urlparse(full_path) - path = parsed.path - if parsed.query: - path = f"{path}?{parsed.query}" - return path.lstrip("/") - return full_path.lstrip("/") - def get_headers_dict(self) -> Dict[str, str]: """Convert raw headers to dictionary format""" headers_dict = {} @@ -161,7 +219,7 @@ def parse_headers(self) -> bool: self.request = HttpRequest( method=method, - path=self.extract_path(full_path), + path=extract_path(full_path), version=version, headers=[header.decode("utf-8") for header in headers[1:]], original_path=full_path, @@ -179,9 +237,33 @@ def _check_buffer_size(self, new_data: bytes) -> bool: """Check if adding new data would exceed buffer size limit""" return len(self.buffer) + len(new_data) <= MAX_BUFFER_SIZE - def _forward_data_to_target(self, data: bytes) -> None: + async def _forward_data_through_pipeline(self, data: bytes) -> bytes: + http_request = http_request_from_bytes(data) + if not http_request: + # we couldn't parse this into an HTTP request, so we just pass through + return data + + http_request.body = await self._body_through_pipeline( + http_request.method, + http_request.path, + http_request.headers, + http_request.body, + ) + + for header in http_request.headers: + if header.lower().startswith("content-length:"): + http_request.headers.remove(header) + break + http_request.headers.append(f"Content-Length: {len(http_request.body)}") + + pipeline_data = http_request.reconstruct() + + return pipeline_data + + async def _forward_data_to_target(self, data: bytes) -> None: """Forward data to target if connection is established""" if self.target_transport and not self.target_transport.is_closing(): + data = await self._forward_data_through_pipeline(data) self.target_transport.write(data) def data_received(self, data: bytes) -> None: @@ -201,7 +283,7 @@ def data_received(self, data: bytes) -> None: else: asyncio.create_task(self.handle_http_request()) else: - self._forward_data_to_target(data) + asyncio.create_task(self._forward_data_to_target(data)) except Exception as e: logger.error(f"Error processing received data: {e}") From ddc0723f1f197688359651ee7848151343730320 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Thu, 12 Dec 2024 13:36:19 +0000 Subject: [PATCH 2/3] Fix up headers into dictionary --- src/codegate/providers/copilot/pipeline.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/codegate/providers/copilot/pipeline.py b/src/codegate/providers/copilot/pipeline.py index c61bc9c1..6ad519ce 100644 --- a/src/codegate/providers/copilot/pipeline.py +++ b/src/codegate/providers/copilot/pipeline.py @@ -59,15 +59,23 @@ async def process_body(self, headers: list[str], body: bytes) -> bytes: try: normalized_body = self.normalizer.normalize(body) + headers_dict = {} + for header in headers: + try: + name, value = header.split(":", 1) + headers_dict[name.strip().lower()] = value.strip() + except ValueError: + continue + pipeline = self.create_pipeline() result = await pipeline.process_request( request=normalized_body, provider=self.provider_name, prompt_id=self._request_id(headers), model=normalized_body.get("model", "gpt-4o-mini"), - api_key = headers.get('authorization','').replace('Bearer ', ''), - api_base = "https://" + headers.get('host', ''), - extra_headers=CopilotPipeline._get_copilot_headers(headers) + api_key = headers_dict.get('authorization','').replace('Bearer ', ''), + api_base = "https://" + headers_dict.get('host', ''), + extra_headers=CopilotPipeline._get_copilot_headers(headers_dict) ) if result.request: From 68bb0b2d3920a93956b6c9392414724517d577b5 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Thu, 12 Dec 2024 13:36:50 +0000 Subject: [PATCH 3/3] Run make all --- src/codegate/llm_utils/extractor.py | 4 ++-- src/codegate/llm_utils/llmclient.py | 2 +- src/codegate/pipeline/base.py | 4 ++-- .../codegate_context_retriever/codegate.py | 4 ++-- src/codegate/pipeline/secrets/secrets.py | 2 +- src/codegate/providers/copilot/pipeline.py | 24 ++++++++++++------- 6 files changed, 24 insertions(+), 16 deletions(-) diff --git a/src/codegate/llm_utils/extractor.py b/src/codegate/llm_utils/extractor.py index b4b2514e..4325f909 100644 --- a/src/codegate/llm_utils/extractor.py +++ b/src/codegate/llm_utils/extractor.py @@ -24,7 +24,7 @@ async def extract_packages( model: str = None, base_url: Optional[str] = None, api_key: Optional[str] = None, - extra_headers: Optional[Dict[str, str]] = None + extra_headers: Optional[Dict[str, str]] = None, ) -> List[str]: """Extract package names from the given content.""" system_prompt = Config.get_config().prompts.lookup_packages @@ -51,7 +51,7 @@ async def extract_ecosystem( model: str = None, base_url: Optional[str] = None, api_key: Optional[str] = None, - extra_headers: Optional[Dict[str, str]] = None + extra_headers: Optional[Dict[str, str]] = None, ) -> List[str]: """Extract ecosystem from the given content.""" system_prompt = Config.get_config().prompts.lookup_ecosystem diff --git a/src/codegate/llm_utils/llmclient.py b/src/codegate/llm_utils/llmclient.py index 64c04f1a..a74c4d00 100644 --- a/src/codegate/llm_utils/llmclient.py +++ b/src/codegate/llm_utils/llmclient.py @@ -137,7 +137,7 @@ async def _complete_litellm( temperature=request["temperature"], base_url=base_url, response_format=request["response_format"], - extra_headers=extra_headers + extra_headers=extra_headers, ) content = response["choices"][0]["message"]["content"] diff --git a/src/codegate/pipeline/base.py b/src/codegate/pipeline/base.py index f2f44e65..d39da34f 100644 --- a/src/codegate/pipeline/base.py +++ b/src/codegate/pipeline/base.py @@ -224,7 +224,7 @@ async def process_request( model: str, api_key: Optional[str] = None, api_base: Optional[str] = None, - extra_headers: Optional[Dict[str, str]] = None + extra_headers: Optional[Dict[str, str]] = None, ) -> PipelineResult: """Process a request through all pipeline steps""" self.context.sensitive = PipelineSensitiveData( @@ -273,7 +273,7 @@ async def process_request( model: str, api_key: Optional[str] = None, api_base: Optional[str] = None, - extra_headers: Optional[Dict[str, str]] = None + extra_headers: Optional[Dict[str, str]] = None, ) -> PipelineResult: """Create a new pipeline instance and process the request""" instance = self.create_instance() diff --git a/src/codegate/pipeline/codegate_context_retriever/codegate.py b/src/codegate/pipeline/codegate_context_retriever/codegate.py index 62780235..6da1da93 100644 --- a/src/codegate/pipeline/codegate_context_retriever/codegate.py +++ b/src/codegate/pipeline/codegate_context_retriever/codegate.py @@ -66,7 +66,7 @@ async def __lookup_packages(self, user_query: str, context: PipelineContext): model=context.sensitive.model, api_key=context.sensitive.api_key, base_url=context.sensitive.api_base, - extra_headers=context.metadata.get('extra_headers', None), + extra_headers=context.metadata.get("extra_headers", None), ) logger.info(f"Packages in user query: {packages}") @@ -80,7 +80,7 @@ async def __lookup_ecosystem(self, user_query: str, context: PipelineContext): model=context.sensitive.model, api_key=context.sensitive.api_key, base_url=context.sensitive.api_base, - extra_headers=context.metadata.get('extra_headers', None), + extra_headers=context.metadata.get("extra_headers", None), ) logger.info(f"Ecosystem in user query: {ecosystem}") diff --git a/src/codegate/pipeline/secrets/secrets.py b/src/codegate/pipeline/secrets/secrets.py index a775f0f5..f12eda5f 100644 --- a/src/codegate/pipeline/secrets/secrets.py +++ b/src/codegate/pipeline/secrets/secrets.py @@ -176,7 +176,7 @@ async def process( PipelineResult containing the processed request and context with redaction metadata """ - if 'messages' not in request: + if "messages" not in request: return PipelineResult(request=request, context=context) secrets_manager = context.sensitive.manager diff --git a/src/codegate/providers/copilot/pipeline.py b/src/codegate/providers/copilot/pipeline.py index 6ad519ce..44d519a5 100644 --- a/src/codegate/providers/copilot/pipeline.py +++ b/src/codegate/providers/copilot/pipeline.py @@ -44,13 +44,21 @@ def _request_id(headers: list[str]) -> str: @staticmethod def _get_copilot_headers(headers: Dict[str, str]) -> Dict[str, str]: - copilot_header_names = ['copilot-integration-id', 'editor-plugin-version', 'editor-version', - 'openai-intent', 'openai-organization', 'user-agent', - 'vscode-machineid', 'vscode-sessionid', 'x-github-api-version', - 'x-request-id'] + copilot_header_names = [ + "copilot-integration-id", + "editor-plugin-version", + "editor-version", + "openai-intent", + "openai-organization", + "user-agent", + "vscode-machineid", + "vscode-sessionid", + "x-github-api-version", + "x-request-id", + ] copilot_headers = {} for a_name in copilot_header_names: - copilot_headers[a_name] = headers.get(a_name, '') + copilot_headers[a_name] = headers.get(a_name, "") return copilot_headers @@ -73,9 +81,9 @@ async def process_body(self, headers: list[str], body: bytes) -> bytes: provider=self.provider_name, prompt_id=self._request_id(headers), model=normalized_body.get("model", "gpt-4o-mini"), - api_key = headers_dict.get('authorization','').replace('Bearer ', ''), - api_base = "https://" + headers_dict.get('host', ''), - extra_headers=CopilotPipeline._get_copilot_headers(headers_dict) + api_key=headers_dict.get("authorization", "").replace("Bearer ", ""), + api_base="https://" + headers_dict.get("host", ""), + extra_headers=CopilotPipeline._get_copilot_headers(headers_dict), ) if result.request: