|
13 | 13 | import ray.train
|
14 | 14 | import torch.nn.functional as F
|
15 | 15 | import numpy as np
|
16 |
| -from typing import Dict |
| 16 | +from typing import Dict, Optional |
17 | 17 | from transformers import AutoTokenizer
|
18 | 18 | from torch import device
|
19 | 19 | from torch import LongTensor, FloatTensor, IntTensor
|
@@ -63,10 +63,12 @@ def ckpt_dump(
|
63 | 63 | global_step_id: int,
|
64 | 64 | batch_id: int,
|
65 | 65 | epoch_id: int,
|
66 |
| - configs: Dict |
| 66 | + configs: Dict, |
| 67 | + ckpt_dir: Optional[str]=None |
67 | 68 | ) -> None:
|
68 |
| - version: str = "step{}-batch{}-epoch{}".format(global_step_id, batch_id, epoch_id) |
69 |
| - ckpt_dir: str = os.path.join(configs["ckpt_dir"], version) |
| 69 | + if ckpt_dir is None: |
| 70 | + version: str = "step{}-batch{}-epoch{}".format(global_step_id, batch_id, epoch_id) |
| 71 | + ckpt_dir = os.path.join(configs["ckpt_dir"], version) |
70 | 72 | os.system("mkdir -p %s" % ckpt_dir)
|
71 | 73 | print("Saving ckpt to %s" % ckpt_dir)
|
72 | 74 | if configs["training_engine"] == "torch":
|
@@ -189,10 +191,21 @@ def train_func(configs: Dict) -> None:
|
189 | 191 | global_step_id += 1
|
190 | 192 | scheduler.step()
|
191 | 193 |
|
| 194 | + eval_results: Dict[str, float] = evaluation( |
| 195 | + model, |
| 196 | + dev_dataloader, |
| 197 | + device, |
| 198 | + configs["single_worker_eval_size"], |
| 199 | + label_confidence_threshold=configs["eval"]["label_confidence_threshold"], |
| 200 | + verbose=True |
| 201 | + ) |
| 202 | + print({k: v for k, v in eval_results.items() if k != "verbose"}) |
192 | 203 | final_ckpt_dir: str = os.path.join(configs["ckpt_dir"], "final")
|
193 |
| - os.system("mkdir -p %s" % final_ckpt_dir) |
194 |
| - print("Saving final ckpt to %s" % final_ckpt_dir) |
195 |
| - ckpt_dump(model, global_step_id, batch_id, epoch_id, configs) |
| 204 | + ckpt_dump(model, global_step_id, batch_id, epoch_id, configs, final_ckpt_dir) |
| 205 | + os.system("cp %s %s" % (data_dict_path, final_ckpt_dir)) |
| 206 | + open(os.path.join(final_ckpt_dir, "eval_results.json"), "w").write( |
| 207 | + json.dumps(eval_results) |
| 208 | + ) |
196 | 209 |
|
197 | 210 |
|
198 | 211 | if __name__ == "__main__":
|
|
0 commit comments