Skip to content

Commit c6de04a

Browse files
authored
Allow audio decoder to seek backwards (#550)
1 parent 8b2ad5b commit c6de04a

File tree

2 files changed

+46
-40
lines changed

2 files changed

+46
-40
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -850,7 +850,7 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
850850
startSeconds <= stopSeconds,
851851
"Start seconds (" + std::to_string(startSeconds) +
852852
") must be less than or equal to stop seconds (" +
853-
std::to_string(stopSeconds) + ".");
853+
std::to_string(stopSeconds) + ").");
854854

855855
if (startSeconds == stopSeconds) {
856856
// For consistency with video
@@ -859,16 +859,14 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
859859

860860
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
861861

862-
// TODO-AUDIO This essentially enforce that we don't need to seek (backwards).
863-
// We should remove it and seek back to the stream's beginning when needed.
864-
// See test_multiple_calls
865-
TORCH_CHECK(
866-
streamInfo.lastDecodedAvFramePts +
867-
streamInfo.lastDecodedAvFrameDuration <=
868-
secondsToClosestPts(startSeconds, streamInfo.timeBase),
869-
"Audio decoder cannot seek backwards, or start from the last decoded frame.");
870-
871-
setCursorPtsInSeconds(startSeconds);
862+
auto startPts = secondsToClosestPts(startSeconds, streamInfo.timeBase);
863+
if (startPts < streamInfo.lastDecodedAvFramePts +
864+
streamInfo.lastDecodedAvFrameDuration) {
865+
// If we need to seek backwards, then we have to seek back to the beginning
866+
// of the stream.
867+
// TODO-AUDIO: document why this is needed in a big comment.
868+
setCursorPtsInSeconds(INT64_MIN);
869+
}
872870

873871
// TODO-AUDIO Pre-allocate a long-enough tensor instead of creating a vec +
874872
// cat(). This would save a copy. We know the duration of the output and the
@@ -879,8 +877,8 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
879877
auto finished = false;
880878
while (!finished) {
881879
try {
882-
AVFrameStream avFrameStream = decodeAVFrame([this](AVFrame* avFrame) {
883-
return cursor_ < avFrame->pts + getDuration(avFrame);
880+
AVFrameStream avFrameStream = decodeAVFrame([startPts](AVFrame* avFrame) {
881+
return startPts < avFrame->pts + getDuration(avFrame);
884882
});
885883
auto frameOutput = convertAVFrameToFrameOutput(avFrameStream);
886884
tensors.push_back(frameOutput.data);
@@ -938,7 +936,9 @@ I P P P I P P P I P P I P P I P
938936
bool VideoDecoder::canWeAvoidSeeking() const {
939937
const StreamInfo& streamInfo = streamInfos_.at(activeStreamIndex_);
940938
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
941-
return true;
939+
// For audio, we only need to seek if a backwards seek was requested within
940+
// getFramesPlayedInRangeAudio(), when setCursorPtsInSeconds() was called.
941+
return !cursorWasJustSet_;
942942
}
943943
int64_t lastDecodedAvFramePts =
944944
streamInfos_.at(activeStreamIndex_).lastDecodedAvFramePts;

test/decoders/test_ops.py

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -741,19 +741,18 @@ def test_decode_start_equal_stop(self, asset):
741741

742742
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
743743
def test_multiple_calls(self, asset):
744-
# Ensure that multiple calls are OK as long as we're decoding
745-
# "sequentially", i.e. we don't require a backwards seek.
746-
# And ensure a proper error is raised in such case.
747-
# TODO-AUDIO We shouldn't error, we should just implement the seeking
748-
# back to the beginning of the stream.
744+
# Ensure that multiple calls to get_frames_by_pts_in_range_audio on the
745+
# same decoder are supported and correct, whether it involves forward
746+
# seeks or backwards seeks.
749747

750748
def get_reference_frames(start_seconds, stop_seconds):
751-
# This stateless helper exists for convenience, to avoid
752-
# complicating this test with pts-to-index conversions. Eventually
753-
# we should remove it and just rely on the asset's methods.
754-
# Using this helper is OK for now: we're comparing a decoder which
755-
# seeks multiple times with a decoder which seeks only once (the one
756-
# here, treated as the reference)
749+
# Usually we get the reference frames from the asset's methods, but
750+
# for this specific test, this helper is more convenient, because
751+
# relying on the asset would force us to convert all timestamps into
752+
# indices.
753+
# Ultimately, this test compares a "stateful decoder" which calls
754+
# `get_frames_by_pts_in_range_audio()`` multiple times with a
755+
# "stateless decoder" (the one here, treated as the reference)
757756
decoder = create_from_file(str(asset.path), seek_mode="approximate")
758757
add_audio_stream(decoder)
759758

@@ -794,23 +793,30 @@ def get_reference_frames(start_seconds, stop_seconds):
794793
frames, get_reference_frames(start_seconds, stop_seconds)
795794
)
796795

797-
# but starting immediately on the same frame raises
798-
expected_match = "Audio decoder cannot seek backwards"
799-
with pytest.raises(RuntimeError, match=expected_match):
800-
get_frames_by_pts_in_range_audio(
801-
decoder, start_seconds=stop_seconds, stop_seconds=6
802-
)
796+
# starting immediately on the same frame is OK
797+
start_seconds, stop_seconds = stop_seconds, 6
798+
frames = get_frames_by_pts_in_range_audio(
799+
decoder, start_seconds=start_seconds, stop_seconds=stop_seconds
800+
)
801+
torch.testing.assert_close(
802+
frames, get_reference_frames(start_seconds, stop_seconds)
803+
)
803804

804-
with pytest.raises(RuntimeError, match=expected_match):
805-
get_frames_by_pts_in_range_audio(
806-
decoder, start_seconds=stop_seconds + 1e-4, stop_seconds=6
807-
)
805+
get_frames_by_pts_in_range_audio(
806+
decoder, start_seconds=start_seconds + 1e-4, stop_seconds=stop_seconds
807+
)
808+
torch.testing.assert_close(
809+
frames, get_reference_frames(start_seconds, stop_seconds)
810+
)
808811

809-
# and seeking backwards doesn't work either
810-
with pytest.raises(RuntimeError, match=expected_match):
811-
frames = get_frames_by_pts_in_range_audio(
812-
decoder, start_seconds=0, stop_seconds=2
813-
)
812+
# seeking backwards
813+
start_seconds, stop_seconds = 0, 2
814+
frames = get_frames_by_pts_in_range_audio(
815+
decoder, start_seconds=start_seconds, stop_seconds=stop_seconds
816+
)
817+
torch.testing.assert_close(
818+
frames, get_reference_frames(start_seconds, stop_seconds)
819+
)
814820

815821

816822
if __name__ == "__main__":

0 commit comments

Comments
 (0)