diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 9d5c1dea..9696c2a5 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -10,7 +10,8 @@ AudioEncoder::~AudioEncoder() {} AudioEncoder::AudioEncoder( const torch::Tensor wf, int sampleRate, - std::string_view fileName) + std::string_view fileName, + std::optional bit_rate) : wf_(wf), sampleRate_(sampleRate) { TORCH_CHECK( wf_.dtype() == torch::kFloat32, @@ -49,11 +50,12 @@ AudioEncoder::AudioEncoder( TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context."); avCodecContext_.reset(avCodecContext); - // TODO-ENCODING I think this sets the bit rate to the minimum supported. - // That's not what the ffmpeg CLI would choose by default, so we should try to - // do the same. - // TODO-ENCODING Should also let user choose for compressed formats like mp3. - avCodecContext_->bit_rate = 0; + if (bit_rate.has_value()) { + TORCH_CHECK(*bit_rate >= 0, "bit_rate=", *bit_rate, " must be >= 0."); + } + // bit_rate=None defaults to 0, which is what the FFmpeg CLI seems to use as + // well when "-b:a" isn't specified. + avCodecContext_->bit_rate = bit_rate.value_or(0); avCodecContext_->sample_rate = sampleRate_; diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index f0621fe5..8e715d3c 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -7,10 +7,16 @@ class AudioEncoder { public: ~AudioEncoder(); + // TODO-ENCODING: document in public docs that bit_rate value is only + // best-effort, matching to the closest supported bit_rate. I.e. passing 1 is + // like passing 0, which results in choosing the minimum supported bit rate. + // Passing 44_100 could result in output being 44000 if only 44000 is + // supported. AudioEncoder( const torch::Tensor wf, int sampleRate, - std::string_view fileName); + std::string_view fileName, + std::optional bit_rate = std::nullopt); void encode(); private: diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 05a6390d..d7ec2b9f 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -30,7 +30,7 @@ TORCH_LIBRARY(torchcodec_ns, m) { "torchcodec._core.ops", "//pytorch/torchcodec:torchcodec"); m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor"); m.def( - "create_audio_encoder(Tensor wf, int sample_rate, str filename) -> Tensor"); + "create_audio_encoder(Tensor wf, int sample_rate, str filename, int? bit_rate=None) -> Tensor"); m.def("encode_audio(Tensor(a!) encoder) -> ()"); m.def( "create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor"); @@ -399,7 +399,8 @@ AudioEncoder* unwrapTensorToGetAudioEncoder(at::Tensor& tensor) { at::Tensor create_audio_encoder( const at::Tensor wf, int64_t sample_rate, - std::string_view file_name) { + std::string_view file_name, + std::optional bit_rate = std::nullopt) { TORCH_CHECK( sample_rate <= std::numeric_limits::max(), "sample_rate=", @@ -407,7 +408,7 @@ at::Tensor create_audio_encoder( " is too large to be cast to an int."); std::unique_ptr uniqueAudioEncoder = std::make_unique( - wf, static_cast(sample_rate), file_name); + wf, static_cast(sample_rate), file_name, bit_rate); return wrapAudioEncoderPointerToTensor(std::move(uniqueAudioEncoder)); } diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index 241df9a0..166ebe55 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -163,7 +163,7 @@ def create_from_file_abstract(filename: str, seek_mode: Optional[str]) -> torch. @register_fake("torchcodec_ns::create_audio_encoder") def create_audio_encoder_abstract( - wf: torch.Tensor, sample_rate: int, filename: str + wf: torch.Tensor, sample_rate: int, filename: str, bit_rate: Optional[int] = None ) -> torch.Tensor: return torch.empty([], dtype=torch.long) diff --git a/test/test_ops.py b/test/test_ops.py index e301701a..515970d5 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1115,6 +1115,13 @@ def test_bad_input(self, tmp_path): sample_rate=10, filename=valid_output_file, ) + with pytest.raises(RuntimeError, match="bit_rate=-1 must be >= 0"): + create_audio_encoder( + wf=self.decode(NASA_AUDIO_MP3), + sample_rate=NASA_AUDIO_MP3.sample_rate, + filename=valid_output_file, + bit_rate=-1, # bad + ) def test_round_trip(self, tmp_path): # Check that decode(encode(samples)) == samples @@ -1127,15 +1134,16 @@ def test_round_trip(self, tmp_path): ) encode_audio(encoder) - # TODO-ENCODING: tol should be stricter. We need to increase the encoded - # bitrate, and / or encode into a lossless format. + # TODO-ENCODING: tol should be stricter. We probably need to encode + # into a lossless format. torch.testing.assert_close( self.decode(encoded_path), source_samples, rtol=0, atol=0.07 ) # TODO-ENCODING: test more encoding formats @pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32)) - def test_against_cli(self, asset, tmp_path): + @pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999)) + def test_against_cli(self, asset, bit_rate, tmp_path): # Encodes samples with our encoder and with the FFmpeg CLI, and checks # that both decoded outputs are equal @@ -1143,12 +1151,9 @@ def test_against_cli(self, asset, tmp_path): encoded_by_us = tmp_path / "our_output.mp3" subprocess.run( - [ - "ffmpeg", - "-i", - str(asset.path), - "-b:a", - "0", # bitrate hardcoded to 0, see corresponding TODO. + ["ffmpeg", "-i", str(asset.path)] + + (["-b:a", f"{bit_rate}"] if bit_rate is not None else []) + + [ str(encoded_by_ffmpeg), ], capture_output=True, @@ -1159,6 +1164,7 @@ def test_against_cli(self, asset, tmp_path): wf=self.decode(asset), sample_rate=asset.sample_rate, filename=str(encoded_by_us), + bit_rate=bit_rate, ) encode_audio(encoder)