Skip to content

Commit 723786a

Browse files
authored
Switch export_model to use AutoModel and AutoTokenizer (#1260)
* refactor export_model to use AutoModel and AutoTokenizer * use AutoConfig, AutoTokenizer, and AutoModel instead of jiant model_type * Switch to hf_pretrained_model_name_or_path. Remove unused tokenizer_path. Update notebooks with AutoClass changes.
1 parent 6e6c2e3 commit 723786a

21 files changed

+150
-249
lines changed

examples/notebooks/jiant_Basic_Example.ipynb

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,8 @@
158158
},
159159
"outputs": [],
160160
"source": [
161-
"export_model.lookup_and_export_model(\n",
162-
" model_type=\"roberta-base\",\n",
161+
"export_model.export_model(\n",
162+
" hf_pretrained_model_name_or_path=\"roberta-base\",\n",
163163
" output_base_path=\"./models/roberta-base\",\n",
164164
")"
165165
]
@@ -191,8 +191,7 @@
191191
"\n",
192192
"tokenize_and_cache.main(tokenize_and_cache.RunConfiguration(\n",
193193
" task_config_path=f\"./tasks/configs/{task_name}_config.json\",\n",
194-
" model_type=\"roberta-base\",\n",
195-
" model_tokenizer_path=\"./models/roberta-base/tokenizer\",\n",
194+
" hf_pretrained_model_name_or_path=\"roberta-base\",\n",
196195
" output_dir=f\"./cache/{task_name}\",\n",
197196
" phases=[\"train\", \"val\"],\n",
198197
"))"
@@ -309,10 +308,9 @@
309308
"run_args = main_runscript.RunConfiguration(\n",
310309
" jiant_task_container_config_path=\"./run_configs/mrpc_run_config.json\",\n",
311310
" output_dir=\"./runs/mrpc\",\n",
312-
" model_type=\"roberta-base\",\n",
311+
" hf_pretrained_model_name_or_path=\"roberta-base\",\n",
313312
" model_path=\"./models/roberta-base/model/roberta-base.p\",\n",
314313
" model_config_path=\"./models/roberta-base/model/roberta-base.json\",\n",
315-
" model_tokenizer_path=\"./models/roberta-base/tokenizer\",\n",
316314
" learning_rate=1e-5,\n",
317315
" eval_every_steps=500,\n",
318316
" do_train=True,\n",

examples/notebooks/jiant_EdgeProbing_Example.ipynb

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2702,8 +2702,8 @@
27022702
"outputId": "c21bdffa-0ff3-49f3-e734-af5530ab4711"
27032703
},
27042704
"source": [
2705-
"export_model.lookup_and_export_model(\n",
2706-
" model_type=\"roberta-base\",\n",
2705+
"export_model.export_model(\n",
2706+
" hf_pretrained_model_name_or_path=\"roberta-base\",\n",
27072707
" output_base_path=\"./models/roberta-base\",\n",
27082708
")"
27092709
],
@@ -2856,8 +2856,7 @@
28562856
"\n",
28572857
"tokenize_and_cache.main(tokenize_and_cache.RunConfiguration(\n",
28582858
" task_config_path=f\"./tasks/configs/{task_name}_config.json\",\n",
2859-
" model_type=\"roberta-base\",\n",
2860-
" model_tokenizer_path=\"./models/roberta-base/tokenizer\",\n",
2859+
" hf_pretrained_model_name_or_path=\"roberta-base\",\n",
28612860
" output_dir=f\"./cache/{task_name}\",\n",
28622861
" phases=[\"train\", \"val\"],\n",
28632862
"))"
@@ -3147,10 +3146,9 @@
31473146
"run_args = main_runscript.RunConfiguration(\n",
31483147
" jiant_task_container_config_path=\"./run_configs/semeval_run_config.json\",\n",
31493148
" output_dir=\"./runs/semeval\",\n",
3150-
" model_type=\"roberta-base\",\n",
3149+
" hf_pretrained_model_name_or_path=\"roberta-base\",\n",
31513150
" model_path=\"./models/roberta-base/model/roberta-base.p\",\n",
31523151
" model_config_path=\"./models/roberta-base/model/roberta-base.json\",\n",
3153-
" model_tokenizer_path=\"./models/roberta-base/tokenizer\",\n",
31543152
" learning_rate=1e-5,\n",
31553153
" eval_every_steps=500,\n",
31563154
" do_train=True,\n",
@@ -3170,7 +3168,6 @@
31703168
" model_type: roberta-base\n",
31713169
" model_path: ./models/roberta-base/model/roberta-base.p\n",
31723170
" model_config_path: ./models/roberta-base/model/roberta-base.json\n",
3173-
" model_tokenizer_path: ./models/roberta-base/tokenizer\n",
31743171
" model_load_mode: from_transformers\n",
31753172
" do_train: True\n",
31763173
" do_val: True\n",
@@ -3204,7 +3201,6 @@
32043201
" \"model_type\": \"roberta-base\",\n",
32053202
" \"model_path\": \"./models/roberta-base/model/roberta-base.p\",\n",
32063203
" \"model_config_path\": \"./models/roberta-base/model/roberta-base.json\",\n",
3207-
" \"model_tokenizer_path\": \"./models/roberta-base/tokenizer\",\n",
32083204
" \"model_load_mode\": \"from_transformers\",\n",
32093205
" \"do_train\": true,\n",
32103206
" \"do_val\": true,\n",

examples/notebooks/jiant_MNLI_Diagnostic_Example.ipynb

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,8 @@
140140
},
141141
"outputs": [],
142142
"source": [
143-
"export_model.lookup_and_export_model(\n",
144-
" model_type=\"roberta-base\",\n",
143+
"export_model.export_model(\n",
144+
" hf_pretrained_model_name_or_path=\"roberta-base\",\n",
145145
" output_base_path=\"./models/roberta-base\",\n",
146146
")"
147147
]
@@ -169,24 +169,21 @@
169169
"# Tokenize and cache each task\n",
170170
"tokenize_and_cache.main(tokenize_and_cache.RunConfiguration(\n",
171171
" task_config_path=f\"./tasks/configs/mnli_config.json\",\n",
172-
" model_type=\"roberta-base\",\n",
173-
" model_tokenizer_path=\"./models/roberta-base/tokenizer\",\n",
172+
" hf_pretrained_model_name_or_path=\"roberta-base\",\n",
174173
" output_dir=f\"./cache/mnli\",\n",
175174
" phases=[\"train\", \"val\"],\n",
176175
"))\n",
177176
"\n",
178177
"tokenize_and_cache.main(tokenize_and_cache.RunConfiguration(\n",
179178
" task_config_path=f\"./tasks/configs/mnli_mismatched_config.json\",\n",
180-
" model_type=\"roberta-base\",\n",
181-
" model_tokenizer_path=\"./models/roberta-base/tokenizer\",\n",
179+
" hf_pretrained_model_name_or_path=\"roberta-base\",\n",
182180
" output_dir=f\"./cache/mnli_mismatched\",\n",
183181
" phases=[\"val\"],\n",
184182
"))\n",
185183
"\n",
186184
"tokenize_and_cache.main(tokenize_and_cache.RunConfiguration(\n",
187185
" task_config_path=f\"./tasks/configs/glue_diagnostics_config.json\",\n",
188-
" model_type=\"roberta-base\",\n",
189-
" model_tokenizer_path=\"./models/roberta-base/tokenizer\",\n",
186+
" hf_pretrained_model_name_or_path=\"roberta-base\",\n",
190187
" output_dir=f\"./cache/glue_diagnostics\",\n",
191188
" phases=[\"test\"],\n",
192189
"))"
@@ -323,10 +320,9 @@
323320
"run_args = main_runscript.RunConfiguration(\n",
324321
" jiant_task_container_config_path=\"./run_configs/jiant_run_config.json\",\n",
325322
" output_dir=\"./runs/run1\",\n",
326-
" model_type=\"roberta-base\",\n",
323+
" hf_pretrained_model_name_or_path=\"roberta-base\",\n",
327324
" model_path=\"./models/roberta-base/model/roberta-base.p\",\n",
328325
" model_config_path=\"./models/roberta-base/model/roberta-base.json\",\n",
329-
" model_tokenizer_path=\"./models/roberta-base/tokenizer\",\n",
330326
" learning_rate=1e-5,\n",
331327
" eval_every_steps=500,\n",
332328
" do_train=True,\n",

examples/notebooks/jiant_Multi_Task_Example.ipynb

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,8 @@
161161
},
162162
"outputs": [],
163163
"source": [
164-
"export_model.lookup_and_export_model(\n",
165-
" model_type=\"roberta-base\",\n",
164+
"export_model.export_model(\n",
165+
" hf_pretrained_model_name_or_path=\"roberta-base\",\n",
166166
" output_base_path=\"./models/roberta-base\",\n",
167167
")"
168168
]
@@ -193,8 +193,7 @@
193193
"for task_name in [\"rte\", \"stsb\", \"commonsenseqa\"]:\n",
194194
" tokenize_and_cache.main(tokenize_and_cache.RunConfiguration(\n",
195195
" task_config_path=f\"./tasks/configs/{task_name}_config.json\",\n",
196-
" model_type=\"roberta-base\",\n",
197-
" model_tokenizer_path=\"./models/roberta-base/tokenizer\",\n",
196+
" hf_pretrained_model_name_or_path=\"roberta-base\",\n",
198197
" output_dir=f\"./cache/{task_name}\",\n",
199198
" phases=[\"train\", \"val\"],\n",
200199
" ))"
@@ -342,10 +341,9 @@
342341
"run_args = main_runscript.RunConfiguration(\n",
343342
" jiant_task_container_config_path=\"./run_configs/jiant_run_config.json\",\n",
344343
" output_dir=\"./runs/run1\",\n",
345-
" model_type=\"roberta-base\",\n",
344+
" hf_pretrained_model_name_or_path=\"roberta-base\",\n",
346345
" model_path=\"./models/roberta-base/model/roberta-base.p\",\n",
347346
" model_config_path=\"./models/roberta-base/model/roberta-base.json\",\n",
348-
" model_tokenizer_path=\"./models/roberta-base/tokenizer\",\n",
349347
" learning_rate=1e-5,\n",
350348
" eval_every_steps=500,\n",
351349
" do_train=True,\n",

examples/notebooks/jiant_STILTs_Example.ipynb

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,8 @@
163163
},
164164
"outputs": [],
165165
"source": [
166-
"export_model.lookup_and_export_model(\n",
167-
" model_type=\"roberta-base\",\n",
166+
"export_model.export_model(\n",
167+
" hf_pretrained_model_name_or_path=\"roberta-base\",\n",
168168
" output_base_path=\"./models/roberta-base\",\n",
169169
")"
170170
]
@@ -195,8 +195,7 @@
195195
"for task_name in [\"mnli\", \"rte\"]:\n",
196196
" tokenize_and_cache.main(tokenize_and_cache.RunConfiguration(\n",
197197
" task_config_path=f\"./tasks/configs/{task_name}_config.json\",\n",
198-
" model_type=\"roberta-base\",\n",
199-
" model_tokenizer_path=\"./models/roberta-base/tokenizer\",\n",
198+
" hf_pretrained_model_name_or_path=\"roberta-base\",\n",
200199
" output_dir=f\"./cache/{task_name}\",\n",
201200
" phases=[\"train\", \"val\"],\n",
202201
" ))"
@@ -367,10 +366,9 @@
367366
"run_args = main_runscript.RunConfiguration(\n",
368367
" jiant_task_container_config_path=\"./run_configs/mnli_run_config.json\",\n",
369368
" output_dir=\"./runs/mnli\",\n",
370-
" model_type=\"roberta-base\",\n",
369+
" hf_pretrained_model_name_or_path=\"roberta-base\",\n",
371370
" model_path=\"./models/roberta-base/model/roberta-base.p\",\n",
372371
" model_config_path=\"./models/roberta-base/model/roberta-base.json\",\n",
373-
" model_tokenizer_path=\"./models/roberta-base/tokenizer\",\n",
374372
" learning_rate=1e-5,\n",
375373
" eval_every_steps=500,\n",
376374
" do_train=True,\n",
@@ -404,11 +402,10 @@
404402
"run_args = main_runscript.RunConfiguration(\n",
405403
" jiant_task_container_config_path=\"./run_configs/rte_run_config.json\",\n",
406404
" output_dir=\"./runs/mnli___rte\",\n",
407-
" model_type=\"roberta-base\",\n",
405+
" hf_pretrained_model_name_or_path=\"roberta-base\",\n",
408406
" model_path=\"./runs/mnli/best_model.p\", # Loading the best model\n",
409407
" model_load_mode=\"partial\",\n",
410408
" model_config_path=\"./models/roberta-base/model/roberta-base.json\",\n",
411-
" model_tokenizer_path=\"./models/roberta-base/tokenizer\",\n",
412409
" learning_rate=1e-5,\n",
413410
" eval_every_steps=500,\n",
414411
" do_train=True,\n",

examples/notebooks/jiant_XNLI_Example.ipynb

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@
164164
},
165165
"outputs": [],
166166
"source": [
167-
"export_model.lookup_and_export_model(\n",
168-
" model_type=\"xlm-roberta-base\",\n",
167+
"export_model.export_model(\n",
168+
" hf_pretrained_model_name_or_path=\"xlm-roberta-base\",\n",
169169
" output_base_path=\"./models/xlm-roberta-base\",\n",
170170
")"
171171
]
@@ -197,8 +197,7 @@
197197
"# Tokenize and cache MNLI\n",
198198
"tokenize_and_cache.main(tokenize_and_cache.RunConfiguration(\n",
199199
" task_config_path=f\"./tasks/configs/mnli_config.json\",\n",
200-
" model_type=\"xlm-roberta-base\",\n",
201-
" model_tokenizer_path=\"./models/xlm-roberta-base/tokenizer\",\n",
200+
" hf_pretrained_model_name_or_path=\"xlm-roberta-base\",\n",
202201
" output_dir=f\"./cache/mnli\",\n",
203202
" phases=[\"train\", \"val\"],\n",
204203
"))\n",
@@ -207,8 +206,7 @@
207206
"for lang in [\"de\", \"zh\"]:\n",
208207
" tokenize_and_cache.main(tokenize_and_cache.RunConfiguration(\n",
209208
" task_config_path=f\"./tasks/configs/xnli_{lang}_config.json\",\n",
210-
" model_type=\"xlm-roberta-base\",\n",
211-
" model_tokenizer_path=\"./models/xlm-roberta-base/tokenizer\",\n",
209+
" hf_pretrained_model_name_or_path=\"xlm-roberta-base\",\n",
212210
" output_dir=f\"./cache/xnli_{lang}\",\n",
213211
" phases=[\"val\"],\n",
214212
" ))"
@@ -384,10 +382,9 @@
384382
"run_args = main_runscript.RunConfiguration(\n",
385383
" jiant_task_container_config_path=\"./run_configs/jiant_run_config.json\",\n",
386384
" output_dir=\"./runs/run1\",\n",
387-
" model_type=\"xlm-roberta-base\",\n",
385+
" hf_pretrained_model_name_or_path=\"xlm-roberta-base\",\n",
388386
" model_path=\"./models/xlm-roberta-base/model/xlm-roberta-base.p\",\n",
389387
" model_config_path=\"./models/xlm-roberta-base/model/xlm-roberta-base.json\",\n",
390-
" model_tokenizer_path=\"./models/xlm-roberta-base/tokenizer\",\n",
391388
" learning_rate=1e-5,\n",
392389
" eval_every_steps=500,\n",
393390
" do_train=True,\n",

guides/tutorials/quick_start_main.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,15 @@ python jiant/scripts/download_data/runscript.py \
2626
2. Next, we download our RoBERTa-base model
2727
```bash
2828
python jiant/proj/main/export_model.py \
29-
--model_type ${MODEL_TYPE} \
29+
--hf_pretrained_model_name_or_path ${MODEL_TYPE} \
3030
--output_base_path ${EXP_DIR}/models/${MODEL_TYPE}
3131
```
3232

3333
3. Next, we tokenize and cache the inputs and labels for our RTE task
3434
```bash
3535
python jiant/proj/main/tokenize_and_cache.py \
3636
--task_config_path ${EXP_DIR}/tasks/configs/${TASK}_config.json \
37-
--model_type ${MODEL_TYPE} \
38-
--model_tokenizer_path \
39-
${EXP_DIR}/models/${MODEL_TYPE}/tokenizer \
37+
--hf_pretrained_model_name_or_path ${MODEL_TYPE} \
4038
--output_dir ${EXP_DIR}/cache/${MODEL_TYPE}/${TASK} \
4139
--phases train,val \
4240
--max_seq_length 256 \

0 commit comments

Comments
 (0)