Skip to content

Commit 64919ba

Browse files
authored
Audio: allow next(), disallow seek() (#563)
1 parent 28e1503 commit 64919ba

File tree

3 files changed

+28
-4
lines changed

3 files changed

+28
-4
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -567,13 +567,15 @@ void VideoDecoder::addAudioStream(int streamIndex) {
567567

568568
VideoDecoder::FrameOutput VideoDecoder::getNextFrame() {
569569
auto output = getNextFrameInternal();
570-
output.data = maybePermuteHWC2CHW(output.data);
570+
if (streamInfos_[activeStreamIndex_].avMediaType == AVMEDIA_TYPE_VIDEO) {
571+
output.data = maybePermuteHWC2CHW(output.data);
572+
}
571573
return output;
572574
}
573575

574576
VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal(
575577
std::optional<torch::Tensor> preAllocatedOutputTensor) {
576-
validateActiveStream(AVMEDIA_TYPE_VIDEO);
578+
validateActiveStream();
577579
AVFrameStream avFrameStream = decodeAVFrame(
578580
[this](AVFrame* avFrame) { return avFrame->pts >= cursor_; });
579581
return convertAVFrameToFrameOutput(avFrameStream, preAllocatedOutputTensor);
@@ -869,7 +871,7 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
869871
// If we need to seek backwards, then we have to seek back to the beginning
870872
// of the stream.
871873
// TODO-AUDIO: document why this is needed in a big comment.
872-
setCursorPtsInSeconds(INT64_MIN);
874+
setCursorPtsInSecondsInternal(INT64_MIN);
873875
}
874876

875877
// TODO-AUDIO Pre-allocate a long-enough tensor instead of creating a vec +
@@ -915,6 +917,11 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
915917
// --------------------------------------------------------------------------
916918

917919
void VideoDecoder::setCursorPtsInSeconds(double seconds) {
920+
validateActiveStream(AVMEDIA_TYPE_VIDEO);
921+
setCursorPtsInSecondsInternal(seconds);
922+
}
923+
924+
void VideoDecoder::setCursorPtsInSecondsInternal(double seconds) {
918925
cursorWasJustSet_ = true;
919926
cursor_ =
920927
secondsToClosestPts(seconds, streamInfos_[activeStreamIndex_].timeBase);

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,7 @@ class VideoDecoder {
370370
// DECODING APIS AND RELATED UTILS
371371
// --------------------------------------------------------------------------
372372

373+
void setCursorPtsInSecondsInternal(double seconds);
373374
bool canWeAvoidSeeking() const;
374375

375376
void maybeSeekToBeforeDesiredPts();

test/decoders/test_ops.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ class TestAudioOps:
626626
partial(get_frames_in_range, start=4, stop=5),
627627
partial(get_frame_at_pts, seconds=2),
628628
partial(get_frames_by_pts, timestamps=[0, 1.5]),
629-
partial(get_next_frame),
629+
partial(seek_to_pts, seconds=5),
630630
),
631631
)
632632
def test_audio_bad_method(self, method):
@@ -642,6 +642,22 @@ def test_audio_bad_seek_mode(self):
642642
):
643643
add_audio_stream(decoder)
644644

645+
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
646+
def test_next(self, asset):
647+
decoder = create_from_file(str(asset.path), seek_mode="approximate")
648+
add_audio_stream(decoder)
649+
650+
frame_index = 0
651+
while True:
652+
try:
653+
frame, *_ = get_next_frame(decoder)
654+
except IndexError:
655+
break
656+
torch.testing.assert_close(
657+
frame, asset.get_frame_data_by_index(frame_index)
658+
)
659+
frame_index += 1
660+
645661
@pytest.mark.parametrize(
646662
"range",
647663
(

0 commit comments

Comments
 (0)