Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Add logic to check if certs exist for generate_certs #334

Merged
merged 3 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 12 additions & 43 deletions src/codegate/ca/codegate_ca.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,35 +368,6 @@ def generate_certificates(self) -> Tuple[str, str]:
)

# Print instructions for trusting the certificates
logger.info(
"""
Certificates generated successfully in the 'certs' directory
To trust these certificates:

On macOS:
`sudo security add-trusted-cert -d -r trustRoot -k /Library/Keychains/System.keychain certs/ca.crt`

On Windows (PowerShell as Admin):
`Import-Certificate -FilePath "certs\\ca.crt" -CertStoreLocation Cert:\\LocalMachine\\Root`

On Linux:
`sudo cp certs/ca.crt /usr/local/share/ca-certificates/codegate.crt`
`sudo update-ca-certificates`

For VSCode, add to settings.json:
{
"http.proxy": "https://localhost:8990",
"http.proxyStrictSSL": true,
"http.proxySupport": "on",
"github.copilot.advanced": {
"debug.useNodeFetcher": true,
"debug.useElectronFetcher": true,
"debug.testOverrideProxyUrl": "https://localhost:8990",
"debug.overrideProxyUrl": "https://localhost:8990"
},
}
"""
)
logger.debug("Certificates generated successfully")
return server_cert, server_key

Expand All @@ -422,23 +393,21 @@ def create_ssl_context(self) -> ssl.SSLContext:
logger.debug("SSL context created successfully")
return ssl_context

def ensure_certificates_exist(self) -> None:
def check_certificates_exist(self) -> bool:
"""Check if SSL certificates exist"""
logger.debug("Checking if certificates exist fn: check_certificates_exist")
return os.path.exists(
os.path.join(Config.get_config().certs_dir, Config.get_config().server_cert)
) and os.path.exists(
os.path.join(Config.get_config().certs_dir, Config.get_config().server_key)
)

def ensure_certificates_exist(self) -> bool:
"""Ensure SSL certificates exist, generate if they don't"""
logger.debug("Ensuring certificates exist. fn ensure_certificates_exist")
if not (
os.path.exists(
os.path.join(Config.get_config().certs_dir, Config.get_config().server_cert)
)
and os.path.exists(
os.path.join(Config.get_config().certs_dir, Config.get_config().server_key)
)
):
logger.debug("Certificates not found, generating new certificates")
if not self.check_certificates_exist():
logger.info("Certificates not found. Generating new certificates.")
self.generate_certificates()
else:
server_cert = Config.get_config().server_cert
server_key = Config.get_config().server_key
logger.debug(f"Certificates found at: {server_cert} and {server_key}.")

def get_ssl_context(self) -> ssl.SSLContext:
"""Get SSL context with certificates"""
Expand Down
19 changes: 18 additions & 1 deletion src/codegate/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,12 @@ def restore_backup(backup_path: Path, backup_name: str) -> None:
default=None,
help="Name that will be given to the created server-key.",
)
@click.option(
"--force-certs",
is_flag=True,
default=False,
help="Force the generation of certificates even if they already exist.",
)
@click.option(
"--log-level",
type=click.Choice([level.value for level in LogLevel]),
Expand All @@ -466,6 +472,7 @@ def generate_certs(
ca_key_name: Optional[str],
server_cert_name: Optional[str],
server_key_name: Optional[str],
force_certs: bool,
log_level: Optional[str],
log_format: Optional[str],
) -> None:
Expand All @@ -476,12 +483,22 @@ def generate_certs(
ca_key=ca_key_name,
server_cert=server_cert_name,
server_key=server_key_name,
force_certs=force_certs,
cli_log_level=log_level,
cli_log_format=log_format,
)
setup_logging(cfg.log_level, cfg.log_format)

ca = CertificateAuthority.get_instance()
ca.generate_certificates()
should_generate = force_certs or not ca.check_certificates_exist()

if should_generate:
ca.generate_certificates()
click.echo("Certificates generated successfully.")
click.echo(f"Certificates saved to: {cfg.certs_dir}")
click.echo("Make sure to add the new CA certificate to the operating system trust store.")
else:
click.echo("Certificates already exist. Skipping generation...")


def main() -> None:
Expand Down
20 changes: 10 additions & 10 deletions src/codegate/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class Config:
ca_key: str = "ca.key"
server_cert: str = "server.crt"
server_key: str = "server.key"
force_certs: bool = False

# Provider URLs with defaults
provider_urls: Dict[str, str] = field(default_factory=lambda: DEFAULT_PROVIDER_URLS.copy())
Expand Down Expand Up @@ -142,6 +143,7 @@ def from_file(cls, config_path: Union[str, Path]) -> "Config":
ca_key=config_data.get("ca_key", cls.ca_key),
server_cert=config_data.get("server_cert", cls.server_cert),
server_key=config_data.get("server_key", cls.server_key),
force_certs=config_data.get("force_certs", cls.force_certs),
prompts=prompts_config,
provider_urls=provider_urls,
)
Expand Down Expand Up @@ -187,6 +189,8 @@ def from_env(cls) -> "Config":
config.server_cert = os.environ["CODEGATE_SERVER_CERT"]
if "CODEGATE_SERVER_KEY" in os.environ:
config.server_key = os.environ["CODEGATE_SERVER_KEY"]
if "CODEGATE_FORCE_CERTS" in os.environ:
config.force_certs = os.environ["CODEGATE_FORCE_CERTS"]

# Load provider URLs from environment variables
for provider in DEFAULT_PROVIDER_URLS.keys():
Expand Down Expand Up @@ -216,6 +220,7 @@ def load(
ca_key: Optional[str] = None,
server_cert: Optional[str] = None,
server_key: Optional[str] = None,
force_certs: Optional[bool] = None,
db_path: Optional[str] = None,
) -> "Config":
"""Load configuration with priority resolution.
Expand All @@ -242,6 +247,7 @@ def load(
ca_key: Optional path to CA key
server_cert: Optional path to server certificate
server_key: Optional path to server key
force_certs: Optional flag to force certificate generation
db_path: Optional path to the SQLite database file

Returns:
Expand Down Expand Up @@ -289,6 +295,8 @@ def load(
config.server_cert = env_config.server_cert
if "CODEGATE_SERVER_KEY" in os.environ:
config.server_key = env_config.server_key
if "CODEGATE_FORCE_CERTS" in os.environ:
config.force_certs = env_config.force_certs

# Override provider URLs from environment
for provider, url in env_config.provider_urls.items():
Expand Down Expand Up @@ -325,16 +333,8 @@ def load(
config.server_key = server_key
if db_path is not None:
config.db_path = db_path
if certs_dir is not None:
config.certs_dir = certs_dir
if ca_cert is not None:
config.ca_cert = ca_cert
if ca_key is not None:
config.ca_key = ca_key
if server_cert is not None:
config.server_cert = server_cert
if server_key is not None:
config.server_key = server_key
if force_certs is not None:
config.force_certs = force_certs

# Set the __config class attribute
Config.__config = config
Expand Down
2 changes: 1 addition & 1 deletion src/codegate/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def get_last_user_message_idx(request: ChatCompletionRequest) -> int:
if request.get("messages") is None:
return -1

for idx, message in reversed(list(enumerate(request['messages']))):
for idx, message in reversed(list(enumerate(request["messages"]))):
if message.get("role", "") == "user":
return idx

Expand Down
5 changes: 1 addition & 4 deletions src/codegate/pipeline/codegate_context_retriever/codegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,7 @@ async def process(
return PipelineResult(request=request)

# Look for matches in vector DB using list of packages as filter
searched_objects = await self.get_objects_from_search(
user_messages, ecosystem, packages
)
searched_objects = await self.get_objects_from_search(user_messages, ecosystem, packages)

logger.info(
f"Found {len(searched_objects)} matches in the database",
Expand Down Expand Up @@ -149,4 +147,3 @@ async def process(
message["content"] = context_msg

return PipelineResult(request=new_request, context=context)

Loading
Loading