From 652f58f0429dd1c0171e5bf20579809d83901994 Mon Sep 17 00:00:00 2001 From: Green Sky Date: Sun, 1 Sep 2024 14:52:19 +0200 Subject: [PATCH 1/5] repair flash attention in _ext this does not fix the currently broken fa behind the define, which is only used by VAE Co-authored-by: FSSRepo --- ggml_extend.hpp | 34 +++++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/ggml_extend.hpp b/ggml_extend.hpp index e50137d5e..fb81c3a62 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -734,13 +734,35 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* float scale = (1.0f / sqrt((float)d_head)); - bool use_flash_attn = false; - ggml_tensor* kqv = NULL; + LOG_DEBUG("attention_ext L_k:%d n_head:%d C:%d d_head:%d", L_k, n_head, C, d_head); + + bool use_flash_attn = true; + // L_k == n_context AND l_k == n_token ???? + use_flash_attn = use_flash_attn && L_k % 256 == 0; + use_flash_attn = use_flash_attn && d_head % 64 == 0; // why + + if (mask != nullptr) { + // TODO: figure out if we can bend t5 to work too + use_flash_attn = use_flash_attn && mask->ne[2] == 1; + use_flash_attn = use_flash_attn && mask->ne[3] == 1; + } + + // TODO: more pad or disable for funny tensor shapes + + ggml_tensor* kqv = nullptr; if (use_flash_attn) { + LOG_DEBUG("using flash attention"); + + k = ggml_cast(ctx, k, GGML_TYPE_F16); + v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head] v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head] - LOG_DEBUG("k->ne[1] == %d", k->ne[1]); + v = ggml_cast(ctx, v, GGML_TYPE_F16); + kqv = ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0, 0); + ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32); + + kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_k, kqv->nb[1], kqv->nb[2], 0); } else { v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, L_k] v = ggml_reshape_3d(ctx, v, L_k, d_head, n_head * N); // [N * n_head, d_head, L_k] @@ -756,10 +778,12 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* kq = ggml_soft_max_inplace(ctx, kq); kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, L_q, d_head] + + kqv = ggml_reshape_4d(ctx, kqv, d_head, L_q, n_head, N); // [N, n_head, L_q, d_head] + kqv = ggml_permute(ctx, kqv, 0, 2, 1, 3); // [N, L_q, n_head, d_head] } - kqv = ggml_reshape_4d(ctx, kqv, d_head, L_q, n_head, N); // [N, n_head, L_q, d_head] - kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); // [N, L_q, n_head, d_head] + kqv = ggml_cont(ctx, kqv); kqv = ggml_reshape_3d(ctx, kqv, d_head * n_head, L_q, N); // [N, L_q, C] return kqv; From cc65d7e1bcf8c5efe9baba45b81c08a15468daf1 Mon Sep 17 00:00:00 2001 From: Green Sky Date: Sat, 7 Sep 2024 11:32:46 +0200 Subject: [PATCH 2/5] make flash attention in the diffusion model a runtime flag no support for sd3 or video --- common.hpp | 23 +++++++++++++--------- diffusion_model.hpp | 12 +++++++----- examples/cli/main.cpp | 10 +++++++++- flux.hpp | 44 ++++++++++++++++++++++++++++--------------- ggml_extend.hpp | 28 ++++++++++++++++----------- stable-diffusion.cpp | 19 ++++++++++++++----- stable-diffusion.h | 3 ++- unet.hpp | 11 ++++++----- 8 files changed, 98 insertions(+), 52 deletions(-) diff --git a/common.hpp b/common.hpp index b18ee51f5..784bab32b 100644 --- a/common.hpp +++ b/common.hpp @@ -245,16 +245,19 @@ class CrossAttention : public GGMLBlock { int64_t context_dim; int64_t n_head; int64_t d_head; + bool flash_attn; public: CrossAttention(int64_t query_dim, int64_t context_dim, int64_t n_head, - int64_t d_head) + int64_t d_head, + bool flash_attn = false) : n_head(n_head), d_head(d_head), query_dim(query_dim), - context_dim(context_dim) { + context_dim(context_dim), + flash_attn(flash_attn) { int64_t inner_dim = d_head * n_head; blocks["to_q"] = std::shared_ptr(new Linear(query_dim, inner_dim, false)); @@ -283,7 +286,7 @@ class CrossAttention : public GGMLBlock { auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim] auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim] - x = ggml_nn_attention_ext(ctx, q, k, v, n_head, NULL, false); // [N, n_token, inner_dim] + x = ggml_nn_attention_ext(ctx, q, k, v, n_head, NULL, false, false, flash_attn); // [N, n_token, inner_dim] x = to_out_0->forward(ctx, x); // [N, n_token, query_dim] return x; @@ -301,15 +304,16 @@ class BasicTransformerBlock : public GGMLBlock { int64_t n_head, int64_t d_head, int64_t context_dim, - bool ff_in = false) + bool ff_in = false, + bool flash_attn = false) : n_head(n_head), d_head(d_head), ff_in(ff_in) { // disable_self_attn is always False // disable_temporal_crossattention is always False // switch_temporal_ca_to_sa is always False // inner_dim is always None or equal to dim // gated_ff is always True - blocks["attn1"] = std::shared_ptr(new CrossAttention(dim, dim, n_head, d_head)); - blocks["attn2"] = std::shared_ptr(new CrossAttention(dim, context_dim, n_head, d_head)); + blocks["attn1"] = std::shared_ptr(new CrossAttention(dim, dim, n_head, d_head, flash_attn)); + blocks["attn2"] = std::shared_ptr(new CrossAttention(dim, context_dim, n_head, d_head, flash_attn)); blocks["ff"] = std::shared_ptr(new FeedForward(dim, dim)); blocks["norm1"] = std::shared_ptr(new LayerNorm(dim)); blocks["norm2"] = std::shared_ptr(new LayerNorm(dim)); @@ -374,7 +378,8 @@ class SpatialTransformer : public GGMLBlock { int64_t n_head, int64_t d_head, int64_t depth, - int64_t context_dim) + int64_t context_dim, + bool flash_attn = false) : in_channels(in_channels), n_head(n_head), d_head(d_head), @@ -388,7 +393,7 @@ class SpatialTransformer : public GGMLBlock { for (int i = 0; i < depth; i++) { std::string name = "transformer_blocks." + std::to_string(i); - blocks[name] = std::shared_ptr(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim)); + blocks[name] = std::shared_ptr(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim, false, flash_attn)); } blocks["proj_out"] = std::shared_ptr(new Conv2d(inner_dim, in_channels, {1, 1})); @@ -511,4 +516,4 @@ class VideoResBlock : public ResBlock { } }; -#endif // __COMMON_HPP__ \ No newline at end of file +#endif // __COMMON_HPP__ diff --git a/diffusion_model.hpp b/diffusion_model.hpp index 2530f7149..7bada4596 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -31,8 +31,9 @@ struct UNetModel : public DiffusionModel { UNetModel(ggml_backend_t backend, ggml_type wtype, - SDVersion version = VERSION_SD1) - : unet(backend, wtype, version) { + SDVersion version = VERSION_SD1, + bool flash_attn = false) + : unet(backend, wtype, version, flash_attn) { } void alloc_params_buffer() { @@ -129,8 +130,9 @@ struct FluxModel : public DiffusionModel { FluxModel(ggml_backend_t backend, ggml_type wtype, - SDVersion version = VERSION_FLUX_DEV) - : flux(backend, wtype, version) { + SDVersion version = VERSION_FLUX_DEV, + bool flash_attn = false) + : flux(backend, wtype, version, flash_attn) { } void alloc_params_buffer() { @@ -173,4 +175,4 @@ struct FluxModel : public DiffusionModel { } }; -#endif \ No newline at end of file +#endif diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index f1bdc698b..2bdd88112 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -116,6 +116,7 @@ struct SDParams { bool normalize_input = false; bool clip_on_cpu = false; bool vae_on_cpu = false; + bool diffusion_flash_attn = false; bool canny_preprocess = false; bool color = false; int upscale_repeats = 1; @@ -146,6 +147,7 @@ void print_params(SDParams params) { printf(" clip on cpu: %s\n", params.clip_on_cpu ? "true" : "false"); printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false"); printf(" vae decoder on cpu:%s\n", params.vae_on_cpu ? "true" : "false"); + printf(" diffusion flash attention:%s\n", params.diffusion_flash_attn ? "true" : "false"); printf(" strength(control): %.2f\n", params.control_strength); printf(" prompt: %s\n", params.prompt.c_str()); printf(" negative_prompt: %s\n", params.negative_prompt.c_str()); @@ -215,6 +217,9 @@ void print_usage(int argc, const char* argv[]) { printf(" --vae-tiling process vae in tiles to reduce memory usage\n"); printf(" --vae-on-cpu keep vae in cpu (for low vram)\n"); printf(" --clip-on-cpu keep clip in cpu (for low vram)\n"); + printf(" --diffusion-fa use flash attention in the diffusion model (for low vram)\n"); + printf(" Might lower quality, since it implies converting k and v to f16.\n"); + printf(" This might crash if it is not supported by the backend.\n"); printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n"); printf(" --canny apply canny preprocessor (edge detection)\n"); printf(" --color Colors the logging tags according to level\n"); @@ -465,6 +470,8 @@ void parse_args(int argc, const char** argv, SDParams& params) { params.clip_on_cpu = true; // will slow down get_learned_condiotion but necessary for low MEM GPUs } else if (arg == "--vae-on-cpu") { params.vae_on_cpu = true; // will slow down latent decoding but necessary for low MEM GPUs + } else if (arg == "--diffusion-fa") { + params.diffusion_flash_attn = true; // can reduce MEM significantly } else if (arg == "--canny") { params.canny_preprocess = true; } else if (arg == "-b" || arg == "--batch-count") { @@ -791,7 +798,8 @@ int main(int argc, const char* argv[]) { params.schedule, params.clip_on_cpu, params.control_net_cpu, - params.vae_on_cpu); + params.vae_on_cpu, + params.diffusion_flash_attn); if (sd_ctx == NULL) { printf("new_sd_ctx_t failed\n"); diff --git a/flux.hpp b/flux.hpp index 73bc345a7..e038fd7f9 100644 --- a/flux.hpp +++ b/flux.hpp @@ -115,25 +115,28 @@ namespace Flux { struct ggml_tensor* q, struct ggml_tensor* k, struct ggml_tensor* v, - struct ggml_tensor* pe) { + struct ggml_tensor* pe, + bool flash_attn) { // q,k,v: [N, L, n_head, d_head] // pe: [L, d_head/2, 2, 2] // return: [N, L, n_head*d_head] q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head] k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head] - auto x = ggml_nn_attention_ext(ctx, q, k, v, v->ne[1], NULL, false, true); // [N, L, n_head*d_head] + auto x = ggml_nn_attention_ext(ctx, q, k, v, v->ne[1], NULL, false, true, flash_attn); // [N, L, n_head*d_head] return x; } struct SelfAttention : public GGMLBlock { public: int64_t num_heads; + bool flash_attn; public: SelfAttention(int64_t dim, int64_t num_heads = 8, - bool qkv_bias = false) + bool qkv_bias = false, + bool flash_attn = false) : num_heads(num_heads) { int64_t head_dim = dim / num_heads; blocks["qkv"] = std::shared_ptr(new Linear(dim, dim * 3, qkv_bias)); @@ -168,7 +171,7 @@ namespace Flux { // pe: [n_token, d_head/2, 2, 2] // return [N, n_token, dim] auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head] - x = attention(ctx, qkv[0], qkv[1], qkv[2], pe); // [N, n_token, dim] + x = attention(ctx, qkv[0], qkv[1], qkv[2], pe, flash_attn); // [N, n_token, dim] x = post_attention(ctx, x); // [N, n_token, dim] return x; } @@ -237,15 +240,18 @@ namespace Flux { } struct DoubleStreamBlock : public GGMLBlock { + bool flash_attn; public: DoubleStreamBlock(int64_t hidden_size, int64_t num_heads, float mlp_ratio, - bool qkv_bias = false) { + bool qkv_bias = false, + bool flash_attn = false) + : flash_attn(flash_attn) { int64_t mlp_hidden_dim = hidden_size * mlp_ratio; blocks["img_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); blocks["img_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); - blocks["img_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias)); + blocks["img_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, flash_attn)); blocks["img_norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); blocks["img_mlp.0"] = std::shared_ptr(new Linear(hidden_size, mlp_hidden_dim)); @@ -254,7 +260,7 @@ namespace Flux { blocks["txt_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); blocks["txt_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); - blocks["txt_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias)); + blocks["txt_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, flash_attn)); blocks["txt_norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); blocks["txt_mlp.0"] = std::shared_ptr(new Linear(hidden_size, mlp_hidden_dim)); @@ -316,7 +322,7 @@ namespace Flux { auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] - auto attn = attention(ctx, q, k, v, pe); // [N, n_txt_token + n_img_token, n_head*d_head] + auto attn = attention(ctx, q, k, v, pe, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head] attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] auto txt_attn_out = ggml_view_3d(ctx, attn, @@ -364,13 +370,15 @@ namespace Flux { int64_t num_heads; int64_t hidden_size; int64_t mlp_hidden_dim; + bool flash_attn; public: SingleStreamBlock(int64_t hidden_size, int64_t num_heads, float mlp_ratio = 4.0f, - float qk_scale = 0.f) - : hidden_size(hidden_size), num_heads(num_heads) { + float qk_scale = 0.f, + bool flash_attn = false) + : hidden_size(hidden_size), num_heads(num_heads), flash_attn(flash_attn) { int64_t head_dim = hidden_size / num_heads; float scale = qk_scale; if (scale <= 0.f) { @@ -433,7 +441,7 @@ namespace Flux { auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head] q = norm->query_norm(ctx, q); k = norm->key_norm(ctx, k); - auto attn = attention(ctx, q, k, v, pe); // [N, n_token, hidden_size] + auto attn = attention(ctx, q, k, v, pe, flash_attn); // [N, n_token, hidden_size] auto attn_mlp = ggml_concat(ctx, attn, ggml_gelu_inplace(ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim] auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size] @@ -492,6 +500,7 @@ namespace Flux { int theta = 10000; bool qkv_bias = true; bool guidance_embed = true; + bool flash_attn = true; }; struct Flux : public GGMLBlock { @@ -646,13 +655,16 @@ namespace Flux { blocks["double_blocks." + std::to_string(i)] = std::shared_ptr(new DoubleStreamBlock(params.hidden_size, params.num_heads, params.mlp_ratio, - params.qkv_bias)); + params.qkv_bias, + params.flash_attn)); } for (int i = 0; i < params.depth_single_blocks; i++) { blocks["single_blocks." + std::to_string(i)] = std::shared_ptr(new SingleStreamBlock(params.hidden_size, params.num_heads, - params.mlp_ratio)); + params.mlp_ratio, + 0.f, + params.flash_attn)); } blocks["final_layer"] = std::shared_ptr(new LastLayer(params.hidden_size, 1, out_channels)); @@ -808,8 +820,10 @@ namespace Flux { FluxRunner(ggml_backend_t backend, ggml_type wtype, - SDVersion version = VERSION_FLUX_DEV) + SDVersion version = VERSION_FLUX_DEV, + bool flash_attn = false) : GGMLRunner(backend, wtype) { + flux_params.flash_attn = flash_attn; if (version == VERSION_FLUX_SCHNELL) { flux_params.guidance_embed = false; } @@ -958,4 +972,4 @@ namespace Flux { } // namespace Flux -#endif // __FLUX_HPP__ \ No newline at end of file +#endif // __FLUX_HPP__ diff --git a/ggml_extend.hpp b/ggml_extend.hpp index fb81c3a62..65892f3c2 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -703,7 +703,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* int64_t n_head, struct ggml_tensor* mask = NULL, bool diag_mask_inf = false, - bool skip_reshape = false) { + bool skip_reshape = false, + bool flash_attn = false) { int64_t L_q; int64_t L_k; int64_t C; @@ -734,25 +735,29 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* float scale = (1.0f / sqrt((float)d_head)); - LOG_DEBUG("attention_ext L_k:%d n_head:%d C:%d d_head:%d", L_k, n_head, C, d_head); + if (flash_attn) { + // TODO: remove before merge + LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N); + } + // is there anything oddly shaped?? + GGML_ASSERT(((L_k % 256 == 0) && L_q == L_k) || !(L_k % 256 == 0)); - bool use_flash_attn = true; - // L_k == n_context AND l_k == n_token ???? - use_flash_attn = use_flash_attn && L_k % 256 == 0; - use_flash_attn = use_flash_attn && d_head % 64 == 0; // why + bool can_use_flash_attn = true; + can_use_flash_attn = can_use_flash_attn && L_k % 256 == 0; + can_use_flash_attn = can_use_flash_attn && d_head % 64 == 0; // double check if (mask != nullptr) { // TODO: figure out if we can bend t5 to work too - use_flash_attn = use_flash_attn && mask->ne[2] == 1; - use_flash_attn = use_flash_attn && mask->ne[3] == 1; + can_use_flash_attn = can_use_flash_attn && mask->ne[2] == 1; + can_use_flash_attn = can_use_flash_attn && mask->ne[3] == 1; } // TODO: more pad or disable for funny tensor shapes ggml_tensor* kqv = nullptr; - if (use_flash_attn) { + //GGML_ASSERT((flash_attn && can_use_flash_attn) || !flash_attn); + if (can_use_flash_attn && flash_attn) { LOG_DEBUG("using flash attention"); - k = ggml_cast(ctx, k, GGML_TYPE_F16); v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head] @@ -762,7 +767,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* kqv = ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0, 0); ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32); - kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_k, kqv->nb[1], kqv->nb[2], 0); + //kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_k, kqv->nb[1], kqv->nb[2], 0); + kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_q, kqv->nb[1], kqv->nb[2], 0); } else { v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, L_k] v = ggml_reshape_3d(ctx, v, L_k, d_head, n_head * N); // [N * n_head, d_head, L_k] diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 4d28a147b..256cc0635 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -153,7 +153,8 @@ class StableDiffusionGGML { schedule_t schedule, bool clip_on_cpu, bool control_net_cpu, - bool vae_on_cpu) { + bool vae_on_cpu, + bool diffusion_flash_attn) { use_tiny_autoencoder = taesd_path.size() > 0; #ifdef SD_USE_CUBLAS LOG_DEBUG("Using CUDA backend"); @@ -322,15 +323,21 @@ class StableDiffusionGGML { LOG_INFO("CLIP: Using CPU backend"); clip_backend = ggml_backend_cpu_init(); } + if (diffusion_flash_attn) { + LOG_INFO("Using flash attention in the diffusion model"); + } if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) { + if (diffusion_flash_attn) { + LOG_WARN("flash attention in this diffusion model is currently unsupported!"); + } cond_stage_model = std::make_shared(clip_backend, conditioner_wtype); diffusion_model = std::make_shared(backend, diffusion_model_wtype, version); } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { cond_stage_model = std::make_shared(clip_backend, conditioner_wtype); - diffusion_model = std::make_shared(backend, diffusion_model_wtype, version); + diffusion_model = std::make_shared(backend, diffusion_model_wtype, version, diffusion_flash_attn); } else { cond_stage_model = std::make_shared(clip_backend, conditioner_wtype, embeddings_path, version); - diffusion_model = std::make_shared(backend, diffusion_model_wtype, version); + diffusion_model = std::make_shared(backend, diffusion_model_wtype, version, diffusion_flash_attn); } cond_stage_model->alloc_params_buffer(); cond_stage_model->get_param_tensors(tensors); @@ -1035,7 +1042,8 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, enum schedule_t s, bool keep_clip_on_cpu, bool keep_control_net_cpu, - bool keep_vae_on_cpu) { + bool keep_vae_on_cpu, + bool diffusion_flash_attn) { sd_ctx_t* sd_ctx = (sd_ctx_t*)malloc(sizeof(sd_ctx_t)); if (sd_ctx == NULL) { return NULL; @@ -1076,7 +1084,8 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, s, keep_clip_on_cpu, keep_control_net_cpu, - keep_vae_on_cpu)) { + keep_vae_on_cpu, + diffusion_flash_attn)) { delete sd_ctx->sd; sd_ctx->sd = NULL; free(sd_ctx); diff --git a/stable-diffusion.h b/stable-diffusion.h index 812e8fc94..06eee40f7 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -142,7 +142,8 @@ SD_API sd_ctx_t* new_sd_ctx(const char* model_path, enum schedule_t s, bool keep_clip_on_cpu, bool keep_control_net_cpu, - bool keep_vae_on_cpu); + bool keep_vae_on_cpu, + bool diffusion_flash_attn); SD_API void free_sd_ctx(sd_ctx_t* sd_ctx); diff --git a/unet.hpp b/unet.hpp index 94a8ba46a..611891298 100644 --- a/unet.hpp +++ b/unet.hpp @@ -183,7 +183,7 @@ class UnetModelBlock : public GGMLBlock { int model_channels = 320; int adm_in_channels = 2816; // only for VERSION_SDXL/SVD - UnetModelBlock(SDVersion version = VERSION_SD1) + UnetModelBlock(SDVersion version = VERSION_SD1, bool flash_attn = false) : version(version) { if (version == VERSION_SD2) { context_dim = 1024; @@ -242,7 +242,7 @@ class UnetModelBlock : public GGMLBlock { if (version == VERSION_SVD) { return new SpatialVideoTransformer(in_channels, n_head, d_head, depth, context_dim); } else { - return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim); + return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim, flash_attn); } }; @@ -533,8 +533,9 @@ struct UNetModelRunner : public GGMLRunner { UNetModelRunner(ggml_backend_t backend, ggml_type wtype, - SDVersion version = VERSION_SD1) - : GGMLRunner(backend, wtype), unet(version) { + SDVersion version = VERSION_SD1, + bool flash_attn = false) + : GGMLRunner(backend, wtype), unet(version, flash_attn) { unet.init(params_ctx, wtype); } @@ -649,4 +650,4 @@ struct UNetModelRunner : public GGMLRunner { } }; -#endif // __UNET_HPP__ \ No newline at end of file +#endif // __UNET_HPP__ From 78aeee8f5655fe43bdd869a8610f4c8e72a10114 Mon Sep 17 00:00:00 2001 From: Green Sky Date: Sat, 7 Sep 2024 12:47:20 +0200 Subject: [PATCH 3/5] remove old flash attention option and switch vae over to attn_ext --- CMakeLists.txt | 6 ------ ggml_extend.hpp | 44 ++++++++++---------------------------------- stable-diffusion.cpp | 8 +------- vae.hpp | 10 ++++++---- 4 files changed, 17 insertions(+), 51 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c993e7c96..8466ed5d9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -29,7 +29,6 @@ option(SD_HIPBLAS "sd: rocm backend" OFF) option(SD_METAL "sd: metal backend" OFF) option(SD_VULKAN "sd: vulkan backend" OFF) option(SD_SYCL "sd: sycl backend" OFF) -option(SD_FLASH_ATTN "sd: use flash attention for x4 less memory usage" OFF) option(SD_FAST_SOFTMAX "sd: x1.5 faster softmax, indeterministic (sometimes, same seed don't generate same image), cuda only" OFF) option(SD_BUILD_SHARED_LIBS "sd: build shared libs" OFF) #option(SD_BUILD_SERVER "sd: build server example" ON) @@ -61,11 +60,6 @@ if (SD_HIPBLAS) endif() endif () -if(SD_FLASH_ATTN) - message("-- Use Flash Attention for memory optimization") - add_definitions(-DSD_USE_FLASH_ATTENTION) -endif() - set(SD_LIB stable-diffusion) file(GLOB SD_LIB_SOURCES diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 65892f3c2..3924b5d1b 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -666,32 +666,6 @@ __STATIC_INLINE__ std::vector split_qkv(struct ggml_context return {q, k, v}; } -// q: [N * n_head, n_token, d_head] -// k: [N * n_head, n_k, d_head] -// v: [N * n_head, d_head, n_k] -// return: [N * n_head, n_token, d_head] -__STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention(struct ggml_context* ctx, - struct ggml_tensor* q, - struct ggml_tensor* k, - struct ggml_tensor* v, - bool mask = false) { -#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) && !defined(SD_USE_METAL) && !defined(SD_USE_VULKAN) && !defined(SD_USE_SYCL) - struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, n_token, d_head] -#else - float d_head = (float)q->ne[0]; - - struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, n_token, n_k] - kq = ggml_scale_inplace(ctx, kq, 1.0f / sqrt(d_head)); - if (mask) { - kq = ggml_diag_mask_inf_inplace(ctx, kq, 0); - } - kq = ggml_soft_max_inplace(ctx, kq); - - struct ggml_tensor* kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, n_token, d_head] -#endif - return kqv; -} - // q: [N, L_q, C] or [N*n_head, L_q, d_head] // k: [N, L_k, C] or [N*n_head, L_k, d_head] // v: [N, L_k, C] or [N, L_k, n_head, d_head] @@ -735,29 +709,31 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* float scale = (1.0f / sqrt((float)d_head)); - if (flash_attn) { - // TODO: remove before merge - LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N); - } - // is there anything oddly shaped?? + //if (flash_attn) { + // LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N); + //} + // is there anything oddly shaped?? ping Green-Sky if you can trip this assert GGML_ASSERT(((L_k % 256 == 0) && L_q == L_k) || !(L_k % 256 == 0)); bool can_use_flash_attn = true; can_use_flash_attn = can_use_flash_attn && L_k % 256 == 0; can_use_flash_attn = can_use_flash_attn && d_head % 64 == 0; // double check + // cuda max d_head seems to be 256, cpu does seem to work with 512 + can_use_flash_attn = can_use_flash_attn && d_head <= 256; // double check + if (mask != nullptr) { - // TODO: figure out if we can bend t5 to work too + // TODO(Green-Sky): figure out if we can bend t5 to work too can_use_flash_attn = can_use_flash_attn && mask->ne[2] == 1; can_use_flash_attn = can_use_flash_attn && mask->ne[3] == 1; } - // TODO: more pad or disable for funny tensor shapes + // TODO(Green-Sky): more pad or disable for funny tensor shapes ggml_tensor* kqv = nullptr; //GGML_ASSERT((flash_attn && can_use_flash_attn) || !flash_attn); if (can_use_flash_attn && flash_attn) { - LOG_DEBUG("using flash attention"); + //LOG_DEBUG("using flash attention"); k = ggml_cast(ctx, k, GGML_TYPE_F16); v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head] diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 256cc0635..7679618be 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -183,13 +183,7 @@ class StableDiffusionGGML { LOG_DEBUG("Using CPU backend"); backend = ggml_backend_cpu_init(); } -#ifdef SD_USE_FLASH_ATTENTION -#if defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) || defined(SD_USE_SYCL) || defined(SD_USE_VULKAN) - LOG_WARN("Flash Attention not supported with GPU Backend"); -#else - LOG_INFO("Flash Attention enabled"); -#endif -#endif + ModelLoader model_loader; vae_tiling = vae_tiling_; diff --git a/vae.hpp b/vae.hpp index 42b694cd5..749c21aad 100644 --- a/vae.hpp +++ b/vae.hpp @@ -99,10 +99,12 @@ class AttnBlock : public UnaryBlock { k = ggml_cont(ctx, ggml_permute(ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels] k = ggml_reshape_3d(ctx, k, c, h * w, n); // [N, h * w, in_channels] - auto v = v_proj->forward(ctx, h_); // [N, in_channels, h, w] - v = ggml_reshape_3d(ctx, v, h * w, c, n); // [N, in_channels, h * w] + auto v = v_proj->forward(ctx, h_); // [N, in_channels, h, w] + v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, h, w, in_channels] + v = ggml_reshape_3d(ctx, v, c, h * w, n); // [N, h * w, in_channels] - h_ = ggml_nn_attention(ctx, q, k, v, false); // [N, h * w, in_channels] + //h_ = ggml_nn_attention(ctx, q, k, v, false); // [N, h * w, in_channels] + h_ = ggml_nn_attention_ext(ctx, q, k, v, 1, nullptr, false, true, false); h_ = ggml_cont(ctx, ggml_permute(ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w] h_ = ggml_reshape_4d(ctx, h_, w, h, c, n); // [N, in_channels, h, w] @@ -612,4 +614,4 @@ struct AutoEncoderKL : public GGMLRunner { }; }; -#endif \ No newline at end of file +#endif From cbf0489acc281a89395509453a96fbad941d1b3d Mon Sep 17 00:00:00 2001 From: Green Sky Date: Sun, 8 Sep 2024 09:47:43 +0200 Subject: [PATCH 4/5] update docs --- README.md | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index c1ba396fe..a6979a3fe 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ Inference of Stable Diffusion and Flux in pure C/C++ - Full CUDA, Metal, Vulkan and SYCL backend for GPU acceleration. - Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs models - No need to convert to `.ggml` or `.gguf` anymore! -- Flash Attention for memory usage optimization (only cpu for now) +- Flash Attention for memory usage optimization - Original `txt2img` and `img2img` mode - Negative prompt - [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) style tokenizer (not all the features, only token weighting for now) @@ -182,11 +182,21 @@ Example of text2img by using SYCL backend: ##### Using Flash Attention -Enabling flash attention reduces memory usage by at least 400 MB. At the moment, it is not supported when CUBLAS is enabled because the kernel implementation is missing. +Enabling flash attention for the diffusion model reduces memory usage by varying amounts of MB. +eg.: + - flux 768x768 ~600mb + - SD2 768x768 ~1400mb +For most backends, it slows things down, but for cuda it generally speeds it up too. +At the moment, it is only supported for some models and some backends (like cpu, cuda/rocm, metal). + +Run by adding `--diffusion-fa` to the arguments and watch for: ``` -cmake .. -DSD_FLASH_ATTN=ON -cmake --build . --config Release +[INFO ] stable-diffusion.cpp:312 - Using flash attention in the diffusion model +``` +and the compute buffer shrink in the debug log: +``` +[DEBUG] ggml_extend.hpp:1004 - flux compute buffer size: 650.00 MB(VRAM) ``` ### Run @@ -240,6 +250,9 @@ arguments: --vae-tiling process vae in tiles to reduce memory usage --vae-on-cpu keep vae in cpu (for low vram) --clip-on-cpu keep clip in cpu (for low vram) + --diffusion-fa use flash attention in the diffusion model (for low vram) + Might lower quality, since it implies converting k and v to f16. + This might crash if it is not supported by the backend. --control-net-cpu keep controlnet in cpu (for low vram) --canny apply canny preprocessor (edge detection) --color Colors the logging tags according to level From 5ef1821fcb17fe63b97973e68989efd60398eb8e Mon Sep 17 00:00:00 2001 From: leejet Date: Sat, 23 Nov 2024 12:06:46 +0800 Subject: [PATCH 5/5] format code --- clip.hpp | 18 ++- common.hpp | 2 +- conditioner.hpp | 19 ++- diffusion_model.hpp | 4 +- examples/cli/main.cpp | 2 +- flux.hpp | 11 +- ggml_extend.hpp | 29 ++--- model.cpp | 28 ++-- model.h | 2 +- pmid.hpp | 293 +++++++++++++++++++----------------------- stable-diffusion.cpp | 36 +++--- unet.hpp | 2 +- util.cpp | 5 +- vae.hpp | 2 +- 14 files changed, 212 insertions(+), 241 deletions(-) diff --git a/clip.hpp b/clip.hpp index 7c2705873..46e52ada4 100644 --- a/clip.hpp +++ b/clip.hpp @@ -343,8 +343,7 @@ class CLIPTokenizer { } } - std::string clean_up_tokenization(std::string &text){ - + std::string clean_up_tokenization(std::string& text) { std::regex pattern(R"( ,)"); // Replace " ," with "," std::string result = std::regex_replace(text, pattern, ","); @@ -359,10 +358,10 @@ class CLIPTokenizer { std::u32string ts = decoder[t]; // printf("%d, %s \n", t, utf32_to_utf8(ts).c_str()); std::string s = utf32_to_utf8(ts); - if (s.length() >= 4 ){ - if(ends_with(s, "")) { + if (s.length() >= 4) { + if (ends_with(s, "")) { text += s.replace(s.length() - 4, s.length() - 1, "") + " "; - }else{ + } else { text += s; } } else { @@ -768,8 +767,7 @@ class CLIPVisionModel : public GGMLBlock { blocks["post_layernorm"] = std::shared_ptr(new LayerNorm(hidden_size)); } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* pixel_values, - bool return_pooled = true) { + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* pixel_values, bool return_pooled = true) { // pixel_values: [N, num_channels, image_size, image_size] auto embeddings = std::dynamic_pointer_cast(blocks["embeddings"]); auto pre_layernorm = std::dynamic_pointer_cast(blocks["pre_layernorm"]); @@ -779,11 +777,11 @@ class CLIPVisionModel : public GGMLBlock { auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim] x = pre_layernorm->forward(ctx, x); x = encoder->forward(ctx, x, -1, false); - // print_ggml_tensor(x, true, "ClipVisionModel x: "); + // print_ggml_tensor(x, true, "ClipVisionModel x: "); auto last_hidden_state = x; - x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size] + x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size] - GGML_ASSERT(x->ne[3] == 1); + GGML_ASSERT(x->ne[3] == 1); if (return_pooled) { ggml_tensor* pooled = ggml_cont(ctx, ggml_view_2d(ctx, x, x->ne[0], x->ne[2], x->nb[2], 0)); return pooled; // [N, hidden_size] diff --git a/common.hpp b/common.hpp index 784bab32b..1ca6b8d0d 100644 --- a/common.hpp +++ b/common.hpp @@ -304,7 +304,7 @@ class BasicTransformerBlock : public GGMLBlock { int64_t n_head, int64_t d_head, int64_t context_dim, - bool ff_in = false, + bool ff_in = false, bool flash_attn = false) : n_head(n_head), d_head(d_head), ff_in(ff_in) { // disable_self_attn is always False diff --git a/conditioner.hpp b/conditioner.hpp index 065f352ec..47fd3eb6e 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -4,7 +4,6 @@ #include "clip.hpp" #include "t5.hpp" - struct SDCondition { struct ggml_tensor* c_crossattn = NULL; // aka context struct ggml_tensor* c_vector = NULL; // aka y @@ -44,7 +43,7 @@ struct Conditioner { // ldm.modules.encoders.modules.FrozenCLIPEmbedder // Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283 struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { - SDVersion version = VERSION_SD1; + SDVersion version = VERSION_SD1; PMVersion pm_version = VERSION_1; CLIPTokenizer tokenizer; ggml_type wtype; @@ -61,7 +60,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { ggml_type wtype, const std::string& embd_dir, SDVersion version = VERSION_SD1, - PMVersion pv = VERSION_1, + PMVersion pv = VERSION_1, int clip_skip = -1) : version(version), pm_version(pv), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir), wtype(wtype) { if (clip_skip <= 0) { @@ -162,7 +161,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { tokenize_with_trigger_token(std::string text, int num_input_imgs, int32_t image_token, - bool padding = false){ + bool padding = false) { return tokenize_with_trigger_token(text, num_input_imgs, image_token, text_model->model.n_token, padding); } @@ -271,7 +270,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { std::vector clean_input_ids_tmp; for (uint32_t i = 0; i < class_token_index[0]; i++) clean_input_ids_tmp.push_back(clean_input_ids[i]); - for (uint32_t i = 0; i < (pm_version == VERSION_2 ? 2*num_input_imgs: num_input_imgs); i++) + for (uint32_t i = 0; i < (pm_version == VERSION_2 ? 2 * num_input_imgs : num_input_imgs); i++) clean_input_ids_tmp.push_back(class_token); for (uint32_t i = class_token_index[0] + 1; i < clean_input_ids.size(); i++) clean_input_ids_tmp.push_back(clean_input_ids[i]); @@ -287,11 +286,11 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { // weights.insert(weights.begin(), 1.0); tokenizer.pad_tokens(tokens, weights, max_length, padding); - int offset = pm_version == VERSION_2 ? 2*num_input_imgs: num_input_imgs; + int offset = pm_version == VERSION_2 ? 2 * num_input_imgs : num_input_imgs; for (uint32_t i = 0; i < tokens.size(); i++) { // if (class_idx + 1 <= i && i < class_idx + 1 + 2*num_input_imgs) // photomaker V2 has num_tokens(=2)*num_input_imgs - if (class_idx + 1 <= i && i < class_idx + 1 + offset) // photomaker V2 has num_tokens(=2)*num_input_imgs - // hardcode for now + if (class_idx + 1 <= i && i < class_idx + 1 + offset) // photomaker V2 has num_tokens(=2)*num_input_imgs + // hardcode for now class_token_mask.push_back(true); else class_token_mask.push_back(false); @@ -536,7 +535,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { int height, int num_input_imgs, int adm_in_channels = -1, - bool force_zero_embeddings = false){ + bool force_zero_embeddings = false) { auto image_tokens = convert_token_to_id(trigger_word); // if(image_tokens.size() == 1){ // printf(" image token id is: %d \n", image_tokens[0]); @@ -964,7 +963,7 @@ struct SD3CLIPEmbedder : public Conditioner { int height, int num_input_imgs, int adm_in_channels = -1, - bool force_zero_embeddings = false){ + bool force_zero_embeddings = false) { GGML_ASSERT(0 && "Not implemented yet!"); } diff --git a/diffusion_model.hpp b/diffusion_model.hpp index 0189f1805..eb433b614 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -33,7 +33,7 @@ struct UNetModel : public DiffusionModel { UNetModel(ggml_backend_t backend, ggml_type wtype, SDVersion version = VERSION_SD1, - bool flash_attn = false) + bool flash_attn = false) : unet(backend, wtype, version, flash_attn) { } @@ -135,7 +135,7 @@ struct FluxModel : public DiffusionModel { FluxModel(ggml_backend_t backend, ggml_type wtype, SDVersion version = VERSION_FLUX_DEV, - bool flash_attn = false) + bool flash_attn = false) : flux(backend, wtype, version, flash_attn) { } diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 39b45c33e..9f25245e3 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -483,7 +483,7 @@ void parse_args(int argc, const char** argv, SDParams& params) { } else if (arg == "--vae-on-cpu") { params.vae_on_cpu = true; // will slow down latent decoding but necessary for low MEM GPUs } else if (arg == "--diffusion-fa") { - params.diffusion_flash_attn = true; // can reduce MEM significantly + params.diffusion_flash_attn = true; // can reduce MEM significantly } else if (arg == "--canny") { params.canny_preprocess = true; } else if (arg == "-b" || arg == "--batch-count") { diff --git a/flux.hpp b/flux.hpp index 09ee93c30..b2d0f57c2 100644 --- a/flux.hpp +++ b/flux.hpp @@ -170,9 +170,9 @@ namespace Flux { // x: [N, n_token, dim] // pe: [n_token, d_head/2, 2, 2] // return [N, n_token, dim] - auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head] + auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head] x = attention(ctx, qkv[0], qkv[1], qkv[2], pe, flash_attn); // [N, n_token, dim] - x = post_attention(ctx, x); // [N, n_token, dim] + x = post_attention(ctx, x); // [N, n_token, dim] return x; } }; @@ -241,11 +241,12 @@ namespace Flux { struct DoubleStreamBlock : public GGMLBlock { bool flash_attn; + public: DoubleStreamBlock(int64_t hidden_size, int64_t num_heads, float mlp_ratio, - bool qkv_bias = false, + bool qkv_bias = false, bool flash_attn = false) : flash_attn(flash_attn) { int64_t mlp_hidden_dim = hidden_size * mlp_ratio; @@ -322,7 +323,7 @@ namespace Flux { auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] - auto attn = attention(ctx, q, k, v, pe, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head] + auto attn = attention(ctx, q, k, v, pe, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head] attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] auto txt_attn_out = ggml_view_3d(ctx, attn, @@ -830,7 +831,7 @@ namespace Flux { FluxRunner(ggml_backend_t backend, ggml_type wtype, SDVersion version = VERSION_FLUX_DEV, - bool flash_attn = false) + bool flash_attn = false) : GGMLRunner(backend, wtype) { flux_params.flash_attn = flash_attn; if (version == VERSION_FLUX_SCHNELL) { diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 65c05e8b4..75ad04142 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -709,18 +709,18 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* float scale = (1.0f / sqrt((float)d_head)); - //if (flash_attn) { - // LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N); - //} - // is there anything oddly shaped?? ping Green-Sky if you can trip this assert + // if (flash_attn) { + // LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N); + // } + // is there anything oddly shaped?? ping Green-Sky if you can trip this assert GGML_ASSERT(((L_k % 256 == 0) && L_q == L_k) || !(L_k % 256 == 0)); bool can_use_flash_attn = true; - can_use_flash_attn = can_use_flash_attn && L_k % 256 == 0; - can_use_flash_attn = can_use_flash_attn && d_head % 64 == 0; // double check + can_use_flash_attn = can_use_flash_attn && L_k % 256 == 0; + can_use_flash_attn = can_use_flash_attn && d_head % 64 == 0; // double check // cuda max d_head seems to be 256, cpu does seem to work with 512 - can_use_flash_attn = can_use_flash_attn && d_head <= 256; // double check + can_use_flash_attn = can_use_flash_attn && d_head <= 256; // double check if (mask != nullptr) { // TODO(Green-Sky): figure out if we can bend t5 to work too @@ -731,9 +731,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* // TODO(Green-Sky): more pad or disable for funny tensor shapes ggml_tensor* kqv = nullptr; - //GGML_ASSERT((flash_attn && can_use_flash_attn) || !flash_attn); + // GGML_ASSERT((flash_attn && can_use_flash_attn) || !flash_attn); if (can_use_flash_attn && flash_attn) { - //LOG_DEBUG("using flash attention"); + // LOG_DEBUG("using flash attention"); k = ggml_cast(ctx, k, GGML_TYPE_F16); v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head] @@ -743,7 +743,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* kqv = ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0, 0); ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32); - //kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_k, kqv->nb[1], kqv->nb[2], 0); + // kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_k, kqv->nb[1], kqv->nb[2], 0); kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_q, kqv->nb[1], kqv->nb[2], 0); } else { v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, L_k] @@ -761,8 +761,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, L_q, d_head] - kqv = ggml_reshape_4d(ctx, kqv, d_head, L_q, n_head, N); // [N, n_head, L_q, d_head] - kqv = ggml_permute(ctx, kqv, 0, 2, 1, 3); // [N, L_q, n_head, d_head] + kqv = ggml_reshape_4d(ctx, kqv, d_head, L_q, n_head, N); // [N, n_head, L_q, d_head] + kqv = ggml_permute(ctx, kqv, 0, 2, 1, 3); // [N, L_q, n_head, d_head] } kqv = ggml_cont(ctx, kqv); @@ -1057,7 +1057,7 @@ struct GGMLRunner { // get_desc().c_str(), // params_buffer_size / (1024.0 * 1024.0), // ggml_backend_is_cpu(backend) ? "RAM" : "VRAM", - // num_tensors); + // num_tensors); return true; } @@ -1227,8 +1227,7 @@ class Linear : public UnaryBlock { params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features); if (bias) { params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_features); - } - + } } public: diff --git a/model.cpp b/model.cpp index 5f1e6e160..2719f63c0 100644 --- a/model.cpp +++ b/model.cpp @@ -148,19 +148,19 @@ std::unordered_map vae_decoder_name_map = { std::unordered_map pmid_v2_name_map = { {"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.weight", - "pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.fc1.weight"}, + "pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.fc1.weight"}, {"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.3.weight", - "pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.fc2.weight"}, + "pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.fc2.weight"}, {"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.weight", - "pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.fc1.weight"}, + "pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.fc1.weight"}, {"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.3.weight", "pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.fc2.weight"}, {"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.weight", - "pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.fc1.weight"}, + "pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.fc1.weight"}, {"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.3.weight", - "pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.fc2.weight"}, + "pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.fc2.weight"}, {"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.weight", - "pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.fc1.weight"}, + "pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.fc1.weight"}, {"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.3.weight", "pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.fc2.weight"}, {"pmid.qformer_perceiver.token_proj.0.bias", @@ -650,9 +650,8 @@ uint16_t f8_e4m3_to_f16(uint8_t f8) { return ggml_fp32_to_fp16(*reinterpret_cast(&result)); } - uint16_t f8_e5m2_to_f16(uint8_t fp8) { - uint8_t sign = (fp8 >> 7) & 0x1; + uint8_t sign = (fp8 >> 7) & 0x1; uint8_t exponent = (fp8 >> 2) & 0x1F; uint8_t mantissa = fp8 & 0x3; @@ -660,23 +659,23 @@ uint16_t f8_e5m2_to_f16(uint8_t fp8) { uint16_t fp16_exponent; uint16_t fp16_mantissa; - if (exponent == 0 && mantissa == 0) { //zero + if (exponent == 0 && mantissa == 0) { // zero return fp16_sign; } - if (exponent == 0x1F) { //NAN and INF + if (exponent == 0x1F) { // NAN and INF fp16_exponent = 0x1F; fp16_mantissa = mantissa ? (mantissa << 8) : 0; return fp16_sign | (fp16_exponent << 10) | fp16_mantissa; } - if (exponent == 0) { //subnormal numbers + if (exponent == 0) { // subnormal numbers fp16_exponent = 0; fp16_mantissa = (mantissa << 8); return fp16_sign | fp16_mantissa; } - //normal numbers + // normal numbers int16_t true_exponent = (int16_t)exponent - 15 + 15; if (true_exponent <= 0) { fp16_exponent = 0; @@ -1051,7 +1050,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const } TensorStorage tensor_storage(prefix + name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin); - tensor_storage.reverse_ne(); + tensor_storage.reverse_ne(); size_t tensor_data_size = end - begin; @@ -1434,10 +1433,9 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s std::string name = zip_entry_name(zip); size_t pos = name.find("data.pkl"); if (pos != std::string::npos) { - std::string dir = name.substr(0, pos); printf("ZIP %d, name = %s, dir = %s \n", i, name.c_str(), dir.c_str()); - void* pkl_data = NULL; + void* pkl_data = NULL; size_t pkl_size; zip_entry_read(zip, &pkl_data, &pkl_size); diff --git a/model.h b/model.h index 77841e82c..552a2ccd8 100644 --- a/model.h +++ b/model.h @@ -167,7 +167,7 @@ class ModelLoader { bool load_tensors(std::map& tensors, ggml_backend_t backend, std::set ignore_tensors = {}); - + bool save_to_gguf_file(const std::string& file_path, ggml_type type); bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type); int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT); diff --git a/pmid.hpp b/pmid.hpp index bde03cc92..b8555eb68 100644 --- a/pmid.hpp +++ b/pmid.hpp @@ -6,7 +6,6 @@ #include "clip.hpp" #include "lora.hpp" - struct FuseBlock : public GGMLBlock { // network hparams int in_dim; @@ -74,26 +73,24 @@ class QFormerPerceiver(nn.Module): x = self.token_norm(x) # cls token out = self.perceiver_resampler(x, last_hidden_state) # retrieve from patch tokens if self.use_residual: # TODO: if use_residual is not true - out = x + 1.0 * out + out = x + 1.0 * out return out */ - struct PMFeedForward : public GGMLBlock { // network hparams int dim; public: - PMFeedForward(int d, int multi=4) - : dim(d) { + PMFeedForward(int d, int multi = 4) + : dim(d) { int inner_dim = dim * multi; blocks["0"] = std::shared_ptr(new LayerNorm(dim)); blocks["1"] = std::shared_ptr(new Mlp(dim, inner_dim, dim, false)); } struct ggml_tensor* forward(struct ggml_context* ctx, - struct ggml_tensor* x){ - + struct ggml_tensor* x) { auto norm = std::dynamic_pointer_cast(blocks["0"]); auto ff = std::dynamic_pointer_cast(blocks["1"]); @@ -101,37 +98,35 @@ struct PMFeedForward : public GGMLBlock { x = ff->forward(ctx, x); return x; } - }; struct PerceiverAttention : public GGMLBlock { // network hparams - float scale; // = dim_head**-0.5 - int dim_head; // = dim_head - int heads; // = heads + float scale; // = dim_head**-0.5 + int dim_head; // = dim_head + int heads; // = heads public: - PerceiverAttention(int dim, int dim_h=64, int h=8) - : scale(powf(dim_h, -0.5)), dim_head(dim_h), heads(h) { - - int inner_dim = dim_head * heads; + PerceiverAttention(int dim, int dim_h = 64, int h = 8) + : scale(powf(dim_h, -0.5)), dim_head(dim_h), heads(h) { + int inner_dim = dim_head * heads; blocks["norm1"] = std::shared_ptr(new LayerNorm(dim)); blocks["norm2"] = std::shared_ptr(new LayerNorm(dim)); blocks["to_q"] = std::shared_ptr(new Linear(dim, inner_dim, false)); - blocks["to_kv"] = std::shared_ptr(new Linear(dim, inner_dim*2, false)); + blocks["to_kv"] = std::shared_ptr(new Linear(dim, inner_dim * 2, false)); blocks["to_out"] = std::shared_ptr(new Linear(inner_dim, dim, false)); } struct ggml_tensor* reshape_tensor(struct ggml_context* ctx, - struct ggml_tensor* x, - int heads) { + struct ggml_tensor* x, + int heads) { int64_t ne[4]; - for(int i = 0; i < 4; ++i) - ne[i] = x->ne[i]; + for (int i = 0; i < 4; ++i) + ne[i] = x->ne[i]; // print_ggml_tensor(x, true, "PerceiverAttention reshape x 0: "); // printf("heads = %d \n", heads); // x = ggml_view_4d(ctx, x, x->ne[0], x->ne[1], heads, x->ne[2]/heads, // x->nb[1], x->nb[2], x->nb[3], 0); - x = ggml_reshape_4d(ctx, x, x->ne[0]/heads, heads, x->ne[1], x->ne[2]); + x = ggml_reshape_4d(ctx, x, x->ne[0] / heads, heads, x->ne[1], x->ne[2]); // x = ggml_view_4d(ctx, x, x->ne[0]/heads, heads, x->ne[1], x->ne[2], // x->nb[1], x->nb[2], x->nb[3], 0); // x = ggml_cont(ctx, x); @@ -142,49 +137,46 @@ struct PerceiverAttention : public GGMLBlock { } std::vector chunk_half(struct ggml_context* ctx, - struct ggml_tensor* x){ - - auto tlo = ggml_view_4d(ctx, x, x->ne[0]/2, x->ne[1], x->ne[2], x->ne[3], x->nb[1], x->nb[2], x->nb[3], 0); - auto tli = ggml_view_4d(ctx, x, x->ne[0]/2, x->ne[1], x->ne[2], x->ne[3], x->nb[1], x->nb[2], x->nb[3], x->nb[0]*x->ne[0]/2); + struct ggml_tensor* x) { + auto tlo = ggml_view_4d(ctx, x, x->ne[0] / 2, x->ne[1], x->ne[2], x->ne[3], x->nb[1], x->nb[2], x->nb[3], 0); + auto tli = ggml_view_4d(ctx, x, x->ne[0] / 2, x->ne[1], x->ne[2], x->ne[3], x->nb[1], x->nb[2], x->nb[3], x->nb[0] * x->ne[0] / 2); return {ggml_cont(ctx, tlo), ggml_cont(ctx, tli)}; - } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, - struct ggml_tensor* latents){ - + struct ggml_tensor* latents) { // x (torch.Tensor): image features // shape (b, n1, D) // latent (torch.Tensor): latent features // shape (b, n2, D) int64_t ne[4]; - for(int i = 0; i < 4; ++i) - ne[i] = latents->ne[i]; + for (int i = 0; i < 4; ++i) + ne[i] = latents->ne[i]; - auto norm1 = std::dynamic_pointer_cast(blocks["norm1"]); - auto norm2 = std::dynamic_pointer_cast(blocks["norm2"]); - x = norm1->forward(ctx, x); - latents = norm2->forward(ctx, latents); - auto to_q = std::dynamic_pointer_cast(blocks["to_q"]); - auto q = to_q->forward(ctx, latents); + auto norm1 = std::dynamic_pointer_cast(blocks["norm1"]); + auto norm2 = std::dynamic_pointer_cast(blocks["norm2"]); + x = norm1->forward(ctx, x); + latents = norm2->forward(ctx, latents); + auto to_q = std::dynamic_pointer_cast(blocks["to_q"]); + auto q = to_q->forward(ctx, latents); auto kv_input = ggml_concat(ctx, x, latents, 1); auto to_kv = std::dynamic_pointer_cast(blocks["to_kv"]); - auto kv = to_kv->forward(ctx, kv_input); - auto k = ggml_view_4d(ctx, kv, kv->ne[0]/2, kv->ne[1], kv->ne[2], kv->ne[3], kv->nb[1]/2, kv->nb[2]/2, kv->nb[3]/2, 0); - auto v = ggml_view_4d(ctx, kv, kv->ne[0]/2, kv->ne[1], kv->ne[2], kv->ne[3], kv->nb[1]/2, kv->nb[2]/2, kv->nb[3]/2, kv->nb[0]*(kv->ne[0]/2)); - k = ggml_cont(ctx, k); - v = ggml_cont(ctx, v); - q = reshape_tensor(ctx, q, heads); - k = reshape_tensor(ctx, k, heads); - v = reshape_tensor(ctx, v, heads); - scale = 1.f / sqrt(sqrt((float)dim_head)); - k = ggml_scale_inplace(ctx, k, scale); - q = ggml_scale_inplace(ctx, q, scale); + auto kv = to_kv->forward(ctx, kv_input); + auto k = ggml_view_4d(ctx, kv, kv->ne[0] / 2, kv->ne[1], kv->ne[2], kv->ne[3], kv->nb[1] / 2, kv->nb[2] / 2, kv->nb[3] / 2, 0); + auto v = ggml_view_4d(ctx, kv, kv->ne[0] / 2, kv->ne[1], kv->ne[2], kv->ne[3], kv->nb[1] / 2, kv->nb[2] / 2, kv->nb[3] / 2, kv->nb[0] * (kv->ne[0] / 2)); + k = ggml_cont(ctx, k); + v = ggml_cont(ctx, v); + q = reshape_tensor(ctx, q, heads); + k = reshape_tensor(ctx, k, heads); + v = reshape_tensor(ctx, v, heads); + scale = 1.f / sqrt(sqrt((float)dim_head)); + k = ggml_scale_inplace(ctx, k, scale); + q = ggml_scale_inplace(ctx, q, scale); // auto weight = ggml_mul_mat(ctx, q, k); - auto weight = ggml_mul_mat(ctx, k, q); // NOTE order of mul is opposite to pytorch + auto weight = ggml_mul_mat(ctx, k, q); // NOTE order of mul is opposite to pytorch // GGML's softmax() is equivalent to pytorch's softmax(x, dim=-1) // in this case, dimension along which Softmax will be computed is the last dim @@ -192,13 +184,13 @@ struct PerceiverAttention : public GGMLBlock { // last dimension (varying most rapidly) corresponds to GGML's first (varying most rapidly). // weight = ggml_soft_max(ctx, weight); weight = ggml_soft_max_inplace(ctx, weight); - v = ggml_cont(ctx, ggml_transpose(ctx, v)); + v = ggml_cont(ctx, ggml_transpose(ctx, v)); // auto out = ggml_mul_mat(ctx, weight, v); - auto out = ggml_mul_mat(ctx, v, weight); // NOTE order of mul is opposite to pytorch - out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); - out = ggml_reshape_3d(ctx, out, ne[0], ne[1], ggml_nelements(out)/(ne[0]*ne[1])); - auto to_out = std::dynamic_pointer_cast(blocks["to_out"]); - out = to_out->forward(ctx, out); + auto out = ggml_mul_mat(ctx, v, weight); // NOTE order of mul is opposite to pytorch + out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); + out = ggml_reshape_3d(ctx, out, ne[0], ne[1], ggml_nelements(out) / (ne[0] * ne[1])); + auto to_out = std::dynamic_pointer_cast(blocks["to_out"]); + out = to_out->forward(ctx, out); return out; } }; @@ -206,45 +198,46 @@ struct PerceiverAttention : public GGMLBlock { struct FacePerceiverResampler : public GGMLBlock { // network hparams int depth; + public: - FacePerceiverResampler( int dim=768, - int d=4, - int dim_head=64, - int heads=16, - int embedding_dim=1280, - int output_dim=768, - int ff_mult=4) - : depth(d) { - blocks["proj_in"] = std::shared_ptr(new Linear(embedding_dim, dim, true)); + FacePerceiverResampler(int dim = 768, + int d = 4, + int dim_head = 64, + int heads = 16, + int embedding_dim = 1280, + int output_dim = 768, + int ff_mult = 4) + : depth(d) { + blocks["proj_in"] = std::shared_ptr(new Linear(embedding_dim, dim, true)); blocks["proj_out"] = std::shared_ptr(new Linear(dim, output_dim, true)); blocks["norm_out"] = std::shared_ptr(new LayerNorm(output_dim)); for (int i = 0; i < depth; i++) { std::string name = "layers." + std::to_string(i) + ".0"; blocks[name] = std::shared_ptr(new PerceiverAttention(dim, dim_head, heads)); - name = "layers." + std::to_string(i) + ".1"; + name = "layers." + std::to_string(i) + ".1"; blocks[name] = std::shared_ptr(new PMFeedForward(dim, ff_mult)); } } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* latents, - struct ggml_tensor* x){ + struct ggml_tensor* x) { // x: [N, channels, h, w] - auto proj_in = std::dynamic_pointer_cast(blocks["proj_in"]); - auto proj_out = std::dynamic_pointer_cast(blocks["proj_out"]); - auto norm_out = std::dynamic_pointer_cast(blocks["norm_out"]); + auto proj_in = std::dynamic_pointer_cast(blocks["proj_in"]); + auto proj_out = std::dynamic_pointer_cast(blocks["proj_out"]); + auto norm_out = std::dynamic_pointer_cast(blocks["norm_out"]); x = proj_in->forward(ctx, x); for (int i = 0; i < depth; i++) { std::string name = "layers." + std::to_string(i) + ".0"; - auto attn = std::dynamic_pointer_cast(blocks[name]); - name = "layers." + std::to_string(i) + ".1"; - auto ff = std::dynamic_pointer_cast(blocks[name]); - auto t = attn->forward(ctx, x, latents); - latents = ggml_add(ctx, t, latents); - t = ff->forward(ctx, latents); - latents = ggml_add(ctx, t, latents); + auto attn = std::dynamic_pointer_cast(blocks[name]); + name = "layers." + std::to_string(i) + ".1"; + auto ff = std::dynamic_pointer_cast(blocks[name]); + auto t = attn->forward(ctx, x, latents); + latents = ggml_add(ctx, t, latents); + t = ff->forward(ctx, latents); + latents = ggml_add(ctx, t, latents); } latents = proj_out->forward(ctx, latents); latents = norm_out->forward(ctx, latents); @@ -258,51 +251,49 @@ struct QFormerPerceiver : public GGMLBlock { int cross_attention_dim; bool use_residul; - public: - QFormerPerceiver(int id_embeddings_dim, int cross_attention_d, int num_t, int embedding_dim=1024, - bool use_r=true, int ratio=4) - : cross_attention_dim(cross_attention_d), num_tokens(num_t), use_residul(use_r) { - blocks["token_proj"] = std::shared_ptr(new Mlp(id_embeddings_dim, - id_embeddings_dim*ratio, - cross_attention_dim*num_tokens, - true)); - blocks["token_norm"] = std::shared_ptr(new LayerNorm(cross_attention_d)); + QFormerPerceiver(int id_embeddings_dim, int cross_attention_d, int num_t, int embedding_dim = 1024, bool use_r = true, int ratio = 4) + : cross_attention_dim(cross_attention_d), num_tokens(num_t), use_residul(use_r) { + blocks["token_proj"] = std::shared_ptr(new Mlp(id_embeddings_dim, + id_embeddings_dim * ratio, + cross_attention_dim * num_tokens, + true)); + blocks["token_norm"] = std::shared_ptr(new LayerNorm(cross_attention_d)); blocks["perceiver_resampler"] = std::shared_ptr(new FacePerceiverResampler( - cross_attention_dim, - 4, - 128, - cross_attention_dim / 128, - embedding_dim, - cross_attention_dim, - 4)); + cross_attention_dim, + 4, + 128, + cross_attention_dim / 128, + embedding_dim, + cross_attention_dim, + 4)); } - /* + /* def forward(self, x, last_hidden_state): x = self.token_proj(x) x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) x = self.token_norm(x) # cls token out = self.perceiver_resampler(x, last_hidden_state) # retrieve from patch tokens if self.use_residual: # TODO: if use_residual is not true - out = x + 1.0 * out + out = x + 1.0 * out return out */ struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, - struct ggml_tensor* last_hidden_state){ + struct ggml_tensor* last_hidden_state) { // x: [N, channels, h, w] - auto token_proj = std::dynamic_pointer_cast(blocks["token_proj"]); - auto token_norm = std::dynamic_pointer_cast(blocks["token_norm"]); + auto token_proj = std::dynamic_pointer_cast(blocks["token_proj"]); + auto token_norm = std::dynamic_pointer_cast(blocks["token_norm"]); auto perceiver_resampler = std::dynamic_pointer_cast(blocks["perceiver_resampler"]); - x = token_proj->forward(ctx, x); - int64_t nel = ggml_nelements(x); - x = ggml_reshape_3d(ctx, x, cross_attention_dim, num_tokens, nel/(cross_attention_dim*num_tokens)); - x = token_norm->forward(ctx, x); + x = token_proj->forward(ctx, x); + int64_t nel = ggml_nelements(x); + x = ggml_reshape_3d(ctx, x, cross_attention_dim, num_tokens, nel / (cross_attention_dim * num_tokens)); + x = token_norm->forward(ctx, x); struct ggml_tensor* out = perceiver_resampler->forward(ctx, x, last_hidden_state); - if(use_residul) + if (use_residul) out = ggml_add(ctx, x, out); return out; } @@ -322,7 +313,7 @@ class FacePerceiverResampler(torch.nn.Module): ff_mult=4, ): super().__init__() - + self.proj_in = torch.nn.Linear(embedding_dim, dim) self.proj_out = torch.nn.Linear(dim, output_dim) self.norm_out = torch.nn.LayerNorm(output_dim) @@ -346,8 +337,6 @@ class FacePerceiverResampler(torch.nn.Module): return self.norm_out(latents) */ - - /* def FeedForward(dim, mult=4): @@ -417,9 +406,6 @@ class PerceiverAttention(nn.Module): */ - - - struct FuseModule : public GGMLBlock { // network hparams int embed_dim; @@ -485,9 +471,9 @@ struct FuseModule : public GGMLBlock { // print_ggml_tensor(class_tokens_mask_pos, true, "class_tokens_mask_pos"); struct ggml_tensor* image_token_embeds = ggml_get_rows(ctx, prompt_embeds, class_tokens_mask_pos); ggml_set_name(image_token_embeds, "image_token_embeds"); - valid_id_embeds = ggml_reshape_2d(ctx, valid_id_embeds, valid_id_embeds->ne[0], - ggml_nelements(valid_id_embeds)/valid_id_embeds->ne[0]); - struct ggml_tensor* stacked_id_embeds = fuse_fn(ctx, image_token_embeds, valid_id_embeds); + valid_id_embeds = ggml_reshape_2d(ctx, valid_id_embeds, valid_id_embeds->ne[0], + ggml_nelements(valid_id_embeds) / valid_id_embeds->ne[0]); + struct ggml_tensor* stacked_id_embeds = fuse_fn(ctx, image_token_embeds, valid_id_embeds); // stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 0, 2, 1, 3)); // print_ggml_tensor(stacked_id_embeds, true, "AA stacked_id_embeds"); @@ -555,14 +541,13 @@ struct PhotoMakerIDEncoderBlock : public CLIPVisionModelProjection { }; struct PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock : public CLIPVisionModelProjection { - int cross_attention_dim; int num_tokens; - PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock(int id_embeddings_dim=512) + PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock(int id_embeddings_dim = 512) : CLIPVisionModelProjection(OPENAI_CLIP_VIT_L_14), - cross_attention_dim (2048), - num_tokens(2) { + cross_attention_dim(2048), + num_tokens(2) { blocks["visual_projection_2"] = std::shared_ptr(new Linear(1024, 1280, false)); blocks["fuse_module"] = std::shared_ptr(new FuseModule(2048)); /* @@ -571,14 +556,13 @@ struct PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock : public CLIPVisionMo self.num_tokens = 2 self.cross_attention_dim = cross_attention_dim self.qformer_perceiver = QFormerPerceiver( - id_embeddings_dim, - cross_attention_dim, + id_embeddings_dim, + cross_attention_dim, self.num_tokens, )*/ - blocks["qformer_perceiver"] = std::shared_ptr(new QFormerPerceiver(id_embeddings_dim, - cross_attention_dim, - num_tokens)); - + blocks["qformer_perceiver"] = std::shared_ptr(new QFormerPerceiver(id_embeddings_dim, + cross_attention_dim, + num_tokens)); } /* @@ -603,14 +587,14 @@ struct PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock : public CLIPVisionMo struct ggml_tensor* left, struct ggml_tensor* right) { // x: [N, channels, h, w] - auto vision_model = std::dynamic_pointer_cast(blocks["vision_model"]); - auto fuse_module = std::dynamic_pointer_cast(blocks["fuse_module"]); - auto qformer_perceiver = std::dynamic_pointer_cast(blocks["qformer_perceiver"]); + auto vision_model = std::dynamic_pointer_cast(blocks["vision_model"]); + auto fuse_module = std::dynamic_pointer_cast(blocks["fuse_module"]); + auto qformer_perceiver = std::dynamic_pointer_cast(blocks["qformer_perceiver"]); // struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, id_pixel_values); // [N, hidden_size] - struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, id_pixel_values, false); // [N, hidden_size] - id_embeds = qformer_perceiver->forward(ctx, id_embeds, last_hidden_state); - + struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, id_pixel_values, false); // [N, hidden_size] + id_embeds = qformer_perceiver->forward(ctx, id_embeds, last_hidden_state); + struct ggml_tensor* updated_prompt_embeds = fuse_module->forward(ctx, prompt_embeds, id_embeds, @@ -623,7 +607,7 @@ struct PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock : public CLIPVisionMo struct PhotoMakerIDEncoder : public GGMLRunner { public: - SDVersion version = VERSION_SDXL; + SDVersion version = VERSION_SDXL; PMVersion pm_version = VERSION_1; PhotoMakerIDEncoderBlock id_encoder; PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock id_encoder2; @@ -639,15 +623,14 @@ struct PhotoMakerIDEncoder : public GGMLRunner { std::vector zeros_right; public: - PhotoMakerIDEncoder(ggml_backend_t backend, ggml_type wtype, SDVersion version = VERSION_SDXL, - PMVersion pm_v = VERSION_1, float sty = 20.f) + PhotoMakerIDEncoder(ggml_backend_t backend, ggml_type wtype, SDVersion version = VERSION_SDXL, PMVersion pm_v = VERSION_1, float sty = 20.f) : GGMLRunner(backend, wtype), version(version), pm_version(pm_v), style_strength(sty) { - if(pm_version == VERSION_1){ + if (pm_version == VERSION_1) { id_encoder.init(params_ctx, wtype); - }else if(pm_version == VERSION_2){ + } else if (pm_version == VERSION_2) { id_encoder2.init(params_ctx, wtype); } } @@ -656,17 +639,15 @@ struct PhotoMakerIDEncoder : public GGMLRunner { return "pmid"; } - PMVersion get_version() const{ + PMVersion get_version() const { return pm_version; } - void get_param_tensors(std::map& tensors, const std::string prefix) { - if(pm_version == VERSION_1) + if (pm_version == VERSION_1) id_encoder.get_param_tensors(tensors, prefix); - else if(pm_version == VERSION_2) + else if (pm_version == VERSION_2) id_encoder2.get_param_tensors(tensors, prefix); - } struct ggml_cgraph* build_graph( // struct ggml_allocr* allocr, @@ -753,14 +734,14 @@ struct PhotoMakerIDEncoder : public GGMLRunner { } } struct ggml_tensor* updated_prompt_embeds = NULL; - if(pm_version == VERSION_1) + if (pm_version == VERSION_1) updated_prompt_embeds = id_encoder.forward(ctx0, - id_pixel_values_d, - prompt_embeds_d, - class_tokens_mask_d, - class_tokens_mask_pos, - left, right); - else if(pm_version == VERSION_2) + id_pixel_values_d, + prompt_embeds_d, + class_tokens_mask_d, + class_tokens_mask_pos, + left, right); + else if (pm_version == VERSION_2) updated_prompt_embeds = id_encoder2.forward(ctx0, id_pixel_values_d, prompt_embeds_d, @@ -791,22 +772,19 @@ struct PhotoMakerIDEncoder : public GGMLRunner { } }; - struct PhotoMakerIDEmbed : public GGMLRunner { - std::map tensors; std::string file_path; - ModelLoader *model_loader; + ModelLoader* model_loader; bool load_failed = false; bool applied = false; PhotoMakerIDEmbed(ggml_backend_t backend, - ggml_type wtype, - ModelLoader *ml, - const std::string& file_path = "", - const std::string& prefix = "") - : file_path(file_path), GGMLRunner(backend, wtype), - model_loader(ml) { + ggml_type wtype, + ModelLoader* ml, + const std::string& file_path = "", + const std::string& prefix = "") + : file_path(file_path), GGMLRunner(backend, wtype), model_loader(ml) { if (!model_loader->init_from_file(file_path, prefix)) { load_failed = true; } @@ -831,13 +809,13 @@ struct PhotoMakerIDEmbed : public GGMLRunner { if (filter_tensor && !contains(name, "pmid.id_embeds")) { // LOG_INFO("skipping LoRA tesnor '%s'", name.c_str()); return true; - } + } if (dry_run) { struct ggml_tensor* real = ggml_new_tensor(params_ctx, tensor_storage.type, tensor_storage.n_dims, tensor_storage.ne); - tensors[name] = real; + tensors[name] = real; } else { auto real = tensors[name]; *dst_tensor = real; @@ -856,13 +834,12 @@ struct PhotoMakerIDEmbed : public GGMLRunner { return true; } - - struct ggml_tensor* get(){ + struct ggml_tensor* get() { std::map::iterator pos; pos = tensors.find("pmid.id_embeds"); - if(pos != tensors.end()) + if (pos != tensors.end()) return pos->second; - return NULL; + return NULL; } }; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 1a46ee824..c722b6539 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -333,12 +333,12 @@ class StableDiffusionGGML { cond_stage_model = std::make_shared(clip_backend, conditioner_wtype); diffusion_model = std::make_shared(backend, diffusion_model_wtype, version, diffusion_flash_attn); } else { - if(id_embeddings_path.find("v2") != std::string::npos) { + if (id_embeddings_path.find("v2") != std::string::npos) { cond_stage_model = std::make_shared(clip_backend, conditioner_wtype, embeddings_path, version, VERSION_2); - }else{ + } else { cond_stage_model = std::make_shared(clip_backend, conditioner_wtype, embeddings_path, version); - } - diffusion_model = std::make_shared(backend, diffusion_model_wtype, version, diffusion_flash_attn); + } + diffusion_model = std::make_shared(backend, diffusion_model_wtype, version, diffusion_flash_attn); } cond_stage_model->alloc_params_buffer(); cond_stage_model->get_param_tensors(tensors); @@ -372,7 +372,7 @@ class StableDiffusionGGML { control_net = std::make_shared(controlnet_backend, diffusion_model_wtype, version); } - if(id_embeddings_path.find("v2") != std::string::npos) { + if (id_embeddings_path.find("v2") != std::string::npos) { pmid_model = std::make_shared(backend, model_wtype, version, VERSION_2); LOG_INFO("using PhotoMaker Version 2"); } else { @@ -1220,9 +1220,9 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, for (std::string img_file : img_files) { int c = 0; int width, height; - if(ends_with(img_file, "safetensors")){ + if (ends_with(img_file, "safetensors")) { continue; - } + } uint8_t* input_image_buffer = stbi_load(img_file.c_str(), &width, &height, &c, 3); if (input_image_buffer == NULL) { LOG_ERROR("PhotoMaker load image from '%s' failed", img_file.c_str()); @@ -1260,18 +1260,18 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, else sd_mul_images_to_tensor(init_image->data, init_img, i, NULL, NULL); } - t0 = ggml_time_ms(); - auto cond_tup = sd_ctx->sd->cond_stage_model->get_learned_condition_with_trigger(work_ctx, - sd_ctx->sd->n_threads, prompt, - clip_skip, - width, - height, - num_input_images, - sd_ctx->sd->diffusion_model->get_adm_in_channels()); - id_cond = std::get<0>(cond_tup); - class_tokens_mask = std::get<1>(cond_tup); // + t0 = ggml_time_ms(); + auto cond_tup = sd_ctx->sd->cond_stage_model->get_learned_condition_with_trigger(work_ctx, + sd_ctx->sd->n_threads, prompt, + clip_skip, + width, + height, + num_input_images, + sd_ctx->sd->diffusion_model->get_adm_in_channels()); + id_cond = std::get<0>(cond_tup); + class_tokens_mask = std::get<1>(cond_tup); // struct ggml_tensor* id_embeds = NULL; - if(pmv2){ + if (pmv2) { // id_embeds = sd_ctx->sd->pmid_id_embeds->get(); id_embeds = load_tensor_from_file(work_ctx, path_join(input_id_images_path, "id_embeds.bin")); // print_ggml_tensor(id_embeds, true, "id_embeds:"); diff --git a/unet.hpp b/unet.hpp index 611891298..79f702c4d 100644 --- a/unet.hpp +++ b/unet.hpp @@ -534,7 +534,7 @@ struct UNetModelRunner : public GGMLRunner { UNetModelRunner(ggml_backend_t backend, ggml_type wtype, SDVersion version = VERSION_SD1, - bool flash_attn = false) + bool flash_attn = false) : GGMLRunner(backend, wtype), unet(version, flash_attn) { unet.init(params_ctx, wtype); } diff --git a/util.cpp b/util.cpp index cd058bb07..b8a65e7d9 100644 --- a/util.cpp +++ b/util.cpp @@ -279,12 +279,12 @@ std::string path_join(const std::string& p1, const std::string& p2) { std::vector splitString(const std::string& str, char delimiter) { std::vector result; size_t start = 0; - size_t end = str.find(delimiter); + size_t end = str.find(delimiter); while (end != std::string::npos) { result.push_back(str.substr(start, end - start)); start = end + 1; - end = str.find(delimiter, start); + end = str.find(delimiter, start); } // Add the last segment after the last delimiter @@ -293,7 +293,6 @@ std::vector splitString(const std::string& str, char delimiter) { return result; } - sd_image_t* preprocess_id_image(sd_image_t* img) { int shortest_edge = 224; int size = shortest_edge; diff --git a/vae.hpp b/vae.hpp index e3d6d3231..c32846a30 100644 --- a/vae.hpp +++ b/vae.hpp @@ -103,7 +103,7 @@ class AttnBlock : public UnaryBlock { v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, h, w, in_channels] v = ggml_reshape_3d(ctx, v, c, h * w, n); // [N, h * w, in_channels] - //h_ = ggml_nn_attention(ctx, q, k, v, false); // [N, h * w, in_channels] + // h_ = ggml_nn_attention(ctx, q, k, v, false); // [N, h * w, in_channels] h_ = ggml_nn_attention_ext(ctx, q, k, v, 1, nullptr, false, true, false); h_ = ggml_cont(ctx, ggml_permute(ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w]