From 5a2bbc57d1c5d6bcecbb35346e5fab7e68e0c47c Mon Sep 17 00:00:00 2001 From: Sasank Chilamkurthy Date: Sun, 3 Sep 2017 23:33:22 +0530 Subject: [PATCH 1/9] First cut refactoring (cherry picked from commit 71afec427baca8e37cd9e10d98812bc586e9a4ac) --- torchvision/transforms.py | 246 ++++++++++++++++++++++---------------- 1 file changed, 146 insertions(+), 100 deletions(-) diff --git a/torchvision/transforms.py b/torchvision/transforms.py index a413fb68553..2137ed90346 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -13,6 +13,112 @@ import collections +def to_tensor(pic): + if isinstance(pic, np.ndarray): + # handle numpy array + img = torch.from_numpy(pic.transpose((2, 0, 1))) + # backward compatibility + return img.float().div(255) + + if accimage is not None and isinstance(pic, accimage.Image): + nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32) + pic.copyto(nppic) + return torch.from_numpy(nppic) + + # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + # put it from HWC to CHW format + # yikes, this transpose takes 80% of the loading time/CPU + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float().div(255) + else: + return img + + +def to_pilimage(pic): + npimg = pic + mode = None + if isinstance(pic, torch.FloatTensor): + pic = pic.mul(255).byte() + if torch.is_tensor(pic): + npimg = np.transpose(pic.numpy(), (1, 2, 0)) + assert isinstance(npimg, np.ndarray), 'pic should be Tensor or ndarray' + if npimg.shape[2] == 1: + npimg = npimg[:, :, 0] + + if npimg.dtype == np.uint8: + mode = 'L' + if npimg.dtype == np.int16: + mode = 'I;16' + if npimg.dtype == np.int32: + mode = 'I' + elif npimg.dtype == np.float32: + mode = 'F' + else: + if npimg.dtype == np.uint8: + mode = 'RGB' + assert mode is not None, '{} is not supported'.format(npimg.dtype) + return Image.fromarray(npimg, mode=mode) + + +def normalize(tensor, mean, std): + # TODO: make efficient + for t, m, s in zip(tensor, mean, std): + t.sub_(m).div_(s) + return tensor + + +def scale(img, size, interpolation=Image.BILINEAR): + assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) + if isinstance(size, int): + w, h = img.size + if (w <= h and w == size) or (h <= w and h == size): + return img + if w < h: + ow = size + oh = int(size * h / w) + return img.resize((ow, oh), interpolation) + else: + oh = size + ow = int(size * w / h) + return img.resize((ow, oh), interpolation) + else: + return img.resize(size, interpolation) + + +def pad(img, padding, fill=0): + assert isinstance(padding, numbers.Number) + assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple) + return ImageOps.expand(img, border=padding, fill=fill) + + +def crop(img, x, y, w, h): + return img.crop((x, y, x + w, y + h)) + + +def scaled_crop(img, x, y, w, h, size, interpolation=Image.BILINEAR): + img = crop(img, x, y, w, h) + img = scale(img, size, interpolation) + + +def hflip(img): + return img.transpose(Image.FLIP_LEFT_RIGHT) + + class Compose(object): """Composes several transforms together. @@ -50,39 +156,7 @@ def __call__(self, pic): Returns: Tensor: Converted image. """ - if isinstance(pic, np.ndarray): - # handle numpy array - img = torch.from_numpy(pic.transpose((2, 0, 1))) - # backward compatibility - return img.float().div(255) - - if accimage is not None and isinstance(pic, accimage.Image): - nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32) - pic.copyto(nppic) - return torch.from_numpy(nppic) - - # handle PIL Image - if pic.mode == 'I': - img = torch.from_numpy(np.array(pic, np.int32, copy=False)) - elif pic.mode == 'I;16': - img = torch.from_numpy(np.array(pic, np.int16, copy=False)) - else: - img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) - # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK - if pic.mode == 'YCbCr': - nchannel = 3 - elif pic.mode == 'I;16': - nchannel = 1 - else: - nchannel = len(pic.mode) - img = img.view(pic.size[1], pic.size[0], nchannel) - # put it from HWC to CHW format - # yikes, this transpose takes 80% of the loading time/CPU - img = img.transpose(0, 1).transpose(0, 2).contiguous() - if isinstance(img, torch.ByteTensor): - return img.float().div(255) - else: - return img + return to_tensor(pic) class ToPILImage(object): @@ -101,29 +175,7 @@ def __call__(self, pic): PIL.Image: Image converted to PIL.Image. """ - npimg = pic - mode = None - if isinstance(pic, torch.FloatTensor): - pic = pic.mul(255).byte() - if torch.is_tensor(pic): - npimg = np.transpose(pic.numpy(), (1, 2, 0)) - assert isinstance(npimg, np.ndarray), 'pic should be Tensor or ndarray' - if npimg.shape[2] == 1: - npimg = npimg[:, :, 0] - - if npimg.dtype == np.uint8: - mode = 'L' - if npimg.dtype == np.int16: - mode = 'I;16' - if npimg.dtype == np.int32: - mode = 'I' - elif npimg.dtype == np.float32: - mode = 'F' - else: - if npimg.dtype == np.uint8: - mode = 'RGB' - assert mode is not None, '{} is not supported'.format(npimg.dtype) - return Image.fromarray(npimg, mode=mode) + return to_pilimage(pic) class Normalize(object): @@ -151,10 +203,7 @@ def __call__(self, tensor): Returns: Tensor: Normalized image. """ - # TODO: make efficient - for t, m, s in zip(tensor, self.mean, self.std): - t.sub_(m).div_(s) - return tensor + return normalize(tensor, self.mean, self.std) class Scale(object): @@ -183,20 +232,7 @@ def __call__(self, img): Returns: PIL.Image: Rescaled image. """ - if isinstance(self.size, int): - w, h = img.size - if (w <= h and w == self.size) or (h <= w and h == self.size): - return img - if w < h: - ow = self.size - oh = int(self.size * h / w) - return img.resize((ow, oh), self.interpolation) - else: - oh = self.size - ow = int(self.size * w / h) - return img.resize((ow, oh), self.interpolation) - else: - return img.resize(self.size, self.interpolation) + return scale(img, self.size, self.interpolation) class CenterCrop(object): @@ -214,6 +250,13 @@ def __init__(self, size): else: self.size = size + def get_params(self, img): + w, h = img.size + th, tw = self.size + x1 = int(round((w - tw) / 2.)) + y1 = int(round((h - th) / 2.)) + return x1, y1, tw, th + def __call__(self, img): """ Args: @@ -222,11 +265,8 @@ def __call__(self, img): Returns: PIL.Image: Cropped image. """ - w, h = img.size - th, tw = self.size - x1 = int(round((w - tw) / 2.)) - y1 = int(round((h - th) / 2.)) - return img.crop((x1, y1, x1 + tw, y1 + th)) + x1, y1, tw, th = self.get_params(img) + return crop(img, x1, y1, tw, th) class Pad(object): @@ -260,7 +300,7 @@ def __call__(self, img): Returns: PIL.Image: Padded image. """ - return ImageOps.expand(img, border=self.padding, fill=self.fill) + return pad(img, self.padding, self.fill) class Lambda(object): @@ -298,6 +338,16 @@ def __init__(self, size, padding=0): self.size = size self.padding = padding + def get_params(self, img): + w, h = img.size + th, tw = self.size + if w == tw and h == th: + return img + + x1 = random.randint(0, w - tw) + y1 = random.randint(0, h - th) + return x1, y1, tw, th + def __call__(self, img): """ Args: @@ -307,16 +357,11 @@ def __call__(self, img): PIL.Image: Cropped image. """ if self.padding > 0: - img = ImageOps.expand(img, border=self.padding, fill=0) + img = pad(img, self.padding) - w, h = img.size - th, tw = self.size - if w == tw and h == th: - return img + x1, y1, tw, th = self.get_params(img) - x1 = random.randint(0, w - tw) - y1 = random.randint(0, h - th) - return img.crop((x1, y1, x1 + tw, y1 + th)) + return crop(img, x1, y1, tw, th) class RandomHorizontalFlip(object): @@ -331,7 +376,7 @@ def __call__(self, img): PIL.Image: Randomly flipped image. """ if random.random() < 0.5: - return img.transpose(Image.FLIP_LEFT_RIGHT) + return hflip(img) return img @@ -352,7 +397,7 @@ def __init__(self, size, interpolation=Image.BILINEAR): self.size = size self.interpolation = interpolation - def __call__(self, img): + def get_params(self, img): for attempt in range(10): area = img.size[0] * img.size[1] target_area = random.uniform(0.08, 1.0) * area @@ -365,15 +410,16 @@ def __call__(self, img): w, h = h, w if w <= img.size[0] and h <= img.size[1]: - x1 = random.randint(0, img.size[0] - w) - y1 = random.randint(0, img.size[1] - h) - - img = img.crop((x1, y1, x1 + w, y1 + h)) - assert(img.size == (w, h)) - - return img.resize((self.size, self.size), self.interpolation) + x = random.randint(0, img.size[0] - w) + y = random.randint(0, img.size[1] - h) + return x, y, w, h # Fallback - scale = Scale(self.size, interpolation=self.interpolation) - crop = CenterCrop(self.size) - return crop(scale(img)) + w = min(img.size[0], img.shape[1]) + x = (img.shape[0] - w) // 2 + y = (img.shape[1] - w) // 2 + return x, y, w, w + + def __call__(self, img): + x, y, w, h = self.get_params(img) + return scaled_crop(img, x, y, w, h, self.size, self.interpolation) From bf38166d59da13a4917c34269fec7b07681f2a41 Mon Sep 17 00:00:00 2001 From: Sasank Chilamkurthy Date: Sun, 3 Sep 2017 23:42:36 +0530 Subject: [PATCH 2/9] Modify assert for pad --- torchvision/transforms.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms.py b/torchvision/transforms.py index 2137ed90346..410e7ceeb81 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -101,8 +101,12 @@ def scale(img, size, interpolation=Image.BILINEAR): def pad(img, padding, fill=0): - assert isinstance(padding, numbers.Number) - assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple) + assert isinstance(padding, (numbers.Number, tuple)) + assert isinstance(fill, (numbers.Number, str, tuple)) + if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]: + raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + + "{} element tuple".format(len(padding))) + return ImageOps.expand(img, border=padding, fill=fill) From 7aeec57f749e3e7b2f8680e10da52ca7b3e7aac9 Mon Sep 17 00:00:00 2001 From: Sasank Chilamkurthy Date: Mon, 4 Sep 2017 00:32:02 +0530 Subject: [PATCH 3/9] Asserts for functions --- torchvision/transforms.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/torchvision/transforms.py b/torchvision/transforms.py index 410e7ceeb81..cec628da179 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -13,7 +13,24 @@ import collections +def _is_pil_image(img): + if accimage is not None: + return isinstance(img, (Image.Image, accimage.Image)) + else: + return isinstance(img, Image.Image) + + +def _is_tensor_image(img): + return torch.is_tensor(img) and img.ndimension() == 3 + + +def _is_numpy_image(img): + return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) + + def to_tensor(pic): + assert _is_pil_image(pic) or _is_numpy_image(pic), 'pic should be PIL Image or ndarray' + if isinstance(pic, np.ndarray): # handle numpy array img = torch.from_numpy(pic.transpose((2, 0, 1))) @@ -50,13 +67,15 @@ def to_tensor(pic): def to_pilimage(pic): + assert _is_numpy_image(pic) or _is_tensor_image(pic), 'pic should be Tensor or ndarray' + npimg = pic mode = None if isinstance(pic, torch.FloatTensor): pic = pic.mul(255).byte() if torch.is_tensor(pic): npimg = np.transpose(pic.numpy(), (1, 2, 0)) - assert isinstance(npimg, np.ndarray), 'pic should be Tensor or ndarray' + assert isinstance(npimg, np.ndarray) if npimg.shape[2] == 1: npimg = npimg[:, :, 0] @@ -76,6 +95,7 @@ def to_pilimage(pic): def normalize(tensor, mean, std): + assert _is_tensor_image(tensor) # TODO: make efficient for t, m, s in zip(tensor, mean, std): t.sub_(m).div_(s) @@ -83,6 +103,7 @@ def normalize(tensor, mean, std): def scale(img, size, interpolation=Image.BILINEAR): + assert _is_pil_image(img), 'img should be PIL Image' assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) if isinstance(size, int): w, h = img.size @@ -101,6 +122,7 @@ def scale(img, size, interpolation=Image.BILINEAR): def pad(img, padding, fill=0): + assert _is_pil_image(img), 'img should be PIL Image' assert isinstance(padding, (numbers.Number, tuple)) assert isinstance(fill, (numbers.Number, str, tuple)) if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]: @@ -111,15 +133,18 @@ def pad(img, padding, fill=0): def crop(img, x, y, w, h): + assert _is_pil_image(img), 'img should be PIL Image' return img.crop((x, y, x + w, y + h)) def scaled_crop(img, x, y, w, h, size, interpolation=Image.BILINEAR): + assert _is_pil_image(img), 'img should be PIL Image' img = crop(img, x, y, w, h) img = scale(img, size, interpolation) def hflip(img): + assert _is_pil_image(img), 'img should be PIL Image' return img.transpose(Image.FLIP_LEFT_RIGHT) From 8b18f526a7d1a728448c683e050489a99d967845 Mon Sep 17 00:00:00 2001 From: Sasank Chilamkurthy Date: Mon, 4 Sep 2017 14:51:24 +0530 Subject: [PATCH 4/9] raise TypeErrors instead of assertins --- torchvision/transforms.py | 35 +++++++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/torchvision/transforms.py b/torchvision/transforms.py index cec628da179..bf3ecb0ba9e 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -29,7 +29,8 @@ def _is_numpy_image(img): def to_tensor(pic): - assert _is_pil_image(pic) or _is_numpy_image(pic), 'pic should be PIL Image or ndarray' + if not(_is_pil_image(pic) or _is_numpy_image(pic)): + raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic))) if isinstance(pic, np.ndarray): # handle numpy array @@ -67,7 +68,8 @@ def to_tensor(pic): def to_pilimage(pic): - assert _is_numpy_image(pic) or _is_tensor_image(pic), 'pic should be Tensor or ndarray' + if not(_is_numpy_image(pic) or _is_tensor_image(pic)): + raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic))) npimg = pic mode = None @@ -95,7 +97,8 @@ def to_pilimage(pic): def normalize(tensor, mean, std): - assert _is_tensor_image(tensor) + if not _is_tensor_image(tensor): + raise TypeError('tensor is not a torch image.') # TODO: make efficient for t, m, s in zip(tensor, mean, std): t.sub_(m).div_(s) @@ -103,8 +106,11 @@ def normalize(tensor, mean, std): def scale(img, size, interpolation=Image.BILINEAR): - assert _is_pil_image(img), 'img should be PIL Image' - assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)): + raise TypeError('Got inappropriate size arg: {}'.format(size)) + if isinstance(size, int): w, h = img.size if (w <= h and w == size) or (h <= w and h == size): @@ -122,9 +128,14 @@ def scale(img, size, interpolation=Image.BILINEAR): def pad(img, padding, fill=0): - assert _is_pil_image(img), 'img should be PIL Image' - assert isinstance(padding, (numbers.Number, tuple)) - assert isinstance(fill, (numbers.Number, str, tuple)) + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + if not isinstance(padding, (numbers.Number, tuple)): + raise TypeError('Got inappropriate padding arg') + if not isinstance(fill, (numbers.Number, str, tuple)): + raise TypeError('Got inappropriate fill arg') + if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]: raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding))) @@ -133,7 +144,9 @@ def pad(img, padding, fill=0): def crop(img, x, y, w, h): - assert _is_pil_image(img), 'img should be PIL Image' + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + return img.crop((x, y, x + w, y + h)) @@ -144,7 +157,9 @@ def scaled_crop(img, x, y, w, h, size, interpolation=Image.BILINEAR): def hflip(img): - assert _is_pil_image(img), 'img should be PIL Image' + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + return img.transpose(Image.FLIP_LEFT_RIGHT) From 4390b559afc2774f0b506e59a82a0b1e9c374c52 Mon Sep 17 00:00:00 2001 From: Sasank Chilamkurthy Date: Sat, 16 Sep 2017 22:16:37 +0530 Subject: [PATCH 5/9] Make get_params static method --- torchvision/transforms.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/torchvision/transforms.py b/torchvision/transforms.py index bf3ecb0ba9e..7f5c20af05b 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -67,7 +67,7 @@ def to_tensor(pic): return img -def to_pilimage(pic): +def to_pil_image(pic): if not(_is_numpy_image(pic) or _is_tensor_image(pic)): raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic))) @@ -219,7 +219,7 @@ def __call__(self, pic): PIL.Image: Image converted to PIL.Image. """ - return to_pilimage(pic) + return to_pil_image(pic) class Normalize(object): @@ -294,9 +294,10 @@ def __init__(self, size): else: self.size = size - def get_params(self, img): + @staticmethod + def get_params(img, output_size): w, h = img.size - th, tw = self.size + th, tw = output_size x1 = int(round((w - tw) / 2.)) y1 = int(round((h - th) / 2.)) return x1, y1, tw, th @@ -309,7 +310,7 @@ def __call__(self, img): Returns: PIL.Image: Cropped image. """ - x1, y1, tw, th = self.get_params(img) + x1, y1, tw, th = self.get_params(img, self.size) return crop(img, x1, y1, tw, th) @@ -382,9 +383,10 @@ def __init__(self, size, padding=0): self.size = size self.padding = padding - def get_params(self, img): + @staticmethod + def get_params(img, output_size): w, h = img.size - th, tw = self.size + th, tw = output_size if w == tw and h == th: return img @@ -403,7 +405,7 @@ def __call__(self, img): if self.padding > 0: img = pad(img, self.padding) - x1, y1, tw, th = self.get_params(img) + x1, y1, tw, th = self.get_params(img, self.size) return crop(img, x1, y1, tw, th) @@ -441,7 +443,8 @@ def __init__(self, size, interpolation=Image.BILINEAR): self.size = size self.interpolation = interpolation - def get_params(self, img): + @staticmethod + def get_params(img): for attempt in range(10): area = img.size[0] * img.size[1] target_area = random.uniform(0.08, 1.0) * area From f4ddc92406c2e627a5ae858e5be0730fc5de86d7 Mon Sep 17 00:00:00 2001 From: Sasank Chilamkurthy Date: Sat, 16 Sep 2017 22:54:03 +0530 Subject: [PATCH 6/9] Add documentation --- torchvision/transforms.py | 136 +++++++++++++++++++++++++++++++++++++- 1 file changed, 135 insertions(+), 1 deletion(-) diff --git a/torchvision/transforms.py b/torchvision/transforms.py index 7f5c20af05b..2a4650803fc 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -29,6 +29,16 @@ def _is_numpy_image(img): def to_tensor(pic): + """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor. + + See ``ToTensor`` for more details. + + Args: + pic (PIL.Image or numpy.ndarray): Image to be converted to tensor. + + Returns: + Tensor: Converted image. + """ if not(_is_pil_image(pic) or _is_numpy_image(pic)): raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic))) @@ -68,6 +78,16 @@ def to_tensor(pic): def to_pil_image(pic): + """Convert a tensor or an ndarray to PIL Image. + + See ``ToPIlImage`` for more details. + + Args: + pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image. + + Returns: + PIL.Image: Image converted to PIL.Image. + """ if not(_is_numpy_image(pic) or _is_tensor_image(pic)): raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic))) @@ -97,6 +117,19 @@ def to_pil_image(pic): def normalize(tensor, mean, std): + """Normalize an tensor image with mean and standard deviation. + + See ``Normalize`` for more details. + + Args: + tensor (Tensor): Tensor image of size (C, H, W) to be normalized. + mean (sequence): Sequence of means for R, G, B channels respecitvely. + std (sequence): Sequence of standard deviations for R, G, B channels + respecitvely. + + Returns: + Tensor: Normalized image. + """ if not _is_tensor_image(tensor): raise TypeError('tensor is not a torch image.') # TODO: make efficient @@ -106,6 +139,21 @@ def normalize(tensor, mean, std): def scale(img, size, interpolation=Image.BILINEAR): + """Rescale the input PIL.Image to the given size. + + Args: + img (PIL.Image): Image to be scaled. + size (sequence or int): Desired output size. If size is a sequence like + (w, h), output size will be matched to this. If size is an int, + smaller edge of the image will be matched to this number. + i.e, if height > width, then image will be rescaled to + (size * height / width, size) + interpolation (int, optional): Desired interpolation. Default is + ``PIL.Image.BILINEAR`` + + Returns: + PIL.Image: Rescaled image. + """ if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)): @@ -128,6 +176,21 @@ def scale(img, size, interpolation=Image.BILINEAR): def pad(img, padding, fill=0): + """Pad the given PIL.Image on all sides with the given "pad" value. + + Args: + img (PIL.Image): Image to be padded. + padding (int or tuple): Padding on each border. If a single int is provided this + is used to pad all borders. If tuple of length 2 is provided this is the padding + on left/right and top/bottom respectively. If a tuple of length 4 is provided + this is the padding for the left, top, right and bottom borders + respectively. + fill: Pixel fill value. Default is 0. If a tuple of + length 3, it is used to fill R, G, B channels respectively. + + Returns: + PIL.Image: Padded image. + """ if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) @@ -144,6 +207,18 @@ def pad(img, padding, fill=0): def crop(img, x, y, w, h): + """Crop the given PIL.Image. + + Args: + img (PIL.Image): Image to be cropped. + x: Left pixel coordinate. + y: Upper pixel coordinate. + w: Width of the cropped image. + h: Height of the cropped image. + + Returns: + PIL.Image: Cropped image. + """ if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) @@ -151,12 +226,36 @@ def crop(img, x, y, w, h): def scaled_crop(img, x, y, w, h, size, interpolation=Image.BILINEAR): + """Crop the given PIL.Image and scale it to desired size. + + Notably used in RandomSizedCrop. + + Args: + img (PIL.Image): Image to be cropped. + x: Left pixel coordinate. + y: Upper pixel coordinate. + w: Width of the cropped image. + h: Height of the cropped image. + size (sequence or int): Desired output size. Same semantics as ``scale``. + interpolation (int, optional): Desired interpolation. Default is + ``PIL.Image.BILINEAR``. + Returns: + PIL.Image: Cropped image. + """ assert _is_pil_image(img), 'img should be PIL Image' img = crop(img, x, y, w, h) img = scale(img, size, interpolation) def hflip(img): + """Horizontally flip the given PIL.Image. + + Args: + img (PIL.Image): Image to be flipped. + + Returns: + PIL.Image: Horizontall flipped image. + """ if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) @@ -204,7 +303,7 @@ def __call__(self, pic): class ToPILImage(object): - """Convert a tensor to PIL Image. + """Convert a tensor or an ndarray to PIL Image. Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape H x W x C to a PIL.Image while preserving the value range. @@ -296,6 +395,15 @@ def __init__(self, size): @staticmethod def get_params(img, output_size): + """Get parameters for ``crop`` for center crop. + + Args: + img (PIL.Image): Image to be cropped. + output_size (tuple): Expected output size of the crop. + + Returns: + tuple: params (x, y, w, h) to be passed to ``crop`` for center crop. + """ w, h = img.size th, tw = output_size x1 = int(round((w - tw) / 2.)) @@ -385,6 +493,15 @@ def __init__(self, size, padding=0): @staticmethod def get_params(img, output_size): + """Get parameters for ``crop`` for a random crop. + + Args: + img (PIL.Image): Image to be cropped. + output_size (tuple): Expected output size of the crop. + + Returns: + tuple: params (x, y, w, h) to be passed to ``crop`` for random crop. + """ w, h = img.size th, tw = output_size if w == tw and h == th: @@ -445,6 +562,16 @@ def __init__(self, size, interpolation=Image.BILINEAR): @staticmethod def get_params(img): + """Get parameters for ``crop`` for a random sized crop. + + Args: + img (PIL.Image): Image to be cropped. + output_size (tuple): Expected output size of the crop. + + Returns: + tuple: params (x, y, w, h) to be passed to ``crop`` for a random + sized crop. + """ for attempt in range(10): area = img.size[0] * img.size[1] target_area = random.uniform(0.08, 1.0) * area @@ -468,5 +595,12 @@ def get_params(img): return x, y, w, w def __call__(self, img): + """ + Args: + img (PIL.Image): Image to be flipped. + + Returns: + PIL.Image: Randomly cropped and scaled image. + """ x, y, w, h = self.get_params(img) return scaled_crop(img, x, y, w, h, self.size, self.interpolation) From 538d87b1c7d90b0403d9f3c55bb575ecc3b71150 Mon Sep 17 00:00:00 2001 From: Sasank Chilamkurthy Date: Sun, 17 Sep 2017 11:08:50 +0530 Subject: [PATCH 7/9] Fix a bug in randomsizedcrop --- torchvision/transforms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/transforms.py b/torchvision/transforms.py index 2a4650803fc..b9d6f4a6b38 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -245,6 +245,7 @@ def scaled_crop(img, x, y, w, h, size, interpolation=Image.BILINEAR): assert _is_pil_image(img), 'img should be PIL Image' img = crop(img, x, y, w, h) img = scale(img, size, interpolation) + return img def hflip(img): @@ -552,12 +553,12 @@ class RandomSizedCrop(object): This is popularly used to train the Inception networks. Args: - size: size of the smaller edge + size: expected output size of each edge interpolation: Default: PIL.Image.BILINEAR """ def __init__(self, size, interpolation=Image.BILINEAR): - self.size = size + self.size = (size, size) self.interpolation = interpolation @staticmethod @@ -566,7 +567,6 @@ def get_params(img): Args: img (PIL.Image): Image to be cropped. - output_size (tuple): Expected output size of the crop. Returns: tuple: params (x, y, w, h) to be passed to ``crop`` for a random From 4d7f70b5d8f523b9677b1fc2d17a217ef3204c82 Mon Sep 17 00:00:00 2001 From: Sasank Chilamkurthy Date: Tue, 19 Sep 2017 11:22:46 +0530 Subject: [PATCH 8/9] scale change to (h, w) ordering. (based on #256) --- torchvision/transforms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/transforms.py b/torchvision/transforms.py index b9d6f4a6b38..dd1034dcfd8 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -144,7 +144,7 @@ def scale(img, size, interpolation=Image.BILINEAR): Args: img (PIL.Image): Image to be scaled. size (sequence or int): Desired output size. If size is a sequence like - (w, h), output size will be matched to this. If size is an int, + (h, w), output size will be matched to this. If size is an int, smaller edge of the image will be matched to this number. i.e, if height > width, then image will be rescaled to (size * height / width, size) @@ -172,7 +172,7 @@ def scale(img, size, interpolation=Image.BILINEAR): ow = int(size * w / h) return img.resize((ow, oh), interpolation) else: - return img.resize(size, interpolation) + return img.resize(size[::-1], interpolation) def pad(img, padding, fill=0): @@ -355,7 +355,7 @@ class Scale(object): Args: size (sequence or int): Desired output size. If size is a sequence like - (w, h), output size will be matched to this. If size is an int, + (h, w), output size will be matched to this. If size is an int, smaller edge of the image will be matched to this number. i.e, if height > width, then image will be rescaled to (size * height / width, size) From 2cc58ed0a01a8f0938a5ccb7452fc0f19c23c5e1 Mon Sep 17 00:00:00 2001 From: Sasank Chilamkurthy Date: Tue, 26 Sep 2017 18:31:05 +0530 Subject: [PATCH 9/9] change x,y,w,h -> i,j,h,w --- torchvision/transforms.py | 62 +++++++++++++++++++-------------------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/torchvision/transforms.py b/torchvision/transforms.py index fbde8ac4e38..75e8a64012e 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -209,15 +209,15 @@ def pad(img, padding, fill=0): return ImageOps.expand(img, border=padding, fill=fill) -def crop(img, x, y, w, h): +def crop(img, i, j, h, w): """Crop the given PIL.Image. Args: img (PIL.Image): Image to be cropped. - x: Left pixel coordinate. - y: Upper pixel coordinate. - w: Width of the cropped image. + i: Upper pixel coordinate. + j: Left pixel coordinate. h: Height of the cropped image. + w: Width of the cropped image. Returns: PIL.Image: Cropped image. @@ -225,20 +225,20 @@ def crop(img, x, y, w, h): if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - return img.crop((x, y, x + w, y + h)) + return img.crop((j, i, j + w, i + h)) -def scaled_crop(img, x, y, w, h, size, interpolation=Image.BILINEAR): +def scaled_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR): """Crop the given PIL.Image and scale it to desired size. Notably used in RandomSizedCrop. Args: img (PIL.Image): Image to be cropped. - x: Left pixel coordinate. - y: Upper pixel coordinate. - w: Width of the cropped image. + i: Upper pixel coordinate. + j: Left pixel coordinate. h: Height of the cropped image. + w: Width of the cropped image. size (sequence or int): Desired output size. Same semantics as ``scale``. interpolation (int, optional): Desired interpolation. Default is ``PIL.Image.BILINEAR``. @@ -246,7 +246,7 @@ def scaled_crop(img, x, y, w, h, size, interpolation=Image.BILINEAR): PIL.Image: Cropped image. """ assert _is_pil_image(img), 'img should be PIL Image' - img = crop(img, x, y, w, h) + img = crop(img, i, j, h, w) img = scale(img, size, interpolation) return img @@ -406,13 +406,13 @@ def get_params(img, output_size): output_size (tuple): Expected output size of the crop. Returns: - tuple: params (x, y, w, h) to be passed to ``crop`` for center crop. + tuple: params (i, j, h, w) to be passed to ``crop`` for center crop. """ w, h = img.size th, tw = output_size - x1 = int(round((w - tw) / 2.)) - y1 = int(round((h - th) / 2.)) - return x1, y1, tw, th + i = int(round((h - th) / 2.)) + j = int(round((w - tw) / 2.)) + return i, j, th, tw def __call__(self, img): """ @@ -422,8 +422,8 @@ def __call__(self, img): Returns: PIL.Image: Cropped image. """ - x1, y1, tw, th = self.get_params(img, self.size) - return crop(img, x1, y1, tw, th) + i, j, h, w = self.get_params(img, self.size) + return crop(img, i, j, h, w) class Pad(object): @@ -504,16 +504,16 @@ def get_params(img, output_size): output_size (tuple): Expected output size of the crop. Returns: - tuple: params (x, y, w, h) to be passed to ``crop`` for random crop. + tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. """ w, h = img.size th, tw = output_size if w == tw and h == th: return img - x1 = random.randint(0, w - tw) - y1 = random.randint(0, h - th) - return x1, y1, tw, th + i = random.randint(0, h - th) + j = random.randint(0, w - tw) + return i, j, th, tw def __call__(self, img): """ @@ -526,9 +526,9 @@ def __call__(self, img): if self.padding > 0: img = pad(img, self.padding) - x1, y1, tw, th = self.get_params(img, self.size) + i, j, h, w = self.get_params(img, self.size) - return crop(img, x1, y1, tw, th) + return crop(img, i, j, h, w) class RandomHorizontalFlip(object): @@ -572,7 +572,7 @@ def get_params(img): img (PIL.Image): Image to be cropped. Returns: - tuple: params (x, y, w, h) to be passed to ``crop`` for a random + tuple: params (i, j, h, w) to be passed to ``crop`` for a random sized crop. """ for attempt in range(10): @@ -587,15 +587,15 @@ def get_params(img): w, h = h, w if w <= img.size[0] and h <= img.size[1]: - x = random.randint(0, img.size[0] - w) - y = random.randint(0, img.size[1] - h) - return x, y, w, h + i = random.randint(0, img.size[1] - h) + j = random.randint(0, img.size[0] - w) + return i, j, h, w # Fallback w = min(img.size[0], img.shape[1]) - x = (img.shape[0] - w) // 2 - y = (img.shape[1] - w) // 2 - return x, y, w, w + i = (img.shape[1] - w) // 2 + j = (img.shape[0] - w) // 2 + return i, j, w, w def __call__(self, img): """ @@ -605,5 +605,5 @@ def __call__(self, img): Returns: PIL.Image: Randomly cropped and scaled image. """ - x, y, w, h = self.get_params(img) - return scaled_crop(img, x, y, w, h, self.size, self.interpolation) + i, j, h, w = self.get_params(img) + return scaled_crop(img, i, j, h, w, self.size, self.interpolation)