Skip to content

Commit 23c73ea

Browse files
authored
Use FrameOutput for audio as well (#574)
1 parent b830480 commit 23c73ea

File tree

3 files changed

+10
-8
lines changed

3 files changed

+10
-8
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -895,10 +895,6 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
895895
decodeAVFrame([startPts](const UniqueAVFrame& avFrame) {
896896
return startPts < avFrame->pts + getDuration(avFrame);
897897
});
898-
// TODO: it's not great that we are getting a FrameOutput, which is
899-
// intended for videos. We should consider bypassing
900-
// convertAVFrameToFrameOutput and directly call
901-
// convertAudioAVFrameToFrameOutputOnCPU.
902898
auto frameOutput = convertAVFrameToFrameOutput(avFrame);
903899
firstFramePtsSeconds =
904900
std::min(firstFramePtsSeconds, frameOutput.ptsSeconds);
@@ -1163,7 +1159,6 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
11631159
std::optional<torch::Tensor> preAllocatedOutputTensor) {
11641160
// Convert the frame to tensor.
11651161
FrameOutput frameOutput;
1166-
frameOutput.streamIndex = activeStreamIndex_;
11671162
auto& streamInfo = streamInfos_[activeStreamIndex_];
11681163
frameOutput.ptsSeconds = ptsToSeconds(
11691164
avFrame->pts, formatContext_->streams[activeStreamIndex_]->time_base);

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,13 @@ class VideoDecoder {
153153
// They are the equivalent of the user-facing Frame and FrameBatch classes in
154154
// Python. They contain RGB decoded frames along with some associated data
155155
// like PTS and duration.
156+
// FrameOutput is also relevant for audio decoding, typically as the output of
157+
// getNextFrame(), or as a temporary output variable.
156158
struct FrameOutput {
157-
torch::Tensor data; // 3D: of shape CHW or HWC.
158-
int streamIndex;
159+
// data shape is:
160+
// - 3D (C, H, W) or (H, W, C) for videos
161+
// - 2D (numChannels, numSamples) for audio
162+
torch::Tensor data;
159163
double ptsSeconds;
160164
double durationSeconds;
161165
};

test/decoders/test_ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -650,12 +650,15 @@ def test_next(self, asset):
650650
frame_index = 0
651651
while True:
652652
try:
653-
frame, *_ = get_next_frame(decoder)
653+
frame, pts_seconds, duration_seconds = get_next_frame(decoder)
654654
except IndexError:
655655
break
656656
torch.testing.assert_close(
657657
frame, asset.get_frame_data_by_index(frame_index)
658658
)
659+
frame_info = asset.get_frame_info(frame_index)
660+
assert pts_seconds == frame_info.pts_seconds
661+
assert duration_seconds == frame_info.duration_seconds
659662
frame_index += 1
660663

661664
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)