Skip to content

Commit a17e753

Browse files
author
innerNULL
committed
feat: Upgraded ckpt ad data dumping logics
1 parent 75fdabf commit a17e753

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

train.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import ray.train
1414
import torch.nn.functional as F
1515
import numpy as np
16-
from typing import Dict
16+
from typing import Dict, Optional
1717
from transformers import AutoTokenizer
1818
from torch import device
1919
from torch import LongTensor, FloatTensor, IntTensor
@@ -63,10 +63,12 @@ def ckpt_dump(
6363
global_step_id: int,
6464
batch_id: int,
6565
epoch_id: int,
66-
configs: Dict
66+
configs: Dict,
67+
ckpt_dir: Optional[str]=None
6768
) -> None:
68-
version: str = "step{}-batch{}-epoch{}".format(global_step_id, batch_id, epoch_id)
69-
ckpt_dir: str = os.path.join(configs["ckpt_dir"], version)
69+
if ckpt_dir is None:
70+
version: str = "step{}-batch{}-epoch{}".format(global_step_id, batch_id, epoch_id)
71+
ckpt_dir = os.path.join(configs["ckpt_dir"], version)
7072
os.system("mkdir -p %s" % ckpt_dir)
7173
print("Saving ckpt to %s" % ckpt_dir)
7274
if configs["training_engine"] == "torch":
@@ -189,10 +191,21 @@ def train_func(configs: Dict) -> None:
189191
global_step_id += 1
190192
scheduler.step()
191193

194+
eval_results: Dict[str, float] = evaluation(
195+
model,
196+
dev_dataloader,
197+
device,
198+
configs["single_worker_eval_size"],
199+
label_confidence_threshold=configs["eval"]["label_confidence_threshold"],
200+
verbose=True
201+
)
202+
print({k: v for k, v in eval_results.items() if k != "verbose"})
192203
final_ckpt_dir: str = os.path.join(configs["ckpt_dir"], "final")
193-
os.system("mkdir -p %s" % final_ckpt_dir)
194-
print("Saving final ckpt to %s" % final_ckpt_dir)
195-
ckpt_dump(model, global_step_id, batch_id, epoch_id, configs)
204+
ckpt_dump(model, global_step_id, batch_id, epoch_id, configs, final_ckpt_dir)
205+
os.system("cp %s %s" % (data_dict_path, final_ckpt_dir))
206+
open(os.path.join(final_ckpt_dir, "eval_results.json"), "w").write(
207+
json.dumps(eval_results)
208+
)
196209

197210

198211
if __name__ == "__main__":

0 commit comments

Comments
 (0)