Skip to content

Commit bba88b7

Browse files
authored
Merge pull request #8 from innerNULL/dev
Re-Implement Ckpt Dumping
2 parents b385826 + e550ccb commit bba88b7

File tree

1 file changed

+26
-25
lines changed

1 file changed

+26
-25
lines changed

train.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,29 @@ def init_with_ckpt(net: PlmMultiLabelEncoder, ckpt_root_path: str, engine: str)
5555
)
5656

5757

58+
def ckpt_dump(
59+
model: torch.nn.Module,
60+
global_step_id: int,
61+
batch_id: int,
62+
epoch_id: int,
63+
configs: Dict
64+
) -> None:
65+
version: str = "step{}-batch{}-epoch{}".format(global_step_id, batch_id, epoch_id)
66+
ckpt_dir: str = os.path.join(configs["ckpt_dir"], version)
67+
os.system("mkdir -p %s" % ckpt_dir)
68+
print("Saving ckpt to %s" % ckpt_dir)
69+
if configs["training_engine"] == "torch":
70+
open(os.path.join(ckpt_dir, "train.json"), "w").write(json.dumps(configs))
71+
torch.save(model.state_dict(), os.path.join(ckpt_dir, "model.pt"))
72+
elif configs["training_engine"] == "ray":
73+
if ray.train.get_context().get_world_rank() == 0:
74+
open(os.path.join(ckpt_dir, "train.json"), "w").write(json.dumps(configs))
75+
try:
76+
torch.save(model.module.state_dict(), os.path.join(ckpt_dir, "model.pt"))
77+
except:
78+
torch.save(model.state_dict(), os.path.join(ckpt_dir, "model.pt"))
79+
80+
5881
def loss_fn(
5982
logits: FloatTensor, label_one_hot: FloatTensor, bias: float=1e-10
6083
) -> FloatTensor:
@@ -229,36 +252,14 @@ def train_func(configs: Dict) -> None:
229252
elif configs["training_engine"] == "ray":
230253
ray.train.report(metrics=eval_metrics)
231254

232-
if global_step_id % configs["dump_period"] == 0:
233-
version: str = "step{}-batch{}-epoch{}".format(global_step_id, batch_id, epoch_id)
234-
ckpt_dir: str = os.path.join(configs["ckpt_dir"], version)
235-
os.system("mkdir -p %s" % ckpt_dir)
236-
print("Saving ckpt to %s" % ckpt_dir)
237-
if configs["training_engine"] == "torch":
238-
open(os.path.join(ckpt_dir, "train.json"), "w").write(json.dumps(configs))
239-
torch.save(model.state_dict(), os.path.join(ckpt_dir, "model.pt"))
240-
elif configs["training_engine"] == "ray":
241-
if ray.train.get_context().get_world_rank() == 0:
242-
open(os.path.join(ckpt_dir, "train.json"), "w").write(json.dumps(configs))
243-
try:
244-
torch.save(model.module.state_dict(), os.path.join(ckpt_dir, "model.pt"))
245-
except:
246-
torch.save(model.state_dict(), os.path.join(ckpt_dir, "model.pt"))
255+
if global_step_id % configs["dump_period"] == 0:
256+
ckpt_dump(model, global_step_id, batch_id, epoch_id, configs)
247257
global_step_id += 1
248258

249259
final_ckpt_dir: str = os.path.join(configs["ckpt_dir"], "final")
250260
os.system("mkdir -p %s" % final_ckpt_dir)
251261
print("Saving final ckpt to %s" % final_ckpt_dir)
252-
if configs["training_engine"] == "torch":
253-
open(os.path.join(final_ckpt_dir, "train.json"), "w").write(json.dumps(configs))
254-
torch.save(model.state_dict(), os.path.join(final_ckpt_dir, "model.pt"))
255-
elif configs["training_engine"] == "ray":
256-
if ray.train.get_context().get_world_rank() == 0:
257-
open(os.path.join(final_ckpt_dir, "train.json"), "w").write(json.dumps(configs))
258-
try:
259-
torch.save(model.module.state_dict(), os.path.join(final_ckpt_dir, "model.pt"))
260-
except:
261-
torch.save(model.state_dict(), os.path.join(ckpt_dir, "model.pt"))
262+
ckpt_dump(model, global_step_id, batch_id, epoch_id, configs)
262263

263264

264265
if __name__ == "__main__":

0 commit comments

Comments
 (0)