Skip to content

Return pts of first frame in audio API #552

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,7 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
return frameBatchOutput;
}

torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to create a new AudioFramesOutput struct instead of relying on the existing FrameOutput, which contains unnecessary fields like streamIndex and durationSeconds. For now, those aren't needed for audio. No super strong opinion though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's the right call, since the tensors are going to be very different.

double startSeconds,
std::optional<double> stopSecondsOptional) {
validateActiveStream(AVMEDIA_TYPE_AUDIO);
Expand All @@ -854,7 +854,7 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(

if (startSeconds == stopSeconds) {
// For consistency with video
return torch::empty({0});
return AudioFramesOutput{torch::empty({0}), 0.0};
}

StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
Expand All @@ -871,17 +871,24 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
// TODO-AUDIO Pre-allocate a long-enough tensor instead of creating a vec +
// cat(). This would save a copy. We know the duration of the output and the
// sample rate, so in theory we know the number of output samples.
std::vector<torch::Tensor> tensors;
std::vector<torch::Tensor> frames;

double firstFramePtsSeconds = std::numeric_limits<double>::max();
auto stopPts = secondsToClosestPts(stopSeconds, streamInfo.timeBase);
auto finished = false;
while (!finished) {
try {
AVFrameStream avFrameStream = decodeAVFrame([startPts](AVFrame* avFrame) {
return startPts < avFrame->pts + getDuration(avFrame);
});
// TODO: it's not great that we are getting a FrameOutput, which is
// intended for videos. We should consider bypassing
// convertAVFrameToFrameOutput and directly call
// convertAudioAVFrameToFrameOutputOnCPU.
auto frameOutput = convertAVFrameToFrameOutput(avFrameStream);
tensors.push_back(frameOutput.data);
firstFramePtsSeconds =
std::min(firstFramePtsSeconds, frameOutput.ptsSeconds);
frames.push_back(frameOutput.data);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's now a bit ugly that we are manipulating both a FrameOutput and an AudioFramesOutput here in this function.

This is mostly because convertAVFrameToFrameOutput returns a FrameOutput, but maybe we could bypass it and directly call convertAudioAVFrameToFrameOutputOnCPU. I'll write a TODO to investigate this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, yeah, that's cumbersome. If we're in an audio-only call, we shouldn't have to deal with video structures.

} catch (const EndOfFileException& e) {
finished = true;
}
Expand All @@ -895,7 +902,8 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
finished |= (streamInfo.lastDecodedAvFramePts) <= stopPts &&
(stopPts <= lastDecodedAvFrameEnd);
}
return torch::cat(tensors, 1);

return AudioFramesOutput{torch::cat(frames, 1), firstFramePtsSeconds};
}

// --------------------------------------------------------------------------
Expand Down
7 changes: 6 additions & 1 deletion src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@ class VideoDecoder {
const StreamMetadata& streamMetadata);
};

struct AudioFramesOutput {
torch::Tensor data; // shape is (numChannels, numSamples)
double ptsSeconds;
};

// Places the cursor at the first frame on or after the position in seconds.
// Calling getNextFrame() will return the first frame at
// or after this position.
Expand Down Expand Up @@ -222,7 +227,7 @@ class VideoDecoder {
double stopSeconds);

// TODO-AUDIO: Should accept sampleRate
torch::Tensor getFramesPlayedInRangeAudio(
AudioFramesOutput getFramesPlayedInRangeAudio(
double startSeconds,
std::optional<double> stopSecondsOptional = std::nullopt);

Expand Down
15 changes: 12 additions & 3 deletions src/torchcodec/decoders/_core/VideoDecoderOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
m.def(
"get_frames_by_pts_in_range(Tensor(a!) decoder, *, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)");
m.def(
"get_frames_by_pts_in_range_audio(Tensor(a!) decoder, *, float start_seconds, float? stop_seconds) -> Tensor");
"get_frames_by_pts_in_range_audio(Tensor(a!) decoder, *, float start_seconds, float? stop_seconds) -> (Tensor, Tensor)");
m.def(
"get_frames_by_pts(Tensor(a!) decoder, *, float[] timestamps) -> (Tensor, Tensor, Tensor)");
m.def("_get_key_frame_indices(Tensor(a!) decoder) -> Tensor");
Expand Down Expand Up @@ -94,6 +94,13 @@ OpsFrameBatchOutput makeOpsFrameBatchOutput(
return std::make_tuple(batch.data, batch.ptsSeconds, batch.durationSeconds);
}

OpsAudioFramesOutput makeOpsAudioFramesOutput(
VideoDecoder::AudioFramesOutput& audioFrames) {
return std::make_tuple(
audioFrames.data,
torch::tensor(audioFrames.ptsSeconds, torch::dtype(torch::kFloat64)));
}

VideoDecoder::SeekMode seekModeFromString(std::string_view seekMode) {
if (seekMode == "exact") {
return VideoDecoder::SeekMode::exact;
Expand Down Expand Up @@ -290,12 +297,14 @@ OpsFrameBatchOutput get_frames_by_pts_in_range(
return makeOpsFrameBatchOutput(result);
}

torch::Tensor get_frames_by_pts_in_range_audio(
OpsAudioFramesOutput get_frames_by_pts_in_range_audio(
at::Tensor& decoder,
double start_seconds,
std::optional<double> stop_seconds) {
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
return videoDecoder->getFramesPlayedInRangeAudio(start_seconds, stop_seconds);
auto result =
videoDecoder->getFramesPlayedInRangeAudio(start_seconds, stop_seconds);
return makeOpsAudioFramesOutput(result);
}

std::string quoteValue(const std::string& value) {
Expand Down
8 changes: 7 additions & 1 deletion src/torchcodec/decoders/_core/VideoDecoderOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ using OpsFrameOutput = std::tuple<at::Tensor, at::Tensor, at::Tensor>;
// single float.
using OpsFrameBatchOutput = std::tuple<at::Tensor, at::Tensor, at::Tensor>;

// The elements of this tuple are all tensors that represent the concatenation
// of multiple audio frames:
// 1. The frames data (concatenated)
// 2. A single float value for the pts of the first frame, in seconds.
using OpsAudioFramesOutput = std::tuple<at::Tensor, at::Tensor>;

// Return the frame that is visible at a given timestamp in seconds. Each frame
// in FFMPEG has a presentation timestamp and a duration. The frame visible at a
// given timestamp T has T >= PTS and T < PTS + Duration.
Expand Down Expand Up @@ -112,7 +118,7 @@ OpsFrameBatchOutput get_frames_by_pts_in_range(
double start_seconds,
double stop_seconds);

torch::Tensor get_frames_by_pts_in_range_audio(
OpsAudioFramesOutput get_frames_by_pts_in_range_audio(
at::Tensor& decoder,
double start_seconds,
std::optional<double> stop_seconds = std::nullopt);
Expand Down
4 changes: 2 additions & 2 deletions src/torchcodec/decoders/_core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,9 @@ def get_frames_by_pts_in_range_audio_abstract(
*,
start_seconds: float,
stop_seconds: Optional[float] = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
return torch.empty(image_size)
return (torch.empty(image_size), torch.empty([], dtype=torch.float))


@register_fake("torchcodec_ns::_get_key_frame_indices")
Expand Down
16 changes: 11 additions & 5 deletions test/decoders/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,19 +691,23 @@ def test_get_frames_by_pts_in_range_audio(self, range, asset):
decoder = create_from_file(str(asset.path), seek_mode="approximate")
add_audio_stream(decoder)

frames = get_frames_by_pts_in_range_audio(
frames, pts_seconds = get_frames_by_pts_in_range_audio(
decoder, start_seconds=start_seconds, stop_seconds=stop_seconds
)

torch.testing.assert_close(frames, reference_frames)

if range == "at_frames_boundaries":
assert pts_seconds == start_seconds
elif range == "not_at_frames_boundaries":
assert pts_seconds == start_frame_info.pts_seconds

@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
def test_decode_epsilon_range(self, asset):
decoder = create_from_file(str(asset.path), seek_mode="approximate")
add_audio_stream(decoder)

start_seconds = 5
frames = get_frames_by_pts_in_range_audio(
frames, *_ = get_frames_by_pts_in_range_audio(
decoder, start_seconds=start_seconds, stop_seconds=start_seconds + 1e-5
)
torch.testing.assert_close(
Expand All @@ -720,7 +724,7 @@ def test_decode_just_one_frame_at_boundaries(self, asset):

start_seconds = asset.get_frame_info(idx=10).pts_seconds
stop_seconds = asset.get_frame_info(idx=11).pts_seconds
frames = get_frames_by_pts_in_range_audio(
frames, pts_seconds = get_frames_by_pts_in_range_audio(
decoder, start_seconds=start_seconds, stop_seconds=stop_seconds
)
torch.testing.assert_close(
Expand All @@ -729,15 +733,17 @@ def test_decode_just_one_frame_at_boundaries(self, asset):
asset.get_frame_index(pts_seconds=start_seconds)
),
)
assert pts_seconds == start_seconds

@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
def test_decode_start_equal_stop(self, asset):
decoder = create_from_file(str(asset.path), seek_mode="approximate")
add_audio_stream(decoder)
frames = get_frames_by_pts_in_range_audio(
frames, pts_seconds = get_frames_by_pts_in_range_audio(
decoder, start_seconds=1, stop_seconds=1
)
assert frames.shape == (0,)
assert pts_seconds == 0

@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
def test_multiple_calls(self, asset):
Expand Down
Loading