diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index c0738e57..aeefa3b5 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -850,7 +850,7 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio( startSeconds <= stopSeconds, "Start seconds (" + std::to_string(startSeconds) + ") must be less than or equal to stop seconds (" + - std::to_string(stopSeconds) + "."); + std::to_string(stopSeconds) + ")."); if (startSeconds == stopSeconds) { // For consistency with video @@ -859,16 +859,14 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio( StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; - // TODO-AUDIO This essentially enforce that we don't need to seek (backwards). - // We should remove it and seek back to the stream's beginning when needed. - // See test_multiple_calls - TORCH_CHECK( - streamInfo.lastDecodedAvFramePts + - streamInfo.lastDecodedAvFrameDuration <= - secondsToClosestPts(startSeconds, streamInfo.timeBase), - "Audio decoder cannot seek backwards, or start from the last decoded frame."); - - setCursorPtsInSeconds(startSeconds); + auto startPts = secondsToClosestPts(startSeconds, streamInfo.timeBase); + if (startPts < streamInfo.lastDecodedAvFramePts + + streamInfo.lastDecodedAvFrameDuration) { + // If we need to seek backwards, then we have to seek back to the beginning + // of the stream. + // TODO-AUDIO: document why this is needed in a big comment. + setCursorPtsInSeconds(INT64_MIN); + } // TODO-AUDIO Pre-allocate a long-enough tensor instead of creating a vec + // cat(). This would save a copy. We know the duration of the output and the @@ -879,8 +877,8 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio( auto finished = false; while (!finished) { try { - AVFrameStream avFrameStream = decodeAVFrame([this](AVFrame* avFrame) { - return cursor_ < avFrame->pts + getDuration(avFrame); + AVFrameStream avFrameStream = decodeAVFrame([startPts](AVFrame* avFrame) { + return startPts < avFrame->pts + getDuration(avFrame); }); auto frameOutput = convertAVFrameToFrameOutput(avFrameStream); tensors.push_back(frameOutput.data); @@ -938,7 +936,9 @@ I P P P I P P P I P P I P P I P bool VideoDecoder::canWeAvoidSeeking() const { const StreamInfo& streamInfo = streamInfos_.at(activeStreamIndex_); if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) { - return true; + // For audio, we only need to seek if a backwards seek was requested within + // getFramesPlayedInRangeAudio(), when setCursorPtsInSeconds() was called. + return !cursorWasJustSet_; } int64_t lastDecodedAvFramePts = streamInfos_.at(activeStreamIndex_).lastDecodedAvFramePts; diff --git a/test/decoders/test_ops.py b/test/decoders/test_ops.py index e33b9941..ce74243f 100644 --- a/test/decoders/test_ops.py +++ b/test/decoders/test_ops.py @@ -741,19 +741,18 @@ def test_decode_start_equal_stop(self, asset): @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) def test_multiple_calls(self, asset): - # Ensure that multiple calls are OK as long as we're decoding - # "sequentially", i.e. we don't require a backwards seek. - # And ensure a proper error is raised in such case. - # TODO-AUDIO We shouldn't error, we should just implement the seeking - # back to the beginning of the stream. + # Ensure that multiple calls to get_frames_by_pts_in_range_audio on the + # same decoder are supported and correct, whether it involves forward + # seeks or backwards seeks. def get_reference_frames(start_seconds, stop_seconds): - # This stateless helper exists for convenience, to avoid - # complicating this test with pts-to-index conversions. Eventually - # we should remove it and just rely on the asset's methods. - # Using this helper is OK for now: we're comparing a decoder which - # seeks multiple times with a decoder which seeks only once (the one - # here, treated as the reference) + # Usually we get the reference frames from the asset's methods, but + # for this specific test, this helper is more convenient, because + # relying on the asset would force us to convert all timestamps into + # indices. + # Ultimately, this test compares a "stateful decoder" which calls + # `get_frames_by_pts_in_range_audio()`` multiple times with a + # "stateless decoder" (the one here, treated as the reference) decoder = create_from_file(str(asset.path), seek_mode="approximate") add_audio_stream(decoder) @@ -794,23 +793,30 @@ def get_reference_frames(start_seconds, stop_seconds): frames, get_reference_frames(start_seconds, stop_seconds) ) - # but starting immediately on the same frame raises - expected_match = "Audio decoder cannot seek backwards" - with pytest.raises(RuntimeError, match=expected_match): - get_frames_by_pts_in_range_audio( - decoder, start_seconds=stop_seconds, stop_seconds=6 - ) + # starting immediately on the same frame is OK + start_seconds, stop_seconds = stop_seconds, 6 + frames = get_frames_by_pts_in_range_audio( + decoder, start_seconds=start_seconds, stop_seconds=stop_seconds + ) + torch.testing.assert_close( + frames, get_reference_frames(start_seconds, stop_seconds) + ) - with pytest.raises(RuntimeError, match=expected_match): - get_frames_by_pts_in_range_audio( - decoder, start_seconds=stop_seconds + 1e-4, stop_seconds=6 - ) + get_frames_by_pts_in_range_audio( + decoder, start_seconds=start_seconds + 1e-4, stop_seconds=stop_seconds + ) + torch.testing.assert_close( + frames, get_reference_frames(start_seconds, stop_seconds) + ) - # and seeking backwards doesn't work either - with pytest.raises(RuntimeError, match=expected_match): - frames = get_frames_by_pts_in_range_audio( - decoder, start_seconds=0, stop_seconds=2 - ) + # seeking backwards + start_seconds, stop_seconds = 0, 2 + frames = get_frames_by_pts_in_range_audio( + decoder, start_seconds=start_seconds, stop_seconds=stop_seconds + ) + torch.testing.assert_close( + frames, get_reference_frames(start_seconds, stop_seconds) + ) if __name__ == "__main__":