From 59e428ea86ade687806cd690e042636e4e61cd0b Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 14 Mar 2025 16:01:00 +0000 Subject: [PATCH 1/2] Audio: allow next(), disallow seek() --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 13 ++++++++++--- src/torchcodec/decoders/_core/VideoDecoder.h | 1 + test/decoders/test_ops.py | 16 +++++++++++++++- 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 79bad294..af82f2fe 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -566,13 +566,15 @@ void VideoDecoder::addAudioStream(int streamIndex) { VideoDecoder::FrameOutput VideoDecoder::getNextFrame() { auto output = getNextFrameInternal(); - output.data = maybePermuteHWC2CHW(output.data); + if (streamInfos_[activeStreamIndex_].avMediaType == AVMEDIA_TYPE_VIDEO) { + output.data = maybePermuteHWC2CHW(output.data); + } return output; } VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal( std::optional preAllocatedOutputTensor) { - validateActiveStream(AVMEDIA_TYPE_VIDEO); + validateActiveStream(); AVFrameStream avFrameStream = decodeAVFrame( [this](AVFrame* avFrame) { return avFrame->pts >= cursor_; }); return convertAVFrameToFrameOutput(avFrameStream, preAllocatedOutputTensor); @@ -868,7 +870,7 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio( // If we need to seek backwards, then we have to seek back to the beginning // of the stream. // TODO-AUDIO: document why this is needed in a big comment. - setCursorPtsInSeconds(INT64_MIN); + setCursorPtsInSecondsInternal(INT64_MIN); } // TODO-AUDIO Pre-allocate a long-enough tensor instead of creating a vec + @@ -914,6 +916,11 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio( // -------------------------------------------------------------------------- void VideoDecoder::setCursorPtsInSeconds(double seconds) { + validateActiveStream(AVMEDIA_TYPE_VIDEO); + setCursorPtsInSecondsInternal(seconds); +} + +void VideoDecoder::setCursorPtsInSecondsInternal(double seconds) { cursorWasJustSet_ = true; cursor_ = secondsToClosestPts(seconds, streamInfos_[activeStreamIndex_].timeBase); diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 51a780fb..63e01899 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -370,6 +370,7 @@ class VideoDecoder { // DECODING APIS AND RELATED UTILS // -------------------------------------------------------------------------- + void setCursorPtsInSecondsInternal(double seconds); bool canWeAvoidSeeking() const; void maybeSeekToBeforeDesiredPts(); diff --git a/test/decoders/test_ops.py b/test/decoders/test_ops.py index 724eff62..7faf0488 100644 --- a/test/decoders/test_ops.py +++ b/test/decoders/test_ops.py @@ -626,7 +626,7 @@ class TestAudioOps: partial(get_frames_in_range, start=4, stop=5), partial(get_frame_at_pts, seconds=2), partial(get_frames_by_pts, timestamps=[0, 1.5]), - partial(get_next_frame), + partial(seek_to_pts, seconds=5), ), ) def test_audio_bad_method(self, method): @@ -642,6 +642,20 @@ def test_audio_bad_seek_mode(self): ): add_audio_stream(decoder) + @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) + def test_next(self, asset): + decoder = create_from_file(str(asset.path), seek_mode="approximate") + add_audio_stream(decoder) + + frame_index = 0 + while True: + try: + frame, *_ = get_next_frame(decoder) + except IndexError: + break + torch.testing.assert_close(frame, asset.get_frame_data_by_index(frame_index)) + frame_index += 1 + @pytest.mark.parametrize( "range", ( From ca7d554769d985b17061e16e7776df00a6639a19 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 14 Mar 2025 17:11:10 +0000 Subject: [PATCH 2/2] Lint --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 4 ++-- test/decoders/test_ops.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index af82f2fe..f2191b0d 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -916,8 +916,8 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio( // -------------------------------------------------------------------------- void VideoDecoder::setCursorPtsInSeconds(double seconds) { - validateActiveStream(AVMEDIA_TYPE_VIDEO); - setCursorPtsInSecondsInternal(seconds); + validateActiveStream(AVMEDIA_TYPE_VIDEO); + setCursorPtsInSecondsInternal(seconds); } void VideoDecoder::setCursorPtsInSecondsInternal(double seconds) { diff --git a/test/decoders/test_ops.py b/test/decoders/test_ops.py index 7faf0488..41986a0f 100644 --- a/test/decoders/test_ops.py +++ b/test/decoders/test_ops.py @@ -653,7 +653,9 @@ def test_next(self, asset): frame, *_ = get_next_frame(decoder) except IndexError: break - torch.testing.assert_close(frame, asset.get_frame_data_by_index(frame_index)) + torch.testing.assert_close( + frame, asset.get_frame_data_by_index(frame_index) + ) frame_index += 1 @pytest.mark.parametrize(