-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Add utility to draw bounding boxes #2785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 20 commits
0a40928
42ab7aa
e229fc7
47acfad
86c0dc9
e6225e5
a66076b
124977f
83524a5
deccdab
6886cac
fc34ccb
396dc53
1aa1b03
c486660
9c14a15
28d29af
d8e10b4
cbd5ee9
35f6951
07274f1
30905e9
1568b54
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,3 +7,4 @@ torchvision.utils | |
|
||
.. autofunction:: save_image | ||
|
||
.. autofunction:: draw_bounding_boxes |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,13 @@ | ||
from typing import Union, Optional, List, Tuple, Text, BinaryIO | ||
import io | ||
import pathlib | ||
import torch | ||
import math | ||
import numpy as np | ||
from PIL import Image, ImageDraw | ||
from PIL.ImageFont import ImageFont | ||
|
||
__all__ = ["make_grid", "save_image", "draw_bounding_boxes"] | ||
|
||
irange = range | ||
|
||
|
||
|
@@ -121,10 +126,60 @@ def save_image( | |
If a file object was used instead of a filename, this parameter should always be used. | ||
**kwargs: Other arguments are documented in ``make_grid``. | ||
""" | ||
from PIL import Image | ||
grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value, | ||
normalize=normalize, range=range, scale_each=scale_each) | ||
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer | ||
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() | ||
im = Image.fromarray(ndarr) | ||
im.save(fp, format=format) | ||
|
||
|
||
def draw_bounding_boxes( | ||
image: torch.Tensor, | ||
boxes: torch.Tensor, | ||
colors: Optional[List[str]] = None, | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
labels: Optional[List[str]] = None, | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
width: int = 1, | ||
font: Optional[ImageFont] = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given that we won't be using this function in torchscript, I'm ok having the input type of the function to be PIL-specific There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not terribly excited about this TBH:
I could try to create a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have a look on the latest commit for an alternative to passing ImageFont. We can choose any of the two options, I'm OK with both. |
||
) -> torch.Tensor: | ||
|
||
""" | ||
Draws bounding boxes on given image. | ||
The values of the input image should be uint8 between 0 and 255. | ||
|
||
Args: | ||
image (Tensor): Tensor of shape (C x H x W) | ||
bboxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
colors (List): List containing the colors of bounding boxes excluding background. | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
labels (List): List containing the labels of bounding boxes excluding background. | ||
width (int): Width of bounding box. | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
font (ImageFont): The PIL ImageFont object used to for drawing the labels. | ||
""" | ||
|
||
if not isinstance(image, torch.Tensor): | ||
raise TypeError(f"Tensor expected, got {type(image)}") | ||
elif image.dtype != torch.uint8: | ||
raise ValueError(f"Tensor uint8 expected, got {image.dtype}") | ||
elif image.dim() != 3: | ||
raise ValueError("Pass individual images, not batches") | ||
|
||
if image.requires_grad: | ||
image = image.detach() | ||
if boxes.requires_grad: | ||
boxes = boxes.detach() | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
ndarr = image.permute(1, 2, 0).numpy() | ||
img_to_draw = Image.fromarray(ndarr) | ||
|
||
img_boxes = boxes.to(torch.int64).tolist() | ||
|
||
draw = ImageDraw.Draw(img_to_draw) | ||
|
||
for i, bbox in enumerate(img_boxes): | ||
color = None if colors is None else colors[i] | ||
draw.rectangle(bbox, width=width, outline=color) | ||
|
||
if labels is not None: | ||
draw.text((bbox[0], bbox[1]), labels[i], fill=color, font=font) | ||
|
||
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you also add a test checking that the color of the output image at pixel values
out[:, 0, 0:100] == fillcolor
etc, so that we know that we are masking the correct pixels in the image?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree we should test pixels but I would rather test all functionalities including labels, fonts etc. I wonder if that's possible or if it will crate a flaky test due to differences on fonts across platforms. I'll give a try to test what I proposed on the earlier comment and see if this works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See latest code for the proposed approach of testing.