Skip to content

Commit fbe4ad5

Browse files
add RandomVerticalFlip transform (pytorch#262)
add RandomVerticalFlip transform
1 parent 459dc59 commit fbe4ad5

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

test/test_transforms.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99
except ImportError:
1010
accimage = None
1111

12+
try:
13+
from scipy import stats
14+
except ImportError:
15+
stats = None
16+
1217

1318
GRACE_HOPPER = 'assets/grace_hopper_517x606.jpg'
1419

@@ -327,6 +332,34 @@ def test_ndarray_gray_int32_to_pil_image(self):
327332
assert img.mode == 'I'
328333
assert np.allclose(img, img_data[:, :, 0])
329334

335+
@unittest.skipIf(stats is None, 'scipy.stats not available')
336+
def test_random_vertical_flip(self):
337+
img = transforms.ToPILImage()(torch.rand(3, 10, 10))
338+
vimg = img.transpose(Image.FLIP_TOP_BOTTOM)
339+
340+
num_vertical = 0
341+
for _ in range(100):
342+
out = transforms.RandomVerticalFlip()(img)
343+
if out == vimg:
344+
num_vertical += 1
345+
346+
p_value = stats.binom_test(num_vertical, 100, p=0.5)
347+
assert p_value > 0.05
348+
349+
@unittest.skipIf(stats is None, 'scipy.stats not available')
350+
def test_random_horizontal_flip(self):
351+
img = transforms.ToPILImage()(torch.rand(3, 10, 10))
352+
himg = img.transpose(Image.FLIP_LEFT_RIGHT)
353+
354+
num_horizontal = 0
355+
for _ in range(100):
356+
out = transforms.RandomHorizontalFlip()(img)
357+
if out == himg:
358+
num_horizontal += 1
359+
360+
p_value = stats.binom_test(num_horizontal, 100, p=0.5)
361+
assert p_value > 0.05
362+
330363

331364
if __name__ == '__main__':
332365
unittest.main()

torchvision/transforms.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,22 @@ def __call__(self, img):
547547
return img
548548

549549

550+
class RandomVerticalFlip(object):
551+
"""Vertically flip the given PIL.Image randomly with a probability of 0.5"""
552+
553+
def __call__(self, img):
554+
"""
555+
Args:
556+
img (PIL.Image): Image to be flipped.
557+
558+
Returns:
559+
PIL.Image: Randomly flipped image.
560+
"""
561+
if random.random() < 0.5:
562+
return img.transpose(Image.FLIP_TOP_BOTTOM)
563+
return img
564+
565+
550566
class RandomSizedCrop(object):
551567
"""Crop the given PIL.Image to random size and aspect ratio.
552568

0 commit comments

Comments
 (0)