Skip to content

Use FrameOutput for audio as well #574

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -895,10 +895,6 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
decodeAVFrame([startPts](const UniqueAVFrame& avFrame) {
return startPts < avFrame->pts + getDuration(avFrame);
});
// TODO: it's not great that we are getting a FrameOutput, which is
// intended for videos. We should consider bypassing
// convertAVFrameToFrameOutput and directly call
// convertAudioAVFrameToFrameOutputOnCPU.
Copy link
Member Author

@NicolasHug NicolasHug Mar 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this PR is about not addressing this TODO. I tried to address it, but then realized that we do call convertAVFrameToFrameOutput() in getNextFrameInternal():

VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal(
std::optional<torch::Tensor> preAllocatedOutputTensor) {
validateActiveStream();
UniqueAVFrame avFrame = decodeAVFrame(
[this](const UniqueAVFrame& avFrame) { return avFrame->pts >= cursor_; });
return convertAVFrameToFrameOutput(avFrame, preAllocatedOutputTensor);
}

If we were to address the TODO, we would need separated audio/video logics in getNextFrameInternal(): one calling convertAVFrameToFrameOutput() (for video), and the other calling the audio-conversion stuff.

And we can't get around the fact that getNextFrameInternal() returns a FrameOutput, can we?

So, I guess this PR is about accepting FrameOutput as a "universal frame output", for both audio and videos.

auto frameOutput = convertAVFrameToFrameOutput(avFrame);
firstFramePtsSeconds =
std::min(firstFramePtsSeconds, frameOutput.ptsSeconds);
Expand Down Expand Up @@ -1163,7 +1159,6 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
std::optional<torch::Tensor> preAllocatedOutputTensor) {
// Convert the frame to tensor.
FrameOutput frameOutput;
frameOutput.streamIndex = activeStreamIndex_;
auto& streamInfo = streamInfos_[activeStreamIndex_];
frameOutput.ptsSeconds = ptsToSeconds(
avFrame->pts, formatContext_->streams[activeStreamIndex_]->time_base);
Expand Down
8 changes: 6 additions & 2 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,13 @@ class VideoDecoder {
// They are the equivalent of the user-facing Frame and FrameBatch classes in
// Python. They contain RGB decoded frames along with some associated data
// like PTS and duration.
// FrameOutput is also relevant for audio decoding, typically as the output of
// getNextFrame(), or as a temporary output variable.
struct FrameOutput {
torch::Tensor data; // 3D: of shape CHW or HWC.
int streamIndex;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Driveby: removing unused streamIndex field

// data shape is:
// - 3D (C, H, W) or (H, W, C) for videos
// - 2D (numChannels, numSamples) for audio
torch::Tensor data;
double ptsSeconds;
double durationSeconds;
};
Expand Down
5 changes: 4 additions & 1 deletion test/decoders/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,12 +650,15 @@ def test_next(self, asset):
frame_index = 0
while True:
try:
frame, *_ = get_next_frame(decoder)
frame, pts_seconds, duration_seconds = get_next_frame(decoder)
except IndexError:
break
torch.testing.assert_close(
frame, asset.get_frame_data_by_index(frame_index)
)
frame_info = asset.get_frame_info(frame_index)
assert pts_seconds == frame_info.pts_seconds
assert duration_seconds == frame_info.duration_seconds
frame_index += 1

@pytest.mark.parametrize(
Expand Down
Loading