-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Add utility to draw bounding boxes #2785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0a40928
42ab7aa
e229fc7
47acfad
86c0dc9
e6225e5
a66076b
124977f
83524a5
deccdab
6886cac
fc34ccb
396dc53
1aa1b03
c486660
9c14a15
28d29af
d8e10b4
cbd5ee9
35f6951
07274f1
30905e9
1568b54
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,3 +7,4 @@ torchvision.utils | |
|
||
.. autofunction:: save_image | ||
|
||
.. autofunction:: draw_bounding_boxes |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
import unittest | ||
from io import BytesIO | ||
import torchvision.transforms.functional as F | ||
from torchvision.io.image import read_image | ||
from PIL import Image | ||
|
||
|
||
|
@@ -79,6 +80,21 @@ def test_save_image_single_pixel_file_object(self): | |
self.assertTrue(torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)), | ||
'Pixel Image not stored in file object') | ||
|
||
def test_draw_boxes(self): | ||
img = torch.full((3, 100, 100), 255, dtype=torch.uint8) | ||
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], | ||
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) | ||
labels = ["a", "b", "c", "d"] | ||
colors = ["green", "#FF00FF", (0, 255, 0), "red"] | ||
result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors) | ||
|
||
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_util.png") | ||
if not os.path.exists(path): | ||
Image.fromarray(result.permute(1, 2, 0).numpy()).save(path) | ||
Comment on lines
+92
to
+93
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: any particular reason why you use PIL to save the result, and not There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that this is worth changing. |
||
|
||
expected = read_image(path) | ||
self.assertTrue(torch.equal(result, expected)) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,13 @@ | ||
from typing import Union, Optional, List, Tuple, Text, BinaryIO | ||
import io | ||
import pathlib | ||
import torch | ||
import math | ||
import numpy as np | ||
from PIL import Image, ImageDraw | ||
from PIL import ImageFont | ||
|
||
__all__ = ["make_grid", "save_image", "draw_bounding_boxes"] | ||
|
||
irange = range | ||
|
||
|
||
|
@@ -121,10 +126,64 @@ def save_image( | |
If a file object was used instead of a filename, this parameter should always be used. | ||
**kwargs: Other arguments are documented in ``make_grid``. | ||
""" | ||
from PIL import Image | ||
grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value, | ||
normalize=normalize, range=range, scale_each=scale_each) | ||
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer | ||
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() | ||
im = Image.fromarray(ndarr) | ||
im.save(fp, format=format) | ||
|
||
|
||
@torch.no_grad() | ||
def draw_bounding_boxes( | ||
image: torch.Tensor, | ||
boxes: torch.Tensor, | ||
labels: Optional[List[str]] = None, | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None, | ||
width: int = 1, | ||
font: Optional[str] = None, | ||
font_size: int = 10 | ||
) -> torch.Tensor: | ||
|
||
""" | ||
Draws bounding boxes on given image. | ||
The values of the input image should be uint8 between 0 and 255. | ||
|
||
Args: | ||
image (Tensor): Tensor of shape (C x H x W) | ||
bboxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that | ||
the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and | ||
`0 <= ymin < ymax < H`. | ||
labels (List[str]): List containing the labels of bounding boxes. | ||
colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of bounding boxes. The colors can | ||
be represented as `str` or `Tuple[int, int, int]`. | ||
width (int): Width of bounding box. | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may | ||
also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`, | ||
`/System/Library/Fonts/` and `~/Library/Fonts/` on macOS. | ||
font_size (int): The requested font size in points. | ||
""" | ||
|
||
if not isinstance(image, torch.Tensor): | ||
raise TypeError(f"Tensor expected, got {type(image)}") | ||
elif image.dtype != torch.uint8: | ||
raise ValueError(f"Tensor uint8 expected, got {image.dtype}") | ||
elif image.dim() != 3: | ||
raise ValueError("Pass individual images, not batches") | ||
|
||
ndarr = image.permute(1, 2, 0).numpy() | ||
img_to_draw = Image.fromarray(ndarr) | ||
|
||
img_boxes = boxes.to(torch.int64).tolist() | ||
|
||
draw = ImageDraw.Draw(img_to_draw) | ||
|
||
for i, bbox in enumerate(img_boxes): | ||
color = None if colors is None else colors[i] | ||
draw.rectangle(bbox, width=width, outline=color) | ||
|
||
if labels is not None: | ||
txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit for a follow-up PR: we can move this to outside of the for loop There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed, this can move outside of the loop. |
||
draw.text((bbox[0], bbox[1]), labels[i], fill=color, font=txt_font) | ||
|
||
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change was originally done on an intermediate commit where I was producing a random image and had to fix the seed. Though I switched to non-random to reduce the size, I think it's a good idea to move this method from
test_models.py
tocommont_utils.py
, so I kept the change in this PR.