Skip to content

Commit 28b0de0

Browse files
authored
Remove multi-stream related code (#483)
1 parent 298c9c1 commit 28b0de0

File tree

2 files changed

+50
-96
lines changed

2 files changed

+50
-96
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 47 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -435,11 +435,9 @@ int VideoDecoder::getBestStreamIndex(AVMediaType mediaType) {
435435
void VideoDecoder::addVideoStreamDecoder(
436436
int preferredStreamIndex,
437437
const VideoStreamOptions& videoStreamOptions) {
438-
if (activeStreamIndices_.count(preferredStreamIndex) > 0) {
439-
throw std::invalid_argument(
440-
"Stream with index " + std::to_string(preferredStreamIndex) +
441-
" is already active.");
442-
}
438+
TORCH_CHECK(
439+
activeStreamIndex_ == NO_ACTIVE_STREAM,
440+
"Can only add one single stream.");
443441
TORCH_CHECK(formatContext_.get() != nullptr);
444442

445443
AVCodecOnlyUseForCallingAVFindBestStream avCodec = nullptr;
@@ -506,7 +504,7 @@ void VideoDecoder::addVideoStreamDecoder(
506504
}
507505

508506
codecContext->time_base = streamInfo.stream->time_base;
509-
activeStreamIndices_.insert(streamIndex);
507+
activeStreamIndex_ = streamIndex;
510508
updateMetadataWithCodecContext(streamInfo.streamIndex, codecContext);
511509
streamInfo.videoStreamOptions = videoStreamOptions;
512510

@@ -754,53 +752,39 @@ bool VideoDecoder::canWeAvoidSeekingForStream(
754752
// AVFormatContext if it is needed. We can skip seeking in certain cases. See
755753
// the comment of canWeAvoidSeeking() for details.
756754
void VideoDecoder::maybeSeekToBeforeDesiredPts() {
757-
if (activeStreamIndices_.size() == 0) {
755+
if (activeStreamIndex_ == NO_ACTIVE_STREAM) {
758756
return;
759757
}
760-
for (int streamIndex : activeStreamIndices_) {
761-
StreamInfo& streamInfo = streamInfos_[streamIndex];
762-
// clang-format off: clang format clashes
763-
streamInfo.discardFramesBeforePts = secondsToClosestPts(*desiredPtsSeconds_, streamInfo.timeBase);
764-
// clang-format on
765-
}
758+
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
759+
streamInfo.discardFramesBeforePts =
760+
secondsToClosestPts(*desiredPtsSeconds_, streamInfo.timeBase);
766761

767762
decodeStats_.numSeeksAttempted++;
768-
// See comment for canWeAvoidSeeking() for details on why this optimization
769-
// works.
770-
bool mustSeek = false;
771-
for (int streamIndex : activeStreamIndices_) {
772-
StreamInfo& streamInfo = streamInfos_[streamIndex];
773-
int64_t desiredPtsForStream = *desiredPtsSeconds_ * streamInfo.timeBase.den;
774-
if (!canWeAvoidSeekingForStream(
775-
streamInfo, streamInfo.currentPts, desiredPtsForStream)) {
776-
mustSeek = true;
777-
break;
778-
}
779-
}
780-
if (!mustSeek) {
763+
764+
int64_t desiredPtsForStream = *desiredPtsSeconds_ * streamInfo.timeBase.den;
765+
if (canWeAvoidSeekingForStream(
766+
streamInfo, streamInfo.currentPts, desiredPtsForStream)) {
781767
decodeStats_.numSeeksSkipped++;
782768
return;
783769
}
784-
int firstActiveStreamIndex = *activeStreamIndices_.begin();
785-
const auto& firstStreamInfo = streamInfos_[firstActiveStreamIndex];
786770
int64_t desiredPts =
787-
secondsToClosestPts(*desiredPtsSeconds_, firstStreamInfo.timeBase);
771+
secondsToClosestPts(*desiredPtsSeconds_, streamInfo.timeBase);
788772

789773
// For some encodings like H265, FFMPEG sometimes seeks past the point we
790774
// set as the max_ts. So we use our own index to give it the exact pts of
791775
// the key frame that we want to seek to.
792776
// See https://github.com/pytorch/torchcodec/issues/179 for more details.
793777
// See https://trac.ffmpeg.org/ticket/11137 for the underlying ffmpeg bug.
794-
if (!firstStreamInfo.keyFrames.empty()) {
778+
if (!streamInfo.keyFrames.empty()) {
795779
int desiredKeyFrameIndex = getKeyFrameIndexForPtsUsingScannedIndex(
796-
firstStreamInfo.keyFrames, desiredPts);
780+
streamInfo.keyFrames, desiredPts);
797781
desiredKeyFrameIndex = std::max(desiredKeyFrameIndex, 0);
798-
desiredPts = firstStreamInfo.keyFrames[desiredKeyFrameIndex].pts;
782+
desiredPts = streamInfo.keyFrames[desiredKeyFrameIndex].pts;
799783
}
800784

801785
int ffmepgStatus = avformat_seek_file(
802786
formatContext_.get(),
803-
firstStreamInfo.streamIndex,
787+
streamInfo.streamIndex,
804788
INT64_MIN,
805789
desiredPts,
806790
desiredPts,
@@ -811,15 +795,12 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
811795
getFFMPEGErrorStringFromErrorCode(ffmepgStatus));
812796
}
813797
decodeStats_.numFlushes++;
814-
for (int streamIndex : activeStreamIndices_) {
815-
StreamInfo& streamInfo = streamInfos_[streamIndex];
816-
avcodec_flush_buffers(streamInfo.codecContext.get());
817-
}
798+
avcodec_flush_buffers(streamInfo.codecContext.get());
818799
}
819800

820801
VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
821-
std::function<bool(int, AVFrame*)> filterFunction) {
822-
if (activeStreamIndices_.size() == 0) {
802+
std::function<bool(AVFrame*)> filterFunction) {
803+
if (activeStreamIndex_ == NO_ACTIVE_STREAM) {
823804
throw std::runtime_error("No active streams configured.");
824805
}
825806

@@ -831,44 +812,25 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
831812
desiredPtsSeconds_ = std::nullopt;
832813
}
833814

815+
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
816+
834817
// Need to get the next frame or error from PopFrame.
835818
UniqueAVFrame avFrame(av_frame_alloc());
836819
AutoAVPacket autoAVPacket;
837820
int ffmpegStatus = AVSUCCESS;
838821
bool reachedEOF = false;
839-
int frameStreamIndex = -1;
840822
while (true) {
841-
frameStreamIndex = -1;
842-
bool gotPermanentErrorOnAnyActiveStream = false;
843-
844-
// Get a frame on an active stream. Note that we don't know ahead of time
845-
// which streams have frames to receive, so we linearly try the active
846-
// streams.
847-
for (int streamIndex : activeStreamIndices_) {
848-
StreamInfo& streamInfo = streamInfos_[streamIndex];
849-
ffmpegStatus =
850-
avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get());
851-
852-
if (ffmpegStatus != AVSUCCESS && ffmpegStatus != AVERROR(EAGAIN)) {
853-
gotPermanentErrorOnAnyActiveStream = true;
854-
break;
855-
}
823+
ffmpegStatus =
824+
avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get());
856825

857-
if (ffmpegStatus == AVSUCCESS) {
858-
frameStreamIndex = streamIndex;
859-
break;
860-
}
861-
}
862-
863-
if (gotPermanentErrorOnAnyActiveStream) {
826+
if (ffmpegStatus != AVSUCCESS && ffmpegStatus != AVERROR(EAGAIN)) {
827+
// Non-retriable error
864828
break;
865829
}
866830

867831
decodeStats_.numFramesReceivedByDecoder++;
868-
869832
// Is this the kind of frame we're looking for?
870-
if (ffmpegStatus == AVSUCCESS &&
871-
filterFunction(frameStreamIndex, avFrame.get())) {
833+
if (ffmpegStatus == AVSUCCESS && filterFunction(avFrame.get())) {
872834
// Yes, this is the frame we'll return; break out of the decoding loop.
873835
break;
874836
} else if (ffmpegStatus == AVSUCCESS) {
@@ -893,18 +855,15 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
893855
decodeStats_.numPacketsRead++;
894856

895857
if (ffmpegStatus == AVERROR_EOF) {
896-
// End of file reached. We must drain all codecs by sending a nullptr
858+
// End of file reached. We must drain the codec by sending a nullptr
897859
// packet.
898-
for (int streamIndex : activeStreamIndices_) {
899-
StreamInfo& streamInfo = streamInfos_[streamIndex];
900-
ffmpegStatus = avcodec_send_packet(
901-
streamInfo.codecContext.get(),
902-
/*avpkt=*/nullptr);
903-
if (ffmpegStatus < AVSUCCESS) {
904-
throw std::runtime_error(
905-
"Could not flush decoder: " +
906-
getFFMPEGErrorStringFromErrorCode(ffmpegStatus));
907-
}
860+
ffmpegStatus = avcodec_send_packet(
861+
streamInfo.codecContext.get(),
862+
/*avpkt=*/nullptr);
863+
if (ffmpegStatus < AVSUCCESS) {
864+
throw std::runtime_error(
865+
"Could not flush decoder: " +
866+
getFFMPEGErrorStringFromErrorCode(ffmpegStatus));
908867
}
909868

910869
// We've reached the end of file so we can't read any more packets from
@@ -920,15 +879,14 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
920879
getFFMPEGErrorStringFromErrorCode(ffmpegStatus));
921880
}
922881

923-
if (activeStreamIndices_.count(packet->stream_index) == 0) {
924-
// This packet is not for any of the active streams.
882+
if (packet->stream_index != activeStreamIndex_) {
925883
continue;
926884
}
927885

928886
// We got a valid packet. Send it to the decoder, and we'll receive it in
929887
// the next iteration.
930-
ffmpegStatus = avcodec_send_packet(
931-
streamInfos_[packet->stream_index].codecContext.get(), packet.get());
888+
ffmpegStatus =
889+
avcodec_send_packet(streamInfo.codecContext.get(), packet.get());
932890
if (ffmpegStatus < AVSUCCESS) {
933891
throw std::runtime_error(
934892
"Could not push packet to decoder: " +
@@ -955,11 +913,10 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
955913
// haven't received as frames. Eventually we will either hit AVERROR_EOF from
956914
// av_receive_frame() or the user will have seeked to a different location in
957915
// the file and that will flush the decoder.
958-
StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex];
959-
activeStreamInfo.currentPts = avFrame->pts;
960-
activeStreamInfo.currentDuration = getDuration(avFrame);
916+
streamInfo.currentPts = avFrame->pts;
917+
streamInfo.currentDuration = getDuration(avFrame);
961918

962-
return AVFrameStream(std::move(avFrame), frameStreamIndex);
919+
return AVFrameStream(std::move(avFrame), activeStreamIndex_);
963920
}
964921

965922
VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
@@ -1124,8 +1081,8 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAtNoDemux(
11241081

11251082
setCursorPtsInSeconds(seconds);
11261083
AVFrameStream avFrameStream =
1127-
decodeAVFrame([seconds, this](int frameStreamIndex, AVFrame* avFrame) {
1128-
StreamInfo& streamInfo = streamInfos_[frameStreamIndex];
1084+
decodeAVFrame([seconds, this](AVFrame* avFrame) {
1085+
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
11291086
double frameStartTime = ptsToSeconds(avFrame->pts, streamInfo.timeBase);
11301087
double frameEndTime = ptsToSeconds(
11311088
avFrame->pts + getDuration(avFrame), streamInfo.timeBase);
@@ -1524,11 +1481,10 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemux() {
15241481

15251482
VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemuxInternal(
15261483
std::optional<torch::Tensor> preAllocatedOutputTensor) {
1527-
AVFrameStream avFrameStream =
1528-
decodeAVFrame([this](int frameStreamIndex, AVFrame* avFrame) {
1529-
StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex];
1530-
return avFrame->pts >= activeStreamInfo.discardFramesBeforePts;
1531-
});
1484+
AVFrameStream avFrameStream = decodeAVFrame([this](AVFrame* avFrame) {
1485+
StreamInfo& activeStreamInfo = streamInfos_[activeStreamIndex_];
1486+
return avFrame->pts >= activeStreamInfo.discardFramesBeforePts;
1487+
});
15321488
return convertAVFrameToFrameOutput(avFrameStream, preAllocatedOutputTensor);
15331489
}
15341490

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -372,8 +372,7 @@ class VideoDecoder {
372372

373373
void maybeSeekToBeforeDesiredPts();
374374

375-
AVFrameStream decodeAVFrame(
376-
std::function<bool(int, AVFrame*)> filterFunction);
375+
AVFrameStream decodeAVFrame(std::function<bool(AVFrame*)> filterFunction);
377376

378377
FrameOutput getNextFrameNoDemuxInternal(
379378
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
@@ -480,9 +479,8 @@ class VideoDecoder {
480479
ContainerMetadata containerMetadata_;
481480
UniqueAVFormatContext formatContext_;
482481
std::map<int, StreamInfo> streamInfos_;
483-
// Stores the stream indices of the active streams, i.e. the streams we are
484-
// decoding and returning to the user.
485-
std::set<int> activeStreamIndices_;
482+
const int NO_ACTIVE_STREAM = -2;
483+
int activeStreamIndex_ = NO_ACTIVE_STREAM;
486484
// Set when the user wants to seek and stores the desired pts that the user
487485
// wants to seek to.
488486
std::optional<double> desiredPtsSeconds_;

0 commit comments

Comments
 (0)