Skip to content

innerNULL/PLM-ICD-multi-label-classifier

Repository files navigation

PLM-ICD-multi-label-classifier

A non-official multi-label classifier based on PLM-ICD paper.

Basically this is my personal side project. The target is deep understanding paper. Finally, here provide a more concise and clear implementation, which can make things easier when need do some custimization or extension.

Although the model comes from paper, I tried my best to make this as a general program for text multi-label classification task.

Usage

Python Env

micromamba env create -f environment.yaml -p ./_pyenv --yes
micromamba activate ./_pyenv
pip install -r requirements.txt

Run Tests

python -m pytest ./test --cov=./src/plm_icd_multi_label_classifier --durations=0 -v

Custom Dataset Preparation

The training dataset should be a directory with following structure:

├── dev.jsonl
├── dict.json
└── train.jsonl

0 directories, 3 files

train.jsonl and dev.jsonl are train and validation dataset which are in JSON lines format. Each JSON data should contains at lease 2 fields which correspondingly be as inputs text and output label names, following is an example:

{
  "input_text": "...",
  "labels": "label_1, label_5, ..., label_m"
}

and based on this sample data format, you should have following settings in your configs:

{
  ...
  "data_dir": "your dataset directory path",
  "text_col": "input_text",
  "label_col": "labels"
  ...   
}

And the dict.json is for bi-directionary mapping between label names and IDs, the format is:

{
  "label2id": {
    "label_0": 0,
    "label_1": 1, 
    "label_2": 2, 
    ...
    "label_n": n
  },
  "id2label": {
    0: "label_0",
    1: "label_1",
    2: "label_2",
    ...
    n: "label_n"
  }
}

As the label ID will be also used as index in one-hot vector, so must start from 0.

As the original paper use MIMIC-III as dataset, here also provide a pre-built ETL
to generate training data from MIMIC-III data.

Training and Evaluation

CUDA_VISIBLE_DEVICES=0,1,2,3 python ./train.py ${TRAIN_CONFIG_JSON_FILE_PATH}

The format of config file is JSON, most of parameters are easy to understand if your are a MLE/data scientist/researcher:

  • chunk_size: Each chunks token ID number.
  • chunk_num: The number of chunk each text/document should have, padding first for short sentences.
  • hf_lm: HuggingFace language model name/path, each hf_lm may have different lm_hidden_dim, I personally tried 2 LMs:
    • "distilbert-base-uncased" with lm_hidden_dim as 768
    • "medicalai/ClinicalBERT" with lm_hidden_dim as 768
  • lm_hidden_dim: Language model's hidden output layer's dimension.
  • data_dir: Data directory, should at least contains two files generated by etl_mimic3_processing.py:
    • train.jsonl
    • dev.jsonl
    • (test.jsonl)
  • training_engine: Training engine, can be "torch" or "ray". Torch mode is mainly used for debugging purpose and not supporting distributed training.
  • single_worker_batch_size: Each worker's batch size. Note if training with "torch" engine, then only have one worker.
  • lr: Initial learning rate.
  • epochs: Training epochs.
  • gpu: If using GPU to train.
  • workers: Eorkers number in distrubued training. This is only effective when using "ray" as training engine.
  • single_worker_eval_size: Each worker's maximum evaluation sample size. Again when using "torch" as training engine, you only have one worker.
  • random_seed: Random seed, this can make sure you can 100% reproduce training.
  • text_col: Text column name in train/dev/test JSON line dataset.
  • label_col: Label column name in train/dev/test JSON line dataset.
  • ckpt_dir: Checkpoint directory name.
  • log_period: How many batchs passed before each time's evaluation log printing.
  • dump_period: How many steps passed before each time's checkpoint dumping.
  • label_splitter: The seperator with which we split concated label string to list of label names.
  • eval.label_confidence_threshold: Each label's confidence threshold, if higher then will be set as positive during the evaluation.

Inference

python inf.py inf.json

Most parameters explanations are already in inf.json.

Evaluation

python eval.py eval.json

Most parameters explanations are already in eval.json.

Examples

Training Examples

Other Implementation Details

  • After chunk_size and chunk_num defined, each text's token ID length are fixed to chunk_size * chunk_num. if not long enough then automatically padding first.