Skip to content

Commit c4f4c64

Browse files
authored
Merge pull request #31 from dreadnode/simone/eng-652-support-user-provided-api-keys-and-model-configurations-for
new: implemented support for user models (ENG-652)
2 parents 210f8e1 + 4335a76 commit c4f4c64

File tree

5 files changed

+97
-18
lines changed

5 files changed

+97
-18
lines changed

dreadnode_cli/agent/cli.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,16 @@
1616
from dreadnode_cli.agent.format import (
1717
format_agent,
1818
format_agent_versions,
19-
format_models,
2019
format_run,
2120
format_runs,
21+
format_strike_models,
2222
format_strikes,
23+
format_user_models,
2324
)
2425
from dreadnode_cli.agent.templates import cli as templates_cli
2526
from dreadnode_cli.agent.templates.format import format_templates
2627
from dreadnode_cli.agent.templates.manager import TemplateManager
27-
from dreadnode_cli.config import UserConfig
28+
from dreadnode_cli.config import UserConfig, UserModel, UserModels
2829
from dreadnode_cli.profile.cli import switch as switch_profile
2930
from dreadnode_cli.types import GithubRepo
3031
from dreadnode_cli.utils import download_and_unzip_archive, get_repo_archive_source_path, pretty_cli
@@ -339,14 +340,22 @@ def deploy(
339340
if strike is None:
340341
raise Exception("No strike specified, use -s/--strike or set the strike in the agent config")
341342

343+
user_models = UserModels.read()
344+
user_model: UserModel | None = None
345+
342346
# Verify the model if it was supplied
343347
if model is not None:
344-
strike_response = client.get_strike(strike)
345-
if not any(m.key == model for m in strike_response.models):
346-
print(format_models(strike_response.models))
347-
raise Exception(f"Model '{model}' not found in strike '{strike_response.name}'")
348-
349-
run = client.start_strike_run(agent.latest_version.id, strike=strike, model=model)
348+
# check if it's a user model
349+
user_model = next((m for m in user_models.models if m.key == model), None)
350+
if not user_model:
351+
# check if it's a strike model
352+
strike_response = client.get_strike(strike)
353+
if not any(m.key == model for m in strike_response.models):
354+
models(directory, strike=strike)
355+
print()
356+
raise Exception(f"Model '{model}' is not a user model nor was found in strike '{strike_response.name}'")
357+
358+
run = client.start_strike_run(agent.latest_version.id, strike=strike, model=model, user_model=user_model)
350359
agent_config.add_run(run.id).write(directory)
351360
formatted = format_run(run)
352361

@@ -369,6 +378,11 @@ def models(
369378
] = pathlib.Path("."),
370379
strike: t.Annotated[str | None, typer.Option("--strike", "-s", help="The strike to query")] = None,
371380
) -> None:
381+
user_models = UserModels.read()
382+
if user_models.models:
383+
print("[bold]User models:[/]\n")
384+
print(format_user_models(user_models.models))
385+
372386
if strike is None:
373387
agent_config = AgentConfig.read(directory)
374388
ensure_profile(agent_config)
@@ -378,7 +392,9 @@ def models(
378392
raise Exception("No strike specified, use -s/--strike or set the strike in the agent config")
379393

380394
strike_response = api.create_client().get_strike(strike)
381-
print(format_models(strike_response.models))
395+
if user_models.models:
396+
print("\n[bold]Strike models:[/]\n")
397+
print(format_strike_models(strike_response.models))
382398

383399

384400
@cli.command(help="List available strikes")

dreadnode_cli/agent/format.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,13 @@
99
from rich.text import Text
1010

1111
from dreadnode_cli import api
12+
from dreadnode_cli.config import UserModel
1213

1314
P = t.ParamSpec("P")
1415

16+
# um@ is added to indicate a user model
17+
USER_MODEL_PREFIX: str = "um@"
18+
1519

1620
def get_status_style(status: api.Client.StrikeRunStatus | api.Client.StrikeRunZoneStatus | None) -> str:
1721
return (
@@ -62,7 +66,26 @@ def format_time(dt: datetime | None) -> str:
6266
return dt.astimezone().strftime("%c") if dt else "-"
6367

6468

65-
def format_models(models: list[api.Client.StrikeModel]) -> RenderableType:
69+
def format_user_models(models: list[UserModel]) -> RenderableType:
70+
table = Table(box=box.ROUNDED)
71+
table.add_column("key")
72+
table.add_column("name")
73+
table.add_column("provider")
74+
table.add_column("api_key")
75+
76+
for model in models:
77+
provider_style = get_model_provider_style(model.provider)
78+
table.add_row(
79+
Text(model.key),
80+
Text(model.name, style=f"bold {provider_style}"),
81+
Text(model.provider, style=provider_style),
82+
Text("yes" if model.api_key else "no", style="green" if model.api_key else "dim"),
83+
)
84+
85+
return table
86+
87+
88+
def format_strike_models(models: list[api.Client.StrikeModel]) -> RenderableType:
6689
table = Table(box=box.ROUNDED)
6790
table.add_column("key")
6891
table.add_column("name")
@@ -272,7 +295,7 @@ def format_run(run: api.Client.StrikeRunResponse, *, verbose: bool = False, incl
272295
agent_name = f"[bold magenta]{run.agent_key}[/]"
273296

274297
table.add_row("", "")
275-
table.add_row("model", run.model or "<default>")
298+
table.add_row("model", run.model.replace(USER_MODEL_PREFIX, "") if run.model else "<default>")
276299
table.add_row("agent", f"{agent_name} ([dim]rev[/] [yellow]{run.agent_revision}[/])")
277300
table.add_row("image", Text(run.agent_version.container.image, style="cyan"))
278301
table.add_row("notes", run.agent_version.notes or "-")
@@ -304,7 +327,7 @@ def format_runs(runs: list[api.Client.StrikeRunSummaryResponse]) -> RenderableTy
304327
str(run.id),
305328
f"[bold magenta]{run.agent_key}[/] [dim]:[/] [yellow]{run.agent_revision}[/]",
306329
Text(run.status, style="bold " + get_status_style(run.status)),
307-
Text(run.model or "-"),
330+
Text(run.model.replace(USER_MODEL_PREFIX, "") if run.model else "-"),
308331
format_time(run.start),
309332
Text(format_duration(run.start, run.end), style="bold cyan"),
310333
)

dreadnode_cli/api.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from rich import print
1212

1313
from dreadnode_cli import __version__, utils
14-
from dreadnode_cli.config import UserConfig
14+
from dreadnode_cli.config import UserConfig, UserModel
1515
from dreadnode_cli.defaults import (
1616
DEBUG,
1717
DEFAULT_MAX_POLL_TIME,
@@ -430,14 +430,20 @@ def create_strike_agent_version(
430430
return self.StrikeAgentResponse(**response.json())
431431

432432
def start_strike_run(
433-
self, agent_version_id: UUID, *, model: str | None = None, strike: UUID | str | None = None
433+
self,
434+
agent_version_id: UUID,
435+
*,
436+
model: str | None = None,
437+
user_model: UserModel | None = None,
438+
strike: UUID | str | None = None,
434439
) -> StrikeRunResponse:
435440
response = self.request(
436441
"POST",
437442
"/api/strikes/runs",
438443
json_data={
439444
"agent_version_id": str(agent_version_id),
440445
"model": model,
446+
"user_model": user_model.model_dump(mode="json") if user_model else None,
441447
"strike": str(strike) if strike else None,
442448
},
443449
)

dreadnode_cli/config.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
import pydantic
1+
from pydantic import BaseModel
22
from rich import print
33
from ruamel.yaml import YAML
44

5-
from dreadnode_cli.defaults import DEFAULT_PROFILE_NAME, USER_CONFIG_PATH
5+
from dreadnode_cli.defaults import DEFAULT_PROFILE_NAME, USER_CONFIG_PATH, USER_MODELS_CONFIG_PATH
66

77

8-
class ServerConfig(pydantic.BaseModel):
8+
class ServerConfig(BaseModel):
99
"""Server specific authentication data and API URL."""
1010

1111
url: str
@@ -16,7 +16,7 @@ class ServerConfig(pydantic.BaseModel):
1616
refresh_token: str
1717

1818

19-
class UserConfig(pydantic.BaseModel):
19+
class UserConfig(BaseModel):
2020
"""User configuration supporting multiple server profiles."""
2121

2222
active: str | None = None
@@ -74,3 +74,31 @@ def set_server_config(self, config: ServerConfig, profile: str | None = None) ->
7474
profile = profile or self.active or DEFAULT_PROFILE_NAME
7575
self.servers[profile] = config
7676
return self
77+
78+
79+
class UserModel(BaseModel):
80+
"""
81+
A user defined model.
82+
"""
83+
84+
key: str
85+
name: str
86+
provider: str
87+
generator_id: str
88+
api_key: str | None = None
89+
90+
91+
class UserModels(BaseModel):
92+
"""User models configuration."""
93+
94+
models: list[UserModel] = []
95+
96+
@classmethod
97+
def read(cls) -> "UserModels":
98+
"""Read the user models configuration from the file system or return an empty instance."""
99+
100+
if not USER_MODELS_CONFIG_PATH.exists():
101+
return cls()
102+
103+
with USER_MODELS_CONFIG_PATH.open("r") as f:
104+
return cls.model_validate(YAML().load(f))

dreadnode_cli/defaults.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@
2525
os.getenv("DREADNODE_USER_CONFIG_FILE") or pathlib.Path.home() / ".dreadnode" / "config"
2626
)
2727

28+
# path to the user models configuration file
29+
USER_MODELS_CONFIG_PATH = pathlib.Path(
30+
# allow overriding the user config file via env variable
31+
os.getenv("DREADNODE_USER_CONFIG_FILE") or pathlib.Path.home() / ".dreadnode" / "models.yml"
32+
)
33+
2834
# path to the templates directory
2935
TEMPLATES_PATH = pathlib.Path(
3036
# allow overriding the templates path via env variable

0 commit comments

Comments
 (0)