diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index aeefa3b5..0e287a5b 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -838,7 +838,7 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange( return frameBatchOutput; } -torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio( +VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio( double startSeconds, std::optional stopSecondsOptional) { validateActiveStream(AVMEDIA_TYPE_AUDIO); @@ -854,7 +854,7 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio( if (startSeconds == stopSeconds) { // For consistency with video - return torch::empty({0}); + return AudioFramesOutput{torch::empty({0}), 0.0}; } StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; @@ -871,8 +871,9 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio( // TODO-AUDIO Pre-allocate a long-enough tensor instead of creating a vec + // cat(). This would save a copy. We know the duration of the output and the // sample rate, so in theory we know the number of output samples. - std::vector tensors; + std::vector frames; + double firstFramePtsSeconds = std::numeric_limits::max(); auto stopPts = secondsToClosestPts(stopSeconds, streamInfo.timeBase); auto finished = false; while (!finished) { @@ -880,8 +881,14 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio( AVFrameStream avFrameStream = decodeAVFrame([startPts](AVFrame* avFrame) { return startPts < avFrame->pts + getDuration(avFrame); }); + // TODO: it's not great that we are getting a FrameOutput, which is + // intended for videos. We should consider bypassing + // convertAVFrameToFrameOutput and directly call + // convertAudioAVFrameToFrameOutputOnCPU. auto frameOutput = convertAVFrameToFrameOutput(avFrameStream); - tensors.push_back(frameOutput.data); + firstFramePtsSeconds = + std::min(firstFramePtsSeconds, frameOutput.ptsSeconds); + frames.push_back(frameOutput.data); } catch (const EndOfFileException& e) { finished = true; } @@ -895,7 +902,8 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio( finished |= (streamInfo.lastDecodedAvFramePts) <= stopPts && (stopPts <= lastDecodedAvFrameEnd); } - return torch::cat(tensors, 1); + + return AudioFramesOutput{torch::cat(frames, 1), firstFramePtsSeconds}; } // -------------------------------------------------------------------------- diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 66b9d93c..60c10055 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -170,6 +170,11 @@ class VideoDecoder { const StreamMetadata& streamMetadata); }; + struct AudioFramesOutput { + torch::Tensor data; // shape is (numChannels, numSamples) + double ptsSeconds; + }; + // Places the cursor at the first frame on or after the position in seconds. // Calling getNextFrame() will return the first frame at // or after this position. @@ -222,7 +227,7 @@ class VideoDecoder { double stopSeconds); // TODO-AUDIO: Should accept sampleRate - torch::Tensor getFramesPlayedInRangeAudio( + AudioFramesOutput getFramesPlayedInRangeAudio( double startSeconds, std::optional stopSecondsOptional = std::nullopt); diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 9eb61ac2..7472c3de 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -48,7 +48,7 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "get_frames_by_pts_in_range(Tensor(a!) decoder, *, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)"); m.def( - "get_frames_by_pts_in_range_audio(Tensor(a!) decoder, *, float start_seconds, float? stop_seconds) -> Tensor"); + "get_frames_by_pts_in_range_audio(Tensor(a!) decoder, *, float start_seconds, float? stop_seconds) -> (Tensor, Tensor)"); m.def( "get_frames_by_pts(Tensor(a!) decoder, *, float[] timestamps) -> (Tensor, Tensor, Tensor)"); m.def("_get_key_frame_indices(Tensor(a!) decoder) -> Tensor"); @@ -94,6 +94,13 @@ OpsFrameBatchOutput makeOpsFrameBatchOutput( return std::make_tuple(batch.data, batch.ptsSeconds, batch.durationSeconds); } +OpsAudioFramesOutput makeOpsAudioFramesOutput( + VideoDecoder::AudioFramesOutput& audioFrames) { + return std::make_tuple( + audioFrames.data, + torch::tensor(audioFrames.ptsSeconds, torch::dtype(torch::kFloat64))); +} + VideoDecoder::SeekMode seekModeFromString(std::string_view seekMode) { if (seekMode == "exact") { return VideoDecoder::SeekMode::exact; @@ -290,12 +297,14 @@ OpsFrameBatchOutput get_frames_by_pts_in_range( return makeOpsFrameBatchOutput(result); } -torch::Tensor get_frames_by_pts_in_range_audio( +OpsAudioFramesOutput get_frames_by_pts_in_range_audio( at::Tensor& decoder, double start_seconds, std::optional stop_seconds) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); - return videoDecoder->getFramesPlayedInRangeAudio(start_seconds, stop_seconds); + auto result = + videoDecoder->getFramesPlayedInRangeAudio(start_seconds, stop_seconds); + return makeOpsAudioFramesOutput(result); } std::string quoteValue(const std::string& value) { diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index c8d32407..a77dec66 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -74,6 +74,12 @@ using OpsFrameOutput = std::tuple; // single float. using OpsFrameBatchOutput = std::tuple; +// The elements of this tuple are all tensors that represent the concatenation +// of multiple audio frames: +// 1. The frames data (concatenated) +// 2. A single float value for the pts of the first frame, in seconds. +using OpsAudioFramesOutput = std::tuple; + // Return the frame that is visible at a given timestamp in seconds. Each frame // in FFMPEG has a presentation timestamp and a duration. The frame visible at a // given timestamp T has T >= PTS and T < PTS + Duration. @@ -112,7 +118,7 @@ OpsFrameBatchOutput get_frames_by_pts_in_range( double start_seconds, double stop_seconds); -torch::Tensor get_frames_by_pts_in_range_audio( +OpsAudioFramesOutput get_frames_by_pts_in_range_audio( at::Tensor& decoder, double start_seconds, std::optional stop_seconds = std::nullopt); diff --git a/src/torchcodec/decoders/_core/ops.py b/src/torchcodec/decoders/_core/ops.py index 74796a17..e8efa45f 100644 --- a/src/torchcodec/decoders/_core/ops.py +++ b/src/torchcodec/decoders/_core/ops.py @@ -271,9 +271,9 @@ def get_frames_by_pts_in_range_audio_abstract( *, start_seconds: float, stop_seconds: Optional[float] = None, -) -> torch.Tensor: +) -> Tuple[torch.Tensor, torch.Tensor]: image_size = [get_ctx().new_dynamic_size() for _ in range(4)] - return torch.empty(image_size) + return (torch.empty(image_size), torch.empty([], dtype=torch.float)) @register_fake("torchcodec_ns::_get_key_frame_indices") diff --git a/test/decoders/test_ops.py b/test/decoders/test_ops.py index ce74243f..b510c8ef 100644 --- a/test/decoders/test_ops.py +++ b/test/decoders/test_ops.py @@ -691,19 +691,23 @@ def test_get_frames_by_pts_in_range_audio(self, range, asset): decoder = create_from_file(str(asset.path), seek_mode="approximate") add_audio_stream(decoder) - frames = get_frames_by_pts_in_range_audio( + frames, pts_seconds = get_frames_by_pts_in_range_audio( decoder, start_seconds=start_seconds, stop_seconds=stop_seconds ) - torch.testing.assert_close(frames, reference_frames) + if range == "at_frames_boundaries": + assert pts_seconds == start_seconds + elif range == "not_at_frames_boundaries": + assert pts_seconds == start_frame_info.pts_seconds + @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) def test_decode_epsilon_range(self, asset): decoder = create_from_file(str(asset.path), seek_mode="approximate") add_audio_stream(decoder) start_seconds = 5 - frames = get_frames_by_pts_in_range_audio( + frames, *_ = get_frames_by_pts_in_range_audio( decoder, start_seconds=start_seconds, stop_seconds=start_seconds + 1e-5 ) torch.testing.assert_close( @@ -720,7 +724,7 @@ def test_decode_just_one_frame_at_boundaries(self, asset): start_seconds = asset.get_frame_info(idx=10).pts_seconds stop_seconds = asset.get_frame_info(idx=11).pts_seconds - frames = get_frames_by_pts_in_range_audio( + frames, pts_seconds = get_frames_by_pts_in_range_audio( decoder, start_seconds=start_seconds, stop_seconds=stop_seconds ) torch.testing.assert_close( @@ -729,15 +733,17 @@ def test_decode_just_one_frame_at_boundaries(self, asset): asset.get_frame_index(pts_seconds=start_seconds) ), ) + assert pts_seconds == start_seconds @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) def test_decode_start_equal_stop(self, asset): decoder = create_from_file(str(asset.path), seek_mode="approximate") add_audio_stream(decoder) - frames = get_frames_by_pts_in_range_audio( + frames, pts_seconds = get_frames_by_pts_in_range_audio( decoder, start_seconds=1, stop_seconds=1 ) assert frames.shape == (0,) + assert pts_seconds == 0 @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) def test_multiple_calls(self, asset):