diff --git a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp index 50889677..81746109 100644 --- a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp +++ b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp @@ -17,7 +17,7 @@ namespace facebook::torchcodec { void convertAVFrameToFrameOutputOnCuda( const torch::Device& device, [[maybe_unused]] const VideoDecoder::VideoStreamOptions& videoStreamOptions, - [[maybe_unused]] VideoDecoder::AVFrameStream& avFrameStream, + [[maybe_unused]] UniqueAVFrame& avFrame, [[maybe_unused]] VideoDecoder::FrameOutput& frameOutput, [[maybe_unused]] std::optional preAllocatedOutputTensor) { throwUnsupportedDeviceError(device); diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index 56de03b6..7ef7b82c 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -190,17 +190,15 @@ void initializeContextOnCuda( void convertAVFrameToFrameOutputOnCuda( const torch::Device& device, const VideoDecoder::VideoStreamOptions& videoStreamOptions, - VideoDecoder::AVFrameStream& avFrameStream, + UniqueAVFrame& avFrame, VideoDecoder::FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) { - AVFrame* avFrame = avFrameStream.avFrame.get(); - TORCH_CHECK( avFrame->format == AV_PIX_FMT_CUDA, "Expected format to be AV_PIX_FMT_CUDA, got " + std::string(av_get_pix_fmt_name((AVPixelFormat)avFrame->format))); auto frameDims = - getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, *avFrame); + getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame); int height = frameDims.height; int width = frameDims.width; torch::Tensor& dst = frameOutput.data; diff --git a/src/torchcodec/decoders/_core/DeviceInterface.h b/src/torchcodec/decoders/_core/DeviceInterface.h index d65afe3d..49aea802 100644 --- a/src/torchcodec/decoders/_core/DeviceInterface.h +++ b/src/torchcodec/decoders/_core/DeviceInterface.h @@ -32,7 +32,7 @@ void initializeContextOnCuda( void convertAVFrameToFrameOutputOnCuda( const torch::Device& device, const VideoDecoder::VideoStreamOptions& videoStreamOptions, - VideoDecoder::AVFrameStream& avFrameStream, + UniqueAVFrame& avFrame, VideoDecoder::FrameOutput& frameOutput, std::optional preAllocatedOutputTensor = std::nullopt); diff --git a/src/torchcodec/decoders/_core/FFMPEGCommon.cpp b/src/torchcodec/decoders/_core/FFMPEGCommon.cpp index cb0152f0..eb82c5a2 100644 --- a/src/torchcodec/decoders/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/decoders/_core/FFMPEGCommon.cpp @@ -48,15 +48,11 @@ std::string getFFMPEGErrorStringFromErrorCode(int errorCode) { return std::string(errorBuffer); } -int64_t getDuration(const UniqueAVFrame& frame) { - return getDuration(frame.get()); -} - -int64_t getDuration(const AVFrame* frame) { +int64_t getDuration(const UniqueAVFrame& avFrame) { #if LIBAVUTIL_VERSION_MAJOR < 58 - return frame->pkt_duration; + return avFrame->pkt_duration; #else - return frame->duration; + return avFrame->duration; #endif } diff --git a/src/torchcodec/decoders/_core/FFMPEGCommon.h b/src/torchcodec/decoders/_core/FFMPEGCommon.h index 955ea82d..382563aa 100644 --- a/src/torchcodec/decoders/_core/FFMPEGCommon.h +++ b/src/torchcodec/decoders/_core/FFMPEGCommon.h @@ -140,7 +140,6 @@ std::string getFFMPEGErrorStringFromErrorCode(int errorCode); // struct member representing duration has changed across the versions we // support. int64_t getDuration(const UniqueAVFrame& frame); -int64_t getDuration(const AVFrame* frame); int getNumChannels(const UniqueAVFrame& avFrame); int getNumChannels(const UniqueAVCodecContext& avCodecContext); diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 9871db64..e9036416 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -583,9 +583,9 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrame() { VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal( std::optional preAllocatedOutputTensor) { validateActiveStream(); - AVFrameStream avFrameStream = decodeAVFrame( - [this](AVFrame* avFrame) { return avFrame->pts >= cursor_; }); - return convertAVFrameToFrameOutput(avFrameStream, preAllocatedOutputTensor); + UniqueAVFrame avFrame = decodeAVFrame( + [this](const UniqueAVFrame& avFrame) { return avFrame->pts >= cursor_; }); + return convertAVFrameToFrameOutput(avFrame, preAllocatedOutputTensor); } VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndex(int64_t frameIndex) { @@ -715,8 +715,8 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt(double seconds) { } setCursorPtsInSeconds(seconds); - AVFrameStream avFrameStream = - decodeAVFrame([seconds, this](AVFrame* avFrame) { + UniqueAVFrame avFrame = + decodeAVFrame([seconds, this](const UniqueAVFrame& avFrame) { StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; double frameStartTime = ptsToSeconds(avFrame->pts, streamInfo.timeBase); double frameEndTime = ptsToSeconds( @@ -735,7 +735,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt(double seconds) { }); // Convert the frame to tensor. - FrameOutput frameOutput = convertAVFrameToFrameOutput(avFrameStream); + FrameOutput frameOutput = convertAVFrameToFrameOutput(avFrame); frameOutput.data = maybePermuteHWC2CHW(frameOutput.data); return frameOutput; } @@ -891,14 +891,15 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio( auto finished = false; while (!finished) { try { - AVFrameStream avFrameStream = decodeAVFrame([startPts](AVFrame* avFrame) { - return startPts < avFrame->pts + getDuration(avFrame); - }); + UniqueAVFrame avFrame = + decodeAVFrame([startPts](const UniqueAVFrame& avFrame) { + return startPts < avFrame->pts + getDuration(avFrame); + }); // TODO: it's not great that we are getting a FrameOutput, which is // intended for videos. We should consider bypassing // convertAVFrameToFrameOutput and directly call // convertAudioAVFrameToFrameOutputOnCPU. - auto frameOutput = convertAVFrameToFrameOutput(avFrameStream); + auto frameOutput = convertAVFrameToFrameOutput(avFrame); firstFramePtsSeconds = std::min(firstFramePtsSeconds, frameOutput.ptsSeconds); frames.push_back(frameOutput.data); @@ -1035,8 +1036,8 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() { // LOW-LEVEL DECODING // -------------------------------------------------------------------------- -VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame( - std::function filterFunction) { +UniqueAVFrame VideoDecoder::decodeAVFrame( + std::function filterFunction) { validateActiveStream(); resetDecodeStats(); @@ -1064,7 +1065,7 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame( decodeStats_.numFramesReceivedByDecoder++; // Is this the kind of frame we're looking for? - if (status == AVSUCCESS && filterFunction(avFrame.get())) { + if (status == AVSUCCESS && filterFunction(avFrame)) { // Yes, this is the frame we'll return; break out of the decoding loop. break; } else if (status == AVSUCCESS) { @@ -1150,7 +1151,7 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame( streamInfo.lastDecodedAvFramePts = avFrame->pts; streamInfo.lastDecodedAvFrameDuration = getDuration(avFrame); - return AVFrameStream(std::move(avFrame), activeStreamIndex_); + return avFrame; } // -------------------------------------------------------------------------- @@ -1158,29 +1159,28 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame( // -------------------------------------------------------------------------- VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput( - VideoDecoder::AVFrameStream& avFrameStream, + UniqueAVFrame& avFrame, std::optional preAllocatedOutputTensor) { // Convert the frame to tensor. FrameOutput frameOutput; - int streamIndex = avFrameStream.streamIndex; - AVFrame* avFrame = avFrameStream.avFrame.get(); - frameOutput.streamIndex = streamIndex; - auto& streamInfo = streamInfos_[streamIndex]; + frameOutput.streamIndex = activeStreamIndex_; + auto& streamInfo = streamInfos_[activeStreamIndex_]; frameOutput.ptsSeconds = ptsToSeconds( - avFrame->pts, formatContext_->streams[streamIndex]->time_base); + avFrame->pts, formatContext_->streams[activeStreamIndex_]->time_base); frameOutput.durationSeconds = ptsToSeconds( - getDuration(avFrame), formatContext_->streams[streamIndex]->time_base); + getDuration(avFrame), + formatContext_->streams[activeStreamIndex_]->time_base); if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) { convertAudioAVFrameToFrameOutputOnCPU( - avFrameStream, frameOutput, preAllocatedOutputTensor); + avFrame, frameOutput, preAllocatedOutputTensor); } else if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) { convertAVFrameToFrameOutputOnCPU( - avFrameStream, frameOutput, preAllocatedOutputTensor); + avFrame, frameOutput, preAllocatedOutputTensor); } else if (streamInfo.videoStreamOptions.device.type() == torch::kCUDA) { convertAVFrameToFrameOutputOnCuda( streamInfo.videoStreamOptions.device, streamInfo.videoStreamOptions, - avFrameStream, + avFrame, frameOutput, preAllocatedOutputTensor); } else { @@ -1201,14 +1201,13 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput( // Dimension order of the preAllocatedOutputTensor must be HWC, regardless of // `dimension_order` parameter. It's up to callers to re-shape it if needed. void VideoDecoder::convertAVFrameToFrameOutputOnCPU( - VideoDecoder::AVFrameStream& avFrameStream, + UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) { - AVFrame* avFrame = avFrameStream.avFrame.get(); auto& streamInfo = streamInfos_[activeStreamIndex_]; auto frameDims = getHeightAndWidthFromOptionsOrAVFrame( - streamInfo.videoStreamOptions, *avFrame); + streamInfo.videoStreamOptions, avFrame); int expectedOutputHeight = frameDims.height; int expectedOutputWidth = frameDims.width; @@ -1302,7 +1301,7 @@ void VideoDecoder::convertAVFrameToFrameOutputOnCPU( } int VideoDecoder::convertAVFrameToTensorUsingSwsScale( - const AVFrame* avFrame, + const UniqueAVFrame& avFrame, torch::Tensor& outputTensor) { StreamInfo& activeStreamInfo = streamInfos_[activeStreamIndex_]; SwsContext* swsContext = activeStreamInfo.swsContext.get(); @@ -1322,11 +1321,11 @@ int VideoDecoder::convertAVFrameToTensorUsingSwsScale( } torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph( - const AVFrame* avFrame) { + const UniqueAVFrame& avFrame) { FilterGraphContext& filterGraphContext = streamInfos_[activeStreamIndex_].filterGraphContext; int status = - av_buffersrc_write_frame(filterGraphContext.sourceContext, avFrame); + av_buffersrc_write_frame(filterGraphContext.sourceContext, avFrame.get()); if (status < AVSUCCESS) { throw std::runtime_error("Failed to add frame to buffer source context"); } @@ -1350,7 +1349,7 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph( } void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU( - VideoDecoder::AVFrameStream& avFrameStream, + UniqueAVFrame& srcAVFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) { TORCH_CHECK( @@ -1358,17 +1357,17 @@ void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU( "pre-allocated audio tensor not supported yet."); AVSampleFormat sourceSampleFormat = - static_cast(avFrameStream.avFrame->format); + static_cast(srcAVFrame->format); AVSampleFormat desiredSampleFormat = AV_SAMPLE_FMT_FLTP; UniqueAVFrame convertedAVFrame; if (sourceSampleFormat != desiredSampleFormat) { convertedAVFrame = convertAudioAVFrameSampleFormat( - avFrameStream.avFrame, sourceSampleFormat, desiredSampleFormat); + srcAVFrame, sourceSampleFormat, desiredSampleFormat); } const UniqueAVFrame& avFrame = (sourceSampleFormat != desiredSampleFormat) ? convertedAVFrame - : avFrameStream.avFrame; + : srcAVFrame; AVSampleFormat format = static_cast(avFrame->format); TORCH_CHECK( @@ -1944,10 +1943,10 @@ FrameDims getHeightAndWidthFromOptionsOrMetadata( FrameDims getHeightAndWidthFromOptionsOrAVFrame( const VideoDecoder::VideoStreamOptions& videoStreamOptions, - const AVFrame& avFrame) { + const UniqueAVFrame& avFrame) { return FrameDims( - videoStreamOptions.height.value_or(avFrame.height), - videoStreamOptions.width.value_or(avFrame.width)); + videoStreamOptions.height.value_or(avFrame->height), + videoStreamOptions.width.value_or(avFrame->width)); } } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index f72f31ab..68311961 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -244,23 +244,6 @@ class VideoDecoder { // These are APIs that should be private, but that are effectively exposed for // practical reasons, typically for testing purposes. - // This struct is needed because AVFrame doesn't retain the streamIndex. Only - // the AVPacket knows its stream. This is what the low-level private decoding - // entry points return. The AVFrameStream is then converted to a FrameOutput - // with convertAVFrameToFrameOutput. It should be private, but is currently - // used by DeviceInterface. - struct AVFrameStream { - // The actual decoded output as a unique pointer to an AVFrame. - // Usually, this is a YUV frame. It'll be converted to RGB in - // convertAVFrameToFrameOutput. - UniqueAVFrame avFrame; - // The stream index of the decoded frame. - int streamIndex; - - explicit AVFrameStream(UniqueAVFrame&& a, int s) - : avFrame(std::move(a)), streamIndex(s) {} - }; - // Once getFrameAtIndex supports the preAllocatedOutputTensor parameter, we // can move it back to private. FrameOutput getFrameAtIndexInternal( @@ -376,7 +359,8 @@ class VideoDecoder { void maybeSeekToBeforeDesiredPts(); - AVFrameStream decodeAVFrame(std::function filterFunction); + UniqueAVFrame decodeAVFrame( + std::function filterFunction); FrameOutput getNextFrameInternal( std::optional preAllocatedOutputTensor = std::nullopt); @@ -384,23 +368,24 @@ class VideoDecoder { torch::Tensor maybePermuteHWC2CHW(torch::Tensor& hwcTensor); FrameOutput convertAVFrameToFrameOutput( - AVFrameStream& avFrameStream, + UniqueAVFrame& avFrame, std::optional preAllocatedOutputTensor = std::nullopt); void convertAVFrameToFrameOutputOnCPU( - AVFrameStream& avFrameStream, + UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor = std::nullopt); void convertAudioAVFrameToFrameOutputOnCPU( - AVFrameStream& avFrameStream, + UniqueAVFrame& srcAVFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor = std::nullopt); - torch::Tensor convertAVFrameToTensorUsingFilterGraph(const AVFrame* avFrame); + torch::Tensor convertAVFrameToTensorUsingFilterGraph( + const UniqueAVFrame& avFrame); int convertAVFrameToTensorUsingSwsScale( - const AVFrame* avFrame, + const UniqueAVFrame& avFrame, torch::Tensor& outputTensor); UniqueAVFrame convertAudioAVFrameSampleFormat( @@ -568,7 +553,7 @@ FrameDims getHeightAndWidthFromOptionsOrMetadata( FrameDims getHeightAndWidthFromOptionsOrAVFrame( const VideoDecoder::VideoStreamOptions& videoStreamOptions, - const AVFrame& avFrame); + const UniqueAVFrame& avFrame); torch::Tensor allocateEmptyHWCTensor( int height,