|
10 | 10 | import inspect
|
11 | 11 | import os
|
12 | 12 | from argparse import Namespace
|
13 |
| -from typing import Any, Callable, Dict, List, Mapping, Optional, Union |
| 13 | +from typing import Any, Callable, Dict, List, Mapping, NamedTuple, Optional, Union |
14 | 14 |
|
15 | 15 | from torchx.specs.api import BindMount, MountType, VolumeMount
|
16 | 16 | from torchx.specs.file_linter import get_fn_docstring, TorchXArgumentHelpFormatter
|
|
19 | 19 | from .api import AppDef, DeviceMount
|
20 | 20 |
|
21 | 21 |
|
| 22 | +class ComponentArgs(NamedTuple): |
| 23 | + """Parsed component function arguments""" |
| 24 | + |
| 25 | + positional_args: dict[str, Any] |
| 26 | + var_args: list[str] |
| 27 | + kwargs: dict[str, Any] |
| 28 | + |
| 29 | + |
22 | 30 | def _create_args_parser(
|
23 | 31 | cmpnt_fn: Callable[..., AppDef],
|
24 | 32 | cmpnt_defaults: Optional[Dict[str, str]] = None,
|
@@ -140,6 +148,91 @@ def parse_args(
|
140 | 148 | return parsed_args
|
141 | 149 |
|
142 | 150 |
|
| 151 | +def component_args_from_str( |
| 152 | + cmpnt_fn: Callable[..., Any], # pyre-fixme[2]: Enforce AppDef type |
| 153 | + cmpnt_args: list[str], |
| 154 | + cmpnt_args_defaults: Optional[Dict[str, Any]] = None, |
| 155 | + config: Optional[Dict[str, Any]] = None, |
| 156 | +) -> ComponentArgs: |
| 157 | + """ |
| 158 | + Parses and decodes command-line arguments for a component function. |
| 159 | +
|
| 160 | + This function takes a component function and its arguments, parses them using argparse, |
| 161 | + and decodes the arguments into their expected types based on the function's signature. |
| 162 | + It separates positional arguments, variable positional arguments (*args), and keyword-only arguments. |
| 163 | +
|
| 164 | + Args: |
| 165 | + cmpnt_fn: The component function whose arguments are to be parsed and decoded. |
| 166 | + cmpnt_args: List of command-line arguments to be parsed. Supports both space separated and '=' separated arguments. |
| 167 | + cmpnt_args_defaults: Optional dictionary of default values for the component function's parameters. |
| 168 | + config: Optional dictionary containing additional configuration values. |
| 169 | +
|
| 170 | + Returns: |
| 171 | + ComponentArgs representing the input args to a component function containing: |
| 172 | + - positional_args: Dictionary of positional and positional-or-keyword arguments. |
| 173 | + - var_args: List of variable positional arguments (*args). |
| 174 | + - kwargs: Dictionary of keyword-only arguments. |
| 175 | +
|
| 176 | + Usage: |
| 177 | +
|
| 178 | + .. doctest:: |
| 179 | + from torchx.specs.api import AppDef |
| 180 | + from torchx.specs.builders import component_args_from_str |
| 181 | +
|
| 182 | + def example_component_fn(foo: str, *args: str, bar: str = "asdf") -> AppDef: |
| 183 | + return AppDef(name="example") |
| 184 | +
|
| 185 | + # Supports space separated arguments |
| 186 | + args = ["--foo", "fooval", "--bar", "barval", "arg1", "arg2"] |
| 187 | + parsed_args = component_args_from_str(example_component_fn, args) |
| 188 | +
|
| 189 | + assert parsed_args.positional_args == {"foo": "fooval"} |
| 190 | + assert parsed_args.var_args == ["arg1", "arg2"] |
| 191 | + assert parsed_args.kwargs == {"bar": "barval"} |
| 192 | +
|
| 193 | + # Supports '=' separated arguments |
| 194 | + args = ["--foo=fooval", "--bar=barval", "arg1", "arg2"] |
| 195 | + parsed_args = component_args_from_str(example_component_fn, args) |
| 196 | +
|
| 197 | + assert parsed_args.positional_args == {"foo": "fooval"} |
| 198 | + assert parsed_args.var_args == ["arg1", "arg2"] |
| 199 | + assert parsed_args.kwargs == {"bar": "barval"} |
| 200 | +
|
| 201 | +
|
| 202 | + """ |
| 203 | + parsed_args: Namespace = parse_args( |
| 204 | + cmpnt_fn, cmpnt_args, cmpnt_args_defaults, config |
| 205 | + ) |
| 206 | + |
| 207 | + positional_args = {} |
| 208 | + var_args = [] |
| 209 | + kwargs = {} |
| 210 | + |
| 211 | + parameters = inspect.signature(cmpnt_fn).parameters |
| 212 | + for param_name, parameter in parameters.items(): |
| 213 | + arg_value = getattr(parsed_args, param_name) |
| 214 | + parameter_type = parameter.annotation |
| 215 | + parameter_type = decode_optional(parameter_type) |
| 216 | + arg_value = decode(arg_value, parameter_type) |
| 217 | + if parameter.kind == inspect.Parameter.VAR_POSITIONAL: |
| 218 | + var_args = arg_value |
| 219 | + elif parameter.kind == inspect.Parameter.KEYWORD_ONLY: |
| 220 | + kwargs[param_name] = arg_value |
| 221 | + elif parameter.kind == inspect.Parameter.VAR_KEYWORD: |
| 222 | + raise TypeError( |
| 223 | + f"component fn param `{param_name}` is a '**kwargs' which is not supported; consider changing the " |
| 224 | + f"type to a dict or explicitly declare the params" |
| 225 | + ) |
| 226 | + else: |
| 227 | + # POSITIONAL or POSITIONAL_OR_KEYWORD |
| 228 | + positional_args[param_name] = arg_value |
| 229 | + |
| 230 | + if len(var_args) > 0 and var_args[0] == "--": |
| 231 | + var_args = var_args[1:] |
| 232 | + |
| 233 | + return ComponentArgs(positional_args, var_args, kwargs) |
| 234 | + |
| 235 | + |
143 | 236 | def materialize_appdef(
|
144 | 237 | cmpnt_fn: Callable[..., Any], # pyre-ignore[2]
|
145 | 238 | cmpnt_args: List[str],
|
@@ -174,30 +267,14 @@ def materialize_appdef(
|
174 | 267 | An application spec
|
175 | 268 | """
|
176 | 269 |
|
177 |
| - function_args = [] |
178 |
| - var_arg = [] |
179 |
| - kwargs = {} |
180 |
| - |
181 |
| - parsed_args = parse_args(cmpnt_fn, cmpnt_args, cmpnt_defaults, config) |
182 |
| - |
183 |
| - parameters = inspect.signature(cmpnt_fn).parameters |
184 |
| - for param_name, parameter in parameters.items(): |
185 |
| - arg_value = getattr(parsed_args, param_name) |
186 |
| - parameter_type = parameter.annotation |
187 |
| - parameter_type = decode_optional(parameter_type) |
188 |
| - arg_value = decode(arg_value, parameter_type) |
189 |
| - if parameter.kind == inspect.Parameter.VAR_POSITIONAL: |
190 |
| - var_arg = arg_value |
191 |
| - elif parameter.kind == inspect.Parameter.KEYWORD_ONLY: |
192 |
| - kwargs[param_name] = arg_value |
193 |
| - elif parameter.kind == inspect.Parameter.VAR_KEYWORD: |
194 |
| - raise TypeError("**kwargs are not supported for component definitions") |
195 |
| - else: |
196 |
| - function_args.append(arg_value) |
197 |
| - if len(var_arg) > 0 and var_arg[0] == "--": |
198 |
| - var_arg = var_arg[1:] |
| 270 | + component_args: ComponentArgs = component_args_from_str( |
| 271 | + cmpnt_fn, cmpnt_args, cmpnt_defaults, config |
| 272 | + ) |
| 273 | + positional_arg_values = list(component_args.positional_args.values()) |
| 274 | + appdef = cmpnt_fn( |
| 275 | + *positional_arg_values, *component_args.var_args, **component_args.kwargs |
| 276 | + ) |
199 | 277 |
|
200 |
| - appdef = cmpnt_fn(*function_args, *var_arg, **kwargs) |
201 | 278 | if not isinstance(appdef, AppDef):
|
202 | 279 | raise TypeError(
|
203 | 280 | f"Expected a component that returns `AppDef`, but got `{type(appdef)}`"
|
|
0 commit comments