From 169b4d6c7daaa0dbd3627a5cdddb6dfbc8dc1666 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Thu, 27 Mar 2025 17:28:42 -0700 Subject: [PATCH 1/2] Remove header for custom ops --- .../decoders/_core/VideoDecoderOps.cpp | 169 ++++++++++++------ .../decoders/_core/VideoDecoderOps.h | 162 ----------------- 2 files changed, 115 insertions(+), 216 deletions(-) delete mode 100644 src/torchcodec/decoders/_core/VideoDecoderOps.h diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index bd142d70..c0689937 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -4,7 +4,6 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. -#include "src/torchcodec/decoders/_core/VideoDecoderOps.h" #include #include #include @@ -85,6 +84,14 @@ VideoDecoder* unwrapTensorToGetDecoder(at::Tensor& tensor) { return decoder; } +// The elements of this tuple are all tensors that represent a single frame: +// 1. The frame data, which is a multidimensional tensor. +// 2. A single float value for the pts in seconds. +// 3. A single float value for the duration in seconds. +// The reason we use Tensors for the second and third values is so we can run +// under torch.compile(). +using OpsFrameOutput = std::tuple; + OpsFrameOutput makeOpsFrameOutput(VideoDecoder::FrameOutput& frame) { return std::make_tuple( frame.data, @@ -92,26 +99,67 @@ OpsFrameOutput makeOpsFrameOutput(VideoDecoder::FrameOutput& frame) { torch::tensor(frame.durationSeconds, torch::dtype(torch::kFloat64))); } +// All elements of this tuple are tensors of the same leading dimension. The +// tuple represents the frames for N total frames, where N is the dimension of +// each stacked tensor. The elments are: +// 1. Stacked tensor of data for all N frames. Each frame is also a +// multidimensional tensor. +// 2. Tensor of N pts values in seconds, where each pts is a single +// float. +// 3. Tensor of N durationis in seconds, where each duration is a +// single float. +using OpsFrameBatchOutput = std::tuple; + OpsFrameBatchOutput makeOpsFrameBatchOutput( VideoDecoder::FrameBatchOutput& batch) { return std::make_tuple(batch.data, batch.ptsSeconds, batch.durationSeconds); } +// The elements of this tuple are all tensors that represent the concatenation +// of multiple audio frames: +// 1. The frames data (concatenated) +// 2. A single float value for the pts of the first frame, in seconds. +using OpsAudioFramesOutput = std::tuple; + OpsAudioFramesOutput makeOpsAudioFramesOutput( VideoDecoder::AudioFramesOutput& audioFrames) { return std::make_tuple( audioFrames.data, torch::tensor(audioFrames.ptsSeconds, torch::dtype(torch::kFloat64))); } + +std::string quoteValue(const std::string& value) { + return "\"" + value + "\""; +} + +std::string mapToJson(const std::map& metadataMap) { + std::stringstream ss; + ss << "{\n"; + auto it = metadataMap.begin(); + while (it != metadataMap.end()) { + ss << "\"" << it->first << "\": " << it->second; + ++it; + if (it != metadataMap.end()) { + ss << ",\n"; + } else { + ss << "\n"; + } + } + ss << "}"; + + return ss.str(); +} + } // namespace // ============================== // Implementations for the operators // ============================== +// Create a VideoDecoder from file and wrap the pointer in a tensor. at::Tensor create_from_file( std::string_view filename, - std::optional seek_mode) { + std::optional seek_mode = std::nullopt) { std::string filenameStr(filename); VideoDecoder::SeekMode realSeek = VideoDecoder::SeekMode::exact; @@ -125,9 +173,11 @@ at::Tensor create_from_file( return wrapDecoderPointerToTensor(std::move(uniqueDecoder)); } +// Create a VideoDecoder from the actual bytes of a video and wrap the pointer +// in a tensor. The VideoDecoder will decode the provided bytes. at::Tensor create_from_tensor( at::Tensor video_tensor, - std::optional seek_mode) { + std::optional seek_mode = std::nullopt) { TORCH_CHECK(video_tensor.is_contiguous(), "video_tensor must be contiguous"); TORCH_CHECK( video_tensor.scalar_type() == torch::kUInt8, @@ -153,33 +203,15 @@ at::Tensor _convert_to_tensor(int64_t decoder_ptr) { return wrapDecoderPointerToTensor(std::move(uniqueDecoder)); } -void add_video_stream( - at::Tensor& decoder, - std::optional width, - std::optional height, - std::optional num_threads, - std::optional dimension_order, - std::optional stream_index, - std::optional device) { - _add_video_stream( - decoder, - width, - height, - num_threads, - dimension_order, - stream_index, - device); -} - void _add_video_stream( at::Tensor& decoder, - std::optional width, - std::optional height, - std::optional num_threads, - std::optional dimension_order, - std::optional stream_index, - std::optional device, - std::optional color_conversion_library) { + std::optional width = std::nullopt, + std::optional height = std::nullopt, + std::optional num_threads = std::nullopt, + std::optional dimension_order = std::nullopt, + std::optional stream_index = std::nullopt, + std::optional device = std::nullopt, + std::optional color_conversion_library = std::nullopt) { VideoDecoder::VideoStreamOptions videoStreamOptions; videoStreamOptions.width = width; videoStreamOptions.height = height; @@ -221,10 +253,29 @@ void _add_video_stream( videoDecoder->addVideoStream(stream_index.value_or(-1), videoStreamOptions); } +// Add a new video stream at `stream_index` using the provided options. +void add_video_stream( + at::Tensor& decoder, + std::optional width = std::nullopt, + std::optional height = std::nullopt, + std::optional num_threads = std::nullopt, + std::optional dimension_order = std::nullopt, + std::optional stream_index = std::nullopt, + std::optional device = std::nullopt) { + _add_video_stream( + decoder, + width, + height, + num_threads, + dimension_order, + stream_index, + device); +} + void add_audio_stream( at::Tensor& decoder, - std::optional stream_index, - std::optional sample_rate) { + std::optional stream_index = std::nullopt, + std::optional sample_rate = std::nullopt) { VideoDecoder::AudioStreamOptions audioStreamOptions; audioStreamOptions.sampleRate = sample_rate; @@ -232,11 +283,14 @@ void add_audio_stream( videoDecoder->addAudioStream(stream_index.value_or(-1), audioStreamOptions); } +// Seek to a particular presentation timestamp in the video in seconds. void seek_to_pts(at::Tensor& decoder, double seconds) { auto videoDecoder = static_cast(decoder.mutable_data_ptr()); videoDecoder->setCursorPtsInSeconds(seconds); } +// Get the next frame from the video as a tuple that has the frame data, pts and +// duration as tensors. OpsFrameOutput get_next_frame(at::Tensor& decoder) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); VideoDecoder::FrameOutput result; @@ -248,6 +302,9 @@ OpsFrameOutput get_next_frame(at::Tensor& decoder) { return makeOpsFrameOutput(result); } +// Return the frame that is visible at a given timestamp in seconds. Each frame +// in FFMPEG has a presentation timestamp and a duration. The frame visible at a +// given timestamp T has T >= PTS and T < PTS + Duration. OpsFrameOutput get_frame_at_pts(at::Tensor& decoder, double seconds) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); VideoDecoder::FrameOutput result; @@ -259,12 +316,14 @@ OpsFrameOutput get_frame_at_pts(at::Tensor& decoder, double seconds) { return makeOpsFrameOutput(result); } +// Return the frame that is visible at a given index in the video. OpsFrameOutput get_frame_at_index(at::Tensor& decoder, int64_t frame_index) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); auto result = videoDecoder->getFrameAtIndex(frame_index); return makeOpsFrameOutput(result); } +// Return the frames at given indices for a given stream OpsFrameBatchOutput get_frames_at_indices( at::Tensor& decoder, at::IntArrayRef frame_indices) { @@ -275,16 +334,19 @@ OpsFrameBatchOutput get_frames_at_indices( return makeOpsFrameBatchOutput(result); } +// Return the frames inside a range as a single stacked Tensor. The range is +// defined as [start, stop). OpsFrameBatchOutput get_frames_in_range( at::Tensor& decoder, int64_t start, int64_t stop, - std::optional step) { + std::optional step = std::nullopt) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); auto result = videoDecoder->getFramesInRange(start, stop, step.value_or(1)); return makeOpsFrameBatchOutput(result); } +// Return the frames at given ptss for a given stream OpsFrameBatchOutput get_frames_by_pts( at::Tensor& decoder, at::ArrayRef timestamps) { @@ -294,6 +356,9 @@ OpsFrameBatchOutput get_frames_by_pts( return makeOpsFrameBatchOutput(result); } +// Return the frames inside the range as a single stacked Tensor. The range is +// defined as [start_seconds, stop_seconds). The frames are stacked in pts +// order. OpsFrameBatchOutput get_frames_by_pts_in_range( at::Tensor& decoder, double start_seconds, @@ -307,35 +372,22 @@ OpsFrameBatchOutput get_frames_by_pts_in_range( OpsAudioFramesOutput get_frames_by_pts_in_range_audio( at::Tensor& decoder, double start_seconds, - std::optional stop_seconds) { + std::optional stop_seconds = std::nullopt) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); auto result = videoDecoder->getFramesPlayedInRangeAudio(start_seconds, stop_seconds); return makeOpsAudioFramesOutput(result); } -std::string quoteValue(const std::string& value) { - return "\"" + value + "\""; -} - -std::string mapToJson(const std::map& metadataMap) { - std::stringstream ss; - ss << "{\n"; - auto it = metadataMap.begin(); - while (it != metadataMap.end()) { - ss << "\"" << it->first << "\": " << it->second; - ++it; - if (it != metadataMap.end()) { - ss << ",\n"; - } else { - ss << "\n"; - } - } - ss << "}"; - - return ss.str(); -} - +// For testing only. We need to implement this operation as a core library +// function because what we're testing is round-tripping pts values as +// double-precision floating point numbers from C++ to Python and back to C++. +// We want to make sure that the value is preserved exactly, bit-for-bit, during +// this process. +// +// Returns true if for the given decoder, the pts +// value when converted to seconds as a double is exactly pts_seconds_to_test. +// Returns false otherwise. bool _test_frame_pts_equality( at::Tensor& decoder, int64_t frame_index, @@ -350,6 +402,7 @@ torch::Tensor _get_key_frame_indices(at::Tensor& decoder) { return videoDecoder->getKeyFrameIndices(); } +// Get the metadata from the video as a string. std::string get_json_metadata(at::Tensor& decoder) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); @@ -418,6 +471,7 @@ std::string get_json_metadata(at::Tensor& decoder) { return mapToJson(metadataMap); } +// Get the container metadata as a string. std::string get_container_json_metadata(at::Tensor& decoder) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); @@ -448,6 +502,7 @@ std::string get_container_json_metadata(at::Tensor& decoder) { return mapToJson(map); } +// Get the stream metadata as a string. std::string get_stream_json_metadata( at::Tensor& decoder, int64_t stream_index) { @@ -519,6 +574,8 @@ std::string get_stream_json_metadata( return mapToJson(map); } +// Returns version information about the various FFMPEG libraries that are +// loaded in the program's address space. std::string _get_json_ffmpeg_library_versions() { std::stringstream ss; ss << "{\n"; @@ -545,6 +602,10 @@ std::string _get_json_ffmpeg_library_versions() { return ss.str(); } +// Scans video packets to get more accurate metadata like frame count, exact +// keyframe positions, etc. Exact keyframe positions are useful for efficient +// accurate seeking. Note that this function reads the entire video but it does +// not decode frames. Reading a video file is much cheaper than decoding it. void scan_all_streams_to_update_metadata(at::Tensor& decoder) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); videoDecoder->scanFileAndUpdateMetadataAndIndex(); diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h deleted file mode 100644 index bc7f2036..00000000 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ /dev/null @@ -1,162 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#pragma once - -#include -#include - -namespace facebook::torchcodec { - -// The following functions are useful for calling the Pytorch C++ ops from C++ -// code. For example, the decoder can be created like so: -// auto createDecoderOp = -// torch::Dispatcher::singleton() -// .findSchemaOrThrow("torchcodec_ns::create_from_file", "") -// .typed(); -// auto decoderTensor = createDecoderOp.call(videoPath); - -// Create a VideoDecoder from file and wrap the pointer in a tensor. -at::Tensor create_from_file( - std::string_view filename, - std::optional seek_mode = std::nullopt); - -at::Tensor create_from_tensor( - at::Tensor video_tensor, - std::optional seek_mode = std::nullopt); - -// Add a new video stream at `stream_index` using the provided options. -void add_video_stream( - at::Tensor& decoder, - std::optional width = std::nullopt, - std::optional height = std::nullopt, - std::optional num_threads = std::nullopt, - std::optional dimension_order = std::nullopt, - std::optional stream_index = std::nullopt, - std::optional device = std::nullopt); - -void _add_video_stream( - at::Tensor& decoder, - std::optional width = std::nullopt, - std::optional height = std::nullopt, - std::optional num_threads = std::nullopt, - std::optional dimension_order = std::nullopt, - std::optional stream_index = std::nullopt, - std::optional device = std::nullopt, - std::optional color_conversion_library = std::nullopt); - -void add_audio_stream( - at::Tensor& decoder, - std::optional stream_index = std::nullopt, - std::optional sample_rate = std::nullopt); - -// Seek to a particular presentation timestamp in the video in seconds. -void seek_to_pts(at::Tensor& decoder, double seconds); - -// The elements of this tuple are all tensors that represent a single frame: -// 1. The frame data, which is a multidimensional tensor. -// 2. A single float value for the pts in seconds. -// 3. A single float value for the duration in seconds. -// The reason we use Tensors for the second and third values is so we can run -// under torch.compile(). -using OpsFrameOutput = std::tuple; - -// All elements of this tuple are tensors of the same leading dimension. The -// tuple represents the frames for N total frames, where N is the dimension of -// each stacked tensor. The elments are: -// 1. Stacked tensor of data for all N frames. Each frame is also a -// multidimensional tensor. -// 2. Tensor of N pts values in seconds, where each pts is a single -// float. -// 3. Tensor of N durationis in seconds, where each duration is a -// single float. -using OpsFrameBatchOutput = std::tuple; - -// The elements of this tuple are all tensors that represent the concatenation -// of multiple audio frames: -// 1. The frames data (concatenated) -// 2. A single float value for the pts of the first frame, in seconds. -using OpsAudioFramesOutput = std::tuple; - -// Return the frame that is visible at a given timestamp in seconds. Each frame -// in FFMPEG has a presentation timestamp and a duration. The frame visible at a -// given timestamp T has T >= PTS and T < PTS + Duration. -OpsFrameOutput get_frame_at_pts(at::Tensor& decoder, double seconds); - -// Return the frames at given ptss for a given stream -OpsFrameBatchOutput get_frames_by_pts( - at::Tensor& decoder, - at::ArrayRef timestamps); - -// Return the frame that is visible at a given index in the video. -OpsFrameOutput get_frame_at_index(at::Tensor& decoder, int64_t frame_index); - -// Get the next frame from the video as a tuple that has the frame data, pts and -// duration as tensors. -OpsFrameOutput get_next_frame(at::Tensor& decoder); - -// Return the frames at given indices for a given stream -OpsFrameBatchOutput get_frames_at_indices( - at::Tensor& decoder, - at::IntArrayRef frame_indices); - -// Return the frames inside a range as a single stacked Tensor. The range is -// defined as [start, stop). -OpsFrameBatchOutput get_frames_in_range( - at::Tensor& decoder, - int64_t start, - int64_t stop, - std::optional step = std::nullopt); - -// Return the frames inside the range as a single stacked Tensor. The range is -// defined as [start_seconds, stop_seconds). The frames are stacked in pts -// order. -OpsFrameBatchOutput get_frames_by_pts_in_range( - at::Tensor& decoder, - double start_seconds, - double stop_seconds); - -OpsAudioFramesOutput get_frames_by_pts_in_range_audio( - at::Tensor& decoder, - double start_seconds, - std::optional stop_seconds = std::nullopt); - -// For testing only. We need to implement this operation as a core library -// function because what we're testing is round-tripping pts values as -// double-precision floating point numbers from C++ to Python and back to C++. -// We want to make sure that the value is preserved exactly, bit-for-bit, during -// this process. -// -// Returns true if for the given decoder, the pts -// value when converted to seconds as a double is exactly pts_seconds_to_test. -// Returns false otherwise. -bool _test_frame_pts_equality( - at::Tensor& decoder, - int64_t frame_index, - double pts_seconds_to_test); - -torch::Tensor _get_key_frame_indices(at::Tensor& decoder); - -// Get the metadata from the video as a string. -std::string get_json_metadata(at::Tensor& decoder); - -// Get the container metadata as a string. -std::string get_container_json_metadata(at::Tensor& decoder); - -// Get the stream metadata as a string. -std::string get_stream_json_metadata(at::Tensor& decoder, int64_t stream_index); - -// Returns version information about the various FFMPEG libraries that are -// loaded in the program's address space. -std::string _get_json_ffmpeg_library_versions(); - -// Scans video packets to get more accurate metadata like frame count, exact -// keyframe positions, etc. Exact keyframe positions are useful for efficient -// accurate seeking. Note that this function reads the entire video but it does -// not decode frames. Reading a video file is much cheaper than decoding it. -void scan_all_streams_to_update_metadata(at::Tensor& decoder); - -} // namespace facebook::torchcodec From c4ae676943a8daefc01d388ba9e9c1b9179adbf4 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Thu, 27 Mar 2025 17:46:14 -0700 Subject: [PATCH 2/2] Rename .cpp files to match the library names --- src/torchcodec/decoders/_core/CMakeLists.txt | 4 ++-- .../decoders/_core/{VideoDecoderOps.cpp => custom_ops.cpp} | 0 .../decoders/_core/{PyBindOps.cpp => pybind_ops.cpp} | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename src/torchcodec/decoders/_core/{VideoDecoderOps.cpp => custom_ops.cpp} (100%) rename src/torchcodec/decoders/_core/{PyBindOps.cpp => pybind_ops.cpp} (100%) diff --git a/src/torchcodec/decoders/_core/CMakeLists.txt b/src/torchcodec/decoders/_core/CMakeLists.txt index f0a8568f..23ef2ca6 100644 --- a/src/torchcodec/decoders/_core/CMakeLists.txt +++ b/src/torchcodec/decoders/_core/CMakeLists.txt @@ -92,7 +92,7 @@ function(make_torchcodec_libraries set(custom_ops_library_name "libtorchcodec_custom_ops${ffmpeg_major_version}") set(custom_ops_sources AVIOBytesContext.cpp - VideoDecoderOps.cpp + custom_ops.cpp ) set(custom_ops_dependencies ${decoder_library_name} @@ -109,7 +109,7 @@ function(make_torchcodec_libraries set(pybind_ops_library_name "libtorchcodec_pybind_ops${ffmpeg_major_version}") set(pybind_ops_sources AVIOFileLikeContext.cpp - PyBindOps.cpp + pybind_ops.cpp ) set(pybind_ops_dependencies ${decoder_library_name} diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/custom_ops.cpp similarity index 100% rename from src/torchcodec/decoders/_core/VideoDecoderOps.cpp rename to src/torchcodec/decoders/_core/custom_ops.cpp diff --git a/src/torchcodec/decoders/_core/PyBindOps.cpp b/src/torchcodec/decoders/_core/pybind_ops.cpp similarity index 100% rename from src/torchcodec/decoders/_core/PyBindOps.cpp rename to src/torchcodec/decoders/_core/pybind_ops.cpp