@@ -41,7 +41,7 @@ def train_libmultilabel_tune(config, datasets, classes, word_dict):
41
41
classes = classes ,
42
42
word_dict = word_dict ,
43
43
search_params = True ,
44
- save_checkpoints = False )
44
+ save_checkpoints = True )
45
45
trainer .train ()
46
46
47
47
@@ -213,9 +213,20 @@ def retrain_best_model(exp_name, best_config, best_log_dir, merge_train_val):
213
213
214
214
data = load_static_data (
215
215
best_config , merge_train_val = best_config .merge_train_val )
216
- logging .info (f'Re-training with best config: \n { best_config } ' )
217
- trainer = TorchTrainer (config = best_config , ** data )
218
- trainer .train ()
216
+
217
+ if merge_train_val :
218
+ logging .info (f'Re-training with best config: \n { best_config } ' )
219
+ trainer = TorchTrainer (config = best_config , ** data )
220
+ trainer .train ()
221
+ else :
222
+ # If not merging training and validation data, load the best result from tune experiments.
223
+ logging .info (f'Loading best model with best config: \n { best_config } ' )
224
+ trainer = TorchTrainer (config = best_config , ** data )
225
+ best_checkpoint = os .path .join (best_log_dir , 'best_model.ckpt' )
226
+ last_checkpoint = os .path .join (best_log_dir , 'last.ckpt' )
227
+ trainer ._setup_model (checkpoint_path = best_checkpoint )
228
+ os .popen (f"cp { best_checkpoint } { os .path .join (checkpoint_dir , 'best_model.ckpt' )} " )
229
+ os .popen (f"cp { last_checkpoint } { os .path .join (checkpoint_dir , 'last.ckpt' )} " )
219
230
220
231
if 'test' in data ['datasets' ]:
221
232
test_results = trainer .test ()
0 commit comments