+import math
+import numbers
+import random
+import warnings
+from collections.abc import Sequence
+from typing import Tuple, List, Optional
+
+import torch
+from PIL import Image
+from torch import Tensor
+
+try:
+ import accimage
+except ImportError:
+ accimage = None
+
+from . import functional as F
+
+__all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale",
+ "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop",
+ "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
+ "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
+ "RandomPerspective", "RandomErasing", "GaussianBlur"]
+
+_pil_interpolation_to_str = {
+ Image.NEAREST: 'PIL.Image.NEAREST',
+ Image.BILINEAR: 'PIL.Image.BILINEAR',
+ Image.BICUBIC: 'PIL.Image.BICUBIC',
+ Image.LANCZOS: 'PIL.Image.LANCZOS',
+ Image.HAMMING: 'PIL.Image.HAMMING',
+ Image.BOX: 'PIL.Image.BOX',
+}
+
+
+[docs]class Compose:
+
"""Composes several transforms together. This transform does not support torchscript.
+
Please, see the note below.
+
+
Args:
+
transforms (list of ``Transform`` objects): list of transforms to compose.
+
+
Example:
+
>>> transforms.Compose([
+
>>> transforms.CenterCrop(10),
+
>>> transforms.ToTensor(),
+
>>> ])
+
+
.. note::
+
In order to script the transformations, please use ``torch.nn.Sequential`` as below.
+
+
>>> transforms = torch.nn.Sequential(
+
>>> transforms.CenterCrop(10),
+
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
+
>>> )
+
>>> scripted_transforms = torch.jit.script(transforms)
+
+
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
+
`lambda` functions or ``PIL.Image``.
+
+
"""
+
+
def __init__(self, transforms):
+
self.transforms = transforms
+
+
def __call__(self, img):
+
for t in self.transforms:
+
img = t(img)
+
return img
+
+
def __repr__(self):
+
format_string = self.__class__.__name__ + '('
+
for t in self.transforms:
+
format_string += '\n'
+
format_string += ' {0}'.format(t)
+
format_string += '\n)'
+
return format_string
+
+
+[docs]class ToTensor:
+
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript.
+
+
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
+
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
+
if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
+
or if the numpy.ndarray has dtype = np.uint8
+
+
In the other cases, tensors are returned without scaling.
+
+
.. note::
+
Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
+
transforming target image masks. See the `references`_ for implementing the transforms for image masks.
+
+
.. _references: https://github.com/pytorch/vision/tree/master/references/segmentation
+
"""
+
+
def __call__(self, pic):
+
"""
+
Args:
+
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
+
+
Returns:
+
Tensor: Converted image.
+
"""
+
return F.to_tensor(pic)
+
+
def __repr__(self):
+
return self.__class__.__name__ + '()'
+
+
+class PILToTensor:
+ """Convert a ``PIL Image`` to a tensor of the same type. This transform does not support torchscript.
+
+ Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W).
+ """
+
+ def __call__(self, pic):
+ """
+ Args:
+ pic (PIL Image): Image to be converted to tensor.
+
+ Returns:
+ Tensor: Converted image.
+ """
+ return F.pil_to_tensor(pic)
+
+ def __repr__(self):
+ return self.__class__.__name__ + '()'
+
+
+[docs]class ConvertImageDtype(torch.nn.Module):
+
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly
+
+
Args:
+
dtype (torch.dtype): Desired data type of the output
+
+
.. note::
+
+
When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
+
If converted back and forth, this mismatch has no effect.
+
+
Raises:
+
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
+
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
+
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
+
of the integer ``dtype``.
+
"""
+
+
def __init__(self, dtype: torch.dtype) -> None:
+
super().__init__()
+
self.dtype = dtype
+
+
def forward(self, image: torch.Tensor) -> torch.Tensor:
+
return F.convert_image_dtype(image, self.dtype)
+
+
+[docs]class ToPILImage:
+
"""Convert a tensor or an ndarray to PIL Image. This transform does not support torchscript.
+
+
Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
+
H x W x C to a PIL Image while preserving the value range.
+
+
Args:
+
mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
+
If ``mode`` is ``None`` (default) there are some assumptions made about the input data:
+
- If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``.
+
- If the input has 3 channels, the ``mode`` is assumed to be ``RGB``.
+
- If the input has 2 channels, the ``mode`` is assumed to be ``LA``.
+
- If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``,
+
``short``).
+
+
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
+
"""
+
def __init__(self, mode=None):
+
self.mode = mode
+
+
def __call__(self, pic):
+
"""
+
Args:
+
pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
+
+
Returns:
+
PIL Image: Image converted to PIL Image.
+
+
"""
+
return F.to_pil_image(pic, self.mode)
+
+
def __repr__(self):
+
format_string = self.__class__.__name__ + '('
+
if self.mode is not None:
+
format_string += 'mode={0}'.format(self.mode)
+
format_string += ')'
+
return format_string
+
+
+[docs]class Normalize(torch.nn.Module):
+
"""Normalize a tensor image with mean and standard deviation.
+
Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
+
channels, this transform will normalize each channel of the input
+
``torch.*Tensor`` i.e.,
+
``output[channel] = (input[channel] - mean[channel]) / std[channel]``
+
+
.. note::
+
This transform acts out of place, i.e., it does not mutate the input tensor.
+
+
Args:
+
mean (sequence): Sequence of means for each channel.
+
std (sequence): Sequence of standard deviations for each channel.
+
inplace(bool,optional): Bool to make this operation in-place.
+
+
"""
+
+
def __init__(self, mean, std, inplace=False):
+
super().__init__()
+
self.mean = mean
+
self.std = std
+
self.inplace = inplace
+
+
[docs] def forward(self, tensor: Tensor) -> Tensor:
+
"""
+
Args:
+
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
+
+
Returns:
+
Tensor: Normalized Tensor image.
+
"""
+
return F.normalize(tensor, self.mean, self.std, self.inplace)
+
+
def __repr__(self):
+
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
+
+
+[docs]class Resize(torch.nn.Module):
+
"""Resize the input image to the given size.
+
The image can be a PIL Image or a torch Tensor, in which case it is expected
+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
+
+
Args:
+
size (sequence or int): Desired output size. If size is a sequence like
+
(h, w), output size will be matched to this. If size is an int,
+
smaller edge of the image will be matched to this number.
+
i.e, if height > width, then image will be rescaled to
+
(size * height / width, size).
+
In torchscript mode padding as single int is not supported, use a tuple or
+
list of length 1: ``[size, ]``.
+
interpolation (int, optional): Desired interpolation enum defined by `filters`_.
+
Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR``
+
and ``PIL.Image.BICUBIC`` are supported.
+
"""
+
+
def __init__(self, size, interpolation=Image.BILINEAR):
+
super().__init__()
+
self.size = _setup_size(size, error_msg="If size is a sequence, it should have 2 values")
+
self.interpolation = interpolation
+
+
[docs] def forward(self, img):
+
"""
+
Args:
+
img (PIL Image or Tensor): Image to be scaled.
+
+
Returns:
+
PIL Image or Tensor: Rescaled image.
+
"""
+
return F.resize(img, self.size, self.interpolation)
+
+
def __repr__(self):
+
interpolate_str = _pil_interpolation_to_str[self.interpolation]
+
return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
+
+
+[docs]class Scale(Resize):
+
"""
+
Note: This transform is deprecated in favor of Resize.
+
"""
+
def __init__(self, *args, **kwargs):
+
warnings.warn("The use of the transforms.Scale transform is deprecated, " +
+
"please use transforms.Resize instead.")
+
super(Scale, self).__init__(*args, **kwargs)
+
+
+[docs]class CenterCrop(torch.nn.Module):
+
"""Crops the given image at the center.
+
The image can be a PIL Image or a torch Tensor, in which case it is expected
+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
+
+
Args:
+
size (sequence or int): Desired output size of the crop. If size is an
+
int instead of sequence like (h, w), a square crop (size, size) is
+
made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]).
+
"""
+
+
def __init__(self, size):
+
super().__init__()
+
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
+
+
[docs] def forward(self, img):
+
"""
+
Args:
+
img (PIL Image or Tensor): Image to be cropped.
+
+
Returns:
+
PIL Image or Tensor: Cropped image.
+
"""
+
return F.center_crop(img, self.size)
+
+
def __repr__(self):
+
return self.__class__.__name__ + '(size={0})'.format(self.size)
+
+
+[docs]class Pad(torch.nn.Module):
+
"""Pad the given image on all sides with the given "pad" value.
+
The image can be a PIL Image or a torch Tensor, in which case it is expected
+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
+
+
Args:
+
padding (int or tuple or list): Padding on each border. If a single int is provided this
+
is used to pad all borders. If tuple of length 2 is provided this is the padding
+
on left/right and top/bottom respectively. If a tuple of length 4 is provided
+
this is the padding for the left, top, right and bottom borders respectively.
+
In torchscript mode padding as single int is not supported, use a tuple or
+
list of length 1: ``[padding, ]``.
+
fill (int or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
+
length 3, it is used to fill R, G, B channels respectively.
+
This value is only used when the padding_mode is constant
+
padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
+
Default is constant. Mode symmetric is not yet supported for Tensor inputs.
+
+
- constant: pads with a constant value, this value is specified with fill
+
+
- edge: pads with the last value at the edge of the image
+
+
- reflect: pads with reflection of image without repeating the last value on the edge
+
+
For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
+
will result in [3, 2, 1, 2, 3, 4, 3, 2]
+
+
- symmetric: pads with reflection of image repeating the last value on the edge
+
+
For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
+
will result in [2, 1, 1, 2, 3, 4, 4, 3]
+
"""
+
+
def __init__(self, padding, fill=0, padding_mode="constant"):
+
super().__init__()
+
if not isinstance(padding, (numbers.Number, tuple, list)):
+
raise TypeError("Got inappropriate padding arg")
+
+
if not isinstance(fill, (numbers.Number, str, tuple)):
+
raise TypeError("Got inappropriate fill arg")
+
+
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
+
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
+
+
if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]:
+
raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " +
+
"{} element tuple".format(len(padding)))
+
+
self.padding = padding
+
self.fill = fill
+
self.padding_mode = padding_mode
+
+
[docs] def forward(self, img):
+
"""
+
Args:
+
img (PIL Image or Tensor): Image to be padded.
+
+
Returns:
+
PIL Image or Tensor: Padded image.
+
"""
+
return F.pad(img, self.padding, self.fill, self.padding_mode)
+
+
def __repr__(self):
+
return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\
+
format(self.padding, self.fill, self.padding_mode)
+
+
+[docs]class Lambda:
+
"""Apply a user-defined lambda as a transform. This transform does not support torchscript.
+
+
Args:
+
lambd (function): Lambda/function to be used for transform.
+
"""
+
+
def __init__(self, lambd):
+
if not callable(lambd):
+
raise TypeError("Argument lambd should be callable, got {}".format(repr(type(lambd).__name__)))
+
self.lambd = lambd
+
+
def __call__(self, img):
+
return self.lambd(img)
+
+
def __repr__(self):
+
return self.__class__.__name__ + '()'
+
+
+class RandomTransforms:
+ """Base class for a list of transformations with randomness
+
+ Args:
+ transforms (list or tuple): list of transformations
+ """
+
+ def __init__(self, transforms):
+ if not isinstance(transforms, Sequence):
+ raise TypeError("Argument transforms should be a sequence")
+ self.transforms = transforms
+
+ def __call__(self, *args, **kwargs):
+ raise NotImplementedError()
+
+ def __repr__(self):
+ format_string = self.__class__.__name__ + '('
+ for t in self.transforms:
+ format_string += '\n'
+ format_string += ' {0}'.format(t)
+ format_string += '\n)'
+ return format_string
+
+
+[docs]class RandomApply(torch.nn.Module):
+
"""Apply randomly a list of transformations with a given probability.
+
+
.. note::
+
In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of
+
transforms as shown below:
+
+
>>> transforms = transforms.RandomApply(torch.nn.ModuleList([
+
>>> transforms.ColorJitter(),
+
>>> ]), p=0.3)
+
>>> scripted_transforms = torch.jit.script(transforms)
+
+
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
+
`lambda` functions or ``PIL.Image``.
+
+
Args:
+
transforms (list or tuple or torch.nn.Module): list of transformations
+
p (float): probability
+
"""
+
+
def __init__(self, transforms, p=0.5):
+
super().__init__()
+
self.transforms = transforms
+
self.p = p
+
+
def forward(self, img):
+
if self.p < torch.rand(1):
+
return img
+
for t in self.transforms:
+
img = t(img)
+
return img
+
+
def __repr__(self):
+
format_string = self.__class__.__name__ + '('
+
format_string += '\n p={}'.format(self.p)
+
for t in self.transforms:
+
format_string += '\n'
+
format_string += ' {0}'.format(t)
+
format_string += '\n)'
+
return format_string
+
+
+[docs]class RandomOrder(RandomTransforms):
+
"""Apply a list of transformations in a random order. This transform does not support torchscript.
+
"""
+
def __call__(self, img):
+
order = list(range(len(self.transforms)))
+
random.shuffle(order)
+
for i in order:
+
img = self.transforms[i](img)
+
return img
+
+
+[docs]class RandomChoice(RandomTransforms):
+
"""Apply single transformation randomly picked from a list. This transform does not support torchscript.
+
"""
+
def __call__(self, img):
+
t = random.choice(self.transforms)
+
return t(img)
+
+
+[docs]class RandomCrop(torch.nn.Module):
+
"""Crop the given image at a random location.
+
The image can be a PIL Image or a Tensor, in which case it is expected
+
to have [..., H, W] shape, where ... means an arbitrary number of leading
+
dimensions
+
+
Args:
+
size (sequence or int): Desired output size of the crop. If size is an
+
int instead of sequence like (h, w), a square crop (size, size) is
+
made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]).
+
padding (int or sequence, optional): Optional padding on each border
+
of the image. Default is None. If a single int is provided this
+
is used to pad all borders. If tuple of length 2 is provided this is the padding
+
on left/right and top/bottom respectively. If a tuple of length 4 is provided
+
this is the padding for the left, top, right and bottom borders respectively.
+
In torchscript mode padding as single int is not supported, use a tuple or
+
list of length 1: ``[padding, ]``.
+
pad_if_needed (boolean): It will pad the image if smaller than the
+
desired size to avoid raising an exception. Since cropping is done
+
after padding, the padding seems to be done at a random offset.
+
fill (int or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
+
length 3, it is used to fill R, G, B channels respectively.
+
This value is only used when the padding_mode is constant
+
padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
+
Mode symmetric is not yet supported for Tensor inputs.
+
+
- constant: pads with a constant value, this value is specified with fill
+
+
- edge: pads with the last value on the edge of the image
+
+
- reflect: pads with reflection of image (without repeating the last value on the edge)
+
+
padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
+
will result in [3, 2, 1, 2, 3, 4, 3, 2]
+
+
- symmetric: pads with reflection of image (repeating the last value on the edge)
+
+
padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
+
will result in [2, 1, 1, 2, 3, 4, 4, 3]
+
+
"""
+
+
[docs] @staticmethod
+
def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]:
+
"""Get parameters for ``crop`` for a random crop.
+
+
Args:
+
img (PIL Image or Tensor): Image to be cropped.
+
output_size (tuple): Expected output size of the crop.
+
+
Returns:
+
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
+
"""
+
w, h = F._get_image_size(img)
+
th, tw = output_size
+
if w == tw and h == th:
+
return 0, 0, h, w
+
+
i = torch.randint(0, h - th + 1, size=(1, )).item()
+
j = torch.randint(0, w - tw + 1, size=(1, )).item()
+
return i, j, th, tw
+
+
def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"):
+
super().__init__()
+
+
self.size = tuple(_setup_size(
+
size, error_msg="Please provide only two dimensions (h, w) for size."
+
))
+
+
self.padding = padding
+
self.pad_if_needed = pad_if_needed
+
self.fill = fill
+
self.padding_mode = padding_mode
+
+
[docs] def forward(self, img):
+
"""
+
Args:
+
img (PIL Image or Tensor): Image to be cropped.
+
+
Returns:
+
PIL Image or Tensor: Cropped image.
+
"""
+
if self.padding is not None:
+
img = F.pad(img, self.padding, self.fill, self.padding_mode)
+
+
width, height = F._get_image_size(img)
+
# pad the width if needed
+
if self.pad_if_needed and width < self.size[1]:
+
padding = [self.size[1] - width, 0]
+
img = F.pad(img, padding, self.fill, self.padding_mode)
+
# pad the height if needed
+
if self.pad_if_needed and height < self.size[0]:
+
padding = [0, self.size[0] - height]
+
img = F.pad(img, padding, self.fill, self.padding_mode)
+
+
i, j, h, w = self.get_params(img, self.size)
+
+
return F.crop(img, i, j, h, w)
+
+
def __repr__(self):
+
return self.__class__.__name__ + "(size={0}, padding={1})".format(self.size, self.padding)
+
+
+[docs]class RandomHorizontalFlip(torch.nn.Module):
+
"""Horizontally flip the given image randomly with a given probability.
+
The image can be a PIL Image or a torch Tensor, in which case it is expected
+
to have [..., H, W] shape, where ... means an arbitrary number of leading
+
dimensions
+
+
Args:
+
p (float): probability of the image being flipped. Default value is 0.5
+
"""
+
+
def __init__(self, p=0.5):
+
super().__init__()
+
self.p = p
+
+
[docs] def forward(self, img):
+
"""
+
Args:
+
img (PIL Image or Tensor): Image to be flipped.
+
+
Returns:
+
PIL Image or Tensor: Randomly flipped image.
+
"""
+
if torch.rand(1) < self.p:
+
return F.hflip(img)
+
return img
+
+
def __repr__(self):
+
return self.__class__.__name__ + '(p={})'.format(self.p)
+
+
+[docs]class RandomVerticalFlip(torch.nn.Module):
+
"""Vertically flip the given image randomly with a given probability.
+
The image can be a PIL Image or a torch Tensor, in which case it is expected
+
to have [..., H, W] shape, where ... means an arbitrary number of leading
+
dimensions
+
+
Args:
+
p (float): probability of the image being flipped. Default value is 0.5
+
"""
+
+
def __init__(self, p=0.5):
+
super().__init__()
+
self.p = p
+
+
[docs] def forward(self, img):
+
"""
+
Args:
+
img (PIL Image or Tensor): Image to be flipped.
+
+
Returns:
+
PIL Image or Tensor: Randomly flipped image.
+
"""
+
if torch.rand(1) < self.p:
+
return F.vflip(img)
+
return img
+
+
def __repr__(self):
+
return self.__class__.__name__ + '(p={})'.format(self.p)
+
+
+[docs]class RandomPerspective(torch.nn.Module):
+
"""Performs a random perspective transformation of the given image with a given probability.
+
The image can be a PIL Image or a Tensor, in which case it is expected
+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
+
+
Args:
+
distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
+
Default is 0.5.
+
p (float): probability of the image being transformed. Default is 0.5.
+
interpolation (int): Interpolation type. If input is Tensor, only ``PIL.Image.NEAREST`` and
+
``PIL.Image.BILINEAR`` are supported. Default, ``PIL.Image.BILINEAR`` for PIL images and Tensors.
+
fill (n-tuple or int or float): Pixel fill value for area outside the rotated
+
image. If int or float, the value is used for all bands respectively. Default is 0.
+
This option is only available for ``pillow>=5.0.0``. This option is not supported for Tensor
+
input. Fill value for the area outside the transform in the output image is always 0.
+
+
"""
+
+
def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BILINEAR, fill=0):
+
super().__init__()
+
self.p = p
+
self.interpolation = interpolation
+
self.distortion_scale = distortion_scale
+
self.fill = fill
+
+
[docs] def forward(self, img):
+
"""
+
Args:
+
img (PIL Image or Tensor): Image to be Perspectively transformed.
+
+
Returns:
+
PIL Image or Tensor: Randomly transformed image.
+
"""
+
if torch.rand(1) < self.p:
+
width, height = F._get_image_size(img)
+
startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
+
return F.perspective(img, startpoints, endpoints, self.interpolation, self.fill)
+
return img
+
+
[docs] @staticmethod
+
def get_params(width: int, height: int, distortion_scale: float) -> Tuple[List[List[int]], List[List[int]]]:
+
"""Get parameters for ``perspective`` for a random perspective transform.
+
+
Args:
+
width (int): width of the image.
+
height (int): height of the image.
+
distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
+
+
Returns:
+
List containing [top-left, top-right, bottom-right, bottom-left] of the original image,
+
List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image.
+
"""
+
half_height = height // 2
+
half_width = width // 2
+
topleft = [
+
int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()),
+
int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item())
+
]
+
topright = [
+
int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()),
+
int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item())
+
]
+
botright = [
+
int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()),
+
int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item())
+
]
+
botleft = [
+
int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()),
+
int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item())
+
]
+
startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
+
endpoints = [topleft, topright, botright, botleft]
+
return startpoints, endpoints
+
+
def __repr__(self):
+
return self.__class__.__name__ + '(p={})'.format(self.p)
+
+
+[docs]class RandomResizedCrop(torch.nn.Module):
+
"""Crop the given image to random size and aspect ratio.
+
The image can be a PIL Image or a Tensor, in which case it is expected
+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
+
+
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
+
aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
+
is finally resized to given size.
+
This is popularly used to train the Inception networks.
+
+
Args:
+
size (int or sequence): expected output size of each edge. If size is an
+
int instead of sequence like (h, w), a square output size ``(size, size)`` is
+
made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]).
+
scale (tuple of float): range of size of the origin size cropped
+
ratio (tuple of float): range of aspect ratio of the origin aspect ratio cropped.
+
interpolation (int): Desired interpolation enum defined by `filters`_.
+
Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR``
+
and ``PIL.Image.BICUBIC`` are supported.
+
"""
+
+
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR):
+
super().__init__()
+
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
+
+
if not isinstance(scale, Sequence):
+
raise TypeError("Scale should be a sequence")
+
if not isinstance(ratio, Sequence):
+
raise TypeError("Ratio should be a sequence")
+
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
+
warnings.warn("Scale and ratio should be of kind (min, max)")
+
+
self.interpolation = interpolation
+
self.scale = scale
+
self.ratio = ratio
+
+
[docs] @staticmethod
+
def get_params(
+
img: Tensor, scale: List[float], ratio: List[float]
+
) -> Tuple[int, int, int, int]:
+
"""Get parameters for ``crop`` for a random sized crop.
+
+
Args:
+
img (PIL Image or Tensor): Input image.
+
scale (list): range of scale of the origin size cropped
+
ratio (list): range of aspect ratio of the origin aspect ratio cropped
+
+
Returns:
+
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
+
sized crop.
+
"""
+
width, height = F._get_image_size(img)
+
area = height * width
+
+
for _ in range(10):
+
target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
+
log_ratio = torch.log(torch.tensor(ratio))
+
aspect_ratio = torch.exp(
+
torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
+
).item()
+
+
w = int(round(math.sqrt(target_area * aspect_ratio)))
+
h = int(round(math.sqrt(target_area / aspect_ratio)))
+
+
if 0 < w <= width and 0 < h <= height:
+
i = torch.randint(0, height - h + 1, size=(1,)).item()
+
j = torch.randint(0, width - w + 1, size=(1,)).item()
+
return i, j, h, w
+
+
# Fallback to central crop
+
in_ratio = float(width) / float(height)
+
if in_ratio < min(ratio):
+
w = width
+
h = int(round(w / min(ratio)))
+
elif in_ratio > max(ratio):
+
h = height
+
w = int(round(h * max(ratio)))
+
else: # whole image
+
w = width
+
h = height
+
i = (height - h) // 2
+
j = (width - w) // 2
+
return i, j, h, w
+
+
[docs] def forward(self, img):
+
"""
+
Args:
+
img (PIL Image or Tensor): Image to be cropped and resized.
+
+
Returns:
+
PIL Image or Tensor: Randomly cropped and resized image.
+
"""
+
i, j, h, w = self.get_params(img, self.scale, self.ratio)
+
return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
+
+
def __repr__(self):
+
interpolate_str = _pil_interpolation_to_str[self.interpolation]
+
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
+
format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
+
format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
+
format_string += ', interpolation={0})'.format(interpolate_str)
+
return format_string
+
+
+[docs]class RandomSizedCrop(RandomResizedCrop):
+
"""
+
Note: This transform is deprecated in favor of RandomResizedCrop.
+
"""
+
def __init__(self, *args, **kwargs):
+
warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " +
+
"please use transforms.RandomResizedCrop instead.")
+
super(RandomSizedCrop, self).__init__(*args, **kwargs)
+
+
+[docs]class FiveCrop(torch.nn.Module):
+
"""Crop the given image into four corners and the central crop.
+
The image can be a PIL Image or a Tensor, in which case it is expected
+
to have [..., H, W] shape, where ... means an arbitrary number of leading
+
dimensions
+
+
.. Note::
+
This transform returns a tuple of images and there may be a mismatch in the number of
+
inputs and targets your Dataset returns. See below for an example of how to deal with
+
this.
+
+
Args:
+
size (sequence or int): Desired output size of the crop. If size is an ``int``
+
instead of sequence like (h, w), a square crop of size (size, size) is made.
+
If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]).
+
+
Example:
+
>>> transform = Compose([
+
>>> FiveCrop(size), # this is a list of PIL Images
+
>>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
+
>>> ])
+
>>> #In your test loop you can do the following:
+
>>> input, target = batch # input is a 5d tensor, target is 2d
+
>>> bs, ncrops, c, h, w = input.size()
+
>>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
+
>>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
+
"""
+
+
def __init__(self, size):
+
super().__init__()
+
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
+
+
[docs] def forward(self, img):
+
"""
+
Args:
+
img (PIL Image or Tensor): Image to be cropped.
+
+
Returns:
+
tuple of 5 images. Image can be PIL Image or Tensor
+
"""
+
return F.five_crop(img, self.size)
+
+
def __repr__(self):
+
return self.__class__.__name__ + '(size={0})'.format(self.size)
+
+
+[docs]class TenCrop(torch.nn.Module):
+
"""Crop the given image into four corners and the central crop plus the flipped version of
+
these (horizontal flipping is used by default).
+
The image can be a PIL Image or a Tensor, in which case it is expected
+
to have [..., H, W] shape, where ... means an arbitrary number of leading
+
dimensions
+
+
.. Note::
+
This transform returns a tuple of images and there may be a mismatch in the number of
+
inputs and targets your Dataset returns. See below for an example of how to deal with
+
this.
+
+
Args:
+
size (sequence or int): Desired output size of the crop. If size is an
+
int instead of sequence like (h, w), a square crop (size, size) is
+
made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]).
+
vertical_flip (bool): Use vertical flipping instead of horizontal
+
+
Example:
+
>>> transform = Compose([
+
>>> TenCrop(size), # this is a list of PIL Images
+
>>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
+
>>> ])
+
>>> #In your test loop you can do the following:
+
>>> input, target = batch # input is a 5d tensor, target is 2d
+
>>> bs, ncrops, c, h, w = input.size()
+
>>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
+
>>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
+
"""
+
+
def __init__(self, size, vertical_flip=False):
+
super().__init__()
+
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
+
self.vertical_flip = vertical_flip
+
+
[docs] def forward(self, img):
+
"""
+
Args:
+
img (PIL Image or Tensor): Image to be cropped.
+
+
Returns:
+
tuple of 10 images. Image can be PIL Image or Tensor
+
"""
+
return F.ten_crop(img, self.size, self.vertical_flip)
+
+
def __repr__(self):
+
return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip)
+
+
+
+
+
+[docs]class ColorJitter(torch.nn.Module):
+
"""Randomly change the brightness, contrast and saturation of an image.
+
+
Args:
+
brightness (float or tuple of float (min, max)): How much to jitter brightness.
+
brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
+
or the given [min, max]. Should be non negative numbers.
+
contrast (float or tuple of float (min, max)): How much to jitter contrast.
+
contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
+
or the given [min, max]. Should be non negative numbers.
+
saturation (float or tuple of float (min, max)): How much to jitter saturation.
+
saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
+
or the given [min, max]. Should be non negative numbers.
+
hue (float or tuple of float (min, max)): How much to jitter hue.
+
hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
+
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
+
"""
+
+
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
+
super().__init__()
+
self.brightness = self._check_input(brightness, 'brightness')
+
self.contrast = self._check_input(contrast, 'contrast')
+
self.saturation = self._check_input(saturation, 'saturation')
+
self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
+
clip_first_on_zero=False)
+
+
@torch.jit.unused
+
def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
+
if isinstance(value, numbers.Number):
+
if value < 0:
+
raise ValueError("If {} is a single number, it must be non negative.".format(name))
+
value = [center - float(value), center + float(value)]
+
if clip_first_on_zero:
+
value[0] = max(value[0], 0.0)
+
elif isinstance(value, (tuple, list)) and len(value) == 2:
+
if not bound[0] <= value[0] <= value[1] <= bound[1]:
+
raise ValueError("{} values should be between {}".format(name, bound))
+
else:
+
raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))
+
+
# if value is 0 or (1., 1.) for brightness/contrast/saturation
+
# or (0., 0.) for hue, do nothing
+
if value[0] == value[1] == center:
+
value = None
+
return value
+
+
[docs] @staticmethod
+
@torch.jit.unused
+
def get_params(brightness, contrast, saturation, hue):
+
"""Get a randomized transform to be applied on image.
+
+
Arguments are same as that of __init__.
+
+
Returns:
+
Transform which randomly adjusts brightness, contrast and
+
saturation in a random order.
+
"""
+
transforms = []
+
+
if brightness is not None:
+
brightness_factor = random.uniform(brightness[0], brightness[1])
+
transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor)))
+
+
if contrast is not None:
+
contrast_factor = random.uniform(contrast[0], contrast[1])
+
transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor)))
+
+
if saturation is not None:
+
saturation_factor = random.uniform(saturation[0], saturation[1])
+
transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor)))
+
+
if hue is not None:
+
hue_factor = random.uniform(hue[0], hue[1])
+
transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor)))
+
+
random.shuffle(transforms)
+
transform = Compose(transforms)
+
+
return transform
+
+
[docs] def forward(self, img):
+
"""
+
Args:
+
img (PIL Image or Tensor): Input image.
+
+
Returns:
+
PIL Image or Tensor: Color jittered image.
+
"""
+
fn_idx = torch.randperm(4)
+
for fn_id in fn_idx:
+
if fn_id == 0 and self.brightness is not None:
+
brightness = self.brightness
+
brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
+
img = F.adjust_brightness(img, brightness_factor)
+
+
if fn_id == 1 and self.contrast is not None:
+
contrast = self.contrast
+
contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
+
img = F.adjust_contrast(img, contrast_factor)
+
+
if fn_id == 2 and self.saturation is not None:
+
saturation = self.saturation
+
saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
+
img = F.adjust_saturation(img, saturation_factor)
+
+
if fn_id == 3 and self.hue is not None:
+
hue = self.hue
+
hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
+
img = F.adjust_hue(img, hue_factor)
+
+
return img
+
+
def __repr__(self):
+
format_string = self.__class__.__name__ + '('
+
format_string += 'brightness={0}'.format(self.brightness)
+
format_string += ', contrast={0}'.format(self.contrast)
+
format_string += ', saturation={0}'.format(self.saturation)
+
format_string += ', hue={0})'.format(self.hue)
+
return format_string
+
+
+[docs]class RandomRotation(torch.nn.Module):
+
"""Rotate the image by angle.
+
The image can be a PIL Image or a Tensor, in which case it is expected
+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
+
+
Args:
+
degrees (sequence or float or int): Range of degrees to select from.
+
If degrees is a number instead of sequence like (min, max), the range of degrees
+
will be (-degrees, +degrees).
+
resample (int, optional): An optional resampling filter. See `filters`_ for more information.
+
If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
+
If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported.
+
expand (bool, optional): Optional expansion flag.
+
If true, expands the output to make it large enough to hold the entire rotated image.
+
If false or omitted, make the output image the same size as the input image.
+
Note that the expand flag assumes rotation around the center and no translation.
+
center (list or tuple, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
+
Default is the center of the image.
+
fill (n-tuple or int or float): Pixel fill value for area outside the rotated
+
image. If int or float, the value is used for all bands respectively.
+
Defaults to 0 for all bands. This option is only available for Pillow>=5.2.0.
+
This option is not supported for Tensor input. Fill value for the area outside the transform in the output
+
image is always 0.
+
+
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
+
+
"""
+
+
def __init__(self, degrees, resample=False, expand=False, center=None, fill=None):
+
super().__init__()
+
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, ))
+
+
if center is not None:
+
_check_sequence_input(center, "center", req_sizes=(2, ))
+
+
self.center = center
+
+
self.resample = resample
+
self.expand = expand
+
self.fill = fill
+
+
[docs] @staticmethod
+
def get_params(degrees: List[float]) -> float:
+
"""Get parameters for ``rotate`` for a random rotation.
+
+
Returns:
+
float: angle parameter to be passed to ``rotate`` for random rotation.
+
"""
+
angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
+
return angle
+
+
[docs] def forward(self, img):
+
"""
+
Args:
+
img (PIL Image or Tensor): Image to be rotated.
+
+
Returns:
+
PIL Image or Tensor: Rotated image.
+
"""
+
angle = self.get_params(self.degrees)
+
return F.rotate(img, angle, self.resample, self.expand, self.center, self.fill)
+
+
def __repr__(self):
+
format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees)
+
format_string += ', resample={0}'.format(self.resample)
+
format_string += ', expand={0}'.format(self.expand)
+
if self.center is not None:
+
format_string += ', center={0}'.format(self.center)
+
if self.fill is not None:
+
format_string += ', fill={0}'.format(self.fill)
+
format_string += ')'
+
return format_string
+
+
+[docs]class RandomAffine(torch.nn.Module):
+
"""Random affine transformation of the image keeping center invariant.
+
The image can be a PIL Image or a Tensor, in which case it is expected
+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
+
+
Args:
+
degrees (sequence or float or int): Range of degrees to select from.
+
If degrees is a number instead of sequence like (min, max), the range of degrees
+
will be (-degrees, +degrees). Set to 0 to deactivate rotations.
+
translate (tuple, optional): tuple of maximum absolute fraction for horizontal
+
and vertical translations. For example translate=(a, b), then horizontal shift
+
is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
+
randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
+
scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
+
randomly sampled from the range a <= scale <= b. Will keep original scale by default.
+
shear (sequence or float or int, optional): Range of degrees to select from.
+
If shear is a number, a shear parallel to the x axis in the range (-shear, +shear)
+
will be applied. Else if shear is a tuple or list of 2 values a shear parallel to the x axis in the
+
range (shear[0], shear[1]) will be applied. Else if shear is a tuple or list of 4 values,
+
a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
+
Will not apply shear by default.
+
resample (int, optional): An optional resampling filter. See `filters`_ for more information.
+
If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
+
If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported.
+
fillcolor (tuple or int): Optional fill color (Tuple for RGB Image and int for grayscale) for the area
+
outside the transform in the output image (Pillow>=5.0.0). This option is not supported for Tensor
+
input. Fill value for the area outside the transform in the output image is always 0.
+
+
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
+
+
"""
+
+
def __init__(self, degrees, translate=None, scale=None, shear=None, resample=0, fillcolor=0):
+
super().__init__()
+
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, ))
+
+
if translate is not None:
+
_check_sequence_input(translate, "translate", req_sizes=(2, ))
+
for t in translate:
+
if not (0.0 <= t <= 1.0):
+
raise ValueError("translation values should be between 0 and 1")
+
self.translate = translate
+
+
if scale is not None:
+
_check_sequence_input(scale, "scale", req_sizes=(2, ))
+
for s in scale:
+
if s <= 0:
+
raise ValueError("scale values should be positive")
+
self.scale = scale
+
+
if shear is not None:
+
self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4))
+
else:
+
self.shear = shear
+
+
self.resample = resample
+
self.fillcolor = fillcolor
+
+
[docs] @staticmethod
+
def get_params(
+
degrees: List[float],
+
translate: Optional[List[float]],
+
scale_ranges: Optional[List[float]],
+
shears: Optional[List[float]],
+
img_size: List[int]
+
) -> Tuple[float, Tuple[int, int], float, Tuple[float, float]]:
+
"""Get parameters for affine transformation
+
+
Returns:
+
params to be passed to the affine transformation
+
"""
+
angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
+
if translate is not None:
+
max_dx = float(translate[0] * img_size[0])
+
max_dy = float(translate[1] * img_size[1])
+
tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
+
ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
+
translations = (tx, ty)
+
else:
+
translations = (0, 0)
+
+
if scale_ranges is not None:
+
scale = float(torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item())
+
else:
+
scale = 1.0
+
+
shear_x = shear_y = 0.0
+
if shears is not None:
+
shear_x = float(torch.empty(1).uniform_(shears[0], shears[1]).item())
+
if len(shears) == 4:
+
shear_y = float(torch.empty(1).uniform_(shears[2], shears[3]).item())
+
+
shear = (shear_x, shear_y)
+
+
return angle, translations, scale, shear
+
+
[docs] def forward(self, img):
+
"""
+
img (PIL Image or Tensor): Image to be transformed.
+
+
Returns:
+
PIL Image or Tensor: Affine transformed image.
+
"""
+
+
img_size = F._get_image_size(img)
+
+
ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)
+
return F.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor)
+
+
def __repr__(self):
+
s = '{name}(degrees={degrees}'
+
if self.translate is not None:
+
s += ', translate={translate}'
+
if self.scale is not None:
+
s += ', scale={scale}'
+
if self.shear is not None:
+
s += ', shear={shear}'
+
if self.resample > 0:
+
s += ', resample={resample}'
+
if self.fillcolor != 0:
+
s += ', fillcolor={fillcolor}'
+
s += ')'
+
d = dict(self.__dict__)
+
d['resample'] = _pil_interpolation_to_str[d['resample']]
+
return s.format(name=self.__class__.__name__, **d)
+
+
+[docs]class Grayscale(torch.nn.Module):
+
"""Convert image to grayscale.
+
The image can be a PIL Image or a Tensor, in which case it is expected
+
to have [..., 3, H, W] shape, where ... means an arbitrary number of leading
+
dimensions
+
+
Args:
+
num_output_channels (int): (1 or 3) number of channels desired for output image
+
+
Returns:
+
PIL Image: Grayscale version of the input.
+
- If ``num_output_channels == 1`` : returned image is single channel
+
- If ``num_output_channels == 3`` : returned image is 3 channel with r == g == b
+
+
"""
+
+
def __init__(self, num_output_channels=1):
+
super().__init__()
+
self.num_output_channels = num_output_channels
+
+
[docs] def forward(self, img: Tensor) -> Tensor:
+
"""
+
Args:
+
img (PIL Image or Tensor): Image to be converted to grayscale.
+
+
Returns:
+
PIL Image or Tensor: Grayscaled image.
+
"""
+
return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels)
+
+
def __repr__(self):
+
return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels)
+
+
+[docs]class RandomGrayscale(torch.nn.Module):
+
"""Randomly convert image to grayscale with a probability of p (default 0.1).
+
The image can be a PIL Image or a Tensor, in which case it is expected
+
to have [..., 3, H, W] shape, where ... means an arbitrary number of leading
+
dimensions
+
+
Args:
+
p (float): probability that image should be converted to grayscale.
+
+
Returns:
+
PIL Image or Tensor: Grayscale version of the input image with probability p and unchanged
+
with probability (1-p).
+
- If input image is 1 channel: grayscale version is 1 channel
+
- If input image is 3 channel: grayscale version is 3 channel with r == g == b
+
+
"""
+
+
def __init__(self, p=0.1):
+
super().__init__()
+
self.p = p
+
+
[docs] def forward(self, img: Tensor) -> Tensor:
+
"""
+
Args:
+
img (PIL Image or Tensor): Image to be converted to grayscale.
+
+
Returns:
+
PIL Image or Tensor: Randomly grayscaled image.
+
"""
+
num_output_channels = F._get_image_num_channels(img)
+
if torch.rand(1) < self.p:
+
return F.rgb_to_grayscale(img, num_output_channels=num_output_channels)
+
return img
+
+
def __repr__(self):
+
return self.__class__.__name__ + '(p={0})'.format(self.p)
+
+
+[docs]class RandomErasing(torch.nn.Module):
+
""" Randomly selects a rectangle region in an image and erases its pixels.
+
'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896
+
+
Args:
+
p: probability that the random erasing operation will be performed.
+
scale: range of proportion of erased area against input image.
+
ratio: range of aspect ratio of erased area.
+
value: erasing value. Default is 0. If a single int, it is used to
+
erase all pixels. If a tuple of length 3, it is used to erase
+
R, G, B channels respectively.
+
If a str of 'random', erasing each pixel with random values.
+
inplace: boolean to make this transform inplace. Default set to False.
+
+
Returns:
+
Erased Image.
+
+
# Examples:
+
>>> transform = transforms.Compose([
+
>>> transforms.RandomHorizontalFlip(),
+
>>> transforms.ToTensor(),
+
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
+
>>> transforms.RandomErasing(),
+
>>> ])
+
"""
+
+
def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False):
+
super().__init__()
+
if not isinstance(value, (numbers.Number, str, tuple, list)):
+
raise TypeError("Argument value should be either a number or str or a sequence")
+
if isinstance(value, str) and value != "random":
+
raise ValueError("If value is str, it should be 'random'")
+
if not isinstance(scale, (tuple, list)):
+
raise TypeError("Scale should be a sequence")
+
if not isinstance(ratio, (tuple, list)):
+
raise TypeError("Ratio should be a sequence")
+
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
+
warnings.warn("Scale and ratio should be of kind (min, max)")
+
if scale[0] < 0 or scale[1] > 1:
+
raise ValueError("Scale should be between 0 and 1")
+
if p < 0 or p > 1:
+
raise ValueError("Random erasing probability should be between 0 and 1")
+
+
self.p = p
+
self.scale = scale
+
self.ratio = ratio
+
self.value = value
+
self.inplace = inplace
+
+
[docs] @staticmethod
+
def get_params(
+
img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None
+
) -> Tuple[int, int, int, int, Tensor]:
+
"""Get parameters for ``erase`` for a random erasing.
+
+
Args:
+
img (Tensor): Tensor image of size (C, H, W) to be erased.
+
scale (tuple or list): range of proportion of erased area against input image.
+
ratio (tuple or list): range of aspect ratio of erased area.
+
value (list, optional): erasing value. If None, it is interpreted as "random"
+
(erasing each pixel with random values). If ``len(value)`` is 1, it is interpreted as a number,
+
i.e. ``value[0]``.
+
+
Returns:
+
tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing.
+
"""
+
img_c, img_h, img_w = img.shape[-3], img.shape[-2], img.shape[-1]
+
area = img_h * img_w
+
+
for _ in range(10):
+
erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
+
aspect_ratio = torch.empty(1).uniform_(ratio[0], ratio[1]).item()
+
+
h = int(round(math.sqrt(erase_area * aspect_ratio)))
+
w = int(round(math.sqrt(erase_area / aspect_ratio)))
+
if not (h < img_h and w < img_w):
+
continue
+
+
if value is None:
+
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
+
else:
+
v = torch.tensor(value)[:, None, None]
+
+
i = torch.randint(0, img_h - h + 1, size=(1, )).item()
+
j = torch.randint(0, img_w - w + 1, size=(1, )).item()
+
return i, j, h, w, v
+
+
# Return original image
+
return 0, 0, img_h, img_w, img
+
+
[docs] def forward(self, img):
+
"""
+
Args:
+
img (Tensor): Tensor image of size (C, H, W) to be erased.
+
+
Returns:
+
img (Tensor): Erased Tensor image.
+
"""
+
if torch.rand(1) < self.p:
+
+
# cast self.value to script acceptable type
+
if isinstance(self.value, (int, float)):
+
value = [self.value, ]
+
elif isinstance(self.value, str):
+
value = None
+
elif isinstance(self.value, tuple):
+
value = list(self.value)
+
else:
+
value = self.value
+
+
if value is not None and not (len(value) in (1, img.shape[-3])):
+
raise ValueError(
+
"If value is a sequence, it should have either a single value or "
+
"{} (number of input channels)".format(img.shape[-3])
+
)
+
+
x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=value)
+
return F.erase(img, x, y, h, w, v, self.inplace)
+
return img
+
+
+[docs]class GaussianBlur(torch.nn.Module):
+
"""Blurs image with randomly chosen Gaussian blur.
+
The image can be a PIL Image or a Tensor, in which case it is expected
+
to have [..., 3, H, W] shape, where ... means an arbitrary number of leading
+
dimensions
+
+
Args:
+
kernel_size (int or sequence): Size of the Gaussian kernel.
+
sigma (float or tuple of float (min, max)): Standard deviation to be used for
+
creating kernel to perform blurring. If float, sigma is fixed. If it is tuple
+
of float (min, max), sigma is chosen uniformly at random to lie in the
+
given range.
+
+
Returns:
+
PIL Image or Tensor: Gaussian blurred version of the input image.
+
+
"""
+
+
def __init__(self, kernel_size, sigma=(0.1, 2.0)):
+
super().__init__()
+
self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers")
+
for ks in self.kernel_size:
+
if ks <= 0 or ks % 2 == 0:
+
raise ValueError("Kernel size value should be an odd and positive number.")
+
+
if isinstance(sigma, numbers.Number):
+
if sigma <= 0:
+
raise ValueError("If sigma is a single number, it must be positive.")
+
sigma = (sigma, sigma)
+
elif isinstance(sigma, Sequence) and len(sigma) == 2:
+
if not 0. < sigma[0] <= sigma[1]:
+
raise ValueError("sigma values should be positive and of the form (min, max).")
+
else:
+
raise ValueError("sigma should be a single number or a list/tuple with length 2.")
+
+
self.sigma = sigma
+
+
[docs] @staticmethod
+
def get_params(sigma_min: float, sigma_max: float) -> float:
+
"""Choose sigma for ``gaussian_blur`` for random gaussian blurring.
+
+
Args:
+
sigma_min (float): Minimum standard deviation that can be chosen for blurring kernel.
+
sigma_max (float): Maximum standard deviation that can be chosen for blurring kernel.
+
+
Returns:
+
float: Standard deviation to be passed to calculate kernel for gaussian blurring.
+
"""
+
return torch.empty(1).uniform_(sigma_min, sigma_max).item()
+
+
[docs] def forward(self, img: Tensor) -> Tensor:
+
"""
+
Args:
+
img (PIL Image or Tensor): image of size (C, H, W) to be blurred.
+
+
Returns:
+
PIL Image or Tensor: Gaussian blurred image
+
"""
+
sigma = self.get_params(self.sigma[0], self.sigma[1])
+
return F.gaussian_blur(img, self.kernel_size, [sigma, sigma])
+
+
def __repr__(self):
+
s = '(kernel_size={}, '.format(self.kernel_size)
+
s += 'sigma={})'.format(self.sigma)
+
return self.__class__.__name__ + s
+
+
+def _setup_size(size, error_msg):
+ if isinstance(size, numbers.Number):
+ return int(size), int(size)
+
+ if isinstance(size, Sequence) and len(size) == 1:
+ return size[0], size[0]
+
+ if len(size) != 2:
+ raise ValueError(error_msg)
+
+ return size
+
+
+def _check_sequence_input(x, name, req_sizes):
+ msg = req_sizes[0] if len(req_sizes) < 2 else " or ".join([str(s) for s in req_sizes])
+ if not isinstance(x, Sequence):
+ raise TypeError("{} should be a sequence of length {}.".format(name, msg))
+ if len(x) not in req_sizes:
+ raise ValueError("{} should be sequence of length {}.".format(name, msg))
+
+
+def _setup_angle(x, name, req_sizes=(2, )):
+ if isinstance(x, numbers.Number):
+ if x < 0:
+ raise ValueError("If {} is a single number, it must be positive.".format(name))
+ x = [-x, x]
+ else:
+ _check_sequence_input(x, name, req_sizes)
+
+ return [float(d) for d in x]
+
+
+