diff --git a/gallery/plot_visualization_utils.py b/gallery/plot_visualization_utils.py index 58788437a28..04c5e3dcb53 100644 --- a/gallery/plot_visualization_utils.py +++ b/gallery/plot_visualization_utils.py @@ -24,7 +24,8 @@ def show(imgs): imgs = [imgs] fix, axs = plt.subplots(ncols=len(imgs), squeeze=False) for i, img in enumerate(imgs): - img = F.to_pil_image(img.to('cpu')) + img = img.detach() + img = F.to_pil_image(img) axs[0, i].imshow(np.asarray(img)) axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) @@ -50,9 +51,8 @@ def show(imgs): # Visualizing bounding boxes # -------------------------- # We can use :func:`~torchvision.utils.draw_bounding_boxes` to draw boxes on an -# image. We can set the colors, labels, width as well as font and font size ! -# The boxes are in ``(xmin, ymin, xmax, ymax)`` format -# from torchvision.utils import draw_bounding_boxes +# image. We can set the colors, labels, width as well as font and font size. +# The boxes are in ``(xmin, ymin, xmax, ymax)`` format. from torchvision.utils import draw_bounding_boxes @@ -74,9 +74,8 @@ def show(imgs): from torchvision.transforms.functional import convert_image_dtype -dog1_float = convert_image_dtype(dog1_int, dtype=torch.float) -dog2_float = convert_image_dtype(dog2_int, dtype=torch.float) -batch = torch.stack([dog1_float, dog2_float]) +batch_int = torch.stack([dog1_int, dog2_int]) +batch = convert_image_dtype(batch_int, dtype=torch.float) model = fasterrcnn_resnet50_fpn(pretrained=True, progress=False) model = model.eval() @@ -91,7 +90,7 @@ def show(imgs): threshold = .8 dogs_with_boxes = [ draw_bounding_boxes(dog_int, boxes=output['boxes'][output['scores'] > threshold], width=4) - for dog_int, output in zip((dog1_int, dog2_int), outputs) + for dog_int, output in zip(batch_int, outputs) ] show(dogs_with_boxes) @@ -99,33 +98,255 @@ def show(imgs): # Visualizing segmentation masks # ------------------------------ # The :func:`~torchvision.utils.draw_segmentation_masks` function can be used to -# draw segmentation amasks on images. We can set the colors as well as -# transparency of masks. +# draw segmentation masks on images. Semantic segmentation and instance +# segmentation models have different outputs, so we will treat each +# independently. # -# Here is demo with torchvision's FCN Resnet-50, loaded with -# :func:`~torchvision.models.segmentation.fcn_resnet50`. -# You can also try using -# DeepLabv3 (:func:`~torchvision.models.segmentation.deeplabv3_resnet50`) -# or lraspp mobilenet models +# Semantic segmentation models +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# We will see how to use it with torchvision's FCN Resnet-50, loaded with +# :func:`~torchvision.models.segmentation.fcn_resnet50`. You can also try using +# DeepLabv3 (:func:`~torchvision.models.segmentation.deeplabv3_resnet50`) or +# lraspp mobilenet models # (:func:`~torchvision.models.segmentation.lraspp_mobilenet_v3_large`). # -# Like :func:`~torchvision.utils.draw_bounding_boxes`, -# :func:`~torchvision.utils.draw_segmentation_masks` requires a single RGB image -# of dtype `uint8`. +# Let's start by looking at the ouput of the model. Remember that in general, +# images must be normalized before they're passed to a semantic segmentation +# model. from torchvision.models.segmentation import fcn_resnet50 -from torchvision.utils import draw_segmentation_masks model = fcn_resnet50(pretrained=True, progress=False) model = model.eval() -# The model expects the batch to be normalized -batch = F.normalize(batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) -outputs = model(batch) +normalized_batch = F.normalize(batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) +output = model(normalized_batch)['out'] +print(output.shape, output.min().item(), output.max().item()) + +##################################### +# As we can see above, the output of the segmentation model is a tensor of shape +# ``(batch_size, num_classes, H, W)``. Each value is a non-normalized score, and +# we can normalize them into ``[0, 1]`` by using a softmax. After the softmax, +# we can interpret each value as a probability indicating how likely a given +# pixel is to belong to a given class. +# +# Let's plot the masks that have been detected for the dog class and for the +# boat class: + +sem_classes = [ + '__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', + 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', + 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' +] +sem_class_to_idx = {cls: idx for (idx, cls) in enumerate(sem_classes)} + +normalized_masks = torch.nn.functional.softmax(output, dim=1) + +dog_and_boat_masks = [ + normalized_masks[img_idx, sem_class_to_idx[cls]] + for img_idx in range(batch.shape[0]) + for cls in ('dog', 'boat') +] + +show(dog_and_boat_masks) + +##################################### +# As expected, the model is confident about the dog class, but not so much for +# the boat class. +# +# The :func:`~torchvision.utils.draw_segmentation_masks` function can be used to +# plots those masks on top of the original image. This function expects the +# masks to be boolean masks, but our masks above contain probabilities in ``[0, +# 1]``. To get boolean masks, we can do the following: + +class_dim = 1 +boolean_dog_masks = (normalized_masks.argmax(class_dim) == sem_class_to_idx['dog']) +print(f"shape = {boolean_dog_masks.shape}, dtype = {boolean_dog_masks.dtype}") +show([m.float() for m in boolean_dog_masks]) + + +##################################### +# The line above where we define ``boolean_dog_masks`` is a bit cryptic, but you +# can read it as the following query: "For which pixels is 'dog' the most likely +# class?" +# +# .. note:: +# While we're using the ``normalized_masks`` here, we would have +# gotten the same result by using the non-normalized scores of the model +# directly (as the softmax operation preserves the order). +# +# Now that we have boolean masks, we can use them with +# :func:`~torchvision.utils.draw_segmentation_masks` to plot them on top of the +# original images: + +from torchvision.utils import draw_segmentation_masks + +dogs_with_masks = [ + draw_segmentation_masks(img, masks=mask, alpha=0.7) + for img, mask in zip(batch_int, boolean_dog_masks) +] +show(dogs_with_masks) + +##################################### +# We can plot more than one mask per image! Remember that the model returned as +# many masks as there are classes. Let's ask the same query as above, but this +# time for *all* classes, not just the dog class: "For each pixel and each class +# C, is class C the most most likely class?" +# +# This one is a bit more involved, so we'll first show how to do it with a +# single image, and then we'll generalize to the batch + +num_classes = normalized_masks.shape[1] +dog1_masks = normalized_masks[0] +class_dim = 0 +dog1_all_classes_masks = dog1_masks.argmax(class_dim) == torch.arange(num_classes)[:, None, None] + +print(f"dog1_masks shape = {dog1_masks.shape}, dtype = {dog1_masks.dtype}") +print(f"dog1_all_classes_masks = {dog1_all_classes_masks.shape}, dtype = {dog1_all_classes_masks.dtype}") + +dog_with_all_masks = draw_segmentation_masks(dog1_int, masks=dog1_all_classes_masks, alpha=.6) +show(dog_with_all_masks) + +##################################### +# We can see in the image above that only 2 masks were drawn: the mask for the +# background and the mask for the dog. This is because the model thinks that +# only these 2 classes are the most likely ones across all the pixels. If the +# model had detected another class as the most likely among other pixels, we +# would have seen its mask above. +# +# Removing the background mask is as simple as passing +# ``masks=dog1_all_classes_masks[1:]``, because the background class is the +# class with index 0. +# +# Let's now do the same but for an entire batch of images. The code is similar +# but involves a bit more juggling with the dimensions. + +class_dim = 1 +all_classes_masks = normalized_masks.argmax(class_dim) == torch.arange(num_classes)[:, None, None, None] +print(f"shape = {all_classes_masks.shape}, dtype = {all_classes_masks.dtype}") +# The first dimension is the classes now, so we need to swap it +all_classes_masks = all_classes_masks.swapaxes(0, 1) dogs_with_masks = [ - draw_segmentation_masks(dog_int, masks=masks, alpha=0.6) - for dog_int, masks in zip((dog1_int, dog2_int), outputs['out']) + draw_segmentation_masks(img, masks=mask, alpha=.6) + for img, mask in zip(batch_int, all_classes_masks) ] show(dogs_with_masks) + + +##################################### +# Instance segmentation models +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# Instance segmentation models have a significantly different output from the +# semantic segmentation models. We will see here how to plot the masks for such +# models. Let's start by analyzing the output of a Mask-RCNN model. Note that +# these models don't require the images to be normalized, so we don't need to +# use the normalized batch. + +from torchvision.models.detection import maskrcnn_resnet50_fpn +model = maskrcnn_resnet50_fpn(pretrained=True, progress=False) +model = model.eval() + +output = model(batch) +print(output) + +##################################### +# Let's break this down. For each image in the batch, the model outputs some +# detections (or instances). The number of detection varies for each input +# image. Each instance is described by its bounding box, its label, its score +# and its mask. +# +# The way the output is organized is as follows: the output is a list of length +# ``batch_size``. Each entry in the list corresponds to an input image, and it +# is a dict with keys 'boxes', 'labels', 'scores', and 'masks'. Each value +# associated to those keys has ``num_instances`` elements in it. In our case +# above there are 3 instances detected in the first image, and 2 instances in +# the second one. +# +# The boxes can be plotted with :func:`~torchvision.utils.draw_bounding_boxes` +# as above, but here we're more interested in the masks. These masks are quite +# different from the masks that we saw above for the semantic segmentation +# models. + +dog1_output = output[0] +dog1_masks = dog1_output['masks'] +print(f"shape = {dog1_masks.shape}, dtype = {dog1_masks.dtype}, " + f"min = {dog1_masks.min()}, max = {dog1_masks.max()}") + +##################################### +# Here the masks corresponds to probabilities indicating, for each pixel, how +# likely it is to belong to the predicted label of that instance. Those +# predicted labels correspond to the 'labels' element in the same output dict. +# Let's see which labels were predicted for the instances of the first image. + +inst_classes = [ + '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', + 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', + 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', + 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', + 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', + 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', + 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', + 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', + 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', + 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', + 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' +] + +inst_class_to_idx = {cls: idx for (idx, cls) in enumerate(inst_classes)} + +print("For the first dog, the following instances were detected:") +print([inst_classes[label] for label in dog1_output['labels']]) + +##################################### +# Interestingly, the model detects two persons in the image. Let's go ahead and +# plot those masks. Since :func:`~torchvision.utils.draw_segmentation_masks` +# expects boolean masks, we need to convert those probabilities into boolean +# values. Remember that the semantic of those masks is "How likely is this pixel +# to belong to the predicted class?". As a result, a natural way of converting +# those masks into boolean values is to threshold them with the 0.5 probability +# (one could also choose a different threshold). + +proba_threshold = 0.5 +dog1_bool_masks = dog1_output['masks'] > proba_threshold +print(f"shape = {dog1_bool_masks.shape}, dtype = {dog1_bool_masks.dtype}") + +# There's an extra dimension (1) to the masks. We need to remove it +dog1_bool_masks = dog1_bool_masks.squeeze(1) + +show(draw_segmentation_masks(dog1_int, dog1_bool_masks, alpha=0.9)) + +##################################### +# The model seems to have properly detected the dog, but it also confused trees +# with people. Looking more closely at the scores will help us plotting more +# relevant masks: + +print(dog1_output['scores']) + +##################################### +# Clearly the model is less confident about the dog detection than it is about +# the people detections. That's good news. When plotting the masks, we can ask +# for only those that have a good score. Let's use a score threshold of .75 +# here, and also plot the masks of the second dog. + +score_threshold = .75 + +boolean_masks = [ + out['masks'][out['scores'] > score_threshold] > proba_threshold + for out in output +] + +dogs_with_masks = [ + draw_segmentation_masks(img, mask.squeeze(1)) + for img, mask in zip(batch_int, boolean_masks) +] +show(dogs_with_masks) + +##################################### +# The two 'people' masks in the first image where not selected because they have +# a lower score than the score threshold. Similarly in the second image, the +# instance with class 15 (which corresponds to 'bench') was not selected. diff --git a/test/assets/fakedata/draw_segm_masks_colors_util.png b/test/assets/fakedata/draw_segm_masks_colors_util.png deleted file mode 100644 index 454b3555631..00000000000 Binary files a/test/assets/fakedata/draw_segm_masks_colors_util.png and /dev/null differ diff --git a/test/assets/fakedata/draw_segm_masks_no_colors_util.png b/test/assets/fakedata/draw_segm_masks_no_colors_util.png deleted file mode 100644 index f048d2469d2..00000000000 Binary files a/test/assets/fakedata/draw_segm_masks_no_colors_util.png and /dev/null differ diff --git a/test/test_utils.py b/test/test_utils.py index 8c4cc620229..ee683b27ca4 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,3 +1,4 @@ +import pytest import numpy as np import os import sys @@ -7,7 +8,7 @@ import unittest from io import BytesIO import torchvision.transforms.functional as F -from PIL import Image, __version__ as PILLOW_VERSION +from PIL import Image, __version__ as PILLOW_VERSION, ImageColor PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split('.')) @@ -159,55 +160,88 @@ def test_draw_invalid_boxes(self): self.assertRaises(ValueError, utils.draw_bounding_boxes, img_wrong1, boxes) self.assertRaises(ValueError, utils.draw_bounding_boxes, img_wrong2, boxes) - def test_draw_segmentation_masks_colors(self): - img = torch.full((3, 5, 5), 255, dtype=torch.uint8) - img_cp = img.clone() - masks_cp = masks.clone() - colors = ["#FF00FF", (0, 255, 0), "red"] - result = utils.draw_segmentation_masks(img, masks, colors=colors) - - path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", - "fakedata", "draw_segm_masks_colors_util.png") - - if not os.path.exists(path): - res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy()) - res.save(path) - - expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) - self.assertTrue(torch.equal(result, expected)) - # Check if modification is not in place - self.assertTrue(torch.all(torch.eq(img, img_cp)).item()) - self.assertTrue(torch.all(torch.eq(masks, masks_cp)).item()) - - def test_draw_segmentation_masks_no_colors(self): - img = torch.full((3, 20, 20), 255, dtype=torch.uint8) - img_cp = img.clone() - masks_cp = masks.clone() - result = utils.draw_segmentation_masks(img, masks, colors=None) - - path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", - "fakedata", "draw_segm_masks_no_colors_util.png") - - if not os.path.exists(path): - res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy()) - res.save(path) - - expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) - self.assertTrue(torch.equal(result, expected)) - # Check if modification is not in place - self.assertTrue(torch.all(torch.eq(img, img_cp)).item()) - self.assertTrue(torch.all(torch.eq(masks, masks_cp)).item()) - - def test_draw_invalid_masks(self): - img_tp = ((1, 1, 1), (1, 2, 3)) - img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float) - img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8) - img_wrong3 = torch.full((4, 5, 5), 255, dtype=torch.uint8) - self.assertRaises(TypeError, utils.draw_segmentation_masks, img_tp, masks) - self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong1, masks) - self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong2, masks) - self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong3, masks) +@pytest.mark.parametrize('colors', [ + None, + ['red', 'blue'], + ['#FF00FF', (1, 34, 122)], +]) +@pytest.mark.parametrize('alpha', (0, .5, .7, 1)) +def test_draw_segmentation_masks(colors, alpha): + """This test makes sure that masks draw their corresponding color where they should""" + num_masks, h, w = 2, 100, 100 + dtype = torch.uint8 + img = torch.randint(0, 256, size=(3, h, w), dtype=dtype) + masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool) + + # For testing we enforce that there's no overlap between the masks. The + # current behaviour is that the last mask's color will take priority when + # masks overlap, but this makes testing slightly harder so we don't really + # care + overlap = masks[0] & masks[1] + masks[:, overlap] = False + + out = utils.draw_segmentation_masks(img, masks, colors=colors, alpha=alpha) + assert out.dtype == dtype + assert out is not img + + # Make sure the image didn't change where there's no mask + masked_pixels = masks[0] | masks[1] + assert (img[:, ~masked_pixels] == out[:, ~masked_pixels]).all() + + if colors is None: + colors = utils._generate_color_palette(num_masks) + + # Make sure each mask draws with its own color + for mask, color in zip(masks, colors): + if isinstance(color, str): + color = ImageColor.getrgb(color) + color = torch.tensor(color, dtype=dtype) + + if alpha == 1: + assert (out[:, mask] == color[:, None]).all() + elif alpha == 0: + assert (out[:, mask] == img[:, mask]).all() + + interpolated_color = (img[:, mask] * (1 - alpha) + color[:, None] * alpha) + max_diff = (out[:, mask] - interpolated_color).abs().max() + assert max_diff <= 1 + + +def test_draw_segmentation_masks_errors(): + h, w = 10, 10 + + masks = torch.randint(0, 2, size=(h, w), dtype=torch.bool) + img = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8) + + with pytest.raises(TypeError, match="The image must be a tensor"): + utils.draw_segmentation_masks(image="Not A Tensor Image", masks=masks) + with pytest.raises(ValueError, match="The image dtype must be"): + img_bad_dtype = torch.randint(0, 256, size=(3, h, w), dtype=torch.int64) + utils.draw_segmentation_masks(image=img_bad_dtype, masks=masks) + with pytest.raises(ValueError, match="Pass individual images, not batches"): + batch = torch.randint(0, 256, size=(10, 3, h, w), dtype=torch.uint8) + utils.draw_segmentation_masks(image=batch, masks=masks) + with pytest.raises(ValueError, match="Pass an RGB image"): + one_channel = torch.randint(0, 256, size=(1, h, w), dtype=torch.uint8) + utils.draw_segmentation_masks(image=one_channel, masks=masks) + with pytest.raises(ValueError, match="The masks must be of dtype bool"): + masks_bad_dtype = torch.randint(0, 2, size=(h, w), dtype=torch.float) + utils.draw_segmentation_masks(image=img, masks=masks_bad_dtype) + with pytest.raises(ValueError, match="masks must be of shape"): + masks_bad_shape = torch.randint(0, 2, size=(3, 2, h, w), dtype=torch.bool) + utils.draw_segmentation_masks(image=img, masks=masks_bad_shape) + with pytest.raises(ValueError, match="must have the same height and width"): + masks_bad_shape = torch.randint(0, 2, size=(h + 4, w), dtype=torch.bool) + utils.draw_segmentation_masks(image=img, masks=masks_bad_shape) + with pytest.raises(ValueError, match="There are more masks"): + utils.draw_segmentation_masks(image=img, masks=masks, colors=[]) + with pytest.raises(ValueError, match="colors must be a tuple or a string, or a list thereof"): + bad_colors = np.array(['red', 'blue']) # should be a list + utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors) + with pytest.raises(ValueError, match="It seems that you passed a tuple of colors instead of"): + bad_colors = ('red', 'blue') # should be a list + utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors) if __name__ == '__main__': diff --git a/torchvision/utils.py b/torchvision/utils.py index 9d9bbdb3c80..8b23cae6eee 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -220,7 +220,7 @@ def draw_bounding_boxes( def draw_segmentation_masks( image: torch.Tensor, masks: torch.Tensor, - alpha: float = 0.2, + alpha: float = 0.8, colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None, ) -> torch.Tensor: @@ -229,49 +229,68 @@ def draw_segmentation_masks( The values of the input image should be uint8 between 0 and 255. Args: - image (Tensor): Tensor of shape (3 x H x W) and dtype uint8. - masks (Tensor): Tensor of shape (num_masks, H, W). Each containing probability of predicted class. - alpha (float): Float number between 0 and 1 denoting factor of transparency of masks. - colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of masks. The colors can - be represented as `str` or `Tuple[int, int, int]`. + image (Tensor): Tensor of shape (3, H, W) and dtype uint8. + masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool. + alpha (float): Float number between 0 and 1 denoting the transparency of the masks. + 0 means full transparency, 1 means no transparency. + colors (list or None): List containing the colors of the masks. The colors can + be represented as PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. + When ``masks`` has a single entry of shape (H, W), you can pass a single color instead of a list + with one element. By default, random colors are generated for each mask. Returns: - img (Tensor[C, H, W]): Image Tensor of dtype uint8 with segmentation masks plotted. + img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top. """ if not isinstance(image, torch.Tensor): - raise TypeError(f"Tensor expected, got {type(image)}") + raise TypeError(f"The image must be a tensor, got {type(image)}") elif image.dtype != torch.uint8: - raise ValueError(f"Tensor uint8 expected, got {image.dtype}") + raise ValueError(f"The image dtype must be uint8, got {image.dtype}") elif image.dim() != 3: raise ValueError("Pass individual images, not batches") elif image.size()[0] != 3: raise ValueError("Pass an RGB image. Other Image formats are not supported") + if masks.ndim == 2: + masks = masks[None, :, :] + if masks.ndim != 3: + raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)") + if masks.dtype != torch.bool: + raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}") + if masks.shape[-2:] != image.shape[-2:]: + raise ValueError("The image and the masks must have the same height and width") num_masks = masks.size()[0] - masks = masks.argmax(0) + if colors is not None and num_masks > len(colors): + raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})") if colors is None: - palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) - colors_t = torch.as_tensor([i for i in range(num_masks)])[:, None] * palette - color_arr = (colors_t % 255).numpy().astype("uint8") - else: - color_list = [] - for color in colors: - if isinstance(color, str): - # This will automatically raise Error if rgb cannot be parsed. - fill_color = ImageColor.getrgb(color) - color_list.append(fill_color) - elif isinstance(color, tuple): - color_list.append(color) + colors = _generate_color_palette(num_masks) + + if not isinstance(colors, list): + colors = [colors] + if not isinstance(colors[0], (tuple, str)): + raise ValueError("colors must be a tuple or a string, or a list thereof") + if isinstance(colors[0], tuple) and len(colors[0]) != 3: + raise ValueError("It seems that you passed a tuple of colors instead of a list of colors") + + out_dtype = torch.uint8 + + colors_ = [] + for color in colors: + if isinstance(color, str): + color = ImageColor.getrgb(color) + color = torch.tensor(color, dtype=out_dtype) + colors_.append(color) - color_arr = np.array(color_list).astype("uint8") + img_to_draw = image.detach().clone() + # TODO: There might be a way to vectorize this + for mask, color in zip(masks, colors_): + img_to_draw[:, mask] = color[:, None] - _, h, w = image.size() - img_to_draw = Image.fromarray(masks.byte().cpu().numpy()).resize((w, h)) - img_to_draw.putpalette(color_arr) + out = image * (1 - alpha) + img_to_draw * alpha + return out.to(out_dtype) - img_to_draw = torch.from_numpy(np.array(img_to_draw.convert('RGB'))) - img_to_draw = img_to_draw.permute((2, 0, 1)) - return (image.float() * alpha + img_to_draw.float() * (1.0 - alpha)).to(dtype=torch.uint8) +def _generate_color_palette(num_masks): + palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) + return [tuple((i * palette) % 255) for i in range(num_masks)]