Skip to content

Commit 1b7779b

Browse files
authored
Merge pull request #343 from ntumlgroup/update_pkgs
Update packages: torch, torchmetrics, lightning
2 parents 10819d0 + 5947a68 commit 1b7779b

File tree

11 files changed

+205
-191
lines changed

11 files changed

+205
-191
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ This is an on-going development so many improvements are still being made. Comme
1010

1111
## Environments
1212
- Python: 3.8+
13-
- CUDA: 11.6 (if training neural networks by GPU)
14-
- Pytorch 1.13.1
13+
- CUDA: 11.8, 12.1 (if training neural networks by GPU)
14+
- Pytorch: 2.0.1+
1515

1616
If you have a different version of CUDA, follow the installation instructions for PyTorch LTS at their [website](https://pytorch.org/).
1717

docs/cli/ov_data_format.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ Install LibMultiLabel from Source
2525
* Environment
2626

2727
* Python: 3.8+
28-
* CUDA: 11.6 (if training neural networks by GPU)
29-
* Pytorch 1.13.1
28+
* CUDA: 11.8, 12.1 (if training neural networks by GPU)
29+
* Pytorch 2.0.1+
3030

3131
It is optional but highly recommended to
3232
create a virtual environment.

libmultilabel/linear/metrics.py

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def _argsort_top_k(preds: np.ndarray, top_k: int) -> np.ndarray:
1717
return np.take_along_axis(top_k_idx, argsort_top_k, axis=-1)
1818

1919

20-
def _DCG_argsort(argsort_preds: np.ndarray, target: np.ndarray, top_k: int) -> np.ndarray:
20+
def _dcg_argsort(argsort_preds: np.ndarray, target: np.ndarray, top_k: int) -> np.ndarray:
2121
"""Computes DCG@k with a sorted preds array and a target array."""
2222
top_k_idx = argsort_preds[:, -top_k:][:, ::-1]
2323
gains = np.take_along_axis(target, top_k_idx, axis=-1)
@@ -27,7 +27,7 @@ def _DCG_argsort(argsort_preds: np.ndarray, target: np.ndarray, top_k: int) -> n
2727
return dcgs
2828

2929

30-
def _IDCG(target: np.ndarray, top_k: int) -> np.ndarray:
30+
def _idcg(target: np.ndarray, top_k: int) -> np.ndarray:
3131
"""Computes IDCG@k for a 0/1 target array. A 0/1 target is a special case that
3232
doesn't require sorting. If IDCG is computed with DCG,
3333
then target will need to be sorted, which incurs a large overhead.
@@ -43,10 +43,13 @@ def _IDCG(target: np.ndarray, top_k: int) -> np.ndarray:
4343
return cum_discount[idx]
4444

4545

46-
class NDCG:
47-
def __init__(self, top_k: int):
48-
"""Compute the normalized DCG@k (nDCG@k).
46+
class NDCGAtK:
47+
"""Compute the normalized DCG@k (nDCG@k). Please refer to the `implementation document`
48+
(https://www.csie.ntu.edu.tw/~cjlin/papers/libmultilabel/libmultilabel_implementation.pdf) for details.
49+
"""
4950

51+
def __init__(self, top_k: int):
52+
"""
5053
Args:
5154
top_k: Consider only the top k elements for each query.
5255
"""
@@ -61,8 +64,8 @@ def update(self, preds: np.ndarray, target: np.ndarray):
6164
return self.update_argsort(_argsort_top_k(preds, self.top_k), target)
6265

6366
def update_argsort(self, argsort_preds: np.ndarray, target: np.ndarray):
64-
dcg = _DCG_argsort(argsort_preds, target, self.top_k)
65-
idcg = _IDCG(target, self.top_k)
67+
dcg = _dcg_argsort(argsort_preds, target, self.top_k)
68+
idcg = _idcg(target, self.top_k)
6669
ndcg_score = dcg / idcg
6770
# by convention, ndcg is 0 for zero label instances
6871
self.score += np.nan_to_num(ndcg_score, nan=0.0).sum()
@@ -76,10 +79,13 @@ def reset(self):
7679
self.num_sample = 0
7780

7881

79-
class RPrecision:
80-
def __init__(self, top_k: int):
81-
"""Compute the R-Precision@K.
82+
class RPrecisionAtK:
83+
"""Compute the R-Precision@K. Please refer to the `implementation document`
84+
(https://www.csie.ntu.edu.tw/~cjlin/papers/libmultilabel/libmultilabel_implementation.pdf) for details.
85+
"""
8286

87+
def __init__(self, top_k: int):
88+
"""
8389
Args:
8490
top_k: Consider only the top k elements for each query.
8591
"""
@@ -108,18 +114,16 @@ def reset(self):
108114
self.num_sample = 0
109115

110116

111-
class Precision:
112-
def __init__(self, num_classes: int, average: str, top_k: int):
113-
"""Compute the Precision@K.
117+
class PrecisionAtK:
118+
"""Compute the Precision@K. Please refer to the `implementation document`
119+
(https://www.csie.ntu.edu.tw/~cjlin/papers/libmultilabel/libmultilabel_implementation.pdf) for details.
120+
"""
114121

122+
def __init__(self, top_k: int):
123+
"""
115124
Args:
116-
num_classes: The number of classes.
117-
average: Define the reduction that is applied over labels. Currently only "samples" is supported.
118125
top_k: Consider only the top k elements for each query.
119126
"""
120-
if average != "samples":
121-
raise ValueError("unsupported average")
122-
123127
_check_top_k(top_k)
124128

125129
self.top_k = top_k
@@ -144,18 +148,16 @@ def reset(self):
144148
self.num_sample = 0
145149

146150

147-
class Recall:
148-
def __init__(self, num_classes: int, average: str, top_k: int):
149-
"""Compute the Recall@K.
151+
class RecallAtK:
152+
"""Compute the Recall@K. Please refer to the `implementation document`
153+
(https://www.csie.ntu.edu.tw/~cjlin/papers/libmultilabel/libmultilabel_implementation.pdf) for details.
154+
"""
150155

156+
def __init__(self, top_k: int):
157+
"""
151158
Args:
152-
num_classes: The number of classes.
153-
average: Define the reduction that is applied over labels. Currently only "samples" is supported.
154159
top_k: Consider only the top k elements for each query.
155160
"""
156-
if average != "samples":
157-
raise ValueError("unsupported average")
158-
159161
_check_top_k(top_k)
160162

161163
self.top_k = top_k
@@ -182,9 +184,12 @@ def reset(self):
182184

183185

184186
class F1:
185-
def __init__(self, num_classes: int, average: str, multiclass=False):
186-
"""Compute the F1 score.
187+
"""Compute the F1 score. Please refer to the `implementation document`
188+
(https://www.csie.ntu.edu.tw/~cjlin/papers/libmultilabel/libmultilabel_implementation.pdf) for details.
189+
"""
187190

191+
def __init__(self, num_classes: int, average: str, multiclass=False):
192+
"""
188193
Args:
189194
num_classes: The number of labels.
190195
average: Define the reduction that is applied over labels. Should be one of "macro", "micro",
@@ -296,13 +301,13 @@ def get_metrics(monitor_metrics: list[str], num_classes: int, multiclass: bool =
296301
metrics = {}
297302
for metric in monitor_metrics:
298303
if re.match("P@\d+", metric):
299-
metrics[metric] = Precision(num_classes, average="samples", top_k=int(metric[2:]))
304+
metrics[metric] = PrecisionAtK(top_k=int(metric[2:]))
300305
elif re.match("R@\d+", metric):
301-
metrics[metric] = Recall(num_classes, average="samples", top_k=int(metric[2:]))
306+
metrics[metric] = RecallAtK(top_k=int(metric[2:]))
302307
elif re.match("RP@\d+", metric):
303-
metrics[metric] = RPrecision(top_k=int(metric[3:]))
308+
metrics[metric] = RPrecisionAtK(top_k=int(metric[3:]))
304309
elif re.match("NDCG@\d+", metric):
305-
metrics[metric] = NDCG(top_k=int(metric[5:]))
310+
metrics[metric] = NDCGAtK(top_k=int(metric[5:]))
306311
elif metric in {"Another-Macro-F1", "Macro-F1", "Micro-F1"}:
307312
metrics[metric] = F1(num_classes, average=metric[:-3].lower(), multiclass=multiclass)
308313
else:

0 commit comments

Comments
 (0)