This is the official Python code implementation for:
Spatiotemporal Modeling Encounters 3D Medical Image Analysis: Slice-Shift UNet with Multi-View Fusion
Cynthia I. Ugwu, Sofia Casarin, Oswald Lanz
[ICMVA]
The ssh-unet repository builds on insights from video action recognition to propose Slice Shift UNet (SSH-UNet), a 2D-based model that encodes 3D features with the efficiency of 2D CNNs. It achieves this by applying 2D convolutions across three orthogonal planes while using weight sharing to integrate multi-view features. The neglected third dimension is reintroduced via feature map shifting along the slice axis.
- Shift Mechanism Update: Shifting is now applied after both convolutional layers (previously only after the second), leading to better results.
- Optional 3D Fusion Block: Added the option to use at the output block Volumetric Fusion Net (VFN) [1], to enhance performance with minimal parameter overhead: from 6.48M to 12.51M.
- Advanced Shifting Mechanisms: Two new shift strategies are offered GSM and GSF [2,3] beyond the original parameter-free method TSM [4].
SSH-UNet is a ri-adaptation of MONAI’s DynUNet, with most of the code based on the official MONAI implementation.
The code uses PyTorch for training. Follow the official PyTorch installation guide
to install the version compatible with your system. Additional required packages are listed in requirements.txt
.
# clone project
git clone https://github.com/cugwu/ssh-unet.git
# [RECOMMENDED] set up a virtual environment
python -m venv venv_name # choose your prefered venv_name
source venv_name/bin/activate
# install requirements
pip install -r requirements.txt
The training data is from the BTCV challenge dataset and
AMOS segmentation dataset. Use the provided link to download the datasets.
We provide the JSON files btcv_dataset.json
amos_task01_dataset.json
that are used to train our model.
For testing submit your prediction on the original challenges websites.
- Target: 13 abdominal organs
- Task: Segmentation
- Modality: CT
- Size: 30 3D volumes (24 Training + 6 Validation)
- Target: 15 abdominal organs
- Task: Segmentation
- Modality: Task 1 (CT scans segmentation)
- Size: 300 3D volumes (200 Training + 100 Validation)
Make sure to specify the location of your dataset directory using --data_dir
. Additionally, the JSON file must be
placed in the same folder as the dataset.
We provide the results on the validation set from some possible configurations, you can find the output logs in ./logs_sshunet
.
In the Dice acc. column we have the best validation accuracy obtained during training and in brackets the validation accuracy
in the original image space.
SSH-UNet | with VFN | shifting type | # params | Dice acc. (overlap = 0.5) |
---|---|---|---|---|
BTCV | no | tsm | 6.48M | 0.826 (0.843) |
BTCV | no | gsm | 6.51M | 0.831 (0.849) |
BTCV | no | gsf | 6.48M | 0.834 (0.849) |
BTCV | yes | tsm | 12.51M | 0.846 (0.825) |
In the ./runs
directory, you can find the original parameter configurations used to train AMOS and BTCV.
Below is an example command for training AMOS without k-fold cross-validation, using TSM as the shifting mechanism
positioned at all three resolution levels of SSH-UNet. The shifting mechanism is also applied in the bottleneck block and decoder.
The model is trained for 1000 epochs and evaluated every 100 epochs.
⚠ Note:
The --cache_num
argument enables a caching mechanism to preload a certain number of samples, which helps speed up training.
If your memory is limited, consider reducing the cache value. In this case, it is set to 200.
python -u main.py --save_checkpoint --logdir "./results" --data_dir "./data" --json_list "task01_dataset.json" \
--roi_x 96 --roi_y 96 --roi_z 96 --out_channels 16 --nfolds 1 --kernels 133 133 133 133 133 --gate_type "tsm" \
--gate_pos 0 1 2 --gate_bottleneck --gate_dec --max_epochs 1000 --val_every 100 --patience 5 --optim_name "sgd" --optim_lr 0.01 \
--batch_size 1 --cache_num 200 --a_min -991.0 --a_max 362.0 --b_min 0.0 --b_max 1.0 \
--RandFlipd_prob 0.2 --RandRotate90d_prob 0.2 --RandScaleIntensityd_prob 0.5 --RandShiftIntensityd_prob 0.5
Use the following commands to validate your model. Before computing accuracy, the transformations applied to the image will be reversed.
python -u validation.py --pretrained_dir "./results" --pretrained_model_name "0model_lastsaved_fold0.pt" \
--data_dir "./data" --exp_name "val_pred" --json_list "task01_dataset.json" \
--roi_x 96 --roi_y 96 --roi_z 96 --out_channels 16 --kernels 133 133 133 133 133 \
--gate_type "tsm"--gate_pos 0 1 2 --gate_bottleneck --gate_dec
To validate using an ensemble of models, use the following command. You can choose between mean and vote ensemble methods by setting the --ensemble_method
parameter accordingly.
If using the mean ensemble, you must also specify the --w_mee
parameter, which assigns a weight to each model in the ensemble.
The ensemble predictions will be saved in the folder defined by --out_dir
and --exp_name
.
python -u validation_ensemble.py --pretrained_dir "./results" --pretrained_model_name "model_lastsaved_fold" \
--data_dir "./data" --out_dir "/runs" --exp_name "val_pred_ensemble" --num_models 5 \
--json_list "btcv_dataset.json" --roi_x 96 --roi_y 96 --roi_z 96 --out_channels 14 --kernels 133 133 133 133 133 \
--gate_type "tsm" --gate_pos 0 1 2 --gate_bottleneck --gate_dec --ensemble_method "mean" --w_mee 0.95 0.94 0.95 0.94 0.90
With the following command, the challenge test set will be loaded, and the predictions will be saved in the folder specified by --out_dir
and --exp_name
.
The saved images can be uploaded to the original challenge website.
python -u test.py --pretrained_dir "./results" --pretrained_model_name "0model_lastsaved_fold0.pt" \
--data_dir "./data" --out_dir "./runs" --exp_name "val_pred" \
--json_list "btcv_dataset.json" --roi_x 96 --roi_y 96 --roi_z 96 --out_channels 14 \
--kernels 133 133 133 133 133 --gate_type "tsm" --gate_pos 0 1 2 --gate_bottleneck --gate_dec
This project is distributed under the Apache License, Version 2.0. See LICENSE.txt
for more details.
Some parts of the code are licensed under the MIT License. For those cases, the MIT License text is included at the top
of the respective files.
If you use this code or find our work helpful for your research, please cite the following papers:
@inproceedings{ugwu2024spatiotemporal,
title={Spatiotemporal Modeling Encounters 3D Medical Image Analysis: Slice-Shift UNet with Multi-View Fusion},
author={Ugwu, Cynthia Ifeyinwa and Casarin, Sofia and Lanz, Oswald},
booktitle={Proceedings of the 2024 7th International Conference on Machine Vision and Applications},
pages={126--133},
year={2024}
}
[1] Y. Xia, L. Xie, F. Liu, Z. Zhu, E. K. Fishman, A. L. Yuille, Bridging the gap between 2d and 3d organ segmentation with volumetric fusion net, in: International Conference on Medical Image Computing and Computer-Assisted Intervention, Springer, 2018, pp. 445–453
[2] S. Sudhakaran, S. Escalera, O. Lanz, Gate-shift networks for video action recognition, in: Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, 2020, pp. 1102-1111.
[3] S. Sudhakaran, S. Escalera, O. Lanz, Gate-shift-fuse for video action recognition. IEEE Transactions on Pattern Analysis and Machine Intelligence, 2023, 45.9: 10913-10928.
[4] J. Lin, C. Gan & S. Han, Tsm: Temporal shift module for efficient video understanding. In Proceedings of the IEEE/CVF international conference on computer vision, 2019, pp. 7083-7093.