Skip to content

Move stream options and frame output structs to dedicated headers #620

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/torchcodec/_core/CudaDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,9 @@ void CudaDevice::initializeContext(AVCodecContext* codecContext) {
}

void CudaDevice::convertAVFrameToFrameOutput(
const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions,
const VideoStreamOptions& videoStreamOptions,
UniqueAVFrame& avFrame,
SingleStreamDecoder::FrameOutput& frameOutput,
FrameOutput& frameOutput,
std::optional<torch::Tensor> preAllocatedOutputTensor) {
TORCH_CHECK(
avFrame->format == AV_PIX_FMT_CUDA,
Expand Down
4 changes: 2 additions & 2 deletions src/torchcodec/_core/CudaDevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ class CudaDevice : public DeviceInterface {
void initializeContext(AVCodecContext* codecContext) override;

void convertAVFrameToFrameOutput(
const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions,
const VideoStreamOptions& videoStreamOptions,
UniqueAVFrame& avFrame,
SingleStreamDecoder::FrameOutput& frameOutput,
FrameOutput& frameOutput,
std::optional<torch::Tensor> preAllocatedOutputTensor =
std::nullopt) override;

Expand Down
7 changes: 4 additions & 3 deletions src/torchcodec/_core/DeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
#include <stdexcept>
#include <string>
#include "FFMPEGCommon.h"
#include "src/torchcodec/_core/SingleStreamDecoder.h"
#include "src/torchcodec/_core/Frame.h"
#include "src/torchcodec/_core/StreamOptions.h"

namespace facebook::torchcodec {

Expand Down Expand Up @@ -41,9 +42,9 @@ class DeviceInterface {
virtual void initializeContext(AVCodecContext* codecContext) = 0;

virtual void convertAVFrameToFrameOutput(
const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions,
const VideoStreamOptions& videoStreamOptions,
UniqueAVFrame& avFrame,
SingleStreamDecoder::FrameOutput& frameOutput,
FrameOutput& frameOutput,
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt) = 0;

protected:
Expand Down
47 changes: 47 additions & 0 deletions src/torchcodec/_core/Frame.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

#include <torch/types.h>
#include "src/torchcodec/_core/Metadata.h"
#include "src/torchcodec/_core/StreamOptions.h"

namespace facebook::torchcodec {

// All public video decoding entry points return either a FrameOutput or a
// FrameBatchOutput.
// They are the equivalent of the user-facing Frame and FrameBatch classes in
// Python. They contain RGB decoded frames along with some associated data
// like PTS and duration.
// FrameOutput is also relevant for audio decoding, typically as the output of
// getNextFrame(), or as a temporary output variable.
struct FrameOutput {
// data shape is:
// - 3D (C, H, W) or (H, W, C) for videos
// - 2D (numChannels, numSamples) for audio
torch::Tensor data;
double ptsSeconds;
double durationSeconds;
};

struct FrameBatchOutput {
torch::Tensor data; // 4D: of shape NCHW or NHWC.
torch::Tensor ptsSeconds; // 1D of shape (N,)
torch::Tensor durationSeconds; // 1D of shape (N,)

explicit FrameBatchOutput(
int64_t numFrames,
const VideoStreamOptions& videoStreamOptions,
const StreamMetadata& streamMetadata);
};

struct AudioFramesOutput {
torch::Tensor data; // shape is (numChannels, numSamples)
double ptsSeconds;
};

} // namespace facebook::torchcodec
70 changes: 70 additions & 0 deletions src/torchcodec/_core/Metadata.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

#include <optional>
#include <string>
#include <vector>

extern "C" {
#include <libavcodec/avcodec.h>
#include <libavutil/avutil.h>
}

namespace facebook::torchcodec {

struct StreamMetadata {
// Common (video and audio) fields derived from the AVStream.
int streamIndex;
// See this link for what various values are available:
// https://ffmpeg.org/doxygen/trunk/group__lavu__misc.html#ga9a84bba4713dfced21a1a56163be1f48
AVMediaType mediaType;
std::optional<AVCodecID> codecId;
std::optional<std::string> codecName;
std::optional<double> durationSeconds;
std::optional<double> beginStreamFromHeader;
std::optional<int64_t> numFrames;
std::optional<int64_t> numKeyFrames;
std::optional<double> averageFps;
std::optional<double> bitRate;

// More accurate duration, obtained by scanning the file.
// These presentation timestamps are in time base.
std::optional<int64_t> minPtsFromScan;
std::optional<int64_t> maxPtsFromScan;
// These presentation timestamps are in seconds.
std::optional<double> minPtsSecondsFromScan;
std::optional<double> maxPtsSecondsFromScan;
// This can be useful for index-based seeking.
std::optional<int64_t> numFramesFromScan;

// Video-only fields derived from the AVCodecContext.
std::optional<int64_t> width;
std::optional<int64_t> height;

// Audio-only fields
std::optional<int64_t> sampleRate;
std::optional<int64_t> numChannels;
std::optional<std::string> sampleFormat;
};

struct ContainerMetadata {
std::vector<StreamMetadata> allStreamMetadata;
int numAudioStreams = 0;
int numVideoStreams = 0;
// Note that this is the container-level duration, which is usually the max
// of all stream durations available in the container.
std::optional<double> durationSeconds;
// Total BitRate level information at the container level in bit/s
std::optional<double> bitRate;
// If set, this is the index to the default audio stream.
std::optional<int> bestAudioStreamIndex;
// If set, this is the index to the default video stream.
std::optional<int> bestVideoStreamIndex;
};

} // namespace facebook::torchcodec
60 changes: 26 additions & 34 deletions src/torchcodec/_core/SingleStreamDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include <sstream>
#include <stdexcept>
#include <string_view>
#include "src/torchcodec/_core/DeviceInterface.h"
#include "torch/types.h"

extern "C" {
Expand Down Expand Up @@ -350,8 +349,7 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() {
scannedAllStreams_ = true;
}

SingleStreamDecoder::ContainerMetadata
SingleStreamDecoder::getContainerMetadata() const {
ContainerMetadata SingleStreamDecoder::getContainerMetadata() const {
return containerMetadata_;
}

Expand Down Expand Up @@ -406,7 +404,7 @@ void SingleStreamDecoder::addStream(
streamInfo.stream = formatContext_->streams[activeStreamIndex_];
streamInfo.avMediaType = mediaType;

deviceInterface = createDeviceInterface(device);
deviceInterface_ = createDeviceInterface(device);

// This should never happen, checking just to be safe.
TORCH_CHECK(
Expand All @@ -418,9 +416,9 @@ void SingleStreamDecoder::addStream(
// TODO_CODE_QUALITY it's pretty meh to have a video-specific logic within
// addStream() which is supposed to be generic
if (mediaType == AVMEDIA_TYPE_VIDEO) {
if (deviceInterface) {
if (deviceInterface_) {
avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream(
deviceInterface->findCodec(streamInfo.stream->codecpar->codec_id)
deviceInterface_->findCodec(streamInfo.stream->codecpar->codec_id)
.value_or(avCodec));
}
}
Expand All @@ -438,8 +436,8 @@ void SingleStreamDecoder::addStream(

// TODO_CODE_QUALITY same as above.
if (mediaType == AVMEDIA_TYPE_VIDEO) {
if (deviceInterface) {
deviceInterface->initializeContext(codecContext);
if (deviceInterface_) {
deviceInterface_->initializeContext(codecContext);
}
}

Expand Down Expand Up @@ -501,9 +499,8 @@ void SingleStreamDecoder::addVideoStream(
// swscale requires widths to be multiples of 32:
// https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
// so we fall back to filtergraph if the width is not a multiple of 32.
auto defaultLibrary = (width % 32 == 0)
? SingleStreamDecoder::ColorConversionLibrary::SWSCALE
: SingleStreamDecoder::ColorConversionLibrary::FILTERGRAPH;
auto defaultLibrary = (width % 32 == 0) ? ColorConversionLibrary::SWSCALE
: ColorConversionLibrary::FILTERGRAPH;

streamInfo.colorConversionLibrary =
videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary);
Expand Down Expand Up @@ -539,30 +536,29 @@ void SingleStreamDecoder::addAudioStream(
// HIGH-LEVEL DECODING ENTRY-POINTS
// --------------------------------------------------------------------------

SingleStreamDecoder::FrameOutput SingleStreamDecoder::getNextFrame() {
FrameOutput SingleStreamDecoder::getNextFrame() {
auto output = getNextFrameInternal();
if (streamInfos_[activeStreamIndex_].avMediaType == AVMEDIA_TYPE_VIDEO) {
output.data = maybePermuteHWC2CHW(output.data);
}
return output;
}

SingleStreamDecoder::FrameOutput SingleStreamDecoder::getNextFrameInternal(
FrameOutput SingleStreamDecoder::getNextFrameInternal(
std::optional<torch::Tensor> preAllocatedOutputTensor) {
validateActiveStream();
UniqueAVFrame avFrame = decodeAVFrame(
[this](const UniqueAVFrame& avFrame) { return avFrame->pts >= cursor_; });
return convertAVFrameToFrameOutput(avFrame, preAllocatedOutputTensor);
}

SingleStreamDecoder::FrameOutput SingleStreamDecoder::getFrameAtIndex(
int64_t frameIndex) {
FrameOutput SingleStreamDecoder::getFrameAtIndex(int64_t frameIndex) {
auto frameOutput = getFrameAtIndexInternal(frameIndex);
frameOutput.data = maybePermuteHWC2CHW(frameOutput.data);
return frameOutput;
}

SingleStreamDecoder::FrameOutput SingleStreamDecoder::getFrameAtIndexInternal(
FrameOutput SingleStreamDecoder::getFrameAtIndexInternal(
int64_t frameIndex,
std::optional<torch::Tensor> preAllocatedOutputTensor) {
validateActiveStream(AVMEDIA_TYPE_VIDEO);
Expand All @@ -577,7 +573,7 @@ SingleStreamDecoder::FrameOutput SingleStreamDecoder::getFrameAtIndexInternal(
return getNextFrameInternal(preAllocatedOutputTensor);
}

SingleStreamDecoder::FrameBatchOutput SingleStreamDecoder::getFramesAtIndices(
FrameBatchOutput SingleStreamDecoder::getFramesAtIndices(
const std::vector<int64_t>& frameIndices) {
validateActiveStream(AVMEDIA_TYPE_VIDEO);

Expand Down Expand Up @@ -636,7 +632,7 @@ SingleStreamDecoder::FrameBatchOutput SingleStreamDecoder::getFramesAtIndices(
return frameBatchOutput;
}

SingleStreamDecoder::FrameBatchOutput SingleStreamDecoder::getFramesInRange(
FrameBatchOutput SingleStreamDecoder::getFramesInRange(
int64_t start,
int64_t stop,
int64_t step) {
Expand Down Expand Up @@ -670,8 +666,7 @@ SingleStreamDecoder::FrameBatchOutput SingleStreamDecoder::getFramesInRange(
return frameBatchOutput;
}

SingleStreamDecoder::FrameOutput SingleStreamDecoder::getFramePlayedAt(
double seconds) {
FrameOutput SingleStreamDecoder::getFramePlayedAt(double seconds) {
validateActiveStream(AVMEDIA_TYPE_VIDEO);
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
double frameStartTime =
Expand Down Expand Up @@ -711,7 +706,7 @@ SingleStreamDecoder::FrameOutput SingleStreamDecoder::getFramePlayedAt(
return frameOutput;
}

SingleStreamDecoder::FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
const std::vector<double>& timestamps) {
validateActiveStream(AVMEDIA_TYPE_VIDEO);

Expand Down Expand Up @@ -741,8 +736,7 @@ SingleStreamDecoder::FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
return getFramesAtIndices(frameIndices);
}

SingleStreamDecoder::FrameBatchOutput
SingleStreamDecoder::getFramesPlayedInRange(
FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
double startSeconds,
double stopSeconds) {
validateActiveStream(AVMEDIA_TYPE_VIDEO);
Expand Down Expand Up @@ -875,8 +869,7 @@ SingleStreamDecoder::getFramesPlayedInRange(
// [2] If you're brave and curious, you can read the long "Seek offset for
// audio" note in https://github.com/pytorch/torchcodec/pull/507/files, which
// sums up past (and failed) attemps at working around this issue.
SingleStreamDecoder::AudioFramesOutput
SingleStreamDecoder::getFramesPlayedInRangeAudio(
AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio(
double startSeconds,
std::optional<double> stopSecondsOptional) {
validateActiveStream(AVMEDIA_TYPE_AUDIO);
Expand Down Expand Up @@ -1196,8 +1189,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
// AVFRAME <-> FRAME OUTPUT CONVERSION
// --------------------------------------------------------------------------

SingleStreamDecoder::FrameOutput
SingleStreamDecoder::convertAVFrameToFrameOutput(
FrameOutput SingleStreamDecoder::convertAVFrameToFrameOutput(
UniqueAVFrame& avFrame,
std::optional<torch::Tensor> preAllocatedOutputTensor) {
// Convert the frame to tensor.
Expand All @@ -1210,11 +1202,11 @@ SingleStreamDecoder::convertAVFrameToFrameOutput(
formatContext_->streams[activeStreamIndex_]->time_base);
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
convertAudioAVFrameToFrameOutputOnCPU(avFrame, frameOutput);
} else if (!deviceInterface) {
} else if (!deviceInterface_) {
convertAVFrameToFrameOutputOnCPU(
avFrame, frameOutput, preAllocatedOutputTensor);
} else if (deviceInterface) {
deviceInterface->convertAVFrameToFrameOutput(
} else if (deviceInterface_) {
deviceInterface_->convertAVFrameToFrameOutput(
streamInfo.videoStreamOptions,
avFrame,
frameOutput,
Expand Down Expand Up @@ -1547,7 +1539,7 @@ std::optional<torch::Tensor> SingleStreamDecoder::maybeFlushSwrBuffers() {
// OUTPUT ALLOCATION AND SHAPE CONVERSION
// --------------------------------------------------------------------------

SingleStreamDecoder::FrameBatchOutput::FrameBatchOutput(
FrameBatchOutput::FrameBatchOutput(
int64_t numFrames,
const VideoStreamOptions& videoStreamOptions,
const StreamMetadata& streamMetadata)
Expand Down Expand Up @@ -2047,15 +2039,15 @@ FrameDims getHeightAndWidthFromResizedAVFrame(const AVFrame& resizedAVFrame) {
}

FrameDims getHeightAndWidthFromOptionsOrMetadata(
const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions,
const SingleStreamDecoder::StreamMetadata& streamMetadata) {
const VideoStreamOptions& videoStreamOptions,
const StreamMetadata& streamMetadata) {
return FrameDims(
videoStreamOptions.height.value_or(*streamMetadata.height),
videoStreamOptions.width.value_or(*streamMetadata.width));
}

FrameDims getHeightAndWidthFromOptionsOrAVFrame(
const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions,
const VideoStreamOptions& videoStreamOptions,
const UniqueAVFrame& avFrame) {
return FrameDims(
videoStreamOptions.height.value_or(avFrame->height),
Expand Down
Loading
Loading