Skip to content

Commit 6753700

Browse files
authored
eng 986: cli supply agent parameters with runs (#36)
* new: implemented run context * fix: fixed typing bug * new: showing run context if set (ENG-989) * improved run context format * docs: documented run context data * fix: lint driven fix
1 parent 18a69f1 commit 6753700

File tree

7 files changed

+108
-2
lines changed

7 files changed

+108
-2
lines changed

CLI.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ $ dreadnode agent deploy [OPTIONS]
8585

8686
* `-m, --model TEXT`: The inference model to use for this run
8787
* `-d, --dir DIRECTORY`: The agent directory [default: .]
88+
* `-e, --env-var TEXT`: Environment vars to override for this run (key=value)
89+
* `-p, --param TEXT`: Define custom parameters for this run (key = value in toml syntax or @filename.toml for multiple values)
90+
* `-c, --command TEXT`: Override the container command for this run.
8891
* `-s, --strike TEXT`: The strike to use for this run
8992
* `-w, --watch`: Watch the run status [default: True]
9093
* `--help`: Show this message and exit.

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,18 @@ dreadnode agent push
165165
# start a new run using the latest agent version.
166166
dreadnode agent deploy
167167

168+
# start a new run using the latest agent version with custom environment variables
169+
dreadnode agent deploy --env-var TEST_ENV=test --env-var ANOTHER_ENV=another_value
170+
171+
# start a new run using the latest agent version with custom parameters (using toml syntax)
172+
dreadnode agent deploy --param "foo = 'bar'" --param "baz = 123.0"
173+
174+
# start a new run using the latest agent version with custom parameters from a toml file
175+
dreadnode agent deploy --param @parameters.toml
176+
177+
# start a new run using the latest agent version and override the container command
178+
dreadnode agent deploy --command "echo 'Hello, world!'"
179+
168180
# show the latest run of the currently active agent
169181
dreadnode agent latest
170182

dreadnode_cli/agent/cli.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import time
55
import typing as t
66

7+
import toml
78
import typer
89
from rich import box, print
910
from rich.live import Live
@@ -318,6 +319,31 @@ def push(
318319
print(":tada: Agent pushed. use [bold]dreadnode agent deploy[/] to start a new run.")
319320

320321

322+
def prepare_run_context(
323+
env_vars: list[str] | None, parameters: list[str] | None, command: str | None
324+
) -> Client.StrikeRunContext | None:
325+
if not env_vars and not parameters and not command:
326+
return None
327+
328+
context = Client.StrikeRunContext()
329+
330+
if env_vars:
331+
context.environment = {env_var.split("=")[0]: env_var.split("=")[1] for env_var in env_vars}
332+
333+
if parameters:
334+
context.parameters = {}
335+
for param in parameters:
336+
if param.startswith("@"):
337+
context.parameters.update(toml.load(open(param[1:])))
338+
else:
339+
context.parameters.update(toml.loads(param))
340+
341+
if command:
342+
context.command = command
343+
344+
return context
345+
346+
321347
@cli.command(help="Start a new run using the latest active agent version")
322348
@pretty_cli
323349
def deploy(
@@ -328,6 +354,22 @@ def deploy(
328354
pathlib.Path,
329355
typer.Option("--dir", "-d", help="The agent directory", file_okay=False, resolve_path=True),
330356
] = pathlib.Path("."),
357+
env_vars: t.Annotated[
358+
list[str] | None,
359+
typer.Option("--env-var", "-e", help="Environment vars to override for this run (key=value)"),
360+
] = None,
361+
parameters: t.Annotated[
362+
list[str] | None,
363+
typer.Option(
364+
"--param",
365+
"-p",
366+
help="Define custom parameters for this run (key = value in toml syntax or @filename.toml for multiple values)",
367+
),
368+
] = None,
369+
command: t.Annotated[
370+
str | None,
371+
typer.Option("--command", "-c", help="Override the container command for this run."),
372+
] = None,
331373
strike: t.Annotated[str | None, typer.Option("--strike", "-s", help="The strike to use for this run")] = None,
332374
watch: t.Annotated[bool, typer.Option("--watch", "-w", help="Watch the run status")] = True,
333375
) -> None:
@@ -346,6 +388,8 @@ def deploy(
346388
if strike is None:
347389
raise Exception("No strike specified, use -s/--strike or set the strike in the agent config")
348390

391+
context = prepare_run_context(env_vars, parameters, command)
392+
349393
user_models = UserModels.read()
350394
user_model: Client.UserModel | None = None
351395

@@ -376,7 +420,9 @@ def deploy(
376420
f"Model '{model}' is not user-defined nor is it available in strike '{strike_response.name}'"
377421
)
378422

379-
run = client.start_strike_run(agent.latest_version.id, strike=strike, model=model, user_model=user_model)
423+
run = client.start_strike_run(
424+
agent.latest_version.id, strike=strike, model=model, user_model=user_model, context=context
425+
)
380426
agent_config.add_run(run.id).write(directory)
381427
formatted = format_run(run, server_url=server_config.url)
382428

dreadnode_cli/agent/format.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,19 @@ def format_run(
291291
table.add_row("start", format_time(run.start))
292292
table.add_row("end", format_time(run.end))
293293

294+
if run.context and (run.context.environment or run.context.parameters or run.context.command):
295+
table.add_row("", "")
296+
if run.context.environment:
297+
table.add_row(
298+
"environment", " ".join(f"[magenta]{k}[/]=[yellow]{v}[/]" for k, v in run.context.environment.items())
299+
)
300+
if run.context.parameters:
301+
table.add_row(
302+
"parameters", " ".join(f"[magenta]{k}[/]=[yellow]{v}[/]" for k, v in run.context.parameters.items())
303+
)
304+
if run.context.command:
305+
table.add_row("command", f"[bold][red]{run.context.command}[/red][/bold]")
306+
294307
components: list[RenderableType] = [
295308
table,
296309
format_zones_verbose(run.zones, include_logs=include_logs) if verbose else format_zones_summary(run.zones),

dreadnode_cli/api.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,11 @@ class StrikeRunZone(_StrikeRunZone):
351351
outputs: list["Client.StrikeRunOutput"]
352352
inferences: list[dict[str, t.Any]]
353353

354+
class StrikeRunContext(BaseModel):
355+
environment: dict[str, str] | None = None
356+
parameters: dict[str, t.Any] | None = None
357+
command: str | None = None
358+
354359
class _StrikeRun(BaseModel):
355360
id: UUID
356361
strike_id: UUID
@@ -364,6 +369,7 @@ class _StrikeRun(BaseModel):
364369
agent_name: str | None = None
365370
agent_revision: int
366371
agent_version: "Client.StrikeAgentVersion"
372+
context: "Client.StrikeRunContext | None" = None
367373
status: "Client.StrikeRunStatus"
368374
start: datetime | None
369375
end: datetime | None
@@ -440,6 +446,7 @@ def start_strike_run(
440446
*,
441447
model: str | None = None,
442448
user_model: UserModel | None = None,
449+
context: StrikeRunContext | None = None,
443450
strike: UUID | str | None = None,
444451
) -> StrikeRunResponse:
445452
response = self.request(
@@ -450,6 +457,7 @@ def start_strike_run(
450457
"model": model,
451458
"user_model": user_model.model_dump(mode="json") if user_model else None,
452459
"strike": str(strike) if strike else None,
460+
"context": context.model_dump(mode="json") if context else None,
453461
},
454462
)
455463
return self.StrikeRunResponse(**response.json())

poetry.lock

Lines changed: 23 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ httpx = "^0.27.2"
2020
ruamel-yaml = "^0.18.6"
2121
docker = "^7.1.0"
2222
pydantic-yaml = "^1.4.0"
23+
toml = "^0.10.2"
24+
types-toml = "^0.10.8.20240310"
2325

2426
[tool.pytest.ini_options]
2527
asyncio_mode = "auto"

0 commit comments

Comments
 (0)