Skip to content

Commit b6dd961

Browse files
authored
Add bit_rate parameter to encoder (#623)
1 parent 5402e7d commit b6dd961

File tree

5 files changed

+35
-20
lines changed

5 files changed

+35
-20
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ AudioEncoder::~AudioEncoder() {}
1010
AudioEncoder::AudioEncoder(
1111
const torch::Tensor wf,
1212
int sampleRate,
13-
std::string_view fileName)
13+
std::string_view fileName,
14+
std::optional<int64_t> bit_rate)
1415
: wf_(wf), sampleRate_(sampleRate) {
1516
TORCH_CHECK(
1617
wf_.dtype() == torch::kFloat32,
@@ -49,11 +50,12 @@ AudioEncoder::AudioEncoder(
4950
TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context.");
5051
avCodecContext_.reset(avCodecContext);
5152

52-
// TODO-ENCODING I think this sets the bit rate to the minimum supported.
53-
// That's not what the ffmpeg CLI would choose by default, so we should try to
54-
// do the same.
55-
// TODO-ENCODING Should also let user choose for compressed formats like mp3.
56-
avCodecContext_->bit_rate = 0;
53+
if (bit_rate.has_value()) {
54+
TORCH_CHECK(*bit_rate >= 0, "bit_rate=", *bit_rate, " must be >= 0.");
55+
}
56+
// bit_rate=None defaults to 0, which is what the FFmpeg CLI seems to use as
57+
// well when "-b:a" isn't specified.
58+
avCodecContext_->bit_rate = bit_rate.value_or(0);
5759

5860
avCodecContext_->sample_rate = sampleRate_;
5961

src/torchcodec/_core/Encoder.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,16 @@ class AudioEncoder {
77
public:
88
~AudioEncoder();
99

10+
// TODO-ENCODING: document in public docs that bit_rate value is only
11+
// best-effort, matching to the closest supported bit_rate. I.e. passing 1 is
12+
// like passing 0, which results in choosing the minimum supported bit rate.
13+
// Passing 44_100 could result in output being 44000 if only 44000 is
14+
// supported.
1015
AudioEncoder(
1116
const torch::Tensor wf,
1217
int sampleRate,
13-
std::string_view fileName);
18+
std::string_view fileName,
19+
std::optional<int64_t> bit_rate = std::nullopt);
1420
void encode();
1521

1622
private:

src/torchcodec/_core/custom_ops.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
2929
"torchcodec._core.ops", "//pytorch/torchcodec:torchcodec");
3030
m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor");
3131
m.def(
32-
"create_audio_encoder(Tensor wf, int sample_rate, str filename) -> Tensor");
32+
"create_audio_encoder(Tensor wf, int sample_rate, str filename, int? bit_rate=None) -> Tensor");
3333
m.def("encode_audio(Tensor(a!) encoder) -> ()");
3434
m.def(
3535
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
@@ -396,15 +396,16 @@ AudioEncoder* unwrapTensorToGetAudioEncoder(at::Tensor& tensor) {
396396
at::Tensor create_audio_encoder(
397397
const at::Tensor wf,
398398
int64_t sample_rate,
399-
std::string_view file_name) {
399+
std::string_view file_name,
400+
std::optional<int64_t> bit_rate = std::nullopt) {
400401
TORCH_CHECK(
401402
sample_rate <= std::numeric_limits<int>::max(),
402403
"sample_rate=",
403404
sample_rate,
404405
" is too large to be cast to an int.");
405406
std::unique_ptr<AudioEncoder> uniqueAudioEncoder =
406407
std::make_unique<AudioEncoder>(
407-
wf, static_cast<int>(sample_rate), file_name);
408+
wf, static_cast<int>(sample_rate), file_name, bit_rate);
408409
return wrapAudioEncoderPointerToTensor(std::move(uniqueAudioEncoder));
409410
}
410411

src/torchcodec/_core/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def create_from_file_abstract(filename: str, seek_mode: Optional[str]) -> torch.
163163

164164
@register_fake("torchcodec_ns::create_audio_encoder")
165165
def create_audio_encoder_abstract(
166-
wf: torch.Tensor, sample_rate: int, filename: str
166+
wf: torch.Tensor, sample_rate: int, filename: str, bit_rate: Optional[int] = None
167167
) -> torch.Tensor:
168168
return torch.empty([], dtype=torch.long)
169169

test/test_ops.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,6 +1115,13 @@ def test_bad_input(self, tmp_path):
11151115
sample_rate=10,
11161116
filename=valid_output_file,
11171117
)
1118+
with pytest.raises(RuntimeError, match="bit_rate=-1 must be >= 0"):
1119+
create_audio_encoder(
1120+
wf=self.decode(NASA_AUDIO_MP3),
1121+
sample_rate=NASA_AUDIO_MP3.sample_rate,
1122+
filename=valid_output_file,
1123+
bit_rate=-1, # bad
1124+
)
11181125

11191126
def test_round_trip(self, tmp_path):
11201127
# Check that decode(encode(samples)) == samples
@@ -1127,28 +1134,26 @@ def test_round_trip(self, tmp_path):
11271134
)
11281135
encode_audio(encoder)
11291136

1130-
# TODO-ENCODING: tol should be stricter. We need to increase the encoded
1131-
# bitrate, and / or encode into a lossless format.
1137+
# TODO-ENCODING: tol should be stricter. We probably need to encode
1138+
# into a lossless format.
11321139
torch.testing.assert_close(
11331140
self.decode(encoded_path), source_samples, rtol=0, atol=0.07
11341141
)
11351142

11361143
# TODO-ENCODING: test more encoding formats
11371144
@pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32))
1138-
def test_against_cli(self, asset, tmp_path):
1145+
@pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999))
1146+
def test_against_cli(self, asset, bit_rate, tmp_path):
11391147
# Encodes samples with our encoder and with the FFmpeg CLI, and checks
11401148
# that both decoded outputs are equal
11411149

11421150
encoded_by_ffmpeg = tmp_path / "ffmpeg_output.mp3"
11431151
encoded_by_us = tmp_path / "our_output.mp3"
11441152

11451153
subprocess.run(
1146-
[
1147-
"ffmpeg",
1148-
"-i",
1149-
str(asset.path),
1150-
"-b:a",
1151-
"0", # bitrate hardcoded to 0, see corresponding TODO.
1154+
["ffmpeg", "-i", str(asset.path)]
1155+
+ (["-b:a", f"{bit_rate}"] if bit_rate is not None else [])
1156+
+ [
11521157
str(encoded_by_ffmpeg),
11531158
],
11541159
capture_output=True,
@@ -1159,6 +1164,7 @@ def test_against_cli(self, asset, tmp_path):
11591164
wf=self.decode(asset),
11601165
sample_rate=asset.sample_rate,
11611166
filename=str(encoded_by_us),
1167+
bit_rate=bit_rate,
11621168
)
11631169
encode_audio(encoder)
11641170

0 commit comments

Comments
 (0)