Skip to content

Commit 9b667b8

Browse files
authored
Merge pull request #9 from innerNULL/dev
Batch Fixs and Upgrades
2 parents bba88b7 + a17e753 commit 9b667b8

File tree

7 files changed

+238
-108
lines changed

7 files changed

+238
-108
lines changed

README.md

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,55 @@ pip install -r requirements.txt
2121
python -m pytest ./test --cov=./src/plm_icd_multi_label_classifier --durations=0 -v
2222
```
2323

24-
### ETL
24+
### Custom Dataset Preparation
25+
The training dataset should be a directory with following structure:
26+
```
27+
├── dev.jsonl
28+
├── dict.json
29+
└── train.jsonl
30+
31+
0 directories, 3 files
32+
```
33+
34+
`train.jsonl` and `dev.jsonl` are train and validation dataset which are in JSON lines format.
35+
Each JSON data should contains at lease 2 fields which correspondingly be as inputs text and
36+
output label names, following is an example:
37+
```
38+
{
39+
"input_text": "...",
40+
"labels": "label_1, label_5, ..., label_m"
41+
}
42+
```
43+
and based on this sample data format, you should have following settings in your configs:
44+
```
45+
{
46+
...
47+
"data_dir": "your dataset directory path",
48+
"text_col": "input_text",
49+
"label_col": "labels"
50+
...
51+
}
52+
```
53+
54+
And the `dict.json` is for bi-directionary mapping between label names and IDs, the format is:
55+
```
56+
{
57+
"label2id": {
58+
59+
},
60+
"id2label": {
61+
0: "label_0",
62+
1: "label_1",
63+
2: "label_2",
64+
...
65+
n: "label_n"
66+
}
67+
}
68+
```
69+
As the label ID will be also used as index in one-hot vector, so must start from 0.
70+
71+
72+
### (MIMIC3 Dataset Preparation)
2573
The ETL contain following steps:
2674
* Origin JSON line dataset preparation
2775
* Transform JSON line file to **limited** JOSN line file, which means all `list` or `dict`

src/plm_icd_multi_label_classifier/data.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,27 @@
2020

2121
class TextOnlyDataset(Dataset):
2222
def __init__(self,
23-
data_path: str, data_dict_path: str, tokenizer: AutoTokenizer,
24-
text_col: str="text", label_col: str="label", data_format: int="csv",
25-
chunk_size: int=512, chunk_num: int=2
23+
data_path: str,
24+
data_dict_path: str,
25+
tokenizer: AutoTokenizer,
26+
text_col: str="text",
27+
label_col: str="label",
28+
data_format: int="csv",
29+
chunk_size: int=512,
30+
chunk_num: int=2,
31+
label_splitter: str=","
2632
):
2733
self.data_path: str = data_path
2834
self.data_dict: Dict[str, Dict] = json.loads(open(data_dict_path, "r").read())
2935
self.text_col: str = text_col
3036
self.label_col: str = label_col
3137
self.data: List[Dict] = []
3238
self.model_ctx: PlmIcdCtx = PlmIcdCtx().init(
33-
data_dict_path=data_dict_path, lm_tokenizer=tokenizer,
34-
chunk_size=chunk_size, chunk_num=chunk_num
39+
data_dict_path=data_dict_path,
40+
lm_tokenizer=tokenizer,
41+
chunk_size=chunk_size,
42+
chunk_num=chunk_num,
43+
label_splitter=label_splitter
3544
)
3645

3746
if data_format == "csv":
@@ -45,7 +54,8 @@ def __init__(self,
4554
# using cutomized dev/test data do evaluation.
4655
for i, record in enumerate(self.data):
4756
curr_filtered_label: List[str] = [
48-
x for x in record[label_col].split(",") if x in self.data_dict["label2id"]
57+
x for x in record[label_col].split(label_splitter)
58+
if x in self.data_dict["label2id"]
4959
]
5060
if len(curr_filtered_label) == 0:
5161
self.data[i] = None
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# -*- coding: utf-8 -*-
2+
# file: eval.py
3+
# date: 2025-08-04
4+
5+
6+
import pdb
7+
import json
8+
import torch
9+
import torch.nn.functional as F
10+
from typing import Dict
11+
from torch import device
12+
from torch import LongTensor, FloatTensor, IntTensor
13+
from torch.utils.data import DataLoader
14+
15+
from .model import PlmMultiLabelEncoder
16+
from .metrics import metrics_func, topk_metrics_func
17+
18+
19+
THRESHOLD: float = 0.6
20+
21+
22+
def evaluation(
23+
model: PlmMultiLabelEncoder,
24+
dataloader: DataLoader,
25+
device: device=None,
26+
max_sample: int=1e4,
27+
label_confidence_threshold: float=THRESHOLD,
28+
verbose: bool=False
29+
) -> Dict[str, float]:
30+
out: Dict[str, float] = {}
31+
total_cnt: int = 0
32+
all_logits: List[FloatTensor] = []
33+
all_label_one_hots: List[FloatTensor] = []
34+
35+
model.eval()
36+
with torch.no_grad():
37+
for batch in dataloader:
38+
curr_label_one_hot: FloatTensor = None
39+
curr_text_ids: LongTensor = None
40+
curr_attn_masks: LongTensor = None
41+
42+
curr_text_ids, curr_attn_masks, curr_label_one_hot = batch
43+
44+
if device is not None:
45+
curr_label_one_hot = curr_label_one_hot.to(device)
46+
curr_text_ids = curr_text_ids.to(device)
47+
curr_attn_masks = curr_attn_masks.to(device)
48+
49+
curr_logits: FloatTensor = model(curr_text_ids, curr_attn_masks)
50+
all_logits.append(curr_logits)
51+
all_label_one_hots.append(curr_label_one_hot)
52+
53+
total_cnt += curr_text_ids.shape[0]
54+
if total_cnt >= max_sample:
55+
break
56+
57+
logits: FloatTensor = torch.concat(all_logits, dim=0)
58+
output_label_probs: FloatTensor = torch.sigmoid(logits)
59+
output_one_hot: FloatTensor = (
60+
(output_label_probs > label_confidence_threshold).float()
61+
)
62+
label_one_hot: FloatTensor = torch.concat(all_label_one_hots, dim=0)
63+
# Loss
64+
loss: float = float(
65+
F.binary_cross_entropy(output_label_probs, label_one_hot).cpu()
66+
)
67+
# Metrics
68+
prob50_metrics: Dict[str, float] = metrics_func(
69+
output_one_hot.int(), label_one_hot.int()
70+
)
71+
#top5_metrics: Dict[str, float] = topk_metrics_func(logits, label_one_hot, top_k=5)
72+
#top8_metrics: Dict[str, float] = topk_metrics_func(logits, label_one_hot, top_k=8)
73+
#top15_metrics: Dict[str, float] = topk_metrics_func(logits, label_one_hot, top_k=15)
74+
75+
out = {
76+
"loss": round(loss, 8),
77+
"micro_recall": round(prob50_metrics["micro_recall"], 4),
78+
"micro_precision": round(prob50_metrics["micro_precision"], 4),
79+
"micro_f1": round(prob50_metrics["micro_f1"], 4),
80+
"macro_recall": round(prob50_metrics["macro_recall"], 4),
81+
"macro_precision": round(prob50_metrics["macro_precision"], 4),
82+
"macro_f1": round(prob50_metrics["macro_f1"], 4),
83+
#"micro_recall@5": round(top5_metrics["micro_recall@5"], 4),
84+
#"micro_precision@5": round(top5_metrics["micro_precision@5"], 4),
85+
#"micro_f1@5": round(top5_metrics["micro_f1@5"], 4),
86+
#"macro_recall@5": round(top5_metrics["macro_recall@5"], 4),
87+
#"macro_precision@5": round(top5_metrics["macro_precision@5"], 4),
88+
#"macro_f1@5": round(top5_metrics["macro_f1@5"], 4),
89+
#"micro_recall@8": round(top8_metrics["micro_recall@8"], 4),
90+
#"micro_precision@8": round(top8_metrics["micro_precision@8"], 4),
91+
#"micro_f1@8": round(top8_metrics["micro_f1@8"], 4),
92+
#"macro_recall@8": round(top8_metrics["macro_recall@8"], 4),
93+
#"macro_precision@8": round(top8_metrics["macro_precision@8"], 4),
94+
#"macro_f1@8": round(top8_metrics["macro_f1@8"], 4),
95+
#"micro_recall@15": round(top15_metrics["micro_recall@15"], 4),
96+
#"micro_precision@15": round(top15_metrics["micro_precision@15"], 4),
97+
#"micro_f1@15": round(top15_metrics["micro_f1@15"], 4),
98+
#"macro_recall@15": round(top15_metrics["macro_recall@15"], 4),
99+
#"macro_precision@15": round(top15_metrics["macro_precision@15"], 4),
100+
#"macro_f1@15": round(top15_metrics["macro_f1@15"], 4)
101+
}
102+
if verbose == True:
103+
out["verbose"] = {}
104+
out["verbose"]["pred_one_hot"] = output_one_hot.int().tolist()
105+
out["verbose"]["gt_one_hot"] = label_one_hot.int().tolist()
106+
return out

src/plm_icd_multi_label_classifier/model.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,12 @@
1515
class PlmMultiLabelEncoder(Module):
1616
def __init__(self,
1717
label_num: int,
18-
lm: Union[str, Module], lm_embd_dim: int, chunk_size: int=128, chunk_num: int=5,
19-
first_attn_hidden_dim: int=512
18+
lm: Union[str, Module],
19+
lm_embd_dim: int,
20+
chunk_size: int=128,
21+
chunk_num: int=5,
22+
first_attn_hidden_dim: int=512,
23+
freeze_lm: bool=False
2024
):
2125
super().__init__()
2226

@@ -25,6 +29,8 @@ def __init__(self,
2529
AutoModel.from_pretrained(lm, trust_remote_code=True) if isinstance(lm, str)
2630
else lm
2731
)
32+
if freeze_lm:
33+
self._lm.requires_grad = False
2834

2935
# Dimension info
3036
self._label_num: int = label_num

src/plm_icd_multi_label_classifier/model_ctx.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,30 @@ def __init__(self):
2121
self.label2id: Dict[str, int] = {}
2222
self.chunk_size: int = -1
2323
self.chunk_num: int = -1
24+
self.label_splitter: str = ","
2425

2526
def init_by_train_config(self, train_conf_path: str):
2627
train_conf: Dict = json.loads(open(train_conf_path, "r").read())
2728
data_dict_path: str = os.path.join(train_conf["data_dir"], "dict.json")
2829
lm_tokenizer: str = train_conf["hf_lm"]
2930
chunk_size: int = train_conf["chunk_size"]
3031
chunk_num: int = train_conf["chunk_num"]
31-
32-
return self.init(data_dict_path, lm_tokenizer, chunk_size, chunk_num)
32+
label_splitter: str = train_conf["label_splitter"]
33+
34+
return self.init(
35+
data_dict_path,
36+
lm_tokenizer,
37+
chunk_size,
38+
chunk_num,
39+
label_splitter
40+
)
3341

3442
def init(self,
3543
data_dict_path: str,
3644
lm_tokenizer: Union[str, AutoTokenizer],
37-
chunk_size: int, chunk_num: int
45+
chunk_size: int,
46+
chunk_num: int,
47+
label_splitter: str
3848
):
3949
self.data_dict = json.loads(open(data_dict_path, "r").read())
4050
self.id2label = {int(k): v for k, v in self.data_dict["id2label"].items()}
@@ -44,6 +54,7 @@ def init(self,
4454
else lm_tokenizer
4555
self.chunk_size = chunk_size
4656
self.chunk_num = chunk_num
57+
self.label_splitter = label_splitter
4758
return self
4859

4960
def json_inputs2model_inf_inputs(self,
@@ -64,7 +75,7 @@ def json_inputs2model_train_inputs(self,
6475
model_inputs: Dict[Tensor] = self.json_inputs2model_inf_inputs(
6576
json_inputs, text_fields
6677
)
67-
label_names: List[str] = json_inputs[label_field].split(",")
78+
label_names: List[str] = json_inputs[label_field].split(self.label_splitter)
6879
label_ids: List[int] = [
6980
self.label2id[x] for x in label_names if x in self.label2id
7081
]

0 commit comments

Comments
 (0)