Skip to content

Encoding: allow user-defined encoded sample rate #700

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 25 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
52d624b
Add num_channels parameter to AudioEncoder
NicolasHug May 21, 2025
aad9c7d
Merge branch 'main' of github.com:pytorch/torchcodec into encoding_nu…
NicolasHug May 22, 2025
2d76a7b
Add validation for num_channels
NicolasHug May 22, 2025
7d643f2
Fix FFmpeg 5.X?
NicolasHug May 22, 2025
5d9eb54
Migrate encoder tests to public Python APIs
NicolasHug May 22, 2025
c40deef
Add output sample rate, WIP
NicolasHug May 22, 2025
96e5e60
Merge branch 'main' of github.com:pytorch/torchcodec into migrate_enc…
NicolasHug May 22, 2025
952af0f
Re-remove
NicolasHug May 22, 2025
88a87c4
Merge branch 'migrate_encoding_test' into encoding_sample_rate_lezzzgo
NicolasHug May 22, 2025
e0ba0c5
Merge branch 'main' of github.com:pytorch/torchcodec into encoding_sa…
NicolasHug May 22, 2025
2c559b2
Use 'output' more consistently
NicolasHug May 23, 2025
70ae1a1
Use AudioStreamOptions in AudioEncoder
NicolasHug May 23, 2025
75e23b9
Merge branch 'main' of github.com:pytorch/torchcodec into use_audioSt…
NicolasHug May 27, 2025
b6e3c27
Merge branch 'use_audioStreamOptions' into encoding_sample_rate_lezzzgo
NicolasHug May 27, 2025
387328a
Merge branch 'main' of github.com:pytorch/torchcodec into encoding_sa…
NicolasHug May 27, 2025
823e7f0
WIP
NicolasHug May 27, 2025
4be2953
Add flushing logic for swresample buffers
NicolasHug May 27, 2025
639d5ab
More tests
NicolasHug May 27, 2025
3ce4612
WIP
NicolasHug May 28, 2025
6c91450
Refactor audio sample conversion in encoder
NicolasHug May 28, 2025
b2eed2f
Merge branch 'move-conversion-out' into encoding_sample_rate_lezzzgo
NicolasHug May 29, 2025
8fdb6ed
wav tests pass
NicolasHug May 29, 2025
3399b34
Merge branch 'main' of github.com:pytorch/torchcodec into encoding_sa…
NicolasHug May 29, 2025
6d7908f
Use intermediate FIFO, WIP
NicolasHug May 30, 2025
f30d0ff
mostly works
NicolasHug May 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 126 additions & 44 deletions src/torchcodec/_core/Encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,33 @@ AVSampleFormat findBestOutputSampleFormat(const AVCodec& avCodec) {
return avCodec.sample_fmts[0];
}

UniqueAVFrame allocateAVFrame(
int numSamples,
int sampleRate,
int numChannels,
AVSampleFormat sampleFormat) {
auto avFrame = UniqueAVFrame(av_frame_alloc());
TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame.");

avFrame->nb_samples = numSamples;
avFrame->sample_rate = sampleRate;
av_channel_layout_default(&avFrame->ch_layout, numChannels);
avFrame->format = sampleFormat;
auto status = av_frame_get_buffer(avFrame.get(), 0);

TORCH_CHECK(
status == AVSUCCESS,
"Couldn't allocate avFrame's buffers: ",
getFFMPEGErrorStringFromErrorCode(status));

status = av_frame_make_writable(avFrame.get());
TORCH_CHECK(
status == AVSUCCESS,
"Couldn't make AVFrame writable: ",
getFFMPEGErrorStringFromErrorCode(status));
return avFrame;
}

} // namespace

AudioEncoder::~AudioEncoder() {}
Expand All @@ -105,7 +132,7 @@ AudioEncoder::AudioEncoder(
int sampleRate,
std::string_view fileName,
const AudioStreamOptions& audioStreamOptions)
: samples_(validateSamples(samples)) {
: samples_(validateSamples(samples)), sampleRateInput_(sampleRate) {
setFFmpegLogLevel();
AVFormatContext* avFormatContext = nullptr;
int status = avformat_alloc_output_context2(
Expand All @@ -128,7 +155,7 @@ AudioEncoder::AudioEncoder(
", make sure it's a valid path? ",
getFFMPEGErrorStringFromErrorCode(status));

initializeEncoder(sampleRate, audioStreamOptions);
initializeEncoder(audioStreamOptions);
}

AudioEncoder::AudioEncoder(
Expand All @@ -138,6 +165,7 @@ AudioEncoder::AudioEncoder(
std::unique_ptr<AVIOToTensorContext> avioContextHolder,
const AudioStreamOptions& audioStreamOptions)
: samples_(validateSamples(samples)),
sampleRateInput_(sampleRate),
avioContextHolder_(std::move(avioContextHolder)) {
setFFmpegLogLevel();
AVFormatContext* avFormatContext = nullptr;
Expand All @@ -155,11 +183,10 @@ AudioEncoder::AudioEncoder(

avFormatContext_->pb = avioContextHolder_->getAVIOContext();

initializeEncoder(sampleRate, audioStreamOptions);
initializeEncoder(audioStreamOptions);
}

void AudioEncoder::initializeEncoder(
int sampleRate,
const AudioStreamOptions& audioStreamOptions) {
// We use the AVFormatContext's default codec for that
// specific format/container.
Expand Down Expand Up @@ -187,8 +214,9 @@ void AudioEncoder::initializeEncoder(
// not related to the input sampes.
setDefaultChannelLayout(avCodecContext_, outNumChannels_);

validateSampleRate(*avCodec, sampleRate);
avCodecContext_->sample_rate = sampleRate;
outSampleRate_ = audioStreamOptions.sampleRate.value_or(sampleRateInput_);
validateSampleRate(*avCodec, outSampleRate_);
avCodecContext_->sample_rate = outSampleRate_;

// Input samples are expected to be FLTP. Not all encoders support FLTP, so we
// may need to convert the samples into a supported output sample format,
Expand All @@ -213,6 +241,18 @@ void AudioEncoder::initializeEncoder(
"avcodec_parameters_from_context failed: ",
getFFMPEGErrorStringFromErrorCode(status));
streamIndex_ = avStream->index;

if (((avCodec->capabilities & AV_CODEC_CAP_VARIABLE_FRAME_SIZE) == 0) &&
(sampleRateInput_ != outSampleRate_)) {
// frame_size * 2 is a decent default size. FFmpeg automatically
// re-allocates the fifo if more space is needed.
auto avAudioFifo = av_audio_fifo_alloc(
avCodecContext_->sample_fmt,
outNumChannels_,
avCodecContext_->frame_size * 2);
TORCH_CHECK(avAudioFifo != nullptr, "Couldn't create AVAudioFifo.");
avAudioFifo_.reset(avAudioFifo);
}
}

torch::Tensor AudioEncoder::encodeToTensor() {
Expand All @@ -230,24 +270,15 @@ void AudioEncoder::encode() {
TORCH_CHECK(!encodeWasCalled_, "Cannot call encode() twice.");
encodeWasCalled_ = true;

UniqueAVFrame avFrame(av_frame_alloc());
TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame.");
// Default to 256 like in torchaudio
int numSamplesAllocatedPerFrame =
avCodecContext_->frame_size > 0 ? avCodecContext_->frame_size : 256;
avFrame->nb_samples = numSamplesAllocatedPerFrame;
avFrame->format = AV_SAMPLE_FMT_FLTP;
avFrame->sample_rate = avCodecContext_->sample_rate;
UniqueAVFrame avFrame = allocateAVFrame(
numSamplesAllocatedPerFrame,
sampleRateInput_,
static_cast<int>(samples_.sizes()[0]),
AV_SAMPLE_FMT_FLTP);
avFrame->pts = 0;
// We set the channel layout of the frame to the default layout corresponding
// to the input samples' number of channels
setDefaultChannelLayout(avFrame, static_cast<int>(samples_.sizes()[0]));

auto status = av_frame_get_buffer(avFrame.get(), 0);
TORCH_CHECK(
status == AVSUCCESS,
"Couldn't allocate avFrame's buffers: ",
getFFMPEGErrorStringFromErrorCode(status));

AutoAVPacket autoAVPacket;

Expand All @@ -257,19 +288,13 @@ void AudioEncoder::encode() {
int numBytesPerSample = static_cast<int>(samples_.element_size());
int numBytesPerChannel = numSamples * numBytesPerSample;

status = avformat_write_header(avFormatContext_.get(), nullptr);
auto status = avformat_write_header(avFormatContext_.get(), nullptr);
TORCH_CHECK(
status == AVSUCCESS,
"Error in avformat_write_header: ",
getFFMPEGErrorStringFromErrorCode(status));

while (numEncodedSamples < numSamples) {
status = av_frame_make_writable(avFrame.get());
TORCH_CHECK(
status == AVSUCCESS,
"Couldn't make AVFrame writable: ",
getFFMPEGErrorStringFromErrorCode(status));

int numSamplesToEncode =
std::min(numSamplesAllocatedPerFrame, numSamples - numEncodedSamples);
int numBytesToEncode = numSamplesToEncode * numBytesPerSample;
Expand All @@ -290,7 +315,7 @@ void AudioEncoder::encode() {
avFrame->nb_samples = numSamplesToEncode;

UniqueAVFrame convertedAVFrame = maybeConvertAVFrame(avFrame);
encodeInnerLoop(autoAVPacket, convertedAVFrame);
sendFrameThroughFifo(autoAVPacket, convertedAVFrame);

numEncodedSamples += numSamplesToEncode;
// TODO-ENCODING set frame pts correctly, and test against it.
Expand All @@ -310,7 +335,8 @@ void AudioEncoder::encode() {
UniqueAVFrame AudioEncoder::maybeConvertAVFrame(const UniqueAVFrame& avFrame) {
if (static_cast<AVSampleFormat>(avFrame->format) ==
avCodecContext_->sample_fmt &&
getNumChannels(avFrame) == outNumChannels_) {
getNumChannels(avFrame) == outNumChannels_ &&
avFrame->sample_rate == outSampleRate_) {
// Note: the clone references the same underlying data, it's a cheap copy.
return UniqueAVFrame(av_frame_clone(avFrame.get()));
}
Expand All @@ -319,29 +345,62 @@ UniqueAVFrame AudioEncoder::maybeConvertAVFrame(const UniqueAVFrame& avFrame) {
swrContext_.reset(createSwrContext(
static_cast<AVSampleFormat>(avFrame->format),
avCodecContext_->sample_fmt,
avFrame->sample_rate, // No sample rate conversion
avFrame->sample_rate,
outSampleRate_,
avFrame,
outNumChannels_));
}
UniqueAVFrame convertedAVFrame = convertAudioAVFrameSamples(
swrContext_,
avFrame,
avCodecContext_->sample_fmt,
avFrame->sample_rate, // No sample rate conversion
outSampleRate_,
outNumChannels_);
TORCH_CHECK(
convertedAVFrame->nb_samples == avFrame->nb_samples,
"convertedAVFrame->nb_samples=",
convertedAVFrame->nb_samples,
" differs from ",
"avFrame->nb_samples=",
avFrame->nb_samples,
"This is unexpected, please report on the TorchCodec bug tracker.");

if (avFrame->sample_rate == outSampleRate_) {
TORCH_CHECK(
convertedAVFrame->nb_samples == avFrame->nb_samples,
"convertedAVFrame->nb_samples=",
convertedAVFrame->nb_samples,
" differs from ",
"avFrame->nb_samples=",
avFrame->nb_samples,
"This is unexpected, please report on the TorchCodec bug tracker.");
}
return convertedAVFrame;
}

void AudioEncoder::encodeInnerLoop(
void AudioEncoder::sendFrameThroughFifo(
AutoAVPacket& autoAVPacket,
const UniqueAVFrame& avFrame,
bool andFlushFifo) {
if (avAudioFifo_ == nullptr) {
encodeFrame(autoAVPacket, avFrame);
return;
}
// TODO static cast
int numSamplesWritten = av_audio_fifo_write(
avAudioFifo_.get(), (void**)avFrame->data, avFrame->nb_samples);
TORCH_CHECK(numSamplesWritten == avFrame->nb_samples, "Tried to write TODO");

UniqueAVFrame newavFrame = allocateAVFrame(
avCodecContext_->frame_size,
outSampleRate_,
outNumChannels_,
avCodecContext_->sample_fmt);

while (av_audio_fifo_size(avAudioFifo_.get()) >=
(andFlushFifo ? 1 : avCodecContext_->frame_size)) {
// TODO cast
int numSamplesRead = av_audio_fifo_read(
avAudioFifo_.get(), (void**)newavFrame->data, newavFrame->nb_samples);
TORCH_CHECK(numSamplesRead > 0, "Tried to read TODO");

encodeFrame(autoAVPacket, newavFrame);
}
}

void AudioEncoder::encodeFrame(
AutoAVPacket& autoAVPacket,
const UniqueAVFrame& avFrame) {
auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get());
Expand Down Expand Up @@ -382,11 +441,34 @@ void AudioEncoder::encodeInnerLoop(
}
}

void AudioEncoder::maybeFlushSwrBuffers(AutoAVPacket& autoAVPacket) {
// Similar to the decoder's method with the same name, but for encoding this
// time. That is, when sample conversion is invovled, libswresample may have
// buffered some samples that we now need to flush and send to the encoder.
if (swrContext_ == nullptr && sampleRateInput_ == outSampleRate_) {
return;
}
int numRemainingSamples = // this is an upper bound
swr_get_out_samples(swrContext_.get(), 0);
if (numRemainingSamples == 0) {
return;
}

UniqueAVFrame avFrame = allocateAVFrame(
numRemainingSamples,
outSampleRate_,
outNumChannels_,
avCodecContext_->sample_fmt);
int actualNumRemainingSamples = swr_convert(
swrContext_.get(), avFrame->data, avFrame->nb_samples, NULL, 0);
avFrame->nb_samples = actualNumRemainingSamples;

sendFrameThroughFifo(autoAVPacket, avFrame, /*andFlushFifo=*/true);
}

void AudioEncoder::flushBuffers() {
// We flush the main FFmpeg buffers, but not swresample buffers. Flushing
// swresample is only necessary when converting sample rates, which we don't
// do for encoding.
AutoAVPacket autoAVPacket;
encodeInnerLoop(autoAVPacket, UniqueAVFrame(nullptr));
maybeFlushSwrBuffers(autoAVPacket);
encodeFrame(autoAVPacket, UniqueAVFrame(nullptr));
}
} // namespace facebook::torchcodec
15 changes: 10 additions & 5 deletions src/torchcodec/_core/Encoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,14 @@ class AudioEncoder {
torch::Tensor encodeToTensor();

private:
void initializeEncoder(
int sampleRate,
const AudioStreamOptions& audioStreamOptions);
void initializeEncoder(const AudioStreamOptions& audioStreamOptions);
UniqueAVFrame maybeConvertAVFrame(const UniqueAVFrame& avFrame);
void encodeInnerLoop(
void sendFrameThroughFifo(
AutoAVPacket& autoAVPacket,
const UniqueAVFrame& srcAVFrame);
const UniqueAVFrame& avFrame,
bool andFlushFifo = false);
void encodeFrame(AutoAVPacket& autoAVPacket, const UniqueAVFrame& avFrame);
void maybeFlushSwrBuffers(AutoAVPacket& autoAVPacket);
void flushBuffers();

UniqueEncodingAVFormatContext avFormatContext_;
Expand All @@ -51,8 +52,12 @@ class AudioEncoder {
AudioStreamOptions audioStreamOptions;

int outNumChannels_ = -1;
int outSampleRate_ = -1;

const torch::Tensor samples_;
int sampleRateInput_ = -1;

UniqueAVAudioFifo avAudioFifo_;

// Stores the AVIOContext for the output tensor buffer.
std::unique_ptr<AVIOToTensorContext> avioContextHolder_;
Expand Down
3 changes: 3 additions & 0 deletions src/torchcodec/_core/FFMPEGCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ extern "C" {
#include <libavfilter/avfilter.h>
#include <libavformat/avformat.h>
#include <libavformat/avio.h>
#include <libavutil/audio_fifo.h>
#include <libavutil/avutil.h>
#include <libavutil/dict.h>
#include <libavutil/display.h>
Expand Down Expand Up @@ -73,6 +74,8 @@ using UniqueSwsContext =
std::unique_ptr<SwsContext, Deleter<SwsContext, void, sws_freeContext>>;
using UniqueSwrContext =
std::unique_ptr<SwrContext, Deleterp<SwrContext, void, swr_free>>;
using UniqueAVAudioFifo = std::
unique_ptr<AVAudioFifo, Deleter<AVAudioFifo, void, av_audio_fifo_free>>;

// These 2 classes share the same underlying AVPacket object. They are meant to
// be used in tandem, like so:
Expand Down
12 changes: 8 additions & 4 deletions src/torchcodec/_core/custom_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ TORCH_LIBRARY(torchcodec_ns, m) {
"torchcodec._core.ops", "//pytorch/torchcodec:torchcodec");
m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor");
m.def(
"encode_audio_to_file(Tensor samples, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None) -> ()");
"encode_audio_to_file(Tensor samples, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
m.def(
"encode_audio_to_tensor(Tensor samples, int sample_rate, str format, int? bit_rate=None, int? num_channels=None) -> Tensor");
"encode_audio_to_tensor(Tensor samples, int sample_rate, str format, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> Tensor");
m.def(
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
m.def("_convert_to_tensor(int decoder_ptr) -> Tensor");
Expand Down Expand Up @@ -392,12 +392,14 @@ void encode_audio_to_file(
int64_t sample_rate,
std::string_view file_name,
std::optional<int64_t> bit_rate = std::nullopt,
std::optional<int64_t> num_channels = std::nullopt) {
std::optional<int64_t> num_channels = std::nullopt,
std::optional<int64_t> desired_sample_rate = std::nullopt) {
// TODO Fix implicit int conversion:
// https://github.com/pytorch/torchcodec/issues/679
AudioStreamOptions audioStreamOptions;
audioStreamOptions.bitRate = bit_rate;
audioStreamOptions.numChannels = num_channels;
audioStreamOptions.sampleRate = desired_sample_rate;
AudioEncoder(
samples, validateSampleRate(sample_rate), file_name, audioStreamOptions)
.encode();
Expand All @@ -408,13 +410,15 @@ at::Tensor encode_audio_to_tensor(
int64_t sample_rate,
std::string_view format,
std::optional<int64_t> bit_rate = std::nullopt,
std::optional<int64_t> num_channels = std::nullopt) {
std::optional<int64_t> num_channels = std::nullopt,
std::optional<int64_t> desired_sample_rate = std::nullopt) {
auto avioContextHolder = std::make_unique<AVIOToTensorContext>();
// TODO Fix implicit int conversion:
// https://github.com/pytorch/torchcodec/issues/679
AudioStreamOptions audioStreamOptions;
audioStreamOptions.bitRate = bit_rate;
audioStreamOptions.numChannels = num_channels;
audioStreamOptions.sampleRate = desired_sample_rate;
return AudioEncoder(
samples,
validateSampleRate(sample_rate),
Expand Down
2 changes: 2 additions & 0 deletions src/torchcodec/_core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def encode_audio_to_file_abstract(
filename: str,
bit_rate: Optional[int] = None,
num_channels: Optional[int] = None,
desired_sample_rate: Optional[int] = None,
) -> None:
return

Expand All @@ -179,6 +180,7 @@ def encode_audio_to_tensor_abstract(
format: str,
bit_rate: Optional[int] = None,
num_channels: Optional[int] = None,
desired_sample_rate: Optional[int] = None,
) -> torch.Tensor:
return torch.empty([], dtype=torch.long)

Expand Down
Loading
Loading