@@ -109,9 +109,9 @@ def eval(
109
109
prob50_metrics : Dict [str , float ] = metrics_func (
110
110
output_one_hot .int (), label_one_hot .int ()
111
111
)
112
- top5_metrics : Dict [str , float ] = topk_metrics_func (logits , label_one_hot , top_k = 5 )
113
- top8_metrics : Dict [str , float ] = topk_metrics_func (logits , label_one_hot , top_k = 8 )
114
- top15_metrics : Dict [str , float ] = topk_metrics_func (logits , label_one_hot , top_k = 15 )
112
+ # top5_metrics: Dict[str, float] = topk_metrics_func(logits, label_one_hot, top_k=5)
113
+ # top8_metrics: Dict[str, float] = topk_metrics_func(logits, label_one_hot, top_k=8)
114
+ # top15_metrics: Dict[str, float] = topk_metrics_func(logits, label_one_hot, top_k=15)
115
115
116
116
out = {
117
117
"loss" : round (loss , 8 ),
@@ -244,7 +244,6 @@ def train_func(configs: Dict) -> None:
244
244
torch .save (model .module .state_dict (), os .path .join (ckpt_dir , "model.pt" ))
245
245
except :
246
246
torch .save (model .state_dict (), os .path .join (ckpt_dir , "model.pt" ))
247
-
248
247
global_step_id += 1
249
248
250
249
final_ckpt_dir : str = os .path .join (configs ["ckpt_dir" ], "final" )
0 commit comments