diff --git a/src/torchcodec/_core/AVIOContextHolder.h b/src/torchcodec/_core/AVIOContextHolder.h index 69b32f29..3b094c26 100644 --- a/src/torchcodec/_core/AVIOContextHolder.h +++ b/src/torchcodec/_core/AVIOContextHolder.h @@ -32,7 +32,7 @@ namespace facebook::torchcodec { // createAVIOContext(), ideally in their constructor. // 3. A generic handle for those that just need to manage having access to an // AVIOContext, but aren't necessarily concerned with how it was customized: -// typically, the VideoDecoder. +// typically, the SingleStreamDecoder. class AVIOContextHolder { public: virtual ~AVIOContextHolder(); diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index bb8f9660..918da235 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -60,7 +60,7 @@ function(make_torchcodec_libraries set(decoder_sources AVIOContextHolder.cpp FFMPEGCommon.cpp - VideoDecoder.cpp + SingleStreamDecoder.cpp ) if(ENABLE_CUDA) diff --git a/src/torchcodec/_core/CPUOnlyDevice.cpp b/src/torchcodec/_core/CPUOnlyDevice.cpp index ad913171..1d5b477d 100644 --- a/src/torchcodec/_core/CPUOnlyDevice.cpp +++ b/src/torchcodec/_core/CPUOnlyDevice.cpp @@ -16,9 +16,10 @@ namespace facebook::torchcodec { void convertAVFrameToFrameOutputOnCuda( const torch::Device& device, - [[maybe_unused]] const VideoDecoder::VideoStreamOptions& videoStreamOptions, + [[maybe_unused]] const SingleStreamDecoder::VideoStreamOptions& + videoStreamOptions, [[maybe_unused]] UniqueAVFrame& avFrame, - [[maybe_unused]] VideoDecoder::FrameOutput& frameOutput, + [[maybe_unused]] SingleStreamDecoder::FrameOutput& frameOutput, [[maybe_unused]] std::optional preAllocatedOutputTensor) { throwUnsupportedDeviceError(device); } diff --git a/src/torchcodec/_core/CudaDevice.cpp b/src/torchcodec/_core/CudaDevice.cpp index 41596cb8..fd8be9de 100644 --- a/src/torchcodec/_core/CudaDevice.cpp +++ b/src/torchcodec/_core/CudaDevice.cpp @@ -6,7 +6,7 @@ #include "src/torchcodec/_core/DeviceInterface.h" #include "src/torchcodec/_core/FFMPEGCommon.h" -#include "src/torchcodec/_core/VideoDecoder.h" +#include "src/torchcodec/_core/SingleStreamDecoder.h" extern "C" { #include @@ -20,9 +20,11 @@ namespace { // creating a cuda context is expensive. The cache mechanism is as follows: // 1. There is a cache of size MAX_CONTEXTS_PER_GPU_IN_CACHE cuda contexts for // each GPU. -// 2. When we destroy a VideoDecoder instance we release the cuda context to +// 2. When we destroy a SingleStreamDecoder instance we release the cuda context +// to // the cache if the cache is not full. -// 3. When we create a VideoDecoder instance we try to get a cuda context from +// 3. When we create a SingleStreamDecoder instance we try to get a cuda context +// from // the cache. If the cache is empty we create a new cuda context. // Pytorch can only handle up to 128 GPUs. @@ -189,9 +191,9 @@ void initializeContextOnCuda( void convertAVFrameToFrameOutputOnCuda( const torch::Device& device, - const VideoDecoder::VideoStreamOptions& videoStreamOptions, + const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions, UniqueAVFrame& avFrame, - VideoDecoder::FrameOutput& frameOutput, + SingleStreamDecoder::FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) { TORCH_CHECK( avFrame->format == AV_PIX_FMT_CUDA, diff --git a/src/torchcodec/_core/DeviceInterface.h b/src/torchcodec/_core/DeviceInterface.h index d2f5940c..352b83d3 100644 --- a/src/torchcodec/_core/DeviceInterface.h +++ b/src/torchcodec/_core/DeviceInterface.h @@ -11,13 +11,13 @@ #include #include #include "FFMPEGCommon.h" -#include "src/torchcodec/_core/VideoDecoder.h" +#include "src/torchcodec/_core/SingleStreamDecoder.h" namespace facebook::torchcodec { // Note that all these device functions should only be called if the device is // not a CPU device. CPU device functions are already implemented in the -// VideoDecoder implementation. +// SingleStreamDecoder implementation. // These functions should only be called from within an if block like this: // if (device.type() != torch::kCPU) { // deviceFunction(device, ...); @@ -31,9 +31,9 @@ void initializeContextOnCuda( void convertAVFrameToFrameOutputOnCuda( const torch::Device& device, - const VideoDecoder::VideoStreamOptions& videoStreamOptions, + const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions, UniqueAVFrame& avFrame, - VideoDecoder::FrameOutput& frameOutput, + SingleStreamDecoder::FrameOutput& frameOutput, std::optional preAllocatedOutputTensor = std::nullopt); void releaseContextOnCuda( diff --git a/src/torchcodec/_core/VideoDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp similarity index 94% rename from src/torchcodec/_core/VideoDecoder.cpp rename to src/torchcodec/_core/SingleStreamDecoder.cpp index 08954342..725854a8 100644 --- a/src/torchcodec/_core/VideoDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -4,7 +4,7 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. -#include "src/torchcodec/_core/VideoDecoder.h" +#include "src/torchcodec/_core/SingleStreamDecoder.h" #include #include #include @@ -49,7 +49,9 @@ int64_t secondsToClosestPts(double seconds, const AVRational& timeBase) { // CONSTRUCTORS, INITIALIZATION, DESTRUCTORS // -------------------------------------------------------------------------- -VideoDecoder::VideoDecoder(const std::string& videoFilePath, SeekMode seekMode) +SingleStreamDecoder::SingleStreamDecoder( + const std::string& videoFilePath, + SeekMode seekMode) : seekMode_(seekMode) { setFFmpegLogLevel(); @@ -66,7 +68,7 @@ VideoDecoder::VideoDecoder(const std::string& videoFilePath, SeekMode seekMode) initializeDecoder(); } -VideoDecoder::VideoDecoder( +SingleStreamDecoder::SingleStreamDecoder( std::unique_ptr context, SeekMode seekMode) : seekMode_(seekMode), avioContextHolder_(std::move(context)) { @@ -95,7 +97,7 @@ VideoDecoder::VideoDecoder( initializeDecoder(); } -VideoDecoder::~VideoDecoder() { +SingleStreamDecoder::~SingleStreamDecoder() { for (auto& [streamIndex, streamInfo] : streamInfos_) { auto& device = streamInfo.videoStreamOptions.device; if (device.type() == torch::kCPU) { @@ -107,7 +109,7 @@ VideoDecoder::~VideoDecoder() { } } -void VideoDecoder::initializeDecoder() { +void SingleStreamDecoder::initializeDecoder() { TORCH_CHECK(!initialized_, "Attempted double initialization."); // In principle, the AVFormatContext should be filled in by the call to @@ -200,7 +202,7 @@ void VideoDecoder::initializeDecoder() { initialized_ = true; } -void VideoDecoder::setFFmpegLogLevel() { +void SingleStreamDecoder::setFFmpegLogLevel() { auto logLevel = AV_LOG_QUIET; const char* logLevelEnv = std::getenv("TORCHCODEC_FFMPEG_LOG_LEVEL"); if (logLevelEnv != nullptr) { @@ -233,7 +235,7 @@ void VideoDecoder::setFFmpegLogLevel() { av_log_set_level(logLevel); } -int VideoDecoder::getBestStreamIndex(AVMediaType mediaType) { +int SingleStreamDecoder::getBestStreamIndex(AVMediaType mediaType) { AVCodecOnlyUseForCallingAVFindBestStream avCodec = nullptr; int streamIndex = av_find_best_stream(formatContext_.get(), mediaType, -1, -1, &avCodec, 0); @@ -244,7 +246,7 @@ int VideoDecoder::getBestStreamIndex(AVMediaType mediaType) { // VIDEO METADATA QUERY API // -------------------------------------------------------------------------- -void VideoDecoder::scanFileAndUpdateMetadataAndIndex() { +void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() { if (scannedAllStreams_) { return; } @@ -365,11 +367,12 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() { scannedAllStreams_ = true; } -VideoDecoder::ContainerMetadata VideoDecoder::getContainerMetadata() const { +SingleStreamDecoder::ContainerMetadata +SingleStreamDecoder::getContainerMetadata() const { return containerMetadata_; } -torch::Tensor VideoDecoder::getKeyFrameIndices() { +torch::Tensor SingleStreamDecoder::getKeyFrameIndices() { validateActiveStream(AVMEDIA_TYPE_VIDEO); validateScannedAllStreams("getKeyFrameIndices"); @@ -388,7 +391,7 @@ torch::Tensor VideoDecoder::getKeyFrameIndices() { // ADDING STREAMS API // -------------------------------------------------------------------------- -void VideoDecoder::addStream( +void SingleStreamDecoder::addStream( int streamIndex, AVMediaType mediaType, const torch::Device& device, @@ -471,7 +474,7 @@ void VideoDecoder::addStream( } } -void VideoDecoder::addVideoStream( +void SingleStreamDecoder::addVideoStream( int streamIndex, const VideoStreamOptions& videoStreamOptions) { TORCH_CHECK( @@ -515,14 +518,14 @@ void VideoDecoder::addVideoStream( // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements // so we fall back to filtergraph if the width is not a multiple of 32. auto defaultLibrary = (width % 32 == 0) - ? VideoDecoder::ColorConversionLibrary::SWSCALE - : VideoDecoder::ColorConversionLibrary::FILTERGRAPH; + ? SingleStreamDecoder::ColorConversionLibrary::SWSCALE + : SingleStreamDecoder::ColorConversionLibrary::FILTERGRAPH; streamInfo.colorConversionLibrary = videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary); } -void VideoDecoder::addAudioStream( +void SingleStreamDecoder::addAudioStream( int streamIndex, const AudioStreamOptions& audioStreamOptions) { TORCH_CHECK( @@ -552,7 +555,7 @@ void VideoDecoder::addAudioStream( // HIGH-LEVEL DECODING ENTRY-POINTS // -------------------------------------------------------------------------- -VideoDecoder::FrameOutput VideoDecoder::getNextFrame() { +SingleStreamDecoder::FrameOutput SingleStreamDecoder::getNextFrame() { auto output = getNextFrameInternal(); if (streamInfos_[activeStreamIndex_].avMediaType == AVMEDIA_TYPE_VIDEO) { output.data = maybePermuteHWC2CHW(output.data); @@ -560,7 +563,7 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrame() { return output; } -VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal( +SingleStreamDecoder::FrameOutput SingleStreamDecoder::getNextFrameInternal( std::optional preAllocatedOutputTensor) { validateActiveStream(); UniqueAVFrame avFrame = decodeAVFrame( @@ -568,13 +571,14 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal( return convertAVFrameToFrameOutput(avFrame, preAllocatedOutputTensor); } -VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndex(int64_t frameIndex) { +SingleStreamDecoder::FrameOutput SingleStreamDecoder::getFrameAtIndex( + int64_t frameIndex) { auto frameOutput = getFrameAtIndexInternal(frameIndex); frameOutput.data = maybePermuteHWC2CHW(frameOutput.data); return frameOutput; } -VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndexInternal( +SingleStreamDecoder::FrameOutput SingleStreamDecoder::getFrameAtIndexInternal( int64_t frameIndex, std::optional preAllocatedOutputTensor) { validateActiveStream(AVMEDIA_TYPE_VIDEO); @@ -589,7 +593,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndexInternal( return getNextFrameInternal(preAllocatedOutputTensor); } -VideoDecoder::FrameBatchOutput VideoDecoder::getFramesAtIndices( +SingleStreamDecoder::FrameBatchOutput SingleStreamDecoder::getFramesAtIndices( const std::vector& frameIndices) { validateActiveStream(AVMEDIA_TYPE_VIDEO); @@ -648,8 +652,10 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesAtIndices( return frameBatchOutput; } -VideoDecoder::FrameBatchOutput -VideoDecoder::getFramesInRange(int64_t start, int64_t stop, int64_t step) { +SingleStreamDecoder::FrameBatchOutput SingleStreamDecoder::getFramesInRange( + int64_t start, + int64_t stop, + int64_t step) { validateActiveStream(AVMEDIA_TYPE_VIDEO); const auto& streamMetadata = @@ -680,7 +686,8 @@ VideoDecoder::getFramesInRange(int64_t start, int64_t stop, int64_t step) { return frameBatchOutput; } -VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt(double seconds) { +SingleStreamDecoder::FrameOutput SingleStreamDecoder::getFramePlayedAt( + double seconds) { validateActiveStream(AVMEDIA_TYPE_VIDEO); StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; double frameStartTime = @@ -720,7 +727,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt(double seconds) { return frameOutput; } -VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedAt( +SingleStreamDecoder::FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt( const std::vector& timestamps) { validateActiveStream(AVMEDIA_TYPE_VIDEO); @@ -750,7 +757,8 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedAt( return getFramesAtIndices(frameIndices); } -VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange( +SingleStreamDecoder::FrameBatchOutput +SingleStreamDecoder::getFramesPlayedInRange( double startSeconds, double stopSeconds) { validateActiveStream(AVMEDIA_TYPE_VIDEO); @@ -883,7 +891,8 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange( // [2] If you're brave and curious, you can read the long "Seek offset for // audio" note in https://github.com/pytorch/torchcodec/pull/507/files, which // sums up past (and failed) attemps at working around this issue. -VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio( +SingleStreamDecoder::AudioFramesOutput +SingleStreamDecoder::getFramesPlayedInRangeAudio( double startSeconds, std::optional stopSecondsOptional) { validateActiveStream(AVMEDIA_TYPE_AUDIO); @@ -966,7 +975,7 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio( // SEEKING APIs // -------------------------------------------------------------------------- -void VideoDecoder::setCursorPtsInSeconds(double seconds) { +void SingleStreamDecoder::setCursorPtsInSeconds(double seconds) { // We don't allow public audio decoding APIs to seek, see [Audio Decoding // Design] validateActiveStream(AVMEDIA_TYPE_VIDEO); @@ -974,7 +983,7 @@ void VideoDecoder::setCursorPtsInSeconds(double seconds) { secondsToClosestPts(seconds, streamInfos_[activeStreamIndex_].timeBase)); } -void VideoDecoder::setCursor(int64_t pts) { +void SingleStreamDecoder::setCursor(int64_t pts) { cursorWasJustSet_ = true; cursor_ = pts; } @@ -1004,7 +1013,7 @@ I P P P I P P P I P P I P P I P (2) is more efficient than (1) if there is an I frame between x and y. */ -bool VideoDecoder::canWeAvoidSeeking() const { +bool SingleStreamDecoder::canWeAvoidSeeking() const { const StreamInfo& streamInfo = streamInfos_.at(activeStreamIndex_); if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) { // For audio, we only need to seek if a backwards seek was requested within @@ -1037,7 +1046,7 @@ bool VideoDecoder::canWeAvoidSeeking() const { // This method looks at currentPts and desiredPts and seeks in the // AVFormatContext if it is needed. We can skip seeking in certain cases. See // the comment of canWeAvoidSeeking() for details. -void VideoDecoder::maybeSeekToBeforeDesiredPts() { +void SingleStreamDecoder::maybeSeekToBeforeDesiredPts() { validateActiveStream(); StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; @@ -1081,7 +1090,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() { // LOW-LEVEL DECODING // -------------------------------------------------------------------------- -UniqueAVFrame VideoDecoder::decodeAVFrame( +UniqueAVFrame SingleStreamDecoder::decodeAVFrame( std::function filterFunction) { validateActiveStream(); @@ -1178,7 +1187,7 @@ UniqueAVFrame VideoDecoder::decodeAVFrame( if (status < AVSUCCESS) { if (reachedEOF || status == AVERROR_EOF) { - throw VideoDecoder::EndOfFileException( + throw SingleStreamDecoder::EndOfFileException( "Requested next frame while there are no more frames left to " "decode."); } @@ -1203,7 +1212,8 @@ UniqueAVFrame VideoDecoder::decodeAVFrame( // AVFRAME <-> FRAME OUTPUT CONVERSION // -------------------------------------------------------------------------- -VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput( +SingleStreamDecoder::FrameOutput +SingleStreamDecoder::convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, std::optional preAllocatedOutputTensor) { // Convert the frame to tensor. @@ -1243,7 +1253,7 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput( // TODO: Figure out whether that's possible! // Dimension order of the preAllocatedOutputTensor must be HWC, regardless of // `dimension_order` parameter. It's up to callers to re-shape it if needed. -void VideoDecoder::convertAVFrameToFrameOutputOnCPU( +void SingleStreamDecoder::convertAVFrameToFrameOutputOnCPU( UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) { @@ -1343,7 +1353,7 @@ void VideoDecoder::convertAVFrameToFrameOutputOnCPU( } } -int VideoDecoder::convertAVFrameToTensorUsingSwsScale( +int SingleStreamDecoder::convertAVFrameToTensorUsingSwsScale( const UniqueAVFrame& avFrame, torch::Tensor& outputTensor) { StreamInfo& activeStreamInfo = streamInfos_[activeStreamIndex_]; @@ -1363,7 +1373,7 @@ int VideoDecoder::convertAVFrameToTensorUsingSwsScale( return resultHeight; } -torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph( +torch::Tensor SingleStreamDecoder::convertAVFrameToTensorUsingFilterGraph( const UniqueAVFrame& avFrame) { FilterGraphContext& filterGraphContext = streamInfos_[activeStreamIndex_].filterGraphContext; @@ -1391,7 +1401,7 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph( filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8}); } -void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU( +void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU( UniqueAVFrame& srcAVFrame, FrameOutput& frameOutput) { AVSampleFormat sourceSampleFormat = @@ -1446,7 +1456,7 @@ void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU( } } -UniqueAVFrame VideoDecoder::convertAudioAVFrameSampleFormatAndSampleRate( +UniqueAVFrame SingleStreamDecoder::convertAudioAVFrameSampleFormatAndSampleRate( const UniqueAVFrame& srcAVFrame, AVSampleFormat sourceSampleFormat, AVSampleFormat desiredSampleFormat, @@ -1517,7 +1527,7 @@ UniqueAVFrame VideoDecoder::convertAudioAVFrameSampleFormatAndSampleRate( return convertedAVFrame; } -std::optional VideoDecoder::maybeFlushSwrBuffers() { +std::optional SingleStreamDecoder::maybeFlushSwrBuffers() { // When sample rate conversion is involved, swresample buffers some of the // samples in-between calls to swr_convert (see the libswresample docs). // That's because the last few samples in a given frame require future samples @@ -1558,7 +1568,7 @@ std::optional VideoDecoder::maybeFlushSwrBuffers() { // OUTPUT ALLOCATION AND SHAPE CONVERSION // -------------------------------------------------------------------------- -VideoDecoder::FrameBatchOutput::FrameBatchOutput( +SingleStreamDecoder::FrameBatchOutput::FrameBatchOutput( int64_t numFrames, const VideoStreamOptions& videoStreamOptions, const StreamMetadata& streamMetadata) @@ -1598,7 +1608,8 @@ torch::Tensor allocateEmptyHWCTensor( // or 4D. // Calling permute() is guaranteed to return a view as per the docs: // https://pytorch.org/docs/stable/generated/torch.permute.html -torch::Tensor VideoDecoder::maybePermuteHWC2CHW(torch::Tensor& hwcTensor) { +torch::Tensor SingleStreamDecoder::maybePermuteHWC2CHW( + torch::Tensor& hwcTensor) { if (streamInfos_[activeStreamIndex_].videoStreamOptions.dimensionOrder == "NHWC") { return hwcTensor; @@ -1621,8 +1632,8 @@ torch::Tensor VideoDecoder::maybePermuteHWC2CHW(torch::Tensor& hwcTensor) { // COLOR CONVERSION UTILS AND INITIALIZERS // -------------------------------------------------------------------------- -bool VideoDecoder::DecodedFrameContext::operator==( - const VideoDecoder::DecodedFrameContext& other) { +bool SingleStreamDecoder::DecodedFrameContext::operator==( + const SingleStreamDecoder::DecodedFrameContext& other) { return decodedWidth == other.decodedWidth && decodedHeight == other.decodedHeight && decodedFormat == other.decodedFormat && @@ -1630,12 +1641,12 @@ bool VideoDecoder::DecodedFrameContext::operator==( expectedHeight == other.expectedHeight; } -bool VideoDecoder::DecodedFrameContext::operator!=( - const VideoDecoder::DecodedFrameContext& other) { +bool SingleStreamDecoder::DecodedFrameContext::operator!=( + const SingleStreamDecoder::DecodedFrameContext& other) { return !(*this == other); } -void VideoDecoder::createFilterGraph( +void SingleStreamDecoder::createFilterGraph( StreamInfo& streamInfo, int expectedOutputHeight, int expectedOutputWidth) { @@ -1741,7 +1752,7 @@ void VideoDecoder::createFilterGraph( } } -void VideoDecoder::createSwsContext( +void SingleStreamDecoder::createSwsContext( StreamInfo& streamInfo, const DecodedFrameContext& frameContext, const enum AVColorSpace colorspace) { @@ -1787,7 +1798,7 @@ void VideoDecoder::createSwsContext( streamInfo.swsContext.reset(swsContext); } -void VideoDecoder::createSwrContext( +void SingleStreamDecoder::createSwrContext( StreamInfo& streamInfo, AVSampleFormat sourceSampleFormat, AVSampleFormat desiredSampleFormat, @@ -1815,7 +1826,7 @@ void VideoDecoder::createSwrContext( // PTS <-> INDEX CONVERSIONS // -------------------------------------------------------------------------- -int VideoDecoder::getKeyFrameIndexForPts(int64_t pts) const { +int SingleStreamDecoder::getKeyFrameIndexForPts(int64_t pts) const { const StreamInfo& streamInfo = streamInfos_.at(activeStreamIndex_); if (streamInfo.keyFrames.empty()) { return av_index_search_timestamp( @@ -1825,14 +1836,14 @@ int VideoDecoder::getKeyFrameIndexForPts(int64_t pts) const { } } -int VideoDecoder::getKeyFrameIndexForPtsUsingScannedIndex( - const std::vector& keyFrames, +int SingleStreamDecoder::getKeyFrameIndexForPtsUsingScannedIndex( + const std::vector& keyFrames, int64_t pts) const { auto upperBound = std::upper_bound( keyFrames.begin(), keyFrames.end(), pts, - [](int64_t pts, const VideoDecoder::FrameInfo& frameInfo) { + [](int64_t pts, const SingleStreamDecoder::FrameInfo& frameInfo) { return pts < frameInfo.pts; }); if (upperBound == keyFrames.begin()) { @@ -1841,7 +1852,7 @@ int VideoDecoder::getKeyFrameIndexForPtsUsingScannedIndex( return upperBound - 1 - keyFrames.begin(); } -int64_t VideoDecoder::secondsToIndexLowerBound(double seconds) { +int64_t SingleStreamDecoder::secondsToIndexLowerBound(double seconds) { auto& streamInfo = streamInfos_[activeStreamIndex_]; switch (seekMode_) { case SeekMode::exact: { @@ -1868,7 +1879,7 @@ int64_t VideoDecoder::secondsToIndexLowerBound(double seconds) { } } -int64_t VideoDecoder::secondsToIndexUpperBound(double seconds) { +int64_t SingleStreamDecoder::secondsToIndexUpperBound(double seconds) { auto& streamInfo = streamInfos_[activeStreamIndex_]; switch (seekMode_) { case SeekMode::exact: { @@ -1895,7 +1906,7 @@ int64_t VideoDecoder::secondsToIndexUpperBound(double seconds) { } } -int64_t VideoDecoder::getPts(int64_t frameIndex) { +int64_t SingleStreamDecoder::getPts(int64_t frameIndex) { auto& streamInfo = streamInfos_[activeStreamIndex_]; switch (seekMode_) { case SeekMode::exact: @@ -1918,7 +1929,8 @@ int64_t VideoDecoder::getPts(int64_t frameIndex) { // STREAM AND METADATA APIS // -------------------------------------------------------------------------- -int64_t VideoDecoder::getNumFrames(const StreamMetadata& streamMetadata) { +int64_t SingleStreamDecoder::getNumFrames( + const StreamMetadata& streamMetadata) { switch (seekMode_) { case SeekMode::exact: return streamMetadata.numFramesFromScan.value(); @@ -1933,7 +1945,8 @@ int64_t VideoDecoder::getNumFrames(const StreamMetadata& streamMetadata) { } } -double VideoDecoder::getMinSeconds(const StreamMetadata& streamMetadata) { +double SingleStreamDecoder::getMinSeconds( + const StreamMetadata& streamMetadata) { switch (seekMode_) { case SeekMode::exact: return streamMetadata.minPtsSecondsFromScan.value(); @@ -1944,7 +1957,8 @@ double VideoDecoder::getMinSeconds(const StreamMetadata& streamMetadata) { } } -double VideoDecoder::getMaxSeconds(const StreamMetadata& streamMetadata) { +double SingleStreamDecoder::getMaxSeconds( + const StreamMetadata& streamMetadata) { switch (seekMode_) { case SeekMode::exact: return streamMetadata.maxPtsSecondsFromScan.value(); @@ -1963,7 +1977,7 @@ double VideoDecoder::getMaxSeconds(const StreamMetadata& streamMetadata) { // VALIDATION UTILS // -------------------------------------------------------------------------- -void VideoDecoder::validateActiveStream( +void SingleStreamDecoder::validateActiveStream( std::optional avMediaType) { auto errorMsg = "Provided stream index=" + std::to_string(activeStreamIndex_) + @@ -1988,14 +2002,14 @@ void VideoDecoder::validateActiveStream( } } -void VideoDecoder::validateScannedAllStreams(const std::string& msg) { +void SingleStreamDecoder::validateScannedAllStreams(const std::string& msg) { if (!scannedAllStreams_) { throw std::runtime_error( "Must scan all streams to update metadata before calling " + msg); } } -void VideoDecoder::validateFrameIndex( +void SingleStreamDecoder::validateFrameIndex( const StreamMetadata& streamMetadata, int64_t frameIndex) { int64_t numFrames = getNumFrames(streamMetadata); @@ -2010,13 +2024,13 @@ void VideoDecoder::validateFrameIndex( // MORALLY PRIVATE UTILS // -------------------------------------------------------------------------- -VideoDecoder::DecodeStats VideoDecoder::getDecodeStats() const { +SingleStreamDecoder::DecodeStats SingleStreamDecoder::getDecodeStats() const { return decodeStats_; } std::ostream& operator<<( std::ostream& os, - const VideoDecoder::DecodeStats& stats) { + const SingleStreamDecoder::DecodeStats& stats) { os << "DecodeStats{" << "numFramesReceivedByDecoder=" << stats.numFramesReceivedByDecoder << ", numPacketsRead=" << stats.numPacketsRead @@ -2028,11 +2042,11 @@ std::ostream& operator<<( return os; } -void VideoDecoder::resetDecodeStats() { +void SingleStreamDecoder::resetDecodeStats() { decodeStats_ = DecodeStats{}; } -double VideoDecoder::getPtsSecondsForFrame(int64_t frameIndex) { +double SingleStreamDecoder::getPtsSecondsForFrame(int64_t frameIndex) { validateActiveStream(AVMEDIA_TYPE_VIDEO); validateScannedAllStreams("getPtsSecondsForFrame"); @@ -2054,26 +2068,26 @@ FrameDims getHeightAndWidthFromResizedAVFrame(const AVFrame& resizedAVFrame) { } FrameDims getHeightAndWidthFromOptionsOrMetadata( - const VideoDecoder::VideoStreamOptions& videoStreamOptions, - const VideoDecoder::StreamMetadata& streamMetadata) { + const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions, + const SingleStreamDecoder::StreamMetadata& streamMetadata) { return FrameDims( videoStreamOptions.height.value_or(*streamMetadata.height), videoStreamOptions.width.value_or(*streamMetadata.width)); } FrameDims getHeightAndWidthFromOptionsOrAVFrame( - const VideoDecoder::VideoStreamOptions& videoStreamOptions, + const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions, const UniqueAVFrame& avFrame) { return FrameDims( videoStreamOptions.height.value_or(avFrame->height), videoStreamOptions.width.value_or(avFrame->width)); } -VideoDecoder::SeekMode seekModeFromString(std::string_view seekMode) { +SingleStreamDecoder::SeekMode seekModeFromString(std::string_view seekMode) { if (seekMode == "exact") { - return VideoDecoder::SeekMode::exact; + return SingleStreamDecoder::SeekMode::exact; } else if (seekMode == "approximate") { - return VideoDecoder::SeekMode::approximate; + return SingleStreamDecoder::SeekMode::approximate; } else { TORCH_CHECK(false, "Invalid seek mode: " + std::string(seekMode)); } diff --git a/src/torchcodec/_core/VideoDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h similarity index 96% rename from src/torchcodec/_core/VideoDecoder.h rename to src/torchcodec/_core/SingleStreamDecoder.h index dacd09f3..df333d03 100644 --- a/src/torchcodec/_core/VideoDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -17,12 +17,12 @@ namespace facebook::torchcodec { -// The VideoDecoder class can be used to decode video frames to Tensors. -// Note that VideoDecoder is not thread-safe. +// The SingleStreamDecoder class can be used to decode video frames to Tensors. +// Note that SingleStreamDecoder is not thread-safe. // Do not call non-const APIs concurrently on the same object. -class VideoDecoder { +class SingleStreamDecoder { public: - ~VideoDecoder(); + ~SingleStreamDecoder(); // -------------------------------------------------------------------------- // CONSTRUCTION API @@ -30,16 +30,16 @@ class VideoDecoder { enum class SeekMode { exact, approximate }; - // Creates a VideoDecoder from the video at videoFilePath. - explicit VideoDecoder( + // Creates a SingleStreamDecoder from the video at videoFilePath. + explicit SingleStreamDecoder( const std::string& videoFilePath, SeekMode seekMode = SeekMode::exact); - // Creates a VideoDecoder using the provided AVIOContext inside the + // Creates a SingleStreamDecoder using the provided AVIOContext inside the // AVIOContextHolder. The AVIOContextHolder is the base class, and the // derived class will have specialized how the custom read, seek and writes // work. - explicit VideoDecoder( + explicit SingleStreamDecoder( std::unique_ptr context, SeekMode seekMode = SeekMode::exact); @@ -443,7 +443,7 @@ class VideoDecoder { // We build this index by scanning the file in // scanFileAndUpdateMetadataAndIndex int getKeyFrameIndexForPtsUsingScannedIndex( - const std::vector& keyFrames, + const std::vector& keyFrames, int64_t pts) const; int64_t secondsToIndexLowerBound(double seconds); @@ -568,11 +568,11 @@ struct FrameDims { FrameDims getHeightAndWidthFromResizedAVFrame(const AVFrame& resizedAVFrame); FrameDims getHeightAndWidthFromOptionsOrMetadata( - const VideoDecoder::VideoStreamOptions& videoStreamOptions, - const VideoDecoder::StreamMetadata& streamMetadata); + const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions, + const SingleStreamDecoder::StreamMetadata& streamMetadata); FrameDims getHeightAndWidthFromOptionsOrAVFrame( - const VideoDecoder::VideoStreamOptions& videoStreamOptions, + const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions, const UniqueAVFrame& avFrame); torch::Tensor allocateEmptyHWCTensor( @@ -581,11 +581,11 @@ torch::Tensor allocateEmptyHWCTensor( torch::Device device, std::optional numFrames = std::nullopt); -// Prints the VideoDecoder::DecodeStats to the ostream. +// Prints the SingleStreamDecoder::DecodeStats to the ostream. std::ostream& operator<<( std::ostream& os, - const VideoDecoder::DecodeStats& stats); + const SingleStreamDecoder::DecodeStats& stats); -VideoDecoder::SeekMode seekModeFromString(std::string_view seekMode); +SingleStreamDecoder::SeekMode seekModeFromString(std::string_view seekMode); } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index af9602e0..45324908 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -11,7 +11,7 @@ #include "c10/core/SymIntArrayRef.h" #include "c10/util/Exception.h" #include "src/torchcodec/_core/AVIOBytesContext.h" -#include "src/torchcodec/_core/VideoDecoder.h" +#include "src/torchcodec/_core/SingleStreamDecoder.h" namespace facebook::torchcodec { @@ -66,21 +66,22 @@ TORCH_LIBRARY(torchcodec_ns, m) { namespace { at::Tensor wrapDecoderPointerToTensor( - std::unique_ptr uniqueDecoder) { - VideoDecoder* decoder = uniqueDecoder.release(); + std::unique_ptr uniqueDecoder) { + SingleStreamDecoder* decoder = uniqueDecoder.release(); auto deleter = [decoder](void*) { delete decoder; }; - at::Tensor tensor = - at::from_blob(decoder, {sizeof(VideoDecoder*)}, deleter, {at::kLong}); - auto videoDecoder = static_cast(tensor.mutable_data_ptr()); + at::Tensor tensor = at::from_blob( + decoder, {sizeof(SingleStreamDecoder*)}, deleter, {at::kLong}); + auto videoDecoder = + static_cast(tensor.mutable_data_ptr()); TORCH_CHECK_EQ(videoDecoder, decoder) << "videoDecoder=" << videoDecoder; return tensor; } -VideoDecoder* unwrapTensorToGetDecoder(at::Tensor& tensor) { +SingleStreamDecoder* unwrapTensorToGetDecoder(at::Tensor& tensor) { TORCH_INTERNAL_ASSERT(tensor.is_contiguous()); void* buffer = tensor.mutable_data_ptr(); - VideoDecoder* decoder = static_cast(buffer); + SingleStreamDecoder* decoder = static_cast(buffer); return decoder; } @@ -92,7 +93,7 @@ VideoDecoder* unwrapTensorToGetDecoder(at::Tensor& tensor) { // under torch.compile(). using OpsFrameOutput = std::tuple; -OpsFrameOutput makeOpsFrameOutput(VideoDecoder::FrameOutput& frame) { +OpsFrameOutput makeOpsFrameOutput(SingleStreamDecoder::FrameOutput& frame) { return std::make_tuple( frame.data, torch::tensor(frame.ptsSeconds, torch::dtype(torch::kFloat64)), @@ -111,7 +112,7 @@ OpsFrameOutput makeOpsFrameOutput(VideoDecoder::FrameOutput& frame) { using OpsFrameBatchOutput = std::tuple; OpsFrameBatchOutput makeOpsFrameBatchOutput( - VideoDecoder::FrameBatchOutput& batch) { + SingleStreamDecoder::FrameBatchOutput& batch) { return std::make_tuple(batch.data, batch.ptsSeconds, batch.durationSeconds); } @@ -122,7 +123,7 @@ OpsFrameBatchOutput makeOpsFrameBatchOutput( using OpsAudioFramesOutput = std::tuple; OpsAudioFramesOutput makeOpsAudioFramesOutput( - VideoDecoder::AudioFramesOutput& audioFrames) { + SingleStreamDecoder::AudioFramesOutput& audioFrames) { return std::make_tuple( audioFrames.data, torch::tensor(audioFrames.ptsSeconds, torch::dtype(torch::kFloat64))); @@ -156,25 +157,25 @@ std::string mapToJson(const std::map& metadataMap) { // Implementations for the operators // ============================== -// Create a VideoDecoder from file and wrap the pointer in a tensor. +// Create a SingleStreamDecoder from file and wrap the pointer in a tensor. at::Tensor create_from_file( std::string_view filename, std::optional seek_mode = std::nullopt) { std::string filenameStr(filename); - VideoDecoder::SeekMode realSeek = VideoDecoder::SeekMode::exact; + SingleStreamDecoder::SeekMode realSeek = SingleStreamDecoder::SeekMode::exact; if (seek_mode.has_value()) { realSeek = seekModeFromString(seek_mode.value()); } - std::unique_ptr uniqueDecoder = - std::make_unique(filenameStr, realSeek); + std::unique_ptr uniqueDecoder = + std::make_unique(filenameStr, realSeek); return wrapDecoderPointerToTensor(std::move(uniqueDecoder)); } -// Create a VideoDecoder from the actual bytes of a video and wrap the pointer -// in a tensor. The VideoDecoder will decode the provided bytes. +// Create a SingleStreamDecoder from the actual bytes of a video and wrap the +// pointer in a tensor. The SingleStreamDecoder will decode the provided bytes. at::Tensor create_from_tensor( at::Tensor video_tensor, std::optional seek_mode = std::nullopt) { @@ -185,21 +186,21 @@ at::Tensor create_from_tensor( void* data = video_tensor.mutable_data_ptr(); size_t length = video_tensor.numel(); - VideoDecoder::SeekMode realSeek = VideoDecoder::SeekMode::exact; + SingleStreamDecoder::SeekMode realSeek = SingleStreamDecoder::SeekMode::exact; if (seek_mode.has_value()) { realSeek = seekModeFromString(seek_mode.value()); } auto contextHolder = std::make_unique(data, length); - std::unique_ptr uniqueDecoder = - std::make_unique(std::move(contextHolder), realSeek); + std::unique_ptr uniqueDecoder = + std::make_unique(std::move(contextHolder), realSeek); return wrapDecoderPointerToTensor(std::move(uniqueDecoder)); } at::Tensor _convert_to_tensor(int64_t decoder_ptr) { - auto decoder = reinterpret_cast(decoder_ptr); - std::unique_ptr uniqueDecoder(decoder); + auto decoder = reinterpret_cast(decoder_ptr); + std::unique_ptr uniqueDecoder(decoder); return wrapDecoderPointerToTensor(std::move(uniqueDecoder)); } @@ -212,7 +213,7 @@ void _add_video_stream( std::optional stream_index = std::nullopt, std::optional device = std::nullopt, std::optional color_conversion_library = std::nullopt) { - VideoDecoder::VideoStreamOptions videoStreamOptions; + SingleStreamDecoder::VideoStreamOptions videoStreamOptions; videoStreamOptions.width = width; videoStreamOptions.height = height; videoStreamOptions.ffmpegThreadCount = num_threads; @@ -226,10 +227,10 @@ void _add_video_stream( std::string stdColorConversionLibrary{color_conversion_library.value()}; if (stdColorConversionLibrary == "filtergraph") { videoStreamOptions.colorConversionLibrary = - VideoDecoder::ColorConversionLibrary::FILTERGRAPH; + SingleStreamDecoder::ColorConversionLibrary::FILTERGRAPH; } else if (stdColorConversionLibrary == "swscale") { videoStreamOptions.colorConversionLibrary = - VideoDecoder::ColorConversionLibrary::SWSCALE; + SingleStreamDecoder::ColorConversionLibrary::SWSCALE; } else { throw std::runtime_error( "Invalid color_conversion_library=" + stdColorConversionLibrary + @@ -276,7 +277,7 @@ void add_audio_stream( at::Tensor& decoder, std::optional stream_index = std::nullopt, std::optional sample_rate = std::nullopt) { - VideoDecoder::AudioStreamOptions audioStreamOptions; + SingleStreamDecoder::AudioStreamOptions audioStreamOptions; audioStreamOptions.sampleRate = sample_rate; auto videoDecoder = unwrapTensorToGetDecoder(decoder); @@ -285,7 +286,8 @@ void add_audio_stream( // Seek to a particular presentation timestamp in the video in seconds. void seek_to_pts(at::Tensor& decoder, double seconds) { - auto videoDecoder = static_cast(decoder.mutable_data_ptr()); + auto videoDecoder = + static_cast(decoder.mutable_data_ptr()); videoDecoder->setCursorPtsInSeconds(seconds); } @@ -293,10 +295,10 @@ void seek_to_pts(at::Tensor& decoder, double seconds) { // duration as tensors. OpsFrameOutput get_next_frame(at::Tensor& decoder) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); - VideoDecoder::FrameOutput result; + SingleStreamDecoder::FrameOutput result; try { result = videoDecoder->getNextFrame(); - } catch (const VideoDecoder::EndOfFileException& e) { + } catch (const SingleStreamDecoder::EndOfFileException& e) { C10_THROW_ERROR(IndexError, e.what()); } return makeOpsFrameOutput(result); @@ -307,10 +309,10 @@ OpsFrameOutput get_next_frame(at::Tensor& decoder) { // given timestamp T has T >= PTS and T < PTS + Duration. OpsFrameOutput get_frame_at_pts(at::Tensor& decoder, double seconds) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); - VideoDecoder::FrameOutput result; + SingleStreamDecoder::FrameOutput result; try { result = videoDecoder->getFramePlayedAt(seconds); - } catch (const VideoDecoder::EndOfFileException& e) { + } catch (const SingleStreamDecoder::EndOfFileException& e) { C10_THROW_ERROR(IndexError, e.what()); } return makeOpsFrameOutput(result); @@ -406,7 +408,7 @@ torch::Tensor _get_key_frame_indices(at::Tensor& decoder) { std::string get_json_metadata(at::Tensor& decoder) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); - VideoDecoder::ContainerMetadata videoMetadata = + SingleStreamDecoder::ContainerMetadata videoMetadata = videoDecoder->getContainerMetadata(); auto maybeBestVideoStreamIndex = videoMetadata.bestVideoStreamIndex; diff --git a/src/torchcodec/_core/pybind_ops.cpp b/src/torchcodec/_core/pybind_ops.cpp index 8c2f0c77..6f873f5a 100644 --- a/src/torchcodec/_core/pybind_ops.cpp +++ b/src/torchcodec/_core/pybind_ops.cpp @@ -10,7 +10,7 @@ #include #include "src/torchcodec/_core/AVIOFileLikeContext.h" -#include "src/torchcodec/_core/VideoDecoder.h" +#include "src/torchcodec/_core/SingleStreamDecoder.h" namespace py = pybind11; @@ -26,15 +26,15 @@ namespace facebook::torchcodec { int64_t create_from_file_like( py::object file_like, std::optional seek_mode) { - VideoDecoder::SeekMode realSeek = VideoDecoder::SeekMode::exact; + SingleStreamDecoder::SeekMode realSeek = SingleStreamDecoder::SeekMode::exact; if (seek_mode.has_value()) { realSeek = seekModeFromString(seek_mode.value()); } auto avioContextHolder = std::make_unique(file_like); - VideoDecoder* decoder = - new VideoDecoder(std::move(avioContextHolder), realSeek); + SingleStreamDecoder* decoder = + new SingleStreamDecoder(std::move(avioContextHolder), realSeek); return reinterpret_cast(decoder); } diff --git a/test/decoders/VideoDecoderTest.cpp b/test/decoders/VideoDecoderTest.cpp index e617afc7..7a81e9dd 100644 --- a/test/decoders/VideoDecoderTest.cpp +++ b/test/decoders/VideoDecoderTest.cpp @@ -4,8 +4,8 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. -#include "src/torchcodec/_core/VideoDecoder.h" #include "src/torchcodec/_core/AVIOBytesContext.h" +#include "src/torchcodec/_core/SingleStreamDecoder.h" #include #include @@ -39,9 +39,9 @@ std::string getResourcePath(const std::string& filename) { return filepath; } -class VideoDecoderTest : public testing::TestWithParam { +class SingleStreamDecoderTest : public testing::TestWithParam { protected: - std::unique_ptr createDecoderFromPath( + std::unique_ptr createDecoderFromPath( const std::string& filepath, bool useMemoryBuffer) { if (useMemoryBuffer) { @@ -53,22 +53,23 @@ class VideoDecoderTest : public testing::TestWithParam { void* buffer = content_.data(); size_t length = content_.length(); auto contextHolder = std::make_unique(buffer, length); - return std::make_unique( - std::move(contextHolder), VideoDecoder::SeekMode::approximate); + return std::make_unique( + std::move(contextHolder), SingleStreamDecoder::SeekMode::approximate); } else { - return std::make_unique( - filepath, VideoDecoder::SeekMode::approximate); + return std::make_unique( + filepath, SingleStreamDecoder::SeekMode::approximate); } } std::string content_; }; -TEST_P(VideoDecoderTest, ReturnsFpsAndDurationForVideoInMetadata) { +TEST_P(SingleStreamDecoderTest, ReturnsFpsAndDurationForVideoInMetadata) { std::string path = getResourcePath("nasa_13013.mp4"); - std::unique_ptr decoder = + std::unique_ptr decoder = createDecoderFromPath(path, GetParam()); - VideoDecoder::ContainerMetadata metadata = decoder->getContainerMetadata(); + SingleStreamDecoder::ContainerMetadata metadata = + decoder->getContainerMetadata(); EXPECT_EQ(metadata.numAudioStreams, 2); EXPECT_EQ(metadata.numVideoStreams, 2); #if LIBAVFORMAT_VERSION_MAJOR >= 60 @@ -95,8 +96,8 @@ TEST_P(VideoDecoderTest, ReturnsFpsAndDurationForVideoInMetadata) { EXPECT_EQ(*videoStream1.numFramesFromScan, 390); } -TEST(VideoDecoderTest, MissingVideoFileThrowsException) { - EXPECT_THROW(VideoDecoder("/this/file/does/not/exist"), c10::Error); +TEST(SingleStreamDecoderTest, MissingVideoFileThrowsException) { + EXPECT_THROW(SingleStreamDecoder("/this/file/does/not/exist"), c10::Error); } void dumpTensorToDisk( @@ -139,15 +140,16 @@ double computeAverageCosineSimilarity( // TEST(DecoderOptionsTest, ConvertsFromStringToOptions) { // std::string optionsString = // "ffmpeg_thread_count=3,dimension_order=NCHW,width=100,height=120"; -// VideoDecoder::DecoderOptions options = -// VideoDecoder::DecoderOptions(optionsString); +// SingleStreamDecoder::DecoderOptions options = +// SingleStreamDecoder::DecoderOptions(optionsString); // EXPECT_EQ(options.ffmpegThreadCount, 3); // } -TEST(VideoDecoderTest, RespectsWidthAndHeightFromOptions) { +TEST(SingleStreamDecoderTest, RespectsWidthAndHeightFromOptions) { std::string path = getResourcePath("nasa_13013.mp4"); - std::unique_ptr decoder = std::make_unique(path); - VideoDecoder::VideoStreamOptions videoStreamOptions; + std::unique_ptr decoder = + std::make_unique(path); + SingleStreamDecoder::VideoStreamOptions videoStreamOptions; videoStreamOptions.width = 100; videoStreamOptions.height = 120; decoder->addVideoStream(-1, videoStreamOptions); @@ -155,19 +157,20 @@ TEST(VideoDecoderTest, RespectsWidthAndHeightFromOptions) { EXPECT_EQ(tensor.sizes(), std::vector({3, 120, 100})); } -TEST(VideoDecoderTest, RespectsOutputTensorDimensionOrderFromOptions) { +TEST(SingleStreamDecoderTest, RespectsOutputTensorDimensionOrderFromOptions) { std::string path = getResourcePath("nasa_13013.mp4"); - std::unique_ptr decoder = std::make_unique(path); - VideoDecoder::VideoStreamOptions videoStreamOptions; + std::unique_ptr decoder = + std::make_unique(path); + SingleStreamDecoder::VideoStreamOptions videoStreamOptions; videoStreamOptions.dimensionOrder = "NHWC"; decoder->addVideoStream(-1, videoStreamOptions); torch::Tensor tensor = decoder->getNextFrame().data; EXPECT_EQ(tensor.sizes(), std::vector({270, 480, 3})); } -TEST_P(VideoDecoderTest, ReturnsFirstTwoFramesOfVideo) { +TEST_P(SingleStreamDecoderTest, ReturnsFirstTwoFramesOfVideo) { std::string path = getResourcePath("nasa_13013.mp4"); - std::unique_ptr ourDecoder = + std::unique_ptr ourDecoder = createDecoderFromPath(path, GetParam()); ourDecoder->addVideoStream(-1); auto output = ourDecoder->getNextFrame(); @@ -201,9 +204,9 @@ TEST_P(VideoDecoderTest, ReturnsFirstTwoFramesOfVideo) { } } -TEST_P(VideoDecoderTest, DecodesFramesInABatchInNCHW) { +TEST_P(SingleStreamDecoderTest, DecodesFramesInABatchInNCHW) { std::string path = getResourcePath("nasa_13013.mp4"); - std::unique_ptr ourDecoder = + std::unique_ptr ourDecoder = createDecoderFromPath(path, GetParam()); ourDecoder->scanFileAndUpdateMetadataAndIndex(); int bestVideoStreamIndex = @@ -223,14 +226,14 @@ TEST_P(VideoDecoderTest, DecodesFramesInABatchInNCHW) { EXPECT_TRUE(torch::equal(tensor[1], tensorTime6FromFFMPEG)); } -TEST_P(VideoDecoderTest, DecodesFramesInABatchInNHWC) { +TEST_P(SingleStreamDecoderTest, DecodesFramesInABatchInNHWC) { std::string path = getResourcePath("nasa_13013.mp4"); - std::unique_ptr ourDecoder = + std::unique_ptr ourDecoder = createDecoderFromPath(path, GetParam()); ourDecoder->scanFileAndUpdateMetadataAndIndex(); int bestVideoStreamIndex = *ourDecoder->getContainerMetadata().bestVideoStreamIndex; - VideoDecoder::VideoStreamOptions videoStreamOptions; + SingleStreamDecoder::VideoStreamOptions videoStreamOptions; videoStreamOptions.dimensionOrder = "NHWC"; ourDecoder->addVideoStream(bestVideoStreamIndex, videoStreamOptions); // Frame with index 180 corresponds to timestamp 6.006. @@ -248,9 +251,9 @@ TEST_P(VideoDecoderTest, DecodesFramesInABatchInNHWC) { EXPECT_TRUE(torch::equal(tensor[1], tensorTime6FromFFMPEG)); } -TEST_P(VideoDecoderTest, SeeksCloseToEof) { +TEST_P(SingleStreamDecoderTest, SeeksCloseToEof) { std::string path = getResourcePath("nasa_13013.mp4"); - std::unique_ptr ourDecoder = + std::unique_ptr ourDecoder = createDecoderFromPath(path, GetParam()); ourDecoder->addVideoStream(-1); ourDecoder->setCursorPtsInSeconds(388388. / 30'000); @@ -261,9 +264,9 @@ TEST_P(VideoDecoderTest, SeeksCloseToEof) { EXPECT_THROW(ourDecoder->getNextFrame(), std::exception); } -TEST_P(VideoDecoderTest, GetsFramePlayedAtTimestamp) { +TEST_P(SingleStreamDecoderTest, GetsFramePlayedAtTimestamp) { std::string path = getResourcePath("nasa_13013.mp4"); - std::unique_ptr ourDecoder = + std::unique_ptr ourDecoder = createDecoderFromPath(path, GetParam()); ourDecoder->addVideoStream(-1); auto output = ourDecoder->getFramePlayedAt(6.006); @@ -291,9 +294,9 @@ TEST_P(VideoDecoderTest, GetsFramePlayedAtTimestamp) { EXPECT_EQ(output.ptsSeconds, kPtsOfLastFrameInVideoStream); } -TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) { +TEST_P(SingleStreamDecoderTest, SeeksToFrameWithSpecificPts) { std::string path = getResourcePath("nasa_13013.mp4"); - std::unique_ptr ourDecoder = + std::unique_ptr ourDecoder = createDecoderFromPath(path, GetParam()); ourDecoder->addVideoStream(-1); ourDecoder->setCursorPtsInSeconds(6.0); @@ -386,47 +389,48 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) { } } -TEST_P(VideoDecoderTest, PreAllocatedTensorFilterGraph) { +TEST_P(SingleStreamDecoderTest, PreAllocatedTensorFilterGraph) { std::string path = getResourcePath("nasa_13013.mp4"); auto preAllocatedOutputTensor = torch::empty({270, 480, 3}, {torch::kUInt8}); - std::unique_ptr ourDecoder = - VideoDecoderTest::createDecoderFromPath(path, GetParam()); + std::unique_ptr ourDecoder = + SingleStreamDecoderTest::createDecoderFromPath(path, GetParam()); ourDecoder->scanFileAndUpdateMetadataAndIndex(); int bestVideoStreamIndex = *ourDecoder->getContainerMetadata().bestVideoStreamIndex; - VideoDecoder::VideoStreamOptions videoStreamOptions; + SingleStreamDecoder::VideoStreamOptions videoStreamOptions; videoStreamOptions.colorConversionLibrary = - VideoDecoder::ColorConversionLibrary::FILTERGRAPH; + SingleStreamDecoder::ColorConversionLibrary::FILTERGRAPH; ourDecoder->addVideoStream(bestVideoStreamIndex, videoStreamOptions); auto output = ourDecoder->getFrameAtIndexInternal(0, preAllocatedOutputTensor); EXPECT_EQ(output.data.data_ptr(), preAllocatedOutputTensor.data_ptr()); } -TEST_P(VideoDecoderTest, PreAllocatedTensorSwscale) { +TEST_P(SingleStreamDecoderTest, PreAllocatedTensorSwscale) { std::string path = getResourcePath("nasa_13013.mp4"); auto preAllocatedOutputTensor = torch::empty({270, 480, 3}, {torch::kUInt8}); - std::unique_ptr ourDecoder = - VideoDecoderTest::createDecoderFromPath(path, GetParam()); + std::unique_ptr ourDecoder = + SingleStreamDecoderTest::createDecoderFromPath(path, GetParam()); ourDecoder->scanFileAndUpdateMetadataAndIndex(); int bestVideoStreamIndex = *ourDecoder->getContainerMetadata().bestVideoStreamIndex; - VideoDecoder::VideoStreamOptions videoStreamOptions; + SingleStreamDecoder::VideoStreamOptions videoStreamOptions; videoStreamOptions.colorConversionLibrary = - VideoDecoder::ColorConversionLibrary::SWSCALE; + SingleStreamDecoder::ColorConversionLibrary::SWSCALE; ourDecoder->addVideoStream(bestVideoStreamIndex, videoStreamOptions); auto output = ourDecoder->getFrameAtIndexInternal(0, preAllocatedOutputTensor); EXPECT_EQ(output.data.data_ptr(), preAllocatedOutputTensor.data_ptr()); } -TEST_P(VideoDecoderTest, GetAudioMetadata) { +TEST_P(SingleStreamDecoderTest, GetAudioMetadata) { std::string path = getResourcePath("nasa_13013.mp4.audio.mp3"); - std::unique_ptr decoder = + std::unique_ptr decoder = createDecoderFromPath(path, GetParam()); - VideoDecoder::ContainerMetadata metadata = decoder->getContainerMetadata(); + SingleStreamDecoder::ContainerMetadata metadata = + decoder->getContainerMetadata(); EXPECT_EQ(metadata.numAudioStreams, 1); EXPECT_EQ(metadata.numVideoStreams, 0); EXPECT_EQ(metadata.allStreamMetadata.size(), 1); @@ -436,6 +440,9 @@ TEST_P(VideoDecoderTest, GetAudioMetadata) { EXPECT_NEAR(*audioStream.durationSeconds, 13.25, 1e-1); } -INSTANTIATE_TEST_SUITE_P(FromFileAndMemory, VideoDecoderTest, testing::Bool()); +INSTANTIATE_TEST_SUITE_P( + FromFileAndMemory, + SingleStreamDecoderTest, + testing::Bool()); } // namespace facebook::torchcodec