Skip to content

Commit 29472a9

Browse files
authored
(torchx/specs) Add intermediate helper method to parse component arguments
Differential Revision: D80215193 Pull Request resolved: #1097
1 parent cc41de6 commit 29472a9

File tree

2 files changed

+133
-24
lines changed

2 files changed

+133
-24
lines changed

torchx/specs/builders.py

Lines changed: 101 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import inspect
1111
import os
1212
from argparse import Namespace
13-
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
13+
from typing import Any, Callable, Dict, List, Mapping, NamedTuple, Optional, Union
1414

1515
from torchx.specs.api import BindMount, MountType, VolumeMount
1616
from torchx.specs.file_linter import get_fn_docstring, TorchXArgumentHelpFormatter
@@ -19,6 +19,14 @@
1919
from .api import AppDef, DeviceMount
2020

2121

22+
class ComponentArgs(NamedTuple):
23+
"""Parsed component function arguments"""
24+
25+
positional_args: dict[str, Any]
26+
var_args: list[str]
27+
kwargs: dict[str, Any]
28+
29+
2230
def _create_args_parser(
2331
cmpnt_fn: Callable[..., AppDef],
2432
cmpnt_defaults: Optional[Dict[str, str]] = None,
@@ -140,6 +148,91 @@ def parse_args(
140148
return parsed_args
141149

142150

151+
def component_args_from_str(
152+
cmpnt_fn: Callable[..., Any], # pyre-fixme[2]: Enforce AppDef type
153+
cmpnt_args: list[str],
154+
cmpnt_args_defaults: Optional[Dict[str, Any]] = None,
155+
config: Optional[Dict[str, Any]] = None,
156+
) -> ComponentArgs:
157+
"""
158+
Parses and decodes command-line arguments for a component function.
159+
160+
This function takes a component function and its arguments, parses them using argparse,
161+
and decodes the arguments into their expected types based on the function's signature.
162+
It separates positional arguments, variable positional arguments (*args), and keyword-only arguments.
163+
164+
Args:
165+
cmpnt_fn: The component function whose arguments are to be parsed and decoded.
166+
cmpnt_args: List of command-line arguments to be parsed. Supports both space separated and '=' separated arguments.
167+
cmpnt_args_defaults: Optional dictionary of default values for the component function's parameters.
168+
config: Optional dictionary containing additional configuration values.
169+
170+
Returns:
171+
ComponentArgs representing the input args to a component function containing:
172+
- positional_args: Dictionary of positional and positional-or-keyword arguments.
173+
- var_args: List of variable positional arguments (*args).
174+
- kwargs: Dictionary of keyword-only arguments.
175+
176+
Usage:
177+
178+
.. doctest::
179+
from torchx.specs.api import AppDef
180+
from torchx.specs.builders import component_args_from_str
181+
182+
def example_component_fn(foo: str, *args: str, bar: str = "asdf") -> AppDef:
183+
return AppDef(name="example")
184+
185+
# Supports space separated arguments
186+
args = ["--foo", "fooval", "--bar", "barval", "arg1", "arg2"]
187+
parsed_args = component_args_from_str(example_component_fn, args)
188+
189+
assert parsed_args.positional_args == {"foo": "fooval"}
190+
assert parsed_args.var_args == ["arg1", "arg2"]
191+
assert parsed_args.kwargs == {"bar": "barval"}
192+
193+
# Supports '=' separated arguments
194+
args = ["--foo=fooval", "--bar=barval", "arg1", "arg2"]
195+
parsed_args = component_args_from_str(example_component_fn, args)
196+
197+
assert parsed_args.positional_args == {"foo": "fooval"}
198+
assert parsed_args.var_args == ["arg1", "arg2"]
199+
assert parsed_args.kwargs == {"bar": "barval"}
200+
201+
202+
"""
203+
parsed_args: Namespace = parse_args(
204+
cmpnt_fn, cmpnt_args, cmpnt_args_defaults, config
205+
)
206+
207+
positional_args = {}
208+
var_args = []
209+
kwargs = {}
210+
211+
parameters = inspect.signature(cmpnt_fn).parameters
212+
for param_name, parameter in parameters.items():
213+
arg_value = getattr(parsed_args, param_name)
214+
parameter_type = parameter.annotation
215+
parameter_type = decode_optional(parameter_type)
216+
arg_value = decode(arg_value, parameter_type)
217+
if parameter.kind == inspect.Parameter.VAR_POSITIONAL:
218+
var_args = arg_value
219+
elif parameter.kind == inspect.Parameter.KEYWORD_ONLY:
220+
kwargs[param_name] = arg_value
221+
elif parameter.kind == inspect.Parameter.VAR_KEYWORD:
222+
raise TypeError(
223+
f"component fn param `{param_name}` is a '**kwargs' which is not supported; consider changing the "
224+
f"type to a dict or explicitly declare the params"
225+
)
226+
else:
227+
# POSITIONAL or POSITIONAL_OR_KEYWORD
228+
positional_args[param_name] = arg_value
229+
230+
if len(var_args) > 0 and var_args[0] == "--":
231+
var_args = var_args[1:]
232+
233+
return ComponentArgs(positional_args, var_args, kwargs)
234+
235+
143236
def materialize_appdef(
144237
cmpnt_fn: Callable[..., Any], # pyre-ignore[2]
145238
cmpnt_args: List[str],
@@ -174,30 +267,14 @@ def materialize_appdef(
174267
An application spec
175268
"""
176269

177-
function_args = []
178-
var_arg = []
179-
kwargs = {}
180-
181-
parsed_args = parse_args(cmpnt_fn, cmpnt_args, cmpnt_defaults, config)
182-
183-
parameters = inspect.signature(cmpnt_fn).parameters
184-
for param_name, parameter in parameters.items():
185-
arg_value = getattr(parsed_args, param_name)
186-
parameter_type = parameter.annotation
187-
parameter_type = decode_optional(parameter_type)
188-
arg_value = decode(arg_value, parameter_type)
189-
if parameter.kind == inspect.Parameter.VAR_POSITIONAL:
190-
var_arg = arg_value
191-
elif parameter.kind == inspect.Parameter.KEYWORD_ONLY:
192-
kwargs[param_name] = arg_value
193-
elif parameter.kind == inspect.Parameter.VAR_KEYWORD:
194-
raise TypeError("**kwargs are not supported for component definitions")
195-
else:
196-
function_args.append(arg_value)
197-
if len(var_arg) > 0 and var_arg[0] == "--":
198-
var_arg = var_arg[1:]
270+
component_args: ComponentArgs = component_args_from_str(
271+
cmpnt_fn, cmpnt_args, cmpnt_defaults, config
272+
)
273+
positional_arg_values = list(component_args.positional_args.values())
274+
appdef = cmpnt_fn(
275+
*positional_arg_values, *component_args.var_args, **component_args.kwargs
276+
)
199277

200-
appdef = cmpnt_fn(*function_args, *var_arg, **kwargs)
201278
if not isinstance(appdef, AppDef):
202279
raise TypeError(
203280
f"Expected a component that returns `AppDef`, but got `{type(appdef)}`"

torchx/specs/test/builders_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from torchx.specs.builders import (
1919
_create_args_parser,
2020
BindMount,
21+
component_args_from_str,
22+
ComponentArgs,
2123
DeviceMount,
2224
make_app_handle,
2325
materialize_appdef,
@@ -281,6 +283,36 @@ def _get_app_args_and_defaults_with_nested_objects(
281283
*role_args,
282284
], defaults
283285

286+
def test_component_args_from_str(self) -> None:
287+
component_fn_args = [
288+
"--foo",
289+
"fooval",
290+
"--bar",
291+
"barval",
292+
"arg1",
293+
"arg2",
294+
]
295+
parsed_args: ComponentArgs = component_args_from_str(
296+
example_var_args, component_fn_args
297+
)
298+
self.assertEqual(parsed_args.positional_args, {"foo": "fooval"})
299+
self.assertEqual(parsed_args.var_args, ["arg1", "arg2"])
300+
self.assertEqual(parsed_args.kwargs, {"bar": "barval"})
301+
302+
def test_component_args_from_str_equals_separated(self) -> None:
303+
component_fn_args = [
304+
"--foo=fooval",
305+
"--bar=barval",
306+
"arg1",
307+
"arg2",
308+
]
309+
parsed_args: ComponentArgs = component_args_from_str(
310+
example_var_args, component_fn_args
311+
)
312+
self.assertEqual(parsed_args.positional_args, {"foo": "fooval"})
313+
self.assertEqual(parsed_args.var_args, ["arg1", "arg2"])
314+
self.assertEqual(parsed_args.kwargs, {"bar": "barval"})
315+
284316
def test_load_from_fn_empty(self) -> None:
285317
actual_app = materialize_appdef(example_empty_fn, [])
286318
expected_app = get_dummy_application("trainer")

0 commit comments

Comments
 (0)