3
3
4
4
import logging
5
5
import math
6
- import numpy as np
7
6
import random
7
+
8
+ import numpy as np
8
9
import torch
9
10
import torchvision .io as io
10
11
@@ -84,7 +85,7 @@ def get_multiple_start_end_idx(
84
85
num_clips_uniform ,
85
86
min_delta = 0 ,
86
87
max_delta = math .inf ,
87
- use_offset = False ,
88
+ use_offset = False
88
89
):
89
90
"""
90
91
Sample a clip of size clip_size from a video of size video_size and
@@ -114,7 +115,7 @@ def sample_clips(
114
115
min_delta = 0 ,
115
116
max_delta = math .inf ,
116
117
num_retries = 100 ,
117
- use_offset = False ,
118
+ use_offset = False
118
119
):
119
120
se_inds = np .empty ((0 , 2 ))
120
121
dt = np .empty ((0 ))
@@ -125,15 +126,13 @@ def sample_clips(
125
126
if clip_idx == - 1 :
126
127
# Random temporal sampling.
127
128
start_idx = random .uniform (0 , max_start )
128
- else : # Uniformly sample the clip with the given index.
129
+ else : # Uniformly sample the clip with the given index.
129
130
if use_offset :
130
131
if num_clips_uniform == 1 :
131
132
# Take the center clip if num_clips is 1.
132
133
start_idx = math .floor (max_start / 2 )
133
134
else :
134
- start_idx = clip_idx * math .floor (
135
- max_start / (num_clips_uniform - 1 )
136
- )
135
+ start_idx = clip_idx * math .floor (max_start / (num_clips_uniform - 1 ))
137
136
else :
138
137
start_idx = max_start * clip_idx / num_clips_uniform
139
138
@@ -304,7 +303,10 @@ def torchvision_decode(
304
303
decode_all_video = False # try selective decoding
305
304
306
305
clip_sizes = [
307
- np .maximum (1.0 , sampling_rate [i ] * num_frames [i ] / target_fps * fps )
306
+ np .maximum (
307
+ 1.0 ,
308
+ sampling_rate [i ] * num_frames [i ] / target_fps * fps
309
+ )
308
310
for i in range (len (sampling_rate ))
309
311
]
310
312
start_end_delta_time = get_multiple_start_end_idx (
@@ -381,6 +383,10 @@ def pyav_decode(
381
383
num_clips_uniform = 10 ,
382
384
target_fps = 30 ,
383
385
use_offset = False ,
386
+ modalities = ("visual" ,),
387
+ max_spatial_scale = 0 ,
388
+ min_delta = - math .inf ,
389
+ max_delta = math .inf ,
384
390
):
385
391
"""
386
392
Convert the video from its original fps to the target_fps. If the video
@@ -418,38 +424,69 @@ def pyav_decode(
418
424
# If failed to fetch the decoding information, decode the entire video.
419
425
decode_all_video = True
420
426
video_start_pts , video_end_pts = 0 , math .inf
427
+ start_end_delta_time = None
428
+
429
+ frames = None
430
+ if container .streams .video :
431
+ video_frames , max_pts = pyav_decode_stream (
432
+ container ,
433
+ video_start_pts ,
434
+ video_end_pts ,
435
+ container .streams .video [0 ],
436
+ {"video" : 0 },
437
+ )
438
+ container .close ()
439
+
440
+ frames = [frame .to_rgb ().to_ndarray () for frame in video_frames ]
441
+ frames = torch .as_tensor (np .stack (frames ))
442
+ frames_out = [frames ]
443
+
421
444
else :
422
445
# Perform selective decoding.
423
446
decode_all_video = False
424
- clip_size = np .maximum (
425
- 1.0 , np .ceil (sampling_rate * (num_frames - 1 ) / target_fps * fps )
426
- )
427
- start_idx , end_idx , fraction = get_start_end_idx (
447
+ clip_sizes = [
448
+ np .maximum (
449
+ 1.0 ,
450
+ np .ceil (
451
+ sampling_rate [i ] * (num_frames [i ] - 1 ) / target_fps * fps
452
+ ),
453
+ )
454
+ for i in range (len (sampling_rate ))
455
+ ]
456
+ start_end_delta_time = get_multiple_start_end_idx (
428
457
frames_length ,
429
- clip_size ,
458
+ clip_sizes ,
430
459
clip_idx ,
431
460
num_clips_uniform ,
432
- use_offset = use_offset ,
433
- )
434
- timebase = duration / frames_length
435
- video_start_pts = int (start_idx * timebase )
436
- video_end_pts = int (end_idx * timebase )
437
-
438
- frames = None
439
- # If video stream was found, fetch video frames from the video.
440
- if container .streams .video :
441
- video_frames , max_pts = pyav_decode_stream (
442
- container ,
443
- video_start_pts ,
444
- video_end_pts ,
445
- container .streams .video [0 ],
446
- {"video" : 0 },
461
+ min_delta = min_delta ,
462
+ max_delta = max_delta ,
447
463
)
464
+ frames_out = [None ] * len (num_frames )
465
+ for k in range (len (num_frames )):
466
+ start_idx = start_end_delta_time [k , 0 ]
467
+ end_idx = start_end_delta_time [k , 1 ]
468
+ timebase = duration / frames_length
469
+ video_start_pts = int (start_idx * timebase )
470
+ video_end_pts = int (end_idx * timebase )
471
+
472
+ frames = None
473
+ # If video stream was found, fetch video frames from the video.
474
+ if container .streams .video :
475
+ video_frames , max_pts = pyav_decode_stream (
476
+ container ,
477
+ video_start_pts ,
478
+ video_end_pts ,
479
+ container .streams .video [0 ],
480
+ {"video" : 0 },
481
+ )
482
+
483
+ frames = [frame .to_rgb ().to_ndarray () for frame in video_frames ]
484
+ frames = torch .as_tensor (np .stack (frames ))
485
+
486
+ frames_out [k ] = frames
448
487
container .close ()
449
488
450
- frames = [frame .to_rgb ().to_ndarray () for frame in video_frames ]
451
- frames = torch .as_tensor (np .stack (frames ))
452
- return frames , fps , decode_all_video
489
+ return frames_out , fps , decode_all_video , start_end_delta_time
453
490
454
491
455
492
def decode (
@@ -509,17 +546,20 @@ def decode(
509
546
) # clips come temporally ordered from decoder
510
547
try :
511
548
if backend == "pyav" :
512
- assert (
513
- min_delta == - math .inf and max_delta == math .inf
514
- ), "delta sampling not supported in pyav"
515
- frames_decoded , fps , decode_all_video = pyav_decode (
549
+ assert min_delta == - math .inf and max_delta == math .inf , \
550
+ "delta sampling not supported in pyav"
551
+ frames_decoded , fps , decode_all_video , start_end_delta_time = pyav_decode (
516
552
container ,
517
553
sampling_rate ,
518
554
num_frames ,
519
555
clip_idx ,
520
556
num_clips_uniform ,
521
557
target_fps ,
522
558
use_offset = use_offset ,
559
+ modalities = ("visual" ,),
560
+ max_spatial_scale = max_spatial_scale ,
561
+ min_delta = min_delta ,
562
+ max_delta = max_delta ,
523
563
)
524
564
elif backend == "torchvision" :
525
565
(
@@ -557,7 +597,10 @@ def decode(
557
597
frames_decoded = [frames_decoded ]
558
598
num_decoded = len (frames_decoded )
559
599
clip_sizes = [
560
- np .maximum (1.0 , sampling_rate [i ] * num_frames [i ] / target_fps * fps )
600
+ np .maximum (
601
+ 1.0 ,
602
+ sampling_rate [i ] * num_frames [i ] / target_fps * fps
603
+ )
561
604
for i in range (len (sampling_rate ))
562
605
]
563
606
@@ -621,4 +664,4 @@ def decode(
621
664
for i in range (num_decode )
622
665
)
623
666
624
- return frames_out , start_end_delta_time , time_diff_aug
667
+ return frames_out , start_end_delta_time , time_diff_aug
0 commit comments