diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 39b7edc8c13..d277f130ad3 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -672,17 +672,22 @@ def forward(self, images, targets=None): return self.eager_outputs(losses, detections) +_COMMON_META = { + "task": "image_object_detection", + "architecture": "RetinaNet", + "categories": _COCO_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, +} + + class RetinaNet_ResNet50_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth", transforms=ObjectDetection, meta={ - "task": "image_object_detection", - "architecture": "RetinaNet", + **_COMMON_META, "publication_year": 2017, "num_params": 34014999, - "categories": _COCO_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet", "map": 36.4, }, @@ -691,7 +696,18 @@ class RetinaNet_ResNet50_FPN_Weights(WeightsEnum): class RetinaNet_ResNet50_FPN_V2_Weights(WeightsEnum): - pass + COCO_V1 = Weights( + url="https://download.pytorch.org/models/retinanet_resnet50_fpn_v2_coco-5905b1c5.pth", + transforms=ObjectDetection, + meta={ + **_COMMON_META, + "publication_year": 2019, + "num_params": 38198935, + "recipe": "https://github.com/pytorch/vision/pull/5756", + "map": 41.5, + }, + ) + DEFAULT = COCO_V1 @handle_legacy_interface(