Skip to content

add opencv benchmark #711

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions benchmarks/decoders/benchmark_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
AbstractDecoder,
DecordAccurate,
DecordAccurateBatch,
OpenCVDecoder,
plot_data,
run_benchmarks,
TorchAudioDecoder,
Expand Down Expand Up @@ -61,6 +62,9 @@ class DecoderKind:
{"backend": "video_reader"},
),
"torchaudio": DecoderKind("TorchAudio", TorchAudioDecoder),
"opencv": DecoderKind(
"OpenCV[backend=FFMPEG]", OpenCVDecoder, {"backend": "FFMPEG"}
),
}


Expand Down
129 changes: 119 additions & 10 deletions benchmarks/decoders/benchmark_decoders_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,84 @@ def decode_and_resize(self, video_file, pts_list, height, width, device):
return frames


class OpenCVDecoder(AbstractDecoder):
def __init__(self, backend):
import cv2

self.cv2 = cv2

self._available_backends = {"FFMPEG": cv2.CAP_FFMPEG}
self._backend = self._available_backends.get(backend)

self._print_each_iteration_time = False

def decode_frames(self, video_file, pts_list):
cap = self.cv2.VideoCapture(video_file, self._backend)
if not cap.isOpened():
raise ValueError("Could not open video stream")

fps = cap.get(self.cv2.CAP_PROP_FPS)
approx_frame_indices = [int(pts * fps) for pts in pts_list]

current_frame = 0
frames = []
while True:
ok = cap.grab()
if not ok:
raise ValueError("Could not grab video frame")
if current_frame in approx_frame_indices: # only decompress needed
ret, frame = cap.retrieve()
if ret:
frame = self.convert_frame_to_rgb_tensor(frame)
frames.append(frame)

if len(frames) == len(approx_frame_indices):
break
current_frame += 1
cap.release()
assert len(frames) == len(approx_frame_indices)
return frames

def decode_first_n_frames(self, video_file, n):
cap = self.cv2.VideoCapture(video_file, self._backend)
if not cap.isOpened():
raise ValueError("Could not open video stream")

frames = []
for i in range(n):
ok = cap.grab()
if not ok:
raise ValueError("Could not grab video frame")
ret, frame = cap.retrieve()
if ret:
frame = self.convert_frame_to_rgb_tensor(frame)
frames.append(frame)
cap.release()
assert len(frames) == n
return frames

def decode_and_resize(self, *args, **kwargs):
raise ValueError(
"OpenCV doesn't apply antialias while pytorch does by default, this is potentially an unfair comparison"
)

def convert_frame_to_rgb_tensor(self, frame):
# OpenCV uses BGR, change to RGB
frame = self.cv2.cvtColor(frame, self.cv2.COLOR_BGR2RGB)
# Update to C, H, W
frame = np.transpose(frame, (2, 0, 1))
# Convert to tensor
frame = torch.from_numpy(frame)
return frame


class TorchCodecCore(AbstractDecoder):
def __init__(self, num_threads=None, color_conversion_library=None, device="cpu"):
def __init__(
self,
num_threads: str | None = None,
color_conversion_library=None,
device="cpu",
):
self._num_threads = int(num_threads) if num_threads else None
self._color_conversion_library = color_conversion_library
self._device = device
Expand Down Expand Up @@ -185,7 +261,12 @@ def decode_first_n_frames(self, video_file, n):


class TorchCodecCoreNonBatch(AbstractDecoder):
def __init__(self, num_threads=None, color_conversion_library=None, device="cpu"):
def __init__(
self,
num_threads: str | None = None,
color_conversion_library=None,
device="cpu",
):
self._num_threads = num_threads
self._color_conversion_library = color_conversion_library
self._device = device
Expand Down Expand Up @@ -254,7 +335,12 @@ def decode_and_resize(self, video_file, pts_list, height, width, device):


class TorchCodecCoreBatch(AbstractDecoder):
def __init__(self, num_threads=None, color_conversion_library=None, device="cpu"):
def __init__(
self,
num_threads: str | None = None,
color_conversion_library=None,
device="cpu",
):
self._print_each_iteration_time = False
self._num_threads = int(num_threads) if num_threads else None
self._color_conversion_library = color_conversion_library
Expand Down Expand Up @@ -293,10 +379,17 @@ def decode_first_n_frames(self, video_file, n):


class TorchCodecPublic(AbstractDecoder):
def __init__(self, num_ffmpeg_threads=None, device="cpu", seek_mode="exact"):
def __init__(
self,
num_ffmpeg_threads: str | None = None,
device="cpu",
seek_mode="exact",
stream_index: str | None = None,
):
self._num_ffmpeg_threads = num_ffmpeg_threads
self._device = device
self._seek_mode = seek_mode
self._stream_index = int(stream_index) if stream_index else None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was a bit surprised to see this written like this, because if stream_index is 0 then it becomes None, which isn't what we want. Then I realized stream_index is a str (!), not an int, so the logic is sound.

I rarely advocate for type annotations but in this case, I think it could help to annotate the stream_index parameter of __init__ as Optional[str]. We might as well do the same for num_ffmpeg_threads. (same below)


from torchvision.transforms import v2 as transforms_v2

Expand All @@ -311,6 +404,7 @@ def decode_frames(self, video_file, pts_list):
num_ffmpeg_threads=num_ffmpeg_threads,
device=self._device,
seek_mode=self._seek_mode,
stream_index=self._stream_index,
)
return decoder.get_frames_played_at(pts_list)

Expand All @@ -323,6 +417,7 @@ def decode_first_n_frames(self, video_file, n):
num_ffmpeg_threads=num_ffmpeg_threads,
device=self._device,
seek_mode=self._seek_mode,
stream_index=self._stream_index,
)
frames = []
count = 0
Expand All @@ -342,14 +437,20 @@ def decode_and_resize(self, video_file, pts_list, height, width, device):
num_ffmpeg_threads=num_ffmpeg_threads,
device=self._device,
seek_mode=self._seek_mode,
stream_index=self._stream_index,
)
frames = decoder.get_frames_played_at(pts_list)
frames = self.transforms_v2.functional.resize(frames.data, (height, width))
return frames


class TorchCodecPublicNonBatch(AbstractDecoder):
def __init__(self, num_ffmpeg_threads=None, device="cpu", seek_mode="approximate"):
def __init__(
self,
num_ffmpeg_threads: str | None = None,
device="cpu",
seek_mode="approximate",
):
self._num_ffmpeg_threads = num_ffmpeg_threads
self._device = device
self._seek_mode = seek_mode
Expand Down Expand Up @@ -452,19 +553,22 @@ def decode_first_n_frames(self, video_file, n):


class TorchAudioDecoder(AbstractDecoder):
def __init__(self):
def __init__(self, stream_index: str | None = None):
import torchaudio # noqa: F401

self.torchaudio = torchaudio

from torchvision.transforms import v2 as transforms_v2

self.transforms_v2 = transforms_v2
self._stream_index = int(stream_index) if stream_index else None

def decode_frames(self, video_file, pts_list):
stream_reader = self.torchaudio.io.StreamReader(src=video_file)
stream_reader.add_basic_video_stream(
frames_per_chunk=1, decoder_option={"threads": "0"}
frames_per_chunk=1,
decoder_option={"threads": "0"},
stream_index=self._stream_index,
)
frames = []
for pts in pts_list:
Expand All @@ -477,7 +581,9 @@ def decode_frames(self, video_file, pts_list):
def decode_first_n_frames(self, video_file, n):
stream_reader = self.torchaudio.io.StreamReader(src=video_file)
stream_reader.add_basic_video_stream(
frames_per_chunk=1, decoder_option={"threads": "0"}
frames_per_chunk=1,
decoder_option={"threads": "0"},
stream_index=self._stream_index,
)
frames = []
frame_cnt = 0
Expand All @@ -492,7 +598,9 @@ def decode_first_n_frames(self, video_file, n):
def decode_and_resize(self, video_file, pts_list, height, width, device):
stream_reader = self.torchaudio.io.StreamReader(src=video_file)
stream_reader.add_basic_video_stream(
frames_per_chunk=1, decoder_option={"threads": "1"}
frames_per_chunk=1,
decoder_option={"threads": "1"},
stream_index=self._stream_index,
)
frames = []
for pts in pts_list:
Expand Down Expand Up @@ -745,7 +853,8 @@ def run_benchmarks(
# are using different random pts values across videos.
random_pts_list = (torch.rand(num_samples) * duration).tolist()

for decoder_name, decoder in decoder_dict.items():
# The decoder items are sorted to perform and display the benchmarks in a consistent order.
for decoder_name, decoder in sorted(decoder_dict.items(), key=lambda x: x[0]):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change was added to make it easier to compare benchmarks, I'm open to alternative approaches to sorting or removing it entirely.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep this, as it ensures that we always get the results printed in the same order. But a comment stating that is good, since it may not be obvious why. (We're actually doing the experiments in the same order every time, which means we populate the results in the same way, which makes them displayed in the same way.)

print(f"video={video_file_path}, decoder={decoder_name}")

if dataloader_parameters:
Expand Down
Loading