Skip to content

Commit bb18530

Browse files
committed
Add --group arg to deploy. Add run-groups subcommand. Structure some formatting. Small bug fix for regarding links.
1 parent 18a69f1 commit bb18530

File tree

3 files changed

+83
-8
lines changed

3 files changed

+83
-8
lines changed

dreadnode_cli/agent/cli.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
format_agent,
1919
format_agent_versions,
2020
format_run,
21+
format_run_groups,
2122
format_runs,
2223
format_strike_models,
2324
format_strikes,
@@ -67,7 +68,9 @@ def ensure_profile(agent_config: AgentConfig, *, user_config: UserConfig | None
6768
):
6869
print()
6970
raise Exception(
70-
"Agent link does not match the current server profile. Use [bold]dreadnode agent switch[/] or [bold]dreadnode profile switch[/]."
71+
f"Current agent link ([yellow]{agent_config.active_link.profile}[/]) does not match "
72+
f"the current server profile ([magenta]{user_config.active_profile_name}[/]). "
73+
"Use [bold]dreadnode agent switch[/] or [bold]dreadnode profile switch[/]."
7174
)
7275

7376
switch_profile(agent_config.active_link.profile)
@@ -247,7 +250,14 @@ def push(
247250

248251
if agent_config.links and not agent_config.has_link_to_profile(user_config.active_profile_name):
249252
print(f":link: Linking as a fresh agent to the current profile [magenta]{user_config.active_profile_name}[/]")
253+
print()
250254
new = True
255+
elif agent_config.active and agent_config.active_link.profile != user_config.active_profile_name:
256+
raise Exception(
257+
f"Current agent link ([yellow]{agent_config.active_link.profile}[/]) does not match "
258+
f"the current server profile ([magenta]{user_config.active_profile_name}[/]). "
259+
"Use [bold]dreadnode agent switch[/] or [bold]dreadnode profile switch[/]."
260+
)
251261

252262
server_config = user_config.get_server_config()
253263

@@ -330,6 +340,7 @@ def deploy(
330340
] = pathlib.Path("."),
331341
strike: t.Annotated[str | None, typer.Option("--strike", "-s", help="The strike to use for this run")] = None,
332342
watch: t.Annotated[bool, typer.Option("--watch", "-w", help="Watch the run status")] = True,
343+
group: t.Annotated[str | None, typer.Option("--group", "-g", help="Group to associate this run with")] = None,
333344
) -> None:
334345
agent_config = AgentConfig.read(directory)
335346
ensure_profile(agent_config)
@@ -376,7 +387,9 @@ def deploy(
376387
f"Model '{model}' is not user-defined nor is it available in strike '{strike_response.name}'"
377388
)
378389

379-
run = client.start_strike_run(agent.latest_version.id, strike=strike, model=model, user_model=user_model)
390+
run = client.start_strike_run(
391+
agent.latest_version.id, strike=strike, model=model, user_model=user_model, group=group
392+
)
380393
agent_config.add_run(run.id).write(directory)
381394
formatted = format_run(run, server_url=server_config.url)
382395

@@ -569,6 +582,14 @@ def switch(
569582
print(f":exclamation: '{agent_or_profile}' not found, use [bold]dreadnode agent links[/]")
570583

571584

585+
@cli.command(help="List strike run groups")
586+
@pretty_cli
587+
def run_groups() -> None:
588+
client = api.create_client()
589+
groups = client.list_strike_run_groups()
590+
print(format_run_groups(groups))
591+
592+
572593
@cli.command(help="Clone a github repository", no_args_is_help=True)
573594
@pretty_cli
574595
def clone(

dreadnode_cli/agent/format.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -267,9 +267,17 @@ def format_run(
267267
table.add_column("Property", style="dim")
268268
table.add_column("Value")
269269

270+
table.add_row("key", run.key)
270271
table.add_row("status", Text(run.status, style=get_status_style(run.status)))
271272
table.add_row("strike", f"[magenta]{run.strike_name}[/] ([dim]{run.strike_key}[/])")
272273
table.add_row("type", run.strike_type)
274+
table.add_row("group", Text(run.group_key or "-", style="blue" if run.group_key else ""))
275+
276+
if server_url != "":
277+
table.add_row("", "")
278+
table.add_row(
279+
"url", Text(f"{server_url.rstrip('/')}/strikes/agents/{run.agent_key}/runs/{run.id}", style="cyan")
280+
)
273281

274282
if run.agent_name:
275283
agent_name = f"[bold magenta]{run.agent_name}[/] [[dim]{run.agent_key}[/]]"
@@ -280,10 +288,6 @@ def format_run(
280288
table.add_row("model", run.model.replace(USER_MODEL_PREFIX, "") if run.model else "<default>")
281289
table.add_row("agent", f"{agent_name} ([dim]rev[/] [yellow]{run.agent_revision}[/])")
282290
table.add_row("image", Text(run.agent_version.container.image, style="cyan"))
283-
if server_url != "":
284-
table.add_row(
285-
"run url", Text(f"{server_url.rstrip('/')}/strikes/agents/{run.agent_key}/runs/{run.id}", style="cyan")
286-
)
287291
table.add_row("notes", run.agent_version.notes or "-")
288292

289293
table.add_row("", "")
@@ -301,21 +305,41 @@ def format_run(
301305

302306
def format_runs(runs: list[api.Client.StrikeRunSummaryResponse]) -> RenderableType:
303307
table = Table(box=box.ROUNDED)
304-
table.add_column("id", style="dim")
308+
table.add_column("key", style="dim")
305309
table.add_column("agent")
306310
table.add_column("status")
307311
table.add_column("model")
312+
table.add_column("group")
308313
table.add_column("started")
309314
table.add_column("duration")
310315

311316
for run in runs:
312317
table.add_row(
313-
str(run.id),
318+
run.key,
314319
f"[bold magenta]{run.agent_key}[/] [dim]:[/] [yellow]{run.agent_revision}[/]",
315320
Text(run.status, style="bold " + get_status_style(run.status)),
316321
Text(run.model.replace(USER_MODEL_PREFIX, "") if run.model else "-"),
322+
Text(run.group_key or "-", style="blue" if run.group_key else "dim"),
317323
format_time(run.start),
318324
Text(format_duration(run.start, run.end), style="bold cyan"),
319325
)
320326

321327
return table
328+
329+
330+
def format_run_groups(groups: list[api.Client.StrikeRunGroupResponse]) -> RenderableType:
331+
table = Table(box=box.ROUNDED)
332+
table.add_column("Name", style="bold cyan")
333+
table.add_column("description")
334+
table.add_column("runs", style="yellow")
335+
table.add_column("created", style="dim")
336+
337+
for group in groups:
338+
table.add_row(
339+
group.key,
340+
group.description or "-",
341+
str(group.run_count),
342+
group.created_at.astimezone().strftime("%c"),
343+
)
344+
345+
return table

dreadnode_cli/api.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,16 @@ class Container(BaseModel):
293293
env: dict[str, str]
294294
name: str | None
295295

296+
class StrikeMetricPoint(BaseModel):
297+
timestamp: datetime
298+
value: float
299+
metadata: dict[str, t.Any]
300+
301+
class StrikeMetric(BaseModel):
302+
type: str
303+
description: str | None
304+
points: "list[Client.StrikeMetricPoint]"
305+
296306
class StrikeAgentVersion(BaseModel):
297307
id: UUID
298308
created_at: datetime
@@ -350,9 +360,11 @@ class StrikeRunZone(_StrikeRunZone):
350360
container_logs: dict[str, str]
351361
outputs: list["Client.StrikeRunOutput"]
352362
inferences: list[dict[str, t.Any]]
363+
metrics: dict[str, "Client.StrikeMetric"]
353364

354365
class _StrikeRun(BaseModel):
355366
id: UUID
367+
key: str
356368
strike_id: UUID
357369
strike_key: str
358370
strike_name: str
@@ -367,6 +379,9 @@ class _StrikeRun(BaseModel):
367379
status: "Client.StrikeRunStatus"
368380
start: datetime | None
369381
end: datetime | None
382+
group_id: UUID | None
383+
group_key: str | None
384+
group_name: str | None
370385

371386
def is_running(self) -> bool:
372387
return self.status in ["pending", "deploying", "running"]
@@ -382,6 +397,15 @@ class UserModel(BaseModel):
382397
generator_id: str
383398
api_key: str
384399

400+
class StrikeRunGroupResponse(BaseModel):
401+
id: UUID
402+
key: str
403+
name: str
404+
description: str | None
405+
created_at: datetime
406+
updated_at: datetime
407+
run_count: int
408+
385409
def get_strike(self, strike: str) -> StrikeResponse:
386410
response = self.request("GET", f"/api/strikes/{strike}")
387411
return self.StrikeResponse(**response.json())
@@ -441,6 +465,7 @@ def start_strike_run(
441465
model: str | None = None,
442466
user_model: UserModel | None = None,
443467
strike: UUID | str | None = None,
468+
group: UUID | str | None = None,
444469
) -> StrikeRunResponse:
445470
response = self.request(
446471
"POST",
@@ -450,6 +475,7 @@ def start_strike_run(
450475
"model": model,
451476
"user_model": user_model.model_dump(mode="json") if user_model else None,
452477
"strike": str(strike) if strike else None,
478+
"group": str(group) if group else None,
453479
},
454480
)
455481
return self.StrikeRunResponse(**response.json())
@@ -464,6 +490,10 @@ def list_strike_runs(self, *, strike_id: UUID | str | None = None) -> list[Strik
464490
)
465491
return [self.StrikeRunSummaryResponse(**run) for run in response.json()]
466492

493+
def list_strike_run_groups(self) -> list[StrikeRunGroupResponse]:
494+
response = self.request("GET", "/api/strikes/groups")
495+
return [self.StrikeRunGroupResponse(**group) for group in response.json()]
496+
467497

468498
def create_client(*, profile: str | None = None) -> Client:
469499
"""Create an authenticated API client using stored configuration data."""

0 commit comments

Comments
 (0)