From e4f05ce1df1911e6594165bcd107fd5535f97fa1 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 29 Apr 2025 17:48:23 +0100 Subject: [PATCH 1/2] Enforce that encode() cannot be called twice --- src/torchcodec/_core/Encoder.cpp | 9 ++++++--- src/torchcodec/_core/Encoder.h | 2 ++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 114e8600..9b75f4fa 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -208,9 +208,12 @@ torch::Tensor AudioEncoder::encodeToTensor() { } void AudioEncoder::encode() { - // TODO-ENCODING: Need to check, but consecutive calls to encode() are - // probably invalid. We can address this once we (re)design the public and - // private encoding APIs. + // To be on the safe side we enforce that encode() can only be called once on + // an encoder object. Whether this is actually necessary is unknown, so this + // may be relaxed if needed. + TORCH_CHECK(!encodeWasCalled_, "Cannot call encode() twice."); + encodeWasCalled_ = true; + UniqueAVFrame avFrame(av_frame_alloc()); TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame."); // Default to 256 like in torchaudio diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index 17f09d59..bf31c31b 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -49,5 +49,7 @@ class AudioEncoder { // Stores the AVIOContext for the output tensor buffer. std::unique_ptr avioContextHolder_; + + bool encodeWasCalled_ = false; }; } // namespace facebook::torchcodec From 8a02d262f7d63ce89d9799c4916dd00e2d782e29 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 30 Apr 2025 11:14:35 +0100 Subject: [PATCH 2/2] Address other todos --- src/torchcodec/_core/Encoder.cpp | 44 ++++++++++++++++------------- src/torchcodec/_core/custom_ops.cpp | 2 -- test/test_ops.py | 9 ++++-- 3 files changed, 31 insertions(+), 24 deletions(-) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index f5ecfbf4..1c876f4e 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -14,6 +14,18 @@ torch::Tensor validateWf(torch::Tensor wf) { "waveform must have float32 dtype, got ", wf.dtype()); TORCH_CHECK(wf.dim() == 2, "waveform must have 2 dimensions, got ", wf.dim()); + + // We enforce this, but if we get user reports we should investigate whether + // that's actually needed. + int numChannels = static_cast(wf.sizes()[0]); + TORCH_CHECK( + numChannels <= AV_NUM_DATA_POINTERS, + "Trying to encode ", + numChannels, + " channels, but FFmpeg only supports ", + AV_NUM_DATA_POINTERS, + " channels per frame."); + return wf.contiguous(); } @@ -164,18 +176,7 @@ void AudioEncoder::initializeEncoder( // what the `.sample_fmt` defines. avCodecContext_->sample_fmt = findBestOutputSampleFormat(*avCodec); - int numChannels = static_cast(wf_.sizes()[0]); - TORCH_CHECK( - // TODO-ENCODING is this even true / needed? We can probably support more - // with non-planar data? - numChannels <= AV_NUM_DATA_POINTERS, - "Trying to encode ", - numChannels, - " channels, but FFmpeg only supports ", - AV_NUM_DATA_POINTERS, - " channels per frame."); - - setDefaultChannelLayout(avCodecContext_, numChannels); + setDefaultChannelLayout(avCodecContext_, static_cast(wf_.sizes()[0])); int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr); TORCH_CHECK( @@ -325,14 +326,17 @@ void AudioEncoder::encodeInnerLoop( ReferenceAVPacket packet(autoAVPacket); status = avcodec_receive_packet(avCodecContext_.get(), packet.get()); if (status == AVERROR(EAGAIN) || status == AVERROR_EOF) { - // TODO-ENCODING this is from TorchAudio, probably needed, but not sure. - // if (status == AVERROR_EOF) { - // status = av_interleaved_write_frame(avFormatContext_.get(), - // nullptr); TORCH_CHECK( - // status == AVSUCCESS, - // "Failed to flush packet ", - // getFFMPEGErrorStringFromErrorCode(status)); - // } + if (status == AVERROR_EOF) { + // Flush the packets that were potentially buffered by + // av_interleaved_write_frame(). See corresponding block in + // TorchAudio: + // https://github.com/pytorch/audio/blob/d60ce09e2c532d5bf2e05619e700ab520543465e/src/libtorio/ffmpeg/stream_writer/encoder.cpp#L21 + status = av_interleaved_write_frame(avFormatContext_.get(), nullptr); + TORCH_CHECK( + status == AVSUCCESS, + "Failed to flush packet: ", + getFFMPEGErrorStringFromErrorCode(status)); + } return; } TORCH_CHECK( diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 2f470617..813c53a7 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -394,8 +394,6 @@ void encode_audio_to_file( .encode(); } -// TODO-ENCODING is "format" a good parameter name?? It kinda conflicts with -// "sample_format" which we may eventually want to expose. at::Tensor encode_audio_to_tensor( const at::Tensor wf, int64_t sample_rate, diff --git a/test/test_ops.py b/test/test_ops.py index ddca330a..6e53d27b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1132,11 +1132,11 @@ def test_bad_input(self, tmp_path): with pytest.raises(RuntimeError, match="No such file or directory"): encode_audio_to_file( - wf=torch.rand(10, 10), sample_rate=10, filename="./bad/path.mp3" + wf=torch.rand(2, 10), sample_rate=10, filename="./bad/path.mp3" ) with pytest.raises(RuntimeError, match="Check the desired extension"): encode_audio_to_file( - wf=torch.rand(10, 10), sample_rate=10, filename="./file.bad_extension" + wf=torch.rand(2, 10), sample_rate=10, filename="./file.bad_extension" ) with pytest.raises(RuntimeError, match="invalid sample rate=10"): @@ -1153,6 +1153,11 @@ def test_bad_input(self, tmp_path): bit_rate=-1, # bad ) + with pytest.raises(RuntimeError, match="Trying to encode 10 channels"): + encode_audio_to_file( + wf=torch.rand(10, 20), sample_rate=10, filename="doesnt_matter" + ) + @pytest.mark.parametrize( "encode_method", (encode_audio_to_file, encode_audio_to_tensor) )