9
9
import tempfile
10
10
import json
11
11
import torch
12
+ import random
12
13
import ray .train
13
14
import torch .nn .functional as F
15
+ import numpy as np
14
16
from typing import Dict
15
17
from transformers import AutoTokenizer
16
18
from torch import device
28
30
from src .plm_icd_multi_label_classifier .metrics import metrics_func , topk_metrics_func
29
31
30
32
33
+ THRESHOLD : float = 0.6
34
+
35
+
31
36
def init_with_ckpt (net : PlmMultiLabelEncoder , ckpt_root_path : str , engine : str ) -> None :
32
37
ckpts : List [str ] = [x for x in os .listdir (ckpt_root_path ) if x != "bak" ]
33
38
if len (ckpts ) == 0 :
@@ -94,7 +99,7 @@ def eval(
94
99
95
100
logits : FloatTensor = torch .concat (all_logits , dim = 0 )
96
101
output_label_probs : FloatTensor = torch .sigmoid (logits )
97
- output_one_hot : FloatTensor = (output_label_probs > 0.5 ).float ()
102
+ output_one_hot : FloatTensor = (output_label_probs > THRESHOLD ).float ()
98
103
label_one_hot : FloatTensor = torch .concat (all_label_one_hots , dim = 0 )
99
104
# Loss
100
105
loss : float = float (
@@ -104,9 +109,9 @@ def eval(
104
109
prob50_metrics : Dict [str , float ] = metrics_func (
105
110
output_one_hot .int (), label_one_hot .int ()
106
111
)
107
- top5_metrics : Dict [str , float ] = topk_metrics_func (logits , label_one_hot , top_k = 5 )
108
- top8_metrics : Dict [str , float ] = topk_metrics_func (logits , label_one_hot , top_k = 8 )
109
- top15_metrics : Dict [str , float ] = topk_metrics_func (logits , label_one_hot , top_k = 15 )
112
+ # top5_metrics: Dict[str, float] = topk_metrics_func(logits, label_one_hot, top_k=5)
113
+ # top8_metrics: Dict[str, float] = topk_metrics_func(logits, label_one_hot, top_k=8)
114
+ # top15_metrics: Dict[str, float] = topk_metrics_func(logits, label_one_hot, top_k=15)
110
115
111
116
out = {
112
117
"loss" : round (loss , 8 ),
@@ -116,29 +121,31 @@ def eval(
116
121
"macro_recall" : round (prob50_metrics ["macro_recall" ], 4 ),
117
122
"macro_precision" : round (prob50_metrics ["macro_precision" ], 4 ),
118
123
"macro_f1" : round (prob50_metrics ["macro_f1" ], 4 ),
119
- "micro_recall@5" : round (top5_metrics ["micro_recall@5" ], 4 ),
120
- "micro_precision@5" : round (top5_metrics ["micro_precision@5" ], 4 ),
121
- "micro_f1@5" : round (top5_metrics ["micro_f1@5" ], 4 ),
122
- "macro_recall@5" : round (top5_metrics ["macro_recall@5" ], 4 ),
123
- "macro_precision@5" : round (top5_metrics ["macro_precision@5" ], 4 ),
124
- "macro_f1@5" : round (top5_metrics ["macro_f1@5" ], 4 ),
125
- "micro_recall@8" : round (top8_metrics ["micro_recall@8" ], 4 ),
126
- "micro_precision@8" : round (top8_metrics ["micro_precision@8" ], 4 ),
127
- "micro_f1@8" : round (top8_metrics ["micro_f1@8" ], 4 ),
128
- "macro_recall@8" : round (top8_metrics ["macro_recall@8" ], 4 ),
129
- "macro_precision@8" : round (top8_metrics ["macro_precision@8" ], 4 ),
130
- "macro_f1@8" : round (top8_metrics ["macro_f1@8" ], 4 ),
131
- "micro_recall@15" : round (top15_metrics ["micro_recall@15" ], 4 ),
132
- "micro_precision@15" : round (top15_metrics ["micro_precision@15" ], 4 ),
133
- "micro_f1@15" : round (top15_metrics ["micro_f1@15" ], 4 ),
134
- "macro_recall@15" : round (top15_metrics ["macro_recall@15" ], 4 ),
135
- "macro_precision@15" : round (top15_metrics ["macro_precision@15" ], 4 ),
136
- "macro_f1@15" : round (top15_metrics ["macro_f1@15" ], 4 )
124
+ # "micro_recall@5": round(top5_metrics["micro_recall@5"], 4),
125
+ # "micro_precision@5": round(top5_metrics["micro_precision@5"], 4),
126
+ # "micro_f1@5": round(top5_metrics["micro_f1@5"], 4),
127
+ # "macro_recall@5": round(top5_metrics["macro_recall@5"], 4),
128
+ # "macro_precision@5": round(top5_metrics["macro_precision@5"], 4),
129
+ # "macro_f1@5": round(top5_metrics["macro_f1@5"], 4),
130
+ # "micro_recall@8": round(top8_metrics["micro_recall@8"], 4),
131
+ # "micro_precision@8": round(top8_metrics["micro_precision@8"], 4),
132
+ # "micro_f1@8": round(top8_metrics["micro_f1@8"], 4),
133
+ # "macro_recall@8": round(top8_metrics["macro_recall@8"], 4),
134
+ # "macro_precision@8": round(top8_metrics["macro_precision@8"], 4),
135
+ # "macro_f1@8": round(top8_metrics["macro_f1@8"], 4),
136
+ # "micro_recall@15": round(top15_metrics["micro_recall@15"], 4),
137
+ # "micro_precision@15": round(top15_metrics["micro_precision@15"], 4),
138
+ # "micro_f1@15": round(top15_metrics["micro_f1@15"], 4),
139
+ # "macro_recall@15": round(top15_metrics["macro_recall@15"], 4),
140
+ # "macro_precision@15": round(top15_metrics["macro_precision@15"], 4),
141
+ # "macro_f1@15": round(top15_metrics["macro_f1@15"], 4)
137
142
}
138
143
return out
139
144
140
145
def train_func (configs : Dict ) -> None :
141
146
torch .manual_seed (configs ["random_seed" ])
147
+ random .seed (configs ["random_seed" ])
148
+ np .random .seed (configs ["random_seed" ])
142
149
143
150
device : device = None
144
151
if configs ["training_engine" ] == "torch" :
@@ -233,8 +240,10 @@ def train_func(configs: Dict) -> None:
233
240
elif configs ["training_engine" ] == "ray" :
234
241
if ray .train .get_context ().get_world_rank () == 0 :
235
242
open (os .path .join (ckpt_dir , "train.json" ), "w" ).write (json .dumps (configs ))
236
- torch .save (model .module .state_dict (), os .path .join (ckpt_dir , "model.pt" ))
237
-
243
+ try :
244
+ torch .save (model .module .state_dict (), os .path .join (ckpt_dir , "model.pt" ))
245
+ except :
246
+ torch .save (model .state_dict (), os .path .join (ckpt_dir , "model.pt" ))
238
247
global_step_id += 1
239
248
240
249
final_ckpt_dir : str = os .path .join (configs ["ckpt_dir" ], "final" )
@@ -246,7 +255,10 @@ def train_func(configs: Dict) -> None:
246
255
elif configs ["training_engine" ] == "ray" :
247
256
if ray .train .get_context ().get_world_rank () == 0 :
248
257
open (os .path .join (final_ckpt_dir , "train.json" ), "w" ).write (json .dumps (configs ))
249
- torch .save (model .module .state_dict (), os .path .join (final_ckpt_dir , "model.pt" ))
258
+ try :
259
+ torch .save (model .module .state_dict (), os .path .join (final_ckpt_dir , "model.pt" ))
260
+ except :
261
+ torch .save (model .state_dict (), os .path .join (ckpt_dir , "model.pt" ))
250
262
251
263
252
264
if __name__ == "__main__" :
@@ -256,7 +268,9 @@ def train_func(configs: Dict) -> None:
256
268
if os .path .exists (train_conf ["hf_lm" ]):
257
269
train_conf ["hf_lm" ] = os .path .abspath (train_conf ["hf_lm" ])
258
270
print ("Training config:\n {}" .format (train_conf ))
259
-
271
+
272
+ os .environ ["HF_TOKEN" ] = train_conf ["hf_key" ]
273
+
260
274
os .system ("mkdir -p %s" % train_conf ["ckpt_dir" ])
261
275
262
276
if train_conf ["training_engine" ] == "torch" :
0 commit comments