diff --git a/src/torchcodec/_core/CudaDevice.cpp b/src/torchcodec/_core/CudaDevice.cpp index 5bde4106..4f6c7407 100644 --- a/src/torchcodec/_core/CudaDevice.cpp +++ b/src/torchcodec/_core/CudaDevice.cpp @@ -190,9 +190,9 @@ void CudaDevice::initializeContext(AVCodecContext* codecContext) { } void CudaDevice::convertAVFrameToFrameOutput( - const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions, + const VideoStreamOptions& videoStreamOptions, UniqueAVFrame& avFrame, - SingleStreamDecoder::FrameOutput& frameOutput, + FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) { TORCH_CHECK( avFrame->format == AV_PIX_FMT_CUDA, diff --git a/src/torchcodec/_core/CudaDevice.h b/src/torchcodec/_core/CudaDevice.h index 0ed53859..3aee6e2b 100644 --- a/src/torchcodec/_core/CudaDevice.h +++ b/src/torchcodec/_core/CudaDevice.h @@ -21,9 +21,9 @@ class CudaDevice : public DeviceInterface { void initializeContext(AVCodecContext* codecContext) override; void convertAVFrameToFrameOutput( - const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions, + const VideoStreamOptions& videoStreamOptions, UniqueAVFrame& avFrame, - SingleStreamDecoder::FrameOutput& frameOutput, + FrameOutput& frameOutput, std::optional preAllocatedOutputTensor = std::nullopt) override; diff --git a/src/torchcodec/_core/DeviceInterface.h b/src/torchcodec/_core/DeviceInterface.h index a5b0e365..b4197d7d 100644 --- a/src/torchcodec/_core/DeviceInterface.h +++ b/src/torchcodec/_core/DeviceInterface.h @@ -12,7 +12,8 @@ #include #include #include "FFMPEGCommon.h" -#include "src/torchcodec/_core/SingleStreamDecoder.h" +#include "src/torchcodec/_core/Frame.h" +#include "src/torchcodec/_core/StreamOptions.h" namespace facebook::torchcodec { @@ -41,9 +42,9 @@ class DeviceInterface { virtual void initializeContext(AVCodecContext* codecContext) = 0; virtual void convertAVFrameToFrameOutput( - const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions, + const VideoStreamOptions& videoStreamOptions, UniqueAVFrame& avFrame, - SingleStreamDecoder::FrameOutput& frameOutput, + FrameOutput& frameOutput, std::optional preAllocatedOutputTensor = std::nullopt) = 0; protected: diff --git a/src/torchcodec/_core/Frame.h b/src/torchcodec/_core/Frame.h new file mode 100644 index 00000000..aa689734 --- /dev/null +++ b/src/torchcodec/_core/Frame.h @@ -0,0 +1,47 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include "src/torchcodec/_core/Metadata.h" +#include "src/torchcodec/_core/StreamOptions.h" + +namespace facebook::torchcodec { + +// All public video decoding entry points return either a FrameOutput or a +// FrameBatchOutput. +// They are the equivalent of the user-facing Frame and FrameBatch classes in +// Python. They contain RGB decoded frames along with some associated data +// like PTS and duration. +// FrameOutput is also relevant for audio decoding, typically as the output of +// getNextFrame(), or as a temporary output variable. +struct FrameOutput { + // data shape is: + // - 3D (C, H, W) or (H, W, C) for videos + // - 2D (numChannels, numSamples) for audio + torch::Tensor data; + double ptsSeconds; + double durationSeconds; +}; + +struct FrameBatchOutput { + torch::Tensor data; // 4D: of shape NCHW or NHWC. + torch::Tensor ptsSeconds; // 1D of shape (N,) + torch::Tensor durationSeconds; // 1D of shape (N,) + + explicit FrameBatchOutput( + int64_t numFrames, + const VideoStreamOptions& videoStreamOptions, + const StreamMetadata& streamMetadata); +}; + +struct AudioFramesOutput { + torch::Tensor data; // shape is (numChannels, numSamples) + double ptsSeconds; +}; + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/Metadata.h b/src/torchcodec/_core/Metadata.h new file mode 100644 index 00000000..a8f300f4 --- /dev/null +++ b/src/torchcodec/_core/Metadata.h @@ -0,0 +1,70 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include + +extern "C" { +#include +#include +} + +namespace facebook::torchcodec { + +struct StreamMetadata { + // Common (video and audio) fields derived from the AVStream. + int streamIndex; + // See this link for what various values are available: + // https://ffmpeg.org/doxygen/trunk/group__lavu__misc.html#ga9a84bba4713dfced21a1a56163be1f48 + AVMediaType mediaType; + std::optional codecId; + std::optional codecName; + std::optional durationSeconds; + std::optional beginStreamFromHeader; + std::optional numFrames; + std::optional numKeyFrames; + std::optional averageFps; + std::optional bitRate; + + // More accurate duration, obtained by scanning the file. + // These presentation timestamps are in time base. + std::optional minPtsFromScan; + std::optional maxPtsFromScan; + // These presentation timestamps are in seconds. + std::optional minPtsSecondsFromScan; + std::optional maxPtsSecondsFromScan; + // This can be useful for index-based seeking. + std::optional numFramesFromScan; + + // Video-only fields derived from the AVCodecContext. + std::optional width; + std::optional height; + + // Audio-only fields + std::optional sampleRate; + std::optional numChannels; + std::optional sampleFormat; +}; + +struct ContainerMetadata { + std::vector allStreamMetadata; + int numAudioStreams = 0; + int numVideoStreams = 0; + // Note that this is the container-level duration, which is usually the max + // of all stream durations available in the container. + std::optional durationSeconds; + // Total BitRate level information at the container level in bit/s + std::optional bitRate; + // If set, this is the index to the default audio stream. + std::optional bestAudioStreamIndex; + // If set, this is the index to the default video stream. + std::optional bestVideoStreamIndex; +}; + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index c7c714da..c389242c 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -13,7 +13,6 @@ #include #include #include -#include "src/torchcodec/_core/DeviceInterface.h" #include "torch/types.h" extern "C" { @@ -350,8 +349,7 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() { scannedAllStreams_ = true; } -SingleStreamDecoder::ContainerMetadata -SingleStreamDecoder::getContainerMetadata() const { +ContainerMetadata SingleStreamDecoder::getContainerMetadata() const { return containerMetadata_; } @@ -406,7 +404,7 @@ void SingleStreamDecoder::addStream( streamInfo.stream = formatContext_->streams[activeStreamIndex_]; streamInfo.avMediaType = mediaType; - deviceInterface = createDeviceInterface(device); + deviceInterface_ = createDeviceInterface(device); // This should never happen, checking just to be safe. TORCH_CHECK( @@ -418,9 +416,9 @@ void SingleStreamDecoder::addStream( // TODO_CODE_QUALITY it's pretty meh to have a video-specific logic within // addStream() which is supposed to be generic if (mediaType == AVMEDIA_TYPE_VIDEO) { - if (deviceInterface) { + if (deviceInterface_) { avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream( - deviceInterface->findCodec(streamInfo.stream->codecpar->codec_id) + deviceInterface_->findCodec(streamInfo.stream->codecpar->codec_id) .value_or(avCodec)); } } @@ -438,8 +436,8 @@ void SingleStreamDecoder::addStream( // TODO_CODE_QUALITY same as above. if (mediaType == AVMEDIA_TYPE_VIDEO) { - if (deviceInterface) { - deviceInterface->initializeContext(codecContext); + if (deviceInterface_) { + deviceInterface_->initializeContext(codecContext); } } @@ -501,9 +499,8 @@ void SingleStreamDecoder::addVideoStream( // swscale requires widths to be multiples of 32: // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements // so we fall back to filtergraph if the width is not a multiple of 32. - auto defaultLibrary = (width % 32 == 0) - ? SingleStreamDecoder::ColorConversionLibrary::SWSCALE - : SingleStreamDecoder::ColorConversionLibrary::FILTERGRAPH; + auto defaultLibrary = (width % 32 == 0) ? ColorConversionLibrary::SWSCALE + : ColorConversionLibrary::FILTERGRAPH; streamInfo.colorConversionLibrary = videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary); @@ -539,7 +536,7 @@ void SingleStreamDecoder::addAudioStream( // HIGH-LEVEL DECODING ENTRY-POINTS // -------------------------------------------------------------------------- -SingleStreamDecoder::FrameOutput SingleStreamDecoder::getNextFrame() { +FrameOutput SingleStreamDecoder::getNextFrame() { auto output = getNextFrameInternal(); if (streamInfos_[activeStreamIndex_].avMediaType == AVMEDIA_TYPE_VIDEO) { output.data = maybePermuteHWC2CHW(output.data); @@ -547,7 +544,7 @@ SingleStreamDecoder::FrameOutput SingleStreamDecoder::getNextFrame() { return output; } -SingleStreamDecoder::FrameOutput SingleStreamDecoder::getNextFrameInternal( +FrameOutput SingleStreamDecoder::getNextFrameInternal( std::optional preAllocatedOutputTensor) { validateActiveStream(); UniqueAVFrame avFrame = decodeAVFrame( @@ -555,14 +552,13 @@ SingleStreamDecoder::FrameOutput SingleStreamDecoder::getNextFrameInternal( return convertAVFrameToFrameOutput(avFrame, preAllocatedOutputTensor); } -SingleStreamDecoder::FrameOutput SingleStreamDecoder::getFrameAtIndex( - int64_t frameIndex) { +FrameOutput SingleStreamDecoder::getFrameAtIndex(int64_t frameIndex) { auto frameOutput = getFrameAtIndexInternal(frameIndex); frameOutput.data = maybePermuteHWC2CHW(frameOutput.data); return frameOutput; } -SingleStreamDecoder::FrameOutput SingleStreamDecoder::getFrameAtIndexInternal( +FrameOutput SingleStreamDecoder::getFrameAtIndexInternal( int64_t frameIndex, std::optional preAllocatedOutputTensor) { validateActiveStream(AVMEDIA_TYPE_VIDEO); @@ -577,7 +573,7 @@ SingleStreamDecoder::FrameOutput SingleStreamDecoder::getFrameAtIndexInternal( return getNextFrameInternal(preAllocatedOutputTensor); } -SingleStreamDecoder::FrameBatchOutput SingleStreamDecoder::getFramesAtIndices( +FrameBatchOutput SingleStreamDecoder::getFramesAtIndices( const std::vector& frameIndices) { validateActiveStream(AVMEDIA_TYPE_VIDEO); @@ -636,7 +632,7 @@ SingleStreamDecoder::FrameBatchOutput SingleStreamDecoder::getFramesAtIndices( return frameBatchOutput; } -SingleStreamDecoder::FrameBatchOutput SingleStreamDecoder::getFramesInRange( +FrameBatchOutput SingleStreamDecoder::getFramesInRange( int64_t start, int64_t stop, int64_t step) { @@ -670,8 +666,7 @@ SingleStreamDecoder::FrameBatchOutput SingleStreamDecoder::getFramesInRange( return frameBatchOutput; } -SingleStreamDecoder::FrameOutput SingleStreamDecoder::getFramePlayedAt( - double seconds) { +FrameOutput SingleStreamDecoder::getFramePlayedAt(double seconds) { validateActiveStream(AVMEDIA_TYPE_VIDEO); StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; double frameStartTime = @@ -711,7 +706,7 @@ SingleStreamDecoder::FrameOutput SingleStreamDecoder::getFramePlayedAt( return frameOutput; } -SingleStreamDecoder::FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt( +FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt( const std::vector& timestamps) { validateActiveStream(AVMEDIA_TYPE_VIDEO); @@ -741,8 +736,7 @@ SingleStreamDecoder::FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt( return getFramesAtIndices(frameIndices); } -SingleStreamDecoder::FrameBatchOutput -SingleStreamDecoder::getFramesPlayedInRange( +FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange( double startSeconds, double stopSeconds) { validateActiveStream(AVMEDIA_TYPE_VIDEO); @@ -875,8 +869,7 @@ SingleStreamDecoder::getFramesPlayedInRange( // [2] If you're brave and curious, you can read the long "Seek offset for // audio" note in https://github.com/pytorch/torchcodec/pull/507/files, which // sums up past (and failed) attemps at working around this issue. -SingleStreamDecoder::AudioFramesOutput -SingleStreamDecoder::getFramesPlayedInRangeAudio( +AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio( double startSeconds, std::optional stopSecondsOptional) { validateActiveStream(AVMEDIA_TYPE_AUDIO); @@ -1196,8 +1189,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( // AVFRAME <-> FRAME OUTPUT CONVERSION // -------------------------------------------------------------------------- -SingleStreamDecoder::FrameOutput -SingleStreamDecoder::convertAVFrameToFrameOutput( +FrameOutput SingleStreamDecoder::convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, std::optional preAllocatedOutputTensor) { // Convert the frame to tensor. @@ -1210,11 +1202,11 @@ SingleStreamDecoder::convertAVFrameToFrameOutput( formatContext_->streams[activeStreamIndex_]->time_base); if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) { convertAudioAVFrameToFrameOutputOnCPU(avFrame, frameOutput); - } else if (!deviceInterface) { + } else if (!deviceInterface_) { convertAVFrameToFrameOutputOnCPU( avFrame, frameOutput, preAllocatedOutputTensor); - } else if (deviceInterface) { - deviceInterface->convertAVFrameToFrameOutput( + } else if (deviceInterface_) { + deviceInterface_->convertAVFrameToFrameOutput( streamInfo.videoStreamOptions, avFrame, frameOutput, @@ -1547,7 +1539,7 @@ std::optional SingleStreamDecoder::maybeFlushSwrBuffers() { // OUTPUT ALLOCATION AND SHAPE CONVERSION // -------------------------------------------------------------------------- -SingleStreamDecoder::FrameBatchOutput::FrameBatchOutput( +FrameBatchOutput::FrameBatchOutput( int64_t numFrames, const VideoStreamOptions& videoStreamOptions, const StreamMetadata& streamMetadata) @@ -2047,15 +2039,15 @@ FrameDims getHeightAndWidthFromResizedAVFrame(const AVFrame& resizedAVFrame) { } FrameDims getHeightAndWidthFromOptionsOrMetadata( - const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions, - const SingleStreamDecoder::StreamMetadata& streamMetadata) { + const VideoStreamOptions& videoStreamOptions, + const StreamMetadata& streamMetadata) { return FrameDims( videoStreamOptions.height.value_or(*streamMetadata.height), videoStreamOptions.width.value_or(*streamMetadata.width)); } FrameDims getHeightAndWidthFromOptionsOrAVFrame( - const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions, + const VideoStreamOptions& videoStreamOptions, const UniqueAVFrame& avFrame) { return FrameDims( videoStreamOptions.height.value_or(avFrame->height), diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 4879a3b7..7b275a20 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -13,10 +13,12 @@ #include #include "src/torchcodec/_core/AVIOContextHolder.h" +#include "src/torchcodec/_core/DeviceInterface.h" #include "src/torchcodec/_core/FFMPEGCommon.h" +#include "src/torchcodec/_core/Frame.h" +#include "src/torchcodec/_core/StreamOptions.h" namespace facebook::torchcodec { -class DeviceInterface; // The SingleStreamDecoder class can be used to decode video frames to Tensors. // Note that SingleStreamDecoder is not thread-safe. @@ -51,56 +53,6 @@ class SingleStreamDecoder { // the allFrames and keyFrames vectors. void scanFileAndUpdateMetadataAndIndex(); - struct StreamMetadata { - // Common (video and audio) fields derived from the AVStream. - int streamIndex; - // See this link for what various values are available: - // https://ffmpeg.org/doxygen/trunk/group__lavu__misc.html#ga9a84bba4713dfced21a1a56163be1f48 - AVMediaType mediaType; - std::optional codecId; - std::optional codecName; - std::optional durationSeconds; - std::optional beginStreamFromHeader; - std::optional numFrames; - std::optional numKeyFrames; - std::optional averageFps; - std::optional bitRate; - - // More accurate duration, obtained by scanning the file. - // These presentation timestamps are in time base. - std::optional minPtsFromScan; - std::optional maxPtsFromScan; - // These presentation timestamps are in seconds. - std::optional minPtsSecondsFromScan; - std::optional maxPtsSecondsFromScan; - // This can be useful for index-based seeking. - std::optional numFramesFromScan; - - // Video-only fields derived from the AVCodecContext. - std::optional width; - std::optional height; - - // Audio-only fields - std::optional sampleRate; - std::optional numChannels; - std::optional sampleFormat; - }; - - struct ContainerMetadata { - std::vector allStreamMetadata; - int numAudioStreams = 0; - int numVideoStreams = 0; - // Note that this is the container-level duration, which is usually the max - // of all stream durations available in the container. - std::optional durationSeconds; - // Total BitRate level information at the container level in bit/s - std::optional bitRate; - // If set, this is the index to the default audio stream. - std::optional bestAudioStreamIndex; - // If set, this is the index to the default video stream. - std::optional bestVideoStreamIndex; - }; - // Returns the metadata for the container. ContainerMetadata getContainerMetadata() const; @@ -112,40 +64,6 @@ class SingleStreamDecoder { // ADDING STREAMS API // -------------------------------------------------------------------------- - enum ColorConversionLibrary { - // TODO: Add an AUTO option later. - // Use the libavfilter library for color conversion. - FILTERGRAPH, - // Use the libswscale library for color conversion. - SWSCALE - }; - - struct VideoStreamOptions { - VideoStreamOptions() {} - - // Number of threads we pass to FFMPEG for decoding. - // 0 means FFMPEG will choose the number of threads automatically to fully - // utilize all cores. If not set, it will be the default FFMPEG behavior for - // the given codec. - std::optional ffmpegThreadCount; - // Currently the dimension order can be either NHWC or NCHW. - // H=height, W=width, C=channel. - std::string dimensionOrder = "NCHW"; - // The output height and width of the frame. If not specified, the output - // is the same as the original video. - std::optional width; - std::optional height; - std::optional colorConversionLibrary; - // By default we use CPU for decoding for both C++ and python users. - torch::Device device = torch::kCPU; - }; - - struct AudioStreamOptions { - AudioStreamOptions() {} - - std::optional sampleRate; - }; - void addVideoStream( int streamIndex, const VideoStreamOptions& videoStreamOptions = VideoStreamOptions()); @@ -157,38 +75,6 @@ class SingleStreamDecoder { // DECODING AND SEEKING APIs // -------------------------------------------------------------------------- - // All public video decoding entry points return either a FrameOutput or a - // FrameBatchOutput. - // They are the equivalent of the user-facing Frame and FrameBatch classes in - // Python. They contain RGB decoded frames along with some associated data - // like PTS and duration. - // FrameOutput is also relevant for audio decoding, typically as the output of - // getNextFrame(), or as a temporary output variable. - struct FrameOutput { - // data shape is: - // - 3D (C, H, W) or (H, W, C) for videos - // - 2D (numChannels, numSamples) for audio - torch::Tensor data; - double ptsSeconds; - double durationSeconds; - }; - - struct FrameBatchOutput { - torch::Tensor data; // 4D: of shape NCHW or NHWC. - torch::Tensor ptsSeconds; // 1D of shape (N,) - torch::Tensor durationSeconds; // 1D of shape (N,) - - explicit FrameBatchOutput( - int64_t numFrames, - const VideoStreamOptions& videoStreamOptions, - const StreamMetadata& streamMetadata); - }; - - struct AudioFramesOutput { - torch::Tensor data; // shape is (numChannels, numSamples) - double ptsSeconds; - }; - // Places the cursor at the first frame on or after the position in seconds. // Calling getNextFrame() will return the first frame at // or after this position. @@ -492,7 +378,7 @@ class SingleStreamDecoder { SeekMode seekMode_; ContainerMetadata containerMetadata_; UniqueDecodingAVFormatContext formatContext_; - std::unique_ptr deviceInterface; + std::unique_ptr deviceInterface_; std::map streamInfos_; const int NO_ACTIVE_STREAM = -2; int activeStreamIndex_ = NO_ACTIVE_STREAM; @@ -568,11 +454,11 @@ struct FrameDims { FrameDims getHeightAndWidthFromResizedAVFrame(const AVFrame& resizedAVFrame); FrameDims getHeightAndWidthFromOptionsOrMetadata( - const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions, - const SingleStreamDecoder::StreamMetadata& streamMetadata); + const VideoStreamOptions& videoStreamOptions, + const StreamMetadata& streamMetadata); FrameDims getHeightAndWidthFromOptionsOrAVFrame( - const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions, + const VideoStreamOptions& videoStreamOptions, const UniqueAVFrame& avFrame); torch::Tensor allocateEmptyHWCTensor( diff --git a/src/torchcodec/_core/StreamOptions.h b/src/torchcodec/_core/StreamOptions.h new file mode 100644 index 00000000..38e51209 --- /dev/null +++ b/src/torchcodec/_core/StreamOptions.h @@ -0,0 +1,49 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include + +namespace facebook::torchcodec { + +enum ColorConversionLibrary { + // TODO: Add an AUTO option later. + // Use the libavfilter library for color conversion. + FILTERGRAPH, + // Use the libswscale library for color conversion. + SWSCALE +}; + +struct VideoStreamOptions { + VideoStreamOptions() {} + + // Number of threads we pass to FFMPEG for decoding. + // 0 means FFMPEG will choose the number of threads automatically to fully + // utilize all cores. If not set, it will be the default FFMPEG behavior for + // the given codec. + std::optional ffmpegThreadCount; + // Currently the dimension order can be either NHWC or NCHW. + // H=height, W=width, C=channel. + std::string dimensionOrder = "NCHW"; + // The output height and width of the frame. If not specified, the output + // is the same as the original video. + std::optional width; + std::optional height; + std::optional colorConversionLibrary; + // By default we use CPU for decoding for both C++ and python users. + torch::Device device = torch::kCPU; +}; + +struct AudioStreamOptions { + AudioStreamOptions() {} + + std::optional sampleRate; +}; + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 05a6390d..9a9b8776 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -11,7 +11,6 @@ #include "c10/core/SymIntArrayRef.h" #include "c10/util/Exception.h" #include "src/torchcodec/_core/AVIOBytesContext.h" -#include "src/torchcodec/_core/DeviceInterface.h" #include "src/torchcodec/_core/Encoder.h" #include "src/torchcodec/_core/SingleStreamDecoder.h" @@ -98,7 +97,7 @@ SingleStreamDecoder* unwrapTensorToGetDecoder(at::Tensor& tensor) { // under torch.compile(). using OpsFrameOutput = std::tuple; -OpsFrameOutput makeOpsFrameOutput(SingleStreamDecoder::FrameOutput& frame) { +OpsFrameOutput makeOpsFrameOutput(FrameOutput& frame) { return std::make_tuple( frame.data, torch::tensor(frame.ptsSeconds, torch::dtype(torch::kFloat64)), @@ -116,8 +115,7 @@ OpsFrameOutput makeOpsFrameOutput(SingleStreamDecoder::FrameOutput& frame) { // single float. using OpsFrameBatchOutput = std::tuple; -OpsFrameBatchOutput makeOpsFrameBatchOutput( - SingleStreamDecoder::FrameBatchOutput& batch) { +OpsFrameBatchOutput makeOpsFrameBatchOutput(FrameBatchOutput& batch) { return std::make_tuple(batch.data, batch.ptsSeconds, batch.durationSeconds); } @@ -127,8 +125,7 @@ OpsFrameBatchOutput makeOpsFrameBatchOutput( // 2. A single float value for the pts of the first frame, in seconds. using OpsAudioFramesOutput = std::tuple; -OpsAudioFramesOutput makeOpsAudioFramesOutput( - SingleStreamDecoder::AudioFramesOutput& audioFrames) { +OpsAudioFramesOutput makeOpsAudioFramesOutput(AudioFramesOutput& audioFrames) { return std::make_tuple( audioFrames.data, torch::tensor(audioFrames.ptsSeconds, torch::dtype(torch::kFloat64))); @@ -218,7 +215,7 @@ void _add_video_stream( std::optional stream_index = std::nullopt, std::optional device = std::nullopt, std::optional color_conversion_library = std::nullopt) { - SingleStreamDecoder::VideoStreamOptions videoStreamOptions; + VideoStreamOptions videoStreamOptions; videoStreamOptions.width = width; videoStreamOptions.height = height; videoStreamOptions.ffmpegThreadCount = num_threads; @@ -232,10 +229,10 @@ void _add_video_stream( std::string stdColorConversionLibrary{color_conversion_library.value()}; if (stdColorConversionLibrary == "filtergraph") { videoStreamOptions.colorConversionLibrary = - SingleStreamDecoder::ColorConversionLibrary::FILTERGRAPH; + ColorConversionLibrary::FILTERGRAPH; } else if (stdColorConversionLibrary == "swscale") { videoStreamOptions.colorConversionLibrary = - SingleStreamDecoder::ColorConversionLibrary::SWSCALE; + ColorConversionLibrary::SWSCALE; } else { throw std::runtime_error( "Invalid color_conversion_library=" + stdColorConversionLibrary + @@ -273,7 +270,7 @@ void add_audio_stream( at::Tensor& decoder, std::optional stream_index = std::nullopt, std::optional sample_rate = std::nullopt) { - SingleStreamDecoder::AudioStreamOptions audioStreamOptions; + AudioStreamOptions audioStreamOptions; audioStreamOptions.sampleRate = sample_rate; auto videoDecoder = unwrapTensorToGetDecoder(decoder); @@ -291,7 +288,7 @@ void seek_to_pts(at::Tensor& decoder, double seconds) { // duration as tensors. OpsFrameOutput get_next_frame(at::Tensor& decoder) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); - SingleStreamDecoder::FrameOutput result; + FrameOutput result; try { result = videoDecoder->getNextFrame(); } catch (const SingleStreamDecoder::EndOfFileException& e) { @@ -305,7 +302,7 @@ OpsFrameOutput get_next_frame(at::Tensor& decoder) { // given timestamp T has T >= PTS and T < PTS + Duration. OpsFrameOutput get_frame_at_pts(at::Tensor& decoder, double seconds) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); - SingleStreamDecoder::FrameOutput result; + FrameOutput result; try { result = videoDecoder->getFramePlayedAt(seconds); } catch (const SingleStreamDecoder::EndOfFileException& e) { @@ -443,8 +440,7 @@ torch::Tensor _get_key_frame_indices(at::Tensor& decoder) { std::string get_json_metadata(at::Tensor& decoder) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); - SingleStreamDecoder::ContainerMetadata videoMetadata = - videoDecoder->getContainerMetadata(); + ContainerMetadata videoMetadata = videoDecoder->getContainerMetadata(); auto maybeBestVideoStreamIndex = videoMetadata.bestVideoStreamIndex; std::map metadataMap; diff --git a/test/VideoDecoderTest.cpp b/test/VideoDecoderTest.cpp index 1937ff97..a30609c2 100644 --- a/test/VideoDecoderTest.cpp +++ b/test/VideoDecoderTest.cpp @@ -5,7 +5,6 @@ // LICENSE file in the root directory of this source tree. #include "src/torchcodec/_core/AVIOBytesContext.h" -#include "src/torchcodec/_core/DeviceInterface.h" #include "src/torchcodec/_core/SingleStreamDecoder.h" #include @@ -69,8 +68,7 @@ TEST_P(SingleStreamDecoderTest, ReturnsFpsAndDurationForVideoInMetadata) { std::string path = getResourcePath("nasa_13013.mp4"); std::unique_ptr decoder = createDecoderFromPath(path, GetParam()); - SingleStreamDecoder::ContainerMetadata metadata = - decoder->getContainerMetadata(); + ContainerMetadata metadata = decoder->getContainerMetadata(); EXPECT_EQ(metadata.numAudioStreams, 2); EXPECT_EQ(metadata.numVideoStreams, 2); #if LIBAVFORMAT_VERSION_MAJOR >= 60 @@ -150,7 +148,7 @@ TEST(SingleStreamDecoderTest, RespectsWidthAndHeightFromOptions) { std::string path = getResourcePath("nasa_13013.mp4"); std::unique_ptr decoder = std::make_unique(path); - SingleStreamDecoder::VideoStreamOptions videoStreamOptions; + VideoStreamOptions videoStreamOptions; videoStreamOptions.width = 100; videoStreamOptions.height = 120; decoder->addVideoStream(-1, videoStreamOptions); @@ -162,7 +160,7 @@ TEST(SingleStreamDecoderTest, RespectsOutputTensorDimensionOrderFromOptions) { std::string path = getResourcePath("nasa_13013.mp4"); std::unique_ptr decoder = std::make_unique(path); - SingleStreamDecoder::VideoStreamOptions videoStreamOptions; + VideoStreamOptions videoStreamOptions; videoStreamOptions.dimensionOrder = "NHWC"; decoder->addVideoStream(-1, videoStreamOptions); torch::Tensor tensor = decoder->getNextFrame().data; @@ -234,7 +232,7 @@ TEST_P(SingleStreamDecoderTest, DecodesFramesInABatchInNHWC) { ourDecoder->scanFileAndUpdateMetadataAndIndex(); int bestVideoStreamIndex = *ourDecoder->getContainerMetadata().bestVideoStreamIndex; - SingleStreamDecoder::VideoStreamOptions videoStreamOptions; + VideoStreamOptions videoStreamOptions; videoStreamOptions.dimensionOrder = "NHWC"; ourDecoder->addVideoStream(bestVideoStreamIndex, videoStreamOptions); // Frame with index 180 corresponds to timestamp 6.006. @@ -399,9 +397,9 @@ TEST_P(SingleStreamDecoderTest, PreAllocatedTensorFilterGraph) { ourDecoder->scanFileAndUpdateMetadataAndIndex(); int bestVideoStreamIndex = *ourDecoder->getContainerMetadata().bestVideoStreamIndex; - SingleStreamDecoder::VideoStreamOptions videoStreamOptions; + VideoStreamOptions videoStreamOptions; videoStreamOptions.colorConversionLibrary = - SingleStreamDecoder::ColorConversionLibrary::FILTERGRAPH; + ColorConversionLibrary::FILTERGRAPH; ourDecoder->addVideoStream(bestVideoStreamIndex, videoStreamOptions); auto output = ourDecoder->getFrameAtIndexInternal(0, preAllocatedOutputTensor); @@ -417,9 +415,8 @@ TEST_P(SingleStreamDecoderTest, PreAllocatedTensorSwscale) { ourDecoder->scanFileAndUpdateMetadataAndIndex(); int bestVideoStreamIndex = *ourDecoder->getContainerMetadata().bestVideoStreamIndex; - SingleStreamDecoder::VideoStreamOptions videoStreamOptions; - videoStreamOptions.colorConversionLibrary = - SingleStreamDecoder::ColorConversionLibrary::SWSCALE; + VideoStreamOptions videoStreamOptions; + videoStreamOptions.colorConversionLibrary = ColorConversionLibrary::SWSCALE; ourDecoder->addVideoStream(bestVideoStreamIndex, videoStreamOptions); auto output = ourDecoder->getFrameAtIndexInternal(0, preAllocatedOutputTensor); @@ -430,8 +427,7 @@ TEST_P(SingleStreamDecoderTest, GetAudioMetadata) { std::string path = getResourcePath("nasa_13013.mp4.audio.mp3"); std::unique_ptr decoder = createDecoderFromPath(path, GetParam()); - SingleStreamDecoder::ContainerMetadata metadata = - decoder->getContainerMetadata(); + ContainerMetadata metadata = decoder->getContainerMetadata(); EXPECT_EQ(metadata.numAudioStreams, 1); EXPECT_EQ(metadata.numVideoStreams, 0); EXPECT_EQ(metadata.allStreamMetadata.size(), 1);