From 27a5c3827d7db39411de4e9d34a886b94044f269 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 29 Apr 2025 14:15:47 +0100 Subject: [PATCH] AudioDecoder: Fix output when stop < start of first frame --- src/torchcodec/_core/SingleStreamDecoder.cpp | 14 ++++++--- test/test_decoders.py | 32 +++++++++++++++++++- test/test_ops.py | 17 ----------- 3 files changed, 40 insertions(+), 23 deletions(-) diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 17e1301d..4bde9986 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -877,8 +877,9 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio( while (!finished) { try { UniqueAVFrame avFrame = - decodeAVFrame([startPts](const UniqueAVFrame& avFrame) { - return startPts < avFrame->pts + getDuration(avFrame); + decodeAVFrame([startPts, stopPts](const UniqueAVFrame& avFrame) { + return startPts < avFrame->pts + getDuration(avFrame) && + stopPts > avFrame->pts; }); auto frameOutput = convertAVFrameToFrameOutput(avFrame); if (!firstFramePtsSeconds.has_value()) { @@ -907,9 +908,12 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio( TORCH_CHECK( frames.size() > 0 && firstFramePtsSeconds.has_value(), "No audio frames were decoded. ", - "This is probably because start_seconds is too high? ", - "Current value is ", - startSeconds); + "This is probably because start_seconds is too high(", + startSeconds, + "),", + "or because stop_seconds(", + stopSecondsOptional, + ") is too low."); return AudioFramesOutput{torch::cat(frames, 1), *firstFramePtsSeconds}; } diff --git a/test/test_decoders.py b/test/test_decoders.py index c68e1ace..a2859bca 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -1112,7 +1112,8 @@ def test_start_equals_stop(self, asset): def test_frame_start_is_not_zero(self): # For NASA_AUDIO_MP3, the first frame is not at 0, it's at 0.138125. - # So if we request start = 0.05, we shouldn't be truncating anything. + # So if we request (start, stop) = (0.05, None), we shouldn't be + # truncating anything. asset = NASA_AUDIO_MP3 start_seconds = 0.05 # this is less than the first frame's pts @@ -1128,6 +1129,35 @@ def test_frame_start_is_not_zero(self): reference_frames = asset.get_frame_data_by_range(start=0, stop=stop_frame_index) torch.testing.assert_close(samples.data, reference_frames) + # Non-regression test for https://github.com/pytorch/torchcodec/issues/567 + # If we ask for start < stop <= first_frame_pts, we should raise. + with pytest.raises(RuntimeError, match="No audio frames were decoded"): + decoder.get_samples_played_in_range(start_seconds=0, stop_seconds=0.05) + + first_frame_pts_seconds = asset.get_frame_info(idx=0).pts_seconds + with pytest.raises(RuntimeError, match="No audio frames were decoded"): + decoder.get_samples_played_in_range( + start_seconds=0, stop_seconds=first_frame_pts_seconds + ) + + # Documenting an edge case: we ask for samples barely beyond the start + # of the first frame. The C++ decoder returns the first frame, which + # gets (correctly!) truncated by the AudioDecoder, and we end up with + # empty data. + samples = decoder.get_samples_played_in_range( + start_seconds=0, stop_seconds=first_frame_pts_seconds + 1e-5 + ) + assert samples.data.shape == (2, 0) + assert samples.pts_seconds == first_frame_pts_seconds + assert samples.duration_seconds == 0 + + # if we ask for a little bit more samples, we get non-empty data + samples = decoder.get_samples_played_in_range( + start_seconds=0, stop_seconds=first_frame_pts_seconds + 1e-3 + ) + assert samples.data.shape == (2, 8) + assert samples.pts_seconds == first_frame_pts_seconds + def test_single_channel(self): asset = SINE_MONO_S32 decoder = AudioDecoder(asset.path) diff --git a/test/test_ops.py b/test/test_ops.py index b41be538..b1735153 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -884,23 +884,6 @@ def test_pts(self, asset): assert pts_seconds == start_seconds - def test_decode_before_frame_start(self): - # Test illustrating bug described in - # https://github.com/pytorch/torchcodec/issues/567 - asset = NASA_AUDIO_MP3 - - decoder = create_from_file(str(asset.path), seek_mode="approximate") - add_audio_stream(decoder) - - frames, *_ = get_frames_by_pts_in_range_audio( - decoder, start_seconds=0, stop_seconds=0.05 - ) - all_frames, *_ = get_frames_by_pts_in_range_audio( - decoder, start_seconds=0, stop_seconds=None - ) - # TODO fix this. `frames` should be empty. - torch.testing.assert_close(frames, all_frames) - def test_sample_rate_conversion(self): def get_all_frames(asset, sample_rate=None, stop_seconds=None): decoder = create_from_file(str(asset.path), seek_mode="approximate")