Skip to content

Commit b6a45f1

Browse files
author
Yi Li
committed
Merge pull request facebookresearch#541
2 parents 67c61a9 + e0d247a commit b6a45f1

File tree

1 file changed

+71
-33
lines changed

1 file changed

+71
-33
lines changed

slowfast/datasets/decoder.py

Lines changed: 71 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,10 @@ def pyav_decode(
382382
num_clips_uniform=10,
383383
target_fps=30,
384384
use_offset=False,
385+
modalities=("visual",),
386+
max_spatial_scale=0,
387+
min_delta=-math.inf,
388+
max_delta=math.inf,
385389
):
386390
"""
387391
Convert the video from its original fps to the target_fps. If the video
@@ -419,38 +423,69 @@ def pyav_decode(
419423
# If failed to fetch the decoding information, decode the entire video.
420424
decode_all_video = True
421425
video_start_pts, video_end_pts = 0, math.inf
426+
start_end_delta_time = None
427+
428+
frames = None
429+
if container.streams.video:
430+
video_frames, max_pts = pyav_decode_stream(
431+
container,
432+
video_start_pts,
433+
video_end_pts,
434+
container.streams.video[0],
435+
{"video": 0},
436+
)
437+
container.close()
438+
439+
frames = [frame.to_rgb().to_ndarray() for frame in video_frames]
440+
frames = torch.as_tensor(np.stack(frames))
441+
frames_out = [frames]
442+
422443
else:
423444
# Perform selective decoding.
424445
decode_all_video = False
425-
clip_size = np.maximum(
426-
1.0, np.ceil(sampling_rate * (num_frames - 1) / target_fps * fps)
427-
)
428-
start_idx, end_idx, fraction = get_start_end_idx(
446+
clip_sizes = [
447+
np.maximum(
448+
1.0,
449+
np.ceil(
450+
sampling_rate[i] * (num_frames[i] - 1) / target_fps * fps
451+
),
452+
)
453+
for i in range(len(sampling_rate))
454+
]
455+
start_end_delta_time = get_multiple_start_end_idx(
429456
frames_length,
430-
clip_size,
457+
clip_sizes,
431458
clip_idx,
432459
num_clips_uniform,
433-
use_offset=use_offset,
434-
)
435-
timebase = duration / frames_length
436-
video_start_pts = int(start_idx * timebase)
437-
video_end_pts = int(end_idx * timebase)
438-
439-
frames = None
440-
# If video stream was found, fetch video frames from the video.
441-
if container.streams.video:
442-
video_frames, max_pts = pyav_decode_stream(
443-
container,
444-
video_start_pts,
445-
video_end_pts,
446-
container.streams.video[0],
447-
{"video": 0},
460+
min_delta=min_delta,
461+
max_delta=max_delta,
448462
)
463+
frames_out = [None] * len(num_frames)
464+
for k in range(len(num_frames)):
465+
start_idx = start_end_delta_time[k, 0]
466+
end_idx = start_end_delta_time[k, 1]
467+
timebase = duration / frames_length
468+
video_start_pts = int(start_idx * timebase)
469+
video_end_pts = int(end_idx * timebase)
470+
471+
frames = None
472+
# If video stream was found, fetch video frames from the video.
473+
if container.streams.video:
474+
video_frames, max_pts = pyav_decode_stream(
475+
container,
476+
video_start_pts,
477+
video_end_pts,
478+
container.streams.video[0],
479+
{"video": 0},
480+
)
481+
482+
frames = [frame.to_rgb().to_ndarray() for frame in video_frames]
483+
frames = torch.as_tensor(np.stack(frames))
484+
485+
frames_out[k] = frames
449486
container.close()
450487

451-
frames = [frame.to_rgb().to_ndarray() for frame in video_frames]
452-
frames = torch.as_tensor(np.stack(frames))
453-
return frames, fps, decode_all_video
488+
return frames_out, fps, decode_all_video, start_end_delta_time
454489

455490

456491
def decode(
@@ -510,17 +545,20 @@ def decode(
510545
) # clips come temporally ordered from decoder
511546
try:
512547
if backend == "pyav":
513-
assert (
514-
min_delta == -math.inf and max_delta == math.inf
515-
), "delta sampling not supported in pyav"
516-
frames_decoded, fps, decode_all_video = pyav_decode(
548+
assert min_delta == -math.inf and max_delta == math.inf, \
549+
"delta sampling not supported in pyav"
550+
frames_decoded, fps, decode_all_video, start_end_delta_time = pyav_decode(
517551
container,
518552
sampling_rate,
519553
num_frames,
520554
clip_idx,
521555
num_clips_uniform,
522556
target_fps,
523557
use_offset=use_offset,
558+
modalities=("visual",),
559+
max_spatial_scale=max_spatial_scale,
560+
min_delta=min_delta,
561+
max_delta=max_delta,
524562
)
525563
elif backend == "torchvision":
526564
(
@@ -558,12 +596,12 @@ def decode(
558596
frames_decoded = [frames_decoded]
559597
num_decoded = len(frames_decoded)
560598
clip_sizes = [
561-
np.maximum(
562-
1.0,
563-
sampling_rate[i] * num_frames[i] / target_fps * fps
564-
)
565-
for i in range(len(sampling_rate))
566-
]
599+
np.maximum(
600+
1.0,
601+
sampling_rate[i] * num_frames[i] / target_fps * fps
602+
)
603+
for i in range(len(sampling_rate))
604+
]
567605

568606
if decode_all_video: # full video was decoded (not trimmed yet)
569607
assert num_decoded == 1 and start_end_delta_time is None

0 commit comments

Comments
 (0)