import bisect
import math
import warnings
from fractions import Fraction
from typing import List
import torch
from torchvision.io import (
_probe_video_from_file,
_read_video_from_file,
read_video,
read_video_timestamps,
)
from .utils import tqdm
def pts_convert(pts, timebase_from, timebase_to, round_func=math.floor):
"""convert pts between different time bases
Args:
pts: presentation timestamp, float
timebase_from: original timebase. Fraction
timebase_to: new timebase. Fraction
round_func: rounding function.
"""
new_pts = Fraction(pts, 1) * timebase_from / timebase_to
return round_func(new_pts)
def unfold(tensor, size, step, dilation=1):
"""
similar to tensor.unfold, but with the dilation
and specialized for 1d tensors
Returns all consecutive windows of `size` elements, with
`step` between windows. The distance between each element
in a window is given by `dilation`.
"""
assert tensor.dim() == 1
o_stride = tensor.stride(0)
numel = tensor.numel()
new_stride = (step * o_stride, dilation * o_stride)
new_size = ((numel - (dilation * (size - 1) + 1)) // step + 1, size)
if new_size[0] < 1:
new_size = (0, size)
return torch.as_strided(tensor, new_size, new_stride)
class _VideoTimestampsDataset(object):
"""
Dataset used to parallelize the reading of the timestamps
of a list of videos, given their paths in the filesystem.
Used in VideoClips and defined at top level so it can be
pickled when forking.
"""
def __init__(self, video_paths: List[str]):
self.video_paths = video_paths
def __len__(self):
return len(self.video_paths)
def __getitem__(self, idx):
return read_video_timestamps(self.video_paths[idx])
def _collate_fn(x):
"""
Dummy collate function to be used with _VideoTimestampsDataset
"""
return x
class VideoClips(object):
"""
Given a list of video files, computes all consecutive subvideos of size
`clip_length_in_frames`, where the distance between each subvideo in the
same video is defined by `frames_between_clips`.
If `frame_rate` is specified, it will also resample all the videos to have
the same frame rate, and the clips will refer to this frame rate.
Creating this instance the first time is time-consuming, as it needs to
decode all the videos in `video_paths`. It is recommended that you
cache the results after instantiation of the class.
Recreating the clips for different clip lengths is fast, and can be done
with the `compute_clips` method.
Args:
video_paths (List[str]): paths to the video files
clip_length_in_frames (int): size of a clip in number of frames
frames_between_clips (int): step (in frames) between each clip
frame_rate (int, optional): if specified, it will resample the video
so that it has `frame_rate`, and then the clips will be defined
on the resampled video
num_workers (int): how many subprocesses to use for data loading.
0 means that the data will be loaded in the main process. (default: 0)
"""
def __init__(
self,
video_paths,
clip_length_in_frames=16,
frames_between_clips=1,
frame_rate=None,
_precomputed_metadata=None,
num_workers=0,
_video_width=0,
_video_height=0,
_video_min_dimension=0,
_video_max_dimension=0,
_audio_samples=0,
_audio_channels=0,
):
self.video_paths = video_paths
self.num_workers = num_workers
# these options are not valid for pyav backend
self._video_width = _video_width
self._video_height = _video_height
self._video_min_dimension = _video_min_dimension
self._video_max_dimension = _video_max_dimension
self._audio_samples = _audio_samples
self._audio_channels = _audio_channels
if _precomputed_metadata is None:
self._compute_frame_pts()
else:
self._init_from_metadata(_precomputed_metadata)
self.compute_clips(clip_length_in_frames, frames_between_clips, frame_rate)
def _compute_frame_pts(self):
self.video_pts = []
self.video_fps = []
# strategy: use a DataLoader to parallelize read_video_timestamps
# so need to create a dummy dataset first
import torch.utils.data
dl = torch.utils.data.DataLoader(
_VideoTimestampsDataset(self.video_paths),
batch_size=16,
num_workers=self.num_workers,
collate_fn=_collate_fn,
)
with tqdm(total=len(dl)) as pbar:
for batch in dl:
pbar.update(1)
clips, fps = list(zip(*batch))
# we need to specify dtype=torch.long because for empty list,
# torch.as_tensor will use torch.float as default dtype. This
# happens when decoding fails and no pts is returned in the list.
clips = [torch.as_tensor(c, dtype=torch.long) for c in clips]
self.video_pts.extend(clips)
self.video_fps.extend(fps)
def _init_from_metadata(self, metadata):
self.video_paths = metadata["video_paths"]
assert len(self.video_paths) == len(metadata["video_pts"])
self.video_pts = metadata["video_pts"]
assert len(self.video_paths) == len(metadata["video_fps"])
self.video_fps = metadata["video_fps"]
@property
def metadata(self):
_metadata = {
"video_paths": self.video_paths,
"video_pts": self.video_pts,
"video_fps": self.video_fps,
}
return _metadata
def subset(self, indices):
video_paths = [self.video_paths[i] for i in indices]
video_pts = [self.video_pts[i] for i in indices]
video_fps = [self.video_fps[i] for i in indices]
metadata = {
"video_paths": video_paths,
"video_pts": video_pts,
"video_fps": video_fps,
}
return type(self)(
video_paths,
self.num_frames,
self.step,
self.frame_rate,
_precomputed_metadata=metadata,
num_workers=self.num_workers,
_video_width=self._video_width,
_video_height=self._video_height,
_video_min_dimension=self._video_min_dimension,
_video_max_dimension=self._video_max_dimension,
_audio_samples=self._audio_samples,
_audio_channels=self._audio_channels,
)
@staticmethod
def compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate):
if fps is None:
# if for some reason the video doesn't have fps (because doesn't have a video stream)
# set the fps to 1. The value doesn't matter, because video_pts is empty anyway
fps = 1
if frame_rate is None:
frame_rate = fps
total_frames = len(video_pts) * (float(frame_rate) / fps)
idxs = VideoClips._resample_video_idx(
int(math.floor(total_frames)), fps, frame_rate
)
video_pts = video_pts[idxs]
clips = unfold(video_pts, num_frames, step)
if not clips.numel():
warnings.warn("There aren't enough frames in the current video to get a clip for the given clip length and "
"frames between clips. The video (and potentially others) will be skipped.")
if isinstance(idxs, slice):
idxs = [idxs] * len(clips)
else:
idxs = unfold(idxs, num_frames, step)
return clips, idxs
def compute_clips(self, num_frames, step, frame_rate=None):
"""
Compute all consecutive sequences of clips from video_pts.
Always returns clips of size `num_frames`, meaning that the
last few frames in a video can potentially be dropped.
Args:
num_frames (int): number of frames for the clip
step (int): distance between two clips
frame_rate (int, optional): The frame rate
"""
self.num_frames = num_frames
self.step = step
self.frame_rate = frame_rate
self.clips = []
self.resampling_idxs = []
for video_pts, fps in zip(self.video_pts, self.video_fps):
clips, idxs = self.compute_clips_for_video(
video_pts, num_frames, step, fps, frame_rate
)
self.clips.append(clips)
self.resampling_idxs.append(idxs)
clip_lengths = torch.as_tensor([len(v) for v in self.clips])
self.cumulative_sizes = clip_lengths.cumsum(0).tolist()
def __len__(self):
return self.num_clips()
def num_videos(self):
return len(self.video_paths)
def num_clips(self):
"""
Number of subclips that are available in the video list.
"""
return self.cumulative_sizes[-1]
def get_clip_location(self, idx):
"""
Converts a flattened representation of the indices into a video_idx, clip_idx
representation.
"""
video_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if video_idx == 0:
clip_idx = idx
else:
clip_idx = idx - self.cumulative_sizes[video_idx - 1]
return video_idx, clip_idx
@staticmethod
def _resample_video_idx(num_frames, original_fps, new_fps):
step = float(original_fps) / new_fps
if step.is_integer():
# optimization: if step is integer, don't need to perform
# advanced indexing
step = int(step)
return slice(None, None, step)
idxs = torch.arange(num_frames, dtype=torch.float32) * step
idxs = idxs.floor().to(torch.int64)
return idxs
def get_clip(self, idx):
"""
Gets a subclip from a list of videos.
Args:
idx (int): index of the subclip. Must be between 0 and num_clips().
Returns:
video (Tensor)
audio (Tensor)
info (Dict)
video_idx (int): index of the video in `video_paths`
"""
if idx >= self.num_clips():
raise IndexError(
"Index {} out of range "
"({} number of clips)".format(idx, self.num_clips())
)
video_idx, clip_idx = self.get_clip_location(idx)
video_path = self.video_paths[video_idx]
clip_pts = self.clips[video_idx][clip_idx]
from torchvision import get_video_backend
backend = get_video_backend()
if backend == "pyav":
# check for invalid options
if self._video_width != 0:
raise ValueError("pyav backend doesn't support _video_width != 0")
if self._video_height != 0:
raise ValueError("pyav backend doesn't support _video_height != 0")
if self._video_min_dimension != 0:
raise ValueError(
"pyav backend doesn't support _video_min_dimension != 0"
)
if self._video_max_dimension != 0:
raise ValueError(
"pyav backend doesn't support _video_max_dimension != 0"
)
if self._audio_samples != 0:
raise ValueError("pyav backend doesn't support _audio_samples != 0")
if backend == "pyav":
start_pts = clip_pts[0].item()
end_pts = clip_pts[-1].item()
video, audio, info = read_video(video_path, start_pts, end_pts)
else:
info = _probe_video_from_file(video_path)
video_fps = info.video_fps
audio_fps = None
video_start_pts = clip_pts[0].item()
video_end_pts = clip_pts[-1].item()
audio_start_pts, audio_end_pts = 0, -1
audio_timebase = Fraction(0, 1)
video_timebase = Fraction(
info.video_timebase.numerator, info.video_timebase.denominator
)
if info.has_audio:
audio_timebase = Fraction(
Loading ...