Skip to content

Commit 4f937ac

Browse files
authored
CLI Arg Groups: Add more structure to the arg parser for Model specification and Model configuration (pytorch#937)
* Suppress the occurance of unused args in README subcommands * Move GGUF into model specification exclusive require group * Further group args into parsing groups * Function Description Typos
1 parent 87da8c4 commit 4f937ac

File tree

1 file changed

+113
-75
lines changed

1 file changed

+113
-75
lines changed

cli.py

Lines changed: 113 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,11 @@
2828
# Subcommands related to downloading and managing model artifacts
2929
INVENTORY_VERBS = ["download", "list", "remove", "where"]
3030

31+
# Subcommands related to generating inference output based on user prompts
32+
GENERATION_VERBS = ["browser", "chat", "generate", "server"]
33+
3134
# List of all supported subcommands in torchchat
32-
KNOWN_VERBS = ["chat", "browser", "generate", "server", "eval", "export"] + INVENTORY_VERBS
35+
KNOWN_VERBS = GENERATION_VERBS + ["eval", "export"] + INVENTORY_VERBS
3336

3437

3538
# Handle CLI arguments that are common to a majority of subcommands.
@@ -46,61 +49,97 @@ def check_args(args, verb: str) -> None:
4649

4750
# Given a arg parser and a subcommand (verb), add the appropriate arguments
4851
# for that subcommand.
52+
#
53+
# Note the use of argparse.SUPPRESS to hide arguments from --help due to
54+
# legacy CLI arg parsing. See https://github.com/pytorch/torchchat/issues/932
4955
def add_arguments_for_verb(parser, verb: str) -> None:
5056
# Argument closure for inventory related subcommands
5157
if verb in INVENTORY_VERBS:
5258
_configure_artifact_inventory_args(parser, verb)
5359
_add_cli_metadata_args(parser)
5460
return
5561

56-
# Model specification
57-
# A model can be specified using a positional model name or checkpoint path
58-
parser.add_argument(
62+
# Add argument groups for model specification (what base model to use)
63+
_add_model_specification_args(parser)
64+
65+
# Add argument groups for exported model path IO
66+
_add_exported_input_path_args(parser, verb)
67+
_add_export_output_path_args(parser, verb)
68+
69+
# Add argument groups for model configuration (compilation, quant, etc)
70+
_add_model_config_args(parser, verb)
71+
72+
# Add thematic argument groups based on the subcommand
73+
if verb in ["browser", "chat", "generate", "server"]:
74+
_add_generation_args(parser, verb)
75+
if verb == "eval":
76+
_add_evaluation_args(parser)
77+
78+
# Add CLI Args related to downloading of model artifacts (if not already downloaded)
79+
_add_jit_downloading_args(parser)
80+
81+
# Add CLI Args that are general to subcommand cli execution
82+
_add_cli_metadata_args(parser)
83+
84+
# WIP Features (suppressed from --help)
85+
_add_distributed_args(parser)
86+
_add_custom_model_args(parser)
87+
_add_speculative_execution_args(parser)
88+
89+
90+
# Add CLI Args related to model specification (what base model to use)
91+
def _add_model_specification_args(parser) -> None:
92+
model_specification_parser = parser.add_argument_group("Model Specification", "(REQUIRED) Specify the base model. Args are mutually exclusive.")
93+
exclusive_parser = model_specification_parser.add_mutually_exclusive_group(required=True)
94+
exclusive_parser.add_argument(
5995
"model",
6096
type=str,
6197
nargs="?",
6298
default=None,
6399
help="Model name for well-known models",
64100
)
65-
parser.add_argument(
101+
exclusive_parser.add_argument(
66102
"--checkpoint-path",
67103
type=Path,
68104
default="not_specified",
69105
help="Use the specified model checkpoint path",
70106
)
107+
# See _add_custom_model_args() for more details
108+
exclusive_parser.add_argument(
109+
"--gguf-path",
110+
type=Path,
111+
default=None,
112+
help=argparse.SUPPRESS,
113+
# "Use the specified GGUF model file",
114+
)
71115

72-
# Add thematic argument groups based on the subcommand
73-
if verb in ["browser", "chat", "generate", "server"]:
74-
_add_generation_args(parser)
75-
if verb == "eval":
76-
_add_evaluation_args(parser)
77-
78-
# Add argument groups for exported model path IO
79-
_add_exported_input_path_args(parser)
80-
_add_export_output_path_args(parser)
81-
82-
parser.add_argument(
116+
model_specification_parser.add_argument(
83117
"--is-chat-model",
84118
action="store_true",
85-
help="Indicate that the model was trained to support chat functionality",
119+
# help="Indicate that the model was trained to support chat functionality",
120+
help=argparse.SUPPRESS,
86121
)
87-
parser.add_argument(
122+
123+
# Add CLI Args related to model configuration (compilation, quant, etc)
124+
def _add_model_config_args(parser, verb: str) -> None:
125+
model_config_parser = parser.add_argument_group("Model Configuration", "Specify model configurations")
126+
model_config_parser.add_argument(
88127
"--compile",
89128
action="store_true",
90129
help="Whether to compile the model with torch.compile",
91130
)
92-
parser.add_argument(
131+
model_config_parser.add_argument(
93132
"--compile-prefill",
94133
action="store_true",
95134
help="Whether to compile the prefill. Improves prefill perf, but has higher compile times.",
96135
)
97-
parser.add_argument(
136+
model_config_parser.add_argument(
98137
"--dtype",
99138
default="fast",
100139
choices=allowable_dtype_names(),
101140
help="Override the dtype of the model (default is the checkpoint dtype). Options: bf16, fp16, fp32, fast16, fast",
102141
)
103-
parser.add_argument(
142+
model_config_parser.add_argument(
104143
"--quantize",
105144
type=str,
106145
default="{ }",
@@ -109,81 +148,84 @@ def add_arguments_for_verb(parser, verb: str) -> None:
109148
+ "modes are: embedding, linear:int8, linear:int4, linear:a8w4dq, precision."
110149
),
111150
)
112-
parser.add_argument(
151+
model_config_parser.add_argument(
113152
"--device",
114153
type=str,
115154
default=default_device,
116155
choices=["fast", "cpu", "cuda", "mps"],
117156
help="Hardware device to use. Options: cpu, cuda, mps",
118157
)
119-
parser.add_argument(
120-
"--hf-token",
121-
type=str,
122-
default=None,
123-
help="A HuggingFace API token to use when downloading model artifacts",
124-
)
125-
parser.add_argument(
126-
"--model-directory",
127-
type=Path,
128-
default=default_model_dir,
129-
help=f"The directory to store downloaded model artifacts. Default: {default_model_dir}",
130-
)
131-
parser.add_argument(
132-
"--profile",
133-
type=Path,
134-
default=None,
135-
help="Profile path.",
136-
)
137-
_add_cli_metadata_args(parser)
138-
139-
# WIP Features (suppressed from --help)
140-
_add_distributed_args(parser)
141-
_add_custom_model_args(parser)
142-
_add_speculative_execution_args(parser)
143158

159+
# Add CLI Args representing output paths of exported model files
160+
def _add_export_output_path_args(parser, verb: str) -> None:
161+
is_export = verb == "export"
144162

145-
# Add CLI Args representing user provided exported model files
146-
def _add_export_output_path_args(parser) -> None:
147163
output_path_parser = parser.add_argument_group(
148-
"Export Output Path Args",
149-
"Specify the output path for the exported model files",
164+
"Export Output Path" if is_export else None,
165+
"Specify the output path for the exported model files" if is_export else None,
150166
)
151-
output_path_parser.add_argument(
167+
exclusive_parser = output_path_parser.add_mutually_exclusive_group()
168+
exclusive_parser.add_argument(
152169
"--output-pte-path",
153170
type=str,
154171
default=None,
155-
help="Output to the specified ExecuTorch .pte model file",
172+
help="Output to the specified ExecuTorch .pte model file" if is_export else argparse.SUPPRESS,
156173
)
157-
output_path_parser.add_argument(
174+
exclusive_parser.add_argument(
158175
"--output-dso-path",
159176
type=str,
160177
default=None,
161-
help="Output to the specified AOT Inductor .dso model file",
178+
help="Output to the specified AOT Inductor .dso model file" if is_export else argparse.SUPPRESS,
162179
)
163180

164181

165182
# Add CLI Args representing user provided exported model files
166-
def _add_exported_input_path_args(parser) -> None:
183+
def _add_exported_input_path_args(parser, verb: str) -> None:
184+
is_generation_verb = verb in GENERATION_VERBS
185+
167186
exported_model_path_parser = parser.add_argument_group(
168-
"Exported Model Path Args",
169-
"Specify the path of the exported model files to ingest",
187+
"Exported Model Path" if is_generation_verb else None,
188+
"Specify the path of the exported model files to ingest" if is_generation_verb else None,
170189
)
171-
exported_model_path_parser.add_argument(
190+
exclusive_parser = exported_model_path_parser.add_mutually_exclusive_group()
191+
exclusive_parser.add_argument(
172192
"--dso-path",
173193
type=Path,
174194
default=None,
175-
help="Use the specified AOT Inductor .dso model file",
195+
help="Use the specified AOT Inductor .dso model file" if is_generation_verb else argparse.SUPPRESS,
176196
)
177-
exported_model_path_parser.add_argument(
197+
exclusive_parser.add_argument(
178198
"--pte-path",
179199
type=Path,
180200
default=None,
181-
help="Use the specified ExecuTorch .pte model file",
201+
help="Use the specified ExecuTorch .pte model file" if is_generation_verb else argparse.SUPPRESS,
182202
)
183203

204+
# Add CLI Args related to JIT downloading of model artifacts
205+
def _add_jit_downloading_args(parser) -> None:
206+
jit_downloading_parser = parser.add_argument_group("Model Downloading", "Specify args for model downloading (if model is not downloaded)",)
207+
jit_downloading_parser.add_argument(
208+
"--hf-token",
209+
type=str,
210+
default=None,
211+
help="A HuggingFace API token to use when downloading model artifacts",
212+
)
213+
jit_downloading_parser.add_argument(
214+
"--model-directory",
215+
type=Path,
216+
default=default_model_dir,
217+
help=f"The directory to store downloaded model artifacts. Default: {default_model_dir}",
218+
)
184219

185220
# Add CLI Args that are general to subcommand cli execution
186221
def _add_cli_metadata_args(parser) -> None:
222+
parser.add_argument(
223+
"--profile",
224+
type=Path,
225+
default=None,
226+
# help="Profile path.",
227+
help=argparse.SUPPRESS,
228+
)
187229
parser.add_argument(
188230
"-v",
189231
"--verbose",
@@ -227,25 +269,27 @@ def _configure_artifact_inventory_args(parser, verb: str) -> None:
227269

228270

229271
# Add CLI Args specific to user prompted generation
230-
def _add_generation_args(parser) -> None:
272+
def _add_generation_args(parser, verb: str) -> None:
231273
generator_parser = parser.add_argument_group(
232-
"Generation Args", "Configs for generating output based on provided prompt"
274+
"Generation", "Configs for generating output based on provided prompt"
233275
)
234276
generator_parser.add_argument(
235277
"--prompt",
236278
type=str,
237279
default="Hello, my name is",
238-
help="Input prompt for manual output generation",
280+
help="Input prompt for manual output generation" if verb == "generate" else argparse.SUPPRESS,
239281
)
240282
generator_parser.add_argument(
241283
"--chat",
242284
action="store_true",
243-
help="Whether to start an interactive chat session",
285+
# help="Whether to start an interactive chat session",
286+
help=argparse.SUPPRESS,
244287
)
245288
generator_parser.add_argument(
246289
"--gui",
247290
action="store_true",
248-
help="Whether to use a web UI for an interactive chat session",
291+
# help="Whether to use a web UI for an interactive chat session",
292+
help=argparse.SUPPRESS,
249293
)
250294
generator_parser.add_argument(
251295
"--num-samples",
@@ -271,14 +315,15 @@ def _add_generation_args(parser) -> None:
271315
generator_parser.add_argument(
272316
"--sequential-prefill",
273317
action="store_true",
274-
help="Whether to perform prefill sequentially. Only used for model debug.",
318+
# help="Whether to perform prefill sequentially. Only used for model debug.",
319+
help=argparse.SUPPRESS,
275320
)
276321

277322

278323
# Add CLI Args specific to Model Evaluation
279324
def _add_evaluation_args(parser) -> None:
280325
eval_parser = parser.add_argument_group(
281-
"Evaluation Args", "Configs for evaluating model performance"
326+
"Evaluation", "Configs for evaluating model performance"
282327
)
283328
eval_parser.add_argument(
284329
"--tasks",
@@ -337,13 +382,6 @@ def _add_custom_model_args(parser) -> None:
337382
help=argparse.SUPPRESS,
338383
# "Use the specified parameter file",
339384
)
340-
parser.add_argument(
341-
"--gguf-path",
342-
type=Path,
343-
default=None,
344-
help=argparse.SUPPRESS,
345-
# "Use the specified GGUF model file",
346-
)
347385
parser.add_argument(
348386
"--tokenizer-path",
349387
type=Path,

0 commit comments

Comments
 (0)