Skip to content

Commit dee6828

Browse files
authored
Merge pull request #32 from dreadnode/feat/expand-user-defined-models
feat: Expand user-defined models
2 parents c4f4c64 + 987baea commit dee6828

File tree

12 files changed

+235
-81
lines changed

12 files changed

+235
-81
lines changed

dreadnode_cli/agent/cli.py

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import pathlib
23
import shutil
34
import time
@@ -20,19 +21,21 @@
2021
format_runs,
2122
format_strike_models,
2223
format_strikes,
23-
format_user_models,
2424
)
2525
from dreadnode_cli.agent.templates import cli as templates_cli
2626
from dreadnode_cli.agent.templates.format import format_templates
2727
from dreadnode_cli.agent.templates.manager import TemplateManager
28-
from dreadnode_cli.config import UserConfig, UserModel, UserModels
28+
from dreadnode_cli.api import Client
29+
from dreadnode_cli.config import UserConfig
30+
from dreadnode_cli.model.config import UserModels
31+
from dreadnode_cli.model.format import format_user_models
2932
from dreadnode_cli.profile.cli import switch as switch_profile
3033
from dreadnode_cli.types import GithubRepo
3134
from dreadnode_cli.utils import download_and_unzip_archive, get_repo_archive_source_path, pretty_cli
3235

3336
cli = typer.Typer(no_args_is_help=True)
3437

35-
cli.add_typer(templates_cli, name="templates", help="Interact with Strike templates")
38+
cli.add_typer(templates_cli, name="templates", help="Manage Agent templates")
3639

3740

3841
def ensure_profile(agent_config: AgentConfig, *, user_config: UserConfig | None = None) -> None:
@@ -48,8 +51,8 @@ def ensure_profile(agent_config: AgentConfig, *, user_config: UserConfig | None
4851
plural = "s" if len(agent_config.linked_profiles) > 1 else ""
4952
raise Exception(
5053
f"This agent is linked to the [magenta]{linked_profiles}[/] server profile{plural}, "
51-
f"but the current server profile is [yellow]{user_config.active_profile_name}[/], ",
52-
"use [bold]dreadnode agent push[/] to create a new link with this profile.",
54+
f"but the current server profile is [yellow]{user_config.active_profile_name}[/], "
55+
"use [bold]dreadnode agent push[/] to create a new link with this profile."
5356
)
5457

5558
if agent_config.active_link.profile != user_config.active_profile_name:
@@ -70,7 +73,7 @@ def ensure_profile(agent_config: AgentConfig, *, user_config: UserConfig | None
7073
switch_profile(agent_config.active_link.profile)
7174

7275

73-
@cli.command(help="Initialize a new agent project")
76+
@cli.command(help="Initialize a new agent project", no_args_is_help=True)
7477
@pretty_cli
7578
def init(
7679
strike: t.Annotated[str, typer.Argument(help="The target strike")],
@@ -341,19 +344,34 @@ def deploy(
341344
raise Exception("No strike specified, use -s/--strike or set the strike in the agent config")
342345

343346
user_models = UserModels.read()
344-
user_model: UserModel | None = None
345-
346-
# Verify the model if it was supplied
347-
if model is not None:
348-
# check if it's a user model
349-
user_model = next((m for m in user_models.models if m.key == model), None)
350-
if not user_model:
351-
# check if it's a strike model
352-
strike_response = client.get_strike(strike)
353-
if not any(m.key == model for m in strike_response.models):
354-
models(directory, strike=strike)
355-
print()
356-
raise Exception(f"Model '{model}' is not a user model nor was found in strike '{strike_response.name}'")
347+
user_model: Client.UserModel | None = None
348+
349+
# Check for a user-defined model
350+
if model in user_models.models:
351+
user_model = Client.UserModel(
352+
key=model,
353+
generator_id=user_models.models[model].generator_id,
354+
api_key=user_models.models[model].api_key,
355+
)
356+
357+
# Resolve the API key from env vars
358+
if user_model.api_key.startswith("$"):
359+
try:
360+
user_model.api_key = os.environ[user_model.api_key[1:]]
361+
except KeyError as e:
362+
raise Exception(
363+
f"API key cannot be read from '{user_model.api_key}', environment variable not found."
364+
) from e
365+
366+
# Otherwise we'll ensure this is a valid strike-native model
367+
if user_model is None and model is not None:
368+
strike_response = client.get_strike(strike)
369+
if not any(m.key == model for m in strike_response.models):
370+
models(directory, strike=strike)
371+
print()
372+
raise Exception(
373+
f"Model '{model}' is not user-defined nor is it available in strike '{strike_response.name}'"
374+
)
357375

358376
run = client.start_strike_run(agent.latest_version.id, strike=strike, model=model, user_model=user_model)
359377
agent_config.add_run(run.id).write(directory)
@@ -380,20 +398,21 @@ def models(
380398
) -> None:
381399
user_models = UserModels.read()
382400
if user_models.models:
383-
print("[bold]User models:[/]\n")
401+
print("[bold]User-defined models:[/]\n")
384402
print(format_user_models(user_models.models))
403+
print()
385404

386405
if strike is None:
387406
agent_config = AgentConfig.read(directory)
388407
ensure_profile(agent_config)
408+
strike = agent_config.strike
389409

390-
strike = strike or agent_config.strike
391410
if strike is None:
392411
raise Exception("No strike specified, use -s/--strike or set the strike in the agent config")
393412

394413
strike_response = api.create_client().get_strike(strike)
395414
if user_models.models:
396-
print("\n[bold]Strike models:[/]\n")
415+
print("\n[bold]Dreadnode-provided models:[/]\n")
397416
print(format_strike_models(strike_response.models))
398417

399418

@@ -522,7 +541,7 @@ def links(
522541
print(table)
523542

524543

525-
@cli.command(help="Switch to a different agent link")
544+
@cli.command(help="Switch to a different agent link", no_args_is_help=True)
526545
@pretty_cli
527546
def switch(
528547
agent_or_profile: t.Annotated[str, typer.Argument(help="Agent key/id or profile name")],
@@ -544,7 +563,7 @@ def switch(
544563
print(f":exclamation: '{agent_or_profile}' not found, use [bold]dreadnode agent links[/]")
545564

546565

547-
@cli.command(help="Clone a github repository")
566+
@cli.command(help="Clone a github repository", no_args_is_help=True)
548567
@pretty_cli
549568
def clone(
550569
repo: t.Annotated[str, typer.Argument(help="Repository name or URL")],

dreadnode_cli/agent/format.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from rich.text import Text
1010

1111
from dreadnode_cli import api
12-
from dreadnode_cli.config import UserModel
1312

1413
P = t.ParamSpec("P")
1514

@@ -66,25 +65,6 @@ def format_time(dt: datetime | None) -> str:
6665
return dt.astimezone().strftime("%c") if dt else "-"
6766

6867

69-
def format_user_models(models: list[UserModel]) -> RenderableType:
70-
table = Table(box=box.ROUNDED)
71-
table.add_column("key")
72-
table.add_column("name")
73-
table.add_column("provider")
74-
table.add_column("api_key")
75-
76-
for model in models:
77-
provider_style = get_model_provider_style(model.provider)
78-
table.add_row(
79-
Text(model.key),
80-
Text(model.name, style=f"bold {provider_style}"),
81-
Text(model.provider, style=provider_style),
82-
Text("yes" if model.api_key else "no", style="green" if model.api_key else "dim"),
83-
)
84-
85-
return table
86-
87-
8868
def format_strike_models(models: list[api.Client.StrikeModel]) -> RenderableType:
8969
table = Table(box=box.ROUNDED)
9070
table.add_column("key")

dreadnode_cli/agent/templates/cli.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@
99
from dreadnode_cli.agent.templates.format import format_templates
1010
from dreadnode_cli.agent.templates.manager import TemplateManager
1111
from dreadnode_cli.defaults import TEMPLATES_DEFAULT_REPO
12+
from dreadnode_cli.ext.typer import AliasGroup
1213
from dreadnode_cli.types import GithubRepo
1314
from dreadnode_cli.utils import download_and_unzip_archive, get_repo_archive_source_path, pretty_cli
1415

15-
cli = typer.Typer(no_args_is_help=True)
16+
cli = typer.Typer(no_args_is_help=True, cls=AliasGroup)
1617

1718

18-
@cli.command(help="List available agent templates with their descriptions")
19+
@cli.command("show|list", help="List available agent templates with their descriptions")
1920
@pretty_cli
2021
def show() -> None:
2122
template_manager = TemplateManager()

dreadnode_cli/api.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from rich import print
1212

1313
from dreadnode_cli import __version__, utils
14-
from dreadnode_cli.config import UserConfig, UserModel
14+
from dreadnode_cli.config import UserConfig
1515
from dreadnode_cli.defaults import (
1616
DEBUG,
1717
DEFAULT_MAX_POLL_TIME,
@@ -377,6 +377,11 @@ class StrikeRunSummaryResponse(_StrikeRun):
377377
class StrikeRunResponse(_StrikeRun):
378378
zones: list["Client.StrikeRunZone"]
379379

380+
class UserModel(BaseModel):
381+
key: str
382+
generator_id: str
383+
api_key: str
384+
380385
def get_strike(self, strike: str) -> StrikeResponse:
381386
response = self.request("GET", f"/api/strikes/{strike}")
382387
return self.StrikeResponse(**response.json())

dreadnode_cli/cli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from dreadnode_cli.challenge import cli as challenge_cli
1010
from dreadnode_cli.config import ServerConfig, UserConfig
1111
from dreadnode_cli.defaults import PLATFORM_BASE_URL
12+
from dreadnode_cli.model import cli as models_cli
1213
from dreadnode_cli.profile import cli as profile_cli
1314
from dreadnode_cli.utils import pretty_cli
1415

@@ -21,6 +22,7 @@
2122
cli.add_typer(profile_cli, name="profile", help="Manage server profiles")
2223
cli.add_typer(challenge_cli, name="challenge", help="Interact with Crucible challenges")
2324
cli.add_typer(agent_cli, name="agent", help="Interact with Strike agents")
25+
cli.add_typer(models_cli, name="model", help="Manage user-defined inference models")
2426

2527

2628
@cli.command(help="Authenticate to the platform.")

dreadnode_cli/config.py

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from rich import print
33
from ruamel.yaml import YAML
44

5-
from dreadnode_cli.defaults import DEFAULT_PROFILE_NAME, USER_CONFIG_PATH, USER_MODELS_CONFIG_PATH
5+
from dreadnode_cli.defaults import DEFAULT_PROFILE_NAME, USER_CONFIG_PATH
66

77

88
class ServerConfig(BaseModel):
@@ -74,31 +74,3 @@ def set_server_config(self, config: ServerConfig, profile: str | None = None) ->
7474
profile = profile or self.active or DEFAULT_PROFILE_NAME
7575
self.servers[profile] = config
7676
return self
77-
78-
79-
class UserModel(BaseModel):
80-
"""
81-
A user defined model.
82-
"""
83-
84-
key: str
85-
name: str
86-
provider: str
87-
generator_id: str
88-
api_key: str | None = None
89-
90-
91-
class UserModels(BaseModel):
92-
"""User models configuration."""
93-
94-
models: list[UserModel] = []
95-
96-
@classmethod
97-
def read(cls) -> "UserModels":
98-
"""Read the user models configuration from the file system or return an empty instance."""
99-
100-
if not USER_MODELS_CONFIG_PATH.exists():
101-
return cls()
102-
103-
with USER_MODELS_CONFIG_PATH.open("r") as f:
104-
return cls.model_validate(YAML().load(f))

dreadnode_cli/ext/typer.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import re
2+
3+
from click import Command, Context
4+
from typer.core import TyperGroup
5+
6+
# https://github.com/fastapi/typer/issues/132
7+
8+
9+
class AliasGroup(TyperGroup):
10+
_CMD_SPLIT_P = re.compile(r" ?[,|] ?")
11+
12+
def get_command(self, ctx: Context, cmd_name: str) -> Command | None:
13+
cmd_name = self._group_cmd_name(cmd_name)
14+
return super().get_command(ctx, cmd_name)
15+
16+
def _group_cmd_name(self, default_name: str) -> str:
17+
for cmd in self.commands.values():
18+
name = cmd.name
19+
if name and default_name in self._CMD_SPLIT_P.split(name):
20+
return name
21+
return default_name

dreadnode_cli/model/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from dreadnode_cli.model.cli import cli
2+
3+
__all__ = ["cli"]

dreadnode_cli/model/cli.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import typing as t
2+
3+
import typer
4+
from rich import print
5+
6+
from dreadnode_cli.defaults import USER_MODELS_CONFIG_PATH
7+
from dreadnode_cli.ext.typer import AliasGroup
8+
from dreadnode_cli.model.config import UserModel, UserModels
9+
from dreadnode_cli.model.format import format_user_models
10+
from dreadnode_cli.utils import pretty_cli
11+
12+
cli = typer.Typer(no_args_is_help=True, cls=AliasGroup)
13+
14+
15+
@cli.command("show|list", help="List all configured models")
16+
@pretty_cli
17+
def show() -> None:
18+
config = UserModels.read()
19+
if not config.models:
20+
print(":exclamation: No models are configured, use [bold]dreadnode models add[/].")
21+
return
22+
23+
print(format_user_models(config.models))
24+
25+
26+
@cli.command(
27+
help="Add a new inference model",
28+
epilog="If $ENV_VAR syntax is used for the api key, it will be replaced with the environment value when used.",
29+
no_args_is_help=True,
30+
)
31+
@pretty_cli
32+
def add(
33+
id: t.Annotated[str, typer.Option("--id", help="Identifier for referencing this model")],
34+
generator_id: t.Annotated[str, typer.Option("--generator-id", "-g", help="Rigging (LiteLLM) generator id")],
35+
api_key: t.Annotated[
36+
str, typer.Option("--api-key", "-k", help="API key for the inference provider (supports $ENV_VAR syntax)")
37+
],
38+
name: t.Annotated[str | None, typer.Option("--name", "-n", help="Friendly name")] = None,
39+
provider: t.Annotated[str | None, typer.Option("--provider", "-p", help="Provider name")] = None,
40+
update: t.Annotated[bool, typer.Option("--update", "-u", help="Update an existing model if it exists")] = False,
41+
) -> None:
42+
config = UserModels.read()
43+
exists = id in config.models
44+
45+
if exists and not update:
46+
print(f":exclamation: Model with id [bold]{id}[/] already exists (use -u/--update to modify)")
47+
return
48+
49+
config.models[id] = UserModel(name=name, provider=provider, generator_id=generator_id, api_key=api_key)
50+
config.write()
51+
52+
print(f":wrench: {'Updated' if exists else 'Added'} model [bold]{id}[/] in {USER_MODELS_CONFIG_PATH}")
53+
54+
55+
@cli.command(help="Remove an user inference model", no_args_is_help=True)
56+
@pretty_cli
57+
def forget(id: t.Annotated[str, typer.Argument(help="Model to remove")]) -> None:
58+
config = UserModels.read()
59+
if id not in config.models:
60+
print(f":exclamation: Model with id [bold]{id}[/] does not exist")
61+
return
62+
63+
del config.models[id]
64+
config.write()
65+
66+
print(f":axe: Forgot about [bold]{id}[/] in {USER_MODELS_CONFIG_PATH}")

0 commit comments

Comments
 (0)