Skip to content

Commit 4a3edb4

Browse files
committed
add decoder from facebookresearch#541
1 parent d52aa76 commit 4a3edb4

File tree

1 file changed

+80
-37
lines changed

1 file changed

+80
-37
lines changed

slowfast/datasets/decoder.py

Lines changed: 80 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33

44
import logging
55
import math
6-
import numpy as np
76
import random
7+
8+
import numpy as np
89
import torch
910
import torchvision.io as io
1011

@@ -84,7 +85,7 @@ def get_multiple_start_end_idx(
8485
num_clips_uniform,
8586
min_delta=0,
8687
max_delta=math.inf,
87-
use_offset=False,
88+
use_offset=False
8889
):
8990
"""
9091
Sample a clip of size clip_size from a video of size video_size and
@@ -114,7 +115,7 @@ def sample_clips(
114115
min_delta=0,
115116
max_delta=math.inf,
116117
num_retries=100,
117-
use_offset=False,
118+
use_offset=False
118119
):
119120
se_inds = np.empty((0, 2))
120121
dt = np.empty((0))
@@ -125,15 +126,13 @@ def sample_clips(
125126
if clip_idx == -1:
126127
# Random temporal sampling.
127128
start_idx = random.uniform(0, max_start)
128-
else: # Uniformly sample the clip with the given index.
129+
else: # Uniformly sample the clip with the given index.
129130
if use_offset:
130131
if num_clips_uniform == 1:
131132
# Take the center clip if num_clips is 1.
132133
start_idx = math.floor(max_start / 2)
133134
else:
134-
start_idx = clip_idx * math.floor(
135-
max_start / (num_clips_uniform - 1)
136-
)
135+
start_idx = clip_idx * math.floor(max_start / (num_clips_uniform - 1))
137136
else:
138137
start_idx = max_start * clip_idx / num_clips_uniform
139138

@@ -304,7 +303,10 @@ def torchvision_decode(
304303
decode_all_video = False # try selective decoding
305304

306305
clip_sizes = [
307-
np.maximum(1.0, sampling_rate[i] * num_frames[i] / target_fps * fps)
306+
np.maximum(
307+
1.0,
308+
sampling_rate[i] * num_frames[i] / target_fps * fps
309+
)
308310
for i in range(len(sampling_rate))
309311
]
310312
start_end_delta_time = get_multiple_start_end_idx(
@@ -381,6 +383,10 @@ def pyav_decode(
381383
num_clips_uniform=10,
382384
target_fps=30,
383385
use_offset=False,
386+
modalities=("visual",),
387+
max_spatial_scale=0,
388+
min_delta=-math.inf,
389+
max_delta=math.inf,
384390
):
385391
"""
386392
Convert the video from its original fps to the target_fps. If the video
@@ -418,38 +424,69 @@ def pyav_decode(
418424
# If failed to fetch the decoding information, decode the entire video.
419425
decode_all_video = True
420426
video_start_pts, video_end_pts = 0, math.inf
427+
start_end_delta_time = None
428+
429+
frames = None
430+
if container.streams.video:
431+
video_frames, max_pts = pyav_decode_stream(
432+
container,
433+
video_start_pts,
434+
video_end_pts,
435+
container.streams.video[0],
436+
{"video": 0},
437+
)
438+
container.close()
439+
440+
frames = [frame.to_rgb().to_ndarray() for frame in video_frames]
441+
frames = torch.as_tensor(np.stack(frames))
442+
frames_out = [frames]
443+
421444
else:
422445
# Perform selective decoding.
423446
decode_all_video = False
424-
clip_size = np.maximum(
425-
1.0, np.ceil(sampling_rate * (num_frames - 1) / target_fps * fps)
426-
)
427-
start_idx, end_idx, fraction = get_start_end_idx(
447+
clip_sizes = [
448+
np.maximum(
449+
1.0,
450+
np.ceil(
451+
sampling_rate[i] * (num_frames[i] - 1) / target_fps * fps
452+
),
453+
)
454+
for i in range(len(sampling_rate))
455+
]
456+
start_end_delta_time = get_multiple_start_end_idx(
428457
frames_length,
429-
clip_size,
458+
clip_sizes,
430459
clip_idx,
431460
num_clips_uniform,
432-
use_offset=use_offset,
433-
)
434-
timebase = duration / frames_length
435-
video_start_pts = int(start_idx * timebase)
436-
video_end_pts = int(end_idx * timebase)
437-
438-
frames = None
439-
# If video stream was found, fetch video frames from the video.
440-
if container.streams.video:
441-
video_frames, max_pts = pyav_decode_stream(
442-
container,
443-
video_start_pts,
444-
video_end_pts,
445-
container.streams.video[0],
446-
{"video": 0},
461+
min_delta=min_delta,
462+
max_delta=max_delta,
447463
)
464+
frames_out = [None] * len(num_frames)
465+
for k in range(len(num_frames)):
466+
start_idx = start_end_delta_time[k, 0]
467+
end_idx = start_end_delta_time[k, 1]
468+
timebase = duration / frames_length
469+
video_start_pts = int(start_idx * timebase)
470+
video_end_pts = int(end_idx * timebase)
471+
472+
frames = None
473+
# If video stream was found, fetch video frames from the video.
474+
if container.streams.video:
475+
video_frames, max_pts = pyav_decode_stream(
476+
container,
477+
video_start_pts,
478+
video_end_pts,
479+
container.streams.video[0],
480+
{"video": 0},
481+
)
482+
483+
frames = [frame.to_rgb().to_ndarray() for frame in video_frames]
484+
frames = torch.as_tensor(np.stack(frames))
485+
486+
frames_out[k] = frames
448487
container.close()
449488

450-
frames = [frame.to_rgb().to_ndarray() for frame in video_frames]
451-
frames = torch.as_tensor(np.stack(frames))
452-
return frames, fps, decode_all_video
489+
return frames_out, fps, decode_all_video, start_end_delta_time
453490

454491

455492
def decode(
@@ -509,17 +546,20 @@ def decode(
509546
) # clips come temporally ordered from decoder
510547
try:
511548
if backend == "pyav":
512-
assert (
513-
min_delta == -math.inf and max_delta == math.inf
514-
), "delta sampling not supported in pyav"
515-
frames_decoded, fps, decode_all_video = pyav_decode(
549+
assert min_delta == -math.inf and max_delta == math.inf, \
550+
"delta sampling not supported in pyav"
551+
frames_decoded, fps, decode_all_video, start_end_delta_time = pyav_decode(
516552
container,
517553
sampling_rate,
518554
num_frames,
519555
clip_idx,
520556
num_clips_uniform,
521557
target_fps,
522558
use_offset=use_offset,
559+
modalities=("visual",),
560+
max_spatial_scale=max_spatial_scale,
561+
min_delta=min_delta,
562+
max_delta=max_delta,
523563
)
524564
elif backend == "torchvision":
525565
(
@@ -557,7 +597,10 @@ def decode(
557597
frames_decoded = [frames_decoded]
558598
num_decoded = len(frames_decoded)
559599
clip_sizes = [
560-
np.maximum(1.0, sampling_rate[i] * num_frames[i] / target_fps * fps)
600+
np.maximum(
601+
1.0,
602+
sampling_rate[i] * num_frames[i] / target_fps * fps
603+
)
561604
for i in range(len(sampling_rate))
562605
]
563606

@@ -621,4 +664,4 @@ def decode(
621664
for i in range(num_decode)
622665
)
623666

624-
return frames_out, start_end_delta_time, time_diff_aug
667+
return frames_out, start_end_delta_time, time_diff_aug

0 commit comments

Comments
 (0)