Skip to content

Commit 730f8f9

Browse files
committed
Super WIP encoder
1 parent 90d7409 commit 730f8f9

File tree

5 files changed

+195
-0
lines changed

5 files changed

+195
-0
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1942,4 +1942,131 @@ FrameDims getHeightAndWidthFromOptionsOrAVFrame(
19421942
videoStreamOptions.width.value_or(avFrame.width));
19431943
}
19441944

1945+
Encoder::~Encoder() {
1946+
fclose(f_);
1947+
}
1948+
1949+
Encoder::Encoder(torch::Tensor& wf) : wf_(wf) {
1950+
f_ = fopen("./coutput", "wb");
1951+
TORCH_CHECK(f_, "Could not open file");
1952+
const AVCodec* avCodec = avcodec_find_encoder(AV_CODEC_ID_MP3);
1953+
TORCH_CHECK(avCodec != nullptr, "Codec not found");
1954+
1955+
AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec);
1956+
TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context.");
1957+
avCodecContext_.reset(avCodecContext);
1958+
1959+
avCodecContext_->bit_rate = 0; // TODO
1960+
avCodecContext_->sample_fmt = AV_SAMPLE_FMT_FLTP; // TODO
1961+
avCodecContext_->sample_rate = 16000; // TODO
1962+
AVChannelLayout channel_layout;
1963+
av_channel_layout_default(&channel_layout, 2);
1964+
avCodecContext_->ch_layout = channel_layout;
1965+
1966+
auto ffmpegRet = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
1967+
TORCH_CHECK(
1968+
ffmpegRet == AVSUCCESS, getFFMPEGErrorStringFromErrorCode(ffmpegRet));
1969+
1970+
AVFrame* avFrame = av_frame_alloc();
1971+
TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame.");
1972+
avFrame_.reset(avFrame);
1973+
avFrame_->nb_samples = avCodecContext_->frame_size;
1974+
avFrame_->format = avCodecContext_->sample_fmt;
1975+
avFrame_->sample_rate = avCodecContext_->sample_rate;
1976+
1977+
ffmpegRet =
1978+
av_channel_layout_copy(&avFrame_->ch_layout, &avCodecContext_->ch_layout);
1979+
TORCH_CHECK(
1980+
ffmpegRet == AVSUCCESS,
1981+
"Couldn't copy channel layout to avFrame: ",
1982+
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
1983+
ffmpegRet = av_frame_get_buffer(avFrame_.get(), 0);
1984+
TORCH_CHECK(
1985+
ffmpegRet == AVSUCCESS,
1986+
"Couldn't allocate avFrame's buffers: ",
1987+
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
1988+
}
1989+
1990+
torch::Tensor Encoder::encode() {
1991+
AVPacket* pkt = av_packet_alloc();
1992+
if (!pkt) {
1993+
fprintf(stderr, "Could not allocate audio packet\n");
1994+
exit(1);
1995+
}
1996+
1997+
auto MAX_NUM_BYTES = 10000000; // 10Mb. TODO find a way not to pre-allocate.
1998+
int numEncodedBytes = 0;
1999+
torch::Tensor outputTensor = torch::empty({MAX_NUM_BYTES}, torch::kUInt8);
2000+
uint8_t* pOutputTensor =
2001+
static_cast<uint8_t*>(outputTensor.data_ptr<uint8_t>());
2002+
2003+
uint8_t* pWf = static_cast<uint8_t*>(wf_.data_ptr());
2004+
auto numBytesWeWroteFromWF = 0;
2005+
auto numBytesPerSample = wf_.element_size();
2006+
auto numBytesPerChannel = wf_.sizes()[1] * numBytesPerSample;
2007+
2008+
// TODO need simpler/cleaner while loop condition.
2009+
while (numBytesWeWroteFromWF < numBytesPerChannel) {
2010+
auto ffmpegRet = av_frame_make_writable(avFrame_.get());
2011+
TORCH_CHECK(
2012+
ffmpegRet == AVSUCCESS,
2013+
"Couldn't make AVFrame writable: ",
2014+
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
2015+
2016+
auto numBytesToWrite = numBytesPerSample * avCodecContext_->frame_size;
2017+
if (numBytesWeWroteFromWF + numBytesToWrite > numBytesPerChannel) {
2018+
numBytesToWrite = numBytesPerChannel - numBytesWeWroteFromWF;
2019+
}
2020+
for (int ch = 0; ch < 2; ch++) {
2021+
memcpy(
2022+
avFrame_->data[ch], pWf + ch * numBytesPerChannel, numBytesToWrite);
2023+
}
2024+
pWf += numBytesToWrite;
2025+
numBytesWeWroteFromWF += numBytesToWrite;
2026+
encode_inner_loop(pkt, pOutputTensor, &numEncodedBytes, false);
2027+
}
2028+
encode_inner_loop(pkt, pOutputTensor, &numEncodedBytes, true);
2029+
2030+
return outputTensor.narrow(
2031+
/*dim=*/0, /*start=*/0, /*length=*/numEncodedBytes);
2032+
// return outputTensor;
2033+
}
2034+
2035+
void Encoder::encode_inner_loop(
2036+
AVPacket* pkt,
2037+
uint8_t* pOutputTensor,
2038+
int* numEncodedBytes,
2039+
bool flush) {
2040+
int ffmpegRet = 0;
2041+
2042+
// TODO ewwww
2043+
if (flush) {
2044+
ffmpegRet = avcodec_send_frame(avCodecContext_.get(), nullptr);
2045+
} else {
2046+
ffmpegRet = avcodec_send_frame(avCodecContext_.get(), avFrame_.get());
2047+
}
2048+
TORCH_CHECK(
2049+
ffmpegRet == AVSUCCESS,
2050+
"Error while sending frame: ",
2051+
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
2052+
2053+
while ((ffmpegRet = avcodec_receive_packet(avCodecContext_.get(), pkt)) >=
2054+
0) {
2055+
if (ffmpegRet == AVERROR(EAGAIN) || ffmpegRet == AVERROR_EOF) {
2056+
return;
2057+
}
2058+
TORCH_CHECK(
2059+
ffmpegRet >= 0,
2060+
"Error receiving packet: ",
2061+
getFFMPEGErrorStringFromErrorCode(ffmpegRet));
2062+
2063+
fwrite(pkt->data, 1, pkt->size, f_);
2064+
2065+
memcpy(pOutputTensor + *numEncodedBytes, pkt->data, pkt->size);
2066+
*numEncodedBytes += pkt->size;
2067+
2068+
av_packet_unref(pkt);
2069+
}
2070+
}
2071+
19452072
} // namespace facebook::torchcodec

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,4 +563,24 @@ std::ostream& operator<<(
563563
std::ostream& os,
564564
const VideoDecoder::DecodeStats& stats);
565565

566+
class Encoder {
567+
public:
568+
~Encoder();
569+
570+
explicit Encoder(torch::Tensor& wf);
571+
torch::Tensor encode();
572+
573+
private:
574+
void encode_inner_loop(
575+
AVPacket* pkt,
576+
uint8_t* pOutputTensor,
577+
int* numEncodedBytes,
578+
bool flush);
579+
580+
torch::Tensor wf_;
581+
UniqueAVCodecContext avCodecContext_;
582+
UniqueAVFrame avFrame_;
583+
FILE* f_;
584+
};
585+
566586
} // namespace facebook::torchcodec

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ TORCH_LIBRARY(torchcodec_ns, m) {
2828
"torchcodec.decoders._core.video_decoder_ops",
2929
"//pytorch/torchcodec:torchcodec");
3030
m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor");
31+
m.def("create_encoder(Tensor wf) -> Tensor");
32+
m.def("encode(Tensor(a!) encoder) -> Tensor");
3133
m.def(
3234
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
3335
m.def(
@@ -74,13 +76,31 @@ at::Tensor wrapDecoderPointerToTensor(
7476
return tensor;
7577
}
7678

79+
at::Tensor wrapEncoderPointerToTensor(std::unique_ptr<Encoder> uniqueEncoder) {
80+
Encoder* encoder = uniqueEncoder.release();
81+
82+
auto deleter = [encoder](void*) { delete encoder; };
83+
at::Tensor tensor =
84+
at::from_blob(encoder, {sizeof(Encoder)}, deleter, {at::kLong});
85+
auto encoder_ = static_cast<Encoder*>(tensor.mutable_data_ptr());
86+
TORCH_CHECK_EQ(encoder_, encoder) << "Encoder=" << encoder_;
87+
return tensor;
88+
}
89+
7790
VideoDecoder* unwrapTensorToGetDecoder(at::Tensor& tensor) {
7891
TORCH_INTERNAL_ASSERT(tensor.is_contiguous());
7992
void* buffer = tensor.mutable_data_ptr();
8093
VideoDecoder* decoder = static_cast<VideoDecoder*>(buffer);
8194
return decoder;
8295
}
8396

97+
Encoder* unwrapTensorToGetEncoder(at::Tensor& tensor) {
98+
TORCH_INTERNAL_ASSERT(tensor.is_contiguous());
99+
void* buffer = tensor.mutable_data_ptr();
100+
Encoder* encoder = static_cast<Encoder*>(buffer);
101+
return encoder;
102+
}
103+
84104
OpsFrameOutput makeOpsFrameOutput(VideoDecoder::FrameOutput& frame) {
85105
return std::make_tuple(
86106
frame.data,
@@ -125,6 +145,16 @@ at::Tensor create_from_file(
125145
return wrapDecoderPointerToTensor(std::move(uniqueDecoder));
126146
}
127147

148+
at::Tensor create_encoder(torch::Tensor& wf) {
149+
std::unique_ptr<Encoder> uniqueEncoder = std::make_unique<Encoder>(wf);
150+
return wrapEncoderPointerToTensor(std::move(uniqueEncoder));
151+
}
152+
153+
at::Tensor encode(at::Tensor& encoder) {
154+
auto encoder_ = unwrapTensorToGetEncoder(encoder);
155+
return encoder_->encode();
156+
}
157+
128158
at::Tensor create_from_tensor(
129159
at::Tensor video_tensor,
130160
std::optional<std::string_view> seek_mode) {
@@ -516,12 +546,14 @@ void scan_all_streams_to_update_metadata(at::Tensor& decoder) {
516546

517547
TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) {
518548
m.impl("create_from_file", &create_from_file);
549+
m.impl("create_encoder", &create_encoder);
519550
m.impl("create_from_tensor", &create_from_tensor);
520551
m.impl(
521552
"_get_json_ffmpeg_library_versions", &_get_json_ffmpeg_library_versions);
522553
}
523554

524555
TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
556+
m.impl("encode", &encode);
525557
m.impl("seek_to_pts", &seek_to_pts);
526558
m.impl("add_video_stream", &add_video_stream);
527559
m.impl("_add_video_stream", &_add_video_stream);

src/torchcodec/decoders/_core/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
_test_frame_pts_equality,
1818
add_audio_stream,
1919
add_video_stream,
20+
create_encoder,
2021
create_from_bytes,
2122
create_from_file,
2223
create_from_tensor,
24+
encode,
2325
get_ffmpeg_library_versions,
2426
get_frame_at_index,
2527
get_frame_at_pts,

src/torchcodec/decoders/_core/video_decoder_ops.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ def load_torchcodec_extension():
6464
create_from_file = torch._dynamo.disallow_in_graph(
6565
torch.ops.torchcodec_ns.create_from_file.default
6666
)
67+
create_encoder = torch._dynamo.disallow_in_graph(
68+
torch.ops.torchcodec_ns.create_encoder.default
69+
)
70+
encode = torch._dynamo.disallow_in_graph(torch.ops.torchcodec_ns.encode.default)
6771
create_from_tensor = torch._dynamo.disallow_in_graph(
6872
torch.ops.torchcodec_ns.create_from_tensor.default
6973
)
@@ -115,6 +119,16 @@ def create_from_file_abstract(filename: str, seek_mode: Optional[str]) -> torch.
115119
return torch.empty([], dtype=torch.long)
116120

117121

122+
@register_fake("torchcodec_ns::create_encoder")
123+
def create_encoder_abstract(wf: torch.Tensor) -> torch.Tensor:
124+
return torch.empty([], dtype=torch.long)
125+
126+
127+
@register_fake("torchcodec_ns::encode")
128+
def encode_abstract(encoder: torch.Tensor) -> torch.Tensor:
129+
return torch.empty([], dtype=torch.long)
130+
131+
118132
@register_fake("torchcodec_ns::create_from_tensor")
119133
def create_from_tensor_abstract(
120134
video_tensor: torch.Tensor, seek_mode: Optional[str]

0 commit comments

Comments
 (0)