Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Fix Unit Tests #298

Merged
merged 2 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 178 additions & 0 deletions cert_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import datetime
import os

from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID


def generate_certificates(cert_dir="certs"):
"""Generate self-signed certificates with proper extensions for HTTPS proxy"""
# Create certificates directory if it doesn't exist
if not os.path.exists(cert_dir):
print("Making: ", cert_dir)
os.makedirs(cert_dir)

# Generate private key
ca_private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=4096, # Increased key size for better security
)

# Generate public key
ca_public_key = ca_private_key.public_key()

# CA BEGIN
name = x509.Name(
[
x509.NameAttribute(NameOID.COMMON_NAME, "Proxy Pilot CA"),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Proxy Pilot"),
x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Development"),
x509.NameAttribute(NameOID.COUNTRY_NAME, "UK"),
]
)

builder = x509.CertificateBuilder()
builder = builder.subject_name(name)
builder = builder.issuer_name(name)
builder = builder.public_key(ca_public_key)
builder = builder.serial_number(x509.random_serial_number())
builder = builder.not_valid_before(datetime.datetime.utcnow())
builder = builder.not_valid_after(
datetime.datetime.utcnow() + datetime.timedelta(days=3650) # 10 years
)

builder = builder.add_extension(
x509.BasicConstraints(ca=True, path_length=None),
critical=True,
)

builder = builder.add_extension(
x509.KeyUsage(
digital_signature=True,
content_commitment=False,
key_encipherment=True,
data_encipherment=False,
key_agreement=False,
key_cert_sign=True, # This is a CA
crl_sign=True,
encipher_only=False,
decipher_only=False,
),
critical=True,
)

ca_cert = builder.sign(
private_key=ca_private_key,
algorithm=hashes.SHA256(),
)

# Save CA certificate and key

with open("certs/ca.crt", "wb") as f:
f.write(ca_cert.public_bytes(serialization.Encoding.PEM))

with open("certs/ca.key", "wb") as f:
f.write(
ca_private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
)
# CA END

# SERVER BEGIN

# Generate new certificate for domain
server_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048, # 2048 bits is sufficient for domain certs
)

name = x509.Name(
[
x509.NameAttribute(NameOID.COMMON_NAME, "Proxy Pilot CA"),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Proxy Pilot Generated"),
]
)

builder = x509.CertificateBuilder()
builder = builder.subject_name(name)
builder = builder.issuer_name(ca_cert.subject)
builder = builder.public_key(server_key.public_key())
builder = builder.serial_number(x509.random_serial_number())
builder = builder.not_valid_before(datetime.datetime.utcnow())
builder = builder.not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=365))

# Add domain to SAN
builder = builder.add_extension(
x509.SubjectAlternativeName([x509.DNSName("localhost")]),
critical=False,
)

# Add extended key usage
builder = builder.add_extension(
x509.ExtendedKeyUsage(
[
ExtendedKeyUsageOID.SERVER_AUTH,
ExtendedKeyUsageOID.CLIENT_AUTH,
]
),
critical=False,
)

# Basic constraints (not a CA)
builder = builder.add_extension(
x509.BasicConstraints(ca=False, path_length=None),
critical=True,
)

certificate = builder.sign(
private_key=ca_private_key,
algorithm=hashes.SHA256(),
)

with open("certs/server.crt", "wb") as f:
f.write(certificate.public_bytes(serialization.Encoding.PEM))

with open("certs/server.key", "wb") as f:
f.write(
server_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
)

print("Certificates generated successfully in the 'certs' directory")
print("\nTo trust these certificates:")
print("\nOn macOS:")
print(
"sudo security add-trusted-cert -d -r trustRoot "
"-k /Library/Keychains/System.keychain certs/server.crt"
)
print("\nOn Windows (PowerShell as Admin):")
print(
'Import-Certificate -FilePath "certs\\server.crt" '
"-CertStoreLocation Cert:\\LocalMachine\\Root"
)
print("\nOn Linux:")
print("sudo cp certs/server.crt /usr/local/share/ca-certificates/proxy-pilot.crt")
print("sudo update-ca-certificates")
print("\nFor VSCode, add to settings.json:")
print(
"""{
"http.proxy": "https://localhost:8989",
"http.proxySupport": "on",
"github.copilot.advanced": {
"debug.testOverrideProxyUrl": "https://localhost:8989",
"debug.overrideProxyUrl": "https://localhost:8989"
}
}"""
)


if __name__ == "__main__":
generate_certificates()
1 change: 0 additions & 1 deletion src/codegate/ca/codegate_ca.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,6 @@ def generate_certificates(self) -> Tuple[str, str]:
algorithm=hashes.SHA256(),
)

# os.path.join(Config.get_config().server_key)
with open(
os.path.join(Config.get_config().certs_dir, Config.get_config().server_cert), "wb"
) as f:
Expand Down
39 changes: 24 additions & 15 deletions src/codegate/cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Command-line interface for codegate."""

import asyncio
import signal
import sys
from pathlib import Path
from typing import Dict, Optional
Expand All @@ -18,8 +19,6 @@
from codegate.server import init_app
from codegate.storage.utils import restore_storage_backup

logger = structlog.get_logger("codegate")


class UvicornServer:
def __init__(self, config: UvicornConfig, server: Server):
Expand All @@ -32,10 +31,16 @@ def __init__(self, config: UvicornConfig, server: Server):
self._startup_complete = asyncio.Event()
self._shutdown_event = asyncio.Event()
self._should_exit = False
self.logger = structlog.get_logger("codegate")

async def serve(self) -> None:
"""Start the uvicorn server and handle shutdown gracefully."""
logger.debug(f"Starting server on {self.host}:{self.port}")
self.logger.debug(f"Starting server on {self.host}:{self.port}")

# Set up signal handlers
loop = asyncio.get_running_loop()
loop.add_signal_handler(signal.SIGTERM, lambda: asyncio.create_task(self.cleanup()))
loop.add_signal_handler(signal.SIGINT, lambda: asyncio.create_task(self.cleanup()))

self.server = Server(config=self.config)
self.server.force_exit = True
Expand All @@ -44,40 +49,41 @@ async def serve(self) -> None:
self._startup_complete.set()
await self.server.serve()
except asyncio.CancelledError:
logger.info("Server received cancellation")
self.logger.info("Server received cancellation")
except Exception as e:
logger.exception("Unexpected error occurred during server execution", exc_info=e)
self.logger.exception("Unexpected error occurred during server execution", exc_info=e)
finally:
await self.cleanup()

async def wait_startup_complete(self) -> None:
"""Wait for the server to complete startup."""
logger.debug("Waiting for server startup to complete")
self.logger.debug("Waiting for server startup to complete")
await self._startup_complete.wait()

async def cleanup(self) -> None:
"""Cleanup server resources and ensure graceful shutdown."""
logger.debug("Cleaning up server resources")
self.logger.debug("Cleaning up server resources")
if not self._should_exit:
self._should_exit = True
logger.debug("Initiating server shutdown")
self.logger.debug("Initiating server shutdown")
self._shutdown_event.set()

if hasattr(self.server, "shutdown"):
logger.debug("Shutting down server")
self.logger.debug("Shutting down server")
await self.server.shutdown()

# Ensure all connections are closed
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
[task.cancel() for task in tasks]

await asyncio.gather(*tasks, return_exceptions=True)
logger.debug("Server shutdown complete")
self.logger.debug("Server shutdown complete")


def validate_port(ctx: click.Context, param: click.Parameter, value: int) -> int:
logger.debug(f"Validating port number: {value}")
"""Validate the port number is in valid range."""
logger = structlog.get_logger("codegate")
logger.debug(f"Validating port number: {value}")
if value is not None and not (1 <= value <= 65535):
raise click.BadParameter("Port must be between 1 and 65535")
return value
Expand Down Expand Up @@ -286,10 +292,14 @@ def serve(
db_path=db_path,
)

# Set up logging first
setup_logging(cfg.log_level, cfg.log_format)
logger = structlog.get_logger("codegate")

init_db_sync(cfg.db_path)

# Check certificates and create CA if necessary
logger.info("Checking certificates and creating CA our created")
logger.info("Checking certificates and creating CA if needed")
ca = CertificateAuthority.get_instance()
ca.ensure_certificates_exist()

Expand All @@ -311,16 +321,15 @@ def serve(
click.echo(f"Configuration error: {e}", err=True)
sys.exit(1)
except Exception as e:
if logger:
logger.exception("Unexpected error occurred")
logger = structlog.get_logger("codegate")
logger.exception("Unexpected error occurred")
click.echo(f"Error: {e}", err=True)
sys.exit(1)


async def run_servers(cfg: Config, app) -> None:
"""Run the codegate server."""
try:
setup_logging(cfg.log_level, cfg.log_format)
logger = structlog.get_logger("codegate")
logger.info(
"Starting server",
Expand Down
6 changes: 4 additions & 2 deletions src/codegate/pipeline/secrets/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,10 @@ def _load_signatures(cls) -> None:
yaml_data = cls._load_yaml(cls._yaml_path)

# Add custom GitHub token patterns
github_patterns = {"Access Token": r"ghp_[0-9a-zA-Z]{32}",
"Personal Token": r"github_pat_[a-zA-Z0-9]{22}_[a-zA-Z0-9]{59}"}
github_patterns = {
"Access Token": r"ghp_[0-9a-zA-Z]{32}",
"Personal Token": r"github_pat_[a-zA-Z0-9]{22}_[a-zA-Z0-9]{59}",
}
cls._add_signature_group("GitHub", github_patterns)

# Process patterns from YAML
Expand Down
Loading
Loading