Skip to content

Commit d0667be

Browse files
authored
Merge pull request #135 from ASUS-AICS/labels
Fix Macro-F1 & Use train labels by default
2 parents ee53873 + aa04e8c commit d0667be

File tree

10 files changed

+70
-20
lines changed

10 files changed

+70
-20
lines changed

example_config/MIMIC-50/bigru.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ data_dir: data/MIMIC-50
33
data_name: MIMIC-50
44
min_vocab_freq: 3
55
max_seq_length: 2500
6+
include_test_labels: true
67

78
# train
89
seed: 1337

example_config/MIMIC-50/caml.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ data_dir: data/MIMIC-50
44
data_name: MIMIC-50
55
min_vocab_freq: 3
66
max_seq_length: 2500
7+
# We follow caml-mimic that includes labels in both training and test datasets.
8+
include_test_labels: true
79

810
# train
911
seed: 1337

example_config/MIMIC-50/caml_tune.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ data_dir: data/MIMIC-50
33
data_name: MIMIC-50
44
min_vocab_freq: 3
55
max_seq_length: 2500
6+
include_test_labels: true
67

78
# train
89
seed: 1337

example_config/MIMIC-50/cnn.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ data_dir: data/MIMIC-50
33
data_name: MIMIC-50
44
min_vocab_freq: 3
55
max_seq_length: 2500
6+
include_test_labels: true
67

78
# train
89
seed: 1337

libmultilabel/metrics.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import re
22

3-
import torch
43
import numpy as np
4+
import torch
55
import torchmetrics.classification
66
from torchmetrics import Metric, MetricCollection, Precision, Recall, RetrievalNormalizedDCG
77
from torchmetrics.utilities.data import select_topk
@@ -40,12 +40,54 @@ def compute(self):
4040
return self.score / self.num_sample
4141

4242

43+
class MacroF1(Metric):
44+
"""The macro-f1 score computes the average f1 scores of all labels in the dataset.
45+
46+
Args:
47+
num_classes (int): The number of classes.
48+
metric_threshold (float): Threshold to monitor for metrics.
49+
another_macro_f1 (bool, optional): Whether to compute the 'Another-Macro-F1' score.
50+
The 'Another-Macro-F1' is the f1 value of macro-precision and macro-recall.
51+
This variant of macro-f1 is less preferred but is used in some works.
52+
Please refer to Opitz et al. 2019 [https://arxiv.org/pdf/1911.03347.pdf].
53+
Defaults to False.
54+
"""
55+
def __init__(
56+
self,
57+
num_classes,
58+
metric_threshold,
59+
another_macro_f1=False
60+
):
61+
super().__init__()
62+
self.metric_threshold = metric_threshold
63+
self.another_macro_f1 = another_macro_f1
64+
self.add_state("preds_sum", default=torch.zeros(num_classes, dtype=torch.double))
65+
self.add_state("target_sum", default=torch.zeros(num_classes, dtype=torch.double))
66+
self.add_state("tp_sum", default=torch.zeros(num_classes, dtype=torch.double))
67+
68+
def update(self, preds, target):
69+
assert preds.shape == target.shape
70+
preds = torch.where(preds > self.metric_threshold, 1, 0)
71+
self.preds_sum = torch.add(self.preds_sum, preds.sum(dim=0))
72+
self.target_sum = torch.add(self.target_sum, target.sum(dim=0))
73+
self.tp_sum = torch.add(self.tp_sum, (preds & target).sum(dim=0))
74+
75+
def compute(self):
76+
if self.another_macro_f1:
77+
macro_prec = torch.mean(torch.nan_to_num(self.tp_sum / self.preds_sum, posinf=0.))
78+
macro_recall = torch.mean(torch.nan_to_num(self.tp_sum / self.target_sum, posinf=0.))
79+
return 2 * (macro_prec * macro_recall) / (macro_prec + macro_recall + 1e-10)
80+
else:
81+
label_f1 = 2 * self.tp_sum / (self.preds_sum + self.target_sum + 1e-10)
82+
return torch.mean(label_f1)
83+
84+
4385
def get_metrics(metric_threshold, monitor_metrics, num_classes):
4486
"""Map monitor metrics to the corresponding classes defined in `torchmetrics.Metric`
4587
(https://torchmetrics.readthedocs.io/en/latest/references/modules.html).
4688
4789
Args:
48-
metric_threshold (float): Thresholds to monitor for metrics.
90+
metric_threshold (float): Threshold to monitor for metrics.
4991
monitor_metrics (list): Metrics to monitor while validating.
5092
num_classes (int): Total number of classes.
5193
@@ -86,15 +128,11 @@ def get_metrics(metric_threshold, monitor_metrics, num_classes):
86128
elif metric_abbr == 'nDCG':
87129
metrics[metric] = RetrievalNormalizedDCG(k=top_k)
88130
elif metric == 'Another-Macro-F1':
89-
# The f1 value of macro_precision and macro_recall. This variant of
90-
# macro_f1 is less preferred but is used in some works. Please
91-
# refer to Opitz et al. 2019 [https://arxiv.org/pdf/1911.03347.pdf]
92-
macro_prec = Precision(num_classes, metric_threshold, average='macro')
93-
macro_recall = Recall(num_classes, metric_threshold, average='macro')
94-
metrics[metric] = 2 * (macro_prec * macro_recall) / \
95-
(macro_prec + macro_recall + 1e-10)
131+
metrics[metric] = MacroF1(num_classes, metric_threshold, another_macro_f1=True)
132+
elif metric == 'Macro-F1':
133+
metrics[metric] = MacroF1(num_classes, metric_threshold)
96134
elif match_metric:
97-
average_type = match_metric.group(1).lower() # Micro or Macro
135+
average_type = match_metric.group(1).lower() # Micro
98136
metric_type = match_metric.group(2) # Precision, Recall, or F1
99137
metrics[metric] = getattr(torchmetrics.classification, metric_type)(
100138
num_classes, metric_threshold, average=average_type)

libmultilabel/nn/data_utils.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -227,27 +227,32 @@ def load_or_build_text_dict(
227227
return vocabs
228228

229229

230-
def load_or_build_label(datasets, label_file=None, silent=False):
230+
def load_or_build_label(datasets, label_file=None, include_test_labels=False):
231231
"""Generate label set either by the given datasets or a predefined label file.
232232
233233
Args:
234-
datasets (dict): A dictionary of datasets. Each dataset contains list of instances with index, label, and tokenized text.
234+
datasets (dict): A dictionary of datasets. Each dataset contains list of instances
235+
with index, label, and tokenized text.
235236
label_file (str, optional): Path to a file holding all labels.
236-
silent (bool, optional): Disable print. Defaults to False.
237+
include_test_labels (bool, optional): Whether to include labels in the test dataset.
238+
Defaults to True.
237239
238240
Returns:
239241
list: A list of labels sorted in alphabetical order.
240242
"""
241243
if label_file:
242-
logging.info('Load labels from {label_file}')
244+
logging.info(f'Load labels from {label_file}.')
243245
with open(label_file, 'r') as fp:
244246
classes = sorted([s.strip() for s in fp.readlines()])
245247
else:
246248
classes = set()
247-
for dataset in datasets.values():
248-
for d in tqdm(dataset, disable=silent):
249-
classes.update(d['label'])
249+
for split, data in datasets.items():
250+
if split == 'test' and not include_test_labels:
251+
continue
252+
for instance in data:
253+
classes.update(instance['label'])
250254
classes = sorted(classes)
255+
logging.info(f'Read {len(classes)} labels.')
251256
return classes
252257

253258

libmultilabel/nn/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class MultiLabelModel(pl.LightningModule):
1919
optimizer (str, optional): Optimizer name (i.e., sgd, adam, or adamw). Defaults to 'adam'.
2020
momentum (float, optional): Momentum factor for SGD only. Defaults to 0.9.
2121
weight_decay (int, optional): Weight decay factor. Defaults to 0.
22-
metric_threshold (float, optional): Thresholds to monitor for metrics. Defaults to 0.5.
22+
metric_threshold (float, optional): Threshold to monitor for metrics. Defaults to 0.5.
2323
monitor_metrics (list, optional): Metrics to monitor while validating. Defaults to None.
2424
log_path (str): Path to a directory holding the log files and models.
2525
silent (bool, optional): Enable silent mode. Defaults to False.

libmultilabel/nn/nn_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def init_model(model_name,
6262
optimizer (str, optional): Optimizer name (i.e., sgd, adam, or adamw). Defaults to 'adam'.
6363
momentum (float, optional): Momentum factor for SGD only. Defaults to 0.9.
6464
weight_decay (int, optional): Weight decay factor. Defaults to 0.
65-
metric_threshold (float, optional): Thresholds to monitor for metrics. Defaults to 0.5.
65+
metric_threshold (float, optional): Threshold to monitor for metrics. Defaults to 0.5.
6666
monitor_metrics (list, optional): Metrics to monitor while validating. Defaults to None.
6767
silent (bool, optional): Enable silent mode. Defaults to False.
6868
save_k_predictions (int, optional): Save top k predictions on test set. Defaults to 0.

main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def get_config():
4747
help='Whether to shuffle training data before each epoch (default: %(default)s)')
4848
parser.add_argument('--merge_train_val', action='store_true',
4949
help='Whether to merge the training and validation data. (default: %(default)s)')
50+
parser.add_argument('--include_test_labels', action='store_true',
51+
help='Whether to include labels in the test dataset. (default: %(default)s)')
5052

5153
# train
5254
parser.add_argument('--seed', type=int,

torch_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def _setup_model(
111111
)
112112
if not classes:
113113
classes = data_utils.load_or_build_label(
114-
self.datasets, self.config.label_file, self.config.silent)
114+
self.datasets, self.config.label_file, self.config.include_test_labels)
115115

116116
if self.config.val_metric not in self.config.monitor_metrics:
117117
logging.warn(

0 commit comments

Comments
 (0)