@@ -55,6 +55,29 @@ def init_with_ckpt(net: PlmMultiLabelEncoder, ckpt_root_path: str, engine: str)
55
55
)
56
56
57
57
58
+ def ckpt_dump (
59
+ model : torch .nn .Module ,
60
+ global_step_id : int ,
61
+ batch_id : int ,
62
+ epoch_id : int ,
63
+ configs : Dict
64
+ ) -> None :
65
+ version : str = "step{}-batch{}-epoch{}" .format (global_step_id , batch_id , epoch_id )
66
+ ckpt_dir : str = os .path .join (configs ["ckpt_dir" ], version )
67
+ os .system ("mkdir -p %s" % ckpt_dir )
68
+ print ("Saving ckpt to %s" % ckpt_dir )
69
+ if configs ["training_engine" ] == "torch" :
70
+ open (os .path .join (ckpt_dir , "train.json" ), "w" ).write (json .dumps (configs ))
71
+ torch .save (model .state_dict (), os .path .join (ckpt_dir , "model.pt" ))
72
+ elif configs ["training_engine" ] == "ray" :
73
+ if ray .train .get_context ().get_world_rank () == 0 :
74
+ open (os .path .join (ckpt_dir , "train.json" ), "w" ).write (json .dumps (configs ))
75
+ try :
76
+ torch .save (model .module .state_dict (), os .path .join (ckpt_dir , "model.pt" ))
77
+ except :
78
+ torch .save (model .state_dict (), os .path .join (ckpt_dir , "model.pt" ))
79
+
80
+
58
81
def loss_fn (
59
82
logits : FloatTensor , label_one_hot : FloatTensor , bias : float = 1e-10
60
83
) -> FloatTensor :
@@ -229,36 +252,14 @@ def train_func(configs: Dict) -> None:
229
252
elif configs ["training_engine" ] == "ray" :
230
253
ray .train .report (metrics = eval_metrics )
231
254
232
- if global_step_id % configs ["dump_period" ] == 0 :
233
- version : str = "step{}-batch{}-epoch{}" .format (global_step_id , batch_id , epoch_id )
234
- ckpt_dir : str = os .path .join (configs ["ckpt_dir" ], version )
235
- os .system ("mkdir -p %s" % ckpt_dir )
236
- print ("Saving ckpt to %s" % ckpt_dir )
237
- if configs ["training_engine" ] == "torch" :
238
- open (os .path .join (ckpt_dir , "train.json" ), "w" ).write (json .dumps (configs ))
239
- torch .save (model .state_dict (), os .path .join (ckpt_dir , "model.pt" ))
240
- elif configs ["training_engine" ] == "ray" :
241
- if ray .train .get_context ().get_world_rank () == 0 :
242
- open (os .path .join (ckpt_dir , "train.json" ), "w" ).write (json .dumps (configs ))
243
- try :
244
- torch .save (model .module .state_dict (), os .path .join (ckpt_dir , "model.pt" ))
245
- except :
246
- torch .save (model .state_dict (), os .path .join (ckpt_dir , "model.pt" ))
255
+ if global_step_id % configs ["dump_period" ] == 0 :
256
+ ckpt_dump (model , global_step_id , batch_id , epoch_id , configs )
247
257
global_step_id += 1
248
258
249
259
final_ckpt_dir : str = os .path .join (configs ["ckpt_dir" ], "final" )
250
260
os .system ("mkdir -p %s" % final_ckpt_dir )
251
261
print ("Saving final ckpt to %s" % final_ckpt_dir )
252
- if configs ["training_engine" ] == "torch" :
253
- open (os .path .join (final_ckpt_dir , "train.json" ), "w" ).write (json .dumps (configs ))
254
- torch .save (model .state_dict (), os .path .join (final_ckpt_dir , "model.pt" ))
255
- elif configs ["training_engine" ] == "ray" :
256
- if ray .train .get_context ().get_world_rank () == 0 :
257
- open (os .path .join (final_ckpt_dir , "train.json" ), "w" ).write (json .dumps (configs ))
258
- try :
259
- torch .save (model .module .state_dict (), os .path .join (final_ckpt_dir , "model.pt" ))
260
- except :
261
- torch .save (model .state_dict (), os .path .join (ckpt_dir , "model.pt" ))
262
+ ckpt_dump (model , global_step_id , batch_id , epoch_id , configs )
262
263
263
264
264
265
if __name__ == "__main__" :
0 commit comments