Skip to content

Commit b385826

Browse files
authored
Merge pull request #7 from innerNULL/dev
Fix Distributed Training and Metrics Calculation
2 parents b6350fe + d79a62d commit b385826

File tree

7 files changed

+97
-49
lines changed

7 files changed

+97
-49
lines changed

README.md

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,9 @@ to make this as a general program for text multi-label classification task.
1212
## Usage
1313
### Python Env
1414
```sh
15-
python -m venv ./_venv --copies
16-
source ./_venv/bin/activate
17-
python -m pip install --upgrade pip
18-
python -m pip install -r requirements.txt
19-
# deactivate
15+
micromamba env create -f environment.yaml -p ./_pyenv --yes
16+
micromamba activate ./_pyenv
17+
pip install -r requirements.txt
2018
```
2119
### Run Tests
2220
```sh

environment.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
name: pyenv
2+
channels:
3+
- conda-forge
4+
- defaults
5+
dependencies:
6+
- python=3.11
7+
- setuptools<65
8+
- gfortran_linux-64>=11.2.0
9+
- openblas>=0.3.18
10+
- ninja>=1.10.2
11+
- openmpi>=5.0.8

requirements.txt

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
mypy==0.982
2-
torch==2.0.0
3-
#torchmetrics==1.2.0
4-
sentencepiece==0.1.96
5-
duckdb==0.9.1
6-
pandas==2.0.0
7-
numpy==1.24.2
8-
scikit-learn==1.2.2
9-
transformers==4.28.1
10-
onnx==1.14.0
11-
onnxruntime==1.15.0
12-
ray[train]==2.7.0
1+
mypy>=1.17.0,<=1.17.0
2+
torch>=2.0.0,<=2.7.1
3+
sentencepiece>=0.1.96
4+
duckdb>=0.9.1,<=1.3.2
5+
pandas>=2.0.0,<=2.3.1
6+
numpy>=1.24.2,<=2.3.1
7+
scikit-learn>=1.2.2,<=1.7.1
8+
transformers>=4.53.1,<=4.54.1
9+
onnx>=1.14.0,<=1.18.0
10+
onnxruntime>=1.15.0,<=1.22.1
11+
ray[train]>=2.7.0,<=2.48.0
12+
xformers==0.0.31.post1

src/plm_icd_multi_label_classifier/metrics.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,41 @@
55

66
import pdb
77
import torch
8-
from typing import Dict, Optional
8+
from typing import List, Dict, Optional
99
from torch import Tensor, IntTensor, FloatTensor
1010

1111

1212
def metrics_func(
1313
preds_one_hot: IntTensor, label_one_hot: IntTensor, bias: float=1e-6
1414
) -> float:
15+
pred_nonzero_idx: List[int] | int = torch.nonzero(preds_one_hot.sum(dim=0))\
16+
.squeeze()\
17+
.tolist()
18+
gt_nonzero_idx: List[int] | int = torch.nonzero(label_one_hot.sum(dim=0))\
19+
.squeeze()\
20+
.tolist()
21+
pred_nonzero_idx = (
22+
[pred_nonzero_idx] if isinstance(pred_nonzero_idx, int)
23+
else pred_nonzero_idx
24+
)
25+
gt_nonzero_idx = (
26+
[gt_nonzero_idx] if isinstance(gt_nonzero_idx, int)
27+
else gt_nonzero_idx
28+
)
29+
target_label_idx: List[str] = sorted(
30+
list(set(pred_nonzero_idx + gt_nonzero_idx))
31+
)
32+
preds_one_hot = preds_one_hot[:, target_label_idx]
33+
label_one_hot = label_one_hot[:, target_label_idx]
34+
1535
# 1 represents correctly predicted positive class
1636
pred_pos_correctness: IntTensor = preds_one_hot.mul(label_one_hot)
1737

18-
correct_pos_pred_cnt: IntTensor = pred_pos_correctness.sum(dim=1)
19-
sample_label_cnt: IntTensor = label_one_hot.sum(dim=1) + bias
20-
pred_label_cnt: IntTensor = preds_one_hot.sum(dim=1) + bias
21-
38+
# Label level statistics
39+
correct_pos_pred_cnt: IntTensor = pred_pos_correctness.sum(dim=0)
40+
sample_label_cnt: IntTensor = label_one_hot.sum(dim=0) + bias
41+
pred_label_cnt: IntTensor = preds_one_hot.sum(dim=0) + bias
42+
2243
macro_recall: FloatTensor = correct_pos_pred_cnt.div(sample_label_cnt).mean()
2344
macro_precision: FloatTensor = correct_pos_pred_cnt.div(pred_label_cnt).mean()
2445
macro_f1: FloatTensor = 2 * macro_recall * macro_precision / (macro_recall + macro_precision + bias)

src/plm_icd_multi_label_classifier/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@ def __init__(self,
2121
super().__init__()
2222

2323
# Language model
24-
self._lm: Module = AutoModel.from_pretrained(lm) if isinstance(lm, str) else lm
24+
self._lm: Module = (
25+
AutoModel.from_pretrained(lm, trust_remote_code=True) if isinstance(lm, str)
26+
else lm
27+
)
2528

2629
# Dimension info
2730
self._label_num: int = label_num

train.py

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
import tempfile
1010
import json
1111
import torch
12+
import random
1213
import ray.train
1314
import torch.nn.functional as F
15+
import numpy as np
1416
from typing import Dict
1517
from transformers import AutoTokenizer
1618
from torch import device
@@ -28,6 +30,9 @@
2830
from src.plm_icd_multi_label_classifier.metrics import metrics_func, topk_metrics_func
2931

3032

33+
THRESHOLD: float = 0.6
34+
35+
3136
def init_with_ckpt(net: PlmMultiLabelEncoder, ckpt_root_path: str, engine: str) -> None:
3237
ckpts: List[str] = [x for x in os.listdir(ckpt_root_path) if x != "bak"]
3338
if len(ckpts) == 0:
@@ -94,7 +99,7 @@ def eval(
9499

95100
logits: FloatTensor = torch.concat(all_logits, dim=0)
96101
output_label_probs: FloatTensor = torch.sigmoid(logits)
97-
output_one_hot: FloatTensor = (output_label_probs > 0.5).float()
102+
output_one_hot: FloatTensor = (output_label_probs > THRESHOLD).float()
98103
label_one_hot: FloatTensor = torch.concat(all_label_one_hots, dim=0)
99104
# Loss
100105
loss: float = float(
@@ -104,9 +109,9 @@ def eval(
104109
prob50_metrics: Dict[str, float] = metrics_func(
105110
output_one_hot.int(), label_one_hot.int()
106111
)
107-
top5_metrics: Dict[str, float] = topk_metrics_func(logits, label_one_hot, top_k=5)
108-
top8_metrics: Dict[str, float] = topk_metrics_func(logits, label_one_hot, top_k=8)
109-
top15_metrics: Dict[str, float] = topk_metrics_func(logits, label_one_hot, top_k=15)
112+
#top5_metrics: Dict[str, float] = topk_metrics_func(logits, label_one_hot, top_k=5)
113+
#top8_metrics: Dict[str, float] = topk_metrics_func(logits, label_one_hot, top_k=8)
114+
#top15_metrics: Dict[str, float] = topk_metrics_func(logits, label_one_hot, top_k=15)
110115

111116
out = {
112117
"loss": round(loss, 8),
@@ -116,29 +121,31 @@ def eval(
116121
"macro_recall": round(prob50_metrics["macro_recall"], 4),
117122
"macro_precision": round(prob50_metrics["macro_precision"], 4),
118123
"macro_f1": round(prob50_metrics["macro_f1"], 4),
119-
"micro_recall@5": round(top5_metrics["micro_recall@5"], 4),
120-
"micro_precision@5": round(top5_metrics["micro_precision@5"], 4),
121-
"micro_f1@5": round(top5_metrics["micro_f1@5"], 4),
122-
"macro_recall@5": round(top5_metrics["macro_recall@5"], 4),
123-
"macro_precision@5": round(top5_metrics["macro_precision@5"], 4),
124-
"macro_f1@5": round(top5_metrics["macro_f1@5"], 4),
125-
"micro_recall@8": round(top8_metrics["micro_recall@8"], 4),
126-
"micro_precision@8": round(top8_metrics["micro_precision@8"], 4),
127-
"micro_f1@8": round(top8_metrics["micro_f1@8"], 4),
128-
"macro_recall@8": round(top8_metrics["macro_recall@8"], 4),
129-
"macro_precision@8": round(top8_metrics["macro_precision@8"], 4),
130-
"macro_f1@8": round(top8_metrics["macro_f1@8"], 4),
131-
"micro_recall@15": round(top15_metrics["micro_recall@15"], 4),
132-
"micro_precision@15": round(top15_metrics["micro_precision@15"], 4),
133-
"micro_f1@15": round(top15_metrics["micro_f1@15"], 4),
134-
"macro_recall@15": round(top15_metrics["macro_recall@15"], 4),
135-
"macro_precision@15": round(top15_metrics["macro_precision@15"], 4),
136-
"macro_f1@15": round(top15_metrics["macro_f1@15"], 4)
124+
#"micro_recall@5": round(top5_metrics["micro_recall@5"], 4),
125+
#"micro_precision@5": round(top5_metrics["micro_precision@5"], 4),
126+
#"micro_f1@5": round(top5_metrics["micro_f1@5"], 4),
127+
#"macro_recall@5": round(top5_metrics["macro_recall@5"], 4),
128+
#"macro_precision@5": round(top5_metrics["macro_precision@5"], 4),
129+
#"macro_f1@5": round(top5_metrics["macro_f1@5"], 4),
130+
#"micro_recall@8": round(top8_metrics["micro_recall@8"], 4),
131+
#"micro_precision@8": round(top8_metrics["micro_precision@8"], 4),
132+
#"micro_f1@8": round(top8_metrics["micro_f1@8"], 4),
133+
#"macro_recall@8": round(top8_metrics["macro_recall@8"], 4),
134+
#"macro_precision@8": round(top8_metrics["macro_precision@8"], 4),
135+
#"macro_f1@8": round(top8_metrics["macro_f1@8"], 4),
136+
#"micro_recall@15": round(top15_metrics["micro_recall@15"], 4),
137+
#"micro_precision@15": round(top15_metrics["micro_precision@15"], 4),
138+
#"micro_f1@15": round(top15_metrics["micro_f1@15"], 4),
139+
#"macro_recall@15": round(top15_metrics["macro_recall@15"], 4),
140+
#"macro_precision@15": round(top15_metrics["macro_precision@15"], 4),
141+
#"macro_f1@15": round(top15_metrics["macro_f1@15"], 4)
137142
}
138143
return out
139144

140145
def train_func(configs: Dict) -> None:
141146
torch.manual_seed(configs["random_seed"])
147+
random.seed(configs["random_seed"])
148+
np.random.seed(configs["random_seed"])
142149

143150
device: device = None
144151
if configs["training_engine"] == "torch":
@@ -233,8 +240,10 @@ def train_func(configs: Dict) -> None:
233240
elif configs["training_engine"] == "ray":
234241
if ray.train.get_context().get_world_rank() == 0:
235242
open(os.path.join(ckpt_dir, "train.json"), "w").write(json.dumps(configs))
236-
torch.save(model.module.state_dict(), os.path.join(ckpt_dir, "model.pt"))
237-
243+
try:
244+
torch.save(model.module.state_dict(), os.path.join(ckpt_dir, "model.pt"))
245+
except:
246+
torch.save(model.state_dict(), os.path.join(ckpt_dir, "model.pt"))
238247
global_step_id += 1
239248

240249
final_ckpt_dir: str = os.path.join(configs["ckpt_dir"], "final")
@@ -246,7 +255,10 @@ def train_func(configs: Dict) -> None:
246255
elif configs["training_engine"] == "ray":
247256
if ray.train.get_context().get_world_rank() == 0:
248257
open(os.path.join(final_ckpt_dir, "train.json"), "w").write(json.dumps(configs))
249-
torch.save(model.module.state_dict(), os.path.join(final_ckpt_dir, "model.pt"))
258+
try:
259+
torch.save(model.module.state_dict(), os.path.join(final_ckpt_dir, "model.pt"))
260+
except:
261+
torch.save(model.state_dict(), os.path.join(ckpt_dir, "model.pt"))
250262

251263

252264
if __name__ == "__main__":
@@ -256,7 +268,9 @@ def train_func(configs: Dict) -> None:
256268
if os.path.exists(train_conf["hf_lm"]):
257269
train_conf["hf_lm"] = os.path.abspath(train_conf["hf_lm"])
258270
print("Training config:\n{}".format(train_conf))
259-
271+
272+
os.environ["HF_TOKEN"] = train_conf["hf_key"]
273+
260274
os.system("mkdir -p %s" % train_conf["ckpt_dir"])
261275

262276
if train_conf["training_engine"] == "torch":

train_mimic3_icd.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
"chunk_size": 256,
33
"chunk_num": 3,
44
"hf_lm": "distilbert-base-uncased",
5+
"hf_key": "",
56
"lm_hidden_dim": 768,
67
"data_dir": "./_data/etl/mimic3",
78
"training_engine": "ray",

0 commit comments

Comments
 (0)