Skip to content

Add utility to draw bounding boxes #2785

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Nov 27, 2020
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ torchvision.utils

.. autofunction:: save_image

.. autofunction:: draw_bounding_boxes
11 changes: 10 additions & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import unittest
from io import BytesIO
import torchvision.transforms.functional as F
from PIL import Image
from PIL import Image, ImageDraw


class Tester(unittest.TestCase):
Expand Down Expand Up @@ -79,6 +79,15 @@ def test_save_image_single_pixel_file_object(self):
self.assertTrue(torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)),
'Pixel Image not stored in file object')

def test_draw_boxes(self):
img = torch.randint(0, 255, (3, 226, 226), dtype=torch.uint8)
boxes = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
labels = ['a', 'b', 'c', 'd']
utils.draw_bounding_boxes(img, boxes, labels=labels)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also add a test checking that the color of the output image at pixel values out[:, 0, 0:100] == fillcolor etc, so that we know that we are masking the correct pixels in the image?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree we should test pixels but I would rather test all functionalities including labels, fonts etc. I wonder if that's possible or if it will crate a flaky test due to differences on fonts across platforms. I'll give a try to test what I proposed on the earlier comment and see if this works.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See latest code for the proposed approach of testing.


return True


if __name__ == '__main__':
unittest.main()
59 changes: 57 additions & 2 deletions torchvision/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from typing import Union, Optional, List, Tuple, Text, BinaryIO
import io
import pathlib
import torch
import math
import numpy as np
from PIL import Image, ImageDraw
from PIL.ImageFont import ImageFont

__all__ = ["make_grid", "save_image", "draw_bounding_boxes"]

irange = range


Expand Down Expand Up @@ -121,10 +126,60 @@ def save_image(
If a file object was used instead of a filename, this parameter should always be used.
**kwargs: Other arguments are documented in ``make_grid``.
"""
from PIL import Image
grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value,
normalize=normalize, range=range, scale_each=scale_each)
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
im = Image.fromarray(ndarr)
im.save(fp, format=format)


def draw_bounding_boxes(
image: torch.Tensor,
boxes: torch.Tensor,
colors: Optional[List[str]] = None,
labels: Optional[List[str]] = None,
width: int = 1,
font: Optional[ImageFont] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that we won't be using this function in torchscript, I'm ok having the input type of the function to be PIL-specific

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not terribly excited about this TBH:

  • On one hand the method receives a uint8 tensor as input (not a PIL image) and hides completely any dependency on PIL. I would agree with earlier comments of yours that it's a bit odd that we expose ImageFont here.
  • On the other hand, using PIL's ImageFont gives the flexibility to the user to do whatever they want without having to deal on our side with the details on how to instantiate the object. It's surely is ugly though and makes for a weird API.

I could try to create a font parameter similar to PIL with description "A filename or file-like object containing a TrueType font." and a font_size. Thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have a look on the latest commit for an alternative to passing ImageFont. We can choose any of the two options, I'm OK with both.

) -> torch.Tensor:

"""
Draws bounding boxes on given image.
The values of the input image should be uint8 between 0 and 255.

Args:
image (Tensor): Tensor of shape (C x H x W)
bboxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format.
colors (List): List containing the colors of bounding boxes excluding background.
labels (List): List containing the labels of bounding boxes excluding background.
width (int): Width of bounding box.
font (ImageFont): The PIL ImageFont object used to for drawing the labels.
"""

if not isinstance(image, torch.Tensor):
raise TypeError(f"Tensor expected, got {type(image)}")
elif image.dtype != torch.uint8:
raise ValueError(f"Tensor uint8 expected, got {image.dtype}")
elif image.dim() != 3:
raise ValueError("Pass individual images, not batches")

if image.requires_grad:
image = image.detach()
if boxes.requires_grad:
boxes = boxes.detach()

ndarr = image.permute(1, 2, 0).numpy()
img_to_draw = Image.fromarray(ndarr)

img_boxes = boxes.to(torch.int64).tolist()

draw = ImageDraw.Draw(img_to_draw)

for i, bbox in enumerate(img_boxes):
color = None if colors is None else colors[i]
draw.rectangle(bbox, width=width, outline=color)

if labels is not None:
draw.text((bbox[0], bbox[1]), labels[i], fill=color, font=font)

return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1)