Skip to content

Commit 8532799

Browse files
authored
Merge pull request #235 from ASUS-AICS/no-retrain
update to really no retrain
2 parents 96148e5 + 42fd85d commit 8532799

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

search_params.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def train_libmultilabel_tune(config, datasets, classes, word_dict):
4141
classes=classes,
4242
word_dict=word_dict,
4343
search_params=True,
44-
save_checkpoints=False)
44+
save_checkpoints=True)
4545
trainer.train()
4646

4747

@@ -213,9 +213,20 @@ def retrain_best_model(exp_name, best_config, best_log_dir, merge_train_val):
213213

214214
data = load_static_data(
215215
best_config, merge_train_val=best_config.merge_train_val)
216-
logging.info(f'Re-training with best config: \n{best_config}')
217-
trainer = TorchTrainer(config=best_config, **data)
218-
trainer.train()
216+
217+
if merge_train_val:
218+
logging.info(f'Re-training with best config: \n{best_config}')
219+
trainer = TorchTrainer(config=best_config, **data)
220+
trainer.train()
221+
else:
222+
# If not merging training and validation data, load the best result from tune experiments.
223+
logging.info(f'Loading best model with best config: \n{best_config}')
224+
trainer = TorchTrainer(config=best_config, **data)
225+
best_checkpoint = os.path.join(best_log_dir, 'best_model.ckpt')
226+
last_checkpoint = os.path.join(best_log_dir, 'last.ckpt')
227+
trainer._setup_model(checkpoint_path=best_checkpoint)
228+
os.popen(f"cp {best_checkpoint} {os.path.join(checkpoint_dir, 'best_model.ckpt')}")
229+
os.popen(f"cp {last_checkpoint} {os.path.join(checkpoint_dir, 'last.ckpt')}")
219230

220231
if 'test' in data['datasets']:
221232
test_results = trainer.test()

0 commit comments

Comments
 (0)