From 2a74c4a3f0f43ed8a904908dd0b70ad626ee8210 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Tue, 5 Dec 2023 10:42:39 -0500 Subject: [PATCH 01/24] add esrgan upscaler --- README.md | 3 +- examples/cli/main.cpp | 22 +- ggml | 2 +- stable-diffusion.cpp | 650 +++++++++++++++++++++++++++++++++++++++++- stable-diffusion.h | 1 + 5 files changed, 672 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index f2c588e1e..117cd0cac 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in - LoRA support, same as [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#lora) - Latent Consistency Models support (LCM/LCM-LoRA) - Faster and memory efficient latent decoding with [TAESD](https://github.com/madebyollin/taesd) +- Upscale images generated with [ESRGAN](https://github.com/xinntao/Real-ESRGAN) - Sampling method - `Euler A` - `Euler` @@ -51,7 +52,6 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in - Implement Winograd Convolution 2D for 3x3 kernel filtering - [ ] Continuing to reduce memory usage (quantizing the weights of ggml_conv_2d) - [ ] Implement BPE Tokenizer -- [ ] Implement [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN/tree/master) upscaler - [ ] k-quants support ## Usage @@ -134,6 +134,7 @@ arguments: -m, --model [MODEL] path to model --vae [VAE] path to vae --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality) + -um, --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now. --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0) If not specified, the default is the type of the weight file. --lora-model-dir [DIR] lora model directory diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 68824dd9e..e6e2e55fe 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -59,6 +59,7 @@ struct SDParams { std::string model_path; std::string vae_path; std::string taesd_path; + std::string esrgan_path; ggml_type wtype = GGML_TYPE_COUNT; std::string lora_model_dir; std::string output_path = "output.png"; @@ -115,6 +116,7 @@ void print_usage(int argc, const char* argv[]) { printf(" -m, --model [MODEL] path to model\n"); printf(" --vae [VAE] path to vae\n"); printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n"); + printf(" -um, --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now.\n"); printf(" --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)\n"); printf(" If not specified, the default is the type of the weight file.\n"); printf(" --lora-model-dir [DIR] lora model directory\n"); @@ -185,6 +187,12 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.taesd_path = argv[i]; + } else if (arg == "--upscale-model" || arg == "-um") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.esrgan_path = argv[i]; } else if (arg == "--type") { if (++i >= argc) { invalid_arg = true; @@ -458,7 +466,7 @@ int main(int argc, const char* argv[]) { } } - StableDiffusion sd(params.n_threads, vae_decode_only, params.taesd_path, true, params.lora_model_dir, params.rng_type); + StableDiffusion sd(params.n_threads, vae_decode_only, params.taesd_path, params.esrgan_path, true, params.lora_model_dir, params.rng_type); if (!sd.load_from_file(params.model_path, params.vae_path, params.wtype, params.schedule)) { return 1; @@ -488,6 +496,18 @@ int main(int argc, const char* argv[]) { params.seed); } + if(params.esrgan_path.size() > 0) { + /* hardcoded scale factor because just RealESRGAN_x4plus_anime_6B is compatible + See also: https://github.com/xinntao/Real-ESRGAN/blob/master/inference_realesrgan.py + + To avoid this, the upscaler needs to be separated from the stable diffusion pipeline. + However, a considerable amount of work would be required for this. It might be better + to opt for a complete project refactoring that facilitates the easier assignment of parameters. + */ + params.width *= 4; + params.height *= 4; + } + if (results.size() == 0 || results.size() != params.batch_count) { LOG_ERROR("generate failed"); return 1; diff --git a/ggml b/ggml index 70474c689..f7a51f1b5 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit 70474c6890c015b53dc10a2300ae35246cc73589 +Subproject commit f7a51f1b53e85f9ac47ae522bb655963023c8776 diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 53609c872..8374304de 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -335,6 +335,39 @@ void sd_image_to_tensor(const uint8_t* image_data, } } + +void sd_split_chunk(struct ggml_tensor* input, + struct ggml_tensor* output, int x, int y) { + int64_t width = output->ne[0]; + int64_t height = output->ne[1]; + int64_t channels = output->ne[2]; + GGML_ASSERT(channels == 3 && output->type == GGML_TYPE_F32); + for (int iy = 0; iy < height; iy++) { + for (int ix = 0; ix < width; ix++) { + for (int k = 0; k < channels; k++) { + float value = ggml_tensor_get_f32(input, ix + x * width, iy + y * height, k); + ggml_tensor_set_f32(output, value, ix, iy, k); + } + } + } +} + +void sd_merge_chunk(struct ggml_tensor* input, + struct ggml_tensor* output, int x, int y) { + int64_t width = input->ne[0]; + int64_t height = input->ne[1]; + int64_t channels = input->ne[2]; + GGML_ASSERT(channels == 3 && input->type == GGML_TYPE_F32); + for (int iy = 0; iy < height; iy++) { + for (int ix = 0; ix < width; ix++) { + for (int k = 0; k < channels; k++) { + float value = ggml_tensor_get_f32(input, ix, iy, k); + ggml_tensor_set_f32(output, value, ix + x * width, iy + y * height, k); + } + } + } +} + float sd_mean(struct ggml_tensor* src) { float mean = 0.0f; int64_t nelements = ggml_nelements(src); @@ -3282,7 +3315,7 @@ struct AutoEncoderKL { }; /* - + =================================== TinyAutoEncoder =================================== References: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoder_tiny.py https://github.com/madebyollin/taesd/blob/main/taesd.py @@ -3971,6 +4004,539 @@ struct TinyAutoEncoder { } }; +/* + =================================== ESRGAN =================================== + References: + https://github.com/xinntao/Real-ESRGAN/blob/master/inference_realesrgan.py + https://github.com/XPixelGroup/BasicSR/blob/v1.4.2/basicsr/archs/rrdbnet_arch.py + +*/ + +struct ResidualDenseBlock { + int num_features; + int num_grow_ch; + ggml_tensor* conv1_w; // [num_grow_ch, num_features, 3, 3] + ggml_tensor* conv1_b; // [num_grow_ch] + + ggml_tensor* conv2_w; // [num_grow_ch, num_features + num_grow_ch, 3, 3] + ggml_tensor* conv2_b; // [num_grow_ch] + + ggml_tensor* conv3_w; // [num_grow_ch, num_features + 2 * num_grow_ch, 3, 3] + ggml_tensor* conv3_b; // [num_grow_ch] + + ggml_tensor* conv4_w; // [num_grow_ch, num_features + 3 * num_grow_ch, 3, 3] + ggml_tensor* conv4_b; // [num_grow_ch] + + ggml_tensor* conv5_w; // [num_features, num_features + 4 * num_grow_ch, 3, 3] + ggml_tensor* conv5_b; // [num_features] + + ResidualDenseBlock() {} + + ResidualDenseBlock(int num_feat, int n_grow_ch) { + num_features = num_feat; + num_grow_ch = n_grow_ch; + } + + size_t calculate_mem_size() { + size_t mem_size = num_features * num_grow_ch * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv1_w + mem_size += num_grow_ch * ggml_type_size(GGML_TYPE_F32); // conv1_b + + mem_size += (num_features + num_grow_ch) * num_grow_ch * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv2_w + mem_size += num_grow_ch * ggml_type_size(GGML_TYPE_F32); // conv2_b + + mem_size += (num_features + 2*num_grow_ch) * num_grow_ch * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv3_w + mem_size += num_grow_ch * ggml_type_size(GGML_TYPE_F32); // conv3_w + + mem_size += (num_features + 3*num_grow_ch) * num_grow_ch * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv4_w + mem_size += num_grow_ch * ggml_type_size(GGML_TYPE_F32); // conv4_w + + mem_size += (num_features + 4*num_grow_ch) * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv5_w + mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv5_w + + return mem_size; + } + + int getNumTensors() { + int num_tensors = 10; + return num_tensors; + } + + void init_params(ggml_context* ctx) { + conv1_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features, num_grow_ch); + conv1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_grow_ch); + conv2_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features + num_grow_ch, num_grow_ch); + conv2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_grow_ch); + conv3_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features + 2 * num_grow_ch, num_grow_ch); + conv3_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_grow_ch); + conv4_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features + 3 * num_grow_ch, num_grow_ch); + conv4_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_grow_ch); + conv5_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features + 4 * num_grow_ch, num_features); + conv5_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_features); + + } + + void map_by_name(std::map & tensors, std::string prefix) { + tensors[prefix + "conv1.weight"] = conv1_w; + tensors[prefix + "conv1.bias"] = conv1_b; + + tensors[prefix + "conv2.weight"] = conv2_w; + tensors[prefix + "conv2.bias"] = conv2_b; + + tensors[prefix + "conv3.weight"] = conv3_w; + tensors[prefix + "conv3.bias"] = conv3_b; + + tensors[prefix + "conv4.weight"] = conv4_w; + tensors[prefix + "conv4.bias"] = conv4_b; + + tensors[prefix + "conv5.weight"] = conv5_w; + tensors[prefix + "conv5.bias"] = conv5_b; + } + + ggml_tensor* forward(ggml_context* ctx, ggml_tensor* out_scale, ggml_tensor* x /* feat */) { + // x1 = self.lrelu(self.conv1(x)) + ggml_tensor* x1 = ggml_conv_2d(ctx, conv1_w, x, 1, 1, 1, 1, 1, 1); + x1 = ggml_add(ctx, x1, ggml_reshape_4d(ctx, conv1_b, 1, 1, conv1_b->ne[0], 1)); + x1 = ggml_leaky_relu(ctx, x1, 0.2f, true); + + // x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + ggml_tensor* x_cat = ggml_concat(ctx, x, x1); + ggml_tensor* x2 = ggml_conv_2d(ctx, conv2_w, x_cat, 1, 1, 1, 1, 1, 1); + x2 = ggml_add(ctx, x2, ggml_reshape_4d(ctx, conv2_b, 1, 1, conv2_b->ne[0], 1)); + x2 = ggml_leaky_relu(ctx, x2, 0.2f, true); + + // x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x_cat = ggml_concat(ctx, x_cat, x2); + ggml_tensor* x3 = ggml_conv_2d(ctx, conv3_w, x_cat, 1, 1, 1, 1, 1, 1); + x3 = ggml_add(ctx, x3, ggml_reshape_4d(ctx, conv3_b, 1, 1, conv3_b->ne[0], 1)); + x3 = ggml_leaky_relu(ctx, x3, 0.2f, true); + + // x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x_cat = ggml_concat(ctx, x_cat, x3); + ggml_tensor* x4 = ggml_conv_2d(ctx, conv4_w, x_cat, 1, 1, 1, 1, 1, 1); + x4 = ggml_add(ctx, x4, ggml_reshape_4d(ctx, conv4_b, 1, 1, conv4_b->ne[0], 1)); + x4 = ggml_leaky_relu(ctx, x4, 0.2f, true); + + // self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + x_cat = ggml_concat(ctx, x_cat, x4); + ggml_tensor* x5 = ggml_conv_2d(ctx, conv5_w, x_cat, 1, 1, 1, 1, 1, 1); + x5 = ggml_add(ctx, x5, ggml_reshape_4d(ctx, conv5_b, 1, 1, conv5_b->ne[0], 1)); + + // return x5 * 0.2 + x + x5 = ggml_add(ctx, ggml_scale(ctx, x5, out_scale), x); + return x5; + } +}; + +struct EsrganBlock { + ResidualDenseBlock rd_blocks[3]; + int num_residual_blocks = 3; + + EsrganBlock() {} + + EsrganBlock(int num_feat, int num_grow_ch) { + for(int i = 0; i < num_residual_blocks; i++) { + rd_blocks[i] = ResidualDenseBlock(num_feat, num_grow_ch); + } + } + + int getNumTensors() { + int num_tensors = 0; + for(int i = 0; i < num_residual_blocks; i++) { + num_tensors += rd_blocks[i].getNumTensors(); + } + return num_tensors; + } + + size_t calculate_mem_size() { + size_t mem_size = 0; + for(int i = 0; i < num_residual_blocks; i++) { + mem_size += rd_blocks[i].calculate_mem_size(); + } + return mem_size; + } + + + void init_params(ggml_context* ctx) { + for(int i = 0; i < num_residual_blocks; i++) { + rd_blocks[i].init_params(ctx); + } + } + + void map_by_name(std::map & tensors, std::string prefix) { + for(int i = 0; i < num_residual_blocks; i++) { + rd_blocks[i].map_by_name(tensors, prefix + "rdb" + std::to_string(i + 1) +"."); + } + } + + ggml_tensor* forward(ggml_context* ctx, ggml_tensor* out_scale, ggml_tensor* x) { + ggml_tensor* out = x; + for(int i = 0; i < num_residual_blocks; i++) { + // out = self.rdb...(x) + out = rd_blocks[i].forward(ctx, out_scale, out); + } + // return out * 0.2 + x + out = ggml_add(ctx, ggml_scale(ctx, out, out_scale), x); + return out; + } +}; + +struct ESRGAN { + int scale = 4; // default RealESRGAN_x4plus_anime_6B + int num_blocks = 6; // default RealESRGAN_x4plus_anime_6B + int in_channels = 3; + int out_channels = 3; + int num_features = 64; // default RealESRGAN_x4plus_anime_6B + int num_grow_ch = 32; // default RealESRGAN_x4plus_anime_6B + int tile_size = 128; // avoid cuda OOM for 4gb VRAM + + ggml_tensor* conv_first_w; // [num_features, in_channels, 3, 3] + ggml_tensor* conv_first_b; // [num_features] + + EsrganBlock body_blocks[6]; + ggml_tensor* conv_body_w; // [num_features, num_features, 3, 3] + ggml_tensor* conv_body_b; // [num_features] + + // upsample + ggml_tensor* conv_up1_w; // [num_features, num_features, 3, 3] + ggml_tensor* conv_up1_b; // [num_features] + ggml_tensor* conv_up2_w; // [num_features, num_features, 3, 3] + ggml_tensor* conv_up2_b; // [num_features] + + ggml_tensor* conv_hr_w; // [num_features, num_features, 3, 3] + ggml_tensor* conv_hr_b; // [num_features] + ggml_tensor* conv_last_w; // [out_channels, num_features, 3, 3] + ggml_tensor* conv_last_b; // [out_channels] + + ggml_context* ctx; + bool decode_only = false; + ggml_backend_buffer_t params_buffer; + ggml_backend_buffer_t compute_buffer; // for compute + struct ggml_allocr * compute_alloc = NULL; + + int memory_buffer_size = 0; + ggml_type wtype; + ggml_backend_t backend = NULL; + + ESRGAN() { + for(int i = 0; i < num_blocks; i++) { + body_blocks[i] = EsrganBlock(num_features, num_grow_ch); + } + } + + size_t calculate_mem_size() { + size_t mem_size = num_features * in_channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_first_w + mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv_first_b + + for(int i = 0; i < num_blocks; i++) { + mem_size += body_blocks[i].calculate_mem_size(); + } + + mem_size += num_features * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_body_w + mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv_body_w + + // upsample + mem_size += num_features * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_up1_w + mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv_up1_b + + mem_size += num_features * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_up2_w + mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv_up2_b + + mem_size += num_features * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_hr_w + mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv_hr_b + + mem_size += out_channels * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_last_w + mem_size += out_channels * ggml_type_size(GGML_TYPE_F32); // conv_last_b + return mem_size; + } + + int getNumTensors() { + int num_tensors = 12; + for(int i = 0; i < num_blocks; i++) { + num_tensors += body_blocks[i].getNumTensors(); + } + return num_tensors; + } + + bool init(ggml_backend_t backend_) { + this->backend = backend_; + memory_buffer_size = calculate_mem_size(); + memory_buffer_size += 1024; // overhead + int num_tensors = getNumTensors(); + + LOG_DEBUG("ESRGAN params backend buffer size = % 6.2f MB (%i tensors)", memory_buffer_size / (1024.0 * 1024.0), num_tensors); + + struct ggml_init_params params; + params.mem_size = static_cast(num_tensors * ggml_tensor_overhead()); + params.mem_buffer = NULL; + params.no_alloc = true; + + params_buffer = ggml_backend_alloc_buffer(backend, memory_buffer_size); + + ctx = ggml_init(params); + if (!ctx) { + LOG_ERROR("ggml_init() failed"); + return false; + } + return true; + } + + void alloc_params() { + ggml_allocr* alloc = ggml_allocr_new_from_buffer(params_buffer); + conv_first_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, in_channels, num_features); + conv_first_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_features); + conv_body_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features, num_features); + conv_body_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_features); + conv_up1_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features, num_features); + conv_up1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_features); + conv_up2_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features, num_features); + conv_up2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_features); + conv_hr_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features, num_features); + conv_hr_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_features); + conv_last_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features, out_channels); + conv_last_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); + + for(int i = 0; i < num_blocks; i++) { + body_blocks[i].init_params(ctx); + } + + // alloc all tensors linked to this context + for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if(t->data == NULL) { + ggml_allocr_alloc(alloc, t); + } + } + ggml_allocr_free(alloc); + } + + bool load_from_file(const std::string& file_path, ggml_backend_t backend) { + LOG_INFO("loading esrgan from '%s'", file_path.c_str()); + + if (!init(backend)) { + return false; + } + + std::map esrgan_tensors; + + ModelLoader model_loader; + if (!model_loader.init_from_file(file_path)) { + LOG_ERROR("init esrgan model loader from file failed: '%s'", file_path.c_str()); + return false; + } + + // prepare memory for the weights + { + alloc_params(); + map_by_name(esrgan_tensors); + } + + std::set tensor_names_in_file; + + auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool { + const std::string& name = tensor_storage.name; + tensor_names_in_file.insert(name); + + struct ggml_tensor* real; + if (esrgan_tensors.find(name) != esrgan_tensors.end()) { + real = esrgan_tensors[name]; + } else { + LOG_ERROR("unknown tensor '%s' in model file", name.data()); + return true; + } + + if ( + real->ne[0] != tensor_storage.ne[0] || + real->ne[1] != tensor_storage.ne[1] || + real->ne[2] != tensor_storage.ne[2] || + real->ne[3] != tensor_storage.ne[3]) { + LOG_ERROR( + "tensor '%s' has wrong shape in model file: " + "got [%d, %d, %d, %d], expected [%d, %d, %d, %d]", + name.c_str(), + (int)tensor_storage.ne[0], (int)tensor_storage.ne[1], (int)tensor_storage.ne[2], (int)tensor_storage.ne[3], + (int)real->ne[0], (int)real->ne[1], (int)real->ne[2], (int)real->ne[3]); + return false; + } + + *dst_tensor = real; + + return true; + }; + + bool success = model_loader.load_tensors(on_new_tensor_cb); + + bool some_tensor_not_init = false; + + for (auto pair : esrgan_tensors) { + if (tensor_names_in_file.find(pair.first) == tensor_names_in_file.end()) { + LOG_ERROR("tensor '%s' not in model file", pair.first.c_str()); + some_tensor_not_init = true; + } + } + + if (some_tensor_not_init) { + return false; + } + + LOG_INFO("esrgan model loaded"); + return success; + } + + void map_by_name(std::map & tensors) { + tensors["conv_first.weight"] = conv_first_w; + tensors["conv_first.bias"] = conv_first_b; + + for(int i = 0; i < num_blocks; i++) { + body_blocks[i].map_by_name(tensors, "body." + std::to_string(i) +"."); + } + + tensors["conv_body.weight"] = conv_body_w; + tensors["conv_body.bias"] = conv_body_b; + + tensors["conv_up1.weight"] = conv_up1_w; + tensors["conv_up1.bias"] = conv_up1_b; + tensors["conv_up2.weight"] = conv_up2_w; + tensors["conv_up2.bias"] = conv_up2_b; + tensors["conv_hr.weight"] = conv_hr_w; + tensors["conv_hr.bias"] = conv_hr_b; + + tensors["conv_last.weight"] = conv_last_w; + tensors["conv_last.bias"] = conv_last_b; + } + + ggml_tensor* forward(ggml_context* ctx0, ggml_tensor* out_scale, ggml_tensor* x /* feat */) { + // feat = self.conv_first(feat) + auto h = ggml_conv_2d(ctx0, conv_first_w, x, 1, 1, 1, 1, 1, 1); + h = ggml_add(ctx0, h, ggml_reshape_4d(ctx0, conv_first_b, 1, 1, conv_first_b->ne[0], 1)); + + auto body_h = h; + // self.body(feat) + for(int i = 0; i < num_blocks; i++) { + body_h = body_blocks[i].forward(ctx0, out_scale, body_h); + } + + // body_feat = self.conv_body(self.body(feat)) + body_h = ggml_conv_2d(ctx0, conv_body_w, body_h, 1, 1, 1, 1, 1, 1); + body_h = ggml_add(ctx0, body_h, ggml_reshape_4d(ctx0, conv_body_b, 1, 1, conv_body_b->ne[0], 1)); + + // feat = feat + body_feat + h = ggml_add(ctx0, h, body_h); + + // upsample + // feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) + h = ggml_upscale(ctx0, h, 2); + h = ggml_conv_2d(ctx0, conv_up1_w, h, 1, 1, 1, 1, 1, 1); + h = ggml_add(ctx0, h, ggml_reshape_4d(ctx0, conv_up1_b, 1, 1, conv_up1_b->ne[0], 1)); + h = ggml_leaky_relu(ctx0, h, 0.2f, true); + + // feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) + h = ggml_upscale(ctx0, h, 2); + h = ggml_conv_2d(ctx0, conv_up2_w, h, 1, 1, 1, 1, 1, 1); + h = ggml_add(ctx0, h, ggml_reshape_4d(ctx0, conv_up2_b, 1, 1, conv_up2_b->ne[0], 1)); + h = ggml_leaky_relu(ctx0, h, 0.2f, true); + + // self.lrelu(self.conv_hr(feat)) + h = ggml_conv_2d(ctx0, conv_hr_w, h, 1, 1, 1, 1, 1, 1); + h = ggml_add(ctx0, h, ggml_reshape_4d(ctx0, conv_hr_b, 1, 1, conv_hr_b->ne[0], 1)); + h = ggml_leaky_relu(ctx0, h, 0.2f, true); + + // out = self.conv_last(self.lrelu(self.conv_hr(feat))) + h = ggml_conv_2d(ctx0, conv_last_w, h, 1, 1, 1, 1, 1, 1); + h = ggml_add(ctx0, h, ggml_reshape_4d(ctx0, conv_last_b, 1, 1, conv_last_b->ne[0], 1)); + return h; + } + + struct ggml_cgraph * build_graph(struct ggml_tensor* x) { + // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data + static size_t buf_size = ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); + static std::vector buf(buf_size); + + struct ggml_init_params params = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf.data(), + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_allocr_alloc_graph() + }; + + struct ggml_context * ctx0 = ggml_init(params); + + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor* x_ = NULL; + struct ggml_tensor* os = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + ggml_allocr_alloc(compute_alloc, os); + if(!ggml_allocr_is_measure(compute_alloc)) { + float scale = 0.2f; + ggml_backend_tensor_set(os, &scale, 0, sizeof(scale)); + } + + // it's performing a compute, check if backend isn't cpu + if(!ggml_backend_is_cpu(backend)) { + // pass input tensors to gpu memory + x_ = ggml_dup_tensor(ctx0, x); + ggml_allocr_alloc(compute_alloc, x_); + + // pass data to device backend + if(!ggml_allocr_is_measure(compute_alloc)) { + ggml_backend_tensor_set(x_, x->data, 0, ggml_nbytes(x)); + } + } else { + x_ = x; + } + + struct ggml_tensor* out = forward(ctx0, os, x); + + ggml_build_forward_expand(gf, out); + ggml_free(ctx0); + + return gf; + } + + void begin(struct ggml_tensor* x) { + // calculate the amount of memory required + // alignment required by the backend + compute_alloc = ggml_allocr_new_measure_from_backend(backend); + + struct ggml_cgraph * gf = build_graph(x); + + // compute the required memory + size_t compute_memory_buffer_size = ggml_allocr_alloc_graph(compute_alloc, gf); + + // recreate the allocator with the required memory + ggml_allocr_free(compute_alloc); + + LOG_DEBUG("ESRGAN compute buffer size: %.2f MB", compute_memory_buffer_size / 1024.0 / 1024.0); + + compute_buffer = ggml_backend_alloc_buffer(backend, compute_memory_buffer_size); + compute_alloc = ggml_allocr_new_from_buffer(compute_buffer); + } + + void compute(struct ggml_tensor* work_result, const int n_threads, struct ggml_tensor* x) { + ggml_allocr_reset(compute_alloc); + + struct ggml_cgraph * gf = build_graph(x); + ggml_allocr_alloc_graph(compute_alloc, gf); + + if (ggml_backend_is_cpu(backend)) { + ggml_backend_cpu_set_n_threads(backend, n_threads); + } + + ggml_backend_graph_compute(backend, gf); + +#ifdef GGML_PERF + ggml_graph_print(gf); +#endif + ggml_tensor* out = gf->nodes[gf->n_nodes - 1]; + ggml_backend_tensor_get(out, work_result->data, 0, ggml_nbytes(out)); + } + + void end() { + ggml_allocr_free(compute_alloc); + ggml_backend_buffer_free(compute_buffer); + compute_alloc = NULL; + } +}; + + + float ggml_backend_tensor_get_f32(ggml_tensor* tensor) { GGML_ASSERT(tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_F16); float value; @@ -4336,6 +4902,10 @@ class StableDiffusionGGML { TinyAutoEncoder tae_first_stage; std::string taesd_path; + ESRGAN esrgan_upscaler; + std::string esrgan_path; + bool upscale_output = false; + StableDiffusionGGML() = default; StableDiffusionGGML(int n_threads, @@ -4611,6 +5181,11 @@ class StableDiffusionGGML { } LOG_DEBUG("finished loaded file"); ggml_free(ctx); + if(upscale_output) { + if(!esrgan_upscaler.load_from_file(esrgan_path, backend)) { + return false; + } + } if (use_tiny_autoencoder) { return tae_first_stage.load_from_file(taesd_path, backend); } @@ -5282,6 +5857,64 @@ class StableDiffusionGGML { } return result; } + + uint8_t* upscale(ggml_tensor* image) { + int input_width = image->ne[0]; + int input_height = image->ne[1]; + int scale = esrgan_upscaler.scale; + int tile_size = esrgan_upscaler.tile_size; + GGML_ASSERT(input_width % 2 == 0 && input_height % 2 == 0); // should be multiple of 2 + int output_width = input_width * scale; + int output_height = input_height * scale; + int tile_width = (input_width < tile_size) ? input_width : tile_size; + int tile_height = (input_height < tile_size) ? input_height : tile_size; + + struct ggml_init_params params; + params.mem_size = output_width * output_height * 3 * sizeof(float); // upscaled + params.mem_size += tile_width * tile_height * 3 * sizeof(float); // input chunk + params.mem_size += (tile_width * scale) * (tile_height * scale) * 3 * sizeof(float); // output chunk + params.mem_size += 4 * ggml_tensor_overhead(); + params.mem_buffer = NULL; + params.no_alloc = false; + + // draft context + struct ggml_context* upscale_ctx = ggml_init(params); + if (!upscale_ctx) { + LOG_ERROR("ggml_init() failed"); + return NULL; + } + + LOG_DEBUG("upscaling from (%i x %i) to (%i x %i)", input_width, input_height, output_width, output_height); + LOG_DEBUG("upscale work buffer size: %.2f MB", params.mem_size / 1024.f / 1024.f); + // tiling + int tiles_x = (input_width + tile_size - 1) / tile_size; + int tiles_y = (input_height + tile_size - 1) / tile_size; + ggml_tensor* input_chunk = ggml_new_tensor_4d(upscale_ctx, GGML_TYPE_F32, tile_width, tile_height, image->ne[2], 1); + ggml_tensor* output_chunk = ggml_new_tensor_4d(upscale_ctx, GGML_TYPE_F32, tile_width * scale, tile_height * scale, image->ne[2], 1); + ggml_tensor* upscaled_image = ggml_new_tensor_4d(upscale_ctx, GGML_TYPE_F32, output_width, output_height, image->ne[2], 1); + esrgan_upscaler.begin(input_chunk); + int num_tiles = tiles_x * tiles_y; + LOG_INFO("processing %i tiles", num_tiles); + pretty_progress(1, num_tiles, 0.0f); + int64_t t0 = ggml_time_ms(); + for(int y = 0; y < tiles_y; y ++) { + for(int x = 0; x < tiles_x; x++) { + int64_t t1 = ggml_time_ms(); + sd_split_chunk(image, input_chunk, x, y); + esrgan_upscaler.compute(output_chunk, n_threads, input_chunk); + sd_merge_chunk(output_chunk, upscaled_image, x, y); + int64_t t2 = ggml_time_ms(); + pretty_progress(x + y * tiles_x + 1, num_tiles, (t2 - t1) / 1000.0f); + } + } + esrgan_upscaler.end(); + sd_clamp(upscaled_image, 0.f, 1.f); + uint8_t* upscaled_data = sd_tensor_to_image(upscaled_image); + ggml_free(upscale_ctx); + int64_t t3 = ggml_time_ms(); + LOG_INFO("image upscaled, taking %.2fs", (t3 - t0) / 1000.0f); + return upscaled_data; + } }; /*================================================= StableDiffusion ==================================================*/ @@ -5289,6 +5922,7 @@ class StableDiffusionGGML { StableDiffusion::StableDiffusion(int n_threads, bool vae_decode_only, std::string taesd_path, + std::string esrgan_path, bool free_params_immediately, std::string lora_model_dir, RNGType rng_type) { @@ -5299,6 +5933,8 @@ StableDiffusion::StableDiffusion(int n_threads, rng_type); sd->use_tiny_autoencoder = taesd_path.size() > 0; sd->taesd_path = taesd_path; + sd->upscale_output = esrgan_path.size() > 0; + sd->esrgan_path = esrgan_path; } bool StableDiffusion::load_from_file(const std::string& model_path, @@ -5406,7 +6042,11 @@ std::vector StableDiffusion::txt2img(std::string prompt, t1 = ggml_time_ms(); struct ggml_tensor* img = sd->compute_first_stage(work_ctx, final_latents[i] /* x_0 */, true); if (img != NULL) { - results.push_back(sd_tensor_to_image(img)); + if(sd->upscale_output) { + results.push_back(sd->upscale(img)); + } else { + results.push_back(sd_tensor_to_image(img)); + } } int64_t t2 = ggml_time_ms(); LOG_INFO("latent %" PRId64 " decoded, taking %.2fs", i + 1, (t2 - t1) * 1.0f / 1000); @@ -5519,7 +6159,11 @@ std::vector StableDiffusion::img2img(const uint8_t* init_img_data, struct ggml_tensor* img = sd->compute_first_stage(work_ctx, x_0, true); if (img != NULL) { - result.push_back(sd_tensor_to_image(img)); + if(sd->upscale_output) { + result.push_back(sd->upscale(img)); + } else { + result.push_back(sd_tensor_to_image(img)); + } } int64_t t4 = ggml_time_ms(); diff --git a/stable-diffusion.h b/stable-diffusion.h index 095016c0f..5de14f407 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -39,6 +39,7 @@ class StableDiffusion { StableDiffusion(int n_threads = -1, bool vae_decode_only = false, std::string taesd_path = "", + std::string esrgan_path = "", bool free_params_immediately = false, std::string lora_model_dir = "", RNGType rng_type = STD_DEFAULT_RNG); From f140532741d4d7737e1be423de7cd0af341df34c Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Fri, 8 Dec 2023 15:16:58 -0500 Subject: [PATCH 02/24] add sd_tiling --- ggml | 2 +- model.cpp | 6 +- model.h | 2 +- stable-diffusion.cpp | 135 ++++++++++++++++++++++++++++--------------- stable-diffusion.h | 1 + 5 files changed, 95 insertions(+), 51 deletions(-) diff --git a/ggml b/ggml index f7a51f1b5..b3e6e664b 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit f7a51f1b53e85f9ac47ae522bb655963023c8776 +Subproject commit b3e6e664b34666cebeecc43fd8b1bb93c8639a9b diff --git a/model.cpp b/model.cpp index 3adbec9f8..b316879b7 100644 --- a/model.cpp +++ b/model.cpp @@ -1201,7 +1201,7 @@ bool ModelLoader::load_vocab(on_new_token_cb_t on_new_token_cb) { return true; } -bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) { +bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend_t backend) { bool success = true; for (size_t file_index = 0; file_index < file_paths_.size(); file_index++) { std::string file_path = file_paths_[file_index]; @@ -1289,11 +1289,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) { continue; } - ggml_backend_t backend = ggml_get_backend(dst_tensor); - size_t nbytes_to_read = tensor_storage.nbytes_to_read(); - if (backend == NULL || ggml_backend_is_cpu(backend)) { + if (dst_tensor->buffer == NULL || ggml_backend_is_cpu(backend)) { // for the CPU and Metal backend, we can copy directly into the tensor if (tensor_storage.type == dst_tensor->type) { GGML_ASSERT(ggml_nbytes(dst_tensor) == tensor_storage.nbytes()); diff --git a/model.h b/model.h index 6f27cdbf9..8a97cbc70 100644 --- a/model.h +++ b/model.h @@ -116,7 +116,7 @@ class ModelLoader { SDVersion get_sd_version(); ggml_type get_sd_wtype(); bool load_vocab(on_new_token_cb_t on_new_token_cb); - bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb); + bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend_t backend); int64_t cal_mem_size(ggml_backend_t backend); ~ModelLoader() = default; }; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 8374304de..9940fa673 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -341,7 +341,7 @@ void sd_split_chunk(struct ggml_tensor* input, int64_t width = output->ne[0]; int64_t height = output->ne[1]; int64_t channels = output->ne[2]; - GGML_ASSERT(channels == 3 && output->type == GGML_TYPE_F32); + GGML_ASSERT(input->type == GGML_TYPE_F32 && output->type == GGML_TYPE_F32); for (int iy = 0; iy < height; iy++) { for (int ix = 0; ix < width; ix++) { for (int k = 0; k < channels; k++) { @@ -357,7 +357,7 @@ void sd_merge_chunk(struct ggml_tensor* input, int64_t width = input->ne[0]; int64_t height = input->ne[1]; int64_t channels = input->ne[2]; - GGML_ASSERT(channels == 3 && input->type == GGML_TYPE_F32); + GGML_ASSERT(input->type == GGML_TYPE_F32 && output->type == GGML_TYPE_F32); for (int iy = 0; iy < height; iy++) { for (int ix = 0; ix < width; ix++) { for (int k = 0; k < channels; k++) { @@ -415,6 +415,58 @@ void sd_convert_output(struct ggml_tensor* src) { } } +typedef std::function on_tile_process; + +// Tiling +void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, on_tile_process proc_tile) { + int input_width = input->ne[0]; + int input_height = input->ne[1]; + int output_width = output->ne[0]; + int output_height = output->ne[1]; + GGML_ASSERT(input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0); // should be multiple of 2 + + int tile_width = (input_width < tile_size) ? input_width : tile_size; + int tile_height = (input_height < tile_size) ? input_height : tile_size; + LOG_DEBUG("tile size(%ix%i)", tile_width, tile_height); + + struct ggml_init_params params = {}; + params.mem_size += tile_width * tile_height * input->ne[2] * sizeof(float); // input chunk + params.mem_size += (tile_width * scale) * (tile_height * scale) * output->ne[2] * sizeof(float); // output chunk + params.mem_size += 3 * ggml_tensor_overhead(); + params.mem_buffer = NULL; + params.no_alloc = false; + + LOG_DEBUG("tile work buffer size: %.2f MB", params.mem_size / 1024.f / 1024.f); + + // draft context + struct ggml_context* tiles_ctx = ggml_init(params); + if (!tiles_ctx) { + LOG_ERROR("ggml_init() failed"); + return; + } + + // tiling + int tiles_x = (input_width + tile_size - 1) / tile_size; + int tiles_y = (input_height + tile_size - 1) / tile_size; + ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, tile_width, tile_height, input->ne[2], 1); + ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, tile_width * scale, tile_height * scale, output->ne[2], 1); + proc_tile(input_tile, NULL, true); + + int num_tiles = tiles_x * tiles_y; + LOG_INFO("processing %i tiles", num_tiles); + pretty_progress(1, num_tiles, 0.0f); + for(int y = 0; y < tiles_y; y ++) { + for(int x = 0; x < tiles_x; x++) { + int64_t t1 = ggml_time_ms(); + sd_split_chunk(input, input_tile, x, y); + proc_tile(input_tile, output_tile, false); + sd_merge_chunk(output_tile, output, x, y); + int64_t t2 = ggml_time_ms(); + pretty_progress(x + y * tiles_x + 1, num_tiles, (t2 - t1) / 1000.0f); + } + } +} + struct ggml_tensor* ggml_group_norm_32(struct ggml_context* ctx, struct ggml_tensor* a) { return ggml_group_norm(ctx, a, 32); @@ -3901,7 +3953,7 @@ struct TinyAutoEncoder { return true; }; - bool success = model_loader.load_tensors(on_new_tensor_cb); + bool success = model_loader.load_tensors(on_new_tensor_cb, backend); bool some_tensor_not_init = false; @@ -4362,7 +4414,7 @@ struct ESRGAN { return true; }; - bool success = model_loader.load_tensors(on_new_tensor_cb); + bool success = model_loader.load_tensors(on_new_tensor_cb, backend); bool some_tensor_not_init = false; @@ -4600,7 +4652,7 @@ struct LoraModel { return true; }; - model_loader.load_tensors(on_new_tensor_cb); + model_loader.load_tensors(on_new_tensor_cb, backend); LOG_DEBUG("finished loaded lora"); ggml_allocr_free(alloc); @@ -4941,7 +4993,7 @@ class StableDiffusionGGML { Schedule schedule) { #ifdef SD_USE_CUBLAS LOG_DEBUG("Using CUDA backend"); - backend = ggml_backend_cuda_init(); + backend = ggml_backend_cuda_init(0); #endif if (!backend) { LOG_DEBUG("Using CPU backend"); @@ -5093,7 +5145,7 @@ class StableDiffusionGGML { // print_ggml_tensor(alphas_cumprod_tensor); - success = model_loader.load_tensors(on_new_tensor_cb); + success = model_loader.load_tensors(on_new_tensor_cb, backend); if (!success) { LOG_ERROR("load tensors from file failed"); ggml_free(ctx); @@ -5839,6 +5891,14 @@ class StableDiffusionGGML { } else { sd_convert_input(x); } + // auto tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { + // if(init) { + // first_stage_model.begin(in, decode); + // } else { + // first_stage_model.compute(out, n_threads, in, decode); + // } + // }; + // sd_tiling(x, result, 8, 32, tiling); first_stage_model.begin(x, decode); first_stage_model.compute(result, n_threads, x, decode); first_stage_model.end(); @@ -5846,6 +5906,14 @@ class StableDiffusionGGML { sd_convert_output(result); } } else { + // auto tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { + // if(init) { + // tae_first_stage.begin(in, decode); + // } else { + // tae_first_stage.compute(out, n_threads, in, decode); + // } + // }; + // sd_tiling(x, result, 8, 32, tiling); tae_first_stage.begin(x, decode); tae_first_stage.compute(result, n_threads, x, decode); tae_first_stage.end(); @@ -5859,57 +5927,34 @@ class StableDiffusionGGML { } uint8_t* upscale(ggml_tensor* image) { - int input_width = image->ne[0]; - int input_height = image->ne[1]; - int scale = esrgan_upscaler.scale; - int tile_size = esrgan_upscaler.tile_size; - GGML_ASSERT(input_width % 2 == 0 && input_height % 2 == 0); // should be multiple of 2 - int output_width = input_width * scale; - int output_height = input_height * scale; - int tile_width = (input_width < tile_size) ? input_width : tile_size; - int tile_height = (input_height < tile_size) ? input_height : tile_size; - + int output_width = image->ne[0] * esrgan_upscaler.scale; + int output_height = image->ne[1] * esrgan_upscaler.scale; + LOG_INFO("upscaling from (%i x %i) to (%i x %i)", image->ne[0], image->ne[1], output_width, output_height); struct ggml_init_params params; params.mem_size = output_width * output_height * 3 * sizeof(float); // upscaled - params.mem_size += tile_width * tile_height * 3 * sizeof(float); // input chunk - params.mem_size += (tile_width * scale) * (tile_height * scale) * 3 * sizeof(float); // output chunk - params.mem_size += 4 * ggml_tensor_overhead(); + params.mem_size += 1 * ggml_tensor_overhead(); params.mem_buffer = NULL; params.no_alloc = false; - // draft context struct ggml_context* upscale_ctx = ggml_init(params); if (!upscale_ctx) { LOG_ERROR("ggml_init() failed"); return NULL; } - - LOG_DEBUG("upscaling from (%i x %i) to (%i x %i)", input_width, input_height, output_width, output_height); LOG_DEBUG("upscale work buffer size: %.2f MB", params.mem_size / 1024.f / 1024.f); - // tiling - int tiles_x = (input_width + tile_size - 1) / tile_size; - int tiles_y = (input_height + tile_size - 1) / tile_size; - ggml_tensor* input_chunk = ggml_new_tensor_4d(upscale_ctx, GGML_TYPE_F32, tile_width, tile_height, image->ne[2], 1); - ggml_tensor* output_chunk = ggml_new_tensor_4d(upscale_ctx, GGML_TYPE_F32, tile_width * scale, tile_height * scale, image->ne[2], 1); - ggml_tensor* upscaled_image = ggml_new_tensor_4d(upscale_ctx, GGML_TYPE_F32, output_width, output_height, image->ne[2], 1); - esrgan_upscaler.begin(input_chunk); - int num_tiles = tiles_x * tiles_y; - LOG_INFO("processing %i tiles", num_tiles); - pretty_progress(1, num_tiles, 0.0f); - int64_t t0 = ggml_time_ms(); - for(int y = 0; y < tiles_y; y ++) { - for(int x = 0; x < tiles_x; x++) { - int64_t t1 = ggml_time_ms(); - sd_split_chunk(image, input_chunk, x, y); - esrgan_upscaler.compute(output_chunk, n_threads, input_chunk); - sd_merge_chunk(output_chunk, upscaled_image, x, y); - int64_t t2 = ggml_time_ms(); - pretty_progress(x + y * tiles_x + 1, num_tiles, (t2 - t1) / 1000.0f); + ggml_tensor* upscaled = ggml_new_tensor_4d(upscale_ctx, GGML_TYPE_F32, output_width, output_height, image->ne[2], 1); + auto tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { + if(init) { + esrgan_upscaler.begin(in); + } else { + esrgan_upscaler.compute(out, n_threads, in); } - } + }; + int64_t t0 = ggml_time_ms(); + sd_tiling(image, upscaled, esrgan_upscaler.scale, esrgan_upscaler.tile_size, tiling); esrgan_upscaler.end(); - sd_clamp(upscaled_image, 0.f, 1.f); - uint8_t* upscaled_data = sd_tensor_to_image(upscaled_image); + sd_clamp(upscaled, 0.f, 1.f); + uint8_t* upscaled_data = sd_tensor_to_image(upscaled); ggml_free(upscale_ctx); int64_t t3 = ggml_time_ms(); LOG_INFO("image upscaled, taking %.2fs", (t3 - t0) / 1000.0f); diff --git a/stable-diffusion.h b/stable-diffusion.h index 5de14f407..a43a886b9 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -4,6 +4,7 @@ #include #include #include +#include "ggml/ggml.h" enum RNGType { STD_DEFAULT_RNG, From b5ade20cd5f2ced8f70b6d65d18d27f70354ab43 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Fri, 8 Dec 2023 15:26:01 -0500 Subject: [PATCH 03/24] ggml: adapt to new backend --- ggml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml b/ggml index b3e6e664b..793a5c490 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit b3e6e664b34666cebeecc43fd8b1bb93c8639a9b +Subproject commit 793a5c49031a7e968fa1e67af5a91c7cdec68be3 From f83b742ef2e035f5b5cb462ff24bd304d4946546 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Sat, 9 Dec 2023 13:20:42 -0500 Subject: [PATCH 04/24] fix some conflicts --- stable-diffusion.cpp | 38 +------------------------------------- 1 file changed, 1 insertion(+), 37 deletions(-) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index eedce4822..510f32392 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -335,7 +335,6 @@ void sd_image_to_tensor(const uint8_t* image_data, } } - void sd_split_chunk(struct ggml_tensor* input, struct ggml_tensor* output, int x, int y) { int64_t width = output->ne[0]; @@ -368,40 +367,6 @@ void sd_merge_chunk(struct ggml_tensor* input, } } - -void sd_split_chunk(struct ggml_tensor* input, - struct ggml_tensor* output, int x, int y) { - int64_t width = output->ne[0]; - int64_t height = output->ne[1]; - int64_t channels = output->ne[2]; - GGML_ASSERT(input->type == GGML_TYPE_F32 && output->type == GGML_TYPE_F32); - for (int iy = 0; iy < height; iy++) { - for (int ix = 0; ix < width; ix++) { - for (int k = 0; k < channels; k++) { - float value = ggml_tensor_get_f32(input, ix + x * width, iy + y * height, k); - ggml_tensor_set_f32(output, value, ix, iy, k); - } - } - } -} - -void sd_merge_chunk(struct ggml_tensor* input, - struct ggml_tensor* output, int x, int y) { - int64_t width = input->ne[0]; - int64_t height = input->ne[1]; - int64_t channels = input->ne[2]; - GGML_ASSERT(input->type == GGML_TYPE_F32 && output->type == GGML_TYPE_F32); - for (int iy = 0; iy < height; iy++) { - for (int ix = 0; ix < width; ix++) { - for (int k = 0; k < channels; k++) { - float value = ggml_tensor_get_f32(input, ix, iy, k); - ggml_tensor_set_f32(output, value, ix + x * width, iy + y * height, k); - } - } - } -} - - float ggml_tensor_mean(struct ggml_tensor* src) { float mean = 0.0f; int64_t nelements = ggml_nelements(src); @@ -472,7 +437,6 @@ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const i int tile_width = (input_width < tile_size) ? input_width : tile_size; int tile_height = (input_height < tile_size) ? input_height : tile_size; - LOG_DEBUG("tile size(%ix%i)", tile_width, tile_height); struct ggml_init_params params = {}; params.mem_size += tile_width * tile_height * input->ne[2] * sizeof(float); // input chunk @@ -5901,7 +5865,7 @@ class StableDiffusionGGML { int64_t t0 = ggml_time_ms(); sd_tiling(image, upscaled, esrgan_upscaler.scale, esrgan_upscaler.tile_size, tiling); esrgan_upscaler.end(); - sd_clamp(upscaled, 0.f, 1.f); + ggml_tensor_clamp(upscaled, 0.f, 1.f); uint8_t* upscaled_data = sd_tensor_to_image(upscaled); ggml_free(upscale_ctx); int64_t t3 = ggml_time_ms(); From 136474df54ba95be7a67612a6f0bdb4a2855a989 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Sun, 10 Dec 2023 17:50:55 -0500 Subject: [PATCH 05/24] sd_tiling support overlapping + vae tiling --- README.md | 2 + examples/cli/main.cpp | 17 ++++- stable-diffusion.cpp | 140 +++++++++++++++++++++++++++--------------- stable-diffusion.h | 4 +- test.py | 23 +++++++ 5 files changed, 133 insertions(+), 53 deletions(-) create mode 100644 test.py diff --git a/README.md b/README.md index 117cd0cac..63b52ecf4 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,8 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in - Implement Winograd Convolution 2D for 3x3 kernel filtering - [ ] Continuing to reduce memory usage (quantizing the weights of ggml_conv_2d) - [ ] Implement BPE Tokenizer +- [ ] Implement Textual Inversion (embeddings) +- [ ] Implement Inpainting support - [ ] k-quants support ## Usage diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index e6e2e55fe..2d29ae9ba 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -68,6 +68,7 @@ struct SDParams { std::string prompt; std::string negative_prompt; float cfg_scale = 7.0f; + int clip_skip_layers = 0; int width = 512; int height = 512; int batch_count = 1; @@ -79,6 +80,7 @@ struct SDParams { RNGType rng_type = CUDA_RNG; int64_t seed = 42; bool verbose = false; + bool vae_tiling = false; }; void print_params(SDParams params) { @@ -136,6 +138,8 @@ void print_usage(int argc, const char* argv[]) { printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n"); printf(" -b, --batch-count COUNT number of images to generate.\n"); printf(" --schedule {discrete, karras} Denoiser sigma schedule (default: discrete)\n"); + printf(" -cs, --clip-skip N number of layers to skip of clip model (default: 0)\n"); + printf(" -vt, --vae-tiling process vae in tiles to reduce memory usage\n"); printf(" -v, --verbose print extra info\n"); } @@ -278,6 +282,14 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.sample_steps = std::stoi(argv[i]); + } else if (arg == "-cs" || arg == "--clip-skip") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.clip_skip_layers = std::stoi(argv[i]); + } else if (arg == "-vt" || arg == "--vae-tiling") { + params.vae_tiling = true; } else if (arg == "-b" || arg == "--batch-count") { if (++i >= argc) { invalid_arg = true; @@ -466,9 +478,9 @@ int main(int argc, const char* argv[]) { } } - StableDiffusion sd(params.n_threads, vae_decode_only, params.taesd_path, params.esrgan_path, true, params.lora_model_dir, params.rng_type); + StableDiffusion sd(params.n_threads, vae_decode_only, params.taesd_path, params.esrgan_path, true, params.vae_tiling, params.lora_model_dir, params.rng_type); - if (!sd.load_from_file(params.model_path, params.vae_path, params.wtype, params.schedule)) { + if (!sd.load_from_file(params.model_path, params.vae_path, params.wtype, params.schedule, params.clip_skip_layers)) { return 1; } @@ -497,6 +509,7 @@ int main(int argc, const char* argv[]) { } if(params.esrgan_path.size() > 0) { + // TODO: support more ESRGAN models, making it easier to set up ESRGAN models. /* hardcoded scale factor because just RealESRGAN_x4plus_anime_6B is compatible See also: https://github.com/xinntao/Real-ESRGAN/blob/master/inference_realesrgan.py diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 510f32392..22b37ab18 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -335,7 +335,7 @@ void sd_image_to_tensor(const uint8_t* image_data, } } -void sd_split_chunk(struct ggml_tensor* input, +void ggml_split_tensor_2d(struct ggml_tensor* input, struct ggml_tensor* output, int x, int y) { int64_t width = output->ne[0]; int64_t height = output->ne[1]; @@ -344,15 +344,15 @@ void sd_split_chunk(struct ggml_tensor* input, for (int iy = 0; iy < height; iy++) { for (int ix = 0; ix < width; ix++) { for (int k = 0; k < channels; k++) { - float value = ggml_tensor_get_f32(input, ix + x * width, iy + y * height, k); + float value = ggml_tensor_get_f32(input, ix + x, iy + y, k); ggml_tensor_set_f32(output, value, ix, iy, k); } } } } -void sd_merge_chunk(struct ggml_tensor* input, - struct ggml_tensor* output, int x, int y) { +void ggml_merge_tensor_2d(struct ggml_tensor* input, + struct ggml_tensor* output, int x, int y, int overlap) { int64_t width = input->ne[0]; int64_t height = input->ne[1]; int64_t channels = input->ne[2]; @@ -360,8 +360,19 @@ void sd_merge_chunk(struct ggml_tensor* input, for (int iy = 0; iy < height; iy++) { for (int ix = 0; ix < width; ix++) { for (int k = 0; k < channels; k++) { - float value = ggml_tensor_get_f32(input, ix, iy, k); - ggml_tensor_set_f32(output, value, ix + x * width, iy + y * height, k); + float new_value = ggml_tensor_get_f32(input, ix, iy, k); + if(overlap > 0) { // blend colors in overlapped area + float old_value = ggml_tensor_get_f32(output, x + ix, y + iy, k); + if(x > 0 && ix < overlap) { // in overlapped horizontal + ggml_tensor_set_f32(output, old_value + (new_value - old_value) * (ix / (1.0f * overlap)), x + ix,y + iy, k); + continue; + } + if(y > 0 && iy < overlap) { // in overlapped vertical + ggml_tensor_set_f32(output, old_value + (new_value - old_value) * (iy / (1.0f * overlap)), x + ix,y + iy, k); + continue; + } + } + ggml_tensor_set_f32(output, new_value, x + ix, y + iy, k); } } } @@ -428,19 +439,19 @@ void ggml_tensor_scale_output(struct ggml_tensor* src) { typedef std::function on_tile_process; // Tiling -void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, on_tile_process proc_tile) { +void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) { int input_width = input->ne[0]; int input_height = input->ne[1]; int output_width = output->ne[0]; int output_height = output->ne[1]; GGML_ASSERT(input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0); // should be multiple of 2 - int tile_width = (input_width < tile_size) ? input_width : tile_size; - int tile_height = (input_height < tile_size) ? input_height : tile_size; + int tile_overlap = (int32_t)(tile_size * tile_overlap_factor); + int non_tile_overlap = tile_size - tile_overlap; struct ggml_init_params params = {}; - params.mem_size += tile_width * tile_height * input->ne[2] * sizeof(float); // input chunk - params.mem_size += (tile_width * scale) * (tile_height * scale) * output->ne[2] * sizeof(float); // output chunk + params.mem_size += tile_size * tile_size * input->ne[2] * sizeof(float); // input chunk + params.mem_size += (tile_size * scale) * (tile_size * scale) * output->ne[2] * sizeof(float); // output chunk params.mem_size += 3 * ggml_tensor_overhead(); params.mem_buffer = NULL; params.no_alloc = false; @@ -455,24 +466,38 @@ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const i } // tiling - int tiles_x = (input_width + tile_size - 1) / tile_size; - int tiles_y = (input_height + tile_size - 1) / tile_size; - ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, tile_width, tile_height, input->ne[2], 1); - ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, tile_width * scale, tile_height * scale, output->ne[2], 1); - proc_tile(input_tile, NULL, true); - - int num_tiles = tiles_x * tiles_y; + ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, tile_size, tile_size, input->ne[2], 1); + ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, tile_size * scale, tile_size * scale, output->ne[2], 1); + on_processing(input_tile, NULL, true); + int num_tiles = (input_width * input_height) / (non_tile_overlap * non_tile_overlap); LOG_INFO("processing %i tiles", num_tiles); pretty_progress(1, num_tiles, 0.0f); - for(int y = 0; y < tiles_y; y ++) { - for(int x = 0; x < tiles_x; x++) { + int tile_count = 1; + bool last_y = false, last_x = false; + float last_time = 0.0f; + for(int y = 0; y < input_height && !last_y; y += non_tile_overlap) { + if (y + tile_size >= input_height) { + y = input_height - tile_size; + last_y = true; + } + for(int x = 0; x < input_width && !last_x; x += non_tile_overlap) { + if (x + tile_size >= input_width) { + x = input_width - tile_size; + last_x = true; + } int64_t t1 = ggml_time_ms(); - sd_split_chunk(input, input_tile, x, y); - proc_tile(input_tile, output_tile, false); - sd_merge_chunk(output_tile, output, x, y); + ggml_split_tensor_2d(input, input_tile, x, y); + on_processing(input_tile, output_tile, false); + ggml_merge_tensor_2d(output_tile, output, x * scale, y * scale, tile_overlap * scale); int64_t t2 = ggml_time_ms(); - pretty_progress(x + y * tiles_x + 1, num_tiles, (t2 - t1) / 1000.0f); + last_time = (t2 - t1) / 1000.0f; + pretty_progress(tile_count, num_tiles, last_time); + tile_count++; } + last_x = false; + } + if(tile_count < num_tiles) { + pretty_progress(num_tiles, num_tiles, last_time); } } @@ -962,6 +987,7 @@ struct CLIPTextModel { int32_t intermediate_size = 3072; // 4096 for SD 2.x int32_t n_head = 12; // num_attention_heads, 16 for SD 2.x int32_t num_hidden_layers = 12; // 24 for SD 2.x + int32_t skip_layers = 0; // embeddings struct ggml_tensor* position_ids; @@ -1121,7 +1147,7 @@ struct CLIPTextModel { // transformer for (int i = 0; i < num_hidden_layers; i++) { - if (version == VERSION_2_x && i == num_hidden_layers - 1) { // layer: "penultimate" + if (version == VERSION_2_x && i == num_hidden_layers - 1 || i > (num_hidden_layers - skip_layers - 1)) { // layer: "penultimate" break; } x = resblocks[i].forward(ctx0, x); // [N, n_token, hidden_size] @@ -1207,6 +1233,7 @@ struct CLIPTextModel { ggml_backend_buffer_free(compute_buffer); compute_alloc = NULL; compute_memory_buffer_size = -1; + work_output = NULL; } }; @@ -4844,6 +4871,7 @@ class StableDiffusionGGML { UNetModel diffusion_model; AutoEncoderKL first_stage_model; bool use_tiny_autoencoder = false; + bool vae_tiling = false; std::map tensors; @@ -4895,7 +4923,7 @@ class StableDiffusionGGML { bool load_from_file(const std::string& model_path, const std::string& vae_path, ggml_type wtype, - Schedule schedule) { + Schedule schedule, int clip_skip_layers) { #ifdef SD_USE_CUBLAS LOG_DEBUG("Using CUDA backend"); backend = ggml_backend_cuda_init(0); @@ -4932,6 +4960,7 @@ class StableDiffusionGGML { return false; } cond_stage_model = FrozenCLIPEmbedderWithCustomWords(version); + cond_stage_model.text_model.skip_layers = clip_skip_layers; diffusion_model = UNetModel(version); LOG_INFO("Stable Diffusion %s ", model_version_to_str[version]); if (wtype == GGML_TYPE_COUNT) { @@ -5803,31 +5832,39 @@ class StableDiffusionGGML { } else { ggml_tensor_scale_input(x); } - // auto tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { - // if(init) { - // first_stage_model.begin(in, decode); - // } else { - // first_stage_model.compute(out, n_threads, in, decode); - // } - // }; - // sd_tiling(x, result, 8, 32, tiling); - first_stage_model.begin(x, decode); - first_stage_model.compute(result, n_threads, x, decode); + if(vae_tiling && decode) { // TODO: support tiling vae encode + // split latent in 32x32 tiles and compute in several steps + auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { + if(init) { + first_stage_model.begin(in, decode); + } else { + first_stage_model.compute(out, n_threads, in, decode); + } + }; + sd_tiling(x, result, 8, 32, 0.5f, on_tiling); + } else { + first_stage_model.begin(x, decode); + first_stage_model.compute(result, n_threads, x, decode); + } first_stage_model.end(); if (decode) { ggml_tensor_scale_output(result); } } else { - // auto tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { - // if(init) { - // tae_first_stage.begin(in, decode); - // } else { - // tae_first_stage.compute(out, n_threads, in, decode); - // } - // }; - // sd_tiling(x, result, 8, 32, tiling); - tae_first_stage.begin(x, decode); - tae_first_stage.compute(result, n_threads, x, decode); + if(vae_tiling && decode) { // TODO: support tiling vae encode + // split latent in 64x64 tiles and compute in several steps + auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { + if(init) { + tae_first_stage.begin(in, decode); + } else { + tae_first_stage.compute(out, n_threads, in, decode); + } + }; + sd_tiling(x, result, 8, 64, 0.5f, on_tiling); + } else { + tae_first_stage.begin(x, decode); + tae_first_stage.compute(result, n_threads, x, decode); + } tae_first_stage.end(); } int64_t t1 = ggml_time_ms(); @@ -5855,7 +5892,7 @@ class StableDiffusionGGML { } LOG_DEBUG("upscale work buffer size: %.2f MB", params.mem_size / 1024.f / 1024.f); ggml_tensor* upscaled = ggml_new_tensor_4d(upscale_ctx, GGML_TYPE_F32, output_width, output_height, image->ne[2], 1); - auto tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { + auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { if(init) { esrgan_upscaler.begin(in); } else { @@ -5863,7 +5900,7 @@ class StableDiffusionGGML { } }; int64_t t0 = ggml_time_ms(); - sd_tiling(image, upscaled, esrgan_upscaler.scale, esrgan_upscaler.tile_size, tiling); + sd_tiling(image, upscaled, esrgan_upscaler.scale, esrgan_upscaler.tile_size, 0.25f, on_tiling); esrgan_upscaler.end(); ggml_tensor_clamp(upscaled, 0.f, 1.f); uint8_t* upscaled_data = sd_tensor_to_image(upscaled); @@ -5889,6 +5926,7 @@ StableDiffusion::StableDiffusion(int n_threads, std::string taesd_path, std::string esrgan_path, bool free_params_immediately, + bool vae_tiling, std::string lora_model_dir, RNGType rng_type) { sd = std::make_shared(n_threads, @@ -5900,13 +5938,15 @@ StableDiffusion::StableDiffusion(int n_threads, sd->taesd_path = taesd_path; sd->upscale_output = esrgan_path.size() > 0; sd->esrgan_path = esrgan_path; + sd->vae_tiling = vae_tiling; } bool StableDiffusion::load_from_file(const std::string& model_path, const std::string& vae_path, ggml_type wtype, - Schedule s) { - return sd->load_from_file(model_path, vae_path, wtype, s); + Schedule s, + int clip_skip_layers) { + return sd->load_from_file(model_path, vae_path, wtype, s, clip_skip_layers); } std::vector StableDiffusion::txt2img(std::string prompt, diff --git a/stable-diffusion.h b/stable-diffusion.h index a43a886b9..ac2b60b90 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -42,12 +42,14 @@ class StableDiffusion { std::string taesd_path = "", std::string esrgan_path = "", bool free_params_immediately = false, + bool vae_tiling = false, std::string lora_model_dir = "", RNGType rng_type = STD_DEFAULT_RNG); bool load_from_file(const std::string& model_path, const std::string& vae_path, ggml_type wtype, - Schedule d = DEFAULT); + Schedule d = DEFAULT, + int clip_skip_layers = 0); std::vector txt2img( std::string prompt, std::string negative_prompt, diff --git a/test.py b/test.py new file mode 100644 index 000000000..83d5e68b0 --- /dev/null +++ b/test.py @@ -0,0 +1,23 @@ +import math + +def split_grid(image, tile_w=512, tile_h=512, overlap=64): + w = image["width"] + h = image["height"] + non_overlap_width = tile_w - overlap + non_overlap_height = tile_h - overlap + cols = math.ceil((w - overlap) / non_overlap_width) + rows = math.ceil((h - overlap) / non_overlap_height) + dx = (w - tile_w) / (cols - 1) if cols > 1 else 0 + dy = (h - tile_h) / (rows - 1) if rows > 1 else 0 + for row in range(rows): + row_images = [] + y = int(row * dy) + if y + tile_h >= h: + y = h - tile_h + for col in range(cols): + x = int(col * dx) + if x + tile_w >= w: + x = w - tile_w + print(f"cursor({x}, {y})") + +split_grid({"width": 512, "height": 512}, 128, 128, 64) \ No newline at end of file From 2bcca30192c204f449596100330331dc6cf4772f Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Sun, 10 Dec 2023 17:58:26 -0500 Subject: [PATCH 06/24] prepare to use metal as backend --- CMakeLists.txt | 7 +++++++ stable-diffusion.cpp | 22 ++++++++++++++++------ test.py | 23 ----------------------- 3 files changed, 23 insertions(+), 29 deletions(-) delete mode 100644 test.py diff --git a/CMakeLists.txt b/CMakeLists.txt index b119ee6ec..788f0d896 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,6 +25,7 @@ endif() #option(SD_BUILD_TESTS "sd: build tests" ${SD_STANDALONE}) option(SD_BUILD_EXAMPLES "sd: build examples" ${SD_STANDALONE}) option(SD_CUBLAS "sd: cuda backend" OFF) +option(SD_METAL "sd: metal backend" OFF) option(SD_FLASH_ATTN "sd: use flash attention for x4 less memory usage" OFF) option(BUILD_SHARED_LIBS "sd: build shared libs" OFF) #option(SD_BUILD_SERVER "sd: build server example" ON) @@ -35,6 +36,12 @@ if(SD_CUBLAS) add_definitions(-DSD_USE_CUBLAS) endif() +if(SD_METAL) + message("Use Metal as backend stable-diffusion") + set(GGML_METAL ON) + add_definitions(-DSD_USE_METAL) +endif() + if(SD_FLASH_ATTN) message("Use Flash Attention for memory optimization") add_definitions(-DSD_USE_FLASH_ATTENTION) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 22b37ab18..cdcdd0700 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -23,6 +23,10 @@ #include "ggml-cuda.h" #endif +#ifdef SD_USE_METAL +#include "ggml-metal.h" +#endif + #include "model.h" #include "rng.h" #include "rng_philox.h" @@ -1628,7 +1632,7 @@ struct SpatialTransformer { { x = ggml_reshape_2d(ctx, x, c, h * w * n); // [N * h * w, in_channels] struct ggml_tensor* q = ggml_mul_mat(ctx, transformer.attn1_q_w, x); // [N * h * w, in_channels] -#if !defined(SD_USE_FLASH_ATTENTION) || defined(SD_USE_CUBLAS) +#if !defined(SD_USE_FLASH_ATTENTION) || defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) q = ggml_scale_inplace(ctx, q, attn_scale); #endif q = ggml_reshape_4d(ctx, q, d_head, n_head, h * w, n); // [N, h * w, n_head, d_head] @@ -1645,7 +1649,7 @@ struct SpatialTransformer { v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, h * w] v = ggml_reshape_3d(ctx, v, h * w, d_head, n_head * n); // [N * n_head, d_head, h * w] -#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) +#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) && !defined(SD_USE_METAL) struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, h * w, d_head] #else struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, h * w, h * w] @@ -1676,7 +1680,7 @@ struct SpatialTransformer { x = ggml_reshape_2d(ctx, x, c, h * w * n); // [N * h * w, in_channels] context = ggml_reshape_2d(ctx, context, context->ne[0], context->ne[1] * context->ne[2]); // [N * max_position, hidden_size] struct ggml_tensor* q = ggml_mul_mat(ctx, transformer.attn2_q_w, x); // [N * h * w, in_channels] -#if !defined(SD_USE_FLASH_ATTENTION) || defined(SD_USE_CUBLAS) +#if !defined(SD_USE_FLASH_ATTENTION) || defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) q = ggml_scale_inplace(ctx, q, attn_scale); #endif q = ggml_reshape_4d(ctx, q, d_head, n_head, h * w, n); // [N, h * w, n_head, d_head] @@ -1692,7 +1696,7 @@ struct SpatialTransformer { v = ggml_reshape_4d(ctx, v, d_head, n_head, max_position, n); // [N, max_position, n_head, d_head] v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, max_position] v = ggml_reshape_3d(ctx, v, max_position, d_head, n_head * n); // [N * n_head, d_head, max_position] -#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) +#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) && !defined(SD_USE_METAL) struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, h * w, d_head] #else struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, h * w, max_position] @@ -4928,13 +4932,19 @@ class StableDiffusionGGML { LOG_DEBUG("Using CUDA backend"); backend = ggml_backend_cuda_init(0); #endif +#ifdef SD_USE_METAL + LOG_DEBUG("Using Metal backend"); + ggml_metal_log_set_callback(ggml_log_callback_default, nullptr); + backend = ggml_backend_metal_init(); +#endif + if (!backend) { LOG_DEBUG("Using CPU backend"); backend = ggml_backend_cpu_init(); } #ifdef SD_USE_FLASH_ATTENTION -#ifdef SD_USE_CUBLAS - LOG_WARN("Flash Attention not supported with CUDA"); +#if defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) + LOG_WARN("Flash Attention not supported with GPU Backend"); #else LOG_INFO("Flash Attention enabled"); #endif diff --git a/test.py b/test.py deleted file mode 100644 index 83d5e68b0..000000000 --- a/test.py +++ /dev/null @@ -1,23 +0,0 @@ -import math - -def split_grid(image, tile_w=512, tile_h=512, overlap=64): - w = image["width"] - h = image["height"] - non_overlap_width = tile_w - overlap - non_overlap_height = tile_h - overlap - cols = math.ceil((w - overlap) / non_overlap_width) - rows = math.ceil((h - overlap) / non_overlap_height) - dx = (w - tile_w) / (cols - 1) if cols > 1 else 0 - dy = (h - tile_h) / (rows - 1) if rows > 1 else 0 - for row in range(rows): - row_images = [] - y = int(row * dy) - if y + tile_h >= h: - y = h - tile_h - for col in range(cols): - x = int(col * dx) - if x + tile_w >= w: - x = w - tile_w - print(f"cursor({x}, {y})") - -split_grid({"width": 512, "height": 512}, 128, 128, 64) \ No newline at end of file From 05101e4bf0ae895a353ee5526d098d2d7140f469 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Sun, 10 Dec 2023 18:16:57 -0500 Subject: [PATCH 07/24] support metal backend --- ggml | 2 +- model.cpp | 6 +++++- stable-diffusion.cpp | 37 +++++++++++++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 2 deletions(-) diff --git a/ggml b/ggml index 793a5c490..4a6f0e50f 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit 793a5c49031a7e968fa1e67af5a91c7cdec68be3 +Subproject commit 4a6f0e50f3659fd054534240f949f916ea9cf6cf diff --git a/model.cpp b/model.cpp index d3673df57..22f1cf920 100644 --- a/model.cpp +++ b/model.cpp @@ -1298,7 +1298,11 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend size_t nbytes_to_read = tensor_storage.nbytes_to_read(); - if (dst_tensor->buffer == NULL || ggml_backend_is_cpu(backend)) { + if (dst_tensor->buffer == NULL || ggml_backend_is_cpu(backend) +#ifdef SD_USE_METAL + || ggml_backend_is_metal(model.backend) +#endif + ) { // for the CPU and Metal backend, we can copy directly into the tensor if (tensor_storage.type == dst_tensor->type) { GGML_ASSERT(ggml_nbytes(dst_tensor) == tensor_storage.nbytes()); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index cdcdd0700..5df7767f4 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -1223,6 +1223,12 @@ struct CLIPTextModel { ggml_backend_cpu_set_n_threads(backend, n_threads); } +#ifdef SD_USE_METAL + if (ggml_backend_is_metal(backend)) { + ggml_backend_metal_set_n_cb(backend, n_threads); + } +#endif + ggml_backend_graph_compute(backend, gf); #ifdef GGML_PERF @@ -2494,6 +2500,12 @@ struct UNetModel { ggml_backend_cpu_set_n_threads(backend, n_threads); } +#ifdef SD_USE_METAL + if (ggml_backend_is_metal(backend)) { + ggml_backend_metal_set_n_cb(backend, n_threads); + } +#endif + ggml_backend_graph_compute(backend, gf); #ifdef GGML_PERF @@ -3299,6 +3311,12 @@ struct AutoEncoderKL { ggml_backend_cpu_set_n_threads(backend, n_threads); } +#ifdef SD_USE_METAL + if (ggml_backend_is_metal(backend)) { + ggml_backend_metal_set_n_cb(backend, n_threads); + } +#endif + ggml_backend_graph_compute(backend, gf); #ifdef GGML_PERF @@ -3976,6 +3994,12 @@ struct TinyAutoEncoder { ggml_backend_cpu_set_n_threads(backend, n_threads); } +#ifdef SD_USE_METAL + if (ggml_backend_is_metal(backend)) { + ggml_backend_metal_set_n_cb(backend, n_threads); + } +#endif + ggml_backend_graph_compute(backend, gf); #ifdef GGML_PERF @@ -4507,6 +4531,12 @@ struct ESRGAN { ggml_backend_cpu_set_n_threads(backend, n_threads); } +#ifdef SD_USE_METAL + if (ggml_backend_is_metal(backend)) { + ggml_backend_metal_set_n_cb(backend, n_threads); + } +#endif + ggml_backend_graph_compute(backend, gf); #ifdef GGML_PERF @@ -4724,6 +4754,13 @@ struct LoraModel { if (ggml_backend_is_cpu(backend)) { ggml_backend_cpu_set_n_threads(backend, n_threads); } + +#ifdef SD_USE_METAL + if (ggml_backend_is_metal(backend)) { + ggml_backend_metal_set_n_cb(backend, n_threads); + } +#endif + ggml_backend_graph_compute(backend, gf); ggml_allocr_free(compute_alloc); ggml_backend_buffer_free(buffer_compute_lora); From e641b8cd453adfea33a4c269f3a1ac981fd52d6f Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Sun, 10 Dec 2023 18:23:57 -0500 Subject: [PATCH 08/24] fix submodule --- ggml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml b/ggml index 4a6f0e50f..f7a51f1b5 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit 4a6f0e50f3659fd054534240f949f916ea9cf6cf +Subproject commit f7a51f1b53e85f9ac47ae522bb655963023c8776 From e1f5a1ce4a03021d7bd428ea5c1b6f933b5c9a8b Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Sun, 10 Dec 2023 18:35:24 -0500 Subject: [PATCH 09/24] fix metal ggml-submodule --- ggml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml b/ggml index f7a51f1b5..e7584f77a 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit f7a51f1b53e85f9ac47ae522bb655963023c8776 +Subproject commit e7584f77a2e26504c3f172784d67de5b253f6609 From 5759b5347c9577b48349816b590b808f890727d1 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Sun, 10 Dec 2023 18:50:31 -0500 Subject: [PATCH 10/24] fix possible error --- model.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/model.cpp b/model.cpp index 22f1cf920..b73c949ce 100644 --- a/model.cpp +++ b/model.cpp @@ -14,6 +14,10 @@ #include "ggml/ggml-backend.h" #include "ggml/ggml.h" +#ifdef SD_USE_METAL +#include "ggml-metal.h" +#endif + #define ST_HEADER_SIZE_LEN 8 uint64_t read_u64(uint8_t* buffer) { From 4d4674758710a1f7efa16937120faaf13a498e5b Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Mon, 11 Dec 2023 09:17:44 -0500 Subject: [PATCH 11/24] fix metal compilation errors --- examples/cli/main.cpp | 2 +- model.cpp | 2 +- stable-diffusion.cpp | 7 +++++++ 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 2d29ae9ba..8d1defe08 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -478,7 +478,7 @@ int main(int argc, const char* argv[]) { } } - StableDiffusion sd(params.n_threads, vae_decode_only, params.taesd_path, params.esrgan_path, true, params.vae_tiling, params.lora_model_dir, params.rng_type); + StableDiffusion sd(params.n_threads, vae_decode_only, params.taesd_path, params.esrgan_path, false, params.vae_tiling, params.lora_model_dir, params.rng_type); if (!sd.load_from_file(params.model_path, params.vae_path, params.wtype, params.schedule, params.clip_skip_layers)) { return 1; diff --git a/model.cpp b/model.cpp index b73c949ce..15672ab87 100644 --- a/model.cpp +++ b/model.cpp @@ -1304,7 +1304,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend if (dst_tensor->buffer == NULL || ggml_backend_is_cpu(backend) #ifdef SD_USE_METAL - || ggml_backend_is_metal(model.backend) + || ggml_backend_is_metal(backend) #endif ) { // for the CPU and Metal backend, we can copy directly into the tensor diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 5df7767f4..76e95d598 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -79,6 +79,13 @@ std::string sd_get_system_info() { return ss.str(); } +static void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; + fputs(text, stderr); + fflush(stderr); +} + void ggml_tensor_set_f32_randn(struct ggml_tensor* tensor, std::shared_ptr rng) { uint32_t n = (uint32_t)ggml_nelements(tensor); std::vector random_numbers = rng->randn(n); From dfe6abb791bd30de7ce903a2ef5a92d0f0401e0f Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Mon, 11 Dec 2023 09:19:26 -0500 Subject: [PATCH 12/24] restore free memory param --- examples/cli/main.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 8d1defe08..2d29ae9ba 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -478,7 +478,7 @@ int main(int argc, const char* argv[]) { } } - StableDiffusion sd(params.n_threads, vae_decode_only, params.taesd_path, params.esrgan_path, false, params.vae_tiling, params.lora_model_dir, params.rng_type); + StableDiffusion sd(params.n_threads, vae_decode_only, params.taesd_path, params.esrgan_path, true, params.vae_tiling, params.lora_model_dir, params.rng_type); if (!sd.load_from_file(params.model_path, params.vae_path, params.wtype, params.schedule, params.clip_skip_layers)) { return 1; From e8d4cb0bd3ace8410980df531801b9b0e3f79ae6 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Tue, 12 Dec 2023 12:18:05 -0500 Subject: [PATCH 13/24] fix metal backendissue --- ggml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml b/ggml index e7584f77a..01135e2f8 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit e7584f77a2e26504c3f172784d67de5b253f6609 +Subproject commit 01135e2f87059a39c5651a6bc9b2b3a2ba9caf3c From 8dc966ecc26387251828d40e19cb6c369013abfa Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Wed, 13 Dec 2023 15:25:04 -0500 Subject: [PATCH 14/24] cuda: fast softmax cmake option --- CMakeLists.txt | 4 ++++ ggml | 2 +- stable-diffusion.cpp | 1 + 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 788f0d896..c58db32cd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -27,6 +27,7 @@ option(SD_BUILD_EXAMPLES "sd: build examples" ${SD_STANDALONE}) option(SD_CUBLAS "sd: cuda backend" OFF) option(SD_METAL "sd: metal backend" OFF) option(SD_FLASH_ATTN "sd: use flash attention for x4 less memory usage" OFF) +option(SD_FAST_SOFTMAX "sd: x1.5 faster softmax, slight variation in the results, cuda only" OFF) option(BUILD_SHARED_LIBS "sd: build shared libs" OFF) #option(SD_BUILD_SERVER "sd: build server example" ON) @@ -34,6 +35,9 @@ if(SD_CUBLAS) message("Use CUBLAS as backend stable-diffusion") set(GGML_CUBLAS ON) add_definitions(-DSD_USE_CUBLAS) + if(SD_FAST_SOFTMAX) + set(GGML_CUDA_FAST_SOFTMAX ON) + endif() endif() if(SD_METAL) diff --git a/ggml b/ggml index 01135e2f8..11dc0eed5 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit 01135e2f87059a39c5651a6bc9b2b3a2ba9caf3c +Subproject commit 11dc0eed5cc86fb0bf0c8c72f3817f50edb27659 diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 76e95d598..bd9bebd13 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -5334,6 +5334,7 @@ class StableDiffusionGGML { int64_t t0 = ggml_time_ms(); cond_stage_model.text_model.begin(work_ctx, (int)tokens.size()); struct ggml_tensor* hidden_states = cond_stage_model.text_model.compute(n_threads, tokens); // [N, n_token, hidden_size] + print_ggml_tensor(hidden_states); cond_stage_model.text_model.end(); int64_t t1 = ggml_time_ms(); LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0); From 8cb12d2eba01c3f50205c127425c3da0814d7b67 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Wed, 13 Dec 2023 15:47:00 -0500 Subject: [PATCH 15/24] update readme --- README.md | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 257de0235..088356581 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in - Accelerated memory-efficient CPU inference - Only requires ~2.3GB when using txt2img with fp16 precision to generate a 512x512 image, enabling Flash Attention just requires ~1.8GB. - AVX, AVX2 and AVX512 support for x86 architectures -- Full CUDA backend for GPU acceleration. +- Full CUDA and Metal backend for GPU acceleration. - Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs models - No need to convert to `.ggml` or `.gguf` anymore! - Flash Attention for memory usage optimization (only cpu for now) @@ -28,6 +28,7 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in - Latent Consistency Models support (LCM/LCM-LoRA) - Faster and memory efficient latent decoding with [TAESD](https://github.com/madebyollin/taesd) - Upscale images generated with [ESRGAN](https://github.com/xinntao/Real-ESRGAN) +- VAE tiling processing for reduce memory usage - Sampling method - `Euler A` - `Euler` @@ -115,6 +116,15 @@ cmake .. -DSD_CUBLAS=ON cmake --build . --config Release ``` +##### Using Metal + +On MacOS, Metal is enabled by default. Using Metal makes the computation run on the GPU. Currently, there are some issues with Metal when performing operations on very large matrices, making it highly inefficient at the moment. Performance improvements are expected in the near future. + +``` +cmake .. -DSD_METAL=ON +cmake --build . --config Release +``` + ### Using Flash Attention Enabling flash attention reduces memory usage by at least 400 MB. At the moment, it is not supported when CUBLAS is enabled because the kernel implementation is missing. @@ -127,7 +137,7 @@ cmake --build . --config Release ### Run ``` -usage: sd [arguments] +usage: sd.exe [arguments] arguments: -h, --help show this help message and exit @@ -157,6 +167,8 @@ arguments: -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0) -b, --batch-count COUNT number of images to generate. --schedule {discrete, karras} Denoiser sigma schedule (default: discrete) + -cs, --clip-skip N number of layers to skip of clip model (default: 0) + -vt, --vae-tiling process vae in tiles to reduce memory usage -v, --verbose print extra info ``` @@ -244,6 +256,16 @@ curl -L -O https://huggingface.co/madebyollin/taesd/blob/main/diffusion_pytorch_ sd -m ../models/v1-5-pruned-emaonly.safetensors -p "a lovely cat" --taesd ../models/diffusion_pytorch_model.safetensors ``` +## Using ESRGAN to upscale results + +You can use ESRGAN to upscale the generated images. At the moment, only the [RealESRGAN_x4plus_anime_6B.pth](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth) model is supported. Support for more models of this architecture will be added soon. + +- Specify the model path using the `--upscale-model PATH` parameter. example: + +```bash +sd -m ../models/v1-5-pruned-emaonly.safetensors -p "a lovely cat" --upscale-model ../models/RealESRGAN_x4plus_anime_6B.pth +``` + ### Docker #### Building using Docker From f97ff96812021341e5ac2885e1a70041658dde77 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Wed, 13 Dec 2023 16:07:30 -0500 Subject: [PATCH 16/24] improve upscale log info --- README.md | 4 ++-- stable-diffusion.cpp | 18 ++++++++++++------ 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 088356581..86a36380e 100644 --- a/README.md +++ b/README.md @@ -118,7 +118,7 @@ cmake --build . --config Release ##### Using Metal -On MacOS, Metal is enabled by default. Using Metal makes the computation run on the GPU. Currently, there are some issues with Metal when performing operations on very large matrices, making it highly inefficient at the moment. Performance improvements are expected in the near future. +Using Metal makes the computation run on the GPU. Currently, there are some issues with Metal when performing operations on very large matrices, making it highly inefficient at the moment. Performance improvements are expected in the near future. ``` cmake .. -DSD_METAL=ON @@ -137,7 +137,7 @@ cmake --build . --config Release ### Run ``` -usage: sd.exe [arguments] +usage: sd [arguments] arguments: -h, --help show this help message and exit diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index bd9bebd13..6e2ac3345 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -5334,7 +5334,6 @@ class StableDiffusionGGML { int64_t t0 = ggml_time_ms(); cond_stage_model.text_model.begin(work_ctx, (int)tokens.size()); struct ggml_tensor* hidden_states = cond_stage_model.text_model.compute(n_threads, tokens); // [N, n_token, hidden_size] - print_ggml_tensor(hidden_states); cond_stage_model.text_model.end(); int64_t t1 = ggml_time_ms(); LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0); @@ -6098,15 +6097,12 @@ std::vector StableDiffusion::txt2img(std::string prompt, LOG_INFO("generating %" PRId64 " latent images completed, taking %.2fs", final_latents.size(), (t3 - t1) * 1.0f / 1000); LOG_INFO("decoding %zu latents", final_latents.size()); + std::vector decoded_images; // collect decoded images for (size_t i = 0; i < final_latents.size(); i++) { t1 = ggml_time_ms(); struct ggml_tensor* img = sd->decode_first_stage(work_ctx, final_latents[i] /* x_0 */); if (img != NULL) { - if(sd->upscale_output) { - results.push_back(sd->upscale(img)); - } else { - results.push_back(sd_tensor_to_image(img)); - } + decoded_images.push_back(img); } int64_t t2 = ggml_time_ms(); LOG_INFO("latent %" PRId64 " decoded, taking %.2fs", i + 1, (t2 - t1) * 1.0f / 1000); @@ -6117,6 +6113,16 @@ std::vector StableDiffusion::txt2img(std::string prompt, if (sd->free_params_immediately && !sd->use_tiny_autoencoder) { sd->first_stage_model.destroy(); } + if(sd->upscale_output) { + LOG_INFO("upscaling %" PRId64 " images", decoded_images.size()); + } + for (size_t i = 0; i < decoded_images.size(); i++) { + if(sd->upscale_output) { + results.push_back(sd->upscale(decoded_images[i])); + } else { + results.push_back(sd_tensor_to_image(decoded_images[i])); + } + } ggml_free(work_ctx); LOG_INFO( "txt2img completed in %.2fs", From 6bdbd2591d6c1763791601de8a6a735616692937 Mon Sep 17 00:00:00 2001 From: leejet Date: Thu, 14 Dec 2023 23:37:24 +0800 Subject: [PATCH 17/24] standardize naming conventions employ PascalCase for class/struct/enumnames and underscore_case for variables/functions/methods. --- README.md | 8 ++++---- examples/cli/main.cpp | 15 +++++++++------ stable-diffusion.cpp | 12 ++++++------ 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 86a36380e..3c6a9d97b 100644 --- a/README.md +++ b/README.md @@ -137,7 +137,7 @@ cmake --build . --config Release ### Run ``` -usage: sd [arguments] +usage: ./bin/sd [arguments] arguments: -h, --help show this help message and exit @@ -147,7 +147,7 @@ arguments: -m, --model [MODEL] path to model --vae [VAE] path to vae --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality) - -um, --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now. + --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now. --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0) If not specified, the default is the type of the weight file. --lora-model-dir [DIR] lora model directory @@ -167,8 +167,8 @@ arguments: -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0) -b, --batch-count COUNT number of images to generate. --schedule {discrete, karras} Denoiser sigma schedule (default: discrete) - -cs, --clip-skip N number of layers to skip of clip model (default: 0) - -vt, --vae-tiling process vae in tiles to reduce memory usage + --clip-skip N number of layers to skip of clip model (default: 0) + --vae-tiling process vae in tiles to reduce memory usage -v, --verbose print extra info ``` diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 2d29ae9ba..5aa38d27a 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -91,11 +91,13 @@ void print_params(SDParams params) { printf(" wtype: %s\n", params.wtype < GGML_TYPE_COUNT ? ggml_type_name(params.wtype) : "unspecified"); printf(" vae_path: %s\n", params.vae_path.c_str()); printf(" taesd_path: %s\n", params.taesd_path.c_str()); + printf(" esrgan_path: %s\n", params.esrgan_path.c_str()); printf(" output_path: %s\n", params.output_path.c_str()); printf(" init_img: %s\n", params.input_path.c_str()); printf(" prompt: %s\n", params.prompt.c_str()); printf(" negative_prompt: %s\n", params.negative_prompt.c_str()); printf(" cfg_scale: %.2f\n", params.cfg_scale); + printf(" clip_skip_layers: %d\n", params.clip_skip_layers); printf(" width: %d\n", params.width); printf(" height: %d\n", params.height); printf(" sample_method: %s\n", sample_method_str[params.sample_method]); @@ -105,6 +107,7 @@ void print_params(SDParams params) { printf(" rng: %s\n", rng_type_to_str[params.rng_type]); printf(" seed: %ld\n", params.seed); printf(" batch_count: %d\n", params.batch_count); + printf(" vae_tiling: %s\n", params.vae_tiling ? "true" : "false"); } void print_usage(int argc, const char* argv[]) { @@ -118,7 +121,7 @@ void print_usage(int argc, const char* argv[]) { printf(" -m, --model [MODEL] path to model\n"); printf(" --vae [VAE] path to vae\n"); printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n"); - printf(" -um, --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now.\n"); + printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now.\n"); printf(" --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)\n"); printf(" If not specified, the default is the type of the weight file.\n"); printf(" --lora-model-dir [DIR] lora model directory\n"); @@ -138,8 +141,8 @@ void print_usage(int argc, const char* argv[]) { printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n"); printf(" -b, --batch-count COUNT number of images to generate.\n"); printf(" --schedule {discrete, karras} Denoiser sigma schedule (default: discrete)\n"); - printf(" -cs, --clip-skip N number of layers to skip of clip model (default: 0)\n"); - printf(" -vt, --vae-tiling process vae in tiles to reduce memory usage\n"); + printf(" --clip-skip N number of layers to skip of clip model (default: 0)\n"); + printf(" --vae-tiling process vae in tiles to reduce memory usage\n"); printf(" -v, --verbose print extra info\n"); } @@ -191,7 +194,7 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.taesd_path = argv[i]; - } else if (arg == "--upscale-model" || arg == "-um") { + } else if (arg == "--upscale-model") { if (++i >= argc) { invalid_arg = true; break; @@ -282,13 +285,13 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.sample_steps = std::stoi(argv[i]); - } else if (arg == "-cs" || arg == "--clip-skip") { + } else if (arg == "--clip-skip") { if (++i >= argc) { invalid_arg = true; break; } params.clip_skip_layers = std::stoi(argv[i]); - } else if (arg == "-vt" || arg == "--vae-tiling") { + } else if (arg == "--vae-tiling") { params.vae_tiling = true; } else if (arg == "-b" || arg == "--batch-count") { if (++i >= argc) { diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 6e2ac3345..4d38dc5d0 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -4075,7 +4075,7 @@ struct ResidualDenseBlock { return mem_size; } - int getNumTensors() { + int get_num_tensors() { int num_tensors = 10; return num_tensors; } @@ -4158,10 +4158,10 @@ struct EsrganBlock { } } - int getNumTensors() { + int get_num_tensors() { int num_tensors = 0; for(int i = 0; i < num_residual_blocks; i++) { - num_tensors += rd_blocks[i].getNumTensors(); + num_tensors += rd_blocks[i].get_num_tensors(); } return num_tensors; } @@ -4268,10 +4268,10 @@ struct ESRGAN { return mem_size; } - int getNumTensors() { + int get_num_tensors() { int num_tensors = 12; for(int i = 0; i < num_blocks; i++) { - num_tensors += body_blocks[i].getNumTensors(); + num_tensors += body_blocks[i].get_num_tensors(); } return num_tensors; } @@ -4280,7 +4280,7 @@ struct ESRGAN { this->backend = backend_; memory_buffer_size = calculate_mem_size(); memory_buffer_size += 1024; // overhead - int num_tensors = getNumTensors(); + int num_tensors = get_num_tensors(); LOG_DEBUG("ESRGAN params backend buffer size = % 6.2f MB (%i tensors)", memory_buffer_size / (1024.0 * 1024.0), num_tensors); From 56e64742a5a526d259755b94fd032f5161b63e66 Mon Sep 17 00:00:00 2001 From: leejet Date: Thu, 14 Dec 2023 23:43:45 +0800 Subject: [PATCH 18/24] format code --- examples/cli/main.cpp | 10 +- model.cpp | 2 +- stable-diffusion.cpp | 370 +++++++++++++++++++++--------------------- 3 files changed, 192 insertions(+), 190 deletions(-) diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 5aa38d27a..b5c7ec31b 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -67,11 +67,11 @@ struct SDParams { std::string prompt; std::string negative_prompt; - float cfg_scale = 7.0f; + float cfg_scale = 7.0f; int clip_skip_layers = 0; - int width = 512; - int height = 512; - int batch_count = 1; + int width = 512; + int height = 512; + int batch_count = 1; SampleMethod sample_method = EULER_A; Schedule schedule = DEFAULT; @@ -511,7 +511,7 @@ int main(int argc, const char* argv[]) { params.seed); } - if(params.esrgan_path.size() > 0) { + if (params.esrgan_path.size() > 0) { // TODO: support more ESRGAN models, making it easier to set up ESRGAN models. /* hardcoded scale factor because just RealESRGAN_x4plus_anime_6B is compatible See also: https://github.com/xinntao/Real-ESRGAN/blob/master/inference_realesrgan.py diff --git a/model.cpp b/model.cpp index 15672ab87..abd245a35 100644 --- a/model.cpp +++ b/model.cpp @@ -1304,7 +1304,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend if (dst_tensor->buffer == NULL || ggml_backend_is_cpu(backend) #ifdef SD_USE_METAL - || ggml_backend_is_metal(backend) + || ggml_backend_is_metal(backend) #endif ) { // for the CPU and Metal backend, we can copy directly into the tensor diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 4d38dc5d0..487df05c9 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -79,9 +79,9 @@ std::string sd_get_system_info() { return ss.str(); } -static void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) { - (void) level; - (void) user_data; +static void ggml_log_callback_default(ggml_log_level level, const char* text, void* user_data) { + (void)level; + (void)user_data; fputs(text, stderr); fflush(stderr); } @@ -347,10 +347,12 @@ void sd_image_to_tensor(const uint8_t* image_data, } void ggml_split_tensor_2d(struct ggml_tensor* input, - struct ggml_tensor* output, int x, int y) { - int64_t width = output->ne[0]; - int64_t height = output->ne[1]; - int64_t channels = output->ne[2]; + struct ggml_tensor* output, + int x, + int y) { + int64_t width = output->ne[0]; + int64_t height = output->ne[1]; + int64_t channels = output->ne[2]; GGML_ASSERT(input->type == GGML_TYPE_F32 && output->type == GGML_TYPE_F32); for (int iy = 0; iy < height; iy++) { for (int ix = 0; ix < width; ix++) { @@ -363,23 +365,26 @@ void ggml_split_tensor_2d(struct ggml_tensor* input, } void ggml_merge_tensor_2d(struct ggml_tensor* input, - struct ggml_tensor* output, int x, int y, int overlap) { - int64_t width = input->ne[0]; - int64_t height = input->ne[1]; - int64_t channels = input->ne[2]; + struct ggml_tensor* output, + int x, + int y, + int overlap) { + int64_t width = input->ne[0]; + int64_t height = input->ne[1]; + int64_t channels = input->ne[2]; GGML_ASSERT(input->type == GGML_TYPE_F32 && output->type == GGML_TYPE_F32); for (int iy = 0; iy < height; iy++) { for (int ix = 0; ix < width; ix++) { for (int k = 0; k < channels; k++) { float new_value = ggml_tensor_get_f32(input, ix, iy, k); - if(overlap > 0) { // blend colors in overlapped area + if (overlap > 0) { // blend colors in overlapped area float old_value = ggml_tensor_get_f32(output, x + ix, y + iy, k); - if(x > 0 && ix < overlap) { // in overlapped horizontal - ggml_tensor_set_f32(output, old_value + (new_value - old_value) * (ix / (1.0f * overlap)), x + ix,y + iy, k); + if (x > 0 && ix < overlap) { // in overlapped horizontal + ggml_tensor_set_f32(output, old_value + (new_value - old_value) * (ix / (1.0f * overlap)), x + ix, y + iy, k); continue; } - if(y > 0 && iy < overlap) { // in overlapped vertical - ggml_tensor_set_f32(output, old_value + (new_value - old_value) * (iy / (1.0f * overlap)), x + ix,y + iy, k); + if (y > 0 && iy < overlap) { // in overlapped vertical + ggml_tensor_set_f32(output, old_value + (new_value - old_value) * (iy / (1.0f * overlap)), x + ix, y + iy, k); continue; } } @@ -447,22 +452,22 @@ void ggml_tensor_scale_output(struct ggml_tensor* src) { } } -typedef std::function on_tile_process; +typedef std::function on_tile_process; // Tiling void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) { - int input_width = input->ne[0]; - int input_height = input->ne[1]; - int output_width = output->ne[0]; + int input_width = input->ne[0]; + int input_height = input->ne[1]; + int output_width = output->ne[0]; int output_height = output->ne[1]; - GGML_ASSERT(input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0); // should be multiple of 2 + GGML_ASSERT(input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0); // should be multiple of 2 - int tile_overlap = (int32_t)(tile_size * tile_overlap_factor); + int tile_overlap = (int32_t)(tile_size * tile_overlap_factor); int non_tile_overlap = tile_size - tile_overlap; struct ggml_init_params params = {}; - params.mem_size += tile_size * tile_size * input->ne[2] * sizeof(float); // input chunk - params.mem_size += (tile_size * scale) * (tile_size * scale) * output->ne[2] * sizeof(float); // output chunk + params.mem_size += tile_size * tile_size * input->ne[2] * sizeof(float); // input chunk + params.mem_size += (tile_size * scale) * (tile_size * scale) * output->ne[2] * sizeof(float); // output chunk params.mem_size += 3 * ggml_tensor_overhead(); params.mem_buffer = NULL; params.no_alloc = false; @@ -477,7 +482,7 @@ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const i } // tiling - ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, tile_size, tile_size, input->ne[2], 1); + ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, tile_size, tile_size, input->ne[2], 1); ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, tile_size * scale, tile_size * scale, output->ne[2], 1); on_processing(input_tile, NULL, true); int num_tiles = (input_width * input_height) / (non_tile_overlap * non_tile_overlap); @@ -486,14 +491,14 @@ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const i int tile_count = 1; bool last_y = false, last_x = false; float last_time = 0.0f; - for(int y = 0; y < input_height && !last_y; y += non_tile_overlap) { + for (int y = 0; y < input_height && !last_y; y += non_tile_overlap) { if (y + tile_size >= input_height) { - y = input_height - tile_size; + y = input_height - tile_size; last_y = true; } - for(int x = 0; x < input_width && !last_x; x += non_tile_overlap) { + for (int x = 0; x < input_width && !last_x; x += non_tile_overlap) { if (x + tile_size >= input_width) { - x = input_width - tile_size; + x = input_width - tile_size; last_x = true; } int64_t t1 = ggml_time_ms(); @@ -501,13 +506,13 @@ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const i on_processing(input_tile, output_tile, false); ggml_merge_tensor_2d(output_tile, output, x * scale, y * scale, tile_overlap * scale); int64_t t2 = ggml_time_ms(); - last_time = (t2 - t1) / 1000.0f; + last_time = (t2 - t1) / 1000.0f; pretty_progress(tile_count, num_tiles, last_time); tile_count++; } last_x = false; } - if(tile_count < num_tiles) { + if (tile_count < num_tiles) { pretty_progress(num_tiles, num_tiles, last_time); } } @@ -998,7 +1003,7 @@ struct CLIPTextModel { int32_t intermediate_size = 3072; // 4096 for SD 2.x int32_t n_head = 12; // num_attention_heads, 16 for SD 2.x int32_t num_hidden_layers = 12; // 24 for SD 2.x - int32_t skip_layers = 0; + int32_t skip_layers = 0; // embeddings struct ggml_tensor* position_ids; @@ -4034,43 +4039,43 @@ struct TinyAutoEncoder { struct ResidualDenseBlock { int num_features; int num_grow_ch; - ggml_tensor* conv1_w; // [num_grow_ch, num_features, 3, 3] - ggml_tensor* conv1_b; // [num_grow_ch] + ggml_tensor* conv1_w; // [num_grow_ch, num_features, 3, 3] + ggml_tensor* conv1_b; // [num_grow_ch] - ggml_tensor* conv2_w; // [num_grow_ch, num_features + num_grow_ch, 3, 3] - ggml_tensor* conv2_b; // [num_grow_ch] + ggml_tensor* conv2_w; // [num_grow_ch, num_features + num_grow_ch, 3, 3] + ggml_tensor* conv2_b; // [num_grow_ch] - ggml_tensor* conv3_w; // [num_grow_ch, num_features + 2 * num_grow_ch, 3, 3] - ggml_tensor* conv3_b; // [num_grow_ch] + ggml_tensor* conv3_w; // [num_grow_ch, num_features + 2 * num_grow_ch, 3, 3] + ggml_tensor* conv3_b; // [num_grow_ch] - ggml_tensor* conv4_w; // [num_grow_ch, num_features + 3 * num_grow_ch, 3, 3] - ggml_tensor* conv4_b; // [num_grow_ch] + ggml_tensor* conv4_w; // [num_grow_ch, num_features + 3 * num_grow_ch, 3, 3] + ggml_tensor* conv4_b; // [num_grow_ch] - ggml_tensor* conv5_w; // [num_features, num_features + 4 * num_grow_ch, 3, 3] - ggml_tensor* conv5_b; // [num_features] + ggml_tensor* conv5_w; // [num_features, num_features + 4 * num_grow_ch, 3, 3] + ggml_tensor* conv5_b; // [num_features] ResidualDenseBlock() {} ResidualDenseBlock(int num_feat, int n_grow_ch) { num_features = num_feat; - num_grow_ch = n_grow_ch; + num_grow_ch = n_grow_ch; } size_t calculate_mem_size() { - size_t mem_size = num_features * num_grow_ch * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv1_w - mem_size += num_grow_ch * ggml_type_size(GGML_TYPE_F32); // conv1_b + size_t mem_size = num_features * num_grow_ch * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv1_w + mem_size += num_grow_ch * ggml_type_size(GGML_TYPE_F32); // conv1_b - mem_size += (num_features + num_grow_ch) * num_grow_ch * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv2_w - mem_size += num_grow_ch * ggml_type_size(GGML_TYPE_F32); // conv2_b + mem_size += (num_features + num_grow_ch) * num_grow_ch * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv2_w + mem_size += num_grow_ch * ggml_type_size(GGML_TYPE_F32); // conv2_b - mem_size += (num_features + 2*num_grow_ch) * num_grow_ch * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv3_w - mem_size += num_grow_ch * ggml_type_size(GGML_TYPE_F32); // conv3_w + mem_size += (num_features + 2 * num_grow_ch) * num_grow_ch * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv3_w + mem_size += num_grow_ch * ggml_type_size(GGML_TYPE_F32); // conv3_w - mem_size += (num_features + 3*num_grow_ch) * num_grow_ch * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv4_w - mem_size += num_grow_ch * ggml_type_size(GGML_TYPE_F32); // conv4_w + mem_size += (num_features + 3 * num_grow_ch) * num_grow_ch * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv4_w + mem_size += num_grow_ch * ggml_type_size(GGML_TYPE_F32); // conv4_w - mem_size += (num_features + 4*num_grow_ch) * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv5_w - mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv5_w + mem_size += (num_features + 4 * num_grow_ch) * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv5_w + mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv5_w return mem_size; } @@ -4091,54 +4096,53 @@ struct ResidualDenseBlock { conv4_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_grow_ch); conv5_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features + 4 * num_grow_ch, num_features); conv5_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_features); - } - void map_by_name(std::map & tensors, std::string prefix) { + void map_by_name(std::map& tensors, std::string prefix) { tensors[prefix + "conv1.weight"] = conv1_w; - tensors[prefix + "conv1.bias"] = conv1_b; + tensors[prefix + "conv1.bias"] = conv1_b; tensors[prefix + "conv2.weight"] = conv2_w; - tensors[prefix + "conv2.bias"] = conv2_b; + tensors[prefix + "conv2.bias"] = conv2_b; tensors[prefix + "conv3.weight"] = conv3_w; - tensors[prefix + "conv3.bias"] = conv3_b; + tensors[prefix + "conv3.bias"] = conv3_b; tensors[prefix + "conv4.weight"] = conv4_w; - tensors[prefix + "conv4.bias"] = conv4_b; + tensors[prefix + "conv4.bias"] = conv4_b; tensors[prefix + "conv5.weight"] = conv5_w; - tensors[prefix + "conv5.bias"] = conv5_b; + tensors[prefix + "conv5.bias"] = conv5_b; } - ggml_tensor* forward(ggml_context* ctx, ggml_tensor* out_scale, ggml_tensor* x /* feat */) { + ggml_tensor* forward(ggml_context* ctx, ggml_tensor* out_scale, ggml_tensor* x /* feat */) { // x1 = self.lrelu(self.conv1(x)) ggml_tensor* x1 = ggml_conv_2d(ctx, conv1_w, x, 1, 1, 1, 1, 1, 1); - x1 = ggml_add(ctx, x1, ggml_reshape_4d(ctx, conv1_b, 1, 1, conv1_b->ne[0], 1)); - x1 = ggml_leaky_relu(ctx, x1, 0.2f, true); + x1 = ggml_add(ctx, x1, ggml_reshape_4d(ctx, conv1_b, 1, 1, conv1_b->ne[0], 1)); + x1 = ggml_leaky_relu(ctx, x1, 0.2f, true); // x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) ggml_tensor* x_cat = ggml_concat(ctx, x, x1); - ggml_tensor* x2 = ggml_conv_2d(ctx, conv2_w, x_cat, 1, 1, 1, 1, 1, 1); - x2 = ggml_add(ctx, x2, ggml_reshape_4d(ctx, conv2_b, 1, 1, conv2_b->ne[0], 1)); - x2 = ggml_leaky_relu(ctx, x2, 0.2f, true); + ggml_tensor* x2 = ggml_conv_2d(ctx, conv2_w, x_cat, 1, 1, 1, 1, 1, 1); + x2 = ggml_add(ctx, x2, ggml_reshape_4d(ctx, conv2_b, 1, 1, conv2_b->ne[0], 1)); + x2 = ggml_leaky_relu(ctx, x2, 0.2f, true); // x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) - x_cat = ggml_concat(ctx, x_cat, x2); + x_cat = ggml_concat(ctx, x_cat, x2); ggml_tensor* x3 = ggml_conv_2d(ctx, conv3_w, x_cat, 1, 1, 1, 1, 1, 1); - x3 = ggml_add(ctx, x3, ggml_reshape_4d(ctx, conv3_b, 1, 1, conv3_b->ne[0], 1)); - x3 = ggml_leaky_relu(ctx, x3, 0.2f, true); + x3 = ggml_add(ctx, x3, ggml_reshape_4d(ctx, conv3_b, 1, 1, conv3_b->ne[0], 1)); + x3 = ggml_leaky_relu(ctx, x3, 0.2f, true); // x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) - x_cat = ggml_concat(ctx, x_cat, x3); + x_cat = ggml_concat(ctx, x_cat, x3); ggml_tensor* x4 = ggml_conv_2d(ctx, conv4_w, x_cat, 1, 1, 1, 1, 1, 1); - x4 = ggml_add(ctx, x4, ggml_reshape_4d(ctx, conv4_b, 1, 1, conv4_b->ne[0], 1)); - x4 = ggml_leaky_relu(ctx, x4, 0.2f, true); + x4 = ggml_add(ctx, x4, ggml_reshape_4d(ctx, conv4_b, 1, 1, conv4_b->ne[0], 1)); + x4 = ggml_leaky_relu(ctx, x4, 0.2f, true); // self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) - x_cat = ggml_concat(ctx, x_cat, x4); + x_cat = ggml_concat(ctx, x_cat, x4); ggml_tensor* x5 = ggml_conv_2d(ctx, conv5_w, x_cat, 1, 1, 1, 1, 1, 1); - x5 = ggml_add(ctx, x5, ggml_reshape_4d(ctx, conv5_b, 1, 1, conv5_b->ne[0], 1)); + x5 = ggml_add(ctx, x5, ggml_reshape_4d(ctx, conv5_b, 1, 1, conv5_b->ne[0], 1)); // return x5 * 0.2 + x x5 = ggml_add(ctx, ggml_scale(ctx, x5, out_scale), x); @@ -4153,14 +4157,14 @@ struct EsrganBlock { EsrganBlock() {} EsrganBlock(int num_feat, int num_grow_ch) { - for(int i = 0; i < num_residual_blocks; i++) { + for (int i = 0; i < num_residual_blocks; i++) { rd_blocks[i] = ResidualDenseBlock(num_feat, num_grow_ch); } } int get_num_tensors() { int num_tensors = 0; - for(int i = 0; i < num_residual_blocks; i++) { + for (int i = 0; i < num_residual_blocks; i++) { num_tensors += rd_blocks[i].get_num_tensors(); } return num_tensors; @@ -4168,28 +4172,27 @@ struct EsrganBlock { size_t calculate_mem_size() { size_t mem_size = 0; - for(int i = 0; i < num_residual_blocks; i++) { + for (int i = 0; i < num_residual_blocks; i++) { mem_size += rd_blocks[i].calculate_mem_size(); } return mem_size; } - void init_params(ggml_context* ctx) { - for(int i = 0; i < num_residual_blocks; i++) { + for (int i = 0; i < num_residual_blocks; i++) { rd_blocks[i].init_params(ctx); } } - void map_by_name(std::map & tensors, std::string prefix) { - for(int i = 0; i < num_residual_blocks; i++) { - rd_blocks[i].map_by_name(tensors, prefix + "rdb" + std::to_string(i + 1) +"."); + void map_by_name(std::map& tensors, std::string prefix) { + for (int i = 0; i < num_residual_blocks; i++) { + rd_blocks[i].map_by_name(tensors, prefix + "rdb" + std::to_string(i + 1) + "."); } } - ggml_tensor* forward(ggml_context* ctx, ggml_tensor* out_scale, ggml_tensor* x) { + ggml_tensor* forward(ggml_context* ctx, ggml_tensor* out_scale, ggml_tensor* x) { ggml_tensor* out = x; - for(int i = 0; i < num_residual_blocks; i++) { + for (int i = 0; i < num_residual_blocks; i++) { // out = self.rdb...(x) out = rd_blocks[i].forward(ctx, out_scale, out); } @@ -4200,94 +4203,94 @@ struct EsrganBlock { }; struct ESRGAN { - int scale = 4; // default RealESRGAN_x4plus_anime_6B - int num_blocks = 6; // default RealESRGAN_x4plus_anime_6B - int in_channels = 3; + int scale = 4; // default RealESRGAN_x4plus_anime_6B + int num_blocks = 6; // default RealESRGAN_x4plus_anime_6B + int in_channels = 3; int out_channels = 3; - int num_features = 64; // default RealESRGAN_x4plus_anime_6B - int num_grow_ch = 32; // default RealESRGAN_x4plus_anime_6B - int tile_size = 128; // avoid cuda OOM for 4gb VRAM + int num_features = 64; // default RealESRGAN_x4plus_anime_6B + int num_grow_ch = 32; // default RealESRGAN_x4plus_anime_6B + int tile_size = 128; // avoid cuda OOM for 4gb VRAM - ggml_tensor* conv_first_w; // [num_features, in_channels, 3, 3] - ggml_tensor* conv_first_b; // [num_features] + ggml_tensor* conv_first_w; // [num_features, in_channels, 3, 3] + ggml_tensor* conv_first_b; // [num_features] EsrganBlock body_blocks[6]; - ggml_tensor* conv_body_w; // [num_features, num_features, 3, 3] - ggml_tensor* conv_body_b; // [num_features] + ggml_tensor* conv_body_w; // [num_features, num_features, 3, 3] + ggml_tensor* conv_body_b; // [num_features] // upsample - ggml_tensor* conv_up1_w; // [num_features, num_features, 3, 3] - ggml_tensor* conv_up1_b; // [num_features] - ggml_tensor* conv_up2_w; // [num_features, num_features, 3, 3] - ggml_tensor* conv_up2_b; // [num_features] + ggml_tensor* conv_up1_w; // [num_features, num_features, 3, 3] + ggml_tensor* conv_up1_b; // [num_features] + ggml_tensor* conv_up2_w; // [num_features, num_features, 3, 3] + ggml_tensor* conv_up2_b; // [num_features] - ggml_tensor* conv_hr_w; // [num_features, num_features, 3, 3] - ggml_tensor* conv_hr_b; // [num_features] - ggml_tensor* conv_last_w; // [out_channels, num_features, 3, 3] - ggml_tensor* conv_last_b; // [out_channels] + ggml_tensor* conv_hr_w; // [num_features, num_features, 3, 3] + ggml_tensor* conv_hr_b; // [num_features] + ggml_tensor* conv_last_w; // [out_channels, num_features, 3, 3] + ggml_tensor* conv_last_b; // [out_channels] ggml_context* ctx; bool decode_only = false; ggml_backend_buffer_t params_buffer; - ggml_backend_buffer_t compute_buffer; // for compute - struct ggml_allocr * compute_alloc = NULL; + ggml_backend_buffer_t compute_buffer; // for compute + struct ggml_allocr* compute_alloc = NULL; int memory_buffer_size = 0; ggml_type wtype; ggml_backend_t backend = NULL; ESRGAN() { - for(int i = 0; i < num_blocks; i++) { + for (int i = 0; i < num_blocks; i++) { body_blocks[i] = EsrganBlock(num_features, num_grow_ch); } } size_t calculate_mem_size() { - size_t mem_size = num_features * in_channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_first_w - mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv_first_b + size_t mem_size = num_features * in_channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_first_w + mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv_first_b - for(int i = 0; i < num_blocks; i++) { + for (int i = 0; i < num_blocks; i++) { mem_size += body_blocks[i].calculate_mem_size(); } - mem_size += num_features * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_body_w - mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv_body_w + mem_size += num_features * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_body_w + mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv_body_w // upsample - mem_size += num_features * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_up1_w - mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv_up1_b + mem_size += num_features * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_up1_w + mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv_up1_b - mem_size += num_features * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_up2_w - mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv_up2_b + mem_size += num_features * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_up2_w + mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv_up2_b - mem_size += num_features * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_hr_w - mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv_hr_b + mem_size += num_features * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_hr_w + mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv_hr_b - mem_size += out_channels * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_last_w - mem_size += out_channels * ggml_type_size(GGML_TYPE_F32); // conv_last_b + mem_size += out_channels * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_last_w + mem_size += out_channels * ggml_type_size(GGML_TYPE_F32); // conv_last_b return mem_size; } int get_num_tensors() { int num_tensors = 12; - for(int i = 0; i < num_blocks; i++) { + for (int i = 0; i < num_blocks; i++) { num_tensors += body_blocks[i].get_num_tensors(); } return num_tensors; } bool init(ggml_backend_t backend_) { - this->backend = backend_; + this->backend = backend_; memory_buffer_size = calculate_mem_size(); - memory_buffer_size += 1024; // overhead + memory_buffer_size += 1024; // overhead int num_tensors = get_num_tensors(); LOG_DEBUG("ESRGAN params backend buffer size = % 6.2f MB (%i tensors)", memory_buffer_size / (1024.0 * 1024.0), num_tensors); struct ggml_init_params params; - params.mem_size = static_cast(num_tensors * ggml_tensor_overhead()); + params.mem_size = static_cast(num_tensors * ggml_tensor_overhead()); params.mem_buffer = NULL; - params.no_alloc = true; + params.no_alloc = true; params_buffer = ggml_backend_alloc_buffer(backend, memory_buffer_size); @@ -4301,26 +4304,26 @@ struct ESRGAN { void alloc_params() { ggml_allocr* alloc = ggml_allocr_new_from_buffer(params_buffer); - conv_first_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, in_channels, num_features); - conv_first_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_features); - conv_body_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features, num_features); - conv_body_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_features); - conv_up1_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features, num_features); - conv_up1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_features); - conv_up2_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features, num_features); - conv_up2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_features); - conv_hr_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features, num_features); - conv_hr_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_features); - conv_last_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features, out_channels); - conv_last_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); - - for(int i = 0; i < num_blocks; i++) { + conv_first_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, in_channels, num_features); + conv_first_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_features); + conv_body_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features, num_features); + conv_body_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_features); + conv_up1_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features, num_features); + conv_up1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_features); + conv_up2_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features, num_features); + conv_up2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_features); + conv_hr_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features, num_features); + conv_hr_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_features); + conv_last_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features, out_channels); + conv_last_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); + + for (int i = 0; i < num_blocks; i++) { body_blocks[i].init_params(ctx); } // alloc all tensors linked to this context - for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { - if(t->data == NULL) { + for (struct ggml_tensor* t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->data == NULL) { ggml_allocr_alloc(alloc, t); } } @@ -4400,36 +4403,36 @@ struct ESRGAN { return success; } - void map_by_name(std::map & tensors) { + void map_by_name(std::map& tensors) { tensors["conv_first.weight"] = conv_first_w; - tensors["conv_first.bias"] = conv_first_b; + tensors["conv_first.bias"] = conv_first_b; - for(int i = 0; i < num_blocks; i++) { - body_blocks[i].map_by_name(tensors, "body." + std::to_string(i) +"."); + for (int i = 0; i < num_blocks; i++) { + body_blocks[i].map_by_name(tensors, "body." + std::to_string(i) + "."); } tensors["conv_body.weight"] = conv_body_w; - tensors["conv_body.bias"] = conv_body_b; + tensors["conv_body.bias"] = conv_body_b; tensors["conv_up1.weight"] = conv_up1_w; - tensors["conv_up1.bias"] = conv_up1_b; + tensors["conv_up1.bias"] = conv_up1_b; tensors["conv_up2.weight"] = conv_up2_w; - tensors["conv_up2.bias"] = conv_up2_b; - tensors["conv_hr.weight"] = conv_hr_w; - tensors["conv_hr.bias"] = conv_hr_b; + tensors["conv_up2.bias"] = conv_up2_b; + tensors["conv_hr.weight"] = conv_hr_w; + tensors["conv_hr.bias"] = conv_hr_b; tensors["conv_last.weight"] = conv_last_w; - tensors["conv_last.bias"] = conv_last_b; + tensors["conv_last.bias"] = conv_last_b; } ggml_tensor* forward(ggml_context* ctx0, ggml_tensor* out_scale, ggml_tensor* x /* feat */) { // feat = self.conv_first(feat) auto h = ggml_conv_2d(ctx0, conv_first_w, x, 1, 1, 1, 1, 1, 1); - h = ggml_add(ctx0, h, ggml_reshape_4d(ctx0, conv_first_b, 1, 1, conv_first_b->ne[0], 1)); + h = ggml_add(ctx0, h, ggml_reshape_4d(ctx0, conv_first_b, 1, 1, conv_first_b->ne[0], 1)); auto body_h = h; // self.body(feat) - for(int i = 0; i < num_blocks; i++) { + for (int i = 0; i < num_blocks; i++) { body_h = body_blocks[i].forward(ctx0, out_scale, body_h); } @@ -4464,37 +4467,37 @@ struct ESRGAN { return h; } - struct ggml_cgraph * build_graph(struct ggml_tensor* x) { + struct ggml_cgraph* build_graph(struct ggml_tensor* x) { // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data static size_t buf_size = ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); static std::vector buf(buf_size); struct ggml_init_params params = { - /*.mem_size =*/ buf_size, - /*.mem_buffer =*/ buf.data(), - /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_allocr_alloc_graph() + /*.mem_size =*/buf_size, + /*.mem_buffer =*/buf.data(), + /*.no_alloc =*/true, // the tensors will be allocated later by ggml_allocr_alloc_graph() }; - struct ggml_context * ctx0 = ggml_init(params); + struct ggml_context* ctx0 = ggml_init(params); - struct ggml_cgraph * gf = ggml_new_graph(ctx0); + struct ggml_cgraph* gf = ggml_new_graph(ctx0); struct ggml_tensor* x_ = NULL; struct ggml_tensor* os = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); ggml_allocr_alloc(compute_alloc, os); - if(!ggml_allocr_is_measure(compute_alloc)) { + if (!ggml_allocr_is_measure(compute_alloc)) { float scale = 0.2f; ggml_backend_tensor_set(os, &scale, 0, sizeof(scale)); } - // it's performing a compute, check if backend isn't cpu - if(!ggml_backend_is_cpu(backend)) { + // it's performing a compute, check if backend isn't cpu + if (!ggml_backend_is_cpu(backend)) { // pass input tensors to gpu memory x_ = ggml_dup_tensor(ctx0, x); ggml_allocr_alloc(compute_alloc, x_); // pass data to device backend - if(!ggml_allocr_is_measure(compute_alloc)) { + if (!ggml_allocr_is_measure(compute_alloc)) { ggml_backend_tensor_set(x_, x->data, 0, ggml_nbytes(x)); } } else { @@ -4514,7 +4517,7 @@ struct ESRGAN { // alignment required by the backend compute_alloc = ggml_allocr_new_measure_from_backend(backend); - struct ggml_cgraph * gf = build_graph(x); + struct ggml_cgraph* gf = build_graph(x); // compute the required memory size_t compute_memory_buffer_size = ggml_allocr_alloc_graph(compute_alloc, gf); @@ -4525,13 +4528,13 @@ struct ESRGAN { LOG_DEBUG("ESRGAN compute buffer size: %.2f MB", compute_memory_buffer_size / 1024.0 / 1024.0); compute_buffer = ggml_backend_alloc_buffer(backend, compute_memory_buffer_size); - compute_alloc = ggml_allocr_new_from_buffer(compute_buffer); + compute_alloc = ggml_allocr_new_from_buffer(compute_buffer); } void compute(struct ggml_tensor* work_result, const int n_threads, struct ggml_tensor* x) { ggml_allocr_reset(compute_alloc); - struct ggml_cgraph * gf = build_graph(x); + struct ggml_cgraph* gf = build_graph(x); ggml_allocr_alloc_graph(compute_alloc, gf); if (ggml_backend_is_cpu(backend)) { @@ -4560,8 +4563,6 @@ struct ESRGAN { } }; - - float ggml_backend_tensor_get_f32(ggml_tensor* tensor) { GGML_ASSERT(tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_F16); float value; @@ -4919,7 +4920,7 @@ class StableDiffusionGGML { UNetModel diffusion_model; AutoEncoderKL first_stage_model; bool use_tiny_autoencoder = false; - bool vae_tiling = false; + bool vae_tiling = false; std::map tensors; @@ -4971,7 +4972,8 @@ class StableDiffusionGGML { bool load_from_file(const std::string& model_path, const std::string& vae_path, ggml_type wtype, - Schedule schedule, int clip_skip_layers) { + Schedule schedule, + int clip_skip_layers) { #ifdef SD_USE_CUBLAS LOG_DEBUG("Using CUDA backend"); backend = ggml_backend_cuda_init(0); @@ -5013,9 +5015,9 @@ class StableDiffusionGGML { LOG_ERROR("get sd version from file failed: '%s'", model_path.c_str()); return false; } - cond_stage_model = FrozenCLIPEmbedderWithCustomWords(version); + cond_stage_model = FrozenCLIPEmbedderWithCustomWords(version); cond_stage_model.text_model.skip_layers = clip_skip_layers; - diffusion_model = UNetModel(version); + diffusion_model = UNetModel(version); LOG_INFO("Stable Diffusion %s ", model_version_to_str[version]); if (wtype == GGML_TYPE_COUNT) { model_data_type = model_loader.get_sd_wtype(); @@ -5221,8 +5223,8 @@ class StableDiffusionGGML { } LOG_DEBUG("finished loaded file"); ggml_free(ctx); - if(upscale_output) { - if(!esrgan_upscaler.load_from_file(esrgan_path, backend)) { + if (upscale_output) { + if (!esrgan_upscaler.load_from_file(esrgan_path, backend)) { return false; } } @@ -5886,10 +5888,10 @@ class StableDiffusionGGML { } else { ggml_tensor_scale_input(x); } - if(vae_tiling && decode) { // TODO: support tiling vae encode + if (vae_tiling && decode) { // TODO: support tiling vae encode // split latent in 32x32 tiles and compute in several steps auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { - if(init) { + if (init) { first_stage_model.begin(in, decode); } else { first_stage_model.compute(out, n_threads, in, decode); @@ -5905,10 +5907,10 @@ class StableDiffusionGGML { ggml_tensor_scale_output(result); } } else { - if(vae_tiling && decode) { // TODO: support tiling vae encode + if (vae_tiling && decode) { // TODO: support tiling vae encode // split latent in 64x64 tiles and compute in several steps auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { - if(init) { + if (init) { tae_first_stage.begin(in, decode); } else { tae_first_stage.compute(out, n_threads, in, decode); @@ -5930,11 +5932,11 @@ class StableDiffusionGGML { } uint8_t* upscale(ggml_tensor* image) { - int output_width = image->ne[0] * esrgan_upscaler.scale; + int output_width = image->ne[0] * esrgan_upscaler.scale; int output_height = image->ne[1] * esrgan_upscaler.scale; LOG_INFO("upscaling from (%i x %i) to (%i x %i)", image->ne[0], image->ne[1], output_width, output_height); struct ggml_init_params params; - params.mem_size = output_width * output_height * 3 * sizeof(float); // upscaled + params.mem_size = output_width * output_height * 3 * sizeof(float); // upscaled params.mem_size += 1 * ggml_tensor_overhead(); params.mem_buffer = NULL; params.no_alloc = false; @@ -5946,8 +5948,8 @@ class StableDiffusionGGML { } LOG_DEBUG("upscale work buffer size: %.2f MB", params.mem_size / 1024.f / 1024.f); ggml_tensor* upscaled = ggml_new_tensor_4d(upscale_ctx, GGML_TYPE_F32, output_width, output_height, image->ne[2], 1); - auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { - if(init) { + auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { + if (init) { esrgan_upscaler.begin(in); } else { esrgan_upscaler.compute(out, n_threads, in); @@ -6097,7 +6099,7 @@ std::vector StableDiffusion::txt2img(std::string prompt, LOG_INFO("generating %" PRId64 " latent images completed, taking %.2fs", final_latents.size(), (t3 - t1) * 1.0f / 1000); LOG_INFO("decoding %zu latents", final_latents.size()); - std::vector decoded_images; // collect decoded images + std::vector decoded_images; // collect decoded images for (size_t i = 0; i < final_latents.size(); i++) { t1 = ggml_time_ms(); struct ggml_tensor* img = sd->decode_first_stage(work_ctx, final_latents[i] /* x_0 */); @@ -6113,11 +6115,11 @@ std::vector StableDiffusion::txt2img(std::string prompt, if (sd->free_params_immediately && !sd->use_tiny_autoencoder) { sd->first_stage_model.destroy(); } - if(sd->upscale_output) { + if (sd->upscale_output) { LOG_INFO("upscaling %" PRId64 " images", decoded_images.size()); } for (size_t i = 0; i < decoded_images.size(); i++) { - if(sd->upscale_output) { + if (sd->upscale_output) { results.push_back(sd->upscale(decoded_images[i])); } else { results.push_back(sd_tensor_to_image(decoded_images[i])); @@ -6229,7 +6231,7 @@ std::vector StableDiffusion::img2img(const uint8_t* init_img_data, struct ggml_tensor* img = sd->decode_first_stage(work_ctx, x_0); if (img != NULL) { - if(sd->upscale_output) { + if (sd->upscale_output) { result.push_back(sd->upscale(img)); } else { result.push_back(sd_tensor_to_image(img)); From 62dd027814391e6fa85de1cc32f14e336de586ba Mon Sep 17 00:00:00 2001 From: leejet Date: Fri, 15 Dec 2023 00:03:33 +0800 Subject: [PATCH 19/24] simplify esrgan code --- stable-diffusion.cpp | 36 ++++++++++++------------------------ 1 file changed, 12 insertions(+), 24 deletions(-) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 487df05c9..55ac85afa 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -4117,32 +4117,27 @@ struct ResidualDenseBlock { ggml_tensor* forward(ggml_context* ctx, ggml_tensor* out_scale, ggml_tensor* x /* feat */) { // x1 = self.lrelu(self.conv1(x)) - ggml_tensor* x1 = ggml_conv_2d(ctx, conv1_w, x, 1, 1, 1, 1, 1, 1); - x1 = ggml_add(ctx, x1, ggml_reshape_4d(ctx, conv1_b, 1, 1, conv1_b->ne[0], 1)); + ggml_tensor* x1 = ggml_nn_conv_2d(ctx, x, conv1_w, conv1_b, 1, 1, 1, 1); x1 = ggml_leaky_relu(ctx, x1, 0.2f, true); // x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) ggml_tensor* x_cat = ggml_concat(ctx, x, x1); - ggml_tensor* x2 = ggml_conv_2d(ctx, conv2_w, x_cat, 1, 1, 1, 1, 1, 1); - x2 = ggml_add(ctx, x2, ggml_reshape_4d(ctx, conv2_b, 1, 1, conv2_b->ne[0], 1)); + ggml_tensor* x2 = ggml_nn_conv_2d(ctx, x_cat, conv2_w, conv2_b, 1, 1, 1, 1); x2 = ggml_leaky_relu(ctx, x2, 0.2f, true); // x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) x_cat = ggml_concat(ctx, x_cat, x2); - ggml_tensor* x3 = ggml_conv_2d(ctx, conv3_w, x_cat, 1, 1, 1, 1, 1, 1); - x3 = ggml_add(ctx, x3, ggml_reshape_4d(ctx, conv3_b, 1, 1, conv3_b->ne[0], 1)); + ggml_tensor* x3 = ggml_nn_conv_2d(ctx, x_cat, conv3_w, conv3_b, 1, 1, 1, 1); x3 = ggml_leaky_relu(ctx, x3, 0.2f, true); // x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) x_cat = ggml_concat(ctx, x_cat, x3); - ggml_tensor* x4 = ggml_conv_2d(ctx, conv4_w, x_cat, 1, 1, 1, 1, 1, 1); - x4 = ggml_add(ctx, x4, ggml_reshape_4d(ctx, conv4_b, 1, 1, conv4_b->ne[0], 1)); + ggml_tensor* x4 = ggml_nn_conv_2d(ctx, x_cat, conv4_w, conv4_b, 1, 1, 1, 1); x4 = ggml_leaky_relu(ctx, x4, 0.2f, true); // self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) x_cat = ggml_concat(ctx, x_cat, x4); - ggml_tensor* x5 = ggml_conv_2d(ctx, conv5_w, x_cat, 1, 1, 1, 1, 1, 1); - x5 = ggml_add(ctx, x5, ggml_reshape_4d(ctx, conv5_b, 1, 1, conv5_b->ne[0], 1)); + ggml_tensor* x5 = ggml_nn_conv_2d(ctx, x_cat, conv5_w, conv5_b, 1, 1, 1, 1); // return x5 * 0.2 + x x5 = ggml_add(ctx, ggml_scale(ctx, x5, out_scale), x); @@ -4427,8 +4422,7 @@ struct ESRGAN { ggml_tensor* forward(ggml_context* ctx0, ggml_tensor* out_scale, ggml_tensor* x /* feat */) { // feat = self.conv_first(feat) - auto h = ggml_conv_2d(ctx0, conv_first_w, x, 1, 1, 1, 1, 1, 1); - h = ggml_add(ctx0, h, ggml_reshape_4d(ctx0, conv_first_b, 1, 1, conv_first_b->ne[0], 1)); + auto h = ggml_nn_conv_2d(ctx0, x, conv_first_w, conv_first_b, 1, 1, 1, 1); auto body_h = h; // self.body(feat) @@ -4437,8 +4431,7 @@ struct ESRGAN { } // body_feat = self.conv_body(self.body(feat)) - body_h = ggml_conv_2d(ctx0, conv_body_w, body_h, 1, 1, 1, 1, 1, 1); - body_h = ggml_add(ctx0, body_h, ggml_reshape_4d(ctx0, conv_body_b, 1, 1, conv_body_b->ne[0], 1)); + body_h = ggml_nn_conv_2d(ctx0, body_h, conv_body_w, conv_body_b, 1, 1, 1, 1); // feat = feat + body_feat h = ggml_add(ctx0, h, body_h); @@ -4446,24 +4439,19 @@ struct ESRGAN { // upsample // feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) h = ggml_upscale(ctx0, h, 2); - h = ggml_conv_2d(ctx0, conv_up1_w, h, 1, 1, 1, 1, 1, 1); - h = ggml_add(ctx0, h, ggml_reshape_4d(ctx0, conv_up1_b, 1, 1, conv_up1_b->ne[0], 1)); + h = ggml_nn_conv_2d(ctx0, h, conv_up1_w, conv_up1_b, 1, 1, 1, 1); h = ggml_leaky_relu(ctx0, h, 0.2f, true); // feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) h = ggml_upscale(ctx0, h, 2); - h = ggml_conv_2d(ctx0, conv_up2_w, h, 1, 1, 1, 1, 1, 1); - h = ggml_add(ctx0, h, ggml_reshape_4d(ctx0, conv_up2_b, 1, 1, conv_up2_b->ne[0], 1)); + h = ggml_nn_conv_2d(ctx0, h, conv_up2_w, conv_up2_b, 1, 1, 1, 1); h = ggml_leaky_relu(ctx0, h, 0.2f, true); - // self.lrelu(self.conv_hr(feat)) - h = ggml_conv_2d(ctx0, conv_hr_w, h, 1, 1, 1, 1, 1, 1); - h = ggml_add(ctx0, h, ggml_reshape_4d(ctx0, conv_hr_b, 1, 1, conv_hr_b->ne[0], 1)); + // out = self.conv_last(self.lrelu(self.conv_hr(feat))) + h = ggml_nn_conv_2d(ctx0, h, conv_hr_w, conv_hr_b, 1, 1, 1, 1); h = ggml_leaky_relu(ctx0, h, 0.2f, true); - // out = self.conv_last(self.lrelu(self.conv_hr(feat))) - h = ggml_conv_2d(ctx0, conv_last_w, h, 1, 1, 1, 1, 1, 1); - h = ggml_add(ctx0, h, ggml_reshape_4d(ctx0, conv_last_b, 1, 1, conv_last_b->ne[0], 1)); + h = ggml_nn_conv_2d(ctx0, h, conv_last_w, conv_last_b, 1, 1, 1, 1); return h; } From ccdec9dcb11721a53ed50a177734c15d0ec1d60b Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Fri, 15 Dec 2023 11:05:08 -0500 Subject: [PATCH 20/24] indeterministic results fast softmax --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c58db32cd..95d59d613 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -27,7 +27,7 @@ option(SD_BUILD_EXAMPLES "sd: build examples" ${SD_STANDALONE}) option(SD_CUBLAS "sd: cuda backend" OFF) option(SD_METAL "sd: metal backend" OFF) option(SD_FLASH_ATTN "sd: use flash attention for x4 less memory usage" OFF) -option(SD_FAST_SOFTMAX "sd: x1.5 faster softmax, slight variation in the results, cuda only" OFF) +option(SD_FAST_SOFTMAX "sd: x1.5 faster softmax, indeterministic (sometimes, same seed don't generate same image), cuda only" OFF) option(BUILD_SHARED_LIBS "sd: build shared libs" OFF) #option(SD_BUILD_SERVER "sd: build server example" ON) From 1e3797af210ad3625761523834536b3de7626de6 Mon Sep 17 00:00:00 2001 From: leejet Date: Mon, 18 Dec 2023 23:11:24 +0800 Subject: [PATCH 21/24] fix clip_skip --- examples/cli/main.cpp | 19 ++++++++++--------- stable-diffusion.cpp | 23 +++++++++++++++-------- stable-diffusion.h | 2 +- 3 files changed, 26 insertions(+), 18 deletions(-) diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index b5c7ec31b..6264d6e2f 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -67,11 +67,11 @@ struct SDParams { std::string prompt; std::string negative_prompt; - float cfg_scale = 7.0f; - int clip_skip_layers = 0; - int width = 512; - int height = 512; - int batch_count = 1; + float cfg_scale = 7.0f; + int clip_skip = -1; // <= 0 represents unspecified + int width = 512; + int height = 512; + int batch_count = 1; SampleMethod sample_method = EULER_A; Schedule schedule = DEFAULT; @@ -97,7 +97,7 @@ void print_params(SDParams params) { printf(" prompt: %s\n", params.prompt.c_str()); printf(" negative_prompt: %s\n", params.negative_prompt.c_str()); printf(" cfg_scale: %.2f\n", params.cfg_scale); - printf(" clip_skip_layers: %d\n", params.clip_skip_layers); + printf(" clip_skip: %d\n", params.clip_skip); printf(" width: %d\n", params.width); printf(" height: %d\n", params.height); printf(" sample_method: %s\n", sample_method_str[params.sample_method]); @@ -141,7 +141,8 @@ void print_usage(int argc, const char* argv[]) { printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n"); printf(" -b, --batch-count COUNT number of images to generate.\n"); printf(" --schedule {discrete, karras} Denoiser sigma schedule (default: discrete)\n"); - printf(" --clip-skip N number of layers to skip of clip model (default: 0)\n"); + printf(" --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n"); + printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n"); printf(" --vae-tiling process vae in tiles to reduce memory usage\n"); printf(" -v, --verbose print extra info\n"); } @@ -290,7 +291,7 @@ void parse_args(int argc, const char** argv, SDParams& params) { invalid_arg = true; break; } - params.clip_skip_layers = std::stoi(argv[i]); + params.clip_skip = std::stoi(argv[i]); } else if (arg == "--vae-tiling") { params.vae_tiling = true; } else if (arg == "-b" || arg == "--batch-count") { @@ -483,7 +484,7 @@ int main(int argc, const char* argv[]) { StableDiffusion sd(params.n_threads, vae_decode_only, params.taesd_path, params.esrgan_path, true, params.vae_tiling, params.lora_model_dir, params.rng_type); - if (!sd.load_from_file(params.model_path, params.vae_path, params.wtype, params.schedule, params.clip_skip_layers)) { + if (!sd.load_from_file(params.model_path, params.vae_path, params.wtype, params.schedule, params.clip_skip)) { return 1; } diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 55ac85afa..dc33eab14 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -1003,7 +1003,7 @@ struct CLIPTextModel { int32_t intermediate_size = 3072; // 4096 for SD 2.x int32_t n_head = 12; // num_attention_heads, 16 for SD 2.x int32_t num_hidden_layers = 12; // 24 for SD 2.x - int32_t skip_layers = 0; + int32_t clip_skip = 1; // embeddings struct ggml_tensor* position_ids; @@ -1162,8 +1162,9 @@ struct CLIPTextModel { ggml_view_1d(ctx0, position_ids, input_ids->ne[0], 0))); // [N, n_token, hidden_size] // transformer + int layer_idx = num_hidden_layers - clip_skip; for (int i = 0; i < num_hidden_layers; i++) { - if (version == VERSION_2_x && i == num_hidden_layers - 1 || i > (num_hidden_layers - skip_layers - 1)) { // layer: "penultimate" + if (i == layer_idx + 1) { break; } x = resblocks[i].forward(ctx0, x); // [N, n_token, hidden_size] @@ -4961,7 +4962,7 @@ class StableDiffusionGGML { const std::string& vae_path, ggml_type wtype, Schedule schedule, - int clip_skip_layers) { + int clip_skip) { #ifdef SD_USE_CUBLAS LOG_DEBUG("Using CUDA backend"); backend = ggml_backend_cuda_init(0); @@ -5003,9 +5004,15 @@ class StableDiffusionGGML { LOG_ERROR("get sd version from file failed: '%s'", model_path.c_str()); return false; } - cond_stage_model = FrozenCLIPEmbedderWithCustomWords(version); - cond_stage_model.text_model.skip_layers = clip_skip_layers; - diffusion_model = UNetModel(version); + if (clip_skip <= 0) { + clip_skip = 1; + if (version == VERSION_2_x) { + clip_skip = 2; + } + } + cond_stage_model = FrozenCLIPEmbedderWithCustomWords(version); + cond_stage_model.text_model.clip_skip = clip_skip; + diffusion_model = UNetModel(version); LOG_INFO("Stable Diffusion %s ", model_version_to_str[version]); if (wtype == GGML_TYPE_COUNT) { model_data_type = model_loader.get_sd_wtype(); @@ -5989,8 +5996,8 @@ bool StableDiffusion::load_from_file(const std::string& model_path, const std::string& vae_path, ggml_type wtype, Schedule s, - int clip_skip_layers) { - return sd->load_from_file(model_path, vae_path, wtype, s, clip_skip_layers); + int clip_skip) { + return sd->load_from_file(model_path, vae_path, wtype, s, clip_skip); } std::vector StableDiffusion::txt2img(std::string prompt, diff --git a/stable-diffusion.h b/stable-diffusion.h index ac2b60b90..13373da6d 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -49,7 +49,7 @@ class StableDiffusion { const std::string& vae_path, ggml_type wtype, Schedule d = DEFAULT, - int clip_skip_layers = 0); + int clip_skip = -1); std::vector txt2img( std::string prompt, std::string negative_prompt, From 6dce7ea49135236707969935d64402b36d8a5653 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Mon, 18 Dec 2023 12:21:48 -0500 Subject: [PATCH 22/24] fix cuda sync buffers --- ggml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml b/ggml index 11dc0eed5..a0c2ec77a 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit 11dc0eed5cc86fb0bf0c8c72f3817f50edb27659 +Subproject commit a0c2ec77a5ef8e630aff65bc535d13b9805cb929 From 9ea2bcd231ea3524f60b47631e3964adfdc280b7 Mon Sep 17 00:00:00 2001 From: leejet Date: Sat, 23 Dec 2023 13:03:02 +0800 Subject: [PATCH 23/24] synchronize after get tensor from backend --- stable-diffusion.cpp | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 5782bafa9..99b4e1d7f 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -605,6 +605,11 @@ std::pair, std::string> extract_and_remov return std::make_pair(filename2multiplier, text); } +void ggml_backend_tensor_get_and_sync(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + ggml_backend_tensor_get(tensor, data, offset, size); + ggml_backend_synchronize(backend); +} + /*================================================== CLIPTokenizer ===================================================*/ const std::string UNK_TOKEN = "<|endoftext|>"; @@ -1410,7 +1415,7 @@ struct CLIPTextModel { #ifdef GGML_PERF ggml_graph_print(gf); #endif - ggml_backend_tensor_get(gf->nodes[gf->n_nodes - 1], work_output->data, 0, ggml_nbytes(work_output)); + ggml_backend_tensor_get_and_sync(backend, gf->nodes[gf->n_nodes - 1], work_output->data, 0, ggml_nbytes(work_output)); return work_output; } @@ -2688,7 +2693,7 @@ struct UNetModel { ggml_graph_print(gf); #endif - ggml_backend_tensor_get(gf->nodes[gf->n_nodes - 1], work_latent->data, 0, ggml_nbytes(work_latent)); + ggml_backend_tensor_get_and_sync(backend, gf->nodes[gf->n_nodes - 1], work_latent->data, 0, ggml_nbytes(work_latent)); } void end() { @@ -3499,7 +3504,7 @@ struct AutoEncoderKL { ggml_graph_print(gf); #endif - ggml_backend_tensor_get(gf->nodes[gf->n_nodes - 1], work_result->data, 0, ggml_nbytes(work_result)); + ggml_backend_tensor_get_and_sync(backend, gf->nodes[gf->n_nodes - 1], work_result->data, 0, ggml_nbytes(work_result)); } void end() { @@ -4182,7 +4187,7 @@ struct TinyAutoEncoder { ggml_graph_print(gf); #endif - ggml_backend_tensor_get(gf->nodes[gf->n_nodes - 1], work_result->data, 0, ggml_nbytes(work_result)); + ggml_backend_tensor_get_and_sync(backend, gf->nodes[gf->n_nodes - 1], work_result->data, 0, ggml_nbytes(work_result)); } void end() { @@ -4705,7 +4710,7 @@ struct ESRGAN { ggml_graph_print(gf); #endif ggml_tensor* out = gf->nodes[gf->n_nodes - 1]; - ggml_backend_tensor_get(out, work_result->data, 0, ggml_nbytes(out)); + ggml_backend_tensor_get_and_sync(backend, out, work_result->data, 0, ggml_nbytes(out)); } void end() { From 421e39b0f0aa30dcc35151f69028a7b224943196 Mon Sep 17 00:00:00 2001 From: leejet Date: Sat, 23 Dec 2023 14:38:04 +0800 Subject: [PATCH 24/24] use ggml_backend_tensor_get_async and sync for cuda backend --- stable-diffusion.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 99b4e1d7f..70cd79a7b 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -606,8 +606,12 @@ std::pair, std::string> extract_and_remov } void ggml_backend_tensor_get_and_sync(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { - ggml_backend_tensor_get(tensor, data, offset, size); - ggml_backend_synchronize(backend); + #ifdef SD_USE_CUBLAS + ggml_backend_tensor_get_async(backend, tensor, data, offset, size); + ggml_backend_synchronize(backend); + #else + ggml_backend_tensor_get(tensor, data, offset, size); + #endif } /*================================================== CLIPTokenizer ===================================================*/