diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index e9036416..25cd81c0 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -895,10 +895,6 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio( decodeAVFrame([startPts](const UniqueAVFrame& avFrame) { return startPts < avFrame->pts + getDuration(avFrame); }); - // TODO: it's not great that we are getting a FrameOutput, which is - // intended for videos. We should consider bypassing - // convertAVFrameToFrameOutput and directly call - // convertAudioAVFrameToFrameOutputOnCPU. auto frameOutput = convertAVFrameToFrameOutput(avFrame); firstFramePtsSeconds = std::min(firstFramePtsSeconds, frameOutput.ptsSeconds); @@ -1163,7 +1159,6 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput( std::optional preAllocatedOutputTensor) { // Convert the frame to tensor. FrameOutput frameOutput; - frameOutput.streamIndex = activeStreamIndex_; auto& streamInfo = streamInfos_[activeStreamIndex_]; frameOutput.ptsSeconds = ptsToSeconds( avFrame->pts, formatContext_->streams[activeStreamIndex_]->time_base); diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 68311961..ea0c6179 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -153,9 +153,13 @@ class VideoDecoder { // They are the equivalent of the user-facing Frame and FrameBatch classes in // Python. They contain RGB decoded frames along with some associated data // like PTS and duration. + // FrameOutput is also relevant for audio decoding, typically as the output of + // getNextFrame(), or as a temporary output variable. struct FrameOutput { - torch::Tensor data; // 3D: of shape CHW or HWC. - int streamIndex; + // data shape is: + // - 3D (C, H, W) or (H, W, C) for videos + // - 2D (numChannels, numSamples) for audio + torch::Tensor data; double ptsSeconds; double durationSeconds; }; diff --git a/test/decoders/test_ops.py b/test/decoders/test_ops.py index a8d23eb0..9088539f 100644 --- a/test/decoders/test_ops.py +++ b/test/decoders/test_ops.py @@ -650,12 +650,15 @@ def test_next(self, asset): frame_index = 0 while True: try: - frame, *_ = get_next_frame(decoder) + frame, pts_seconds, duration_seconds = get_next_frame(decoder) except IndexError: break torch.testing.assert_close( frame, asset.get_frame_data_by_index(frame_index) ) + frame_info = asset.get_frame_info(frame_index) + assert pts_seconds == frame_info.pts_seconds + assert duration_seconds == frame_info.duration_seconds frame_index += 1 @pytest.mark.parametrize(