diff --git a/cert_gen.py b/cert_gen.py new file mode 100644 index 00000000..c9d79c82 --- /dev/null +++ b/cert_gen.py @@ -0,0 +1,178 @@ +import datetime +import os + +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID + + +def generate_certificates(cert_dir="certs"): + """Generate self-signed certificates with proper extensions for HTTPS proxy""" + # Create certificates directory if it doesn't exist + if not os.path.exists(cert_dir): + print("Making: ", cert_dir) + os.makedirs(cert_dir) + + # Generate private key + ca_private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=4096, # Increased key size for better security + ) + + # Generate public key + ca_public_key = ca_private_key.public_key() + + # CA BEGIN + name = x509.Name( + [ + x509.NameAttribute(NameOID.COMMON_NAME, "Proxy Pilot CA"), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Proxy Pilot"), + x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Development"), + x509.NameAttribute(NameOID.COUNTRY_NAME, "UK"), + ] + ) + + builder = x509.CertificateBuilder() + builder = builder.subject_name(name) + builder = builder.issuer_name(name) + builder = builder.public_key(ca_public_key) + builder = builder.serial_number(x509.random_serial_number()) + builder = builder.not_valid_before(datetime.datetime.utcnow()) + builder = builder.not_valid_after( + datetime.datetime.utcnow() + datetime.timedelta(days=3650) # 10 years + ) + + builder = builder.add_extension( + x509.BasicConstraints(ca=True, path_length=None), + critical=True, + ) + + builder = builder.add_extension( + x509.KeyUsage( + digital_signature=True, + content_commitment=False, + key_encipherment=True, + data_encipherment=False, + key_agreement=False, + key_cert_sign=True, # This is a CA + crl_sign=True, + encipher_only=False, + decipher_only=False, + ), + critical=True, + ) + + ca_cert = builder.sign( + private_key=ca_private_key, + algorithm=hashes.SHA256(), + ) + + # Save CA certificate and key + + with open("certs/ca.crt", "wb") as f: + f.write(ca_cert.public_bytes(serialization.Encoding.PEM)) + + with open("certs/ca.key", "wb") as f: + f.write( + ca_private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + ) + # CA END + + # SERVER BEGIN + + # Generate new certificate for domain + server_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, # 2048 bits is sufficient for domain certs + ) + + name = x509.Name( + [ + x509.NameAttribute(NameOID.COMMON_NAME, "Proxy Pilot CA"), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Proxy Pilot Generated"), + ] + ) + + builder = x509.CertificateBuilder() + builder = builder.subject_name(name) + builder = builder.issuer_name(ca_cert.subject) + builder = builder.public_key(server_key.public_key()) + builder = builder.serial_number(x509.random_serial_number()) + builder = builder.not_valid_before(datetime.datetime.utcnow()) + builder = builder.not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=365)) + + # Add domain to SAN + builder = builder.add_extension( + x509.SubjectAlternativeName([x509.DNSName("localhost")]), + critical=False, + ) + + # Add extended key usage + builder = builder.add_extension( + x509.ExtendedKeyUsage( + [ + ExtendedKeyUsageOID.SERVER_AUTH, + ExtendedKeyUsageOID.CLIENT_AUTH, + ] + ), + critical=False, + ) + + # Basic constraints (not a CA) + builder = builder.add_extension( + x509.BasicConstraints(ca=False, path_length=None), + critical=True, + ) + + certificate = builder.sign( + private_key=ca_private_key, + algorithm=hashes.SHA256(), + ) + + with open("certs/server.crt", "wb") as f: + f.write(certificate.public_bytes(serialization.Encoding.PEM)) + + with open("certs/server.key", "wb") as f: + f.write( + server_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + ) + + print("Certificates generated successfully in the 'certs' directory") + print("\nTo trust these certificates:") + print("\nOn macOS:") + print( + "sudo security add-trusted-cert -d -r trustRoot " + "-k /Library/Keychains/System.keychain certs/server.crt" + ) + print("\nOn Windows (PowerShell as Admin):") + print( + 'Import-Certificate -FilePath "certs\\server.crt" ' + "-CertStoreLocation Cert:\\LocalMachine\\Root" + ) + print("\nOn Linux:") + print("sudo cp certs/server.crt /usr/local/share/ca-certificates/proxy-pilot.crt") + print("sudo update-ca-certificates") + print("\nFor VSCode, add to settings.json:") + print( + """{ + "http.proxy": "https://localhost:8989", + "http.proxySupport": "on", + "github.copilot.advanced": { + "debug.testOverrideProxyUrl": "https://localhost:8989", + "debug.overrideProxyUrl": "https://localhost:8989" + } +}""" + ) + + +if __name__ == "__main__": + generate_certificates() diff --git a/src/codegate/ca/codegate_ca.py b/src/codegate/ca/codegate_ca.py index 4a8315d6..1ac6d128 100644 --- a/src/codegate/ca/codegate_ca.py +++ b/src/codegate/ca/codegate_ca.py @@ -349,7 +349,6 @@ def generate_certificates(self) -> Tuple[str, str]: algorithm=hashes.SHA256(), ) - # os.path.join(Config.get_config().server_key) with open( os.path.join(Config.get_config().certs_dir, Config.get_config().server_cert), "wb" ) as f: diff --git a/src/codegate/cli.py b/src/codegate/cli.py index 297aedda..f2620969 100644 --- a/src/codegate/cli.py +++ b/src/codegate/cli.py @@ -1,6 +1,7 @@ """Command-line interface for codegate.""" import asyncio +import signal import sys from pathlib import Path from typing import Dict, Optional @@ -18,8 +19,6 @@ from codegate.server import init_app from codegate.storage.utils import restore_storage_backup -logger = structlog.get_logger("codegate") - class UvicornServer: def __init__(self, config: UvicornConfig, server: Server): @@ -32,10 +31,16 @@ def __init__(self, config: UvicornConfig, server: Server): self._startup_complete = asyncio.Event() self._shutdown_event = asyncio.Event() self._should_exit = False + self.logger = structlog.get_logger("codegate") async def serve(self) -> None: """Start the uvicorn server and handle shutdown gracefully.""" - logger.debug(f"Starting server on {self.host}:{self.port}") + self.logger.debug(f"Starting server on {self.host}:{self.port}") + + # Set up signal handlers + loop = asyncio.get_running_loop() + loop.add_signal_handler(signal.SIGTERM, lambda: asyncio.create_task(self.cleanup())) + loop.add_signal_handler(signal.SIGINT, lambda: asyncio.create_task(self.cleanup())) self.server = Server(config=self.config) self.server.force_exit = True @@ -44,27 +49,27 @@ async def serve(self) -> None: self._startup_complete.set() await self.server.serve() except asyncio.CancelledError: - logger.info("Server received cancellation") + self.logger.info("Server received cancellation") except Exception as e: - logger.exception("Unexpected error occurred during server execution", exc_info=e) + self.logger.exception("Unexpected error occurred during server execution", exc_info=e) finally: await self.cleanup() async def wait_startup_complete(self) -> None: """Wait for the server to complete startup.""" - logger.debug("Waiting for server startup to complete") + self.logger.debug("Waiting for server startup to complete") await self._startup_complete.wait() async def cleanup(self) -> None: """Cleanup server resources and ensure graceful shutdown.""" - logger.debug("Cleaning up server resources") + self.logger.debug("Cleaning up server resources") if not self._should_exit: self._should_exit = True - logger.debug("Initiating server shutdown") + self.logger.debug("Initiating server shutdown") self._shutdown_event.set() if hasattr(self.server, "shutdown"): - logger.debug("Shutting down server") + self.logger.debug("Shutting down server") await self.server.shutdown() # Ensure all connections are closed @@ -72,12 +77,13 @@ async def cleanup(self) -> None: [task.cancel() for task in tasks] await asyncio.gather(*tasks, return_exceptions=True) - logger.debug("Server shutdown complete") + self.logger.debug("Server shutdown complete") def validate_port(ctx: click.Context, param: click.Parameter, value: int) -> int: - logger.debug(f"Validating port number: {value}") """Validate the port number is in valid range.""" + logger = structlog.get_logger("codegate") + logger.debug(f"Validating port number: {value}") if value is not None and not (1 <= value <= 65535): raise click.BadParameter("Port must be between 1 and 65535") return value @@ -286,10 +292,14 @@ def serve( db_path=db_path, ) + # Set up logging first + setup_logging(cfg.log_level, cfg.log_format) + logger = structlog.get_logger("codegate") + init_db_sync(cfg.db_path) # Check certificates and create CA if necessary - logger.info("Checking certificates and creating CA our created") + logger.info("Checking certificates and creating CA if needed") ca = CertificateAuthority.get_instance() ca.ensure_certificates_exist() @@ -311,8 +321,8 @@ def serve( click.echo(f"Configuration error: {e}", err=True) sys.exit(1) except Exception as e: - if logger: - logger.exception("Unexpected error occurred") + logger = structlog.get_logger("codegate") + logger.exception("Unexpected error occurred") click.echo(f"Error: {e}", err=True) sys.exit(1) @@ -320,7 +330,6 @@ def serve( async def run_servers(cfg: Config, app) -> None: """Run the codegate server.""" try: - setup_logging(cfg.log_level, cfg.log_format) logger = structlog.get_logger("codegate") logger.info( "Starting server", diff --git a/src/codegate/pipeline/secrets/signatures.py b/src/codegate/pipeline/secrets/signatures.py index d6cbed36..1f1a7ddd 100644 --- a/src/codegate/pipeline/secrets/signatures.py +++ b/src/codegate/pipeline/secrets/signatures.py @@ -175,8 +175,10 @@ def _load_signatures(cls) -> None: yaml_data = cls._load_yaml(cls._yaml_path) # Add custom GitHub token patterns - github_patterns = {"Access Token": r"ghp_[0-9a-zA-Z]{32}", - "Personal Token": r"github_pat_[a-zA-Z0-9]{22}_[a-zA-Z0-9]{59}"} + github_patterns = { + "Access Token": r"ghp_[0-9a-zA-Z]{32}", + "Personal Token": r"github_pat_[a-zA-Z0-9]{22}_[a-zA-Z0-9]{59}", + } cls._add_signature_group("GitHub", github_patterns) # Process patterns from YAML diff --git a/src/codegate/providers/copilot/provider.py b/src/codegate/providers/copilot/provider.py index 31b9bfcf..dda7b9b2 100644 --- a/src/codegate/providers/copilot/provider.py +++ b/src/codegate/providers/copilot/provider.py @@ -1,7 +1,8 @@ import asyncio import re import ssl -from typing import Dict, Optional, Tuple +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple from urllib.parse import unquote, urljoin, urlparse import structlog @@ -12,14 +13,34 @@ logger = structlog.get_logger("codegate") -# Increase buffer sizes +# Constants MAX_BUFFER_SIZE = 10 * 1024 * 1024 # 10MB CHUNK_SIZE = 64 * 1024 # 64KB +HTTP_STATUS_MESSAGES = { + 400: "Bad Request", + 404: "Not Found", + 413: "Request Entity Too Large", + 502: "Bad Gateway", +} + + +@dataclass +class HttpRequest: + """Data class to store HTTP request details""" + + method: str + path: str + version: str + headers: List[str] + original_path: str + target: Optional[str] = None class CopilotProvider(asyncio.Protocol): - def __init__(self, loop): - logger.debug("Initializing CopilotProvider class: CopilotProvider") + """Protocol implementation for the Copilot proxy server""" + + def __init__(self, loop: asyncio.AbstractEventLoop): + logger.debug("Initializing CopilotProvider") self.loop = loop self.transport: Optional[asyncio.Transport] = None self.target_transport: Optional[asyncio.Transport] = None @@ -29,49 +50,40 @@ def __init__(self, loop): self.target_port: Optional[int] = None self.handshake_done = False self.is_connect = False - self.content_length = 0 self.headers_parsed = False - self.method = None - self.path = None - self.version = None - self.headers = [] - self.target = None - self.original_path = None - self.ssl_context = None - self.proxy_ep = None - self.decrypted_data = bytearray() - # Get the singleton instance of CertificateAuthority + self.request: Optional[HttpRequest] = None + self.ssl_context: Optional[ssl.SSLContext] = None + self.proxy_ep: Optional[str] = None self.ca = CertificateAuthority.get_instance() + self._closing = False - def connection_made(self, transport: asyncio.Transport): - logger.debug("Client connected fn: connection_made") + def connection_made(self, transport: asyncio.Transport) -> None: + """Handle new client connection""" self.transport = transport self.peername = transport.get_extra_info("peername") logger.debug(f"Client connected from {self.peername}") - def extract_path(self, full_path: str) -> str: - logger.debug(f"Extracting path from {full_path} fn: extract_path") - if full_path.startswith("http://") or full_path.startswith("https://"): + @staticmethod + def extract_path(full_path: str) -> str: + """Extract clean path from full URL or path string""" + logger.debug(f"Extracting path from {full_path}") + if full_path.startswith(("http://", "https://")): parsed = urlparse(full_path) path = parsed.path if parsed.query: path = f"{path}?{parsed.query}" return path.lstrip("/") - elif full_path.startswith("/"): - return full_path.lstrip("/") - return full_path + return full_path.lstrip("/") - def get_headers(self) -> Dict[str, str]: - """Get request headers as a dictionary""" - logger.debug("Getting headers as dictionary fn: get_headers") + def get_headers_dict(self) -> Dict[str, str]: + """Convert raw headers to dictionary format""" headers_dict = {} - try: if b"\r\n\r\n" not in self.buffer: return {} headers_end = self.buffer.index(b"\r\n\r\n") - headers = self.buffer[:headers_end].split(b"\r\n")[1:] # Skip request line + headers = self.buffer[:headers_end].split(b"\r\n")[1:] for header in headers: try: @@ -82,11 +94,11 @@ def get_headers(self) -> Dict[str, str]: return headers_dict except Exception as e: - logger.error(f"Error getting headers: {e}") + logger.error(f"Error parsing headers: {e}") return {} def parse_headers(self) -> bool: - logger.debug("Parsing headers fn: parse_headers") + """Parse HTTP headers from buffer""" try: if b"\r\n\r\n" not in self.buffer: return False @@ -95,77 +107,66 @@ def parse_headers(self) -> bool: headers = self.buffer[:headers_end].split(b"\r\n") request = headers[0].decode("utf-8") - self.method, full_path, self.version = request.split(" ") - - self.original_path = full_path - - if self.method == "CONNECT": - logger.debug(f"CONNECT request to {full_path}") - self.target = full_path - self.path = "" - else: - logger.debug(f"Request: {self.method} {full_path} {self.version}") - self.path = self.extract_path(full_path) + method, full_path, version = request.split(" ") + + self.request = HttpRequest( + method=method, + path=self.extract_path(full_path), + version=version, + headers=[header.decode("utf-8") for header in headers[1:]], + original_path=full_path, + target=full_path if method == "CONNECT" else None, + ) - self.headers = [header.decode("utf-8") for header in headers[1:]] + logger.debug(f"Request: {method} {full_path} {version}") return True + except Exception as e: logger.error(f"Error parsing headers: {e}") return False - def log_decrypted_data(self, data: bytes, direction: str): - """ - Uncomment to log data from payload - """ - try: - # decoded = data.decode('utf-8') - # logger.debug(f"=== Decrypted {direction} Data ===") - # logger.debug(decoded) - # logger.debug("=" * 40) - pass - except UnicodeDecodeError: - # pass - # logger.debug(f"=== Decrypted {direction} Data (hex) ===") - # logger.debug(data.hex()) - # logger.debug("=" * 40) - pass - - async def handle_http_request(self): - logger.debug("Handling HTTP request fn: handle_http_request") - logger.debug("=" * 40) - logger.debug(f"Method: {self.method}") - logger.debug(f"Searched Path: {self.path} in target URL") + def _check_buffer_size(self, new_data: bytes) -> bool: + """Check if adding new data would exceed buffer size limit""" + return len(self.buffer) + len(new_data) <= MAX_BUFFER_SIZE + + def _forward_data_to_target(self, data: bytes) -> None: + """Forward data to target if connection is established""" + if self.target_transport and not self.target_transport.is_closing(): + self._log_decrypted_data(data, "Client to Server") + self.target_transport.write(data) + + def data_received(self, data: bytes) -> None: + """Handle received data from client""" try: - # Extract proxy endpoint from authorization header if present - headers_dict = self.get_headers() - auth_header = headers_dict.get("authorization", "") - if auth_header: - match = re.search(r"proxy-ep=([^;]+)", auth_header) - if match: - self.proxy_ep = match.group(1) - logger.debug(f"Extracted proxy-ep value: {self.proxy_ep}") - - # Check if the proxy_ep includes a scheme - parsed_proxy_ep = urlparse(self.proxy_ep) - if not parsed_proxy_ep.scheme: - # Default to https if no scheme is provided - self.proxy_ep = f"https://{self.proxy_ep}" - logger.debug(f"Added default scheme to proxy-ep: {self.proxy_ep}") - - target_url = f"{self.proxy_ep}/{self.path}" - else: - target_url = await self.get_target_url(self.path) + if not self._check_buffer_size(data): + self.send_error_response(413, b"Request body too large") + return + + self.buffer.extend(data) + + if not self.headers_parsed: + self.headers_parsed = self.parse_headers() + if self.headers_parsed: + if self.request.method == "CONNECT": + self.handle_connect() + else: + asyncio.create_task(self.handle_http_request()) else: - target_url = await self.get_target_url(self.path) + self._forward_data_to_target(data) + except Exception as e: + logger.error(f"Error processing received data: {e}") + self.send_error_response(502, str(e).encode()) + + async def handle_http_request(self) -> None: + """Handle standard HTTP request""" + try: + target_url = await self._get_target_url() if not target_url: self.send_error_response(404, b"Not Found") return - logger.debug(f"Target URL: {target_url}") parsed_url = urlparse(target_url) - logger.debug(f"Parsed URL {parsed_url}") - self.target_host = parsed_url.hostname self.target_port = parsed_url.port or (443 if parsed_url.scheme == "https" else 80) @@ -221,226 +222,201 @@ async def handle_http_request(self): logger.error(f"Error handling HTTP request: {e}") self.send_error_response(502, str(e).encode()) - def _check_buffer_size(self, new_data: bytes) -> bool: - """Check if adding new data would exceed the maximum buffer size""" - return len(self.buffer) + len(new_data) <= MAX_BUFFER_SIZE + async def _get_target_url(self) -> Optional[str]: + """Determine target URL based on request path and headers""" + headers_dict = self.get_headers_dict() + auth_header = headers_dict.get("authorization", "") + + if auth_header: + match = re.search(r"proxy-ep=([^;]+)", auth_header) + if match: + self.proxy_ep = match.group(1) + if not urlparse(self.proxy_ep).scheme: + self.proxy_ep = f"https://{self.proxy_ep}" + return f"{self.proxy_ep}/{self.request.path}" + + return await self.get_target_url(self.request.path) + + async def _establish_target_connection(self, use_ssl: bool) -> None: + """Establish connection to target server""" + target_protocol = CopilotProxyTargetProtocol(self) + await self.loop.create_connection( + lambda: target_protocol, self.target_host, self.target_port, ssl=use_ssl + ) - def _handle_parsed_headers(self) -> None: - """Handle the request based on parsed headers""" - if self.method == "CONNECT": - logger.debug("Handling CONNECT request") - self.handle_connect() - else: - logger.debug("Handling HTTP request") - asyncio.create_task(self.handle_http_request()) + def _send_request_to_target(self) -> None: + """Send modified request to target server""" + if not self.target_transport: + logger.error("Target transport not available") + self.send_error_response(502, b"Failed to establish target connection") + return - def _forward_data_to_target(self, data: bytes) -> None: - """Forward data to target if connection is established""" - if self.target_transport and not self.target_transport.is_closing(): - self.log_decrypted_data(data, "Client to Server") - self.target_transport.write(data) + headers = self._prepare_request_headers() + self.target_transport.write(headers) - def data_received(self, data: bytes) -> None: - """Handle received data from the client""" - logger.debug(f"Data received from {self.peername} fn: data_received") + body_start = self.buffer.index(b"\r\n\r\n") + 4 + body = self.buffer[body_start:] - try: - # Check buffer size limit - if not self._check_buffer_size(data): - logger.error("Request exceeds maximum buffer size") - self.send_error_response(413, b"Request body too large") - return + if body: + self._log_decrypted_data(body, "Request Body") + for i in range(0, len(body), CHUNK_SIZE): + self.target_transport.write(body[i : i + CHUNK_SIZE]) - # Append new data to buffer - self.buffer.extend(data) + def _prepare_request_headers(self) -> bytes: + """Prepare modified request headers""" + new_headers = [] + has_host = False - if not self.headers_parsed: - # Try to parse headers - self.headers_parsed = self.parse_headers() - if not self.headers_parsed: - return - - # Handle the request based on parsed headers - self._handle_parsed_headers() + for header in self.request.headers: + if header.lower().startswith("host:"): + has_host = True + new_headers.append(f"Host: {self.target_host}") else: - # Forward data to target if headers are already parsed - self._forward_data_to_target(data) + new_headers.append(header) - except asyncio.CancelledError: - logger.warning("Operation cancelled") - raise - except Exception as e: - logger.error(f"Error processing received data: {e}") - self.send_error_response(502, str(e).encode()) + if not has_host: + new_headers.append(f"Host: {self.target_host}") - def handle_connect(self): - """ - This where requests are sent directly via the tunnel created during - a CONNECT request. This is where the SSL context is created and the - internal connection is made to the target host. + request_line = f"{self.request.method} /{self.request.path} {self.request.version}\r\n" + header_block = "\r\n".join(new_headers) + return f"{request_line}{header_block}\r\n\r\n".encode() - We do not need to do a URL to mapping, as this passes through the - tunnel with a FQDN already set by the source (client) request. - """ + def handle_connect(self) -> None: + """Handle CONNECT request for SSL/TLS tunneling""" try: - path = unquote(self.target) - if ":" in path: - self.target_host, port = path.split(":") - self.target_port = int(port) - logger.debug("=" * 40) - logger.debug(f"CONNECT request to {self.target_host}:{self.target_port}") - logger.debug("Headers:") - for header in self.headers: - logger.debug(f" {header}") + path = unquote(self.request.target) + if ":" not in path: + raise ValueError(f"Invalid CONNECT path: {path}") - logger.debug("=" * 40) - cert_path, key_path = self.ca.get_domain_certificate(self.target_host) + self.target_host, port = path.split(":") + self.target_port = int(port) - logger.debug(f"Setting up SSL context for {self.target_host}") - logger.debug(f"Using certificate: {cert_path}") - logger.debug(f"Using key: {key_path}") + cert_path, key_path = self.ca.get_domain_certificate(self.target_host) + self.ssl_context = self._create_ssl_context(cert_path, key_path) - self.ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - self.ssl_context.load_cert_chain(cert_path, key_path) - self.ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 + self.is_connect = True + asyncio.create_task(self.connect_to_target()) + self.handshake_done = True - self.is_connect = True - logger.debug("CONNECT handshake complete") - asyncio.create_task(self.connect_to_target()) - self.handshake_done = True - else: - logger.error(f"Invalid CONNECT path: {path}") - self.send_error_response(400, b"Invalid CONNECT path") except Exception as e: logger.error(f"Error handling CONNECT: {e}") self.send_error_response(502, str(e).encode()) - def send_error_response(self, status: int, message: bytes): - logger.debug(f"Sending error response: {status} {message} fn: send_error_response") - response = ( - f"HTTP/1.1 {status} {self.get_status_text(status)}\r\n" - f"Content-Length: {len(message)}\r\n" - f"Content-Type: text/plain\r\n" - f"\r\n" - ).encode() + message - if self.transport and not self.transport.is_closing(): - self.transport.write(response) - self.transport.close() + def _create_ssl_context(self, cert_path: str, key_path: str) -> ssl.SSLContext: + """Create SSL context for CONNECT tunneling""" + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + ssl_context.load_cert_chain(cert_path, key_path) + ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 + return ssl_context - def get_status_text(self, status: int) -> str: - logger.debug(f"Getting status text for {status} fn: get_status_text") - status_texts = { - 400: "Bad Request", - 404: "Not Found", - 413: "Request Entity Too Large", - 502: "Bad Gateway", - } - return status_texts.get(status, "Error") - - async def connect_to_target(self): - logger.debug( - f"Connecting to target {self.target_host}:{self.target_port} fn: connect_to_target" - ) + async def connect_to_target(self) -> None: + """Establish connection to target for CONNECT requests""" try: if not self.target_host or not self.target_port: raise ValueError("Target host and port not set") - logger.debug(f"Attempting to connect to {self.target_host}:{self.target_port}") - - # Create SSL context for target connection - logger.debug("Creating SSL context for target connection") target_ssl_context = ssl.create_default_context() - # Don't verify certificates when connecting to target target_ssl_context.check_hostname = False target_ssl_context.verify_mode = ssl.CERT_NONE - # Connect directly to target host - logger.debug(f"Connecting to {self.target_host}:{self.target_port}") target_protocol = CopilotProxyTargetProtocol(self) transport, _ = await self.loop.create_connection( lambda: target_protocol, self.target_host, self.target_port, ssl=target_ssl_context ) - logger.debug(f"Successfully connected to {self.target_host}:{self.target_port}") - - # Send 200 Connection Established if self.transport and not self.transport.is_closing(): - logger.debug("Sending 200 Connection Established response") self.transport.write( b"HTTP/1.1 200 Connection Established\r\n" b"Proxy-Agent: ProxyPilot\r\n" b"Connection: keep-alive\r\n\r\n" ) - # Upgrade client connection to SSL - logger.debug("Upgrading client connection to SSL") - transport = await self.loop.start_tls( + self.transport = await self.loop.start_tls( self.transport, self, self.ssl_context, server_side=True ) - self.transport = transport except Exception as e: logger.error(f"Failed to connect to target {self.target_host}:{self.target_port}: {e}") self.send_error_response(502, str(e).encode()) - def connection_lost(self, exc): - logger.debug(f"Connection lost from {self.peername} fn: connection_lost") - logger.debug(f"Client disconnected from {self.peername}") + def send_error_response(self, status: int, message: bytes) -> None: + """Send error response to client""" + if self._closing: + return + + response = ( + f"HTTP/1.1 {status} {HTTP_STATUS_MESSAGES.get(status, 'Error')}\r\n" + f"Content-Length: {len(message)}\r\n" + f"Content-Type: text/plain\r\n" + f"\r\n" + ).encode() + message + + if self.transport and not self.transport.is_closing(): + self.transport.write(response) + self.transport.close() + + def connection_lost(self, exc: Optional[Exception]) -> None: + """Handle connection loss""" + if self._closing: + return + + self._closing = True + logger.debug(f"Connection lost from {self.peername}") + + # Close target transport if it exists and isn't already closing if self.target_transport and not self.target_transport.is_closing(): - self.target_transport.close() + try: + self.target_transport.close() + except Exception as e: + logger.error(f"Error closing target transport: {e}") + + # Clear references to help with cleanup + self.transport = None + self.target_transport = None + self.buffer.clear() + self.ssl_context = None + + @staticmethod + def _log_decrypted_data(data: bytes, direction: str) -> None: + """Log decrypted data for debugging""" + pass # Logging disabled by default @classmethod async def create_proxy_server( cls, host: str, port: int, ssl_context: Optional[ssl.SSLContext] = None - ): - logger.debug(f"Creating proxy server on {host}:{port} fn: create_proxy_server") + ) -> asyncio.AbstractServer: + """Create and start proxy server""" loop = asyncio.get_event_loop() - - def create_protocol(): - logger.debug("Creating protocol for proxy server fn: create_protocol") - return cls(loop) - - logger.debug(f"Starting proxy server on {host}:{port}") server = await loop.create_server( - create_protocol, host, port, ssl=ssl_context, reuse_port=True, start_serving=True + lambda: cls(loop), host, port, ssl=ssl_context, reuse_port=True, start_serving=True ) - logger.debug(f"Proxy server running on {host}:{port}") return server @classmethod - async def run_proxy_server(cls): - logger.debug("Running proxy server fn: run_proxy_server") + async def run_proxy_server(cls) -> None: + """Run the proxy server""" try: - # Get the singleton instance of CertificateAuthority ca = CertificateAuthority.get_instance() - logger.debug("Creating SSL context for proxy server") ssl_context = ca.create_ssl_context() - server = await cls.create_proxy_server( - Config.get_config().host, Config.get_config().proxy_port, ssl_context - ) - logger.debug("Proxy server created") + config = Config.get_config() + server = await cls.create_proxy_server(config.host, config.proxy_port, ssl_context) + async with server: await server.serve_forever() - except Exception as e: logger.error(f"Proxy server error: {e}") raise - @classmethod - async def get_target_url(cls, path: str) -> Optional[str]: + @staticmethod + async def get_target_url(path: str) -> Optional[str]: """Get target URL for the given path""" - logger.debug(f"Attempting to get target URL for path: {path} fn: get_target_url") - - logger.debug("=" * 40) - logger.debug("Validated routes:") + # Check for exact path match for route in VALIDATED_ROUTES: if path == route.path: - logger.debug(f" {route.path} -> {route.target}") - logger.debug(f"Found exact path match: {path} -> {route.target}") return str(route.target) - # Then check for prefix match + # Check for prefix match for route in VALIDATED_ROUTES: # For prefix matches, keep the rest of the path remaining_path = path[len(route.path) :] @@ -449,10 +425,6 @@ async def get_target_url(cls, path: str) -> Optional[str]: if remaining_path and remaining_path.startswith("/"): remaining_path = remaining_path[1:] target = urljoin(str(route.target), remaining_path) - logger.debug( - f"Found prefix match: {path} -> {target} " - "(using route {route.path} -> {route.target})" - ) return target logger.warning(f"No route found for path: {path}") @@ -460,25 +432,31 @@ async def get_target_url(cls, path: str) -> Optional[str]: class CopilotProxyTargetProtocol(asyncio.Protocol): + """Protocol implementation for proxy target connections""" + def __init__(self, proxy: CopilotProvider): - logger.debug("Initializing CopilotProxyTargetProtocol class: CopilotProxyTargetProtocol") self.proxy = proxy self.transport: Optional[asyncio.Transport] = None - def connection_made(self, transport: asyncio.Transport): - logger.debug(f"Connection made to target {self.proxy.target_host}:{self.proxy.target_port}") + def connection_made(self, transport: asyncio.Transport) -> None: + """Handle successful connection to target""" self.transport = transport self.proxy.target_transport = transport - def data_received(self, data: bytes): - logger.debug(f"Data received from target {self.proxy.target_host}:{self.proxy.target_port}") + def data_received(self, data: bytes) -> None: + """Handle data received from target""" if self.proxy.transport and not self.proxy.transport.is_closing(): - self.proxy.log_decrypted_data(data, "Server to Client") + self.proxy._log_decrypted_data(data, "Server to Client") self.proxy.transport.write(data) - def connection_lost(self, exc): - logger.debug( - f"Connection lost from target {self.proxy.target_host}:{self.proxy.target_port}" - ) - if self.proxy.transport and not self.proxy.transport.is_closing(): - self.proxy.transport.close() + def connection_lost(self, exc: Optional[Exception]) -> None: + """Handle connection loss to target""" + if ( + not self.proxy._closing + and self.proxy.transport + and not self.proxy.transport.is_closing() + ): + try: + self.proxy.transport.close() + except Exception as e: + logger.error(f"Error closing proxy transport: {e}") diff --git a/tests/conftest.py b/tests/conftest.py index afbb5956..0d8c6731 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,6 +13,12 @@ from codegate.config import Config +@pytest.fixture(autouse=True) +def setup_config() -> None: + """Initialize Config with default prompts before each test.""" + Config.load() + + @pytest.fixture def temp_config_file(tmp_path: Path) -> Iterator[Path]: """Create a temporary config file.""" diff --git a/tests/test_cli.py b/tests/test_cli.py index b91e7025..54c22434 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,397 +1,150 @@ -"""Tests for the CLI module.""" +"""Tests for the server module.""" -from pathlib import Path -from typing import Any, AsyncGenerator -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock, patch import pytest -from click.testing import CliRunner -from uvicorn.config import Config as UvicornConfig +from fastapi.middleware.cors import CORSMiddleware +from fastapi.testclient import TestClient +from httpx import AsyncClient -from codegate.cli import UvicornServer, cli +from codegate import __version__ +from codegate.pipeline.secrets.manager import SecretsManager +from codegate.providers.registry import ProviderRegistry +from codegate.server import init_app @pytest.fixture -def cli_runner() -> CliRunner: - """Create a Click CLI test runner.""" - return CliRunner() +def mock_secrets_manager(): + """Create a mock secrets manager.""" + return MagicMock(spec=SecretsManager) @pytest.fixture -def mock_logging(monkeypatch: Any) -> MagicMock: - """Mock the logging function.""" - mock = MagicMock() - monkeypatch.setattr("codegate.cli.structlog.get_logger", mock) - return mock +def mock_provider_registry(): + """Create a mock provider registry.""" + return MagicMock(spec=ProviderRegistry) @pytest.fixture -def mock_setup_logging(monkeypatch: Any) -> MagicMock: - """Mock the setup_logging function.""" - mock = MagicMock() - monkeypatch.setattr("codegate.cli.setup_logging", mock) - return mock +def test_client() -> TestClient: + """Create a test client for the FastAPI application.""" + app = init_app() + return TestClient(app) -@pytest.fixture -def temp_config_file(tmp_path: Path) -> Path: - """Create a temporary config file.""" - config_file = tmp_path / "config.yaml" - config_file.write_text( - """ -port: 8989 -host: localhost -log_level: DEBUG -log_format: JSON -certs_dir: "./test-certs" -ca_cert: "test-ca.crt" -ca_key: "test-ca.key" -server_cert: "test-server.crt" -server_key: "test-server.key" -""" - ) - return config_file +def test_app_initialization() -> None: + """Test that the FastAPI application initializes correctly.""" + app = init_app() + assert app is not None + assert app.title == "CodeGate" + assert app.version == __version__ + + +def test_cors_middleware() -> None: + """Test that CORS middleware is properly configured.""" + app = init_app() + cors_middleware = None + for middleware in app.user_middleware: + if isinstance(middleware.cls, type) and issubclass(middleware.cls, CORSMiddleware): + cors_middleware = middleware + break + assert cors_middleware is not None + assert cors_middleware.kwargs.get("allow_origins") == ["*"] + assert cors_middleware.kwargs.get("allow_credentials") is True + assert cors_middleware.kwargs.get("allow_methods") == ["*"] + assert cors_middleware.kwargs.get("allow_headers") == ["*"] + + +def test_health_check(test_client: TestClient) -> None: + """Test the health check endpoint.""" + response = test_client.get("/health") + assert response.status_code == 200 + assert response.json() == {"status": "healthy"} + + +@patch("codegate.server.ProviderRegistry") +@patch("codegate.server.SecretsManager") +def test_provider_registration(mock_secrets_mgr, mock_registry) -> None: + """Test that all providers are registered correctly.""" + init_app() + + # Verify SecretsManager was initialized + mock_secrets_mgr.assert_called_once() + + # Verify ProviderRegistry was initialized with the app + mock_registry.assert_called_once() + + # Verify all providers were registered + registry_instance = mock_registry.return_value + assert ( + registry_instance.add_provider.call_count == 5 + ) # openai, anthropic, llamacpp, vllm, ollama + + # Verify specific providers were registered + provider_names = [call.args[0] for call in registry_instance.add_provider.call_args_list] + assert "openai" in provider_names + assert "anthropic" in provider_names + assert "llamacpp" in provider_names + assert "vllm" in provider_names + assert "ollama" in provider_names + + +@patch("codegate.server.CodegateSignatures") +def test_signatures_initialization(mock_signatures) -> None: + """Test that signatures are initialized correctly.""" + init_app() + mock_signatures.initialize.assert_called_once_with("signatures.yaml") + + +def test_pipeline_initialization() -> None: + """Test that pipelines are initialized correctly.""" + app = init_app() + + # Access the provider registry to check pipeline configuration + registry = next((route for route in app.routes if hasattr(route, "registry")), None) + + if registry: + for provider in registry.registry.values(): + # Verify each provider has the required pipelines + assert hasattr(provider, "pipeline_processor") + assert hasattr(provider, "fim_pipeline_processor") + assert hasattr(provider, "output_pipeline_processor") + + +def test_dashboard_routes() -> None: + """Test that dashboard routes are included.""" + app = init_app() + routes = [route.path for route in app.routes] + + # Verify dashboard endpoints are included + dashboard_routes = [route for route in routes if route.startswith("/dashboard")] + assert len(dashboard_routes) > 0 + + +def test_system_routes() -> None: + """Test that system routes are included.""" + app = init_app() + routes = [route.path for route in app.routes] + + # Verify system endpoints are included + assert "/health" in routes -@pytest.fixture -async def mock_uvicorn_server() -> AsyncGenerator[MagicMock, None]: - """Create a mock Uvicorn server.""" - config = UvicornConfig(app=MagicMock(), host="localhost", port=8989) - server = MagicMock() - server.serve = AsyncMock() - server.shutdown = AsyncMock() - yield UvicornServer(config=config, server=server) - - -def test_cli_version(cli_runner: CliRunner) -> None: - """Test CLI version command.""" - result = cli_runner.invoke(cli, ["--version"]) - assert result.exit_code == 0 - - -# @pytest.mark.asyncio -# async def test_uvicorn_server_serve(mock_uvicorn_server: UvicornServer) -> None: -# """Test UvicornServer serve method.""" -# # Start server in background task -# server_task = asyncio.create_task(mock_uvicorn_server.serve()) -# -# # Wait for startup to complete -# await mock_uvicorn_server.wait_startup_complete() -# -# # Verify server started -# assert mock_uvicorn_server.server.serve.called -# -# # Cleanup -# await mock_uvicorn_server.cleanup() -# await server_task - - -# @pytest.mark.asyncio -# async def test_uvicorn_server_cleanup(mock_uvicorn_server: UvicornServer) -> None: -# """Test UvicornServer cleanup method.""" -# # Start server -# server_task = asyncio.create_task(mock_uvicorn_server.serve()) -# await mock_uvicorn_server.wait_startup_complete() -# -# # Trigger cleanup -# await mock_uvicorn_server.cleanup() -# -# # Verify shutdown was called -# assert mock_uvicorn_server.server.shutdown.called -# assert mock_uvicorn_server._shutdown_event.is_set() -# -# await server_task - - -# @pytest.mark.asyncio -# async def test_uvicorn_server_signal_handling(mock_uvicorn_server: UvicornServer) -> None: -# """Test signal handling in UvicornServer.""" -# # Mock signal handlers -# with patch("asyncio.get_running_loop") as mock_loop: -# mock_loop_instance = MagicMock() -# mock_loop.return_value = mock_loop_instance -# -# # Start server -# server_task = asyncio.create_task(mock_uvicorn_server.serve()) -# await mock_uvicorn_server.wait_startup_complete() -# -# # Simulate SIGTERM -# mock_loop_instance.add_signal_handler.assert_any_call( -# signal.SIGTERM, pytest.approx(type(lambda: None)) # Check if a callable was passed -# ) -# -# # Simulate SIGINT -# mock_loop_instance.add_signal_handler.assert_any_call( -# signal.SIGINT, pytest.approx(type(lambda: None)) # Check if a callable was passed -# ) -# -# await mock_uvicorn_server.cleanup() -# await server_task - - -def test_serve_default_options( - cli_runner: CliRunner, mock_logging: Any, mock_setup_logging: Any -) -> None: - """Test serve command with default options.""" - with patch("codegate.cli.run_servers") as mock_run: - logger_instance = MagicMock() - mock_logging.return_value = logger_instance - result = cli_runner.invoke(cli, ["serve"]) - - assert result.exit_code == 0 - # mock_setup_logging.assert_called_once_with(LogLevel.INFO, LogFormat.JSON) - # mock_logging.assert_called_with("codegate") - - # validate only a subset of the expected extra arguments - # expected_extra = { - # "host": "localhost", - # "port": 8989, - # "log_level": "INFO", - # "log_format": "JSON", - # "prompts_loaded": 7, - # "provider_urls": DEFAULT_PROVIDER_URLS, - # "certs_dir": "./certs", # Default certificate directory - # } - - # Retrieve the actual call arguments - # calls = [call[1]["extra"] for call in logger_instance.info.call_args_list] - - # Check if one of the calls matches the expected subset - # assert any( - # all(expected_extra[k] == actual_extra.get(k) for k in expected_extra) - # for actual_extra in calls - # ) - mock_run.assert_called_once() - - -def test_serve_custom_options( - cli_runner: CliRunner, mock_logging: Any, mock_setup_logging: Any -) -> None: - """Test serve command with custom options.""" - with patch("codegate.cli.run_servers") as mock_run: - logger_instance = MagicMock() - mock_logging.return_value = logger_instance - result = cli_runner.invoke( - cli, - [ - "serve", - "--port", - "8989", - "--host", - "localhost", - "--log-level", - "DEBUG", - "--log-format", - "TEXT", - "--certs-dir", - "./custom-certs", - "--ca-cert", - "custom-ca.crt", - "--ca-key", - "custom-ca.key", - "--server-cert", - "custom-server.crt", - "--server-key", - "custom-server.key", - ], - ) - - assert result.exit_code == 0 - # mock_setup_logging.assert_called_once_with(LogLevel.DEBUG, LogFormat.TEXT) - # mock_logging.assert_called_with("codegate") - - # Retrieve the actual call arguments - # calls = [call[1]["extra"] for call in logger_instance.info.call_args_list] - - # expected_extra = { - # "host": "localhost", - # "port": 8989, - # "log_level": "DEBUG", - # "log_format": "TEXT", - # "prompts_loaded": 7, # Default prompts are loaded - # "provider_urls": DEFAULT_PROVIDER_URLS, - # "certs_dir": "./custom-certs", - # } - - # Check if one of the calls matches the expected subset - # assert any( - # all(expected_extra[k] == actual_extra.get(k) for k in expected_extra) - # for actual_extra in calls - # ) - mock_run.assert_called_once() - - -def test_serve_invalid_port(cli_runner: CliRunner) -> None: - """Test serve command with invalid port.""" - result = cli_runner.invoke(cli, ["serve", "--port", "999999"]) - assert result.exit_code == 2 - assert "Port must be between 1 and 65535" in result.output - - -def test_serve_invalid_log_level(cli_runner: CliRunner) -> None: - """Test serve command with invalid log level.""" - result = cli_runner.invoke(cli, ["serve", "--log-level", "INVALID"]) - assert result.exit_code == 2 - assert "Invalid value for '--log-level'" in result.output - - -def test_serve_with_config_file( - cli_runner: CliRunner, mock_logging: Any, temp_config_file: Path, mock_setup_logging: Any -) -> None: - """Test serve command with config file.""" - with patch("codegate.cli.run_servers") as mock_run: - logger_instance = MagicMock() - mock_logging.return_value = logger_instance - result = cli_runner.invoke(cli, ["serve", "--config", str(temp_config_file)]) - - assert result.exit_code == 0 - # mock_setup_logging.assert_called_once_with(LogLevel.DEBUG, LogFormat.JSON) - # mock_logging.assert_called_with("codegate") - - # Retrieve the actual call arguments - # calls = [call[1]["extra"] for call in logger_instance.info.call_args_list] - - # expected_extra = { - # "host": "localhost", - # "port": 8989, - # "log_level": "DEBUG", - # "log_format": "JSON", - # "prompts_loaded": 7, # Default prompts are loaded - # "provider_urls": DEFAULT_PROVIDER_URLS, - # "certs_dir": "./test-certs", # From config file - # } - - # Check if one of the calls matches the expected subset - # assert any( - # all(expected_extra[k] == actual_extra.get(k) for k in expected_extra) - # for actual_extra in calls - # ) - mock_run.assert_called_once() - - -def test_serve_with_nonexistent_config_file(cli_runner: CliRunner) -> None: - """Test serve command with nonexistent config file.""" - result = cli_runner.invoke(cli, ["serve", "--config", "nonexistent.yaml"]) - assert result.exit_code == 2 - assert "does not exist" in result.output - - -def test_serve_priority_resolution( - cli_runner: CliRunner, - mock_logging: Any, - temp_config_file: Path, - env_vars: Any, - mock_setup_logging: Any, -) -> None: - """Test serve command respects configuration priority.""" - with patch("codegate.cli.run_servers") as mock_run: - logger_instance = MagicMock() - mock_logging.return_value = logger_instance - result = cli_runner.invoke( - cli, - [ - "serve", - "--config", - str(temp_config_file), - "--port", - "8080", - "--host", - "example.com", - "--log-level", - "ERROR", - "--log-format", - "TEXT", - "--certs-dir", - "./cli-certs", - "--ca-cert", - "cli-ca.crt", - "--ca-key", - "cli-ca.key", - "--server-cert", - "cli-server.crt", - "--server-key", - "cli-server.key", - ], - ) - - assert result.exit_code == 0 - # mock_setup_logging.assert_called_once_with(LogLevel.ERROR, LogFormat.TEXT) - # mock_logging.assert_called_with("codegate") - - # Retrieve the actual call arguments - # calls = [call[1]["extra"] for call in logger_instance.info.call_args_list] - - # expected_extra = { - # "host": "example.com", - # "port": 8080, - # "log_level": "ERROR", - # "log_format": "TEXT", - # "prompts_loaded": 7, # Default prompts are loaded - # "provider_urls": DEFAULT_PROVIDER_URLS, - # "certs_dir": "./cli-certs", # CLI args override config file - # } - - # Check if one of the calls matches the expected subset - # assert any( - # all(expected_extra[k] == actual_extra.get(k) for k in expected_extra) - # for actual_extra in calls - # ) - mock_run.assert_called_once() - - -def test_serve_certificate_options( - cli_runner: CliRunner, mock_logging: Any, mock_setup_logging: Any -) -> None: - """Test serve command with certificate options.""" - with patch("codegate.cli.run_servers") as mock_run: - logger_instance = MagicMock() - mock_logging.return_value = logger_instance - result = cli_runner.invoke( - cli, - [ - "serve", - "--certs-dir", - "./custom-certs", - "--ca-cert", - "custom-ca.crt", - "--ca-key", - "custom-ca.key", - "--server-cert", - "custom-server.crt", - "--server-key", - "custom-server.key", - ], - ) - - assert result.exit_code == 0 - # mock_setup_logging.assert_called_once_with(LogLevel.INFO, LogFormat.JSON) - # mock_logging.assert_called_with("codegate") - - # Retrieve the actual call arguments - # calls = [call[1]["extra"] for call in logger_instance.info.call_args_list] - - # expected_extra = { - # "host": "localhost", - # "port": 8989, - # "log_level": "INFO", - # "log_format": "JSON", - # "prompts_loaded": 6, - # "provider_urls": DEFAULT_PROVIDER_URLS, - # "certs_dir": "./custom-certs", - # } - - # Check if one of the calls matches the expected subset - # assert any( - # all(expected_extra[k] == actual_extra.get(k) for k in expected_extra) - # for actual_extra in calls - # ) - mock_run.assert_called_once() - - -def test_main_function() -> None: - """Test main function.""" - with patch("codegate.cli.cli") as mock_cli: - from codegate.cli import main - - main() - mock_cli.assert_called_once() +@pytest.mark.asyncio +async def test_async_health_check() -> None: + """Test the health check endpoint with async client.""" + app = init_app() + async with AsyncClient(app=app, base_url="http://test") as ac: + response = await ac.get("/health") + assert response.status_code == 200 + assert response.json() == {"status": "healthy"} + + +def test_error_handling(test_client: TestClient) -> None: + """Test error handling for non-existent endpoints.""" + response = test_client.get("/nonexistent") + assert response.status_code == 404 + + # Test method not allowed + response = test_client.post("/health") # Health endpoint only allows GET + assert response.status_code == 405 diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 00000000..54c22434 --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,150 @@ +"""Tests for the server module.""" + +from unittest.mock import MagicMock, patch + +import pytest +from fastapi.middleware.cors import CORSMiddleware +from fastapi.testclient import TestClient +from httpx import AsyncClient + +from codegate import __version__ +from codegate.pipeline.secrets.manager import SecretsManager +from codegate.providers.registry import ProviderRegistry +from codegate.server import init_app + + +@pytest.fixture +def mock_secrets_manager(): + """Create a mock secrets manager.""" + return MagicMock(spec=SecretsManager) + + +@pytest.fixture +def mock_provider_registry(): + """Create a mock provider registry.""" + return MagicMock(spec=ProviderRegistry) + + +@pytest.fixture +def test_client() -> TestClient: + """Create a test client for the FastAPI application.""" + app = init_app() + return TestClient(app) + + +def test_app_initialization() -> None: + """Test that the FastAPI application initializes correctly.""" + app = init_app() + assert app is not None + assert app.title == "CodeGate" + assert app.version == __version__ + + +def test_cors_middleware() -> None: + """Test that CORS middleware is properly configured.""" + app = init_app() + cors_middleware = None + for middleware in app.user_middleware: + if isinstance(middleware.cls, type) and issubclass(middleware.cls, CORSMiddleware): + cors_middleware = middleware + break + assert cors_middleware is not None + assert cors_middleware.kwargs.get("allow_origins") == ["*"] + assert cors_middleware.kwargs.get("allow_credentials") is True + assert cors_middleware.kwargs.get("allow_methods") == ["*"] + assert cors_middleware.kwargs.get("allow_headers") == ["*"] + + +def test_health_check(test_client: TestClient) -> None: + """Test the health check endpoint.""" + response = test_client.get("/health") + assert response.status_code == 200 + assert response.json() == {"status": "healthy"} + + +@patch("codegate.server.ProviderRegistry") +@patch("codegate.server.SecretsManager") +def test_provider_registration(mock_secrets_mgr, mock_registry) -> None: + """Test that all providers are registered correctly.""" + init_app() + + # Verify SecretsManager was initialized + mock_secrets_mgr.assert_called_once() + + # Verify ProviderRegistry was initialized with the app + mock_registry.assert_called_once() + + # Verify all providers were registered + registry_instance = mock_registry.return_value + assert ( + registry_instance.add_provider.call_count == 5 + ) # openai, anthropic, llamacpp, vllm, ollama + + # Verify specific providers were registered + provider_names = [call.args[0] for call in registry_instance.add_provider.call_args_list] + assert "openai" in provider_names + assert "anthropic" in provider_names + assert "llamacpp" in provider_names + assert "vllm" in provider_names + assert "ollama" in provider_names + + +@patch("codegate.server.CodegateSignatures") +def test_signatures_initialization(mock_signatures) -> None: + """Test that signatures are initialized correctly.""" + init_app() + mock_signatures.initialize.assert_called_once_with("signatures.yaml") + + +def test_pipeline_initialization() -> None: + """Test that pipelines are initialized correctly.""" + app = init_app() + + # Access the provider registry to check pipeline configuration + registry = next((route for route in app.routes if hasattr(route, "registry")), None) + + if registry: + for provider in registry.registry.values(): + # Verify each provider has the required pipelines + assert hasattr(provider, "pipeline_processor") + assert hasattr(provider, "fim_pipeline_processor") + assert hasattr(provider, "output_pipeline_processor") + + +def test_dashboard_routes() -> None: + """Test that dashboard routes are included.""" + app = init_app() + routes = [route.path for route in app.routes] + + # Verify dashboard endpoints are included + dashboard_routes = [route for route in routes if route.startswith("/dashboard")] + assert len(dashboard_routes) > 0 + + +def test_system_routes() -> None: + """Test that system routes are included.""" + app = init_app() + routes = [route.path for route in app.routes] + + # Verify system endpoints are included + assert "/health" in routes + + +@pytest.mark.asyncio +async def test_async_health_check() -> None: + """Test the health check endpoint with async client.""" + app = init_app() + async with AsyncClient(app=app, base_url="http://test") as ac: + response = await ac.get("/health") + assert response.status_code == 200 + assert response.json() == {"status": "healthy"} + + +def test_error_handling(test_client: TestClient) -> None: + """Test error handling for non-existent endpoints.""" + response = test_client.get("/nonexistent") + assert response.status_code == 404 + + # Test method not allowed + response = test_client.post("/health") # Health endpoint only allows GET + assert response.status_code == 405