This is the official PyTorch implementation of UDI proposed in our paper, "Unsqueeze [CLS] bottleneck to Learn Rich Representations", available at:
Fig. 1 UDI Framework. UDI is an SSL method based on the joint-embedding strategy with multilevel self-distillation objectives. Specifically, for each image, UDI creates two views with one cropped out from the other, followed by two random augmentations, respectively, for student network
We provide pretrained backbone and full checkpoints containing weights of backbone, prediction head and other modules for both student and teacher networks.
arch | params | pretraining epochs | k-nn | linear | download | |||||
---|---|---|---|---|---|---|---|---|---|---|
ViT-S/16 | 21M | 100 | 74.9% | 76.3% | backbone only | full ckpt | args | logs | ||
ViT-S/16 | 21M | 300 | 75.6% | 77.6% | backbone only | full ckpt | args | logs |
Table 1. KNN and linear probing performance with their corresponding hyper-parameters, logs and model weights.
Fig 2. Model Performance of Top1-Acc on IN-1K and mAP on MC-COCO. UDI achieves more balanced performance between the Image-level and dense prediction tasks.
For reproducing, please install PyTorch and download the ImageNet dataset.
This codebase has been developed with python version 3.9, PyTorch version 1.12.1, CUDA 11.6 and torchvision 0.13.1. For the full
environment, please refer to requirement.txt
file.
To pretraining with UDI, please find the exact hyper-parameter settings at the args
column of Table 1. To run ViT-small, we use one node of total 8 A100 GPUs (total 1024 minibatch size) by using following command:
[300 epoch]
torchrun --standalone --nproc_per_node=8 main_udi.py \
--data_path $DATASET_ROOT \
--output_dir $OUTPUT_ROOT \
--arch vit_small \
--teacher_temp 0.07 \
--warmup_teacher_temp_epochs 30 \
--warmup_epochs 10 \
--local_crops_number 10 \
--norm_last_layer false \
--epochs 300
! **ViT-Small trained with 800 epoch and larger models (ViT-B, ViT-L) will be released in the future.**
To evaluate a simple k-NN classifier with a single GPU on a pre-trained model, run:
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE ./eval_knn/eval_knn.py \
--data_path $DATASET_ROOT \
--pretrained_weights $PRETRAINED_WEIGHTS \
--arch vit_small \
--batch_size_per_gpu 256
To train a supervised linear classifier on frozen weights on a single node with 8 gpus, run:
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE ./eval_linear_probing/eval_linear.py \
--data_path $DATASET_ROOT \
--pretrained_weights $PRETRAINED_WEIGHTS \
--arch vit_small \
--lr 0.02 \
--epochs 100
To evaluate fine-tuning on a pre-trained model, you need to
python ./eval_finetuning/extract_backbone_weights_for_finetuning.py \
--checkpoint $CHECKPOINT \
--output $OUTPUT \
--checkpoint_key teacher
torchrun --standalone --nproc_per_node=4 ./eval_finetuning/eval_finetuning.py --data_path $DATASET_ROOT --finetune $OUTPUT \
--model vit_small --epochs 200 --batch_size 256 --warmup_epochs 20 --drop_path 0.1 --lr 0.001 --layer_decay 0.55 \
--mixup 0.8 --cutmix 1.0 --layer_scale_init_value 0.0 \
--disable_rel_pos_bias --abs_pos_emb --use_cls --imagenet_default_mean_and_std
To visualize the self-attention map of the [CLS] token on the heads of the last layer, run
python visualize_attention.py --pretrained_weights ${dir_to_model}
Fig 3. Self-attention of ViT-Small/16 trained with UDI.
This repository is released under the Apache 2.0 license as found in the LICENSE file.
If you find this repository useful, please consider giving a star ⭐ and citation 📘:
@inproceedings{udi2024ssl,
title={Unsqueeze [CLS] bottleneck to Learn Rich Representations},
author={Qing Su and Shihao Ji},
booktitle={The 18th European Conference on Computer Vision (ECCV)},
year={2024}
}
This software was created by Georgia State University Research Foundation under Army Research Laboratory (ARL) Award Number W911NF-23-2-0224. ARL, as the Federal awarding agency, reserves a royalty-free, nonexclusive and irrevocable right to reproduce, publish, or otherwise use this software for Federal purposes, and to authorize others to do so in accordance with 2 CFR 200.315(b).