|
1 | 1 | import asyncio
|
2 |
| -import json |
3 | 2 | import re
|
4 | 3 | import ssl
|
5 | 4 | from dataclasses import dataclass
|
6 | 5 | from typing import Dict, List, Optional, Tuple
|
7 | 6 | from urllib.parse import unquote, urljoin, urlparse
|
8 | 7 |
|
9 | 8 | import structlog
|
| 9 | +from litellm.types.utils import Delta, ModelResponse, StreamingChoices |
10 | 10 |
|
11 | 11 | from codegate.ca.codegate_ca import CertificateAuthority
|
12 | 12 | from codegate.config import Config
|
@@ -559,32 +559,75 @@ def __init__(self, proxy: CopilotProvider):
|
559 | 559 | self.headers_sent = False
|
560 | 560 | self.sse_processor: Optional[SSEProcessor] = None
|
561 | 561 | self.output_pipeline_instance: Optional[OutputPipelineInstance] = None
|
| 562 | + self.stream_queue: Optional[asyncio.Queue] = None |
562 | 563 |
|
563 | 564 | def connection_made(self, transport: asyncio.Transport) -> None:
|
564 | 565 | """Handle successful connection to target"""
|
565 | 566 | self.transport = transport
|
566 | 567 | self.proxy.target_transport = transport
|
567 | 568 |
|
568 |
| - def _process_chunk(self, chunk: bytes): |
569 |
| - records = self.sse_processor.process_chunk(chunk) |
| 569 | + async def _process_stream(self): |
| 570 | + try: |
570 | 571 |
|
571 |
| - for record in records: |
572 |
| - if record["type"] == "done": |
573 |
| - sse_data = b"data: [DONE]\n\n" |
574 |
| - # Add chunk size for DONE message too |
575 |
| - chunk_size = hex(len(sse_data))[2:] + "\r\n" |
576 |
| - self._proxy_transport_write(chunk_size.encode()) |
577 |
| - self._proxy_transport_write(sse_data) |
578 |
| - self._proxy_transport_write(b"\r\n") |
579 |
| - # Now send the final zero chunk |
580 |
| - self._proxy_transport_write(b"0\r\n\r\n") |
581 |
| - else: |
582 |
| - sse_data = f"data: {json.dumps(record['content'])}\n\n".encode("utf-8") |
| 572 | + async def stream_iterator(): |
| 573 | + while True: |
| 574 | + incoming_record = await self.stream_queue.get() |
| 575 | + record_content = incoming_record.get("content", {}) |
| 576 | + |
| 577 | + streaming_choices = [] |
| 578 | + for choice in record_content.get("choices", []): |
| 579 | + streaming_choices.append( |
| 580 | + StreamingChoices( |
| 581 | + finish_reason=choice.get("finish_reason", None), |
| 582 | + index=0, |
| 583 | + delta=Delta( |
| 584 | + content=choice.get("delta", {}).get("content"), role="assistant" |
| 585 | + ), |
| 586 | + logprobs=None, |
| 587 | + ) |
| 588 | + ) |
| 589 | + |
| 590 | + # Convert record to ModelResponse |
| 591 | + mr = ModelResponse( |
| 592 | + id=record_content.get("id", ""), |
| 593 | + choices=streaming_choices, |
| 594 | + created=record_content.get("created", 0), |
| 595 | + model=record_content.get("model", ""), |
| 596 | + object="chat.completion.chunk", |
| 597 | + ) |
| 598 | + yield mr |
| 599 | + |
| 600 | + async for record in self.output_pipeline_instance.process_stream(stream_iterator()): |
| 601 | + chunk = record.model_dump_json(exclude_none=True, exclude_unset=True) |
| 602 | + sse_data = f"data:{chunk}\n\n".encode("utf-8") |
583 | 603 | chunk_size = hex(len(sse_data))[2:] + "\r\n"
|
584 | 604 | self._proxy_transport_write(chunk_size.encode())
|
585 | 605 | self._proxy_transport_write(sse_data)
|
586 | 606 | self._proxy_transport_write(b"\r\n")
|
587 | 607 |
|
| 608 | + sse_data = b"data: [DONE]\n\n" |
| 609 | + # Add chunk size for DONE message too |
| 610 | + chunk_size = hex(len(sse_data))[2:] + "\r\n" |
| 611 | + self._proxy_transport_write(chunk_size.encode()) |
| 612 | + self._proxy_transport_write(sse_data) |
| 613 | + self._proxy_transport_write(b"\r\n") |
| 614 | + # Now send the final zero chunk |
| 615 | + self._proxy_transport_write(b"0\r\n\r\n") |
| 616 | + |
| 617 | + except Exception as e: |
| 618 | + logger.error(f"Error processing stream: {e}") |
| 619 | + |
| 620 | + def _process_chunk(self, chunk: bytes): |
| 621 | + records = self.sse_processor.process_chunk(chunk) |
| 622 | + |
| 623 | + for record in records: |
| 624 | + if self.stream_queue is None: |
| 625 | + # Initialize queue and start processing task on first record |
| 626 | + self.stream_queue = asyncio.Queue() |
| 627 | + self.processing_task = asyncio.create_task(self._process_stream()) |
| 628 | + |
| 629 | + self.stream_queue.put_nowait(record) |
| 630 | + |
588 | 631 | def _proxy_transport_write(self, data: bytes):
|
589 | 632 | self.proxy.transport.write(data)
|
590 | 633 |
|
|
0 commit comments