diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index 918da235..abec1d21 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -61,6 +61,9 @@ function(make_torchcodec_libraries AVIOContextHolder.cpp FFMPEGCommon.cpp SingleStreamDecoder.cpp + # TODO: lib name should probably not be "*_decoder*" now that it also + # contains an encoder + Encoder.cpp ) if(ENABLE_CUDA) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp new file mode 100644 index 00000000..9d5c1dea --- /dev/null +++ b/src/torchcodec/_core/Encoder.cpp @@ -0,0 +1,219 @@ +#include "src/torchcodec/_core/Encoder.h" +#include "torch/types.h" + +namespace facebook::torchcodec { + +AudioEncoder::~AudioEncoder() {} + +// TODO-ENCODING: disable ffmpeg logs by default + +AudioEncoder::AudioEncoder( + const torch::Tensor wf, + int sampleRate, + std::string_view fileName) + : wf_(wf), sampleRate_(sampleRate) { + TORCH_CHECK( + wf_.dtype() == torch::kFloat32, + "waveform must have float32 dtype, got ", + wf_.dtype()); + TORCH_CHECK( + wf_.dim() == 2, "waveform must have 2 dimensions, got ", wf_.dim()); + AVFormatContext* avFormatContext = nullptr; + auto status = avformat_alloc_output_context2( + &avFormatContext, nullptr, nullptr, fileName.data()); + TORCH_CHECK( + avFormatContext != nullptr, + "Couldn't allocate AVFormatContext. ", + "Check the desired extension? ", + getFFMPEGErrorStringFromErrorCode(status)); + avFormatContext_.reset(avFormatContext); + + // TODO-ENCODING: Should also support encoding into bytes (use + // AVIOBytesContext) + TORCH_CHECK( + !(avFormatContext->oformat->flags & AVFMT_NOFILE), + "AVFMT_NOFILE is set. We only support writing to a file."); + status = avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE); + TORCH_CHECK( + status >= 0, + "avio_open failed: ", + getFFMPEGErrorStringFromErrorCode(status)); + + // We use the AVFormatContext's default codec for that + // specific format/container. + const AVCodec* avCodec = + avcodec_find_encoder(avFormatContext_->oformat->audio_codec); + TORCH_CHECK(avCodec != nullptr, "Codec not found"); + + AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec); + TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context."); + avCodecContext_.reset(avCodecContext); + + // TODO-ENCODING I think this sets the bit rate to the minimum supported. + // That's not what the ffmpeg CLI would choose by default, so we should try to + // do the same. + // TODO-ENCODING Should also let user choose for compressed formats like mp3. + avCodecContext_->bit_rate = 0; + + avCodecContext_->sample_rate = sampleRate_; + + // Note: This is the format of the **input** waveform. This doesn't determine + // the output. + // TODO-ENCODING check contiguity of the input wf to ensure that it is indeed + // planar. + // TODO-ENCODING If the encoder doesn't support FLTP (like flac), FFmpeg will + // raise. We need to handle this, probably converting the format with + // libswresample. + avCodecContext_->sample_fmt = AV_SAMPLE_FMT_FLTP; + + int numChannels = static_cast(wf_.sizes()[0]); + TORCH_CHECK( + // TODO-ENCODING is this even true / needed? We can probably support more + // with non-planar data? + numChannels <= AV_NUM_DATA_POINTERS, + "Trying to encode ", + numChannels, + " channels, but FFmpeg only supports ", + AV_NUM_DATA_POINTERS, + " channels per frame."); + + setDefaultChannelLayout(avCodecContext_, numChannels); + + status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr); + TORCH_CHECK( + status == AVSUCCESS, + "avcodec_open2 failed: ", + getFFMPEGErrorStringFromErrorCode(status)); + + TORCH_CHECK( + avCodecContext_->frame_size > 0, + "frame_size is ", + avCodecContext_->frame_size, + ". Cannot encode. This should probably never happen?"); + + // We're allocating the stream here. Streams are meant to be freed by + // avformat_free_context(avFormatContext), which we call in the + // avFormatContext_'s destructor. + AVStream* avStream = avformat_new_stream(avFormatContext_.get(), nullptr); + TORCH_CHECK(avStream != nullptr, "Couldn't create new stream."); + status = avcodec_parameters_from_context( + avStream->codecpar, avCodecContext_.get()); + TORCH_CHECK( + status == AVSUCCESS, + "avcodec_parameters_from_context failed: ", + getFFMPEGErrorStringFromErrorCode(status)); + streamIndex_ = avStream->index; +} + +void AudioEncoder::encode() { + UniqueAVFrame avFrame(av_frame_alloc()); + TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame."); + avFrame->nb_samples = avCodecContext_->frame_size; + avFrame->format = avCodecContext_->sample_fmt; + avFrame->sample_rate = avCodecContext_->sample_rate; + avFrame->pts = 0; + setChannelLayout(avFrame, avCodecContext_); + + auto status = av_frame_get_buffer(avFrame.get(), 0); + TORCH_CHECK( + status == AVSUCCESS, + "Couldn't allocate avFrame's buffers: ", + getFFMPEGErrorStringFromErrorCode(status)); + + AutoAVPacket autoAVPacket; + + uint8_t* pwf = static_cast(wf_.data_ptr()); + int numSamples = static_cast(wf_.sizes()[1]); // per channel + int numEncodedSamples = 0; // per channel + int numSamplesPerFrame = avCodecContext_->frame_size; // per channel + int numBytesPerSample = static_cast(wf_.element_size()); + int numBytesPerChannel = numSamples * numBytesPerSample; + + status = avformat_write_header(avFormatContext_.get(), nullptr); + TORCH_CHECK( + status == AVSUCCESS, + "Error in avformat_write_header: ", + getFFMPEGErrorStringFromErrorCode(status)); + + while (numEncodedSamples < numSamples) { + status = av_frame_make_writable(avFrame.get()); + TORCH_CHECK( + status == AVSUCCESS, + "Couldn't make AVFrame writable: ", + getFFMPEGErrorStringFromErrorCode(status)); + + int numSamplesToEncode = + std::min(numSamplesPerFrame, numSamples - numEncodedSamples); + int numBytesToEncode = numSamplesToEncode * numBytesPerSample; + + for (int ch = 0; ch < wf_.sizes()[0]; ch++) { + std::memcpy( + avFrame->data[ch], pwf + ch * numBytesPerChannel, numBytesToEncode); + } + pwf += numBytesToEncode; + + // Above, we set the AVFrame's .nb_samples to AVCodecContext.frame_size so + // that the frame buffers are allocated to a big enough size. Here, we reset + // it to the exact number of samples that need to be encoded, otherwise the + // encoded frame would contain more samples than necessary and our results + // wouldn't match the ffmpeg CLI. + avFrame->nb_samples = numSamplesToEncode; + encodeInnerLoop(autoAVPacket, avFrame); + + avFrame->pts += static_cast(numSamplesToEncode); + numEncodedSamples += numSamplesToEncode; + } + TORCH_CHECK(numEncodedSamples == numSamples, "Hmmmmmm something went wrong."); + + flushBuffers(); + + status = av_write_trailer(avFormatContext_.get()); + TORCH_CHECK( + status == AVSUCCESS, + "Error in: av_write_trailer", + getFFMPEGErrorStringFromErrorCode(status)); +} + +void AudioEncoder::encodeInnerLoop( + AutoAVPacket& autoAVPacket, + const UniqueAVFrame& avFrame) { + auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get()); + TORCH_CHECK( + status == AVSUCCESS, + "Error while sending frame: ", + getFFMPEGErrorStringFromErrorCode(status)); + + while (status >= 0) { + ReferenceAVPacket packet(autoAVPacket); + status = avcodec_receive_packet(avCodecContext_.get(), packet.get()); + if (status == AVERROR(EAGAIN) || status == AVERROR_EOF) { + // TODO-ENCODING this is from TorchAudio, probably needed, but not sure. + // if (status == AVERROR_EOF) { + // status = av_interleaved_write_frame(avFormatContext_.get(), + // nullptr); TORCH_CHECK( + // status == AVSUCCESS, + // "Failed to flush packet ", + // getFFMPEGErrorStringFromErrorCode(status)); + // } + return; + } + TORCH_CHECK( + status >= 0, + "Error receiving packet: ", + getFFMPEGErrorStringFromErrorCode(status)); + + packet->stream_index = streamIndex_; + + status = av_interleaved_write_frame(avFormatContext_.get(), packet.get()); + TORCH_CHECK( + status == AVSUCCESS, + "Error in av_interleaved_write_frame: ", + getFFMPEGErrorStringFromErrorCode(status)); + } +} + +void AudioEncoder::flushBuffers() { + AutoAVPacket autoAVPacket; + encodeInnerLoop(autoAVPacket, UniqueAVFrame(nullptr)); +} +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h new file mode 100644 index 00000000..f0621fe5 --- /dev/null +++ b/src/torchcodec/_core/Encoder.h @@ -0,0 +1,36 @@ +#pragma once +#include +#include "src/torchcodec/_core/FFMPEGCommon.h" + +namespace facebook::torchcodec { +class AudioEncoder { + public: + ~AudioEncoder(); + + AudioEncoder( + const torch::Tensor wf, + int sampleRate, + std::string_view fileName); + void encode(); + + private: + void encodeInnerLoop( + AutoAVPacket& autoAVPacket, + const UniqueAVFrame& avFrame); + void flushBuffers(); + + UniqueEncodingAVFormatContext avFormatContext_; + UniqueAVCodecContext avCodecContext_; + int streamIndex_; + + const torch::Tensor wf_; + // The *output* sample rate. We can't really decide for the user what it + // should be. Particularly, the sample rate of the input waveform should match + // this, and that's up to the user. If sample rates don't match, encoding will + // still work but audio will be distorted. + // We technically could let the user also specify the input sample rate, and + // resample the waveform internally to match them, but that's not in scope for + // an initial version (if at all). + int sampleRate_; +}; +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/FFMPEGCommon.cpp b/src/torchcodec/_core/FFMPEGCommon.cpp index 33c8b484..64e4da70 100644 --- a/src/torchcodec/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/_core/FFMPEGCommon.cpp @@ -74,6 +74,38 @@ int getNumChannels(const UniqueAVCodecContext& avCodecContext) { #endif } +void setDefaultChannelLayout( + UniqueAVCodecContext& avCodecContext, + int numChannels) { +#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4 + AVChannelLayout channel_layout; + av_channel_layout_default(&channel_layout, numChannels); + avCodecContext->ch_layout = channel_layout; + +#else + uint64_t channel_layout = av_get_default_channel_layout(numChannels); + avCodecContext->channel_layout = channel_layout; + avCodecContext->channels = numChannels; +#endif +} + +void setChannelLayout( + UniqueAVFrame& dstAVFrame, + const UniqueAVCodecContext& avCodecContext) { +#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4 + auto status = av_channel_layout_copy( + &dstAVFrame->ch_layout, &avCodecContext->ch_layout); + TORCH_CHECK( + status == AVSUCCESS, + "Couldn't copy channel layout to avFrame: ", + getFFMPEGErrorStringFromErrorCode(status)); +#else + dstAVFrame->channel_layout = avCodecContext->channel_layout; + dstAVFrame->channels = avCodecContext->channels; + +#endif +} + void setChannelLayout( UniqueAVFrame& dstAVFrame, const UniqueAVFrame& srcAVFrame) { diff --git a/src/torchcodec/_core/FFMPEGCommon.h b/src/torchcodec/_core/FFMPEGCommon.h index 0309bf93..fdb30962 100644 --- a/src/torchcodec/_core/FFMPEGCommon.h +++ b/src/torchcodec/_core/FFMPEGCommon.h @@ -50,9 +50,12 @@ struct Deleter { }; // Unique pointers for FFMPEG structures. -using UniqueAVFormatContext = std::unique_ptr< +using UniqueDecodingAVFormatContext = std::unique_ptr< AVFormatContext, Deleterp>; +using UniqueEncodingAVFormatContext = std::unique_ptr< + AVFormatContext, + Deleter>; using UniqueAVCodecContext = std::unique_ptr< AVCodecContext, Deleterp>; @@ -144,6 +147,14 @@ int64_t getDuration(const UniqueAVFrame& frame); int getNumChannels(const UniqueAVFrame& avFrame); int getNumChannels(const UniqueAVCodecContext& avCodecContext); +void setDefaultChannelLayout( + UniqueAVCodecContext& avCodecContext, + int numChannels); + +void setChannelLayout( + UniqueAVFrame& dstAVFrame, + const UniqueAVCodecContext& avCodecContext); + void setChannelLayout( UniqueAVFrame& dstAVFrame, const UniqueAVFrame& srcAVFrame); diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index efd93498..b7438f19 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -1443,7 +1443,7 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU( auto numBytesPerChannel = numSamples * av_get_bytes_per_sample(format); for (auto channel = 0; channel < numChannels; ++channel, outputChannelData += numBytesPerChannel) { - memcpy( + std::memcpy( outputChannelData, avFrame->extended_data[channel], numBytesPerChannel); diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index df333d03..f712cdbb 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -492,7 +492,7 @@ class SingleStreamDecoder { SeekMode seekMode_; ContainerMetadata containerMetadata_; - UniqueAVFormatContext formatContext_; + UniqueDecodingAVFormatContext formatContext_; std::map streamInfos_; const int NO_ACTIVE_STREAM = -2; int activeStreamIndex_ = NO_ACTIVE_STREAM; diff --git a/src/torchcodec/_core/__init__.py b/src/torchcodec/_core/__init__.py index 9de779f6..4be8a7de 100644 --- a/src/torchcodec/_core/__init__.py +++ b/src/torchcodec/_core/__init__.py @@ -18,10 +18,12 @@ _test_frame_pts_equality, add_audio_stream, add_video_stream, + create_audio_encoder, create_from_bytes, create_from_file, create_from_file_like, create_from_tensor, + encode_audio, get_ffmpeg_library_versions, get_frame_at_index, get_frame_at_pts, diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 45324908..596412a8 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -11,6 +11,7 @@ #include "c10/core/SymIntArrayRef.h" #include "c10/util/Exception.h" #include "src/torchcodec/_core/AVIOBytesContext.h" +#include "src/torchcodec/_core/Encoder.h" #include "src/torchcodec/_core/SingleStreamDecoder.h" namespace facebook::torchcodec { @@ -27,6 +28,9 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.impl_abstract_pystub( "torchcodec._core.ops", "//pytorch/torchcodec:torchcodec"); m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor"); + m.def( + "create_audio_encoder(Tensor wf, int sample_rate, str filename) -> Tensor"); + m.def("encode_audio(Tensor(a!) encoder) -> ()"); m.def( "create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor"); m.def("_convert_to_tensor(int decoder_ptr) -> Tensor"); @@ -381,6 +385,45 @@ OpsAudioFramesOutput get_frames_by_pts_in_range_audio( return makeOpsAudioFramesOutput(result); } +at::Tensor wrapAudioEncoderPointerToTensor( + std::unique_ptr uniqueAudioEncoder) { + AudioEncoder* encoder = uniqueAudioEncoder.release(); + + auto deleter = [encoder](void*) { delete encoder; }; + at::Tensor tensor = + at::from_blob(encoder, {sizeof(AudioEncoder*)}, deleter, {at::kLong}); + auto encoder_ = static_cast(tensor.mutable_data_ptr()); + TORCH_CHECK_EQ(encoder_, encoder) << "AudioEncoder=" << encoder_; + return tensor; +} + +AudioEncoder* unwrapTensorToGetAudioEncoder(at::Tensor& tensor) { + TORCH_INTERNAL_ASSERT(tensor.is_contiguous()); + void* buffer = tensor.mutable_data_ptr(); + AudioEncoder* encoder = static_cast(buffer); + return encoder; +} + +at::Tensor create_audio_encoder( + const at::Tensor wf, + int64_t sample_rate, + std::string_view file_name) { + TORCH_CHECK( + sample_rate <= std::numeric_limits::max(), + "sample_rate=", + sample_rate, + " is too large to be cast to an int."); + std::unique_ptr uniqueAudioEncoder = + std::make_unique( + wf, static_cast(sample_rate), file_name); + return wrapAudioEncoderPointerToTensor(std::move(uniqueAudioEncoder)); +} + +void encode_audio(at::Tensor& encoder) { + auto encoder_ = unwrapTensorToGetAudioEncoder(encoder); + encoder_->encode(); +} + // For testing only. We need to implement this operation as a core library // function because what we're testing is round-tripping pts values as // double-precision floating point numbers from C++ to Python and back to C++. @@ -615,6 +658,7 @@ void scan_all_streams_to_update_metadata(at::Tensor& decoder) { TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) { m.impl("create_from_file", &create_from_file); + m.impl("create_audio_encoder", &create_audio_encoder); m.impl("create_from_tensor", &create_from_tensor); m.impl("_convert_to_tensor", &_convert_to_tensor); m.impl( @@ -622,6 +666,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) { } TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { + m.impl("encode_audio", &encode_audio); m.impl("seek_to_pts", &seek_to_pts); m.impl("add_video_stream", &add_video_stream); m.impl("_add_video_stream", &_add_video_stream); diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index 0f0bdfe2..d910fcad 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -91,6 +91,12 @@ def load_torchcodec_shared_libraries(): create_from_file = torch._dynamo.disallow_in_graph( torch.ops.torchcodec_ns.create_from_file.default ) +create_audio_encoder = torch._dynamo.disallow_in_graph( + torch.ops.torchcodec_ns.create_audio_encoder.default +) +encode_audio = torch._dynamo.disallow_in_graph( + torch.ops.torchcodec_ns.encode_audio.default +) create_from_tensor = torch._dynamo.disallow_in_graph( torch.ops.torchcodec_ns.create_from_tensor.default ) @@ -155,6 +161,18 @@ def create_from_file_abstract(filename: str, seek_mode: Optional[str]) -> torch. return torch.empty([], dtype=torch.long) +@register_fake("torchcodec_ns::create_audio_encoder") +def create_audio_encoder_abstract( + wf: torch.Tensor, sample_rate: int, filename: str +) -> torch.Tensor: + return torch.empty([], dtype=torch.long) + + +@register_fake("torchcodec_ns::encode_audio") +def encode_audio_abstract(encoder: torch.Tensor) -> torch.Tensor: + return torch.empty([], dtype=torch.long) + + @register_fake("torchcodec_ns::create_from_tensor") def create_from_tensor_abstract( video_tensor: torch.Tensor, seek_mode: Optional[str] diff --git a/test/test_ops.py b/test/test_ops.py index 3269fa46..e301701a 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -22,10 +22,12 @@ _test_frame_pts_equality, add_audio_stream, add_video_stream, + create_audio_encoder, create_from_bytes, create_from_file, create_from_file_like, create_from_tensor, + encode_audio, get_ffmpeg_library_versions, get_frame_at_index, get_frame_at_pts, @@ -49,6 +51,7 @@ SINE_MONO_S32, SINE_MONO_S32_44100, SINE_MONO_S32_8000, + TestContainerFile, ) torch._dynamo.config.capture_dynamic_output_shape_ops = True @@ -56,7 +59,7 @@ INDEX_OF_FRAME_AT_6_SECONDS = 180 -class TestVideoOps: +class TestVideoDecoderOps: @pytest.mark.parametrize("device", cpu_and_cuda()) def test_seek_and_next(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) @@ -633,7 +636,7 @@ def test_cuda_decoder(self): ) -class TestAudioOps: +class TestAudioDecoderOps: @pytest.mark.parametrize( "method", ( @@ -1066,5 +1069,103 @@ def seek(self, offset: int, whence: int) -> bytes: ) +class TestAudioEncoderOps: + + def decode(self, source) -> torch.Tensor: + if isinstance(source, TestContainerFile): + source = str(source.path) + else: + source = str(source) + decoder = create_from_file(source, seek_mode="approximate") + add_audio_stream(decoder) + frames, *_ = get_frames_by_pts_in_range_audio( + decoder, start_seconds=0, stop_seconds=None + ) + return frames + + def test_bad_input(self, tmp_path): + + valid_output_file = str(tmp_path / ".mp3") + + with pytest.raises(RuntimeError, match="must have float32 dtype, got int"): + create_audio_encoder( + wf=torch.arange(10, dtype=torch.int), + sample_rate=10, + filename=valid_output_file, + ) + with pytest.raises(RuntimeError, match="must have 2 dimensions, got 1"): + create_audio_encoder( + wf=torch.rand(3), sample_rate=10, filename=valid_output_file + ) + + with pytest.raises(RuntimeError, match="No such file or directory"): + create_audio_encoder( + wf=torch.rand(10, 10), sample_rate=10, filename="./bad/path.mp3" + ) + with pytest.raises(RuntimeError, match="Check the desired extension"): + create_audio_encoder( + wf=torch.rand(10, 10), sample_rate=10, filename="./file.bad_extension" + ) + + # TODO-ENCODING: raise more informative error message when sample rate + # isn't supported + with pytest.raises(RuntimeError, match="Invalid argument"): + create_audio_encoder( + wf=self.decode(NASA_AUDIO_MP3), + sample_rate=10, + filename=valid_output_file, + ) + + def test_round_trip(self, tmp_path): + # Check that decode(encode(samples)) == samples + asset = NASA_AUDIO_MP3 + source_samples = self.decode(asset) + + encoded_path = tmp_path / "output.mp3" + encoder = create_audio_encoder( + wf=source_samples, sample_rate=asset.sample_rate, filename=str(encoded_path) + ) + encode_audio(encoder) + + # TODO-ENCODING: tol should be stricter. We need to increase the encoded + # bitrate, and / or encode into a lossless format. + torch.testing.assert_close( + self.decode(encoded_path), source_samples, rtol=0, atol=0.07 + ) + + # TODO-ENCODING: test more encoding formats + @pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32)) + def test_against_cli(self, asset, tmp_path): + # Encodes samples with our encoder and with the FFmpeg CLI, and checks + # that both decoded outputs are equal + + encoded_by_ffmpeg = tmp_path / "ffmpeg_output.mp3" + encoded_by_us = tmp_path / "our_output.mp3" + + subprocess.run( + [ + "ffmpeg", + "-i", + str(asset.path), + "-b:a", + "0", # bitrate hardcoded to 0, see corresponding TODO. + str(encoded_by_ffmpeg), + ], + capture_output=True, + check=True, + ) + + encoder = create_audio_encoder( + wf=self.decode(asset), + sample_rate=asset.sample_rate, + filename=str(encoded_by_us), + ) + encode_audio(encoder) + + torch.testing.assert_close( + self.decode(encoded_by_ffmpeg), self.decode(encoded_by_us) + ) + + if __name__ == "__main__": pytest.main() diff --git a/test/utils.py b/test/utils.py index ea3e96c6..e7ce12e5 100644 --- a/test/utils.py +++ b/test/utils.py @@ -114,6 +114,8 @@ class TestAudioStreamInfo: @dataclass class TestContainerFile: + __test__ = False # prevents pytest from thinking this is a test class + filename: str default_stream_index: int