diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index b773e018..778a1b3e 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -435,11 +435,9 @@ int VideoDecoder::getBestStreamIndex(AVMediaType mediaType) { void VideoDecoder::addVideoStreamDecoder( int preferredStreamIndex, const VideoStreamOptions& videoStreamOptions) { - if (activeStreamIndices_.count(preferredStreamIndex) > 0) { - throw std::invalid_argument( - "Stream with index " + std::to_string(preferredStreamIndex) + - " is already active."); - } + TORCH_CHECK( + activeStreamIndex_ == NO_ACTIVE_STREAM, + "Can only add one single stream."); TORCH_CHECK(formatContext_.get() != nullptr); AVCodecOnlyUseForCallingAVFindBestStream avCodec = nullptr; @@ -506,7 +504,7 @@ void VideoDecoder::addVideoStreamDecoder( } codecContext->time_base = streamInfo.stream->time_base; - activeStreamIndices_.insert(streamIndex); + activeStreamIndex_ = streamIndex; updateMetadataWithCodecContext(streamInfo.streamIndex, codecContext); streamInfo.videoStreamOptions = videoStreamOptions; @@ -726,53 +724,39 @@ bool VideoDecoder::canWeAvoidSeekingForStream( // AVFormatContext if it is needed. We can skip seeking in certain cases. See // the comment of canWeAvoidSeeking() for details. void VideoDecoder::maybeSeekToBeforeDesiredPts() { - if (activeStreamIndices_.size() == 0) { + if (activeStreamIndex_ == NO_ACTIVE_STREAM) { return; } - for (int streamIndex : activeStreamIndices_) { - StreamInfo& streamInfo = streamInfos_[streamIndex]; - // clang-format off: clang format clashes - streamInfo.discardFramesBeforePts = secondsToClosestPts(*desiredPtsSeconds_, streamInfo.timeBase); - // clang-format on - } + StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; + streamInfo.discardFramesBeforePts = + secondsToClosestPts(*desiredPtsSeconds_, streamInfo.timeBase); decodeStats_.numSeeksAttempted++; - // See comment for canWeAvoidSeeking() for details on why this optimization - // works. - bool mustSeek = false; - for (int streamIndex : activeStreamIndices_) { - StreamInfo& streamInfo = streamInfos_[streamIndex]; - int64_t desiredPtsForStream = *desiredPtsSeconds_ * streamInfo.timeBase.den; - if (!canWeAvoidSeekingForStream( - streamInfo, streamInfo.currentPts, desiredPtsForStream)) { - mustSeek = true; - break; - } - } - if (!mustSeek) { + + int64_t desiredPtsForStream = *desiredPtsSeconds_ * streamInfo.timeBase.den; + if (canWeAvoidSeekingForStream( + streamInfo, streamInfo.currentPts, desiredPtsForStream)) { decodeStats_.numSeeksSkipped++; return; } - int firstActiveStreamIndex = *activeStreamIndices_.begin(); - const auto& firstStreamInfo = streamInfos_[firstActiveStreamIndex]; int64_t desiredPts = - secondsToClosestPts(*desiredPtsSeconds_, firstStreamInfo.timeBase); + secondsToClosestPts(*desiredPtsSeconds_, streamInfo.timeBase); // For some encodings like H265, FFMPEG sometimes seeks past the point we // set as the max_ts. So we use our own index to give it the exact pts of // the key frame that we want to seek to. // See https://github.com/pytorch/torchcodec/issues/179 for more details. // See https://trac.ffmpeg.org/ticket/11137 for the underlying ffmpeg bug. - if (!firstStreamInfo.keyFrames.empty()) { + if (!streamInfo.keyFrames.empty()) { int desiredKeyFrameIndex = getKeyFrameIndexForPtsUsingScannedIndex( - firstStreamInfo.keyFrames, desiredPts); + streamInfo.keyFrames, desiredPts); desiredKeyFrameIndex = std::max(desiredKeyFrameIndex, 0); - desiredPts = firstStreamInfo.keyFrames[desiredKeyFrameIndex].pts; + desiredPts = streamInfo.keyFrames[desiredKeyFrameIndex].pts; } int ffmepgStatus = avformat_seek_file( formatContext_.get(), - firstStreamInfo.streamIndex, + streamInfo.streamIndex, INT64_MIN, desiredPts, desiredPts, @@ -783,15 +767,12 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() { getFFMPEGErrorStringFromErrorCode(ffmepgStatus)); } decodeStats_.numFlushes++; - for (int streamIndex : activeStreamIndices_) { - StreamInfo& streamInfo = streamInfos_[streamIndex]; - avcodec_flush_buffers(streamInfo.codecContext.get()); - } + avcodec_flush_buffers(streamInfo.codecContext.get()); } VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame( - std::function filterFunction) { - if (activeStreamIndices_.size() == 0) { + std::function filterFunction) { + if (activeStreamIndex_ == NO_ACTIVE_STREAM) { throw std::runtime_error("No active streams configured."); } @@ -803,44 +784,25 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame( desiredPtsSeconds_ = std::nullopt; } + StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; + // Need to get the next frame or error from PopFrame. UniqueAVFrame avFrame(av_frame_alloc()); AutoAVPacket autoAVPacket; int ffmpegStatus = AVSUCCESS; bool reachedEOF = false; - int frameStreamIndex = -1; while (true) { - frameStreamIndex = -1; - bool gotPermanentErrorOnAnyActiveStream = false; - - // Get a frame on an active stream. Note that we don't know ahead of time - // which streams have frames to receive, so we linearly try the active - // streams. - for (int streamIndex : activeStreamIndices_) { - StreamInfo& streamInfo = streamInfos_[streamIndex]; - ffmpegStatus = - avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get()); - - if (ffmpegStatus != AVSUCCESS && ffmpegStatus != AVERROR(EAGAIN)) { - gotPermanentErrorOnAnyActiveStream = true; - break; - } + ffmpegStatus = + avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get()); - if (ffmpegStatus == AVSUCCESS) { - frameStreamIndex = streamIndex; - break; - } - } - - if (gotPermanentErrorOnAnyActiveStream) { + if (ffmpegStatus != AVSUCCESS && ffmpegStatus != AVERROR(EAGAIN)) { + // Non-retriable error break; } decodeStats_.numFramesReceivedByDecoder++; - // Is this the kind of frame we're looking for? - if (ffmpegStatus == AVSUCCESS && - filterFunction(frameStreamIndex, avFrame.get())) { + if (ffmpegStatus == AVSUCCESS && filterFunction(avFrame.get())) { // Yes, this is the frame we'll return; break out of the decoding loop. break; } else if (ffmpegStatus == AVSUCCESS) { @@ -865,18 +827,15 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame( decodeStats_.numPacketsRead++; if (ffmpegStatus == AVERROR_EOF) { - // End of file reached. We must drain all codecs by sending a nullptr + // End of file reached. We must drain the codec by sending a nullptr // packet. - for (int streamIndex : activeStreamIndices_) { - StreamInfo& streamInfo = streamInfos_[streamIndex]; - ffmpegStatus = avcodec_send_packet( - streamInfo.codecContext.get(), - /*avpkt=*/nullptr); - if (ffmpegStatus < AVSUCCESS) { - throw std::runtime_error( - "Could not flush decoder: " + - getFFMPEGErrorStringFromErrorCode(ffmpegStatus)); - } + ffmpegStatus = avcodec_send_packet( + streamInfo.codecContext.get(), + /*avpkt=*/nullptr); + if (ffmpegStatus < AVSUCCESS) { + throw std::runtime_error( + "Could not flush decoder: " + + getFFMPEGErrorStringFromErrorCode(ffmpegStatus)); } // We've reached the end of file so we can't read any more packets from @@ -892,15 +851,14 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame( getFFMPEGErrorStringFromErrorCode(ffmpegStatus)); } - if (activeStreamIndices_.count(packet->stream_index) == 0) { - // This packet is not for any of the active streams. + if (packet->stream_index != activeStreamIndex_) { continue; } // We got a valid packet. Send it to the decoder, and we'll receive it in // the next iteration. - ffmpegStatus = avcodec_send_packet( - streamInfos_[packet->stream_index].codecContext.get(), packet.get()); + ffmpegStatus = + avcodec_send_packet(streamInfo.codecContext.get(), packet.get()); if (ffmpegStatus < AVSUCCESS) { throw std::runtime_error( "Could not push packet to decoder: " + @@ -927,11 +885,10 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame( // haven't received as frames. Eventually we will either hit AVERROR_EOF from // av_receive_frame() or the user will have seeked to a different location in // the file and that will flush the decoder. - StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex]; - activeStreamInfo.currentPts = avFrame->pts; - activeStreamInfo.currentDuration = getDuration(avFrame); + streamInfo.currentPts = avFrame->pts; + streamInfo.currentDuration = getDuration(avFrame); - return AVFrameStream(std::move(avFrame), frameStreamIndex); + return AVFrameStream(std::move(avFrame), activeStreamIndex_); } VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput( @@ -1096,8 +1053,8 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAtNoDemux( setCursorPtsInSeconds(seconds); AVFrameStream avFrameStream = - decodeAVFrame([seconds, this](int frameStreamIndex, AVFrame* avFrame) { - StreamInfo& streamInfo = streamInfos_[frameStreamIndex]; + decodeAVFrame([seconds, this](AVFrame* avFrame) { + StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; double frameStartTime = ptsToSeconds(avFrame->pts, streamInfo.timeBase); double frameEndTime = ptsToSeconds( avFrame->pts + getDuration(avFrame), streamInfo.timeBase); @@ -1496,11 +1453,10 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemux() { VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemuxInternal( std::optional preAllocatedOutputTensor) { - AVFrameStream avFrameStream = - decodeAVFrame([this](int frameStreamIndex, AVFrame* avFrame) { - StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex]; - return avFrame->pts >= activeStreamInfo.discardFramesBeforePts; - }); + AVFrameStream avFrameStream = decodeAVFrame([this](AVFrame* avFrame) { + StreamInfo& activeStreamInfo = streamInfos_[activeStreamIndex_]; + return avFrame->pts >= activeStreamInfo.discardFramesBeforePts; + }); return convertAVFrameToFrameOutput(avFrameStream, preAllocatedOutputTensor); } diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 0d4bfb1c..696b2fa2 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -361,8 +361,7 @@ class VideoDecoder { void maybeSeekToBeforeDesiredPts(); - AVFrameStream decodeAVFrame( - std::function filterFunction); + AVFrameStream decodeAVFrame(std::function filterFunction); FrameOutput getNextFrameNoDemuxInternal( std::optional preAllocatedOutputTensor = std::nullopt); @@ -469,9 +468,8 @@ class VideoDecoder { ContainerMetadata containerMetadata_; UniqueAVFormatContext formatContext_; std::map streamInfos_; - // Stores the stream indices of the active streams, i.e. the streams we are - // decoding and returning to the user. - std::set activeStreamIndices_; + const int NO_ACTIVE_STREAM = -2; + int activeStreamIndex_ = NO_ACTIVE_STREAM; // Set when the user wants to seek and stores the desired pts that the user // wants to seek to. std::optional desiredPtsSeconds_;