diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index ebe3f510a49..0485c9d61e5 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -1,12 +1,13 @@ +import warnings import torch from functools import partial from torch import nn, Tensor -from torch.nn import functional as F -from typing import Any, Callable, Dict, List, Optional, Sequence +from typing import Any, Callable, List, Optional, Sequence from .._internally_replaced_utils import load_state_dict_from_url -from torchvision.models.mobilenetv2 import _make_divisible, ConvBNActivation +from .efficientnet import SqueezeExcitation as SElayer +from .mobilenetv2 import _make_divisible, ConvBNActivation __all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small"] @@ -18,25 +19,16 @@ } -class SqueezeExcitation(nn.Module): - # Implemented as described at Figure 4 of the MobileNetV3 paper +class SqueezeExcitation(SElayer): + """DEPRECATED + """ def __init__(self, input_channels: int, squeeze_factor: int = 4): - super().__init__() squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8) - self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1) - self.relu = nn.ReLU(inplace=True) - self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1) - - def _scale(self, input: Tensor, inplace: bool) -> Tensor: - scale = F.adaptive_avg_pool2d(input, 1) - scale = self.fc1(scale) - scale = self.relu(scale) - scale = self.fc2(scale) - return F.hardsigmoid(scale, inplace=inplace) - - def forward(self, input: Tensor) -> Tensor: - scale = self._scale(input, True) - return scale * input + super().__init__(input_channels, squeeze_channels, scale_activation=nn.Hardsigmoid) + self.relu = self.activation + delattr(self, 'activation') + warnings.warn( + "This SqueezeExcitation class is deprecated and will be removed in future versions.", FutureWarning) class InvertedResidualConfig: @@ -60,7 +52,7 @@ def adjust_channels(channels: int, width_mult: float): class InvertedResidual(nn.Module): # Implemented as described at section 5 of MobileNetV3 paper def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module], - se_layer: Callable[..., nn.Module] = SqueezeExcitation): + se_layer: Callable[..., nn.Module] = partial(SElayer, scale_activation=nn.Hardsigmoid)): super().__init__() if not (1 <= cnf.stride <= 2): raise ValueError('illegal stride value') @@ -81,7 +73,8 @@ def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Mod stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels, norm_layer=norm_layer, activation_layer=activation_layer)) if cnf.use_se: - layers.append(se_layer(cnf.expanded_channels)) + squeeze_channels = _make_divisible(cnf.expanded_channels // 4, 8) + layers.append(se_layer(cnf.expanded_channels, squeeze_channels)) # project layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index a1aa9d7d4bd..8c64f137053 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -1,8 +1,9 @@ import torch from torch import nn, Tensor from ..._internally_replaced_utils import load_state_dict_from_url -from torchvision.models.mobilenetv3 import InvertedResidual, InvertedResidualConfig, ConvBNActivation, MobileNetV3,\ - SqueezeExcitation, model_urls, _mobilenet_v3_conf +from ..efficientnet import SqueezeExcitation as SElayer +from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, ConvBNActivation, MobileNetV3,\ + model_urls, _mobilenet_v3_conf from torch.quantization import QuantStub, DeQuantStub, fuse_modules from typing import Any, List, Optional from .utils import _replace_relu @@ -16,16 +17,53 @@ } -class QuantizableSqueezeExcitation(SqueezeExcitation): +class QuantizableSqueezeExcitation(SElayer): + _version = 2 + def __init__(self, *args: Any, **kwargs: Any) -> None: + kwargs["scale_activation"] = nn.Hardsigmoid super().__init__(*args, **kwargs) self.skip_mul = nn.quantized.FloatFunctional() def forward(self, input: Tensor) -> Tensor: - return self.skip_mul.mul(self._scale(input, False), input) + return self.skip_mul.mul(self._scale(input), input) def fuse_model(self) -> None: - fuse_modules(self, ['fc1', 'relu'], inplace=True) + fuse_modules(self, ['fc1', 'activation'], inplace=True) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + + if version is None or version < 2: + default_state_dict = { + "scale_activation.activation_post_process.scale": torch.tensor([1.]), + "scale_activation.activation_post_process.zero_point": torch.tensor([0], dtype=torch.int32), + "scale_activation.activation_post_process.fake_quant_enabled": torch.tensor([1]), + "scale_activation.activation_post_process.observer_enabled": torch.tensor([1]), + } + for k, v in default_state_dict.items(): + full_key = prefix + k + if full_key not in state_dict: + state_dict[full_key] = v + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) class QuantizableInvertedResidual(InvertedResidual): @@ -78,7 +116,7 @@ def _load_weights( arch: str, model: QuantizableMobileNetV3, model_url: Optional[str], - progress: bool, + progress: bool ) -> None: if model_url is None: raise ValueError("No checkpoint is available for {}".format(arch))