Skip to content

Commit bea7360

Browse files
authored
Audio encoding - part 1 of N (#524)
1 parent 9ebac73 commit bea7360

File tree

12 files changed

+474
-5
lines changed

12 files changed

+474
-5
lines changed

src/torchcodec/_core/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ function(make_torchcodec_libraries
6161
AVIOContextHolder.cpp
6262
FFMPEGCommon.cpp
6363
SingleStreamDecoder.cpp
64+
# TODO: lib name should probably not be "*_decoder*" now that it also
65+
# contains an encoder
66+
Encoder.cpp
6467
)
6568

6669
if(ENABLE_CUDA)

src/torchcodec/_core/Encoder.cpp

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
#include "src/torchcodec/_core/Encoder.h"
2+
#include "torch/types.h"
3+
4+
namespace facebook::torchcodec {
5+
6+
AudioEncoder::~AudioEncoder() {}
7+
8+
// TODO-ENCODING: disable ffmpeg logs by default
9+
10+
AudioEncoder::AudioEncoder(
11+
const torch::Tensor wf,
12+
int sampleRate,
13+
std::string_view fileName)
14+
: wf_(wf), sampleRate_(sampleRate) {
15+
TORCH_CHECK(
16+
wf_.dtype() == torch::kFloat32,
17+
"waveform must have float32 dtype, got ",
18+
wf_.dtype());
19+
TORCH_CHECK(
20+
wf_.dim() == 2, "waveform must have 2 dimensions, got ", wf_.dim());
21+
AVFormatContext* avFormatContext = nullptr;
22+
auto status = avformat_alloc_output_context2(
23+
&avFormatContext, nullptr, nullptr, fileName.data());
24+
TORCH_CHECK(
25+
avFormatContext != nullptr,
26+
"Couldn't allocate AVFormatContext. ",
27+
"Check the desired extension? ",
28+
getFFMPEGErrorStringFromErrorCode(status));
29+
avFormatContext_.reset(avFormatContext);
30+
31+
// TODO-ENCODING: Should also support encoding into bytes (use
32+
// AVIOBytesContext)
33+
TORCH_CHECK(
34+
!(avFormatContext->oformat->flags & AVFMT_NOFILE),
35+
"AVFMT_NOFILE is set. We only support writing to a file.");
36+
status = avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE);
37+
TORCH_CHECK(
38+
status >= 0,
39+
"avio_open failed: ",
40+
getFFMPEGErrorStringFromErrorCode(status));
41+
42+
// We use the AVFormatContext's default codec for that
43+
// specific format/container.
44+
const AVCodec* avCodec =
45+
avcodec_find_encoder(avFormatContext_->oformat->audio_codec);
46+
TORCH_CHECK(avCodec != nullptr, "Codec not found");
47+
48+
AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec);
49+
TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context.");
50+
avCodecContext_.reset(avCodecContext);
51+
52+
// TODO-ENCODING I think this sets the bit rate to the minimum supported.
53+
// That's not what the ffmpeg CLI would choose by default, so we should try to
54+
// do the same.
55+
// TODO-ENCODING Should also let user choose for compressed formats like mp3.
56+
avCodecContext_->bit_rate = 0;
57+
58+
avCodecContext_->sample_rate = sampleRate_;
59+
60+
// Note: This is the format of the **input** waveform. This doesn't determine
61+
// the output.
62+
// TODO-ENCODING check contiguity of the input wf to ensure that it is indeed
63+
// planar.
64+
// TODO-ENCODING If the encoder doesn't support FLTP (like flac), FFmpeg will
65+
// raise. We need to handle this, probably converting the format with
66+
// libswresample.
67+
avCodecContext_->sample_fmt = AV_SAMPLE_FMT_FLTP;
68+
69+
int numChannels = static_cast<int>(wf_.sizes()[0]);
70+
TORCH_CHECK(
71+
// TODO-ENCODING is this even true / needed? We can probably support more
72+
// with non-planar data?
73+
numChannels <= AV_NUM_DATA_POINTERS,
74+
"Trying to encode ",
75+
numChannels,
76+
" channels, but FFmpeg only supports ",
77+
AV_NUM_DATA_POINTERS,
78+
" channels per frame.");
79+
80+
setDefaultChannelLayout(avCodecContext_, numChannels);
81+
82+
status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
83+
TORCH_CHECK(
84+
status == AVSUCCESS,
85+
"avcodec_open2 failed: ",
86+
getFFMPEGErrorStringFromErrorCode(status));
87+
88+
TORCH_CHECK(
89+
avCodecContext_->frame_size > 0,
90+
"frame_size is ",
91+
avCodecContext_->frame_size,
92+
". Cannot encode. This should probably never happen?");
93+
94+
// We're allocating the stream here. Streams are meant to be freed by
95+
// avformat_free_context(avFormatContext), which we call in the
96+
// avFormatContext_'s destructor.
97+
AVStream* avStream = avformat_new_stream(avFormatContext_.get(), nullptr);
98+
TORCH_CHECK(avStream != nullptr, "Couldn't create new stream.");
99+
status = avcodec_parameters_from_context(
100+
avStream->codecpar, avCodecContext_.get());
101+
TORCH_CHECK(
102+
status == AVSUCCESS,
103+
"avcodec_parameters_from_context failed: ",
104+
getFFMPEGErrorStringFromErrorCode(status));
105+
streamIndex_ = avStream->index;
106+
}
107+
108+
void AudioEncoder::encode() {
109+
UniqueAVFrame avFrame(av_frame_alloc());
110+
TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame.");
111+
avFrame->nb_samples = avCodecContext_->frame_size;
112+
avFrame->format = avCodecContext_->sample_fmt;
113+
avFrame->sample_rate = avCodecContext_->sample_rate;
114+
avFrame->pts = 0;
115+
setChannelLayout(avFrame, avCodecContext_);
116+
117+
auto status = av_frame_get_buffer(avFrame.get(), 0);
118+
TORCH_CHECK(
119+
status == AVSUCCESS,
120+
"Couldn't allocate avFrame's buffers: ",
121+
getFFMPEGErrorStringFromErrorCode(status));
122+
123+
AutoAVPacket autoAVPacket;
124+
125+
uint8_t* pwf = static_cast<uint8_t*>(wf_.data_ptr());
126+
int numSamples = static_cast<int>(wf_.sizes()[1]); // per channel
127+
int numEncodedSamples = 0; // per channel
128+
int numSamplesPerFrame = avCodecContext_->frame_size; // per channel
129+
int numBytesPerSample = static_cast<int>(wf_.element_size());
130+
int numBytesPerChannel = numSamples * numBytesPerSample;
131+
132+
status = avformat_write_header(avFormatContext_.get(), nullptr);
133+
TORCH_CHECK(
134+
status == AVSUCCESS,
135+
"Error in avformat_write_header: ",
136+
getFFMPEGErrorStringFromErrorCode(status));
137+
138+
while (numEncodedSamples < numSamples) {
139+
status = av_frame_make_writable(avFrame.get());
140+
TORCH_CHECK(
141+
status == AVSUCCESS,
142+
"Couldn't make AVFrame writable: ",
143+
getFFMPEGErrorStringFromErrorCode(status));
144+
145+
int numSamplesToEncode =
146+
std::min(numSamplesPerFrame, numSamples - numEncodedSamples);
147+
int numBytesToEncode = numSamplesToEncode * numBytesPerSample;
148+
149+
for (int ch = 0; ch < wf_.sizes()[0]; ch++) {
150+
std::memcpy(
151+
avFrame->data[ch], pwf + ch * numBytesPerChannel, numBytesToEncode);
152+
}
153+
pwf += numBytesToEncode;
154+
155+
// Above, we set the AVFrame's .nb_samples to AVCodecContext.frame_size so
156+
// that the frame buffers are allocated to a big enough size. Here, we reset
157+
// it to the exact number of samples that need to be encoded, otherwise the
158+
// encoded frame would contain more samples than necessary and our results
159+
// wouldn't match the ffmpeg CLI.
160+
avFrame->nb_samples = numSamplesToEncode;
161+
encodeInnerLoop(autoAVPacket, avFrame);
162+
163+
avFrame->pts += static_cast<int64_t>(numSamplesToEncode);
164+
numEncodedSamples += numSamplesToEncode;
165+
}
166+
TORCH_CHECK(numEncodedSamples == numSamples, "Hmmmmmm something went wrong.");
167+
168+
flushBuffers();
169+
170+
status = av_write_trailer(avFormatContext_.get());
171+
TORCH_CHECK(
172+
status == AVSUCCESS,
173+
"Error in: av_write_trailer",
174+
getFFMPEGErrorStringFromErrorCode(status));
175+
}
176+
177+
void AudioEncoder::encodeInnerLoop(
178+
AutoAVPacket& autoAVPacket,
179+
const UniqueAVFrame& avFrame) {
180+
auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get());
181+
TORCH_CHECK(
182+
status == AVSUCCESS,
183+
"Error while sending frame: ",
184+
getFFMPEGErrorStringFromErrorCode(status));
185+
186+
while (status >= 0) {
187+
ReferenceAVPacket packet(autoAVPacket);
188+
status = avcodec_receive_packet(avCodecContext_.get(), packet.get());
189+
if (status == AVERROR(EAGAIN) || status == AVERROR_EOF) {
190+
// TODO-ENCODING this is from TorchAudio, probably needed, but not sure.
191+
// if (status == AVERROR_EOF) {
192+
// status = av_interleaved_write_frame(avFormatContext_.get(),
193+
// nullptr); TORCH_CHECK(
194+
// status == AVSUCCESS,
195+
// "Failed to flush packet ",
196+
// getFFMPEGErrorStringFromErrorCode(status));
197+
// }
198+
return;
199+
}
200+
TORCH_CHECK(
201+
status >= 0,
202+
"Error receiving packet: ",
203+
getFFMPEGErrorStringFromErrorCode(status));
204+
205+
packet->stream_index = streamIndex_;
206+
207+
status = av_interleaved_write_frame(avFormatContext_.get(), packet.get());
208+
TORCH_CHECK(
209+
status == AVSUCCESS,
210+
"Error in av_interleaved_write_frame: ",
211+
getFFMPEGErrorStringFromErrorCode(status));
212+
}
213+
}
214+
215+
void AudioEncoder::flushBuffers() {
216+
AutoAVPacket autoAVPacket;
217+
encodeInnerLoop(autoAVPacket, UniqueAVFrame(nullptr));
218+
}
219+
} // namespace facebook::torchcodec

src/torchcodec/_core/Encoder.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#pragma once
2+
#include <torch/types.h>
3+
#include "src/torchcodec/_core/FFMPEGCommon.h"
4+
5+
namespace facebook::torchcodec {
6+
class AudioEncoder {
7+
public:
8+
~AudioEncoder();
9+
10+
AudioEncoder(
11+
const torch::Tensor wf,
12+
int sampleRate,
13+
std::string_view fileName);
14+
void encode();
15+
16+
private:
17+
void encodeInnerLoop(
18+
AutoAVPacket& autoAVPacket,
19+
const UniqueAVFrame& avFrame);
20+
void flushBuffers();
21+
22+
UniqueEncodingAVFormatContext avFormatContext_;
23+
UniqueAVCodecContext avCodecContext_;
24+
int streamIndex_;
25+
26+
const torch::Tensor wf_;
27+
// The *output* sample rate. We can't really decide for the user what it
28+
// should be. Particularly, the sample rate of the input waveform should match
29+
// this, and that's up to the user. If sample rates don't match, encoding will
30+
// still work but audio will be distorted.
31+
// We technically could let the user also specify the input sample rate, and
32+
// resample the waveform internally to match them, but that's not in scope for
33+
// an initial version (if at all).
34+
int sampleRate_;
35+
};
36+
} // namespace facebook::torchcodec

src/torchcodec/_core/FFMPEGCommon.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,38 @@ int getNumChannels(const UniqueAVCodecContext& avCodecContext) {
7474
#endif
7575
}
7676

77+
void setDefaultChannelLayout(
78+
UniqueAVCodecContext& avCodecContext,
79+
int numChannels) {
80+
#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4
81+
AVChannelLayout channel_layout;
82+
av_channel_layout_default(&channel_layout, numChannels);
83+
avCodecContext->ch_layout = channel_layout;
84+
85+
#else
86+
uint64_t channel_layout = av_get_default_channel_layout(numChannels);
87+
avCodecContext->channel_layout = channel_layout;
88+
avCodecContext->channels = numChannels;
89+
#endif
90+
}
91+
92+
void setChannelLayout(
93+
UniqueAVFrame& dstAVFrame,
94+
const UniqueAVCodecContext& avCodecContext) {
95+
#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4
96+
auto status = av_channel_layout_copy(
97+
&dstAVFrame->ch_layout, &avCodecContext->ch_layout);
98+
TORCH_CHECK(
99+
status == AVSUCCESS,
100+
"Couldn't copy channel layout to avFrame: ",
101+
getFFMPEGErrorStringFromErrorCode(status));
102+
#else
103+
dstAVFrame->channel_layout = avCodecContext->channel_layout;
104+
dstAVFrame->channels = avCodecContext->channels;
105+
106+
#endif
107+
}
108+
77109
void setChannelLayout(
78110
UniqueAVFrame& dstAVFrame,
79111
const UniqueAVFrame& srcAVFrame) {

src/torchcodec/_core/FFMPEGCommon.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,12 @@ struct Deleter {
5050
};
5151

5252
// Unique pointers for FFMPEG structures.
53-
using UniqueAVFormatContext = std::unique_ptr<
53+
using UniqueDecodingAVFormatContext = std::unique_ptr<
5454
AVFormatContext,
5555
Deleterp<AVFormatContext, void, avformat_close_input>>;
56+
using UniqueEncodingAVFormatContext = std::unique_ptr<
57+
AVFormatContext,
58+
Deleter<AVFormatContext, void, avformat_free_context>>;
5659
using UniqueAVCodecContext = std::unique_ptr<
5760
AVCodecContext,
5861
Deleterp<AVCodecContext, void, avcodec_free_context>>;
@@ -144,6 +147,14 @@ int64_t getDuration(const UniqueAVFrame& frame);
144147
int getNumChannels(const UniqueAVFrame& avFrame);
145148
int getNumChannels(const UniqueAVCodecContext& avCodecContext);
146149

150+
void setDefaultChannelLayout(
151+
UniqueAVCodecContext& avCodecContext,
152+
int numChannels);
153+
154+
void setChannelLayout(
155+
UniqueAVFrame& dstAVFrame,
156+
const UniqueAVCodecContext& avCodecContext);
157+
147158
void setChannelLayout(
148159
UniqueAVFrame& dstAVFrame,
149160
const UniqueAVFrame& srcAVFrame);

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1443,7 +1443,7 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU(
14431443
auto numBytesPerChannel = numSamples * av_get_bytes_per_sample(format);
14441444
for (auto channel = 0; channel < numChannels;
14451445
++channel, outputChannelData += numBytesPerChannel) {
1446-
memcpy(
1446+
std::memcpy(
14471447
outputChannelData,
14481448
avFrame->extended_data[channel],
14491449
numBytesPerChannel);

src/torchcodec/_core/SingleStreamDecoder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ class SingleStreamDecoder {
492492

493493
SeekMode seekMode_;
494494
ContainerMetadata containerMetadata_;
495-
UniqueAVFormatContext formatContext_;
495+
UniqueDecodingAVFormatContext formatContext_;
496496
std::map<int, StreamInfo> streamInfos_;
497497
const int NO_ACTIVE_STREAM = -2;
498498
int activeStreamIndex_ = NO_ACTIVE_STREAM;

src/torchcodec/_core/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
_test_frame_pts_equality,
1919
add_audio_stream,
2020
add_video_stream,
21+
create_audio_encoder,
2122
create_from_bytes,
2223
create_from_file,
2324
create_from_file_like,
2425
create_from_tensor,
26+
encode_audio,
2527
get_ffmpeg_library_versions,
2628
get_frame_at_index,
2729
get_frame_at_pts,

0 commit comments

Comments
 (0)