@@ -435,11 +435,9 @@ int VideoDecoder::getBestStreamIndex(AVMediaType mediaType) {
435
435
void VideoDecoder::addVideoStreamDecoder (
436
436
int preferredStreamIndex,
437
437
const VideoStreamOptions& videoStreamOptions) {
438
- if (activeStreamIndices_.count (preferredStreamIndex) > 0 ) {
439
- throw std::invalid_argument (
440
- " Stream with index " + std::to_string (preferredStreamIndex) +
441
- " is already active." );
442
- }
438
+ TORCH_CHECK (
439
+ activeStreamIndex_ == NO_ACTIVE_STREAM,
440
+ " Can only add one single stream." );
443
441
TORCH_CHECK (formatContext_.get () != nullptr );
444
442
445
443
AVCodecOnlyUseForCallingAVFindBestStream avCodec = nullptr ;
@@ -506,7 +504,7 @@ void VideoDecoder::addVideoStreamDecoder(
506
504
}
507
505
508
506
codecContext->time_base = streamInfo.stream ->time_base ;
509
- activeStreamIndices_. insert ( streamIndex) ;
507
+ activeStreamIndex_ = streamIndex;
510
508
updateMetadataWithCodecContext (streamInfo.streamIndex , codecContext);
511
509
streamInfo.videoStreamOptions = videoStreamOptions;
512
510
@@ -754,53 +752,39 @@ bool VideoDecoder::canWeAvoidSeekingForStream(
754
752
// AVFormatContext if it is needed. We can skip seeking in certain cases. See
755
753
// the comment of canWeAvoidSeeking() for details.
756
754
void VideoDecoder::maybeSeekToBeforeDesiredPts () {
757
- if (activeStreamIndices_. size () == 0 ) {
755
+ if (activeStreamIndex_ == NO_ACTIVE_STREAM ) {
758
756
return ;
759
757
}
760
- for (int streamIndex : activeStreamIndices_) {
761
- StreamInfo& streamInfo = streamInfos_[streamIndex];
762
- // clang-format off: clang format clashes
763
- streamInfo.discardFramesBeforePts = secondsToClosestPts (*desiredPtsSeconds_, streamInfo.timeBase );
764
- // clang-format on
765
- }
758
+ StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
759
+ streamInfo.discardFramesBeforePts =
760
+ secondsToClosestPts (*desiredPtsSeconds_, streamInfo.timeBase );
766
761
767
762
decodeStats_.numSeeksAttempted ++;
768
- // See comment for canWeAvoidSeeking() for details on why this optimization
769
- // works.
770
- bool mustSeek = false ;
771
- for (int streamIndex : activeStreamIndices_) {
772
- StreamInfo& streamInfo = streamInfos_[streamIndex];
773
- int64_t desiredPtsForStream = *desiredPtsSeconds_ * streamInfo.timeBase .den ;
774
- if (!canWeAvoidSeekingForStream (
775
- streamInfo, streamInfo.currentPts , desiredPtsForStream)) {
776
- mustSeek = true ;
777
- break ;
778
- }
779
- }
780
- if (!mustSeek) {
763
+
764
+ int64_t desiredPtsForStream = *desiredPtsSeconds_ * streamInfo.timeBase .den ;
765
+ if (canWeAvoidSeekingForStream (
766
+ streamInfo, streamInfo.currentPts , desiredPtsForStream)) {
781
767
decodeStats_.numSeeksSkipped ++;
782
768
return ;
783
769
}
784
- int firstActiveStreamIndex = *activeStreamIndices_.begin ();
785
- const auto & firstStreamInfo = streamInfos_[firstActiveStreamIndex];
786
770
int64_t desiredPts =
787
- secondsToClosestPts (*desiredPtsSeconds_, firstStreamInfo .timeBase );
771
+ secondsToClosestPts (*desiredPtsSeconds_, streamInfo .timeBase );
788
772
789
773
// For some encodings like H265, FFMPEG sometimes seeks past the point we
790
774
// set as the max_ts. So we use our own index to give it the exact pts of
791
775
// the key frame that we want to seek to.
792
776
// See https://github.com/pytorch/torchcodec/issues/179 for more details.
793
777
// See https://trac.ffmpeg.org/ticket/11137 for the underlying ffmpeg bug.
794
- if (!firstStreamInfo .keyFrames .empty ()) {
778
+ if (!streamInfo .keyFrames .empty ()) {
795
779
int desiredKeyFrameIndex = getKeyFrameIndexForPtsUsingScannedIndex (
796
- firstStreamInfo .keyFrames , desiredPts);
780
+ streamInfo .keyFrames , desiredPts);
797
781
desiredKeyFrameIndex = std::max (desiredKeyFrameIndex, 0 );
798
- desiredPts = firstStreamInfo .keyFrames [desiredKeyFrameIndex].pts ;
782
+ desiredPts = streamInfo .keyFrames [desiredKeyFrameIndex].pts ;
799
783
}
800
784
801
785
int ffmepgStatus = avformat_seek_file (
802
786
formatContext_.get (),
803
- firstStreamInfo .streamIndex ,
787
+ streamInfo .streamIndex ,
804
788
INT64_MIN,
805
789
desiredPts,
806
790
desiredPts,
@@ -811,15 +795,12 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
811
795
getFFMPEGErrorStringFromErrorCode (ffmepgStatus));
812
796
}
813
797
decodeStats_.numFlushes ++;
814
- for (int streamIndex : activeStreamIndices_) {
815
- StreamInfo& streamInfo = streamInfos_[streamIndex];
816
- avcodec_flush_buffers (streamInfo.codecContext .get ());
817
- }
798
+ avcodec_flush_buffers (streamInfo.codecContext .get ());
818
799
}
819
800
820
801
VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame (
821
- std::function<bool (int , AVFrame*)> filterFunction) {
822
- if (activeStreamIndices_. size () == 0 ) {
802
+ std::function<bool (AVFrame*)> filterFunction) {
803
+ if (activeStreamIndex_ == NO_ACTIVE_STREAM ) {
823
804
throw std::runtime_error (" No active streams configured." );
824
805
}
825
806
@@ -831,44 +812,25 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
831
812
desiredPtsSeconds_ = std::nullopt;
832
813
}
833
814
815
+ StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
816
+
834
817
// Need to get the next frame or error from PopFrame.
835
818
UniqueAVFrame avFrame (av_frame_alloc ());
836
819
AutoAVPacket autoAVPacket;
837
820
int ffmpegStatus = AVSUCCESS;
838
821
bool reachedEOF = false ;
839
- int frameStreamIndex = -1 ;
840
822
while (true ) {
841
- frameStreamIndex = -1 ;
842
- bool gotPermanentErrorOnAnyActiveStream = false ;
843
-
844
- // Get a frame on an active stream. Note that we don't know ahead of time
845
- // which streams have frames to receive, so we linearly try the active
846
- // streams.
847
- for (int streamIndex : activeStreamIndices_) {
848
- StreamInfo& streamInfo = streamInfos_[streamIndex];
849
- ffmpegStatus =
850
- avcodec_receive_frame (streamInfo.codecContext .get (), avFrame.get ());
851
-
852
- if (ffmpegStatus != AVSUCCESS && ffmpegStatus != AVERROR (EAGAIN)) {
853
- gotPermanentErrorOnAnyActiveStream = true ;
854
- break ;
855
- }
823
+ ffmpegStatus =
824
+ avcodec_receive_frame (streamInfo.codecContext .get (), avFrame.get ());
856
825
857
- if (ffmpegStatus == AVSUCCESS) {
858
- frameStreamIndex = streamIndex;
859
- break ;
860
- }
861
- }
862
-
863
- if (gotPermanentErrorOnAnyActiveStream) {
826
+ if (ffmpegStatus != AVSUCCESS && ffmpegStatus != AVERROR (EAGAIN)) {
827
+ // Non-retriable error
864
828
break ;
865
829
}
866
830
867
831
decodeStats_.numFramesReceivedByDecoder ++;
868
-
869
832
// Is this the kind of frame we're looking for?
870
- if (ffmpegStatus == AVSUCCESS &&
871
- filterFunction (frameStreamIndex, avFrame.get ())) {
833
+ if (ffmpegStatus == AVSUCCESS && filterFunction (avFrame.get ())) {
872
834
// Yes, this is the frame we'll return; break out of the decoding loop.
873
835
break ;
874
836
} else if (ffmpegStatus == AVSUCCESS) {
@@ -893,18 +855,15 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
893
855
decodeStats_.numPacketsRead ++;
894
856
895
857
if (ffmpegStatus == AVERROR_EOF) {
896
- // End of file reached. We must drain all codecs by sending a nullptr
858
+ // End of file reached. We must drain the codec by sending a nullptr
897
859
// packet.
898
- for (int streamIndex : activeStreamIndices_) {
899
- StreamInfo& streamInfo = streamInfos_[streamIndex];
900
- ffmpegStatus = avcodec_send_packet (
901
- streamInfo.codecContext .get (),
902
- /* avpkt=*/ nullptr );
903
- if (ffmpegStatus < AVSUCCESS) {
904
- throw std::runtime_error (
905
- " Could not flush decoder: " +
906
- getFFMPEGErrorStringFromErrorCode (ffmpegStatus));
907
- }
860
+ ffmpegStatus = avcodec_send_packet (
861
+ streamInfo.codecContext .get (),
862
+ /* avpkt=*/ nullptr );
863
+ if (ffmpegStatus < AVSUCCESS) {
864
+ throw std::runtime_error (
865
+ " Could not flush decoder: " +
866
+ getFFMPEGErrorStringFromErrorCode (ffmpegStatus));
908
867
}
909
868
910
869
// We've reached the end of file so we can't read any more packets from
@@ -920,15 +879,14 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
920
879
getFFMPEGErrorStringFromErrorCode (ffmpegStatus));
921
880
}
922
881
923
- if (activeStreamIndices_.count (packet->stream_index ) == 0 ) {
924
- // This packet is not for any of the active streams.
882
+ if (packet->stream_index != activeStreamIndex_) {
925
883
continue ;
926
884
}
927
885
928
886
// We got a valid packet. Send it to the decoder, and we'll receive it in
929
887
// the next iteration.
930
- ffmpegStatus = avcodec_send_packet (
931
- streamInfos_[packet-> stream_index ] .codecContext .get (), packet.get ());
888
+ ffmpegStatus =
889
+ avcodec_send_packet (streamInfo .codecContext .get (), packet.get ());
932
890
if (ffmpegStatus < AVSUCCESS) {
933
891
throw std::runtime_error (
934
892
" Could not push packet to decoder: " +
@@ -955,11 +913,10 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
955
913
// haven't received as frames. Eventually we will either hit AVERROR_EOF from
956
914
// av_receive_frame() or the user will have seeked to a different location in
957
915
// the file and that will flush the decoder.
958
- StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex];
959
- activeStreamInfo.currentPts = avFrame->pts ;
960
- activeStreamInfo.currentDuration = getDuration (avFrame);
916
+ streamInfo.currentPts = avFrame->pts ;
917
+ streamInfo.currentDuration = getDuration (avFrame);
961
918
962
- return AVFrameStream (std::move (avFrame), frameStreamIndex );
919
+ return AVFrameStream (std::move (avFrame), activeStreamIndex_ );
963
920
}
964
921
965
922
VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput (
@@ -1124,8 +1081,8 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAtNoDemux(
1124
1081
1125
1082
setCursorPtsInSeconds (seconds);
1126
1083
AVFrameStream avFrameStream =
1127
- decodeAVFrame ([seconds, this ](int frameStreamIndex, AVFrame* avFrame) {
1128
- StreamInfo& streamInfo = streamInfos_[frameStreamIndex ];
1084
+ decodeAVFrame ([seconds, this ](AVFrame* avFrame) {
1085
+ StreamInfo& streamInfo = streamInfos_[activeStreamIndex_ ];
1129
1086
double frameStartTime = ptsToSeconds (avFrame->pts , streamInfo.timeBase );
1130
1087
double frameEndTime = ptsToSeconds (
1131
1088
avFrame->pts + getDuration (avFrame), streamInfo.timeBase );
@@ -1524,11 +1481,10 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemux() {
1524
1481
1525
1482
VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemuxInternal (
1526
1483
std::optional<torch::Tensor> preAllocatedOutputTensor) {
1527
- AVFrameStream avFrameStream =
1528
- decodeAVFrame ([this ](int frameStreamIndex, AVFrame* avFrame) {
1529
- StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex];
1530
- return avFrame->pts >= activeStreamInfo.discardFramesBeforePts ;
1531
- });
1484
+ AVFrameStream avFrameStream = decodeAVFrame ([this ](AVFrame* avFrame) {
1485
+ StreamInfo& activeStreamInfo = streamInfos_[activeStreamIndex_];
1486
+ return avFrame->pts >= activeStreamInfo.discardFramesBeforePts ;
1487
+ });
1532
1488
return convertAVFrameToFrameOutput (avFrameStream, preAllocatedOutputTensor);
1533
1489
}
1534
1490
0 commit comments