Skip to content

Improved utilites, adds examples, tests #3594

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Mar 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
[Examples of Tensor Images transformations](https://github.com/pytorch/vision/blob/master/examples/python/tensor_transforms.ipynb)
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pytorch/vision/blob/master/examples/python/video_api.ipynb)
[Example of VideoAPI](https://github.com/pytorch/vision/blob/master/examples/python/video_api.ipynb)
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pytorch/vision/blob/master/examples/python/visualization_utils.ipynb)
[Example of Visualization Utils](https://github.com/pytorch/vision/blob/master/examples/python/visualization_utils.ipynb)


Prior to v0.8.0, transforms in torchvision have traditionally been PIL-centric and presented multiple limitations due to
Expand All @@ -16,3 +18,5 @@ features:
- read and decode data directly as torch tensor with torchscript support (for PNG and JPEG image formats)

Furthermore, previously we used to provide a very high-level API for video decoding which left little control to the user. We're now expanding that API (and replacing it in the future) with a lower-level API that allows the user a frame-based access to a video.

Torchvision also provides utilities to visualize results. You can make grid of images, plot bounding boxes as well as segmentation masks. Thse utilities work standalone as well as with torchvision models for detection and segmentation.
683 changes: 683 additions & 0 deletions examples/python/visualization_utils.ipynb

Large diffs are not rendered by default.

Binary file added test/assets/fakedata/draw_boxes_vanilla.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
58 changes: 56 additions & 2 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import torchvision.transforms.functional as F
from PIL import Image

boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0],
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)

masks = torch.tensor([
[
[-2.2799, -2.2799, -2.2799, -2.2799, -2.2799],
Expand Down Expand Up @@ -106,8 +109,8 @@ def test_save_image_single_pixel_file_object(self):

def test_draw_boxes(self):
img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0],
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
img_cp = img.clone()
boxes_cp = boxes.clone()
labels = ["a", "b", "c", "d"]
colors = ["green", "#FF00FF", (0, 255, 0), "red"]
result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors, fill=True)
Expand All @@ -119,9 +122,41 @@ def test_draw_boxes(self):

expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
self.assertTrue(torch.equal(result, expected))
# Check if modification is not in place
self.assertTrue(torch.all(torch.eq(boxes, boxes_cp)).item())
self.assertTrue(torch.all(torch.eq(img, img_cp)).item())

def test_draw_boxes_vanilla(self):
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
img_cp = img.clone()
boxes_cp = boxes.clone()
result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7)

path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_vanilla.png")
if not os.path.exists(path):
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
res.save(path)

expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
self.assertTrue(torch.equal(result, expected))
# Check if modification is not in place
self.assertTrue(torch.all(torch.eq(boxes, boxes_cp)).item())
self.assertTrue(torch.all(torch.eq(img, img_cp)).item())

def test_draw_invalid_boxes(self):
img_tp = ((1, 1, 1), (1, 2, 3))
img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0],
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
self.assertRaises(TypeError, utils.draw_bounding_boxes, img_tp, boxes)
self.assertRaises(ValueError, utils.draw_bounding_boxes, img_wrong1, boxes)
self.assertRaises(ValueError, utils.draw_bounding_boxes, img_wrong2, boxes)

def test_draw_segmentation_masks_colors(self):
img = torch.full((3, 5, 5), 255, dtype=torch.uint8)
img_cp = img.clone()
masks_cp = masks.clone()
colors = ["#FF00FF", (0, 255, 0), "red"]
result = utils.draw_segmentation_masks(img, masks, colors=colors)

Expand All @@ -134,9 +169,14 @@ def test_draw_segmentation_masks_colors(self):

expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
self.assertTrue(torch.equal(result, expected))
# Check if modification is not in place
self.assertTrue(torch.all(torch.eq(img, img_cp)).item())
self.assertTrue(torch.all(torch.eq(masks, masks_cp)).item())

def test_draw_segmentation_masks_no_colors(self):
img = torch.full((3, 20, 20), 255, dtype=torch.uint8)
img_cp = img.clone()
masks_cp = masks.clone()
result = utils.draw_segmentation_masks(img, masks, colors=None)

path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
Expand All @@ -148,6 +188,20 @@ def test_draw_segmentation_masks_no_colors(self):

expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
self.assertTrue(torch.equal(result, expected))
# Check if modification is not in place
self.assertTrue(torch.all(torch.eq(img, img_cp)).item())
self.assertTrue(torch.all(torch.eq(masks, masks_cp)).item())

def test_draw_invalid_masks(self):
img_tp = ((1, 1, 1), (1, 2, 3))
img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
img_wrong3 = torch.full((4, 5, 5), 255, dtype=torch.uint8)

self.assertRaises(TypeError, utils.draw_segmentation_masks, img_tp, masks)
self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong1, masks)
self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong2, masks)
self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong3, masks)


if __name__ == '__main__':
Expand Down
31 changes: 25 additions & 6 deletions torchvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def make_grid(
pad_value: int = 0,
**kwargs
) -> torch.Tensor:
"""Make a grid of images.
"""
Make a grid of images.

Args:
tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
Expand All @@ -37,9 +38,12 @@ def make_grid(
images separately rather than the (min, max) over all images. Default: ``False``.
pad_value (float, optional): Value for the padded pixels. Default: ``0``.

Example:
See this notebook `here <https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91>`_
Returns:
grid (Tensor): the tensor containing grid of images.

Example:
See this notebook
`here <https://github.com/pytorch/vision/blob/master/examples/python/visualization_utils.ipynb>`_
"""
if not (torch.is_tensor(tensor) or
(isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
Expand Down Expand Up @@ -117,7 +121,8 @@ def save_image(
format: Optional[str] = None,
**kwargs
) -> None:
"""Save a given Tensor into an image file.
"""
Save a given Tensor into an image file.

Args:
tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
Expand Down Expand Up @@ -150,7 +155,7 @@ def draw_bounding_boxes(
"""
Draws bounding boxes on given image.
The values of the input image should be uint8 between 0 and 255.
If filled, Resulting Tensor should be saved as PNG image.
If fill is True, Resulting Tensor should be saved as PNG image.

Args:
image (Tensor): Tensor of shape (C x H x W) and dtype uint8.
Expand All @@ -166,6 +171,13 @@ def draw_bounding_boxes(
also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`,
`/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
font_size (int): The requested font size in points.

Returns:
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.

Example:
See this notebook
`linked <https://github.com/pytorch/vision/blob/master/examples/python/visualization_utils.ipynb>`_
"""

if not isinstance(image, torch.Tensor):
Expand Down Expand Up @@ -209,7 +221,7 @@ def draw_bounding_boxes(
if labels is not None:
draw.text((bbox[0], bbox[1]), labels[i], fill=color, font=txt_font)

return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1)
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)


@torch.no_grad()
Expand All @@ -230,6 +242,13 @@ def draw_segmentation_masks(
alpha (float): Float number between 0 and 1 denoting factor of transpaerency of masks.
colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of masks. The colors can
be represented as `str` or `Tuple[int, int, int]`.

Returns:
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with segmentation masks plotted.

Example:
See this notebook
`attached <https://github.com/pytorch/vision/blob/master/examples/python/visualization_utils.ipynb>`_
"""

if not isinstance(image, torch.Tensor):
Expand Down