diff --git a/docs/source/utils.rst b/docs/source/utils.rst index ad2fc91c897..0ae450487e3 100644 --- a/docs/source/utils.rst +++ b/docs/source/utils.rst @@ -7,3 +7,4 @@ torchvision.utils .. autofunction:: save_image +.. autofunction:: draw_bounding_boxes \ No newline at end of file diff --git a/test/assets/fakedata/draw_boxes_util.png b/test/assets/fakedata/draw_boxes_util.png new file mode 100644 index 00000000000..e6b9286bf92 Binary files /dev/null and b/test/assets/fakedata/draw_boxes_util.png differ diff --git a/test/common_utils.py b/test/common_utils.py index c951a0e34c7..8202378bec8 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -9,6 +9,7 @@ import torch import errno import __main__ +import random from numbers import Number from torch._six import string_classes @@ -30,6 +31,12 @@ def get_tmp_dir(src=None, **kwargs): shutil.rmtree(tmp_dir) +def set_rng_seed(seed): + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + + ACCEPT = os.getenv('EXPECTTEST_ACCEPT') TEST_WITH_SLOW = os.getenv('PYTORCH_TEST_WITH_SLOW', '0') == '1' # TEST_WITH_SLOW = True # TODO: Delete this line once there is a PYTORCH_TEST_WITH_SLOW aware CI job diff --git a/test/test_models.py b/test/test_models.py index aacb19bdb42..f374c20ab40 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -1,23 +1,15 @@ -from common_utils import TestCase, map_nested_tensor_object, freeze_rng_state +from common_utils import TestCase, map_nested_tensor_object, freeze_rng_state, set_rng_seed from collections import OrderedDict from itertools import product import functools import operator import torch import torch.nn as nn -import numpy as np from torchvision import models import unittest -import random import warnings -def set_rng_seed(seed): - torch.manual_seed(seed) - random.seed(seed) - np.random.seed(seed) - - def get_available_classification_models(): # TODO add a registration mechanism to torchvision.models return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] diff --git a/test/test_utils.py b/test/test_utils.py index f1982130f75..18722fe0fb5 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -6,6 +6,7 @@ import unittest from io import BytesIO import torchvision.transforms.functional as F +from torchvision.io.image import read_image from PIL import Image @@ -79,6 +80,21 @@ def test_save_image_single_pixel_file_object(self): self.assertTrue(torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)), 'Pixel Image not stored in file object') + def test_draw_boxes(self): + img = torch.full((3, 100, 100), 255, dtype=torch.uint8) + boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], + [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) + labels = ["a", "b", "c", "d"] + colors = ["green", "#FF00FF", (0, 255, 0), "red"] + result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors) + + path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_util.png") + if not os.path.exists(path): + Image.fromarray(result.permute(1, 2, 0).numpy()).save(path) + + expected = read_image(path) + self.assertTrue(torch.equal(result, expected)) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/utils.py b/torchvision/utils.py index d40284e09a5..65b07032408 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -1,8 +1,13 @@ from typing import Union, Optional, List, Tuple, Text, BinaryIO -import io import pathlib import torch import math +import numpy as np +from PIL import Image, ImageDraw +from PIL import ImageFont + +__all__ = ["make_grid", "save_image", "draw_bounding_boxes"] + irange = range @@ -121,10 +126,64 @@ def save_image( If a file object was used instead of a filename, this parameter should always be used. **kwargs: Other arguments are documented in ``make_grid``. """ - from PIL import Image grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value, normalize=normalize, range=range, scale_each=scale_each) # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() im = Image.fromarray(ndarr) im.save(fp, format=format) + + +@torch.no_grad() +def draw_bounding_boxes( + image: torch.Tensor, + boxes: torch.Tensor, + labels: Optional[List[str]] = None, + colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None, + width: int = 1, + font: Optional[str] = None, + font_size: int = 10 +) -> torch.Tensor: + + """ + Draws bounding boxes on given image. + The values of the input image should be uint8 between 0 and 255. + + Args: + image (Tensor): Tensor of shape (C x H x W) + bboxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that + the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and + `0 <= ymin < ymax < H`. + labels (List[str]): List containing the labels of bounding boxes. + colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of bounding boxes. The colors can + be represented as `str` or `Tuple[int, int, int]`. + width (int): Width of bounding box. + font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may + also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`, + `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS. + font_size (int): The requested font size in points. + """ + + if not isinstance(image, torch.Tensor): + raise TypeError(f"Tensor expected, got {type(image)}") + elif image.dtype != torch.uint8: + raise ValueError(f"Tensor uint8 expected, got {image.dtype}") + elif image.dim() != 3: + raise ValueError("Pass individual images, not batches") + + ndarr = image.permute(1, 2, 0).numpy() + img_to_draw = Image.fromarray(ndarr) + + img_boxes = boxes.to(torch.int64).tolist() + + draw = ImageDraw.Draw(img_to_draw) + + for i, bbox in enumerate(img_boxes): + color = None if colors is None else colors[i] + draw.rectangle(bbox, width=width, outline=color) + + if labels is not None: + txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size) + draw.text((bbox[0], bbox[1]), labels[i], fill=color, font=txt_font) + + return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1)