@@ -382,6 +382,10 @@ def pyav_decode(
382
382
num_clips_uniform = 10 ,
383
383
target_fps = 30 ,
384
384
use_offset = False ,
385
+ modalities = ("visual" ,),
386
+ max_spatial_scale = 0 ,
387
+ min_delta = - math .inf ,
388
+ max_delta = math .inf ,
385
389
):
386
390
"""
387
391
Convert the video from its original fps to the target_fps. If the video
@@ -419,38 +423,69 @@ def pyav_decode(
419
423
# If failed to fetch the decoding information, decode the entire video.
420
424
decode_all_video = True
421
425
video_start_pts , video_end_pts = 0 , math .inf
426
+ start_end_delta_time = None
427
+
428
+ frames = None
429
+ if container .streams .video :
430
+ video_frames , max_pts = pyav_decode_stream (
431
+ container ,
432
+ video_start_pts ,
433
+ video_end_pts ,
434
+ container .streams .video [0 ],
435
+ {"video" : 0 },
436
+ )
437
+ container .close ()
438
+
439
+ frames = [frame .to_rgb ().to_ndarray () for frame in video_frames ]
440
+ frames = torch .as_tensor (np .stack (frames ))
441
+ frames_out = [frames ]
442
+
422
443
else :
423
444
# Perform selective decoding.
424
445
decode_all_video = False
425
- clip_size = np .maximum (
426
- 1.0 , np .ceil (sampling_rate * (num_frames - 1 ) / target_fps * fps )
427
- )
428
- start_idx , end_idx , fraction = get_start_end_idx (
446
+ clip_sizes = [
447
+ np .maximum (
448
+ 1.0 ,
449
+ np .ceil (
450
+ sampling_rate [i ] * (num_frames [i ] - 1 ) / target_fps * fps
451
+ ),
452
+ )
453
+ for i in range (len (sampling_rate ))
454
+ ]
455
+ start_end_delta_time = get_multiple_start_end_idx (
429
456
frames_length ,
430
- clip_size ,
457
+ clip_sizes ,
431
458
clip_idx ,
432
459
num_clips_uniform ,
433
- use_offset = use_offset ,
434
- )
435
- timebase = duration / frames_length
436
- video_start_pts = int (start_idx * timebase )
437
- video_end_pts = int (end_idx * timebase )
438
-
439
- frames = None
440
- # If video stream was found, fetch video frames from the video.
441
- if container .streams .video :
442
- video_frames , max_pts = pyav_decode_stream (
443
- container ,
444
- video_start_pts ,
445
- video_end_pts ,
446
- container .streams .video [0 ],
447
- {"video" : 0 },
460
+ min_delta = min_delta ,
461
+ max_delta = max_delta ,
448
462
)
463
+ frames_out = [None ] * len (num_frames )
464
+ for k in range (len (num_frames )):
465
+ start_idx = start_end_delta_time [k , 0 ]
466
+ end_idx = start_end_delta_time [k , 1 ]
467
+ timebase = duration / frames_length
468
+ video_start_pts = int (start_idx * timebase )
469
+ video_end_pts = int (end_idx * timebase )
470
+
471
+ frames = None
472
+ # If video stream was found, fetch video frames from the video.
473
+ if container .streams .video :
474
+ video_frames , max_pts = pyav_decode_stream (
475
+ container ,
476
+ video_start_pts ,
477
+ video_end_pts ,
478
+ container .streams .video [0 ],
479
+ {"video" : 0 },
480
+ )
481
+
482
+ frames = [frame .to_rgb ().to_ndarray () for frame in video_frames ]
483
+ frames = torch .as_tensor (np .stack (frames ))
484
+
485
+ frames_out [k ] = frames
449
486
container .close ()
450
487
451
- frames = [frame .to_rgb ().to_ndarray () for frame in video_frames ]
452
- frames = torch .as_tensor (np .stack (frames ))
453
- return frames , fps , decode_all_video
488
+ return frames_out , fps , decode_all_video , start_end_delta_time
454
489
455
490
456
491
def decode (
@@ -510,17 +545,20 @@ def decode(
510
545
) # clips come temporally ordered from decoder
511
546
try :
512
547
if backend == "pyav" :
513
- assert (
514
- min_delta == - math .inf and max_delta == math .inf
515
- ), "delta sampling not supported in pyav"
516
- frames_decoded , fps , decode_all_video = pyav_decode (
548
+ assert min_delta == - math .inf and max_delta == math .inf , \
549
+ "delta sampling not supported in pyav"
550
+ frames_decoded , fps , decode_all_video , start_end_delta_time = pyav_decode (
517
551
container ,
518
552
sampling_rate ,
519
553
num_frames ,
520
554
clip_idx ,
521
555
num_clips_uniform ,
522
556
target_fps ,
523
557
use_offset = use_offset ,
558
+ modalities = ("visual" ,),
559
+ max_spatial_scale = max_spatial_scale ,
560
+ min_delta = min_delta ,
561
+ max_delta = max_delta ,
524
562
)
525
563
elif backend == "torchvision" :
526
564
(
@@ -558,12 +596,12 @@ def decode(
558
596
frames_decoded = [frames_decoded ]
559
597
num_decoded = len (frames_decoded )
560
598
clip_sizes = [
561
- np .maximum (
562
- 1.0 ,
563
- sampling_rate [i ] * num_frames [i ] / target_fps * fps
564
- )
565
- for i in range (len (sampling_rate ))
566
- ]
599
+ np .maximum (
600
+ 1.0 ,
601
+ sampling_rate [i ] * num_frames [i ] / target_fps * fps
602
+ )
603
+ for i in range (len (sampling_rate ))
604
+ ]
567
605
568
606
if decode_all_video : # full video was decoded (not trimmed yet)
569
607
assert num_decoded == 1 and start_end_delta_time is None
0 commit comments