diff --git a/docs/source/models.rst b/docs/source/models.rst index 51881e505e4..3b1db1b3bf4 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -176,8 +176,10 @@ Densenet-201 76.896 93.370 Densenet-161 77.138 93.560 Inception v3 77.294 93.450 GoogleNet 69.778 89.530 -ShuffleNet V2 x1.0 69.362 88.316 ShuffleNet V2 x0.5 60.552 81.746 +ShuffleNet V2 x1.0 69.362 88.316 +ShuffleNet V2 x1.5 72.996 91.086 +ShuffleNet V2 x2.0 76.230 93.006 MobileNet V2 71.878 90.286 MobileNet V3 Large 74.042 91.340 MobileNet V3 Small 67.668 87.402 @@ -499,8 +501,10 @@ Model Acc@1 Acc@5 ================================ ============= ============= MobileNet V2 71.658 90.150 MobileNet V3 Large 73.004 90.858 -ShuffleNet V2 x1.0 68.360 87.582 ShuffleNet V2 x0.5 57.972 79.780 +ShuffleNet V2 x1.0 68.360 87.582 +ShuffleNet V2 x1.5 72.052 90.700 +ShuffleNet V2 x2.0 75.354 92.488 ResNet 18 69.494 88.882 ResNet 50 75.920 92.814 ResNext 101 32x8d 78.986 94.480 diff --git a/hubconf.py b/hubconf.py index bbd5da52b13..28d2c5a5d01 100644 --- a/hubconf.py +++ b/hubconf.py @@ -59,7 +59,12 @@ deeplabv3_mobilenet_v3_large, lraspp_mobilenet_v3_large, ) -from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0 +from torchvision.models.shufflenetv2 import ( + shufflenet_v2_x0_5, + shufflenet_v2_x1_0, + shufflenet_v2_x1_5, + shufflenet_v2_x2_0, +) from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1 from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn from torchvision.models.vision_transformer import ( diff --git a/references/classification/README.md b/references/classification/README.md index 758c57ce27d..9eb95fd00e9 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -236,6 +236,20 @@ torchrun --nproc_per_node=8 train.py\ Note that `--val-resize-size` was optimized in a post-training step, see their `Weights` entry for the exact value. +### ShuffleNet V2 +``` +torchrun --nproc_per_node=8 train.py \ +--batch-size=128 \ +--lr=0.5 --lr-scheduler=cosineannealinglr --lr-warmup-epochs=5 --lr-warmup-method=linear \ +--auto-augment=ta_wide --epochs=600 --random-erase=0.1 --weight-decay=0.00002 \ +--norm-weight-decay=0.0 --label-smoothing=0.1 --mixup-alpha=0.2 --cutmix-alpha=1.0 \ +--train-crop-size=176 --model-ema --val-resize-size=232 --ra-sampler --ra-reps=4 +``` +Here `$MODEL` is either `shufflenet_v2_x1_5` or `shufflenet_v2_x2_0`. + +The models `shufflenet_v2_x0_5` and `shufflenet_v2_x1_0` were contributed by the community. See [PR-849](https://github.com/pytorch/vision/pull/849#issuecomment-483391686) for details. + + ## Mixed precision training Automatic Mixed Precision (AMP) training on GPU for Pytorch can be enabled with the [torch.cuda.amp](https://pytorch.org/docs/stable/amp.html?highlight=amp#module-torch.cuda.amp). @@ -263,6 +277,21 @@ python train_quantization.py --device='cpu' --post-training-quantize --backend=' ``` Here `$MODEL` is one of `googlenet`, `inception_v3`, `resnet18`, `resnet50`, `resnext101_32x8d`, `shufflenet_v2_x0_5` and `shufflenet_v2_x1_0`. +### Quantized ShuffleNet V2 + +Here are commands that we use to quantized the `shufflenet_v2_x1_5` and `shufflenet_v2_x2_0` models. +``` +# For shufflenet_v2_x1_5 +python train_quantization.py --device='cpu' --post-training-quantize --backend='fbgemm' \ + --model=shufflenet_v2_x1_5 --weights="ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1" \ + --train-crop-size 176 --val-resize-size 232 --data-path /datasets01_ontap/imagenet_full_size/061417/ + +# For shufflenet_v2_x2_0 +python train_quantization.py --device='cpu' --post-training-quantize --backend='fbgemm' \ + --model=shufflenet_v2_x2_0 --weights="ShuffleNet_V2_X2_0_Weights.IMAGENET1K_V1" \ + --train-crop-size 176 --val-resize-size 232 --data-path /datasets01_ontap/imagenet_full_size/061417/ +``` + ### QAT MobileNetV2 For Mobilenet-v2, the model was trained with quantization aware training, the settings used are: diff --git a/torchvision/models/quantization/shufflenetv2.py b/torchvision/models/quantization/shufflenetv2.py index 44e7772aa37..cfbdc5b0c24 100644 --- a/torchvision/models/quantization/shufflenetv2.py +++ b/torchvision/models/quantization/shufflenetv2.py @@ -10,7 +10,12 @@ from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param -from ..shufflenetv2 import ShuffleNet_V2_X0_5_Weights, ShuffleNet_V2_X1_0_Weights +from ..shufflenetv2 import ( + ShuffleNet_V2_X0_5_Weights, + ShuffleNet_V2_X1_0_Weights, + ShuffleNet_V2_X1_5_Weights, + ShuffleNet_V2_X2_0_Weights, +) from .utils import _fuse_modules, _replace_relu, quantize_model @@ -18,8 +23,12 @@ "QuantizableShuffleNetV2", "ShuffleNet_V2_X0_5_QuantizedWeights", "ShuffleNet_V2_X1_0_QuantizedWeights", + "ShuffleNet_V2_X1_5_QuantizedWeights", + "ShuffleNet_V2_X2_0_QuantizedWeights", "shufflenet_v2_x0_5", "shufflenet_v2_x1_0", + "shufflenet_v2_x1_5", + "shufflenet_v2_x2_0", ] @@ -143,6 +152,42 @@ class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum): DEFAULT = IMAGENET1K_FBGEMM_V1 +class ShuffleNet_V2_X1_5_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_5_fbgemm-d7401f05.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "recipe": "https://github.com/pytorch/vision/pull/5906", + "num_params": 3503624, + "unquantized": ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1, + "metrics": { + "acc@1": 72.052, + "acc@5": 90.700, + }, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V1 + + +class ShuffleNet_V2_X2_0_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/shufflenetv2_x2_0_fbgemm-5cac526c.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "recipe": "https://github.com/pytorch/vision/pull/5906", + "num_params": 7393996, + "unquantized": ShuffleNet_V2_X2_0_Weights.IMAGENET1K_V1, + "metrics": { + "acc@1": 75.354, + "acc@5": 92.488, + }, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V1 + + @handle_legacy_interface( weights=( "pretrained", @@ -205,3 +250,51 @@ def shufflenet_v2_x1_0( return _shufflenetv2( [4, 8, 4], [24, 116, 232, 464, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs ) + + +def shufflenet_v2_x1_5( + *, + weights: Optional[Union[ShuffleNet_V2_X1_5_QuantizedWeights, ShuffleNet_V2_X1_5_Weights]] = None, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableShuffleNetV2: + """ + Constructs a ShuffleNetV2 with 1.5x output channels, as described in + `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" + `_. + + Args: + weights (ShuffleNet_V2_X1_5_QuantizedWeights or ShuffleNet_V2_X1_5_Weights, optional): The pretrained + weights for the model + progress (bool): If True, displays a progress bar of the download to stderr + quantize (bool): If True, return a quantized version of the model + """ + weights = (ShuffleNet_V2_X1_5_QuantizedWeights if quantize else ShuffleNet_V2_X1_5_Weights).verify(weights) + return _shufflenetv2( + [4, 8, 4], [24, 176, 352, 704, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs + ) + + +def shufflenet_v2_x2_0( + *, + weights: Optional[Union[ShuffleNet_V2_X2_0_QuantizedWeights, ShuffleNet_V2_X2_0_Weights]] = None, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableShuffleNetV2: + """ + Constructs a ShuffleNetV2 with 2.0x output channels, as described in + `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" + `_. + + Args: + weights (ShuffleNet_V2_X2_0_QuantizedWeights or ShuffleNet_V2_X2_0_Weights, optional): The pretrained + weights for the model + progress (bool): If True, displays a progress bar of the download to stderr + quantize (bool): If True, return a quantized version of the model + """ + weights = (ShuffleNet_V2_X2_0_QuantizedWeights if quantize else ShuffleNet_V2_X2_0_Weights).verify(weights) + return _shufflenetv2( + [4, 8, 4], [24, 244, 488, 976, 2048], weights=weights, progress=progress, quantize=quantize, **kwargs + ) diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index 68243e06e8f..5e58c489ee0 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -223,11 +223,37 @@ class ShuffleNet_V2_X1_0_Weights(WeightsEnum): class ShuffleNet_V2_X1_5_Weights(WeightsEnum): - pass + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/shufflenetv2_x1_5-3c479a10.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "recipe": "https://github.com/pytorch/vision/pull/5906", + "num_params": 3503624, + "metrics": { + "acc@1": 72.996, + "acc@5": 91.086, + }, + }, + ) + DEFAULT = IMAGENET1K_V1 class ShuffleNet_V2_X2_0_Weights(WeightsEnum): - pass + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/shufflenetv2_x2_0-8be3c8ee.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "recipe": "https://github.com/pytorch/vision/pull/5906", + "num_params": 7393996, + "metrics": { + "acc@1": 76.230, + "acc@5": 93.006, + }, + }, + ) + DEFAULT = IMAGENET1K_V1 @handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1))