From d5fe9968a3f4546df3a2e7979fa7861d25ceecb3 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 26 Feb 2025 10:55:57 +0000 Subject: [PATCH 01/18] Super WIP encoder --- .../decoders/_core/VideoDecoder.cpp | 127 ++++++++++++++++++ src/torchcodec/decoders/_core/VideoDecoder.h | 20 +++ .../decoders/_core/VideoDecoderOps.cpp | 32 +++++ src/torchcodec/decoders/_core/__init__.py | 2 + .../decoders/_core/video_decoder_ops.py | 14 ++ 5 files changed, 195 insertions(+) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 116379eb..bf577008 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1695,4 +1695,131 @@ FrameDims getHeightAndWidthFromOptionsOrAVFrame( videoStreamOptions.width.value_or(avFrame.width)); } +Encoder::~Encoder() { + fclose(f_); +} + +Encoder::Encoder(torch::Tensor& wf) : wf_(wf) { + f_ = fopen("./coutput", "wb"); + TORCH_CHECK(f_, "Could not open file"); + const AVCodec* avCodec = avcodec_find_encoder(AV_CODEC_ID_MP3); + TORCH_CHECK(avCodec != nullptr, "Codec not found"); + + AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec); + TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context."); + avCodecContext_.reset(avCodecContext); + + avCodecContext_->bit_rate = 0; // TODO + avCodecContext_->sample_fmt = AV_SAMPLE_FMT_FLTP; // TODO + avCodecContext_->sample_rate = 16000; // TODO + AVChannelLayout channel_layout; + av_channel_layout_default(&channel_layout, 2); + avCodecContext_->ch_layout = channel_layout; + + auto ffmpegRet = avcodec_open2(avCodecContext_.get(), avCodec, nullptr); + TORCH_CHECK( + ffmpegRet == AVSUCCESS, getFFMPEGErrorStringFromErrorCode(ffmpegRet)); + + AVFrame* avFrame = av_frame_alloc(); + TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame."); + avFrame_.reset(avFrame); + avFrame_->nb_samples = avCodecContext_->frame_size; + avFrame_->format = avCodecContext_->sample_fmt; + avFrame_->sample_rate = avCodecContext_->sample_rate; + + ffmpegRet = + av_channel_layout_copy(&avFrame_->ch_layout, &avCodecContext_->ch_layout); + TORCH_CHECK( + ffmpegRet == AVSUCCESS, + "Couldn't copy channel layout to avFrame: ", + getFFMPEGErrorStringFromErrorCode(ffmpegRet)); + ffmpegRet = av_frame_get_buffer(avFrame_.get(), 0); + TORCH_CHECK( + ffmpegRet == AVSUCCESS, + "Couldn't allocate avFrame's buffers: ", + getFFMPEGErrorStringFromErrorCode(ffmpegRet)); +} + +torch::Tensor Encoder::encode() { + AVPacket* pkt = av_packet_alloc(); + if (!pkt) { + fprintf(stderr, "Could not allocate audio packet\n"); + exit(1); + } + + auto MAX_NUM_BYTES = 10000000; // 10Mb. TODO find a way not to pre-allocate. + int numEncodedBytes = 0; + torch::Tensor outputTensor = torch::empty({MAX_NUM_BYTES}, torch::kUInt8); + uint8_t* pOutputTensor = + static_cast(outputTensor.data_ptr()); + + uint8_t* pWf = static_cast(wf_.data_ptr()); + auto numBytesWeWroteFromWF = 0; + auto numBytesPerSample = wf_.element_size(); + auto numBytesPerChannel = wf_.sizes()[1] * numBytesPerSample; + + // TODO need simpler/cleaner while loop condition. + while (numBytesWeWroteFromWF < numBytesPerChannel) { + auto ffmpegRet = av_frame_make_writable(avFrame_.get()); + TORCH_CHECK( + ffmpegRet == AVSUCCESS, + "Couldn't make AVFrame writable: ", + getFFMPEGErrorStringFromErrorCode(ffmpegRet)); + + auto numBytesToWrite = numBytesPerSample * avCodecContext_->frame_size; + if (numBytesWeWroteFromWF + numBytesToWrite > numBytesPerChannel) { + numBytesToWrite = numBytesPerChannel - numBytesWeWroteFromWF; + } + for (int ch = 0; ch < 2; ch++) { + memcpy( + avFrame_->data[ch], pWf + ch * numBytesPerChannel, numBytesToWrite); + } + pWf += numBytesToWrite; + numBytesWeWroteFromWF += numBytesToWrite; + encode_inner_loop(pkt, pOutputTensor, &numEncodedBytes, false); + } + encode_inner_loop(pkt, pOutputTensor, &numEncodedBytes, true); + + return outputTensor.narrow( + /*dim=*/0, /*start=*/0, /*length=*/numEncodedBytes); + // return outputTensor; +} + +void Encoder::encode_inner_loop( + AVPacket* pkt, + uint8_t* pOutputTensor, + int* numEncodedBytes, + bool flush) { + int ffmpegRet = 0; + + // TODO ewwww + if (flush) { + ffmpegRet = avcodec_send_frame(avCodecContext_.get(), nullptr); + } else { + ffmpegRet = avcodec_send_frame(avCodecContext_.get(), avFrame_.get()); + } + TORCH_CHECK( + ffmpegRet == AVSUCCESS, + "Error while sending frame: ", + getFFMPEGErrorStringFromErrorCode(ffmpegRet)); + + while ((ffmpegRet = avcodec_receive_packet(avCodecContext_.get(), pkt)) >= + 0) { + if (ffmpegRet == AVERROR(EAGAIN) || ffmpegRet == AVERROR_EOF) { + return; + } + TORCH_CHECK( + ffmpegRet >= 0, + "Error receiving packet: ", + getFFMPEGErrorStringFromErrorCode(ffmpegRet)); + + fwrite(pkt->data, 1, pkt->size, f_); + + memcpy(pOutputTensor + *numEncodedBytes, pkt->data, pkt->size); + *numEncodedBytes += pkt->size; + + av_packet_unref(pkt); + } +} + } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index e7197385..5bc97c13 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -545,4 +545,24 @@ std::ostream& operator<<( std::ostream& os, const VideoDecoder::DecodeStats& stats); +class Encoder { + public: + ~Encoder(); + + explicit Encoder(torch::Tensor& wf); + torch::Tensor encode(); + + private: + void encode_inner_loop( + AVPacket* pkt, + uint8_t* pOutputTensor, + int* numEncodedBytes, + bool flush); + + torch::Tensor wf_; + UniqueAVCodecContext avCodecContext_; + UniqueAVFrame avFrame_; + FILE* f_; +}; + } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index d0fefcd9..e87a81d2 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -28,6 +28,8 @@ TORCH_LIBRARY(torchcodec_ns, m) { "torchcodec.decoders._core.video_decoder_ops", "//pytorch/torchcodec:torchcodec"); m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor"); + m.def("create_encoder(Tensor wf) -> Tensor"); + m.def("encode(Tensor(a!) encoder) -> Tensor"); m.def( "create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor"); m.def( @@ -72,6 +74,17 @@ at::Tensor wrapDecoderPointerToTensor( return tensor; } +at::Tensor wrapEncoderPointerToTensor(std::unique_ptr uniqueEncoder) { + Encoder* encoder = uniqueEncoder.release(); + + auto deleter = [encoder](void*) { delete encoder; }; + at::Tensor tensor = + at::from_blob(encoder, {sizeof(Encoder)}, deleter, {at::kLong}); + auto encoder_ = static_cast(tensor.mutable_data_ptr()); + TORCH_CHECK_EQ(encoder_, encoder) << "Encoder=" << encoder_; + return tensor; +} + VideoDecoder* unwrapTensorToGetDecoder(at::Tensor& tensor) { TORCH_INTERNAL_ASSERT(tensor.is_contiguous()); void* buffer = tensor.mutable_data_ptr(); @@ -79,6 +92,13 @@ VideoDecoder* unwrapTensorToGetDecoder(at::Tensor& tensor) { return decoder; } +Encoder* unwrapTensorToGetEncoder(at::Tensor& tensor) { + TORCH_INTERNAL_ASSERT(tensor.is_contiguous()); + void* buffer = tensor.mutable_data_ptr(); + Encoder* encoder = static_cast(buffer); + return encoder; +} + OpsFrameOutput makeOpsFrameOutput(VideoDecoder::FrameOutput& frame) { return std::make_tuple( frame.data, @@ -123,6 +143,16 @@ at::Tensor create_from_file( return wrapDecoderPointerToTensor(std::move(uniqueDecoder)); } +at::Tensor create_encoder(torch::Tensor& wf) { + std::unique_ptr uniqueEncoder = std::make_unique(wf); + return wrapEncoderPointerToTensor(std::move(uniqueEncoder)); +} + +at::Tensor encode(at::Tensor& encoder) { + auto encoder_ = unwrapTensorToGetEncoder(encoder); + return encoder_->encode(); +} + at::Tensor create_from_tensor( at::Tensor video_tensor, std::optional seek_mode) { @@ -512,12 +542,14 @@ void scan_all_streams_to_update_metadata(at::Tensor& decoder) { TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) { m.impl("create_from_file", &create_from_file); + m.impl("create_encoder", &create_encoder); m.impl("create_from_tensor", &create_from_tensor); m.impl( "_get_json_ffmpeg_library_versions", &_get_json_ffmpeg_library_versions); } TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { + m.impl("encode", &encode); m.impl("seek_to_pts", &seek_to_pts); m.impl("add_video_stream", &add_video_stream); m.impl("_add_video_stream", &_add_video_stream); diff --git a/src/torchcodec/decoders/_core/__init__.py b/src/torchcodec/decoders/_core/__init__.py index d39d3d23..e774a43c 100644 --- a/src/torchcodec/decoders/_core/__init__.py +++ b/src/torchcodec/decoders/_core/__init__.py @@ -16,9 +16,11 @@ _get_key_frame_indices, _test_frame_pts_equality, add_video_stream, + create_encoder, create_from_bytes, create_from_file, create_from_tensor, + encode, get_ffmpeg_library_versions, get_frame_at_index, get_frame_at_pts, diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index 40216304..cba7492f 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -64,6 +64,10 @@ def load_torchcodec_extension(): create_from_file = torch._dynamo.disallow_in_graph( torch.ops.torchcodec_ns.create_from_file.default ) +create_encoder = torch._dynamo.disallow_in_graph( + torch.ops.torchcodec_ns.create_encoder.default +) +encode = torch._dynamo.disallow_in_graph(torch.ops.torchcodec_ns.encode.default) create_from_tensor = torch._dynamo.disallow_in_graph( torch.ops.torchcodec_ns.create_from_tensor.default ) @@ -114,6 +118,16 @@ def create_from_file_abstract(filename: str, seek_mode: Optional[str]) -> torch. return torch.empty([], dtype=torch.long) +@register_fake("torchcodec_ns::create_encoder") +def create_encoder_abstract(wf: torch.Tensor) -> torch.Tensor: + return torch.empty([], dtype=torch.long) + + +@register_fake("torchcodec_ns::encode") +def encode_abstract(encoder: torch.Tensor) -> torch.Tensor: + return torch.empty([], dtype=torch.long) + + @register_fake("torchcodec_ns::create_from_tensor") def create_from_tensor_abstract( video_tensor: torch.Tensor, seek_mode: Optional[str] From 779f19ee812d02420d3d272019b1c1371cf99e78 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 27 Feb 2025 18:25:24 +0000 Subject: [PATCH 02/18] Write output file through AVFormatContext --- src/torchcodec/decoders/_core/CMakeLists.txt | 3 +- .../decoders/_core/VideoDecoder.cpp | 116 ++++++++++++++++-- src/torchcodec/decoders/_core/VideoDecoder.h | 20 ++- .../decoders/_core/VideoDecoderOps.cpp | 13 +- .../decoders/_core/video_decoder_ops.py | 4 +- 5 files changed, 131 insertions(+), 25 deletions(-) diff --git a/src/torchcodec/decoders/_core/CMakeLists.txt b/src/torchcodec/decoders/_core/CMakeLists.txt index 688a249d..cefa11e3 100644 --- a/src/torchcodec/decoders/_core/CMakeLists.txt +++ b/src/torchcodec/decoders/_core/CMakeLists.txt @@ -4,7 +4,8 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) find_package(Torch REQUIRED) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Werror ${TORCH_CXX_FLAGS}") +# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Werror ${TORCH_CXX_FLAGS}") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall ${TORCH_CXX_FLAGS}") find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development) function(make_torchcodec_library library_name ffmpeg_target) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index bf577008..5ff1b3be 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1697,42 +1697,87 @@ FrameDims getHeightAndWidthFromOptionsOrAVFrame( Encoder::~Encoder() { fclose(f_); + // TODO NEED TO CALL THIS + // avformat_free_context(avFormatContext_.get()); } -Encoder::Encoder(torch::Tensor& wf) : wf_(wf) { +Encoder::Encoder(int sampleRate, std::string_view fileName) + : sampleRate_(sampleRate) { f_ = fopen("./coutput", "wb"); TORCH_CHECK(f_, "Could not open file"); - const AVCodec* avCodec = avcodec_find_encoder(AV_CODEC_ID_MP3); + + AVFormatContext* avFormatContext = nullptr; + avformat_alloc_output_context2(&avFormatContext, NULL, NULL, fileName.data()); + TORCH_CHECK(avFormatContext != nullptr, "Couldn't allocate AVFormatContext."); + avFormatContext_.reset(avFormatContext); + + TORCH_CHECK( + !(avFormatContext->oformat->flags & AVFMT_NOFILE), + "AVFMT_NOFILE is set. We only support writing to a file."); + auto ffmpegRet = + avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE); + TORCH_CHECK( + ffmpegRet >= 0, + "avio_open failed: ", + getFFMPEGErrorStringFromErrorCode(ffmpegRet)); + + // We use the AVFormatContext's default codec for that + // specificavcodec_parameters_from_context format/container. + const AVCodec* avCodec = + avcodec_find_encoder(avFormatContext_->oformat->audio_codec); TORCH_CHECK(avCodec != nullptr, "Codec not found"); AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec); TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context."); avCodecContext_.reset(avCodecContext); - avCodecContext_->bit_rate = 0; // TODO - avCodecContext_->sample_fmt = AV_SAMPLE_FMT_FLTP; // TODO - avCodecContext_->sample_rate = 16000; // TODO + // I think this will use the default. TODO Should let user choose for + // compressed formats like mp3. + avCodecContext_->bit_rate = 0; + + // TODO A given encoder only supports a finite set of output sample rates. + // FFmpeg raises informative error message. Are we happy with that, or do we + // run our own checks by checking against avCodec->supported_samplerates? + avCodecContext_->sample_rate = sampleRate_; + + // Note: This is the format of the **input** waveform. This doesn't determine + // the output. TODO check contiguity of the input wf to ensure that it is + // indeed planar. + // TODO What if the encoder doesn't support FLTP? Like flac? + avCodecContext_->sample_fmt = AV_SAMPLE_FMT_FLTP; + AVChannelLayout channel_layout; av_channel_layout_default(&channel_layout, 2); avCodecContext_->ch_layout = channel_layout; - auto ffmpegRet = avcodec_open2(avCodecContext_.get(), avCodec, nullptr); + ffmpegRet = avcodec_open2(avCodecContext_.get(), avCodec, nullptr); TORCH_CHECK( ffmpegRet == AVSUCCESS, getFFMPEGErrorStringFromErrorCode(ffmpegRet)); + TORCH_CHECK( + avCodecContext_->frame_size > 0, + "frame_size is ", + avCodecContext_->frame_size, + ". Cannot encode. This should probably never happen?"); + + avStream_ = avformat_new_stream(avFormatContext_.get(), NULL); + TORCH_CHECK(avStream_ != nullptr, "Couldn't create new stream."); + avcodec_parameters_from_context(avStream_->codecpar, avCodecContext_.get()); + AVFrame* avFrame = av_frame_alloc(); TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame."); avFrame_.reset(avFrame); avFrame_->nb_samples = avCodecContext_->frame_size; avFrame_->format = avCodecContext_->sample_fmt; avFrame_->sample_rate = avCodecContext_->sample_rate; - + avFrame_->pts = 0; ffmpegRet = av_channel_layout_copy(&avFrame_->ch_layout, &avCodecContext_->ch_layout); TORCH_CHECK( ffmpegRet == AVSUCCESS, "Couldn't copy channel layout to avFrame: ", getFFMPEGErrorStringFromErrorCode(ffmpegRet)); + ffmpegRet = av_frame_get_buffer(avFrame_.get(), 0); TORCH_CHECK( ffmpegRet == AVSUCCESS, @@ -1740,7 +1785,7 @@ Encoder::Encoder(torch::Tensor& wf) : wf_(wf) { getFFMPEGErrorStringFromErrorCode(ffmpegRet)); } -torch::Tensor Encoder::encode() { +torch::Tensor Encoder::encode(const torch::Tensor& wf) { AVPacket* pkt = av_packet_alloc(); if (!pkt) { fprintf(stderr, "Could not allocate audio packet\n"); @@ -1753,14 +1798,31 @@ torch::Tensor Encoder::encode() { uint8_t* pOutputTensor = static_cast(outputTensor.data_ptr()); - uint8_t* pWf = static_cast(wf_.data_ptr()); + uint8_t* pWf = static_cast(wf.data_ptr()); auto numBytesWeWroteFromWF = 0; - auto numBytesPerSample = wf_.element_size(); - auto numBytesPerChannel = wf_.sizes()[1] * numBytesPerSample; + auto numBytesPerSample = wf.element_size(); + auto numBytesPerChannel = wf.sizes()[1] * numBytesPerSample; + auto numChannels = wf.sizes()[0]; + + TORCH_CHECK( + // TODO is this even true / needed? We can probably support more with + // non-planar data? + numChannels <= AV_NUM_DATA_POINTERS, + "Trying to encode ", + numChannels, + " channels, but FFmpeg only supports ", + AV_NUM_DATA_POINTERS, + " channels per frame."); + + auto ffmpegRet = avformat_write_header(avFormatContext_.get(), NULL); + TORCH_CHECK( + ffmpegRet == AVSUCCESS, + "Error in avformat_write_header: ", + getFFMPEGErrorStringFromErrorCode(ffmpegRet)); // TODO need simpler/cleaner while loop condition. while (numBytesWeWroteFromWF < numBytesPerChannel) { - auto ffmpegRet = av_frame_make_writable(avFrame_.get()); + ffmpegRet = av_frame_make_writable(avFrame_.get()); TORCH_CHECK( ffmpegRet == AVSUCCESS, "Couldn't make AVFrame writable: ", @@ -1770,16 +1832,24 @@ torch::Tensor Encoder::encode() { if (numBytesWeWroteFromWF + numBytesToWrite > numBytesPerChannel) { numBytesToWrite = numBytesPerChannel - numBytesWeWroteFromWF; } - for (int ch = 0; ch < 2; ch++) { + + for (int ch = 0; ch < numChannels; ch++) { memcpy( avFrame_->data[ch], pWf + ch * numBytesPerChannel, numBytesToWrite); } pWf += numBytesToWrite; numBytesWeWroteFromWF += numBytesToWrite; encode_inner_loop(pkt, pOutputTensor, &numEncodedBytes, false); + avFrame_->pts += avFrame_->nb_samples; } encode_inner_loop(pkt, pOutputTensor, &numEncodedBytes, true); + ffmpegRet = av_write_trailer(avFormatContext_.get()); + TORCH_CHECK( + ffmpegRet == AVSUCCESS, + "Error in : av_write_trailer", + getFFMPEGErrorStringFromErrorCode(ffmpegRet)); + return outputTensor.narrow( /*dim=*/0, /*start=*/0, /*length=*/numEncodedBytes); // return outputTensor; @@ -1806,6 +1876,14 @@ void Encoder::encode_inner_loop( while ((ffmpegRet = avcodec_receive_packet(avCodecContext_.get(), pkt)) >= 0) { if (ffmpegRet == AVERROR(EAGAIN) || ffmpegRet == AVERROR_EOF) { + // TODO this is from TorchAudio, probably needed, but not sure. + // if (ffmpegRet == AVERROR_EOF) { + // ffmpegRet = av_interleaved_write_frame(avFormatContext_.get(), + // nullptr); TORCH_CHECK( + // ffmpegRet == AVSUCCESS, + // "Failed to flush packet ", + // getFFMPEGErrorStringFromErrorCode(ffmpegRet)); + // } return; } TORCH_CHECK( @@ -1813,6 +1891,18 @@ void Encoder::encode_inner_loop( "Error receiving packet: ", getFFMPEGErrorStringFromErrorCode(ffmpegRet)); + // TODO why are these 2 lines needed?? + // av_packet_rescale_ts(pkt, avCodecContext_->time_base, + // avStream_->time_base); + pkt->stream_index = avStream_->index; + printf("PACKET PTS %d\n", pkt->pts); + + ffmpegRet = av_interleaved_write_frame(avFormatContext_.get(), pkt); + TORCH_CHECK( + ffmpegRet == AVSUCCESS, + "Error in av_interleaved_write_frame: ", + getFFMPEGErrorStringFromErrorCode(ffmpegRet)); + fwrite(pkt->data, 1, pkt->size, f_); memcpy(pOutputTensor + *numEncodedBytes, pkt->data, pkt->size); diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 5bc97c13..70145ffe 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -549,8 +549,12 @@ class Encoder { public: ~Encoder(); - explicit Encoder(torch::Tensor& wf); - torch::Tensor encode(); + // TODO Are we OK passing a string_view to the constructor? + // TODO fileName should be optional. + // TODO doesn't make much sense to pass fileName and the wf tensor in 2 + // different calls. Same with sampleRate. + Encoder(int sampleRate, std::string_view fileName); + torch::Tensor encode(const torch::Tensor& wf); private: void encode_inner_loop( @@ -559,9 +563,19 @@ class Encoder { int* numEncodedBytes, bool flush); - torch::Tensor wf_; + UniqueAVFormatContext avFormatContext_; UniqueAVCodecContext avCodecContext_; UniqueAVFrame avFrame_; + AVStream* avStream_; + + // The *output* sample rate. We can't really decide for the user what it + // should be. Particularly, the sample rate of the input waveform should match + // this, and that's up to the user. If sample rates don't match, encoding will + // still work but audio will be distorted. + // We technically could let the user also specify the input sample rate, and + // resample the waveform internally to match them, but that's not in scope for + // an initial version (if at all). + int sampleRate_; FILE* f_; }; diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index e87a81d2..fee94fe3 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -28,8 +28,8 @@ TORCH_LIBRARY(torchcodec_ns, m) { "torchcodec.decoders._core.video_decoder_ops", "//pytorch/torchcodec:torchcodec"); m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor"); - m.def("create_encoder(Tensor wf) -> Tensor"); - m.def("encode(Tensor(a!) encoder) -> Tensor"); + m.def("create_encoder(int sample_rate, str filename) -> Tensor"); + m.def("encode(Tensor(a!) encoder, Tensor wf) -> Tensor"); m.def( "create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor"); m.def( @@ -143,14 +143,15 @@ at::Tensor create_from_file( return wrapDecoderPointerToTensor(std::move(uniqueDecoder)); } -at::Tensor create_encoder(torch::Tensor& wf) { - std::unique_ptr uniqueEncoder = std::make_unique(wf); +at::Tensor create_encoder(int64_t sample_rate, std::string_view file_name) { + std::unique_ptr uniqueEncoder = + std::make_unique(static_cast(sample_rate), file_name); return wrapEncoderPointerToTensor(std::move(uniqueEncoder)); } -at::Tensor encode(at::Tensor& encoder) { +at::Tensor encode(at::Tensor& encoder, const at::Tensor& wf) { auto encoder_ = unwrapTensorToGetEncoder(encoder); - return encoder_->encode(); + return encoder_->encode(wf); } at::Tensor create_from_tensor( diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index cba7492f..480afc80 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -119,12 +119,12 @@ def create_from_file_abstract(filename: str, seek_mode: Optional[str]) -> torch. @register_fake("torchcodec_ns::create_encoder") -def create_encoder_abstract(wf: torch.Tensor) -> torch.Tensor: +def create_encoder_abstract(sample_rate: int, filename: str) -> torch.Tensor: return torch.empty([], dtype=torch.long) @register_fake("torchcodec_ns::encode") -def encode_abstract(encoder: torch.Tensor) -> torch.Tensor: +def encode_abstract(encoder: torch.Tensor, wf: torch.Tensor) -> torch.Tensor: return torch.empty([], dtype=torch.long) From b110dacb9b781cf0a09e69749192839aaea29f5f Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 28 Feb 2025 10:24:53 +0000 Subject: [PATCH 03/18] Cleanup --- .../decoders/_core/VideoDecoder.cpp | 117 +++++++----------- src/torchcodec/decoders/_core/VideoDecoder.h | 8 +- 2 files changed, 45 insertions(+), 80 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 5ff1b3be..7f26c6fc 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1696,16 +1696,12 @@ FrameDims getHeightAndWidthFromOptionsOrAVFrame( } Encoder::~Encoder() { - fclose(f_); // TODO NEED TO CALL THIS // avformat_free_context(avFormatContext_.get()); } Encoder::Encoder(int sampleRate, std::string_view fileName) : sampleRate_(sampleRate) { - f_ = fopen("./coutput", "wb"); - TORCH_CHECK(f_, "Could not open file"); - AVFormatContext* avFormatContext = nullptr; avformat_alloc_output_context2(&avFormatContext, NULL, NULL, fileName.data()); TORCH_CHECK(avFormatContext != nullptr, "Couldn't allocate AVFormatContext."); @@ -1763,46 +1759,38 @@ Encoder::Encoder(int sampleRate, std::string_view fileName) avStream_ = avformat_new_stream(avFormatContext_.get(), NULL); TORCH_CHECK(avStream_ != nullptr, "Couldn't create new stream."); avcodec_parameters_from_context(avStream_->codecpar, avCodecContext_.get()); +} - AVFrame* avFrame = av_frame_alloc(); +torch::Tensor Encoder::encode(const torch::Tensor& wf) { + UniqueAVFrame avFrame(av_frame_alloc()); TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame."); - avFrame_.reset(avFrame); - avFrame_->nb_samples = avCodecContext_->frame_size; - avFrame_->format = avCodecContext_->sample_fmt; - avFrame_->sample_rate = avCodecContext_->sample_rate; - avFrame_->pts = 0; - ffmpegRet = - av_channel_layout_copy(&avFrame_->ch_layout, &avCodecContext_->ch_layout); + avFrame->nb_samples = avCodecContext_->frame_size; + avFrame->format = avCodecContext_->sample_fmt; + avFrame->sample_rate = avCodecContext_->sample_rate; + avFrame->pts = 0; + auto ffmpegRet = + av_channel_layout_copy(&avFrame->ch_layout, &avCodecContext_->ch_layout); TORCH_CHECK( ffmpegRet == AVSUCCESS, "Couldn't copy channel layout to avFrame: ", getFFMPEGErrorStringFromErrorCode(ffmpegRet)); - ffmpegRet = av_frame_get_buffer(avFrame_.get(), 0); + ffmpegRet = av_frame_get_buffer(avFrame.get(), 0); TORCH_CHECK( ffmpegRet == AVSUCCESS, "Couldn't allocate avFrame's buffers: ", getFFMPEGErrorStringFromErrorCode(ffmpegRet)); -} -torch::Tensor Encoder::encode(const torch::Tensor& wf) { - AVPacket* pkt = av_packet_alloc(); - if (!pkt) { - fprintf(stderr, "Could not allocate audio packet\n"); - exit(1); - } - - auto MAX_NUM_BYTES = 10000000; // 10Mb. TODO find a way not to pre-allocate. - int numEncodedBytes = 0; - torch::Tensor outputTensor = torch::empty({MAX_NUM_BYTES}, torch::kUInt8); - uint8_t* pOutputTensor = - static_cast(outputTensor.data_ptr()); + AutoAVPacket autoAVPacket; uint8_t* pWf = static_cast(wf.data_ptr()); - auto numBytesWeWroteFromWF = 0; + auto numChannels = wf.sizes()[0]; + auto numSamples = wf.sizes()[1]; // per channel + auto numEncodedSamples = 0; // per channel + auto numSamplesPerFrame = + static_cast(avCodecContext_->frame_size); // per channel auto numBytesPerSample = wf.element_size(); auto numBytesPerChannel = wf.sizes()[1] * numBytesPerSample; - auto numChannels = wf.sizes()[0]; TORCH_CHECK( // TODO is this even true / needed? We can probably support more with @@ -1814,67 +1802,57 @@ torch::Tensor Encoder::encode(const torch::Tensor& wf) { AV_NUM_DATA_POINTERS, " channels per frame."); - auto ffmpegRet = avformat_write_header(avFormatContext_.get(), NULL); + ffmpegRet = avformat_write_header(avFormatContext_.get(), NULL); TORCH_CHECK( ffmpegRet == AVSUCCESS, "Error in avformat_write_header: ", getFFMPEGErrorStringFromErrorCode(ffmpegRet)); - // TODO need simpler/cleaner while loop condition. - while (numBytesWeWroteFromWF < numBytesPerChannel) { - ffmpegRet = av_frame_make_writable(avFrame_.get()); + while (numEncodedSamples < numSamples) { + ffmpegRet = av_frame_make_writable(avFrame.get()); TORCH_CHECK( ffmpegRet == AVSUCCESS, "Couldn't make AVFrame writable: ", getFFMPEGErrorStringFromErrorCode(ffmpegRet)); - auto numBytesToWrite = numBytesPerSample * avCodecContext_->frame_size; - if (numBytesWeWroteFromWF + numBytesToWrite > numBytesPerChannel) { - numBytesToWrite = numBytesPerChannel - numBytesWeWroteFromWF; - } + auto numSamplesToEncode = + std::min(numSamplesPerFrame, numSamples - numEncodedSamples); + auto numBytesToEncode = numSamplesToEncode * numBytesPerSample; for (int ch = 0; ch < numChannels; ch++) { memcpy( - avFrame_->data[ch], pWf + ch * numBytesPerChannel, numBytesToWrite); + avFrame->data[ch], pWf + ch * numBytesPerChannel, numBytesToEncode); } - pWf += numBytesToWrite; - numBytesWeWroteFromWF += numBytesToWrite; - encode_inner_loop(pkt, pOutputTensor, &numEncodedBytes, false); - avFrame_->pts += avFrame_->nb_samples; + pWf += numBytesToEncode; + encode_inner_loop(autoAVPacket, avFrame.get()); + + avFrame->pts += avFrame->nb_samples; + numEncodedSamples += numSamplesToEncode; } - encode_inner_loop(pkt, pOutputTensor, &numEncodedBytes, true); + encode_inner_loop(autoAVPacket, nullptr); // flush + + TORCH_CHECK(numEncodedSamples == numSamples, "Hmmmmmm something went wrong."); ffmpegRet = av_write_trailer(avFormatContext_.get()); TORCH_CHECK( ffmpegRet == AVSUCCESS, - "Error in : av_write_trailer", + "Error in: av_write_trailer", getFFMPEGErrorStringFromErrorCode(ffmpegRet)); - return outputTensor.narrow( - /*dim=*/0, /*start=*/0, /*length=*/numEncodedBytes); - // return outputTensor; + // TODO handle writing to output uint8 tensor with AVIO logic. + return torch::empty({10}); } -void Encoder::encode_inner_loop( - AVPacket* pkt, - uint8_t* pOutputTensor, - int* numEncodedBytes, - bool flush) { - int ffmpegRet = 0; - - // TODO ewwww - if (flush) { - ffmpegRet = avcodec_send_frame(avCodecContext_.get(), nullptr); - } else { - ffmpegRet = avcodec_send_frame(avCodecContext_.get(), avFrame_.get()); - } +void Encoder::encode_inner_loop(AutoAVPacket& autoAVPacket, AVFrame* avFrame) { + auto ffmpegRet = avcodec_send_frame(avCodecContext_.get(), avFrame); TORCH_CHECK( ffmpegRet == AVSUCCESS, "Error while sending frame: ", getFFMPEGErrorStringFromErrorCode(ffmpegRet)); - while ((ffmpegRet = avcodec_receive_packet(avCodecContext_.get(), pkt)) >= - 0) { + while (ffmpegRet >= 0) { + ReferenceAVPacket packet(autoAVPacket); + ffmpegRet = avcodec_receive_packet(avCodecContext_.get(), packet.get()); if (ffmpegRet == AVERROR(EAGAIN) || ffmpegRet == AVERROR_EOF) { // TODO this is from TorchAudio, probably needed, but not sure. // if (ffmpegRet == AVERROR_EOF) { @@ -1892,23 +1870,16 @@ void Encoder::encode_inner_loop( getFFMPEGErrorStringFromErrorCode(ffmpegRet)); // TODO why are these 2 lines needed?? - // av_packet_rescale_ts(pkt, avCodecContext_->time_base, - // avStream_->time_base); - pkt->stream_index = avStream_->index; - printf("PACKET PTS %d\n", pkt->pts); + av_packet_rescale_ts( + packet.get(), avCodecContext_->time_base, avStream_->time_base); + packet->stream_index = avStream_->index; - ffmpegRet = av_interleaved_write_frame(avFormatContext_.get(), pkt); + ffmpegRet = + av_interleaved_write_frame(avFormatContext_.get(), packet.get()); TORCH_CHECK( ffmpegRet == AVSUCCESS, "Error in av_interleaved_write_frame: ", getFFMPEGErrorStringFromErrorCode(ffmpegRet)); - - fwrite(pkt->data, 1, pkt->size, f_); - - memcpy(pOutputTensor + *numEncodedBytes, pkt->data, pkt->size); - *numEncodedBytes += pkt->size; - - av_packet_unref(pkt); } } diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 70145ffe..9f6df4a8 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -557,15 +557,10 @@ class Encoder { torch::Tensor encode(const torch::Tensor& wf); private: - void encode_inner_loop( - AVPacket* pkt, - uint8_t* pOutputTensor, - int* numEncodedBytes, - bool flush); + void encode_inner_loop(AutoAVPacket& autoAVPacket, AVFrame* avFrame); UniqueAVFormatContext avFormatContext_; UniqueAVCodecContext avCodecContext_; - UniqueAVFrame avFrame_; AVStream* avStream_; // The *output* sample rate. We can't really decide for the user what it @@ -576,7 +571,6 @@ class Encoder { // resample the waveform internally to match them, but that's not in scope for // an initial version (if at all). int sampleRate_; - FILE* f_; }; } // namespace facebook::torchcodec From 0906fb3d0fcc87408e822cbe796ff39f1ea03aba Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 28 Feb 2025 11:00:53 +0000 Subject: [PATCH 04/18] Properly free AVFormatContext and streams --- src/torchcodec/decoders/_core/FFMPEGCommon.h | 5 ++++- src/torchcodec/decoders/_core/VideoDecoder.cpp | 14 +++++++------- src/torchcodec/decoders/_core/VideoDecoder.h | 4 ++-- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/torchcodec/decoders/_core/FFMPEGCommon.h b/src/torchcodec/decoders/_core/FFMPEGCommon.h index deabae52..86cd716d 100644 --- a/src/torchcodec/decoders/_core/FFMPEGCommon.h +++ b/src/torchcodec/decoders/_core/FFMPEGCommon.h @@ -49,9 +49,12 @@ struct Deleter { }; // Unique pointers for FFMPEG structures. -using UniqueAVFormatContext = std::unique_ptr< +using UniqueAVFormatContextForDecoding = std::unique_ptr< AVFormatContext, Deleterp>; +using UniqueAVFormatContextForEncoding = std::unique_ptr< + AVFormatContext, + Deleter>; using UniqueAVCodecContext = std::unique_ptr< AVCodecContext, Deleterp>; diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 7f26c6fc..a4f5d10f 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -41,7 +41,7 @@ int64_t secondsToClosestPts(double seconds, const AVRational& timeBase) { } struct AVInput { - UniqueAVFormatContext formatContext; + UniqueAVFormatContextForDecoding formatContext; std::unique_ptr ioBytesContext; }; @@ -1695,10 +1695,7 @@ FrameDims getHeightAndWidthFromOptionsOrAVFrame( videoStreamOptions.width.value_or(avFrame.width)); } -Encoder::~Encoder() { - // TODO NEED TO CALL THIS - // avformat_free_context(avFormatContext_.get()); -} +Encoder::~Encoder() {} Encoder::Encoder(int sampleRate, std::string_view fileName) : sampleRate_(sampleRate) { @@ -1756,6 +1753,9 @@ Encoder::Encoder(int sampleRate, std::string_view fileName) avCodecContext_->frame_size, ". Cannot encode. This should probably never happen?"); + // We're allocating the stream here. Streams are meant to be freed by + // avformat_free_context(avFormatContext), which we call in the + // avFormatContext_'s destructor. avStream_ = avformat_new_stream(avFormatContext_.get(), NULL); TORCH_CHECK(avStream_ != nullptr, "Couldn't create new stream."); avcodec_parameters_from_context(avStream_->codecpar, avCodecContext_.get()); @@ -1829,10 +1829,10 @@ torch::Tensor Encoder::encode(const torch::Tensor& wf) { avFrame->pts += avFrame->nb_samples; numEncodedSamples += numSamplesToEncode; } - encode_inner_loop(autoAVPacket, nullptr); // flush - TORCH_CHECK(numEncodedSamples == numSamples, "Hmmmmmm something went wrong."); + encode_inner_loop(autoAVPacket, nullptr); // flush + ffmpegRet = av_write_trailer(avFormatContext_.get()); TORCH_CHECK( ffmpegRet == AVSUCCESS, diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 9f6df4a8..7b3158fc 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -453,7 +453,7 @@ class VideoDecoder { SeekMode seekMode_; ContainerMetadata containerMetadata_; - UniqueAVFormatContext formatContext_; + UniqueAVFormatContextForDecoding formatContext_; std::map streamInfos_; const int NO_ACTIVE_STREAM = -2; int activeStreamIndex_ = NO_ACTIVE_STREAM; @@ -559,7 +559,7 @@ class Encoder { private: void encode_inner_loop(AutoAVPacket& autoAVPacket, AVFrame* avFrame); - UniqueAVFormatContext avFormatContext_; + UniqueAVFormatContextForEncoding avFormatContext_; UniqueAVCodecContext avCodecContext_; AVStream* avStream_; From 3890227214a3ca413b355c1f716e0536d945e7b3 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 1 Apr 2025 13:40:46 +0100 Subject: [PATCH 05/18] don't return encoded bytes for now --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 5 +---- src/torchcodec/decoders/_core/VideoDecoder.h | 2 +- src/torchcodec/decoders/_core/VideoDecoderOps.cpp | 6 +++--- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 5d8ffe2e..b18520c0 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -2145,7 +2145,7 @@ Encoder::Encoder(int sampleRate, std::string_view fileName) avcodec_parameters_from_context(avStream_->codecpar, avCodecContext_.get()); } -torch::Tensor Encoder::encode(const torch::Tensor& wf) { +void Encoder::encode(const torch::Tensor& wf) { UniqueAVFrame avFrame(av_frame_alloc()); TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame."); avFrame->nb_samples = avCodecContext_->frame_size; @@ -2222,9 +2222,6 @@ torch::Tensor Encoder::encode(const torch::Tensor& wf) { ffmpegRet == AVSUCCESS, "Error in: av_write_trailer", getFFMPEGErrorStringFromErrorCode(ffmpegRet)); - - // TODO handle writing to output uint8 tensor with AVIO logic. - return torch::empty({10}); } void Encoder::encode_inner_loop(AutoAVPacket& autoAVPacket, AVFrame* avFrame) { diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 1479e9a3..8c91bb1b 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -597,7 +597,7 @@ class Encoder { // TODO doesn't make much sense to pass fileName and the wf tensor in 2 // different calls. Same with sampleRate. Encoder(int sampleRate, std::string_view fileName); - torch::Tensor encode(const torch::Tensor& wf); + void encode(const torch::Tensor& wf); private: void encode_inner_loop(AutoAVPacket& autoAVPacket, AVFrame* avFrame); diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index d057112a..4760d158 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -29,7 +29,7 @@ TORCH_LIBRARY(torchcodec_ns, m) { "torchcodec.decoders._core.ops", "//pytorch/torchcodec:torchcodec"); m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor"); m.def("create_encoder(int sample_rate, str filename) -> Tensor"); - m.def("encode(Tensor(a!) encoder, Tensor wf) -> Tensor"); + m.def("encode(Tensor(a!) encoder, Tensor wf) -> ()"); m.def( "create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor"); m.def("_convert_to_tensor(int decoder_ptr) -> Tensor"); @@ -151,9 +151,9 @@ at::Tensor create_encoder(int64_t sample_rate, std::string_view file_name) { return wrapEncoderPointerToTensor(std::move(uniqueEncoder)); } -at::Tensor encode(at::Tensor& encoder, const at::Tensor& wf) { +void encode(at::Tensor& encoder, const at::Tensor& wf) { auto encoder_ = unwrapTensorToGetEncoder(encoder); - return encoder_->encode(wf); + encoder_->encode(wf); } at::Tensor create_from_tensor( From 52d1753463346abe3dd90b795608a96b0a247dda Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 1 Apr 2025 13:51:49 +0100 Subject: [PATCH 06/18] Write TODOs, avoid raw pointers --- .../decoders/_core/VideoDecoder.cpp | 36 ++++++++++--------- src/torchcodec/decoders/_core/VideoDecoder.h | 4 ++- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index b18520c0..1a311f07 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -2108,19 +2108,21 @@ Encoder::Encoder(int sampleRate, std::string_view fileName) TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context."); avCodecContext_.reset(avCodecContext); - // I think this will use the default. TODO Should let user choose for - // compressed formats like mp3. + // This will use the default bit rate + // TODO-ENCODING Should let user choose for compressed formats like mp3. avCodecContext_->bit_rate = 0; - // TODO A given encoder only supports a finite set of output sample rates. - // FFmpeg raises informative error message. Are we happy with that, or do we - // run our own checks by checking against avCodec->supported_samplerates? + // FFmpeg will raise a reasonably informative error if the desired sample rate + // isn't supported by the encoder. avCodecContext_->sample_rate = sampleRate_; // Note: This is the format of the **input** waveform. This doesn't determine - // the output. TODO check contiguity of the input wf to ensure that it is - // indeed planar. - // TODO What if the encoder doesn't support FLTP? Like flac? + // the output. + // TODO-ENCODING check contiguity of the input wf to ensure that it is indeed + // planar. + // TODO-ENCODING If the encoder doesn't support FLTP (like flac), FFmpeg will + // raise. We need to handle this, probably converting the format with + // libswresample. avCodecContext_->sample_fmt = AV_SAMPLE_FMT_FLTP; AVChannelLayout channel_layout; @@ -2177,8 +2179,8 @@ void Encoder::encode(const torch::Tensor& wf) { auto numBytesPerChannel = wf.sizes()[1] * numBytesPerSample; TORCH_CHECK( - // TODO is this even true / needed? We can probably support more with - // non-planar data? + // TODO-ENCODING is this even true / needed? We can probably support more + // with non-planar data? numChannels <= AV_NUM_DATA_POINTERS, "Trying to encode ", numChannels, @@ -2208,14 +2210,14 @@ void Encoder::encode(const torch::Tensor& wf) { avFrame->data[ch], pWf + ch * numBytesPerChannel, numBytesToEncode); } pWf += numBytesToEncode; - encode_inner_loop(autoAVPacket, avFrame.get()); + encode_inner_loop(autoAVPacket, avFrame); avFrame->pts += avFrame->nb_samples; numEncodedSamples += numSamplesToEncode; } TORCH_CHECK(numEncodedSamples == numSamples, "Hmmmmmm something went wrong."); - encode_inner_loop(autoAVPacket, nullptr); // flush + encode_inner_loop(autoAVPacket, UniqueAVFrame(nullptr)); // flush ffmpegRet = av_write_trailer(avFormatContext_.get()); TORCH_CHECK( @@ -2224,8 +2226,10 @@ void Encoder::encode(const torch::Tensor& wf) { getFFMPEGErrorStringFromErrorCode(ffmpegRet)); } -void Encoder::encode_inner_loop(AutoAVPacket& autoAVPacket, AVFrame* avFrame) { - auto ffmpegRet = avcodec_send_frame(avCodecContext_.get(), avFrame); +void Encoder::encode_inner_loop( + AutoAVPacket& autoAVPacket, + const UniqueAVFrame& avFrame) { + auto ffmpegRet = avcodec_send_frame(avCodecContext_.get(), avFrame.get()); TORCH_CHECK( ffmpegRet == AVSUCCESS, "Error while sending frame: ", @@ -2235,7 +2239,7 @@ void Encoder::encode_inner_loop(AutoAVPacket& autoAVPacket, AVFrame* avFrame) { ReferenceAVPacket packet(autoAVPacket); ffmpegRet = avcodec_receive_packet(avCodecContext_.get(), packet.get()); if (ffmpegRet == AVERROR(EAGAIN) || ffmpegRet == AVERROR_EOF) { - // TODO this is from TorchAudio, probably needed, but not sure. + // TODO-ENCODING this is from TorchAudio, probably needed, but not sure. // if (ffmpegRet == AVERROR_EOF) { // ffmpegRet = av_interleaved_write_frame(avFormatContext_.get(), // nullptr); TORCH_CHECK( @@ -2250,7 +2254,7 @@ void Encoder::encode_inner_loop(AutoAVPacket& autoAVPacket, AVFrame* avFrame) { "Error receiving packet: ", getFFMPEGErrorStringFromErrorCode(ffmpegRet)); - // TODO why are these 2 lines needed?? + // TODO-ENCODING why are these 2 lines needed?? av_packet_rescale_ts( packet.get(), avCodecContext_->time_base, avStream_->time_base); packet->stream_index = avStream_->index; diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 8c91bb1b..0587e268 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -600,7 +600,9 @@ class Encoder { void encode(const torch::Tensor& wf); private: - void encode_inner_loop(AutoAVPacket& autoAVPacket, AVFrame* avFrame); + void encode_inner_loop( + AutoAVPacket& autoAVPacket, + const UniqueAVFrame& avFrame); UniqueAVFormatContextForEncoding avFormatContext_; UniqueAVCodecContext avCodecContext_; From 45fd0eceaaa497ab2a6fe3bd866b7e2ea08c3b0d Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 1 Apr 2025 14:14:00 +0100 Subject: [PATCH 07/18] Add (failing) round-trip test --- test/decoders/test_ops.py | 35 +++++++++++++++++++++++++++++++++-- test/utils.py | 2 ++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/test/decoders/test_ops.py b/test/decoders/test_ops.py index 9efb33f3..216d28aa 100644 --- a/test/decoders/test_ops.py +++ b/test/decoders/test_ops.py @@ -21,10 +21,12 @@ _test_frame_pts_equality, add_audio_stream, add_video_stream, + create_encoder, create_from_bytes, create_from_file, create_from_file_like, create_from_tensor, + encode, get_ffmpeg_library_versions, get_frame_at_index, get_frame_at_pts, @@ -48,6 +50,7 @@ SINE_MONO_S32, SINE_MONO_S32_44100, SINE_MONO_S32_8000, + TestContainerFile, ) torch._dynamo.config.capture_dynamic_output_shape_ops = True @@ -55,7 +58,7 @@ INDEX_OF_FRAME_AT_6_SECONDS = 180 -class TestVideoOps: +class TestVideoDecoderOps: @pytest.mark.parametrize("device", cpu_and_cuda()) def test_seek_and_next(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) @@ -632,7 +635,7 @@ def test_cuda_decoder(self): ) -class TestAudioOps: +class TestAudioDecoderOps: @pytest.mark.parametrize( "method", ( @@ -923,5 +926,33 @@ def get_all_frames(asset, sample_rate=None, stop_seconds=None): torch.testing.assert_close(frames_downsampled_to_8000, frames_8000_native) +class TestAudioEncoderOps: + + def decode(self, source) -> torch.Tensor: + if isinstance(source, TestContainerFile): + source = str(source.path) + else: + source = str(source) + decoder = create_from_file(source, seek_mode="approximate") + add_audio_stream(decoder) + frames, *_ = get_frames_by_pts_in_range_audio( + decoder, start_seconds=0, stop_seconds=None + ) + return frames + + def test_round_trip(self, tmp_path): + asset = SINE_MONO_S32 + source_samples = self.decode(asset) + + output_file = tmp_path / "output.mp3" + encoder = create_encoder( + sample_rate=asset.sample_rate, filename=str(output_file) + ) + encode(encoder, source_samples) + + round_trip_samples = self.decode(output_file) + torch.testing.assert_close(source_samples, round_trip_samples) + + if __name__ == "__main__": pytest.main() diff --git a/test/utils.py b/test/utils.py index 70f32bfb..dd438785 100644 --- a/test/utils.py +++ b/test/utils.py @@ -114,6 +114,8 @@ class TestAudioStreamInfo: @dataclass class TestContainerFile: + __test__ = False # prevents pytest from thinking this is a test class + filename: str default_stream_index: int From 75b099b6c0ca2696857bdaa61484818ea4b80cf7 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 2 Apr 2025 10:55:08 +0100 Subject: [PATCH 08/18] Create new file --- src/torchcodec/_core/CMakeLists.txt | 3 + src/torchcodec/_core/Encoder.cpp | 199 ++++++++++++++++++++++++++ src/torchcodec/_core/Encoder.h | 35 +++++ src/torchcodec/_core/VideoDecoder.cpp | 189 ------------------------ src/torchcodec/_core/VideoDecoder.h | 30 ---- src/torchcodec/_core/custom_ops.cpp | 1 + 6 files changed, 238 insertions(+), 219 deletions(-) create mode 100644 src/torchcodec/_core/Encoder.cpp create mode 100644 src/torchcodec/_core/Encoder.h diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index 9e2fb4e8..465e893d 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -62,6 +62,9 @@ function(make_torchcodec_libraries AVIOContextHolder.cpp FFMPEGCommon.cpp VideoDecoder.cpp + # TODO: lib name should probably not be "*_decoder*" now that it also + # contains an encoder + Encoder.cpp ) if(ENABLE_CUDA) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp new file mode 100644 index 00000000..e5c4cf18 --- /dev/null +++ b/src/torchcodec/_core/Encoder.cpp @@ -0,0 +1,199 @@ +#include "src/torchcodec/_core/Encoder.h" +#include "torch/types.h" + +extern "C" { +#include +#include +} + +namespace facebook::torchcodec { + +Encoder::~Encoder() {} + +Encoder::Encoder(int sampleRate, std::string_view fileName) + : sampleRate_(sampleRate) { + AVFormatContext* avFormatContext = nullptr; + avformat_alloc_output_context2(&avFormatContext, NULL, NULL, fileName.data()); + TORCH_CHECK(avFormatContext != nullptr, "Couldn't allocate AVFormatContext."); + avFormatContext_.reset(avFormatContext); + + TORCH_CHECK( + !(avFormatContext->oformat->flags & AVFMT_NOFILE), + "AVFMT_NOFILE is set. We only support writing to a file."); + auto ffmpegRet = + avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE); + TORCH_CHECK( + ffmpegRet >= 0, + "avio_open failed: ", + getFFMPEGErrorStringFromErrorCode(ffmpegRet)); + + // We use the AVFormatContext's default codec for that + // specificavcodec_parameters_from_context format/container. + const AVCodec* avCodec = + avcodec_find_encoder(avFormatContext_->oformat->audio_codec); + TORCH_CHECK(avCodec != nullptr, "Codec not found"); + + AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec); + TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context."); + avCodecContext_.reset(avCodecContext); + + // This will use the default bit rate + // TODO-ENCODING Should let user choose for compressed formats like mp3. + avCodecContext_->bit_rate = 0; + + // FFmpeg will raise a reasonably informative error if the desired sample rate + // isn't supported by the encoder. + avCodecContext_->sample_rate = sampleRate_; + + // Note: This is the format of the **input** waveform. This doesn't determine + // the output. + // TODO-ENCODING check contiguity of the input wf to ensure that it is indeed + // planar. + // TODO-ENCODING If the encoder doesn't support FLTP (like flac), FFmpeg will + // raise. We need to handle this, probably converting the format with + // libswresample. + avCodecContext_->sample_fmt = AV_SAMPLE_FMT_FLTP; + + AVChannelLayout channel_layout; + av_channel_layout_default(&channel_layout, 2); + avCodecContext_->ch_layout = channel_layout; + + ffmpegRet = avcodec_open2(avCodecContext_.get(), avCodec, nullptr); + TORCH_CHECK( + ffmpegRet == AVSUCCESS, getFFMPEGErrorStringFromErrorCode(ffmpegRet)); + + TORCH_CHECK( + avCodecContext_->frame_size > 0, + "frame_size is ", + avCodecContext_->frame_size, + ". Cannot encode. This should probably never happen?"); + + // We're allocating the stream here. Streams are meant to be freed by + // avformat_free_context(avFormatContext), which we call in the + // avFormatContext_'s destructor. + avStream_ = avformat_new_stream(avFormatContext_.get(), NULL); + TORCH_CHECK(avStream_ != nullptr, "Couldn't create new stream."); + avcodec_parameters_from_context(avStream_->codecpar, avCodecContext_.get()); +} + +void Encoder::encode(const torch::Tensor& wf) { + UniqueAVFrame avFrame(av_frame_alloc()); + TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame."); + avFrame->nb_samples = avCodecContext_->frame_size; + avFrame->format = avCodecContext_->sample_fmt; + avFrame->sample_rate = avCodecContext_->sample_rate; + avFrame->pts = 0; + auto ffmpegRet = + av_channel_layout_copy(&avFrame->ch_layout, &avCodecContext_->ch_layout); + TORCH_CHECK( + ffmpegRet == AVSUCCESS, + "Couldn't copy channel layout to avFrame: ", + getFFMPEGErrorStringFromErrorCode(ffmpegRet)); + + ffmpegRet = av_frame_get_buffer(avFrame.get(), 0); + TORCH_CHECK( + ffmpegRet == AVSUCCESS, + "Couldn't allocate avFrame's buffers: ", + getFFMPEGErrorStringFromErrorCode(ffmpegRet)); + + AutoAVPacket autoAVPacket; + + uint8_t* pWf = static_cast(wf.data_ptr()); + auto numChannels = wf.sizes()[0]; + auto numSamples = wf.sizes()[1]; // per channel + auto numEncodedSamples = 0; // per channel + auto numSamplesPerFrame = + static_cast(avCodecContext_->frame_size); // per channel + auto numBytesPerSample = wf.element_size(); + auto numBytesPerChannel = wf.sizes()[1] * numBytesPerSample; + + TORCH_CHECK( + // TODO-ENCODING is this even true / needed? We can probably support more + // with non-planar data? + numChannels <= AV_NUM_DATA_POINTERS, + "Trying to encode ", + numChannels, + " channels, but FFmpeg only supports ", + AV_NUM_DATA_POINTERS, + " channels per frame."); + + ffmpegRet = avformat_write_header(avFormatContext_.get(), NULL); + TORCH_CHECK( + ffmpegRet == AVSUCCESS, + "Error in avformat_write_header: ", + getFFMPEGErrorStringFromErrorCode(ffmpegRet)); + + while (numEncodedSamples < numSamples) { + ffmpegRet = av_frame_make_writable(avFrame.get()); + TORCH_CHECK( + ffmpegRet == AVSUCCESS, + "Couldn't make AVFrame writable: ", + getFFMPEGErrorStringFromErrorCode(ffmpegRet)); + + auto numSamplesToEncode = + std::min(numSamplesPerFrame, numSamples - numEncodedSamples); + auto numBytesToEncode = numSamplesToEncode * numBytesPerSample; + + for (int ch = 0; ch < numChannels; ch++) { + memcpy( + avFrame->data[ch], pWf + ch * numBytesPerChannel, numBytesToEncode); + } + pWf += numBytesToEncode; + encode_inner_loop(autoAVPacket, avFrame); + + avFrame->pts += avFrame->nb_samples; + numEncodedSamples += numSamplesToEncode; + } + TORCH_CHECK(numEncodedSamples == numSamples, "Hmmmmmm something went wrong."); + + encode_inner_loop(autoAVPacket, UniqueAVFrame(nullptr)); // flush + + ffmpegRet = av_write_trailer(avFormatContext_.get()); + TORCH_CHECK( + ffmpegRet == AVSUCCESS, + "Error in: av_write_trailer", + getFFMPEGErrorStringFromErrorCode(ffmpegRet)); +} + +void Encoder::encode_inner_loop( + AutoAVPacket& autoAVPacket, + const UniqueAVFrame& avFrame) { + auto ffmpegRet = avcodec_send_frame(avCodecContext_.get(), avFrame.get()); + TORCH_CHECK( + ffmpegRet == AVSUCCESS, + "Error while sending frame: ", + getFFMPEGErrorStringFromErrorCode(ffmpegRet)); + + while (ffmpegRet >= 0) { + ReferenceAVPacket packet(autoAVPacket); + ffmpegRet = avcodec_receive_packet(avCodecContext_.get(), packet.get()); + if (ffmpegRet == AVERROR(EAGAIN) || ffmpegRet == AVERROR_EOF) { + // TODO-ENCODING this is from TorchAudio, probably needed, but not sure. + // if (ffmpegRet == AVERROR_EOF) { + // ffmpegRet = av_interleaved_write_frame(avFormatContext_.get(), + // nullptr); TORCH_CHECK( + // ffmpegRet == AVSUCCESS, + // "Failed to flush packet ", + // getFFMPEGErrorStringFromErrorCode(ffmpegRet)); + // } + return; + } + TORCH_CHECK( + ffmpegRet >= 0, + "Error receiving packet: ", + getFFMPEGErrorStringFromErrorCode(ffmpegRet)); + + // TODO-ENCODING why are these 2 lines needed?? + av_packet_rescale_ts( + packet.get(), avCodecContext_->time_base, avStream_->time_base); + packet->stream_index = avStream_->index; + + ffmpegRet = + av_interleaved_write_frame(avFormatContext_.get(), packet.get()); + TORCH_CHECK( + ffmpegRet == AVSUCCESS, + "Error in av_interleaved_write_frame: ", + getFFMPEGErrorStringFromErrorCode(ffmpegRet)); + } +} +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h new file mode 100644 index 00000000..85176ae1 --- /dev/null +++ b/src/torchcodec/_core/Encoder.h @@ -0,0 +1,35 @@ +#pragma once +#include +#include "src/torchcodec/_core/FFMPEGCommon.h" + +namespace facebook::torchcodec { +class Encoder { + public: + ~Encoder(); + + // TODO Are we OK passing a string_view to the constructor? + // TODO fileName should be optional. + // TODO doesn't make much sense to pass fileName and the wf tensor in 2 + // different calls. Same with sampleRate. + Encoder(int sampleRate, std::string_view fileName); + void encode(const torch::Tensor& wf); + + private: + void encode_inner_loop( + AutoAVPacket& autoAVPacket, + const UniqueAVFrame& avFrame); + + UniqueAVFormatContextForEncoding avFormatContext_; + UniqueAVCodecContext avCodecContext_; + AVStream* avStream_; + + // The *output* sample rate. We can't really decide for the user what it + // should be. Particularly, the sample rate of the input waveform should match + // this, and that's up to the user. If sample rates don't match, encoding will + // still work but audio will be distorted. + // We technically could let the user also specify the input sample rate, and + // resample the waveform internally to match them, but that's not in scope for + // an initial version (if at all). + int sampleRate_; +}; +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/VideoDecoder.cpp b/src/torchcodec/_core/VideoDecoder.cpp index 1f3694f1..08954342 100644 --- a/src/torchcodec/_core/VideoDecoder.cpp +++ b/src/torchcodec/_core/VideoDecoder.cpp @@ -2079,193 +2079,4 @@ VideoDecoder::SeekMode seekModeFromString(std::string_view seekMode) { } } -Encoder::~Encoder() {} - -Encoder::Encoder(int sampleRate, std::string_view fileName) - : sampleRate_(sampleRate) { - AVFormatContext* avFormatContext = nullptr; - avformat_alloc_output_context2(&avFormatContext, NULL, NULL, fileName.data()); - TORCH_CHECK(avFormatContext != nullptr, "Couldn't allocate AVFormatContext."); - avFormatContext_.reset(avFormatContext); - - TORCH_CHECK( - !(avFormatContext->oformat->flags & AVFMT_NOFILE), - "AVFMT_NOFILE is set. We only support writing to a file."); - auto ffmpegRet = - avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE); - TORCH_CHECK( - ffmpegRet >= 0, - "avio_open failed: ", - getFFMPEGErrorStringFromErrorCode(ffmpegRet)); - - // We use the AVFormatContext's default codec for that - // specificavcodec_parameters_from_context format/container. - const AVCodec* avCodec = - avcodec_find_encoder(avFormatContext_->oformat->audio_codec); - TORCH_CHECK(avCodec != nullptr, "Codec not found"); - - AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec); - TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context."); - avCodecContext_.reset(avCodecContext); - - // This will use the default bit rate - // TODO-ENCODING Should let user choose for compressed formats like mp3. - avCodecContext_->bit_rate = 0; - - // FFmpeg will raise a reasonably informative error if the desired sample rate - // isn't supported by the encoder. - avCodecContext_->sample_rate = sampleRate_; - - // Note: This is the format of the **input** waveform. This doesn't determine - // the output. - // TODO-ENCODING check contiguity of the input wf to ensure that it is indeed - // planar. - // TODO-ENCODING If the encoder doesn't support FLTP (like flac), FFmpeg will - // raise. We need to handle this, probably converting the format with - // libswresample. - avCodecContext_->sample_fmt = AV_SAMPLE_FMT_FLTP; - - AVChannelLayout channel_layout; - av_channel_layout_default(&channel_layout, 2); - avCodecContext_->ch_layout = channel_layout; - - ffmpegRet = avcodec_open2(avCodecContext_.get(), avCodec, nullptr); - TORCH_CHECK( - ffmpegRet == AVSUCCESS, getFFMPEGErrorStringFromErrorCode(ffmpegRet)); - - TORCH_CHECK( - avCodecContext_->frame_size > 0, - "frame_size is ", - avCodecContext_->frame_size, - ". Cannot encode. This should probably never happen?"); - - // We're allocating the stream here. Streams are meant to be freed by - // avformat_free_context(avFormatContext), which we call in the - // avFormatContext_'s destructor. - avStream_ = avformat_new_stream(avFormatContext_.get(), NULL); - TORCH_CHECK(avStream_ != nullptr, "Couldn't create new stream."); - avcodec_parameters_from_context(avStream_->codecpar, avCodecContext_.get()); -} - -void Encoder::encode(const torch::Tensor& wf) { - UniqueAVFrame avFrame(av_frame_alloc()); - TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame."); - avFrame->nb_samples = avCodecContext_->frame_size; - avFrame->format = avCodecContext_->sample_fmt; - avFrame->sample_rate = avCodecContext_->sample_rate; - avFrame->pts = 0; - auto ffmpegRet = - av_channel_layout_copy(&avFrame->ch_layout, &avCodecContext_->ch_layout); - TORCH_CHECK( - ffmpegRet == AVSUCCESS, - "Couldn't copy channel layout to avFrame: ", - getFFMPEGErrorStringFromErrorCode(ffmpegRet)); - - ffmpegRet = av_frame_get_buffer(avFrame.get(), 0); - TORCH_CHECK( - ffmpegRet == AVSUCCESS, - "Couldn't allocate avFrame's buffers: ", - getFFMPEGErrorStringFromErrorCode(ffmpegRet)); - - AutoAVPacket autoAVPacket; - - uint8_t* pWf = static_cast(wf.data_ptr()); - auto numChannels = wf.sizes()[0]; - auto numSamples = wf.sizes()[1]; // per channel - auto numEncodedSamples = 0; // per channel - auto numSamplesPerFrame = - static_cast(avCodecContext_->frame_size); // per channel - auto numBytesPerSample = wf.element_size(); - auto numBytesPerChannel = wf.sizes()[1] * numBytesPerSample; - - TORCH_CHECK( - // TODO-ENCODING is this even true / needed? We can probably support more - // with non-planar data? - numChannels <= AV_NUM_DATA_POINTERS, - "Trying to encode ", - numChannels, - " channels, but FFmpeg only supports ", - AV_NUM_DATA_POINTERS, - " channels per frame."); - - ffmpegRet = avformat_write_header(avFormatContext_.get(), NULL); - TORCH_CHECK( - ffmpegRet == AVSUCCESS, - "Error in avformat_write_header: ", - getFFMPEGErrorStringFromErrorCode(ffmpegRet)); - - while (numEncodedSamples < numSamples) { - ffmpegRet = av_frame_make_writable(avFrame.get()); - TORCH_CHECK( - ffmpegRet == AVSUCCESS, - "Couldn't make AVFrame writable: ", - getFFMPEGErrorStringFromErrorCode(ffmpegRet)); - - auto numSamplesToEncode = - std::min(numSamplesPerFrame, numSamples - numEncodedSamples); - auto numBytesToEncode = numSamplesToEncode * numBytesPerSample; - - for (int ch = 0; ch < numChannels; ch++) { - memcpy( - avFrame->data[ch], pWf + ch * numBytesPerChannel, numBytesToEncode); - } - pWf += numBytesToEncode; - encode_inner_loop(autoAVPacket, avFrame); - - avFrame->pts += avFrame->nb_samples; - numEncodedSamples += numSamplesToEncode; - } - TORCH_CHECK(numEncodedSamples == numSamples, "Hmmmmmm something went wrong."); - - encode_inner_loop(autoAVPacket, UniqueAVFrame(nullptr)); // flush - - ffmpegRet = av_write_trailer(avFormatContext_.get()); - TORCH_CHECK( - ffmpegRet == AVSUCCESS, - "Error in: av_write_trailer", - getFFMPEGErrorStringFromErrorCode(ffmpegRet)); -} - -void Encoder::encode_inner_loop( - AutoAVPacket& autoAVPacket, - const UniqueAVFrame& avFrame) { - auto ffmpegRet = avcodec_send_frame(avCodecContext_.get(), avFrame.get()); - TORCH_CHECK( - ffmpegRet == AVSUCCESS, - "Error while sending frame: ", - getFFMPEGErrorStringFromErrorCode(ffmpegRet)); - - while (ffmpegRet >= 0) { - ReferenceAVPacket packet(autoAVPacket); - ffmpegRet = avcodec_receive_packet(avCodecContext_.get(), packet.get()); - if (ffmpegRet == AVERROR(EAGAIN) || ffmpegRet == AVERROR_EOF) { - // TODO-ENCODING this is from TorchAudio, probably needed, but not sure. - // if (ffmpegRet == AVERROR_EOF) { - // ffmpegRet = av_interleaved_write_frame(avFormatContext_.get(), - // nullptr); TORCH_CHECK( - // ffmpegRet == AVSUCCESS, - // "Failed to flush packet ", - // getFFMPEGErrorStringFromErrorCode(ffmpegRet)); - // } - return; - } - TORCH_CHECK( - ffmpegRet >= 0, - "Error receiving packet: ", - getFFMPEGErrorStringFromErrorCode(ffmpegRet)); - - // TODO-ENCODING why are these 2 lines needed?? - av_packet_rescale_ts( - packet.get(), avCodecContext_->time_base, avStream_->time_base); - packet->stream_index = avStream_->index; - - ffmpegRet = - av_interleaved_write_frame(avFormatContext_.get(), packet.get()); - TORCH_CHECK( - ffmpegRet == AVSUCCESS, - "Error in av_interleaved_write_frame: ", - getFFMPEGErrorStringFromErrorCode(ffmpegRet)); - } -} - } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/VideoDecoder.h b/src/torchcodec/_core/VideoDecoder.h index 5b95a648..1a596e9e 100644 --- a/src/torchcodec/_core/VideoDecoder.h +++ b/src/torchcodec/_core/VideoDecoder.h @@ -588,34 +588,4 @@ std::ostream& operator<<( VideoDecoder::SeekMode seekModeFromString(std::string_view seekMode); -class Encoder { - public: - ~Encoder(); - - // TODO Are we OK passing a string_view to the constructor? - // TODO fileName should be optional. - // TODO doesn't make much sense to pass fileName and the wf tensor in 2 - // different calls. Same with sampleRate. - Encoder(int sampleRate, std::string_view fileName); - void encode(const torch::Tensor& wf); - - private: - void encode_inner_loop( - AutoAVPacket& autoAVPacket, - const UniqueAVFrame& avFrame); - - UniqueAVFormatContextForEncoding avFormatContext_; - UniqueAVCodecContext avCodecContext_; - AVStream* avStream_; - - // The *output* sample rate. We can't really decide for the user what it - // should be. Particularly, the sample rate of the input waveform should match - // this, and that's up to the user. If sample rates don't match, encoding will - // still work but audio will be distorted. - // We technically could let the user also specify the input sample rate, and - // resample the waveform internally to match them, but that's not in scope for - // an initial version (if at all). - int sampleRate_; -}; - } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index d7206931..e37618cc 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -11,6 +11,7 @@ #include "c10/core/SymIntArrayRef.h" #include "c10/util/Exception.h" #include "src/torchcodec/_core/AVIOBytesContext.h" +#include "src/torchcodec/_core/Encoder.h" #include "src/torchcodec/_core/VideoDecoder.h" namespace facebook::torchcodec { From 01dc1b1563b727a9d725e3b9137617f091936117 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 2 Apr 2025 10:57:47 +0100 Subject: [PATCH 09/18] NULL -> nullptr --- src/torchcodec/_core/Encoder.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index e5c4cf18..67def23f 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -13,7 +13,7 @@ Encoder::~Encoder() {} Encoder::Encoder(int sampleRate, std::string_view fileName) : sampleRate_(sampleRate) { AVFormatContext* avFormatContext = nullptr; - avformat_alloc_output_context2(&avFormatContext, NULL, NULL, fileName.data()); + avformat_alloc_output_context2(&avFormatContext, nullptr, nullptr, fileName.data()); TORCH_CHECK(avFormatContext != nullptr, "Couldn't allocate AVFormatContext."); avFormatContext_.reset(avFormatContext); @@ -71,7 +71,7 @@ Encoder::Encoder(int sampleRate, std::string_view fileName) // We're allocating the stream here. Streams are meant to be freed by // avformat_free_context(avFormatContext), which we call in the // avFormatContext_'s destructor. - avStream_ = avformat_new_stream(avFormatContext_.get(), NULL); + avStream_ = avformat_new_stream(avFormatContext_.get(), nullptr); TORCH_CHECK(avStream_ != nullptr, "Couldn't create new stream."); avcodec_parameters_from_context(avStream_->codecpar, avCodecContext_.get()); } @@ -117,7 +117,7 @@ void Encoder::encode(const torch::Tensor& wf) { AV_NUM_DATA_POINTERS, " channels per frame."); - ffmpegRet = avformat_write_header(avFormatContext_.get(), NULL); + ffmpegRet = avformat_write_header(avFormatContext_.get(), nullptr); TORCH_CHECK( ffmpegRet == AVSUCCESS, "Error in avformat_write_header: ", From 691dde73a02b8a67ccca1a1ea2d769f288dc4b9a Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 2 Apr 2025 11:17:43 +0100 Subject: [PATCH 10/18] Use 'status' instead of ffmpegRet --- src/torchcodec/_core/Encoder.cpp | 78 ++++++++++++++++---------------- test/decoders/test_ops.py | 29 +++++++++--- 2 files changed, 62 insertions(+), 45 deletions(-) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 67def23f..fdee4c5f 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -13,19 +13,20 @@ Encoder::~Encoder() {} Encoder::Encoder(int sampleRate, std::string_view fileName) : sampleRate_(sampleRate) { AVFormatContext* avFormatContext = nullptr; - avformat_alloc_output_context2(&avFormatContext, nullptr, nullptr, fileName.data()); + avformat_alloc_output_context2( + &avFormatContext, nullptr, nullptr, fileName.data()); TORCH_CHECK(avFormatContext != nullptr, "Couldn't allocate AVFormatContext."); avFormatContext_.reset(avFormatContext); TORCH_CHECK( !(avFormatContext->oformat->flags & AVFMT_NOFILE), "AVFMT_NOFILE is set. We only support writing to a file."); - auto ffmpegRet = + auto status = avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE); TORCH_CHECK( - ffmpegRet >= 0, + status >= 0, "avio_open failed: ", - getFFMPEGErrorStringFromErrorCode(ffmpegRet)); + getFFMPEGErrorStringFromErrorCode(status)); // We use the AVFormatContext's default codec for that // specificavcodec_parameters_from_context format/container. @@ -39,7 +40,8 @@ Encoder::Encoder(int sampleRate, std::string_view fileName) // This will use the default bit rate // TODO-ENCODING Should let user choose for compressed formats like mp3. - avCodecContext_->bit_rate = 0; + // avCodecContext_->bit_rate = 0; + avCodecContext_->bit_rate = 24000; // FFmpeg will raise a reasonably informative error if the desired sample rate // isn't supported by the encoder. @@ -58,9 +60,8 @@ Encoder::Encoder(int sampleRate, std::string_view fileName) av_channel_layout_default(&channel_layout, 2); avCodecContext_->ch_layout = channel_layout; - ffmpegRet = avcodec_open2(avCodecContext_.get(), avCodec, nullptr); - TORCH_CHECK( - ffmpegRet == AVSUCCESS, getFFMPEGErrorStringFromErrorCode(ffmpegRet)); + status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr); + TORCH_CHECK(status == AVSUCCESS, getFFMPEGErrorStringFromErrorCode(status)); TORCH_CHECK( avCodecContext_->frame_size > 0, @@ -83,18 +84,18 @@ void Encoder::encode(const torch::Tensor& wf) { avFrame->format = avCodecContext_->sample_fmt; avFrame->sample_rate = avCodecContext_->sample_rate; avFrame->pts = 0; - auto ffmpegRet = + auto status = av_channel_layout_copy(&avFrame->ch_layout, &avCodecContext_->ch_layout); TORCH_CHECK( - ffmpegRet == AVSUCCESS, + status == AVSUCCESS, "Couldn't copy channel layout to avFrame: ", - getFFMPEGErrorStringFromErrorCode(ffmpegRet)); + getFFMPEGErrorStringFromErrorCode(status)); - ffmpegRet = av_frame_get_buffer(avFrame.get(), 0); + status = av_frame_get_buffer(avFrame.get(), 0); TORCH_CHECK( - ffmpegRet == AVSUCCESS, + status == AVSUCCESS, "Couldn't allocate avFrame's buffers: ", - getFFMPEGErrorStringFromErrorCode(ffmpegRet)); + getFFMPEGErrorStringFromErrorCode(status)); AutoAVPacket autoAVPacket; @@ -117,18 +118,18 @@ void Encoder::encode(const torch::Tensor& wf) { AV_NUM_DATA_POINTERS, " channels per frame."); - ffmpegRet = avformat_write_header(avFormatContext_.get(), nullptr); + status = avformat_write_header(avFormatContext_.get(), nullptr); TORCH_CHECK( - ffmpegRet == AVSUCCESS, + status == AVSUCCESS, "Error in avformat_write_header: ", - getFFMPEGErrorStringFromErrorCode(ffmpegRet)); + getFFMPEGErrorStringFromErrorCode(status)); while (numEncodedSamples < numSamples) { - ffmpegRet = av_frame_make_writable(avFrame.get()); + status = av_frame_make_writable(avFrame.get()); TORCH_CHECK( - ffmpegRet == AVSUCCESS, + status == AVSUCCESS, "Couldn't make AVFrame writable: ", - getFFMPEGErrorStringFromErrorCode(ffmpegRet)); + getFFMPEGErrorStringFromErrorCode(status)); auto numSamplesToEncode = std::min(numSamplesPerFrame, numSamples - numEncodedSamples); @@ -148,52 +149,51 @@ void Encoder::encode(const torch::Tensor& wf) { encode_inner_loop(autoAVPacket, UniqueAVFrame(nullptr)); // flush - ffmpegRet = av_write_trailer(avFormatContext_.get()); + status = av_write_trailer(avFormatContext_.get()); TORCH_CHECK( - ffmpegRet == AVSUCCESS, + status == AVSUCCESS, "Error in: av_write_trailer", - getFFMPEGErrorStringFromErrorCode(ffmpegRet)); + getFFMPEGErrorStringFromErrorCode(status)); } void Encoder::encode_inner_loop( AutoAVPacket& autoAVPacket, const UniqueAVFrame& avFrame) { - auto ffmpegRet = avcodec_send_frame(avCodecContext_.get(), avFrame.get()); + auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get()); TORCH_CHECK( - ffmpegRet == AVSUCCESS, + status == AVSUCCESS, "Error while sending frame: ", - getFFMPEGErrorStringFromErrorCode(ffmpegRet)); + getFFMPEGErrorStringFromErrorCode(status)); - while (ffmpegRet >= 0) { + while (status >= 0) { ReferenceAVPacket packet(autoAVPacket); - ffmpegRet = avcodec_receive_packet(avCodecContext_.get(), packet.get()); - if (ffmpegRet == AVERROR(EAGAIN) || ffmpegRet == AVERROR_EOF) { + status = avcodec_receive_packet(avCodecContext_.get(), packet.get()); + if (status == AVERROR(EAGAIN) || status == AVERROR_EOF) { // TODO-ENCODING this is from TorchAudio, probably needed, but not sure. - // if (ffmpegRet == AVERROR_EOF) { - // ffmpegRet = av_interleaved_write_frame(avFormatContext_.get(), + // if (status == AVERROR_EOF) { + // status = av_interleaved_write_frame(avFormatContext_.get(), // nullptr); TORCH_CHECK( - // ffmpegRet == AVSUCCESS, + // status == AVSUCCESS, // "Failed to flush packet ", - // getFFMPEGErrorStringFromErrorCode(ffmpegRet)); + // getFFMPEGErrorStringFromErrorCode(status)); // } return; } TORCH_CHECK( - ffmpegRet >= 0, + status >= 0, "Error receiving packet: ", - getFFMPEGErrorStringFromErrorCode(ffmpegRet)); + getFFMPEGErrorStringFromErrorCode(status)); // TODO-ENCODING why are these 2 lines needed?? av_packet_rescale_ts( packet.get(), avCodecContext_->time_base, avStream_->time_base); packet->stream_index = avStream_->index; - ffmpegRet = - av_interleaved_write_frame(avFormatContext_.get(), packet.get()); + status = av_interleaved_write_frame(avFormatContext_.get(), packet.get()); TORCH_CHECK( - ffmpegRet == AVSUCCESS, + status == AVSUCCESS, "Error in av_interleaved_write_frame: ", - getFFMPEGErrorStringFromErrorCode(ffmpegRet)); + getFFMPEGErrorStringFromErrorCode(status)); } } } // namespace facebook::torchcodec diff --git a/test/decoders/test_ops.py b/test/decoders/test_ops.py index 71bebd95..7cb61349 100644 --- a/test/decoders/test_ops.py +++ b/test/decoders/test_ops.py @@ -942,16 +942,33 @@ def decode(self, source) -> torch.Tensor: def test_round_trip(self, tmp_path): asset = SINE_MONO_S32 - source_samples = self.decode(asset) - output_file = tmp_path / "output.mp3" + encoded_by_ffmpeg = tmp_path / "ffmpeg_output.mp3" + encoded_by_us = tmp_path / "our_output.mp3" + + command = [ + "ffmpeg", + "-i", + str(asset.path), + # '-vn', + # '-ar', '44100', # Set audio sampling rate + # '-ac', '2', # Set number of audio channels + # '-b:a', '192k', # Set audio bitrate + str(encoded_by_ffmpeg), + ] + subprocess.run(command, check=True) + encoder = create_encoder( - sample_rate=asset.sample_rate, filename=str(output_file) + sample_rate=asset.sample_rate, filename=str(encoded_by_us) ) - encode(encoder, source_samples) - round_trip_samples = self.decode(output_file) - torch.testing.assert_close(source_samples, round_trip_samples) + encode(encoder, self.decode(asset)) + + print(encoded_by_ffmpeg) + print(encoded_by_us) + from_ffmpeg = self.decode(encoded_by_ffmpeg) + from_us = self.decode(encoded_by_us) + torch.testing.assert_close(from_us, from_ffmpeg) if __name__ == "__main__": From eb2a86cfe3f3503309baaa9e25f577cca36d6e3c Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 2 Apr 2025 13:05:29 +0100 Subject: [PATCH 11/18] Stuff --- src/torchcodec/_core/Encoder.cpp | 12 ++++++++++-- test/decoders/test_ops.py | 23 ++++++++++++++++++----- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index fdee4c5f..f537b887 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -10,6 +10,8 @@ namespace facebook::torchcodec { Encoder::~Encoder() {} +// TODO-ENCODING: disable ffmpeg logs by default + Encoder::Encoder(int sampleRate, std::string_view fileName) : sampleRate_(sampleRate) { AVFormatContext* avFormatContext = nullptr; @@ -40,8 +42,8 @@ Encoder::Encoder(int sampleRate, std::string_view fileName) // This will use the default bit rate // TODO-ENCODING Should let user choose for compressed formats like mp3. - // avCodecContext_->bit_rate = 0; - avCodecContext_->bit_rate = 24000; + // avCodecContext_->bit_rate = 64000; + avCodecContext_->bit_rate = 0; // FFmpeg will raise a reasonably informative error if the desired sample rate // isn't supported by the encoder. @@ -134,6 +136,7 @@ void Encoder::encode(const torch::Tensor& wf) { auto numSamplesToEncode = std::min(numSamplesPerFrame, numSamples - numEncodedSamples); auto numBytesToEncode = numSamplesToEncode * numBytesPerSample; + avFrame->nb_samples = std::min(static_cast(avCodecContext_->frame_size), numSamplesToEncode); for (int ch = 0; ch < numChannels; ch++) { memcpy( @@ -160,6 +163,11 @@ void Encoder::encode_inner_loop( AutoAVPacket& autoAVPacket, const UniqueAVFrame& avFrame) { auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get()); +// if (avFrame.get()) { +// printf("Sending frame with %d samples\n", avFrame->nb_samples); +// } else { +// printf("Flushing\n"); +// } TORCH_CHECK( status == AVSUCCESS, "Error while sending frame: ", diff --git a/test/decoders/test_ops.py b/test/decoders/test_ops.py index 7cb61349..8b2ce243 100644 --- a/test/decoders/test_ops.py +++ b/test/decoders/test_ops.py @@ -940,8 +940,22 @@ def decode(self, source) -> torch.Tensor: ) return frames - def test_round_trip(self, tmp_path): - asset = SINE_MONO_S32 + # def test_round_trip(self, tmp_path): + # asset = NASA_AUDIO_MP3 + + # encoded_path = tmp_path / "output.mp3" + # encoder = create_encoder( + # sample_rate=asset.sample_rate, filename=str(encoded_path) + # ) + + # source_samples = self.decode(asset) + # encode(encoder, source_samples) + + # torch.testing.assert_close(self.decode(encoded_path), source_samples) + + def test_against_cli(self, tmp_path): + + asset = NASA_AUDIO_MP3 encoded_by_ffmpeg = tmp_path / "ffmpeg_output.mp3" encoded_by_us = tmp_path / "our_output.mp3" @@ -951,9 +965,10 @@ def test_round_trip(self, tmp_path): "-i", str(asset.path), # '-vn', - # '-ar', '44100', # Set audio sampling rate + # '-ar', '16000', # Set audio sampling rate # '-ac', '2', # Set number of audio channels # '-b:a', '192k', # Set audio bitrate + '-b:a', '0', # Set audio bitrate str(encoded_by_ffmpeg), ] subprocess.run(command, check=True) @@ -964,8 +979,6 @@ def test_round_trip(self, tmp_path): encode(encoder, self.decode(asset)) - print(encoded_by_ffmpeg) - print(encoded_by_us) from_ffmpeg = self.decode(encoded_by_ffmpeg) from_us = self.decode(encoded_by_us) torch.testing.assert_close(from_us, from_ffmpeg) From 3cec7611b887b32704facd89ffa6361b39fc6cd9 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 2 Apr 2025 14:03:11 +0100 Subject: [PATCH 12/18] Add tests --- src/torchcodec/_core/Encoder.cpp | 73 ++++++++++++++++------------- src/torchcodec/_core/Encoder.h | 9 ++-- src/torchcodec/_core/custom_ops.cpp | 15 +++--- src/torchcodec/_core/ops.py | 6 ++- test/decoders/test_ops.py | 69 +++++++++++++++------------ 5 files changed, 94 insertions(+), 78 deletions(-) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index f537b887..bdeaf420 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -12,14 +12,19 @@ Encoder::~Encoder() {} // TODO-ENCODING: disable ffmpeg logs by default -Encoder::Encoder(int sampleRate, std::string_view fileName) - : sampleRate_(sampleRate) { +Encoder::Encoder( + const torch::Tensor wf, + int sampleRate, + std::string_view fileName) + : wf_(wf), sampleRate_(sampleRate) { AVFormatContext* avFormatContext = nullptr; avformat_alloc_output_context2( &avFormatContext, nullptr, nullptr, fileName.data()); TORCH_CHECK(avFormatContext != nullptr, "Couldn't allocate AVFormatContext."); avFormatContext_.reset(avFormatContext); + // TODO-ENCODING: Should also support encoding into bytes (use + // AVIOBytesContext) TORCH_CHECK( !(avFormatContext->oformat->flags & AVFMT_NOFILE), "AVFMT_NOFILE is set. We only support writing to a file."); @@ -31,7 +36,7 @@ Encoder::Encoder(int sampleRate, std::string_view fileName) getFFMPEGErrorStringFromErrorCode(status)); // We use the AVFormatContext's default codec for that - // specificavcodec_parameters_from_context format/container. + // specific format/container. const AVCodec* avCodec = avcodec_find_encoder(avFormatContext_->oformat->audio_codec); TORCH_CHECK(avCodec != nullptr, "Codec not found"); @@ -40,9 +45,10 @@ Encoder::Encoder(int sampleRate, std::string_view fileName) TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context."); avCodecContext_.reset(avCodecContext); - // This will use the default bit rate - // TODO-ENCODING Should let user choose for compressed formats like mp3. - // avCodecContext_->bit_rate = 64000; + // TODO-ENCODING I think this sets the bit rate to the minimum supported. + // That's not what the ffmpeg CLI would choose by default, so we should try to + // do the same. + // TODO-ENCODING Should also let user choose for compressed formats like mp3. avCodecContext_->bit_rate = 0; // FFmpeg will raise a reasonably informative error if the desired sample rate @@ -58,8 +64,19 @@ Encoder::Encoder(int sampleRate, std::string_view fileName) // libswresample. avCodecContext_->sample_fmt = AV_SAMPLE_FMT_FLTP; + auto numChannels = wf_.sizes()[0]; + TORCH_CHECK( + // TODO-ENCODING is this even true / needed? We can probably support more + // with non-planar data? + numChannels <= AV_NUM_DATA_POINTERS, + "Trying to encode ", + numChannels, + " channels, but FFmpeg only supports ", + AV_NUM_DATA_POINTERS, + " channels per frame."); + AVChannelLayout channel_layout; - av_channel_layout_default(&channel_layout, 2); + av_channel_layout_default(&channel_layout, numChannels); avCodecContext_->ch_layout = channel_layout; status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr); @@ -79,7 +96,7 @@ Encoder::Encoder(int sampleRate, std::string_view fileName) avcodec_parameters_from_context(avStream_->codecpar, avCodecContext_.get()); } -void Encoder::encode(const torch::Tensor& wf) { +void Encoder::encode() { UniqueAVFrame avFrame(av_frame_alloc()); TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame."); avFrame->nb_samples = avCodecContext_->frame_size; @@ -101,24 +118,13 @@ void Encoder::encode(const torch::Tensor& wf) { AutoAVPacket autoAVPacket; - uint8_t* pWf = static_cast(wf.data_ptr()); - auto numChannels = wf.sizes()[0]; - auto numSamples = wf.sizes()[1]; // per channel + uint8_t* pwf = static_cast(wf_.data_ptr()); + auto numSamples = wf_.sizes()[1]; // per channel auto numEncodedSamples = 0; // per channel auto numSamplesPerFrame = static_cast(avCodecContext_->frame_size); // per channel - auto numBytesPerSample = wf.element_size(); - auto numBytesPerChannel = wf.sizes()[1] * numBytesPerSample; - - TORCH_CHECK( - // TODO-ENCODING is this even true / needed? We can probably support more - // with non-planar data? - numChannels <= AV_NUM_DATA_POINTERS, - "Trying to encode ", - numChannels, - " channels, but FFmpeg only supports ", - AV_NUM_DATA_POINTERS, - " channels per frame."); + auto numBytesPerSample = wf_.element_size(); + auto numBytesPerChannel = numSamples * numBytesPerSample; status = avformat_write_header(avFormatContext_.get(), nullptr); TORCH_CHECK( @@ -136,16 +142,22 @@ void Encoder::encode(const torch::Tensor& wf) { auto numSamplesToEncode = std::min(numSamplesPerFrame, numSamples - numEncodedSamples); auto numBytesToEncode = numSamplesToEncode * numBytesPerSample; - avFrame->nb_samples = std::min(static_cast(avCodecContext_->frame_size), numSamplesToEncode); - for (int ch = 0; ch < numChannels; ch++) { + for (int ch = 0; ch < wf_.sizes()[0]; ch++) { memcpy( - avFrame->data[ch], pWf + ch * numBytesPerChannel, numBytesToEncode); + avFrame->data[ch], pwf + ch * numBytesPerChannel, numBytesToEncode); } - pWf += numBytesToEncode; + pwf += numBytesToEncode; + + // Above, we set the AVFrame's .nb_samples to AVCodecContext.frame_size so + // that the frame buffers are allocated to a big enough size. Here, we reset + // it to the exact number of samples that need to be encoded, otherwise the + // encoded frame would contain more samples than necessary and our results + // wouldn't match the ffmpeg CLI. + avFrame->nb_samples = numSamplesToEncode; encode_inner_loop(autoAVPacket, avFrame); - avFrame->pts += avFrame->nb_samples; + avFrame->pts += numSamplesToEncode; numEncodedSamples += numSamplesToEncode; } TORCH_CHECK(numEncodedSamples == numSamples, "Hmmmmmm something went wrong."); @@ -163,11 +175,6 @@ void Encoder::encode_inner_loop( AutoAVPacket& autoAVPacket, const UniqueAVFrame& avFrame) { auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get()); -// if (avFrame.get()) { -// printf("Sending frame with %d samples\n", avFrame->nb_samples); -// } else { -// printf("Flushing\n"); -// } TORCH_CHECK( status == AVSUCCESS, "Error while sending frame: ", diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index 85176ae1..7e833f29 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -7,12 +7,8 @@ class Encoder { public: ~Encoder(); - // TODO Are we OK passing a string_view to the constructor? - // TODO fileName should be optional. - // TODO doesn't make much sense to pass fileName and the wf tensor in 2 - // different calls. Same with sampleRate. - Encoder(int sampleRate, std::string_view fileName); - void encode(const torch::Tensor& wf); + Encoder(const torch::Tensor wf, int sampleRate, std::string_view fileName); + void encode(); private: void encode_inner_loop( @@ -31,5 +27,6 @@ class Encoder { // resample the waveform internally to match them, but that's not in scope for // an initial version (if at all). int sampleRate_; + const torch::Tensor wf_; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index e37618cc..aae9a1e7 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -28,8 +28,8 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.impl_abstract_pystub( "torchcodec._core.ops", "//pytorch/torchcodec:torchcodec"); m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor"); - m.def("create_encoder(int sample_rate, str filename) -> Tensor"); - m.def("encode(Tensor(a!) encoder, Tensor wf) -> ()"); + m.def("create_encoder(Tensor wf, int sample_rate, str filename) -> Tensor"); + m.def("encode(Tensor(a!) encoder) -> ()"); m.def( "create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor"); m.def("_convert_to_tensor(int decoder_ptr) -> Tensor"); @@ -194,15 +194,18 @@ at::Tensor create_from_file( return wrapDecoderPointerToTensor(std::move(uniqueDecoder)); } -at::Tensor create_encoder(int64_t sample_rate, std::string_view file_name) { +at::Tensor create_encoder( + const at::Tensor wf, + int64_t sample_rate, + std::string_view file_name) { std::unique_ptr uniqueEncoder = - std::make_unique(static_cast(sample_rate), file_name); + std::make_unique(wf, static_cast(sample_rate), file_name); return wrapEncoderPointerToTensor(std::move(uniqueEncoder)); } -void encode(at::Tensor& encoder, const at::Tensor& wf) { +void encode(at::Tensor& encoder) { auto encoder_ = unwrapTensorToGetEncoder(encoder); - encoder_->encode(wf); + encoder_->encode(); } // Create a VideoDecoder from the actual bytes of a video and wrap the pointer diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index 529a0bb8..47063aff 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -160,12 +160,14 @@ def create_from_file_abstract(filename: str, seek_mode: Optional[str]) -> torch. @register_fake("torchcodec_ns::create_encoder") -def create_encoder_abstract(sample_rate: int, filename: str) -> torch.Tensor: +def create_encoder_abstract( + wf: torch.Tensor, sample_rate: int, filename: str +) -> torch.Tensor: return torch.empty([], dtype=torch.long) @register_fake("torchcodec_ns::encode") -def encode_abstract(encoder: torch.Tensor, wf: torch.Tensor) -> torch.Tensor: +def encode_abstract(encoder: torch.Tensor) -> torch.Tensor: return torch.empty([], dtype=torch.long) diff --git a/test/decoders/test_ops.py b/test/decoders/test_ops.py index 8b2ce243..3a6af0fb 100644 --- a/test/decoders/test_ops.py +++ b/test/decoders/test_ops.py @@ -940,48 +940,55 @@ def decode(self, source) -> torch.Tensor: ) return frames - # def test_round_trip(self, tmp_path): - # asset = NASA_AUDIO_MP3 - - # encoded_path = tmp_path / "output.mp3" - # encoder = create_encoder( - # sample_rate=asset.sample_rate, filename=str(encoded_path) - # ) - - # source_samples = self.decode(asset) - # encode(encoder, source_samples) + def test_round_trip(self, tmp_path): + # Check that decode(encode(samples)) == samples + asset = NASA_AUDIO_MP3 + source_samples = self.decode(asset) - # torch.testing.assert_close(self.decode(encoded_path), source_samples) + encoded_path = tmp_path / "output.mp3" + encoder = create_encoder( + wf=source_samples, sample_rate=asset.sample_rate, filename=str(encoded_path) + ) + encode(encoder) - def test_against_cli(self, tmp_path): + # TODO-ENCODING: tol should be stricter. We need to increase the encoded + # bitrate, and / or encode into a lossless format. + torch.testing.assert_close( + self.decode(encoded_path), source_samples, rtol=0, atol=0.07 + ) - asset = NASA_AUDIO_MP3 + # TODO-ENCODING: test more encoding formats + @pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32)) + def test_against_cli(self, asset, tmp_path): + # Encodes samples with our encoder and with the FFmpeg CLI, and checks + # that both decoded outputs are equal encoded_by_ffmpeg = tmp_path / "ffmpeg_output.mp3" encoded_by_us = tmp_path / "our_output.mp3" - command = [ - "ffmpeg", - "-i", - str(asset.path), - # '-vn', - # '-ar', '16000', # Set audio sampling rate - # '-ac', '2', # Set number of audio channels - # '-b:a', '192k', # Set audio bitrate - '-b:a', '0', # Set audio bitrate - str(encoded_by_ffmpeg), - ] - subprocess.run(command, check=True) + subprocess.run( + [ + "ffmpeg", + "-i", + str(asset.path), + "-b:a", + "0", # bitrate hardcoded to 0, see corresponding TODO. + str(encoded_by_ffmpeg), + ], + capture_output=True, + check=True, + ) encoder = create_encoder( - sample_rate=asset.sample_rate, filename=str(encoded_by_us) + wf=self.decode(asset), + sample_rate=asset.sample_rate, + filename=str(encoded_by_us), ) + encode(encoder) - encode(encoder, self.decode(asset)) - - from_ffmpeg = self.decode(encoded_by_ffmpeg) - from_us = self.decode(encoded_by_us) - torch.testing.assert_close(from_us, from_ffmpeg) + torch.testing.assert_close( + self.decode(encoded_by_ffmpeg), self.decode(encoded_by_us) + ) if __name__ == "__main__": From 8c5479c6d57f0ccf435a5194e91d67ca16234fec Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 2 Apr 2025 14:09:06 +0100 Subject: [PATCH 13/18] Flags --- src/torchcodec/_core/CMakeLists.txt | 3 +-- src/torchcodec/_core/Encoder.h | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index 465e893d..464d185a 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -8,8 +8,7 @@ find_package(pybind11 REQUIRED) find_package(Torch REQUIRED) find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development) -# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Werror ${TORCH_CXX_FLAGS}") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Werror ${TORCH_CXX_FLAGS}") function(make_torchcodec_sublibrary library_name diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index 7e833f29..112ac83f 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -19,6 +19,7 @@ class Encoder { UniqueAVCodecContext avCodecContext_; AVStream* avStream_; + const torch::Tensor wf_; // The *output* sample rate. We can't really decide for the user what it // should be. Particularly, the sample rate of the input waveform should match // this, and that's up to the user. If sample rates don't match, encoding will @@ -27,6 +28,5 @@ class Encoder { // resample the waveform internally to match them, but that's not in scope for // an initial version (if at all). int sampleRate_; - const torch::Tensor wf_; }; } // namespace facebook::torchcodec From f609052fdb96678ac6377303bbc68fd16007a664 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 2 Apr 2025 14:45:06 +0100 Subject: [PATCH 14/18] hopefully fix ffmpeg4 --- src/torchcodec/_core/Encoder.cpp | 15 ++++--------- src/torchcodec/_core/FFMPEGCommon.cpp | 32 +++++++++++++++++++++++++++ src/torchcodec/_core/FFMPEGCommon.h | 8 +++++++ 3 files changed, 44 insertions(+), 11 deletions(-) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index bdeaf420..f0fc8464 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -64,7 +64,7 @@ Encoder::Encoder( // libswresample. avCodecContext_->sample_fmt = AV_SAMPLE_FMT_FLTP; - auto numChannels = wf_.sizes()[0]; + int numChannels = static_cast(wf_.sizes()[0]); TORCH_CHECK( // TODO-ENCODING is this even true / needed? We can probably support more // with non-planar data? @@ -75,9 +75,7 @@ Encoder::Encoder( AV_NUM_DATA_POINTERS, " channels per frame."); - AVChannelLayout channel_layout; - av_channel_layout_default(&channel_layout, numChannels); - avCodecContext_->ch_layout = channel_layout; + setDefaultChannelLayout(avCodecContext_, numChannels); status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr); TORCH_CHECK(status == AVSUCCESS, getFFMPEGErrorStringFromErrorCode(status)); @@ -103,14 +101,9 @@ void Encoder::encode() { avFrame->format = avCodecContext_->sample_fmt; avFrame->sample_rate = avCodecContext_->sample_rate; avFrame->pts = 0; - auto status = - av_channel_layout_copy(&avFrame->ch_layout, &avCodecContext_->ch_layout); - TORCH_CHECK( - status == AVSUCCESS, - "Couldn't copy channel layout to avFrame: ", - getFFMPEGErrorStringFromErrorCode(status)); + setChannelLayout(avFrame, avCodecContext_); - status = av_frame_get_buffer(avFrame.get(), 0); + auto status = av_frame_get_buffer(avFrame.get(), 0); TORCH_CHECK( status == AVSUCCESS, "Couldn't allocate avFrame's buffers: ", diff --git a/src/torchcodec/_core/FFMPEGCommon.cpp b/src/torchcodec/_core/FFMPEGCommon.cpp index 33c8b484..96517d3b 100644 --- a/src/torchcodec/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/_core/FFMPEGCommon.cpp @@ -74,6 +74,38 @@ int getNumChannels(const UniqueAVCodecContext& avCodecContext) { #endif } +void setDefaultChannelLayout( + UniqueAVCodecContext& avCodecContext, + int numChannels) { +#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4 + AVChannelLayout channel_layout; + av_channel_layout_default(&channel_layout, numChannels); + avCodecContext->ch_layout = channel_layout; + +#else + uint64_t channel_layout = av_get_default_channel_layout(numChannels); + avCodecContext->channel_layout = channel_layout; + avCodecContext->channels = numChannels; +#endif +} + +void setChannelLayout( + UniqueAVFrame& dstAVFrame, + const UniqueAVCodecContext& avCodecContext) { +#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4 + auto status = av_channel_layout_copy( + &dstAvFrame->ch_layout, &avCodecContext->ch_layout); + TORCH_CHECK( + status == AVSUCCESS, + "Couldn't copy channel layout to avFrame: ", + getFFMPEGErrorStringFromErrorCode(status)); +#else + dstAVFrame->channel_layout = avCodecContext->channel_layout; + dstAVFrame->channels = avCodecContext->channels; + +#endif +} + void setChannelLayout( UniqueAVFrame& dstAVFrame, const UniqueAVFrame& srcAVFrame) { diff --git a/src/torchcodec/_core/FFMPEGCommon.h b/src/torchcodec/_core/FFMPEGCommon.h index 9fe9ea4f..790a4043 100644 --- a/src/torchcodec/_core/FFMPEGCommon.h +++ b/src/torchcodec/_core/FFMPEGCommon.h @@ -147,6 +147,14 @@ int64_t getDuration(const UniqueAVFrame& frame); int getNumChannels(const UniqueAVFrame& avFrame); int getNumChannels(const UniqueAVCodecContext& avCodecContext); +void setDefaultChannelLayout( + UniqueAVCodecContext& avCodecContext, + int numChannels); + +void setChannelLayout( + UniqueAVFrame& dstAVFrame, + const UniqueAVCodecContext& avCodecContext); + void setChannelLayout( UniqueAVFrame& dstAVFrame, const UniqueAVFrame& srcAVFrame); From 42f5160df0c9e8fa57ca390f11e42048b7743f9d Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 2 Apr 2025 15:06:16 +0100 Subject: [PATCH 15/18] Fix MacOS build?? --- src/torchcodec/_core/Encoder.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index f0fc8464..3544dc6d 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -132,8 +132,8 @@ void Encoder::encode() { "Couldn't make AVFrame writable: ", getFFMPEGErrorStringFromErrorCode(status)); - auto numSamplesToEncode = - std::min(numSamplesPerFrame, numSamples - numEncodedSamples); + auto numSamplesToEncode = std::min( + numSamplesPerFrame, static_cast(numSamples - numEncodedSamples)); auto numBytesToEncode = numSamplesToEncode * numBytesPerSample; for (int ch = 0; ch < wf_.sizes()[0]; ch++) { From 52c4d5409b8d41132ce3c8c242ca5bfe07c7175b Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 2 Apr 2025 15:34:40 +0100 Subject: [PATCH 16/18] more tests --- src/torchcodec/_core/Encoder.cpp | 13 ++++++++++--- test/decoders/test_ops.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 3544dc6d..fd513628 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -17,10 +17,19 @@ Encoder::Encoder( int sampleRate, std::string_view fileName) : wf_(wf), sampleRate_(sampleRate) { + TORCH_CHECK( + wf_.dtype() == torch::kFloat32, + "waveform must have float32 dtype, got ", + wf_.dtype()); + TORCH_CHECK( + wf_.dim() == 2, "waveform must have 2 dimensions, got ", wf_.dim()); AVFormatContext* avFormatContext = nullptr; avformat_alloc_output_context2( &avFormatContext, nullptr, nullptr, fileName.data()); - TORCH_CHECK(avFormatContext != nullptr, "Couldn't allocate AVFormatContext."); + TORCH_CHECK( + avFormatContext != nullptr, + "Couldn't allocate AVFormatContext. ", + "Check the desired extension?"); avFormatContext_.reset(avFormatContext); // TODO-ENCODING: Should also support encoding into bytes (use @@ -51,8 +60,6 @@ Encoder::Encoder( // TODO-ENCODING Should also let user choose for compressed formats like mp3. avCodecContext_->bit_rate = 0; - // FFmpeg will raise a reasonably informative error if the desired sample rate - // isn't supported by the encoder. avCodecContext_->sample_rate = sampleRate_; // Note: This is the format of the **input** waveform. This doesn't determine diff --git a/test/decoders/test_ops.py b/test/decoders/test_ops.py index 3c86d815..db2fcaed 100644 --- a/test/decoders/test_ops.py +++ b/test/decoders/test_ops.py @@ -1011,6 +1011,37 @@ def decode(self, source) -> torch.Tensor: ) return frames + def test_bad_input(self, tmp_path): + + valid_output_file = str(tmp_path / ".mp3") + + with pytest.raises(RuntimeError, match="must have float32 dtype, got int"): + create_encoder( + wf=torch.arange(10, dtype=torch.int), + sample_rate=10, + filename=valid_output_file, + ) + with pytest.raises(RuntimeError, match="must have 2 dimensions, got 1"): + create_encoder(wf=torch.rand(3), sample_rate=10, filename=valid_output_file) + + with pytest.raises(RuntimeError, match="No such file or directory"): + create_encoder( + wf=torch.rand(10, 10), sample_rate=10, filename="./bad/path.mp3" + ) + with pytest.raises(RuntimeError, match="Check the desired extension"): + create_encoder( + wf=torch.rand(10, 10), sample_rate=10, filename="./file.bad_extension" + ) + + # TODO-ENCODING: raise more informative error message when sample rate + # isn't supported + with pytest.raises(RuntimeError, match="Invalid argument"): + create_encoder( + wf=self.decode(NASA_AUDIO_MP3), + sample_rate=10, + filename=valid_output_file, + ) + def test_round_trip(self, tmp_path): # Check that decode(encode(samples)) == samples asset = NASA_AUDIO_MP3 From 061c60f5848b99b7c671f5e6a46c100a4d5f07a8 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 3 Apr 2025 10:20:49 +0100 Subject: [PATCH 17/18] Address some comments --- src/torchcodec/_core/Encoder.cpp | 73 +++++++++++--------- src/torchcodec/_core/Encoder.h | 16 +++-- src/torchcodec/_core/FFMPEGCommon.h | 4 +- src/torchcodec/_core/SingleStreamDecoder.cpp | 2 +- src/torchcodec/_core/SingleStreamDecoder.h | 2 +- src/torchcodec/_core/__init__.py | 4 +- src/torchcodec/_core/custom_ops.cpp | 46 +++++++----- src/torchcodec/_core/ops.py | 16 +++-- test/test_ops.py | 24 ++++--- 9 files changed, 104 insertions(+), 83 deletions(-) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index fd513628..a86423a2 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -1,18 +1,13 @@ #include "src/torchcodec/_core/Encoder.h" #include "torch/types.h" -extern "C" { -#include -#include -} - namespace facebook::torchcodec { -Encoder::~Encoder() {} +AudioEncoder::~AudioEncoder() {} // TODO-ENCODING: disable ffmpeg logs by default -Encoder::Encoder( +AudioEncoder::AudioEncoder( const torch::Tensor wf, int sampleRate, std::string_view fileName) @@ -24,12 +19,13 @@ Encoder::Encoder( TORCH_CHECK( wf_.dim() == 2, "waveform must have 2 dimensions, got ", wf_.dim()); AVFormatContext* avFormatContext = nullptr; - avformat_alloc_output_context2( + auto status = avformat_alloc_output_context2( &avFormatContext, nullptr, nullptr, fileName.data()); TORCH_CHECK( avFormatContext != nullptr, "Couldn't allocate AVFormatContext. ", - "Check the desired extension?"); + "Check the desired extension? ", + getFFMPEGErrorStringFromErrorCode(status)); avFormatContext_.reset(avFormatContext); // TODO-ENCODING: Should also support encoding into bytes (use @@ -37,8 +33,7 @@ Encoder::Encoder( TORCH_CHECK( !(avFormatContext->oformat->flags & AVFMT_NOFILE), "AVFMT_NOFILE is set. We only support writing to a file."); - auto status = - avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE); + status = avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE); TORCH_CHECK( status >= 0, "avio_open failed: ", @@ -85,7 +80,10 @@ Encoder::Encoder( setDefaultChannelLayout(avCodecContext_, numChannels); status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr); - TORCH_CHECK(status == AVSUCCESS, getFFMPEGErrorStringFromErrorCode(status)); + TORCH_CHECK( + status == AVSUCCESS, + "avcodec_open2 failed: ", + getFFMPEGErrorStringFromErrorCode(status)); TORCH_CHECK( avCodecContext_->frame_size > 0, @@ -96,12 +94,18 @@ Encoder::Encoder( // We're allocating the stream here. Streams are meant to be freed by // avformat_free_context(avFormatContext), which we call in the // avFormatContext_'s destructor. - avStream_ = avformat_new_stream(avFormatContext_.get(), nullptr); - TORCH_CHECK(avStream_ != nullptr, "Couldn't create new stream."); - avcodec_parameters_from_context(avStream_->codecpar, avCodecContext_.get()); + AVStream* avStream = avformat_new_stream(avFormatContext_.get(), nullptr); + TORCH_CHECK(avStream != nullptr, "Couldn't create new stream."); + status = avcodec_parameters_from_context( + avStream->codecpar, avCodecContext_.get()); + TORCH_CHECK( + status == AVSUCCESS, + "avcodec_parameters_from_context failed: ", + getFFMPEGErrorStringFromErrorCode(status)); + streamIndex_ = avStream->index; } -void Encoder::encode() { +void AudioEncoder::encode() { UniqueAVFrame avFrame(av_frame_alloc()); TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame."); avFrame->nb_samples = avCodecContext_->frame_size; @@ -119,12 +123,11 @@ void Encoder::encode() { AutoAVPacket autoAVPacket; uint8_t* pwf = static_cast(wf_.data_ptr()); - auto numSamples = wf_.sizes()[1]; // per channel - auto numEncodedSamples = 0; // per channel - auto numSamplesPerFrame = - static_cast(avCodecContext_->frame_size); // per channel - auto numBytesPerSample = wf_.element_size(); - auto numBytesPerChannel = numSamples * numBytesPerSample; + int numSamples = static_cast(wf_.sizes()[1]); // per channel + int numEncodedSamples = 0; // per channel + int numSamplesPerFrame = avCodecContext_->frame_size; // per channel + int numBytesPerSample = wf_.element_size(); + int numBytesPerChannel = numSamples * numBytesPerSample; status = avformat_write_header(avFormatContext_.get(), nullptr); TORCH_CHECK( @@ -139,12 +142,12 @@ void Encoder::encode() { "Couldn't make AVFrame writable: ", getFFMPEGErrorStringFromErrorCode(status)); - auto numSamplesToEncode = std::min( - numSamplesPerFrame, static_cast(numSamples - numEncodedSamples)); - auto numBytesToEncode = numSamplesToEncode * numBytesPerSample; + int numSamplesToEncode = + std::min(numSamplesPerFrame, numSamples - numEncodedSamples); + int numBytesToEncode = numSamplesToEncode * numBytesPerSample; for (int ch = 0; ch < wf_.sizes()[0]; ch++) { - memcpy( + std::memcpy( avFrame->data[ch], pwf + ch * numBytesPerChannel, numBytesToEncode); } pwf += numBytesToEncode; @@ -155,14 +158,14 @@ void Encoder::encode() { // encoded frame would contain more samples than necessary and our results // wouldn't match the ffmpeg CLI. avFrame->nb_samples = numSamplesToEncode; - encode_inner_loop(autoAVPacket, avFrame); + encodeInnerLoop(autoAVPacket, avFrame); - avFrame->pts += numSamplesToEncode; + avFrame->pts += static_cast(numSamplesToEncode); numEncodedSamples += numSamplesToEncode; } TORCH_CHECK(numEncodedSamples == numSamples, "Hmmmmmm something went wrong."); - encode_inner_loop(autoAVPacket, UniqueAVFrame(nullptr)); // flush + flushBuffers(); status = av_write_trailer(avFormatContext_.get()); TORCH_CHECK( @@ -171,7 +174,7 @@ void Encoder::encode() { getFFMPEGErrorStringFromErrorCode(status)); } -void Encoder::encode_inner_loop( +void AudioEncoder::encodeInnerLoop( AutoAVPacket& autoAVPacket, const UniqueAVFrame& avFrame) { auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get()); @@ -199,10 +202,7 @@ void Encoder::encode_inner_loop( "Error receiving packet: ", getFFMPEGErrorStringFromErrorCode(status)); - // TODO-ENCODING why are these 2 lines needed?? - av_packet_rescale_ts( - packet.get(), avCodecContext_->time_base, avStream_->time_base); - packet->stream_index = avStream_->index; + packet->stream_index = streamIndex_; status = av_interleaved_write_frame(avFormatContext_.get(), packet.get()); TORCH_CHECK( @@ -211,4 +211,9 @@ void Encoder::encode_inner_loop( getFFMPEGErrorStringFromErrorCode(status)); } } + +void AudioEncoder::flushBuffers() { + AutoAVPacket autoAVPacket; + encodeInnerLoop(autoAVPacket, UniqueAVFrame(nullptr)); +} } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index 112ac83f..f0621fe5 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -3,21 +3,25 @@ #include "src/torchcodec/_core/FFMPEGCommon.h" namespace facebook::torchcodec { -class Encoder { +class AudioEncoder { public: - ~Encoder(); + ~AudioEncoder(); - Encoder(const torch::Tensor wf, int sampleRate, std::string_view fileName); + AudioEncoder( + const torch::Tensor wf, + int sampleRate, + std::string_view fileName); void encode(); private: - void encode_inner_loop( + void encodeInnerLoop( AutoAVPacket& autoAVPacket, const UniqueAVFrame& avFrame); + void flushBuffers(); - UniqueAVFormatContextForEncoding avFormatContext_; + UniqueEncodingAVFormatContext avFormatContext_; UniqueAVCodecContext avCodecContext_; - AVStream* avStream_; + int streamIndex_; const torch::Tensor wf_; // The *output* sample rate. We can't really decide for the user what it diff --git a/src/torchcodec/_core/FFMPEGCommon.h b/src/torchcodec/_core/FFMPEGCommon.h index 790a4043..fdb30962 100644 --- a/src/torchcodec/_core/FFMPEGCommon.h +++ b/src/torchcodec/_core/FFMPEGCommon.h @@ -50,10 +50,10 @@ struct Deleter { }; // Unique pointers for FFMPEG structures. -using UniqueAVFormatContextForDecoding = std::unique_ptr< +using UniqueDecodingAVFormatContext = std::unique_ptr< AVFormatContext, Deleterp>; -using UniqueAVFormatContextForEncoding = std::unique_ptr< +using UniqueEncodingAVFormatContext = std::unique_ptr< AVFormatContext, Deleter>; using UniqueAVCodecContext = std::unique_ptr< diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index efd93498..b7438f19 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -1443,7 +1443,7 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU( auto numBytesPerChannel = numSamples * av_get_bytes_per_sample(format); for (auto channel = 0; channel < numChannels; ++channel, outputChannelData += numBytesPerChannel) { - memcpy( + std::memcpy( outputChannelData, avFrame->extended_data[channel], numBytesPerChannel); diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 4b7a7dbf..f712cdbb 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -492,7 +492,7 @@ class SingleStreamDecoder { SeekMode seekMode_; ContainerMetadata containerMetadata_; - UniqueAVFormatContextForDecoding formatContext_; + UniqueDecodingAVFormatContext formatContext_; std::map streamInfos_; const int NO_ACTIVE_STREAM = -2; int activeStreamIndex_ = NO_ACTIVE_STREAM; diff --git a/src/torchcodec/_core/__init__.py b/src/torchcodec/_core/__init__.py index 2dccbca5..4be8a7de 100644 --- a/src/torchcodec/_core/__init__.py +++ b/src/torchcodec/_core/__init__.py @@ -18,12 +18,12 @@ _test_frame_pts_equality, add_audio_stream, add_video_stream, - create_encoder, + create_audio_encoder, create_from_bytes, create_from_file, create_from_file_like, create_from_tensor, - encode, + encode_audio, get_ffmpeg_library_versions, get_frame_at_index, get_frame_at_pts, diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index baaa8622..596412a8 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -28,8 +28,9 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.impl_abstract_pystub( "torchcodec._core.ops", "//pytorch/torchcodec:torchcodec"); m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor"); - m.def("create_encoder(Tensor wf, int sample_rate, str filename) -> Tensor"); - m.def("encode(Tensor(a!) encoder) -> ()"); + m.def( + "create_audio_encoder(Tensor wf, int sample_rate, str filename) -> Tensor"); + m.def("encode_audio(Tensor(a!) encoder) -> ()"); m.def( "create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor"); m.def("_convert_to_tensor(int decoder_ptr) -> Tensor"); @@ -384,35 +385,42 @@ OpsAudioFramesOutput get_frames_by_pts_in_range_audio( return makeOpsAudioFramesOutput(result); } -at::Tensor wrapEncoderPointerToTensor(std::unique_ptr uniqueEncoder) { - Encoder* encoder = uniqueEncoder.release(); +at::Tensor wrapAudioEncoderPointerToTensor( + std::unique_ptr uniqueAudioEncoder) { + AudioEncoder* encoder = uniqueAudioEncoder.release(); auto deleter = [encoder](void*) { delete encoder; }; at::Tensor tensor = - at::from_blob(encoder, {sizeof(Encoder)}, deleter, {at::kLong}); - auto encoder_ = static_cast(tensor.mutable_data_ptr()); - TORCH_CHECK_EQ(encoder_, encoder) << "Encoder=" << encoder_; + at::from_blob(encoder, {sizeof(AudioEncoder*)}, deleter, {at::kLong}); + auto encoder_ = static_cast(tensor.mutable_data_ptr()); + TORCH_CHECK_EQ(encoder_, encoder) << "AudioEncoder=" << encoder_; return tensor; } -Encoder* unwrapTensorToGetEncoder(at::Tensor& tensor) { +AudioEncoder* unwrapTensorToGetAudioEncoder(at::Tensor& tensor) { TORCH_INTERNAL_ASSERT(tensor.is_contiguous()); void* buffer = tensor.mutable_data_ptr(); - Encoder* encoder = static_cast(buffer); + AudioEncoder* encoder = static_cast(buffer); return encoder; } -at::Tensor create_encoder( +at::Tensor create_audio_encoder( const at::Tensor wf, int64_t sample_rate, std::string_view file_name) { - std::unique_ptr uniqueEncoder = - std::make_unique(wf, static_cast(sample_rate), file_name); - return wrapEncoderPointerToTensor(std::move(uniqueEncoder)); -} - -void encode(at::Tensor& encoder) { - auto encoder_ = unwrapTensorToGetEncoder(encoder); + TORCH_CHECK( + sample_rate <= std::numeric_limits::max(), + "sample_rate=", + sample_rate, + " is too large to be cast to an int."); + std::unique_ptr uniqueAudioEncoder = + std::make_unique( + wf, static_cast(sample_rate), file_name); + return wrapAudioEncoderPointerToTensor(std::move(uniqueAudioEncoder)); +} + +void encode_audio(at::Tensor& encoder) { + auto encoder_ = unwrapTensorToGetAudioEncoder(encoder); encoder_->encode(); } @@ -650,7 +658,7 @@ void scan_all_streams_to_update_metadata(at::Tensor& decoder) { TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) { m.impl("create_from_file", &create_from_file); - m.impl("create_encoder", &create_encoder); + m.impl("create_audio_encoder", &create_audio_encoder); m.impl("create_from_tensor", &create_from_tensor); m.impl("_convert_to_tensor", &_convert_to_tensor); m.impl( @@ -658,7 +666,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) { } TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { - m.impl("encode", &encode); + m.impl("encode_audio", &encode_audio); m.impl("seek_to_pts", &seek_to_pts); m.impl("add_video_stream", &add_video_stream); m.impl("_add_video_stream", &_add_video_stream); diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index 47063aff..d910fcad 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -91,10 +91,12 @@ def load_torchcodec_shared_libraries(): create_from_file = torch._dynamo.disallow_in_graph( torch.ops.torchcodec_ns.create_from_file.default ) -create_encoder = torch._dynamo.disallow_in_graph( - torch.ops.torchcodec_ns.create_encoder.default +create_audio_encoder = torch._dynamo.disallow_in_graph( + torch.ops.torchcodec_ns.create_audio_encoder.default +) +encode_audio = torch._dynamo.disallow_in_graph( + torch.ops.torchcodec_ns.encode_audio.default ) -encode = torch._dynamo.disallow_in_graph(torch.ops.torchcodec_ns.encode.default) create_from_tensor = torch._dynamo.disallow_in_graph( torch.ops.torchcodec_ns.create_from_tensor.default ) @@ -159,15 +161,15 @@ def create_from_file_abstract(filename: str, seek_mode: Optional[str]) -> torch. return torch.empty([], dtype=torch.long) -@register_fake("torchcodec_ns::create_encoder") -def create_encoder_abstract( +@register_fake("torchcodec_ns::create_audio_encoder") +def create_audio_encoder_abstract( wf: torch.Tensor, sample_rate: int, filename: str ) -> torch.Tensor: return torch.empty([], dtype=torch.long) -@register_fake("torchcodec_ns::encode") -def encode_abstract(encoder: torch.Tensor) -> torch.Tensor: +@register_fake("torchcodec_ns::encode_audio") +def encode_audio_abstract(encoder: torch.Tensor) -> torch.Tensor: return torch.empty([], dtype=torch.long) diff --git a/test/test_ops.py b/test/test_ops.py index 44003392..e301701a 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -22,12 +22,12 @@ _test_frame_pts_equality, add_audio_stream, add_video_stream, - create_encoder, + create_audio_encoder, create_from_bytes, create_from_file, create_from_file_like, create_from_tensor, - encode, + encode_audio, get_ffmpeg_library_versions, get_frame_at_index, get_frame_at_pts, @@ -1088,27 +1088,29 @@ def test_bad_input(self, tmp_path): valid_output_file = str(tmp_path / ".mp3") with pytest.raises(RuntimeError, match="must have float32 dtype, got int"): - create_encoder( + create_audio_encoder( wf=torch.arange(10, dtype=torch.int), sample_rate=10, filename=valid_output_file, ) with pytest.raises(RuntimeError, match="must have 2 dimensions, got 1"): - create_encoder(wf=torch.rand(3), sample_rate=10, filename=valid_output_file) + create_audio_encoder( + wf=torch.rand(3), sample_rate=10, filename=valid_output_file + ) with pytest.raises(RuntimeError, match="No such file or directory"): - create_encoder( + create_audio_encoder( wf=torch.rand(10, 10), sample_rate=10, filename="./bad/path.mp3" ) with pytest.raises(RuntimeError, match="Check the desired extension"): - create_encoder( + create_audio_encoder( wf=torch.rand(10, 10), sample_rate=10, filename="./file.bad_extension" ) # TODO-ENCODING: raise more informative error message when sample rate # isn't supported with pytest.raises(RuntimeError, match="Invalid argument"): - create_encoder( + create_audio_encoder( wf=self.decode(NASA_AUDIO_MP3), sample_rate=10, filename=valid_output_file, @@ -1120,10 +1122,10 @@ def test_round_trip(self, tmp_path): source_samples = self.decode(asset) encoded_path = tmp_path / "output.mp3" - encoder = create_encoder( + encoder = create_audio_encoder( wf=source_samples, sample_rate=asset.sample_rate, filename=str(encoded_path) ) - encode(encoder) + encode_audio(encoder) # TODO-ENCODING: tol should be stricter. We need to increase the encoded # bitrate, and / or encode into a lossless format. @@ -1153,12 +1155,12 @@ def test_against_cli(self, asset, tmp_path): check=True, ) - encoder = create_encoder( + encoder = create_audio_encoder( wf=self.decode(asset), sample_rate=asset.sample_rate, filename=str(encoded_by_us), ) - encode(encoder) + encode_audio(encoder) torch.testing.assert_close( self.decode(encoded_by_ffmpeg), self.decode(encoded_by_us) From 9b90894cd247a925b717b6bb5c4f26469bb54053 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 3 Apr 2025 15:59:50 +0100 Subject: [PATCH 18/18] cast --- src/torchcodec/_core/Encoder.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index a86423a2..9d5c1dea 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -126,7 +126,7 @@ void AudioEncoder::encode() { int numSamples = static_cast(wf_.sizes()[1]); // per channel int numEncodedSamples = 0; // per channel int numSamplesPerFrame = avCodecContext_->frame_size; // per channel - int numBytesPerSample = wf_.element_size(); + int numBytesPerSample = static_cast(wf_.element_size()); int numBytesPerChannel = numSamples * numBytesPerSample; status = avformat_write_header(avFormatContext_.get(), nullptr);