import torch
def _is_tensor_video_clip(clip):
if not torch.is_tensor(clip):
raise TypeError("clip should be Tesnor. Got %s" % type(clip))
if not clip.ndimension() == 4:
raise ValueError("clip should be 4D. Got %dD" % clip.dim())
return True
def crop(clip, i, j, h, w):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
"""
assert len(clip.size()) == 4, "clip should be a 4D tensor"
return clip[..., i:i + h, j:j + w]
def resize(clip, target_size, interpolation_mode):
assert len(target_size) == 2, "target size should be tuple (height, width)"
return torch.nn.functional.interpolate(
clip, size=target_size, mode=interpolation_mode, align_corners=False
)
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
"""
Do spatial cropping and resizing to the video clip
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
i (int): i in (i,j) i.e coordinates of the upper left corner.
j (int): j in (i,j) i.e coordinates of the upper left corner.
h (int): Height of the cropped region.
w (int): Width of the cropped region.
size (tuple(int, int)): height and width of resized clip
Returns:
clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W)
"""
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
clip = crop(clip, i, j, h, w)
clip = resize(clip, size, interpolation_mode)
return clip
def center_crop(clip, crop_size):
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
h, w = clip.size(-2), clip.size(-1)
th, tw = crop_size
assert h >= th and w >= tw, "height and width must be no smaller than crop_size"
i = int(round((h - th) / 2.0))
j = int(round((w - tw) / 2.0))
return crop(clip, i, j, th, tw)
def to_tensor(clip):
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)
Return:
clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
"""
_is_tensor_video_clip(clip)
if not clip.dtype == torch.uint8:
raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
return clip.float().permute(3, 0, 1, 2) / 255.0
def normalize(clip, mean, std, inplace=False):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
mean (tuple): pixel RGB mean. Size is (3)
std (tuple): pixel standard deviation. Size is (3)
Returns:
normalized clip (torch.tensor): Size is (C, T, H, W)
"""
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
if not inplace:
clip = clip.clone()
mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
return clip
def hflip(clip):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
Returns:
flipped clip (torch.tensor): Size is (C, T, H, W)
"""
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
return clip.flip((-1))