Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 58de99c

Browse files
authored
Merge pull request #319 from jhrozek/copilot_fim_pipeline
Pipe the Copilot output chunks through the output pipeline
2 parents e890340 + c3afb52 commit 58de99c

File tree

4 files changed

+61
-22
lines changed

4 files changed

+61
-22
lines changed

src/codegate/pipeline/extract_snippets/output.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ async def process_chunk(
9898
input_context: Optional[PipelineContext] = None,
9999
) -> list[ModelResponse]:
100100
"""Process a single chunk of the stream"""
101-
if not chunk.choices[0].delta.content:
101+
if len(chunk.choices) == 0 or not chunk.choices[0].delta.content:
102102
return [chunk]
103103

104104
# Get current content plus this new chunk

src/codegate/pipeline/secrets/secrets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ async def process_chunk(
262262
if input_context.sensitive.session_id == "":
263263
raise ValueError("Session ID not found in input context")
264264

265-
if not chunk.choices[0].delta.content:
265+
if len(chunk.choices) == 0 or not chunk.choices[0].delta.content:
266266
return [chunk]
267267

268268
# Check the buffered content

src/codegate/providers/copilot/provider.py

Lines changed: 58 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import asyncio
2-
import json
32
import re
43
import ssl
54
from dataclasses import dataclass
65
from typing import Dict, List, Optional, Tuple
76
from urllib.parse import unquote, urljoin, urlparse
87

98
import structlog
9+
from litellm.types.utils import Delta, ModelResponse, StreamingChoices
1010

1111
from codegate.ca.codegate_ca import CertificateAuthority
1212
from codegate.config import Config
@@ -559,32 +559,75 @@ def __init__(self, proxy: CopilotProvider):
559559
self.headers_sent = False
560560
self.sse_processor: Optional[SSEProcessor] = None
561561
self.output_pipeline_instance: Optional[OutputPipelineInstance] = None
562+
self.stream_queue: Optional[asyncio.Queue] = None
562563

563564
def connection_made(self, transport: asyncio.Transport) -> None:
564565
"""Handle successful connection to target"""
565566
self.transport = transport
566567
self.proxy.target_transport = transport
567568

568-
def _process_chunk(self, chunk: bytes):
569-
records = self.sse_processor.process_chunk(chunk)
569+
async def _process_stream(self):
570+
try:
570571

571-
for record in records:
572-
if record["type"] == "done":
573-
sse_data = b"data: [DONE]\n\n"
574-
# Add chunk size for DONE message too
575-
chunk_size = hex(len(sse_data))[2:] + "\r\n"
576-
self._proxy_transport_write(chunk_size.encode())
577-
self._proxy_transport_write(sse_data)
578-
self._proxy_transport_write(b"\r\n")
579-
# Now send the final zero chunk
580-
self._proxy_transport_write(b"0\r\n\r\n")
581-
else:
582-
sse_data = f"data: {json.dumps(record['content'])}\n\n".encode("utf-8")
572+
async def stream_iterator():
573+
while True:
574+
incoming_record = await self.stream_queue.get()
575+
record_content = incoming_record.get("content", {})
576+
577+
streaming_choices = []
578+
for choice in record_content.get("choices", []):
579+
streaming_choices.append(
580+
StreamingChoices(
581+
finish_reason=choice.get("finish_reason", None),
582+
index=0,
583+
delta=Delta(
584+
content=choice.get("delta", {}).get("content"), role="assistant"
585+
),
586+
logprobs=None,
587+
)
588+
)
589+
590+
# Convert record to ModelResponse
591+
mr = ModelResponse(
592+
id=record_content.get("id", ""),
593+
choices=streaming_choices,
594+
created=record_content.get("created", 0),
595+
model=record_content.get("model", ""),
596+
object="chat.completion.chunk",
597+
)
598+
yield mr
599+
600+
async for record in self.output_pipeline_instance.process_stream(stream_iterator()):
601+
chunk = record.model_dump_json(exclude_none=True, exclude_unset=True)
602+
sse_data = f"data:{chunk}\n\n".encode("utf-8")
583603
chunk_size = hex(len(sse_data))[2:] + "\r\n"
584604
self._proxy_transport_write(chunk_size.encode())
585605
self._proxy_transport_write(sse_data)
586606
self._proxy_transport_write(b"\r\n")
587607

608+
sse_data = b"data: [DONE]\n\n"
609+
# Add chunk size for DONE message too
610+
chunk_size = hex(len(sse_data))[2:] + "\r\n"
611+
self._proxy_transport_write(chunk_size.encode())
612+
self._proxy_transport_write(sse_data)
613+
self._proxy_transport_write(b"\r\n")
614+
# Now send the final zero chunk
615+
self._proxy_transport_write(b"0\r\n\r\n")
616+
617+
except Exception as e:
618+
logger.error(f"Error processing stream: {e}")
619+
620+
def _process_chunk(self, chunk: bytes):
621+
records = self.sse_processor.process_chunk(chunk)
622+
623+
for record in records:
624+
if self.stream_queue is None:
625+
# Initialize queue and start processing task on first record
626+
self.stream_queue = asyncio.Queue()
627+
self.processing_task = asyncio.create_task(self._process_stream())
628+
629+
self.stream_queue.put_nowait(record)
630+
588631
def _proxy_transport_write(self, data: bytes):
589632
self.proxy.transport.write(data)
590633

src/codegate/providers/copilot/streaming.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@ def __init__(self):
1313
self.size_written = False
1414

1515
def process_chunk(self, chunk: bytes) -> list:
16-
print("BUFFER AT START")
17-
print(self.buffer)
18-
print("BUFFER AT START - END")
1916
# Skip any chunk size lines (hex number followed by \r\n)
2017
try:
2118
chunk_str = chunk.decode("utf-8")
@@ -25,13 +22,12 @@ def process_chunk(self, chunk: bytes) -> list:
2522
continue
2623
self.buffer += line
2724
except UnicodeDecodeError:
28-
print("Failed to decode chunk")
25+
logger.error("Failed to decode chunk")
2926

3027
records = []
3128
while True:
3229
record_end = self.buffer.find("\n\n")
3330
if record_end == -1:
34-
print(f"REMAINING BUFFER {self.buffer}")
3531
break
3632

3733
record = self.buffer[:record_end]

0 commit comments

Comments
 (0)