From 62dc50a30cd9ae8679668de23f06f0abdccfa5db Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 27 Sep 2021 12:56:45 +0100 Subject: [PATCH 1/6] Reuse EfficientNet SE layer. --- torchvision/models/mobilenetv3.py | 21 ++++--------------- .../models/quantization/mobilenetv3.py | 2 +- 2 files changed, 5 insertions(+), 18 deletions(-) diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index ebe3f510a49..0836ae4a7a2 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -6,6 +6,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence from .._internally_replaced_utils import load_state_dict_from_url +from torchvision.models.efficientnet import SqueezeExcitation as SElayer from torchvision.models.mobilenetv2 import _make_divisible, ConvBNActivation @@ -18,25 +19,11 @@ } -class SqueezeExcitation(nn.Module): - # Implemented as described at Figure 4 of the MobileNetV3 paper +class SqueezeExcitation(SElayer): def __init__(self, input_channels: int, squeeze_factor: int = 4): - super().__init__() squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8) - self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1) - self.relu = nn.ReLU(inplace=True) - self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1) - - def _scale(self, input: Tensor, inplace: bool) -> Tensor: - scale = F.adaptive_avg_pool2d(input, 1) - scale = self.fc1(scale) - scale = self.relu(scale) - scale = self.fc2(scale) - return F.hardsigmoid(scale, inplace=inplace) - - def forward(self, input: Tensor) -> Tensor: - scale = self._scale(input, True) - return scale * input + super().__init__(input_channels, squeeze_channels, activation=nn.ReLU, scale_activation=nn.Hardsigmoid) + self.relu = self.activation class InvertedResidualConfig: diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index a1aa9d7d4bd..348fc502510 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -22,7 +22,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self.skip_mul = nn.quantized.FloatFunctional() def forward(self, input: Tensor) -> Tensor: - return self.skip_mul.mul(self._scale(input, False), input) + return self.skip_mul.mul(self._scale(input), input) def fuse_model(self) -> None: fuse_modules(self, ['fc1', 'relu'], inplace=True) From 72cecb189b6590e506985e53838db1c28636ca9f Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 27 Sep 2021 14:49:58 +0100 Subject: [PATCH 2/6] Deprecating the mobilenetv3.SqueezeExcitation layer. --- torchvision/models/mobilenetv3.py | 19 ++++++++++++------- .../models/quantization/mobilenetv3.py | 9 +++++---- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index 0836ae4a7a2..11dd36c804b 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -1,13 +1,13 @@ +import warnings import torch from functools import partial from torch import nn, Tensor -from torch.nn import functional as F -from typing import Any, Callable, Dict, List, Optional, Sequence +from typing import Any, Callable, List, Optional, Sequence from .._internally_replaced_utils import load_state_dict_from_url -from torchvision.models.efficientnet import SqueezeExcitation as SElayer -from torchvision.models.mobilenetv2 import _make_divisible, ConvBNActivation +from .efficientnet import SqueezeExcitation as SElayer +from .mobilenetv2 import _make_divisible, ConvBNActivation __all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small"] @@ -20,10 +20,14 @@ class SqueezeExcitation(SElayer): + """DEPRECATED + """ def __init__(self, input_channels: int, squeeze_factor: int = 4): squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8) - super().__init__(input_channels, squeeze_channels, activation=nn.ReLU, scale_activation=nn.Hardsigmoid) + super().__init__(input_channels, squeeze_channels, scale_activation=nn.Hardsigmoid) self.relu = self.activation + warnings.warn( + "This SqueezeExcitation class is deprecated and will be removed in future versions.", FutureWarning) class InvertedResidualConfig: @@ -47,7 +51,7 @@ def adjust_channels(channels: int, width_mult: float): class InvertedResidual(nn.Module): # Implemented as described at section 5 of MobileNetV3 paper def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module], - se_layer: Callable[..., nn.Module] = SqueezeExcitation): + se_layer: Callable[..., nn.Module] = partial(SElayer, scale_activation=nn.Hardsigmoid)): super().__init__() if not (1 <= cnf.stride <= 2): raise ValueError('illegal stride value') @@ -68,7 +72,8 @@ def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Mod stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels, norm_layer=norm_layer, activation_layer=activation_layer)) if cnf.use_se: - layers.append(se_layer(cnf.expanded_channels)) + squeeze_channels = _make_divisible(cnf.expanded_channels // 4, 8) + layers.append(se_layer(cnf.expanded_channels, squeeze_channels)) # project layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 348fc502510..dcff2fc1f6c 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -1,8 +1,9 @@ import torch from torch import nn, Tensor from ..._internally_replaced_utils import load_state_dict_from_url -from torchvision.models.mobilenetv3 import InvertedResidual, InvertedResidualConfig, ConvBNActivation, MobileNetV3,\ - SqueezeExcitation, model_urls, _mobilenet_v3_conf +from ..efficientnet import SqueezeExcitation as SElayer +from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, ConvBNActivation, MobileNetV3,\ + model_urls, _mobilenet_v3_conf from torch.quantization import QuantStub, DeQuantStub, fuse_modules from typing import Any, List, Optional from .utils import _replace_relu @@ -16,7 +17,7 @@ } -class QuantizableSqueezeExcitation(SqueezeExcitation): +class QuantizableSqueezeExcitation(SElayer): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.skip_mul = nn.quantized.FloatFunctional() @@ -25,7 +26,7 @@ def forward(self, input: Tensor) -> Tensor: return self.skip_mul.mul(self._scale(input), input) def fuse_model(self) -> None: - fuse_modules(self, ['fc1', 'relu'], inplace=True) + fuse_modules(self, ['fc1', 'activation'], inplace=True) class QuantizableInvertedResidual(InvertedResidual): From eedb93940742794d27632fab0e032b2666e751ea Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 27 Sep 2021 21:44:42 +0100 Subject: [PATCH 3/6] Passing the right activation on quantization. --- torchvision/models/mobilenetv3.py | 1 + torchvision/models/quantization/mobilenetv3.py | 8 +++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index 11dd36c804b..0485c9d61e5 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -26,6 +26,7 @@ def __init__(self, input_channels: int, squeeze_factor: int = 4): squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8) super().__init__(input_channels, squeeze_channels, scale_activation=nn.Hardsigmoid) self.relu = self.activation + delattr(self, 'activation') warnings.warn( "This SqueezeExcitation class is deprecated and will be removed in future versions.", FutureWarning) diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index dcff2fc1f6c..8840747222a 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -19,6 +19,7 @@ class QuantizableSqueezeExcitation(SElayer): def __init__(self, *args: Any, **kwargs: Any) -> None: + kwargs["scale_activation"] = nn.Hardswish super().__init__(*args, **kwargs) self.skip_mul = nn.quantized.FloatFunctional() @@ -80,11 +81,12 @@ def _load_weights( model: QuantizableMobileNetV3, model_url: Optional[str], progress: bool, + strict: bool ) -> None: if model_url is None: raise ValueError("No checkpoint is available for {}".format(arch)) state_dict = load_state_dict_from_url(model_url, progress=progress) - model.load_state_dict(state_dict) + model.load_state_dict(state_dict, strict=strict) def _mobilenet_v3_model( @@ -108,13 +110,13 @@ def _mobilenet_v3_model( torch.quantization.prepare_qat(model, inplace=True) if pretrained: - _load_weights(arch, model, quant_model_urls.get(arch + '_' + backend, None), progress) + _load_weights(arch, model, quant_model_urls.get(arch + '_' + backend, None), progress, False) torch.quantization.convert(model, inplace=True) model.eval() else: if pretrained: - _load_weights(arch, model, model_urls.get(arch, None), progress) + _load_weights(arch, model, model_urls.get(arch, None), progress, True) return model From b396443fb365169a74d893d65a3912372db33c7d Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 28 Sep 2021 19:39:10 +0100 Subject: [PATCH 4/6] Making strict named param. --- torchvision/models/quantization/mobilenetv3.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 8840747222a..64001146e97 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -81,7 +81,7 @@ def _load_weights( model: QuantizableMobileNetV3, model_url: Optional[str], progress: bool, - strict: bool + strict: bool = True ) -> None: if model_url is None: raise ValueError("No checkpoint is available for {}".format(arch)) @@ -110,13 +110,13 @@ def _mobilenet_v3_model( torch.quantization.prepare_qat(model, inplace=True) if pretrained: - _load_weights(arch, model, quant_model_urls.get(arch + '_' + backend, None), progress, False) + _load_weights(arch, model, quant_model_urls.get(arch + '_' + backend, None), progress, strict=False) torch.quantization.convert(model, inplace=True) model.eval() else: if pretrained: - _load_weights(arch, model, model_urls.get(arch, None), progress, True) + _load_weights(arch, model, model_urls.get(arch, None), progress) return model From b491fa26fd7638fed05665fcf40a51f42734122a Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 29 Sep 2021 13:08:05 +0100 Subject: [PATCH 5/6] Set default params if missing. --- .../models/quantization/mobilenetv3.py | 45 +++++++++++++++++-- 1 file changed, 41 insertions(+), 4 deletions(-) diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 64001146e97..4a5282c92cd 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -18,6 +18,8 @@ class QuantizableSqueezeExcitation(SElayer): + _version = 2 + def __init__(self, *args: Any, **kwargs: Any) -> None: kwargs["scale_activation"] = nn.Hardswish super().__init__(*args, **kwargs) @@ -29,6 +31,42 @@ def forward(self, input: Tensor) -> Tensor: def fuse_model(self) -> None: fuse_modules(self, ['fc1', 'activation'], inplace=True) + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + + if version is None or version < 2: + default_state_dict = { + "scale_activation.activation_post_process.scale": torch.tensor([1.]), + "scale_activation.activation_post_process.zero_point": torch.tensor([0], dtype=torch.int32), + "scale_activation.activation_post_process.fake_quant_enabled": torch.tensor([1]), + "scale_activation.activation_post_process.observer_enabled": torch.tensor([1]), + "scale_activation.activation_post_process.activation_post_process.min_val": torch.tensor(float('inf')), + "scale_activation.activation_post_process.activation_post_process.max_val": torch.tensor(-float('inf')), + } + for k, v in default_state_dict.items(): + full_key = prefix + k + if full_key not in state_dict: + state_dict[full_key] = v + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + class QuantizableInvertedResidual(InvertedResidual): # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 @@ -80,13 +118,12 @@ def _load_weights( arch: str, model: QuantizableMobileNetV3, model_url: Optional[str], - progress: bool, - strict: bool = True + progress: bool ) -> None: if model_url is None: raise ValueError("No checkpoint is available for {}".format(arch)) state_dict = load_state_dict_from_url(model_url, progress=progress) - model.load_state_dict(state_dict, strict=strict) + model.load_state_dict(state_dict) def _mobilenet_v3_model( @@ -110,7 +147,7 @@ def _mobilenet_v3_model( torch.quantization.prepare_qat(model, inplace=True) if pretrained: - _load_weights(arch, model, quant_model_urls.get(arch + '_' + backend, None), progress, strict=False) + _load_weights(arch, model, quant_model_urls.get(arch + '_' + backend, None), progress) torch.quantization.convert(model, inplace=True) model.eval() From 24ce2bdb8e7e50d92032a8633dc75b4010a25c3a Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 29 Sep 2021 14:42:00 +0100 Subject: [PATCH 6/6] Fixing typos. --- torchvision/models/quantization/mobilenetv3.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 4a5282c92cd..8c64f137053 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -21,7 +21,7 @@ class QuantizableSqueezeExcitation(SElayer): _version = 2 def __init__(self, *args: Any, **kwargs: Any) -> None: - kwargs["scale_activation"] = nn.Hardswish + kwargs["scale_activation"] = nn.Hardsigmoid super().__init__(*args, **kwargs) self.skip_mul = nn.quantized.FloatFunctional() @@ -49,8 +49,6 @@ def _load_from_state_dict( "scale_activation.activation_post_process.zero_point": torch.tensor([0], dtype=torch.int32), "scale_activation.activation_post_process.fake_quant_enabled": torch.tensor([1]), "scale_activation.activation_post_process.observer_enabled": torch.tensor([1]), - "scale_activation.activation_post_process.activation_post_process.min_val": torch.tensor(float('inf')), - "scale_activation.activation_post_process.activation_post_process.max_val": torch.tensor(-float('inf')), } for k, v in default_state_dict.items(): full_key = prefix + k