@@ -146,8 +146,84 @@ def decode_and_resize(self, video_file, pts_list, height, width, device):
146
146
return frames
147
147
148
148
149
+ class OpenCVDecoder (AbstractDecoder ):
150
+ def __init__ (self , backend ):
151
+ import cv2
152
+
153
+ self .cv2 = cv2
154
+
155
+ self ._available_backends = {"FFMPEG" : cv2 .CAP_FFMPEG }
156
+ self ._backend = self ._available_backends .get (backend )
157
+
158
+ self ._print_each_iteration_time = False
159
+
160
+ def decode_frames (self , video_file , pts_list ):
161
+ cap = self .cv2 .VideoCapture (video_file , self ._backend )
162
+ if not cap .isOpened ():
163
+ raise ValueError ("Could not open video stream" )
164
+
165
+ fps = cap .get (self .cv2 .CAP_PROP_FPS )
166
+ approx_frame_indices = [int (pts * fps ) for pts in pts_list ]
167
+
168
+ current_frame = 0
169
+ frames = []
170
+ while True :
171
+ ok = cap .grab ()
172
+ if not ok :
173
+ raise ValueError ("Could not grab video frame" )
174
+ if current_frame in approx_frame_indices : # only decompress needed
175
+ ret , frame = cap .retrieve ()
176
+ if ret :
177
+ frame = self .convert_frame_to_rgb_tensor (frame )
178
+ frames .append (frame )
179
+
180
+ if len (frames ) == len (approx_frame_indices ):
181
+ break
182
+ current_frame += 1
183
+ cap .release ()
184
+ assert len (frames ) == len (approx_frame_indices )
185
+ return frames
186
+
187
+ def decode_first_n_frames (self , video_file , n ):
188
+ cap = self .cv2 .VideoCapture (video_file , self ._backend )
189
+ if not cap .isOpened ():
190
+ raise ValueError ("Could not open video stream" )
191
+
192
+ frames = []
193
+ for i in range (n ):
194
+ ok = cap .grab ()
195
+ if not ok :
196
+ raise ValueError ("Could not grab video frame" )
197
+ ret , frame = cap .retrieve ()
198
+ if ret :
199
+ frame = self .convert_frame_to_rgb_tensor (frame )
200
+ frames .append (frame )
201
+ cap .release ()
202
+ assert len (frames ) == n
203
+ return frames
204
+
205
+ def decode_and_resize (self , * args , ** kwargs ):
206
+ raise ValueError (
207
+ "OpenCV doesn't apply antialias while pytorch does by default, this is potentially an unfair comparison"
208
+ )
209
+
210
+ def convert_frame_to_rgb_tensor (self , frame ):
211
+ # OpenCV uses BGR, change to RGB
212
+ frame = self .cv2 .cvtColor (frame , self .cv2 .COLOR_BGR2RGB )
213
+ # Update to C, H, W
214
+ frame = np .transpose (frame , (2 , 0 , 1 ))
215
+ # Convert to tensor
216
+ frame = torch .from_numpy (frame )
217
+ return frame
218
+
219
+
149
220
class TorchCodecCore (AbstractDecoder ):
150
- def __init__ (self , num_threads = None , color_conversion_library = None , device = "cpu" ):
221
+ def __init__ (
222
+ self ,
223
+ num_threads : str | None = None ,
224
+ color_conversion_library = None ,
225
+ device = "cpu" ,
226
+ ):
151
227
self ._num_threads = int (num_threads ) if num_threads else None
152
228
self ._color_conversion_library = color_conversion_library
153
229
self ._device = device
@@ -185,7 +261,12 @@ def decode_first_n_frames(self, video_file, n):
185
261
186
262
187
263
class TorchCodecCoreNonBatch (AbstractDecoder ):
188
- def __init__ (self , num_threads = None , color_conversion_library = None , device = "cpu" ):
264
+ def __init__ (
265
+ self ,
266
+ num_threads : str | None = None ,
267
+ color_conversion_library = None ,
268
+ device = "cpu" ,
269
+ ):
189
270
self ._num_threads = num_threads
190
271
self ._color_conversion_library = color_conversion_library
191
272
self ._device = device
@@ -254,7 +335,12 @@ def decode_and_resize(self, video_file, pts_list, height, width, device):
254
335
255
336
256
337
class TorchCodecCoreBatch (AbstractDecoder ):
257
- def __init__ (self , num_threads = None , color_conversion_library = None , device = "cpu" ):
338
+ def __init__ (
339
+ self ,
340
+ num_threads : str | None = None ,
341
+ color_conversion_library = None ,
342
+ device = "cpu" ,
343
+ ):
258
344
self ._print_each_iteration_time = False
259
345
self ._num_threads = int (num_threads ) if num_threads else None
260
346
self ._color_conversion_library = color_conversion_library
@@ -293,10 +379,17 @@ def decode_first_n_frames(self, video_file, n):
293
379
294
380
295
381
class TorchCodecPublic (AbstractDecoder ):
296
- def __init__ (self , num_ffmpeg_threads = None , device = "cpu" , seek_mode = "exact" ):
382
+ def __init__ (
383
+ self ,
384
+ num_ffmpeg_threads : str | None = None ,
385
+ device = "cpu" ,
386
+ seek_mode = "exact" ,
387
+ stream_index : str | None = None ,
388
+ ):
297
389
self ._num_ffmpeg_threads = num_ffmpeg_threads
298
390
self ._device = device
299
391
self ._seek_mode = seek_mode
392
+ self ._stream_index = int (stream_index ) if stream_index else None
300
393
301
394
from torchvision .transforms import v2 as transforms_v2
302
395
@@ -311,6 +404,7 @@ def decode_frames(self, video_file, pts_list):
311
404
num_ffmpeg_threads = num_ffmpeg_threads ,
312
405
device = self ._device ,
313
406
seek_mode = self ._seek_mode ,
407
+ stream_index = self ._stream_index ,
314
408
)
315
409
return decoder .get_frames_played_at (pts_list )
316
410
@@ -323,6 +417,7 @@ def decode_first_n_frames(self, video_file, n):
323
417
num_ffmpeg_threads = num_ffmpeg_threads ,
324
418
device = self ._device ,
325
419
seek_mode = self ._seek_mode ,
420
+ stream_index = self ._stream_index ,
326
421
)
327
422
frames = []
328
423
count = 0
@@ -342,14 +437,20 @@ def decode_and_resize(self, video_file, pts_list, height, width, device):
342
437
num_ffmpeg_threads = num_ffmpeg_threads ,
343
438
device = self ._device ,
344
439
seek_mode = self ._seek_mode ,
440
+ stream_index = self ._stream_index ,
345
441
)
346
442
frames = decoder .get_frames_played_at (pts_list )
347
443
frames = self .transforms_v2 .functional .resize (frames .data , (height , width ))
348
444
return frames
349
445
350
446
351
447
class TorchCodecPublicNonBatch (AbstractDecoder ):
352
- def __init__ (self , num_ffmpeg_threads = None , device = "cpu" , seek_mode = "approximate" ):
448
+ def __init__ (
449
+ self ,
450
+ num_ffmpeg_threads : str | None = None ,
451
+ device = "cpu" ,
452
+ seek_mode = "approximate" ,
453
+ ):
353
454
self ._num_ffmpeg_threads = num_ffmpeg_threads
354
455
self ._device = device
355
456
self ._seek_mode = seek_mode
@@ -452,19 +553,22 @@ def decode_first_n_frames(self, video_file, n):
452
553
453
554
454
555
class TorchAudioDecoder (AbstractDecoder ):
455
- def __init__ (self ):
556
+ def __init__ (self , stream_index : str | None = None ):
456
557
import torchaudio # noqa: F401
457
558
458
559
self .torchaudio = torchaudio
459
560
460
561
from torchvision .transforms import v2 as transforms_v2
461
562
462
563
self .transforms_v2 = transforms_v2
564
+ self ._stream_index = int (stream_index ) if stream_index else None
463
565
464
566
def decode_frames (self , video_file , pts_list ):
465
567
stream_reader = self .torchaudio .io .StreamReader (src = video_file )
466
568
stream_reader .add_basic_video_stream (
467
- frames_per_chunk = 1 , decoder_option = {"threads" : "0" }
569
+ frames_per_chunk = 1 ,
570
+ decoder_option = {"threads" : "0" },
571
+ stream_index = self ._stream_index ,
468
572
)
469
573
frames = []
470
574
for pts in pts_list :
@@ -477,7 +581,9 @@ def decode_frames(self, video_file, pts_list):
477
581
def decode_first_n_frames (self , video_file , n ):
478
582
stream_reader = self .torchaudio .io .StreamReader (src = video_file )
479
583
stream_reader .add_basic_video_stream (
480
- frames_per_chunk = 1 , decoder_option = {"threads" : "0" }
584
+ frames_per_chunk = 1 ,
585
+ decoder_option = {"threads" : "0" },
586
+ stream_index = self ._stream_index ,
481
587
)
482
588
frames = []
483
589
frame_cnt = 0
@@ -492,7 +598,9 @@ def decode_first_n_frames(self, video_file, n):
492
598
def decode_and_resize (self , video_file , pts_list , height , width , device ):
493
599
stream_reader = self .torchaudio .io .StreamReader (src = video_file )
494
600
stream_reader .add_basic_video_stream (
495
- frames_per_chunk = 1 , decoder_option = {"threads" : "1" }
601
+ frames_per_chunk = 1 ,
602
+ decoder_option = {"threads" : "1" },
603
+ stream_index = self ._stream_index ,
496
604
)
497
605
frames = []
498
606
for pts in pts_list :
@@ -745,7 +853,8 @@ def run_benchmarks(
745
853
# are using different random pts values across videos.
746
854
random_pts_list = (torch .rand (num_samples ) * duration ).tolist ()
747
855
748
- for decoder_name , decoder in decoder_dict .items ():
856
+ # The decoder items are sorted to perform and display the benchmarks in a consistent order.
857
+ for decoder_name , decoder in sorted (decoder_dict .items (), key = lambda x : x [0 ]):
749
858
print (f"video={ video_file_path } , decoder={ decoder_name } " )
750
859
751
860
if dataloader_parameters :
0 commit comments