Skip to content

Allow audio decoder to seek backwards #550

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
startSeconds <= stopSeconds,
"Start seconds (" + std::to_string(startSeconds) +
") must be less than or equal to stop seconds (" +
std::to_string(stopSeconds) + ".");
std::to_string(stopSeconds) + ").");

if (startSeconds == stopSeconds) {
// For consistency with video
Expand All @@ -859,16 +859,14 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(

StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];

// TODO-AUDIO This essentially enforce that we don't need to seek (backwards).
// We should remove it and seek back to the stream's beginning when needed.
// See test_multiple_calls
TORCH_CHECK(
streamInfo.lastDecodedAvFramePts +
streamInfo.lastDecodedAvFrameDuration <=
secondsToClosestPts(startSeconds, streamInfo.timeBase),
"Audio decoder cannot seek backwards, or start from the last decoded frame.");

setCursorPtsInSeconds(startSeconds);
auto startPts = secondsToClosestPts(startSeconds, streamInfo.timeBase);
if (startPts < streamInfo.lastDecodedAvFramePts +
streamInfo.lastDecodedAvFrameDuration) {
// If we need to seek backwards, then we have to seek back to the beginning
// of the stream.
// TODO-AUDIO: document why this is needed in a big comment.
setCursorPtsInSeconds(INT64_MIN);
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this is INT64_MIN and not 0, because some packets actually start before 0. In one of our assets the first packet is at -1024.
I noticed that passing an arbitrary low value like -999999 makes FFmpeg unhappy and raise and error, but INT64_MIN seems to be understood and correct (although I haven't found docs on this).


// TODO-AUDIO Pre-allocate a long-enough tensor instead of creating a vec +
// cat(). This would save a copy. We know the duration of the output and the
Expand All @@ -879,8 +877,8 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
auto finished = false;
while (!finished) {
try {
AVFrameStream avFrameStream = decodeAVFrame([this](AVFrame* avFrame) {
return cursor_ < avFrame->pts + getDuration(avFrame);
AVFrameStream avFrameStream = decodeAVFrame([startPts](AVFrame* avFrame) {
return startPts < avFrame->pts + getDuration(avFrame);
});
auto frameOutput = convertAVFrameToFrameOutput(avFrameStream);
tensors.push_back(frameOutput.data);
Expand Down Expand Up @@ -938,7 +936,9 @@ I P P P I P P P I P P I P P I P
bool VideoDecoder::canWeAvoidSeeking() const {
const StreamInfo& streamInfo = streamInfos_.at(activeStreamIndex_);
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
return true;
// For audio, we only need to seek if a backwards seek was requested within
// getFramesPlayedInRangeAudio(), when setCursorPtsInSeconds() was called.
return !cursorWasJustSet_;
}
int64_t lastDecodedAvFramePts =
streamInfos_.at(activeStreamIndex_).lastDecodedAvFramePts;
Expand Down
58 changes: 32 additions & 26 deletions test/decoders/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,19 +741,18 @@ def test_decode_start_equal_stop(self, asset):

@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
def test_multiple_calls(self, asset):
# Ensure that multiple calls are OK as long as we're decoding
# "sequentially", i.e. we don't require a backwards seek.
# And ensure a proper error is raised in such case.
# TODO-AUDIO We shouldn't error, we should just implement the seeking
# back to the beginning of the stream.
# Ensure that multiple calls to get_frames_by_pts_in_range_audio on the
# same decoder are supported and correct, whether it involves forward
# seeks or backwards seeks.

def get_reference_frames(start_seconds, stop_seconds):
# This stateless helper exists for convenience, to avoid
# complicating this test with pts-to-index conversions. Eventually
# we should remove it and just rely on the asset's methods.
# Using this helper is OK for now: we're comparing a decoder which
# seeks multiple times with a decoder which seeks only once (the one
# here, treated as the reference)
# Usually we get the reference frames from the asset's methods, but
# for this specific test, this helper is more convenient, because
# relying on the asset would force us to convert all timestamps into
# indices.
# Ultimately, this test compares a "stateful decoder" which calls
# `get_frames_by_pts_in_range_audio()`` multiple times with a
# "stateless decoder" (the one here, treated as the reference)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've convinced myself that we should actually keep this helper instead of doing the conversion. Thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's fine - then what we're testing is that decoder behaves the same when it seeks to a location "fresh" versus having to seek from some given location, including backwards. That seems reasonable.

decoder = create_from_file(str(asset.path), seek_mode="approximate")
add_audio_stream(decoder)

Expand Down Expand Up @@ -794,23 +793,30 @@ def get_reference_frames(start_seconds, stop_seconds):
frames, get_reference_frames(start_seconds, stop_seconds)
)

# but starting immediately on the same frame raises
expected_match = "Audio decoder cannot seek backwards"
with pytest.raises(RuntimeError, match=expected_match):
get_frames_by_pts_in_range_audio(
decoder, start_seconds=stop_seconds, stop_seconds=6
)
# starting immediately on the same frame is OK
start_seconds, stop_seconds = stop_seconds, 6
frames = get_frames_by_pts_in_range_audio(
decoder, start_seconds=start_seconds, stop_seconds=stop_seconds
)
torch.testing.assert_close(
frames, get_reference_frames(start_seconds, stop_seconds)
)

with pytest.raises(RuntimeError, match=expected_match):
get_frames_by_pts_in_range_audio(
decoder, start_seconds=stop_seconds + 1e-4, stop_seconds=6
)
get_frames_by_pts_in_range_audio(
decoder, start_seconds=start_seconds + 1e-4, stop_seconds=stop_seconds
)
torch.testing.assert_close(
frames, get_reference_frames(start_seconds, stop_seconds)
)

# and seeking backwards doesn't work either
with pytest.raises(RuntimeError, match=expected_match):
frames = get_frames_by_pts_in_range_audio(
decoder, start_seconds=0, stop_seconds=2
)
# seeking backwards
start_seconds, stop_seconds = 0, 2
frames = get_frames_by_pts_in_range_audio(
decoder, start_seconds=start_seconds, stop_seconds=stop_seconds
)
torch.testing.assert_close(
frames, get_reference_frames(start_seconds, stop_seconds)
)


if __name__ == "__main__":
Expand Down
Loading