Skip to content

Commit 5b58112

Browse files
authored
feat: Assign strike runs to groups (ENG-981) (#35)
* Add --group arg to deploy. Add run-groups subcommand. Structure some formatting. Small bug fix for regarding links. * Fix tests
1 parent 6753700 commit 5b58112

File tree

4 files changed

+82
-9
lines changed

4 files changed

+82
-9
lines changed

dreadnode_cli/agent/cli.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
format_agent,
2020
format_agent_versions,
2121
format_run,
22+
format_run_groups,
2223
format_runs,
2324
format_strike_models,
2425
format_strikes,
@@ -68,7 +69,9 @@ def ensure_profile(agent_config: AgentConfig, *, user_config: UserConfig | None
6869
):
6970
print()
7071
raise Exception(
71-
"Agent link does not match the current server profile. Use [bold]dreadnode agent switch[/] or [bold]dreadnode profile switch[/]."
72+
f"Current agent link ([yellow]{agent_config.active_link.profile}[/]) does not match "
73+
f"the current server profile ([magenta]{user_config.active_profile_name}[/]). "
74+
"Use [bold]dreadnode agent switch[/] or [bold]dreadnode profile switch[/]."
7275
)
7376

7477
switch_profile(agent_config.active_link.profile)
@@ -248,7 +251,14 @@ def push(
248251

249252
if agent_config.links and not agent_config.has_link_to_profile(user_config.active_profile_name):
250253
print(f":link: Linking as a fresh agent to the current profile [magenta]{user_config.active_profile_name}[/]")
254+
print()
251255
new = True
256+
elif agent_config.active and agent_config.active_link.profile != user_config.active_profile_name:
257+
raise Exception(
258+
f"Current agent link ([yellow]{agent_config.active_link.profile}[/]) does not match "
259+
f"the current server profile ([magenta]{user_config.active_profile_name}[/]). "
260+
"Use [bold]dreadnode agent switch[/] or [bold]dreadnode profile switch[/]."
261+
)
252262

253263
server_config = user_config.get_server_config()
254264

@@ -372,6 +382,7 @@ def deploy(
372382
] = None,
373383
strike: t.Annotated[str | None, typer.Option("--strike", "-s", help="The strike to use for this run")] = None,
374384
watch: t.Annotated[bool, typer.Option("--watch", "-w", help="Watch the run status")] = True,
385+
group: t.Annotated[str | None, typer.Option("--group", "-g", help="Group to associate this run with")] = None,
375386
) -> None:
376387
agent_config = AgentConfig.read(directory)
377388
ensure_profile(agent_config)
@@ -421,7 +432,7 @@ def deploy(
421432
)
422433

423434
run = client.start_strike_run(
424-
agent.latest_version.id, strike=strike, model=model, user_model=user_model, context=context
435+
agent.latest_version.id, strike=strike, model=model, user_model=user_model, group=group, context=context
425436
)
426437
agent_config.add_run(run.id).write(directory)
427438
formatted = format_run(run, server_url=server_config.url)
@@ -615,6 +626,14 @@ def switch(
615626
print(f":exclamation: '{agent_or_profile}' not found, use [bold]dreadnode agent links[/]")
616627

617628

629+
@cli.command(help="List strike run groups")
630+
@pretty_cli
631+
def run_groups() -> None:
632+
client = api.create_client()
633+
groups = client.list_strike_run_groups()
634+
print(format_run_groups(groups))
635+
636+
618637
@cli.command(help="Clone a github repository", no_args_is_help=True)
619638
@pretty_cli
620639
def clone(

dreadnode_cli/agent/format.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -267,9 +267,17 @@ def format_run(
267267
table.add_column("Property", style="dim")
268268
table.add_column("Value")
269269

270+
table.add_row("key", run.key)
270271
table.add_row("status", Text(run.status, style=get_status_style(run.status)))
271272
table.add_row("strike", f"[magenta]{run.strike_name}[/] ([dim]{run.strike_key}[/])")
272273
table.add_row("type", run.strike_type)
274+
table.add_row("group", Text(run.group_key or "-", style="blue" if run.group_key else ""))
275+
276+
if server_url != "":
277+
table.add_row("", "")
278+
table.add_row(
279+
"url", Text(f"{server_url.rstrip('/')}/strikes/agents/{run.agent_key}/runs/{run.id}", style="cyan")
280+
)
273281

274282
if run.agent_name:
275283
agent_name = f"[bold magenta]{run.agent_name}[/] [[dim]{run.agent_key}[/]]"
@@ -280,10 +288,6 @@ def format_run(
280288
table.add_row("model", run.model.replace(USER_MODEL_PREFIX, "") if run.model else "<default>")
281289
table.add_row("agent", f"{agent_name} ([dim]rev[/] [yellow]{run.agent_revision}[/])")
282290
table.add_row("image", Text(run.agent_version.container.image, style="cyan"))
283-
if server_url != "":
284-
table.add_row(
285-
"run url", Text(f"{server_url.rstrip('/')}/strikes/agents/{run.agent_key}/runs/{run.id}", style="cyan")
286-
)
287291
table.add_row("notes", run.agent_version.notes or "-")
288292

289293
table.add_row("", "")
@@ -314,21 +318,41 @@ def format_run(
314318

315319
def format_runs(runs: list[api.Client.StrikeRunSummaryResponse]) -> RenderableType:
316320
table = Table(box=box.ROUNDED)
317-
table.add_column("id", style="dim")
321+
table.add_column("key", style="dim")
318322
table.add_column("agent")
319323
table.add_column("status")
320324
table.add_column("model")
325+
table.add_column("group")
321326
table.add_column("started")
322327
table.add_column("duration")
323328

324329
for run in runs:
325330
table.add_row(
326-
str(run.id),
331+
run.key,
327332
f"[bold magenta]{run.agent_key}[/] [dim]:[/] [yellow]{run.agent_revision}[/]",
328333
Text(run.status, style="bold " + get_status_style(run.status)),
329334
Text(run.model.replace(USER_MODEL_PREFIX, "") if run.model else "-"),
335+
Text(run.group_key or "-", style="blue" if run.group_key else "dim"),
330336
format_time(run.start),
331337
Text(format_duration(run.start, run.end), style="bold cyan"),
332338
)
333339

334340
return table
341+
342+
343+
def format_run_groups(groups: list[api.Client.StrikeRunGroupResponse]) -> RenderableType:
344+
table = Table(box=box.ROUNDED)
345+
table.add_column("Name", style="bold cyan")
346+
table.add_column("description")
347+
table.add_column("runs", style="yellow")
348+
table.add_column("created", style="dim")
349+
350+
for group in groups:
351+
table.add_row(
352+
group.key,
353+
group.description or "-",
354+
str(group.run_count),
355+
group.created_at.astimezone().strftime("%c"),
356+
)
357+
358+
return table

dreadnode_cli/agent/tests/test_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def test_ensure_profile() -> None:
121121
agent_config.add_link("test-main", UUID("00000000-0000-0000-0000-000000000000"), "main")
122122
agent_config.active = "test-other"
123123
with patch("rich.prompt.Prompt.ask", return_value="n"):
124-
with pytest.raises(Exception, match="Agent link does not match the current server profile"):
124+
with pytest.raises(Exception, match="Current agent link"):
125125
ensure_profile(agent_config, user_config=user_config)
126126

127127
# We should switch if the user agrees

dreadnode_cli/api.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,16 @@ class Container(BaseModel):
293293
env: dict[str, str]
294294
name: str | None
295295

296+
class StrikeMetricPoint(BaseModel):
297+
timestamp: datetime
298+
value: float
299+
metadata: dict[str, t.Any]
300+
301+
class StrikeMetric(BaseModel):
302+
type: str
303+
description: str | None
304+
points: "list[Client.StrikeMetricPoint]"
305+
296306
class StrikeAgentVersion(BaseModel):
297307
id: UUID
298308
created_at: datetime
@@ -350,6 +360,7 @@ class StrikeRunZone(_StrikeRunZone):
350360
container_logs: dict[str, str]
351361
outputs: list["Client.StrikeRunOutput"]
352362
inferences: list[dict[str, t.Any]]
363+
metrics: dict[str, "Client.StrikeMetric"]
353364

354365
class StrikeRunContext(BaseModel):
355366
environment: dict[str, str] | None = None
@@ -358,6 +369,7 @@ class StrikeRunContext(BaseModel):
358369

359370
class _StrikeRun(BaseModel):
360371
id: UUID
372+
key: str
361373
strike_id: UUID
362374
strike_key: str
363375
strike_name: str
@@ -373,6 +385,9 @@ class _StrikeRun(BaseModel):
373385
status: "Client.StrikeRunStatus"
374386
start: datetime | None
375387
end: datetime | None
388+
group_id: UUID | None
389+
group_key: str | None
390+
group_name: str | None
376391

377392
def is_running(self) -> bool:
378393
return self.status in ["pending", "deploying", "running"]
@@ -388,6 +403,15 @@ class UserModel(BaseModel):
388403
generator_id: str
389404
api_key: str
390405

406+
class StrikeRunGroupResponse(BaseModel):
407+
id: UUID
408+
key: str
409+
name: str
410+
description: str | None
411+
created_at: datetime
412+
updated_at: datetime
413+
run_count: int
414+
391415
def get_strike(self, strike: str) -> StrikeResponse:
392416
response = self.request("GET", f"/api/strikes/{strike}")
393417
return self.StrikeResponse(**response.json())
@@ -448,6 +472,7 @@ def start_strike_run(
448472
user_model: UserModel | None = None,
449473
context: StrikeRunContext | None = None,
450474
strike: UUID | str | None = None,
475+
group: UUID | str | None = None,
451476
) -> StrikeRunResponse:
452477
response = self.request(
453478
"POST",
@@ -457,6 +482,7 @@ def start_strike_run(
457482
"model": model,
458483
"user_model": user_model.model_dump(mode="json") if user_model else None,
459484
"strike": str(strike) if strike else None,
485+
"group": str(group) if group else None,
460486
"context": context.model_dump(mode="json") if context else None,
461487
},
462488
)
@@ -472,6 +498,10 @@ def list_strike_runs(self, *, strike_id: UUID | str | None = None) -> list[Strik
472498
)
473499
return [self.StrikeRunSummaryResponse(**run) for run in response.json()]
474500

501+
def list_strike_run_groups(self) -> list[StrikeRunGroupResponse]:
502+
response = self.request("GET", "/api/strikes/groups")
503+
return [self.StrikeRunGroupResponse(**group) for group in response.json()]
504+
475505

476506
def create_client(*, profile: str | None = None) -> Client:
477507
"""Create an authenticated API client using stored configuration data."""

0 commit comments

Comments
 (0)