Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Copilot chats are sent through an input pipeline #315

Merged
merged 3 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/codegate/llm_utils/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ async def extract_packages(
model: str = None,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
extra_headers: Optional[Dict[str, str]] = None
extra_headers: Optional[Dict[str, str]] = None,
) -> List[str]:
"""Extract package names from the given content."""
system_prompt = Config.get_config().prompts.lookup_packages
Expand All @@ -51,7 +51,7 @@ async def extract_ecosystem(
model: str = None,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
extra_headers: Optional[Dict[str, str]] = None
extra_headers: Optional[Dict[str, str]] = None,
) -> List[str]:
"""Extract ecosystem from the given content."""
system_prompt = Config.get_config().prompts.lookup_ecosystem
Expand Down
2 changes: 1 addition & 1 deletion src/codegate/llm_utils/llmclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ async def _complete_litellm(
temperature=request["temperature"],
base_url=base_url,
response_format=request["response_format"],
extra_headers=extra_headers
extra_headers=extra_headers,
)
content = response["choices"][0]["message"]["content"]

Expand Down
4 changes: 2 additions & 2 deletions src/codegate/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ async def process_request(
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
extra_headers: Optional[Dict[str, str]] = None
extra_headers: Optional[Dict[str, str]] = None,
) -> PipelineResult:
"""Process a request through all pipeline steps"""
self.context.sensitive = PipelineSensitiveData(
Expand Down Expand Up @@ -273,7 +273,7 @@ async def process_request(
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
extra_headers: Optional[Dict[str, str]] = None
extra_headers: Optional[Dict[str, str]] = None,
) -> PipelineResult:
"""Create a new pipeline instance and process the request"""
instance = self.create_instance()
Expand Down
4 changes: 2 additions & 2 deletions src/codegate/pipeline/codegate_context_retriever/codegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ async def __lookup_packages(self, user_query: str, context: PipelineContext):
model=context.sensitive.model,
api_key=context.sensitive.api_key,
base_url=context.sensitive.api_base,
extra_headers=context.metadata.get('extra_headers', None),
extra_headers=context.metadata.get("extra_headers", None),
)

logger.info(f"Packages in user query: {packages}")
Expand All @@ -80,7 +80,7 @@ async def __lookup_ecosystem(self, user_query: str, context: PipelineContext):
model=context.sensitive.model,
api_key=context.sensitive.api_key,
base_url=context.sensitive.api_base,
extra_headers=context.metadata.get('extra_headers', None),
extra_headers=context.metadata.get("extra_headers", None),
)

logger.info(f"Ecosystem in user query: {ecosystem}")
Expand Down
2 changes: 1 addition & 1 deletion src/codegate/pipeline/secrets/secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ async def process(
PipelineResult containing the processed request and context with redaction metadata
"""

if 'messages' not in request:
if "messages" not in request:
return PipelineResult(request=request, context=context)

secrets_manager = context.sensitive.manager
Expand Down
62 changes: 53 additions & 9 deletions src/codegate/providers/copilot/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,21 @@ def _request_id(headers: list[str]) -> str:

@staticmethod
def _get_copilot_headers(headers: Dict[str, str]) -> Dict[str, str]:
copilot_header_names = ['copilot-integration-id', 'editor-plugin-version', 'editor-version',
'openai-intent', 'openai-organization', 'user-agent',
'vscode-machineid', 'vscode-sessionid', 'x-github-api-version',
'x-request-id']
copilot_header_names = [
"copilot-integration-id",
"editor-plugin-version",
"editor-version",
"openai-intent",
"openai-organization",
"user-agent",
"vscode-machineid",
"vscode-sessionid",
"x-github-api-version",
"x-request-id",
]
copilot_headers = {}
for a_name in copilot_header_names:
copilot_headers[a_name] = headers.get(a_name, '')
copilot_headers[a_name] = headers.get(a_name, "")

return copilot_headers

Expand All @@ -59,15 +67,23 @@ async def process_body(self, headers: list[str], body: bytes) -> bytes:
try:
normalized_body = self.normalizer.normalize(body)

headers_dict = {}
for header in headers:
try:
name, value = header.split(":", 1)
headers_dict[name.strip().lower()] = value.strip()
except ValueError:
continue

pipeline = self.create_pipeline()
result = await pipeline.process_request(
request=normalized_body,
provider=self.provider_name,
prompt_id=self._request_id(headers),
model=normalized_body.get("model", "gpt-4o-mini"),
api_key = headers.get('authorization','').replace('Bearer ', ''),
api_base = "https://" + headers.get('host', ''),
extra_headers=CopilotPipeline._get_copilot_headers(headers)
api_key=headers_dict.get("authorization", "").replace("Bearer ", ""),
api_base="https://" + headers_dict.get("host", ""),
extra_headers=CopilotPipeline._get_copilot_headers(headers_dict),
)

if result.request:
Expand Down Expand Up @@ -101,14 +117,42 @@ def denormalize(self, request_from_pipeline: ChatCompletionRequest) -> bytes:
return json.dumps(normalized_json_body).encode()


class CopilotChatNormalizer:
"""
A custom normalizer for the chat format used by Copilot
The requests are already in the OpenAI format, we just need
to unmarshall them and marshall them back.
"""

def normalize(self, body: bytes) -> ChatCompletionRequest:
json_body = json.loads(body)
return ChatCompletionRequest(**json_body)

def denormalize(self, request_from_pipeline: ChatCompletionRequest) -> bytes:
return json.dumps(request_from_pipeline).encode()


class CopilotFimPipeline(CopilotPipeline):
"""
A pipeline for the FIM format used by Copilot. Combines the normalizer for the FIM
format and the FIM pipeline used by all providers.
"""

def _create_normalizer(self):
return CopilotFimNormalizer() # Uses your custom normalizer
return CopilotFimNormalizer()

def create_pipeline(self):
return self.pipeline_factory.create_fim_pipeline()


class CopilotChatPipeline(CopilotPipeline):
"""
A pipeline for the Chat format used by Copilot. Combines the normalizer for the FIM
format and the FIM pipeline used by all providers.
"""

def _create_normalizer(self):
return CopilotChatNormalizer()

def create_pipeline(self):
return self.pipeline_factory.create_input_pipeline()
130 changes: 106 additions & 24 deletions src/codegate/providers/copilot/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
from codegate.pipeline.factory import PipelineFactory
from codegate.pipeline.secrets.manager import SecretsManager
from codegate.providers.copilot.mapping import VALIDATED_ROUTES
from codegate.providers.copilot.pipeline import CopilotFimPipeline
from codegate.providers.copilot.pipeline import (
CopilotChatPipeline,
CopilotFimPipeline,
CopilotPipeline,
)

logger = structlog.get_logger("codegate")

Expand All @@ -38,6 +42,61 @@ class HttpRequest:
headers: List[str]
original_path: str
target: Optional[str] = None
body: Optional[bytes] = None

def reconstruct(self) -> bytes:
"""Reconstruct HTTP request from stored details"""
headers = "\r\n".join(self.headers)
request_line = f"{self.method} /{self.path} {self.version}\r\n"
header_block = f"{request_line}{headers}\r\n\r\n"

# Convert header block to bytes and combine with body
result = header_block.encode("utf-8")
if self.body:
result += self.body

return result


def extract_path(full_path: str) -> str:
"""Extract clean path from full URL or path string"""
logger.debug(f"Extracting path from {full_path}")
if full_path.startswith(("http://", "https://")):
parsed = urlparse(full_path)
path = parsed.path
if parsed.query:
path = f"{path}?{parsed.query}"
return path.lstrip("/")
return full_path.lstrip("/")


def http_request_from_bytes(data: bytes) -> Optional[HttpRequest]:
"""
Parse HTTP request details from raw bytes data.
TODO: Make safer by checking for valid HTTP request format, check
if there is a method if there are headers, etc.
"""
if b"\r\n\r\n" not in data:
return None

headers_end = data.index(b"\r\n\r\n")
headers = data[:headers_end].split(b"\r\n")

request = headers[0].decode("utf-8")
method, full_path, version = request.split(" ")

body_start = data.index(b"\r\n\r\n") + 4
body = data[body_start:]

return HttpRequest(
method=method,
path=extract_path(full_path),
version=version,
headers=[header.decode("utf-8") for header in headers[1:]],
original_path=full_path,
target=full_path if method == "CONNECT" else None,
body=body,
)


class CopilotProvider(asyncio.Protocol):
Expand All @@ -63,20 +122,26 @@ def __init__(self, loop: asyncio.AbstractEventLoop):
self.pipeline_factory = PipelineFactory(SecretsManager())
self.context_tracking: Optional[PipelineContext] = None

def _select_pipeline(self):
if (
self.request.method == "POST"
and self.request.path == "v1/engines/copilot-codex/completions"
):
def _select_pipeline(self, method: str, path: str) -> Optional[CopilotPipeline]:
if method == "POST" and path == "v1/engines/copilot-codex/completions":
logger.debug("Selected CopilotFimStrategy")
return CopilotFimPipeline(self.pipeline_factory)
if method == "POST" and path == "chat/completions":
logger.debug("Selected CopilotChatStrategy")
return CopilotChatPipeline(self.pipeline_factory)

logger.debug("No pipeline strategy selected")
return None

async def _body_through_pipeline(self, headers: list[str], body: bytes) -> bytes:
async def _body_through_pipeline(
self,
method: str,
path: str,
headers: list[str],
body: bytes,
) -> bytes:
logger.debug(f"Processing body through pipeline: {len(body)} bytes")
strategy = self._select_pipeline()
strategy = self._select_pipeline(method, path)
if strategy is None:
# if we didn't select any strategy that would change the request
# let's just pass through the body as-is
Expand All @@ -89,7 +154,12 @@ async def _request_to_target(self, headers: list[str], body: bytes):
).encode()
logger.debug(f"Request Line: {request_line}")

body = await self._body_through_pipeline(headers, body)
body = await self._body_through_pipeline(
self.request.method,
self.request.path,
headers,
body,
)

for header in headers:
if header.lower().startswith("content-length:"):
Expand All @@ -113,18 +183,6 @@ def connection_made(self, transport: asyncio.Transport) -> None:
self.peername = transport.get_extra_info("peername")
logger.debug(f"Client connected from {self.peername}")

@staticmethod
def extract_path(full_path: str) -> str:
"""Extract clean path from full URL or path string"""
logger.debug(f"Extracting path from {full_path}")
if full_path.startswith(("http://", "https://")):
parsed = urlparse(full_path)
path = parsed.path
if parsed.query:
path = f"{path}?{parsed.query}"
return path.lstrip("/")
return full_path.lstrip("/")

def get_headers_dict(self) -> Dict[str, str]:
"""Convert raw headers to dictionary format"""
headers_dict = {}
Expand Down Expand Up @@ -161,7 +219,7 @@ def parse_headers(self) -> bool:

self.request = HttpRequest(
method=method,
path=self.extract_path(full_path),
path=extract_path(full_path),
version=version,
headers=[header.decode("utf-8") for header in headers[1:]],
original_path=full_path,
Expand All @@ -179,9 +237,33 @@ def _check_buffer_size(self, new_data: bytes) -> bool:
"""Check if adding new data would exceed buffer size limit"""
return len(self.buffer) + len(new_data) <= MAX_BUFFER_SIZE

def _forward_data_to_target(self, data: bytes) -> None:
async def _forward_data_through_pipeline(self, data: bytes) -> bytes:
http_request = http_request_from_bytes(data)
if not http_request:
# we couldn't parse this into an HTTP request, so we just pass through
return data

http_request.body = await self._body_through_pipeline(
http_request.method,
http_request.path,
http_request.headers,
http_request.body,
)

for header in http_request.headers:
if header.lower().startswith("content-length:"):
http_request.headers.remove(header)
break
http_request.headers.append(f"Content-Length: {len(http_request.body)}")

pipeline_data = http_request.reconstruct()

return pipeline_data

async def _forward_data_to_target(self, data: bytes) -> None:
"""Forward data to target if connection is established"""
if self.target_transport and not self.target_transport.is_closing():
data = await self._forward_data_through_pipeline(data)
self.target_transport.write(data)

def data_received(self, data: bytes) -> None:
Expand All @@ -201,7 +283,7 @@ def data_received(self, data: bytes) -> None:
else:
asyncio.create_task(self.handle_http_request())
else:
self._forward_data_to_target(data)
asyncio.create_task(self._forward_data_to_target(data))

except Exception as e:
logger.error(f"Error processing received data: {e}")
Expand Down