@@ -17,7 +17,7 @@ def _argsort_top_k(preds: np.ndarray, top_k: int) -> np.ndarray:
17
17
return np .take_along_axis (top_k_idx , argsort_top_k , axis = - 1 )
18
18
19
19
20
- def _DCG_argsort (argsort_preds : np .ndarray , target : np .ndarray , top_k : int ) -> np .ndarray :
20
+ def _dcg_argsort (argsort_preds : np .ndarray , target : np .ndarray , top_k : int ) -> np .ndarray :
21
21
"""Computes DCG@k with a sorted preds array and a target array."""
22
22
top_k_idx = argsort_preds [:, - top_k :][:, ::- 1 ]
23
23
gains = np .take_along_axis (target , top_k_idx , axis = - 1 )
@@ -27,7 +27,7 @@ def _DCG_argsort(argsort_preds: np.ndarray, target: np.ndarray, top_k: int) -> n
27
27
return dcgs
28
28
29
29
30
- def _IDCG (target : np .ndarray , top_k : int ) -> np .ndarray :
30
+ def _idcg (target : np .ndarray , top_k : int ) -> np .ndarray :
31
31
"""Computes IDCG@k for a 0/1 target array. A 0/1 target is a special case that
32
32
doesn't require sorting. If IDCG is computed with DCG,
33
33
then target will need to be sorted, which incurs a large overhead.
@@ -43,10 +43,13 @@ def _IDCG(target: np.ndarray, top_k: int) -> np.ndarray:
43
43
return cum_discount [idx ]
44
44
45
45
46
- class NDCG :
47
- def __init__ (self , top_k : int ):
48
- """Compute the normalized DCG@k (nDCG@k).
46
+ class NDCGAtK :
47
+ """Compute the normalized DCG@k (nDCG@k). Please refer to the `implementation document`
48
+ (https://www.csie.ntu.edu.tw/~cjlin/papers/libmultilabel/libmultilabel_implementation.pdf) for details.
49
+ """
49
50
51
+ def __init__ (self , top_k : int ):
52
+ """
50
53
Args:
51
54
top_k: Consider only the top k elements for each query.
52
55
"""
@@ -61,8 +64,8 @@ def update(self, preds: np.ndarray, target: np.ndarray):
61
64
return self .update_argsort (_argsort_top_k (preds , self .top_k ), target )
62
65
63
66
def update_argsort (self , argsort_preds : np .ndarray , target : np .ndarray ):
64
- dcg = _DCG_argsort (argsort_preds , target , self .top_k )
65
- idcg = _IDCG (target , self .top_k )
67
+ dcg = _dcg_argsort (argsort_preds , target , self .top_k )
68
+ idcg = _idcg (target , self .top_k )
66
69
ndcg_score = dcg / idcg
67
70
# by convention, ndcg is 0 for zero label instances
68
71
self .score += np .nan_to_num (ndcg_score , nan = 0.0 ).sum ()
@@ -76,10 +79,13 @@ def reset(self):
76
79
self .num_sample = 0
77
80
78
81
79
- class RPrecision :
80
- def __init__ (self , top_k : int ):
81
- """Compute the R-Precision@K.
82
+ class RPrecisionAtK :
83
+ """Compute the R-Precision@K. Please refer to the `implementation document`
84
+ (https://www.csie.ntu.edu.tw/~cjlin/papers/libmultilabel/libmultilabel_implementation.pdf) for details.
85
+ """
82
86
87
+ def __init__ (self , top_k : int ):
88
+ """
83
89
Args:
84
90
top_k: Consider only the top k elements for each query.
85
91
"""
@@ -108,18 +114,16 @@ def reset(self):
108
114
self .num_sample = 0
109
115
110
116
111
- class Precision :
112
- def __init__ (self , num_classes : int , average : str , top_k : int ):
113
- """Compute the Precision@K.
117
+ class PrecisionAtK :
118
+ """Compute the Precision@K. Please refer to the `implementation document`
119
+ (https://www.csie.ntu.edu.tw/~cjlin/papers/libmultilabel/libmultilabel_implementation.pdf) for details.
120
+ """
114
121
122
+ def __init__ (self , top_k : int ):
123
+ """
115
124
Args:
116
- num_classes: The number of classes.
117
- average: Define the reduction that is applied over labels. Currently only "samples" is supported.
118
125
top_k: Consider only the top k elements for each query.
119
126
"""
120
- if average != "samples" :
121
- raise ValueError ("unsupported average" )
122
-
123
127
_check_top_k (top_k )
124
128
125
129
self .top_k = top_k
@@ -144,18 +148,16 @@ def reset(self):
144
148
self .num_sample = 0
145
149
146
150
147
- class Recall :
148
- def __init__ (self , num_classes : int , average : str , top_k : int ):
149
- """Compute the Recall@K.
151
+ class RecallAtK :
152
+ """Compute the Recall@K. Please refer to the `implementation document`
153
+ (https://www.csie.ntu.edu.tw/~cjlin/papers/libmultilabel/libmultilabel_implementation.pdf) for details.
154
+ """
150
155
156
+ def __init__ (self , top_k : int ):
157
+ """
151
158
Args:
152
- num_classes: The number of classes.
153
- average: Define the reduction that is applied over labels. Currently only "samples" is supported.
154
159
top_k: Consider only the top k elements for each query.
155
160
"""
156
- if average != "samples" :
157
- raise ValueError ("unsupported average" )
158
-
159
161
_check_top_k (top_k )
160
162
161
163
self .top_k = top_k
@@ -182,9 +184,12 @@ def reset(self):
182
184
183
185
184
186
class F1 :
185
- def __init__ (self , num_classes : int , average : str , multiclass = False ):
186
- """Compute the F1 score.
187
+ """Compute the F1 score. Please refer to the `implementation document`
188
+ (https://www.csie.ntu.edu.tw/~cjlin/papers/libmultilabel/libmultilabel_implementation.pdf) for details.
189
+ """
187
190
191
+ def __init__ (self , num_classes : int , average : str , multiclass = False ):
192
+ """
188
193
Args:
189
194
num_classes: The number of labels.
190
195
average: Define the reduction that is applied over labels. Should be one of "macro", "micro",
@@ -296,13 +301,13 @@ def get_metrics(monitor_metrics: list[str], num_classes: int, multiclass: bool =
296
301
metrics = {}
297
302
for metric in monitor_metrics :
298
303
if re .match ("P@\d+" , metric ):
299
- metrics [metric ] = Precision ( num_classes , average = "samples" , top_k = int (metric [2 :]))
304
+ metrics [metric ] = PrecisionAtK ( top_k = int (metric [2 :]))
300
305
elif re .match ("R@\d+" , metric ):
301
- metrics [metric ] = Recall ( num_classes , average = "samples" , top_k = int (metric [2 :]))
306
+ metrics [metric ] = RecallAtK ( top_k = int (metric [2 :]))
302
307
elif re .match ("RP@\d+" , metric ):
303
- metrics [metric ] = RPrecision (top_k = int (metric [3 :]))
308
+ metrics [metric ] = RPrecisionAtK (top_k = int (metric [3 :]))
304
309
elif re .match ("NDCG@\d+" , metric ):
305
- metrics [metric ] = NDCG (top_k = int (metric [5 :]))
310
+ metrics [metric ] = NDCGAtK (top_k = int (metric [5 :]))
306
311
elif metric in {"Another-Macro-F1" , "Macro-F1" , "Micro-F1" }:
307
312
metrics [metric ] = F1 (num_classes , average = metric [:- 3 ].lower (), multiclass = multiclass )
308
313
else :
0 commit comments