Skip to content

Add utility to draw bounding boxes #2785

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Nov 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ torchvision.utils

.. autofunction:: save_image

.. autofunction:: draw_bounding_boxes
Binary file added test/assets/fakedata/draw_boxes_util.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 7 additions & 0 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import errno
import __main__
import random

from numbers import Number
from torch._six import string_classes
Expand All @@ -30,6 +31,12 @@ def get_tmp_dir(src=None, **kwargs):
shutil.rmtree(tmp_dir)


def set_rng_seed(seed):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change was originally done on an intermediate commit where I was producing a random image and had to fix the seed. Though I switched to non-random to reduce the size, I think it's a good idea to move this method from test_models.py to commont_utils.py, so I kept the change in this PR.



ACCEPT = os.getenv('EXPECTTEST_ACCEPT')
TEST_WITH_SLOW = os.getenv('PYTORCH_TEST_WITH_SLOW', '0') == '1'
# TEST_WITH_SLOW = True # TODO: Delete this line once there is a PYTORCH_TEST_WITH_SLOW aware CI job
Expand Down
10 changes: 1 addition & 9 deletions test/test_models.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,15 @@
from common_utils import TestCase, map_nested_tensor_object, freeze_rng_state
from common_utils import TestCase, map_nested_tensor_object, freeze_rng_state, set_rng_seed
from collections import OrderedDict
from itertools import product
import functools
import operator
import torch
import torch.nn as nn
import numpy as np
from torchvision import models
import unittest
import random
import warnings


def set_rng_seed(seed):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)


def get_available_classification_models():
# TODO add a registration mechanism to torchvision.models
return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
Expand Down
16 changes: 16 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import unittest
from io import BytesIO
import torchvision.transforms.functional as F
from torchvision.io.image import read_image
from PIL import Image


Expand Down Expand Up @@ -79,6 +80,21 @@ def test_save_image_single_pixel_file_object(self):
self.assertTrue(torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)),
'Pixel Image not stored in file object')

def test_draw_boxes(self):
img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0],
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
labels = ["a", "b", "c", "d"]
colors = ["green", "#FF00FF", (0, 255, 0), "red"]
result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors)

path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_util.png")
if not os.path.exists(path):
Image.fromarray(result.permute(1, 2, 0).numpy()).save(path)
Comment on lines +92 to +93
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: any particular reason why you use PIL to save the result, and not write_image? Although this is not really important as the file is committed to the repo.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that this is worth changing.


expected = read_image(path)
self.assertTrue(torch.equal(result, expected))


if __name__ == '__main__':
unittest.main()
63 changes: 61 additions & 2 deletions torchvision/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from typing import Union, Optional, List, Tuple, Text, BinaryIO
import io
import pathlib
import torch
import math
import numpy as np
from PIL import Image, ImageDraw
from PIL import ImageFont

__all__ = ["make_grid", "save_image", "draw_bounding_boxes"]

irange = range


Expand Down Expand Up @@ -121,10 +126,64 @@ def save_image(
If a file object was used instead of a filename, this parameter should always be used.
**kwargs: Other arguments are documented in ``make_grid``.
"""
from PIL import Image
grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value,
normalize=normalize, range=range, scale_each=scale_each)
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
im = Image.fromarray(ndarr)
im.save(fp, format=format)


@torch.no_grad()
def draw_bounding_boxes(
image: torch.Tensor,
boxes: torch.Tensor,
labels: Optional[List[str]] = None,
colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None,
width: int = 1,
font: Optional[str] = None,
font_size: int = 10
) -> torch.Tensor:

"""
Draws bounding boxes on given image.
The values of the input image should be uint8 between 0 and 255.

Args:
image (Tensor): Tensor of shape (C x H x W)
bboxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that
the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and
`0 <= ymin < ymax < H`.
labels (List[str]): List containing the labels of bounding boxes.
colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of bounding boxes. The colors can
be represented as `str` or `Tuple[int, int, int]`.
width (int): Width of bounding box.
font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may
also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`,
`/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
font_size (int): The requested font size in points.
"""

if not isinstance(image, torch.Tensor):
raise TypeError(f"Tensor expected, got {type(image)}")
elif image.dtype != torch.uint8:
raise ValueError(f"Tensor uint8 expected, got {image.dtype}")
elif image.dim() != 3:
raise ValueError("Pass individual images, not batches")

ndarr = image.permute(1, 2, 0).numpy()
img_to_draw = Image.fromarray(ndarr)

img_boxes = boxes.to(torch.int64).tolist()

draw = ImageDraw.Draw(img_to_draw)

for i, bbox in enumerate(img_boxes):
color = None if colors is None else colors[i]
draw.rectangle(bbox, width=width, outline=color)

if labels is not None:
txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit for a follow-up PR: we can move this to outside of the for loop

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, this can move outside of the loop.

draw.text((bbox[0], bbox[1]), labels[i], fill=color, font=txt_font)

return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1)