1
1
import os
2
+ import logging
2
3
from math import ceil
3
4
4
5
import numpy as np
5
6
6
7
import libmultilabel .linear as linear
8
+ from libmultilabel .utils import dump_log , argsort_top_k
7
9
8
10
9
11
def linear_test (config , model , datasets ):
@@ -13,12 +15,24 @@ def linear_test(config, model, datasets):
13
15
datasets ['test' ]['y' ].shape [1 ]
14
16
)
15
17
num_instance = datasets ['test' ]['x' ].shape [0 ]
18
+
19
+ k = config .save_k_predictions
20
+ top_k_idx = np .zeros ((num_instance , k ), dtype = 'i' )
21
+ top_k_scores = np .zeros ((num_instance , k ), dtype = 'd' )
22
+
16
23
for i in range (ceil (num_instance / config .eval_batch_size )):
17
24
slice = np .s_ [i * config .eval_batch_size :(i + 1 )* config .eval_batch_size ]
18
25
preds = linear .predict_values (model , datasets ['test' ]['x' ][slice ])
19
26
target = datasets ['test' ]['y' ][slice ].toarray ()
20
27
metrics .update (preds , target )
21
- print (linear .tabulate_metrics (metrics .compute (), 'test' ))
28
+
29
+ if k > 0 :
30
+ top_k_idx [slice ] = argsort_top_k (preds , k , axis = 1 )
31
+ top_k_scores [slice ] = np .take_along_axis (
32
+ preds , top_k_idx [slice ], axis = 1 )
33
+
34
+ metric_dict = metrics .compute ()
35
+ return (metric_dict , top_k_idx , top_k_scores )
22
36
23
37
24
38
def linear_train (datasets , config ):
@@ -46,5 +60,18 @@ def linear_run(config):
46
60
linear .save_pipeline (config .checkpoint_dir , preprocessor , model )
47
61
48
62
if os .path .exists (config .test_path ):
49
- linear_test (config , model , datasets )
50
- # TODO: dump logs?
63
+ metric_dict , top_k_idx , top_k_scores = linear_test (
64
+ config , model , datasets )
65
+
66
+ dump_log (config = config , metrics = metric_dict ,
67
+ split = 'test' , log_path = config .log_path )
68
+ print (linear .tabulate_metrics (metric_dict , 'test' ))
69
+
70
+ if config .save_k_predictions > 0 :
71
+ classes = preprocessor .binarizer .classes_
72
+ with open (config .predict_out_path , 'w' ) as fp :
73
+ for idx , score in zip (top_k_idx , top_k_scores ):
74
+ out_str = ' ' .join ([f'{ classes [i ]} :{ s :.4} ' for i , s in zip (
75
+ idx , score )])
76
+ fp .write (out_str + '\n ' )
77
+ logging .info (f'Saved predictions to: { config .predict_out_path } ' )
0 commit comments