Skip to content

expose dilation argument in VideoClips class #2385

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: fbsync
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion torchvision/datasets/kinetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class Kinetics400(VisionDataset):
def __init__(self, root, frames_per_clip, step_between_clips=1, frame_rate=None,
extensions=('avi',), transform=None, _precomputed_metadata=None,
num_workers=1, _video_width=0, _video_height=0,
_video_min_dimension=0, _audio_samples=0, _audio_channels=0):
_video_min_dimension=0, _audio_samples=0, _audio_channels=0,
dilation=1):
super(Kinetics400, self).__init__(root)

classes = list(sorted(list_dir(root)))
Expand All @@ -59,6 +60,7 @@ def __init__(self, root, frames_per_clip, step_between_clips=1, frame_rate=None,
_video_min_dimension=_video_min_dimension,
_audio_samples=_audio_samples,
_audio_channels=_audio_channels,
dilation=dilation,
)
self.transform = transform

Expand Down
18 changes: 10 additions & 8 deletions torchvision/datasets/video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(
_video_max_dimension=0,
_audio_samples=0,
_audio_channels=0,
dilation=1,
):

self.video_paths = video_paths
Expand All @@ -118,7 +119,7 @@ def __init__(
self._compute_frame_pts()
else:
self._init_from_metadata(_precomputed_metadata)
self.compute_clips(clip_length_in_frames, frames_between_clips, frame_rate)
self.compute_clips(clip_length_in_frames, frames_between_clips, frame_rate, dilation)

def _collate_fn(self, x):
return x
Expand Down Expand Up @@ -190,7 +191,7 @@ def subset(self, indices):
)

@staticmethod
def compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate):
def compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate, dilation=1):
if fps is None:
# if for some reason the video doesn't have fps (because doesn't have a video stream)
# set the fps to 1. The value doesn't matter, because video_pts is empty anyway
Expand All @@ -202,14 +203,14 @@ def compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate):
int(math.floor(total_frames)), fps, frame_rate
)
video_pts = video_pts[idxs]
clips = unfold(video_pts, num_frames, step)
clips = unfold(video_pts, num_frames, step, dilation)
if isinstance(idxs, slice):
idxs = [idxs] * len(clips)
idxs = [slice(None, None, idxs.step * dilation)] * len(clips)
else:
idxs = unfold(idxs, num_frames, step)
idxs = unfold(idxs, num_frames, step, dilation)
return clips, idxs

def compute_clips(self, num_frames, step, frame_rate=None):
def compute_clips(self, num_frames, step, frame_rate=None, dilation=1):
"""
Compute all consecutive sequences of clips from video_pts.
Always returns clips of size `num_frames`, meaning that the
Expand All @@ -222,11 +223,12 @@ def compute_clips(self, num_frames, step, frame_rate=None):
self.num_frames = num_frames
self.step = step
self.frame_rate = frame_rate
self.dilation = dilation
self.clips = []
self.resampling_idxs = []
for video_pts, fps in zip(self.video_pts, self.video_fps):
clips, idxs = self.compute_clips_for_video(
video_pts, num_frames, step, fps, frame_rate
video_pts, num_frames, step, fps, frame_rate, dilation
)
self.clips.append(clips)
self.resampling_idxs.append(idxs)
Expand Down Expand Up @@ -409,4 +411,4 @@ def __setstate__(self, d):
d["video_pts"] = video_pts
self.__dict__ = d
# recompute attributes "clips", "resampling_idxs" and other derivative ones
self.compute_clips(self.num_frames, self.step, self.frame_rate)
self.compute_clips(self.num_frames, self.step, self.frame_rate, self.dilation)