Skip to content

Commit e904b86

Browse files
committed
remove old flash attention option and switch vae over to attn_ext
1 parent 408cb05 commit e904b86

File tree

4 files changed

+10
-43
lines changed

4 files changed

+10
-43
lines changed

CMakeLists.txt

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ option(SD_HIPBLAS "sd: rocm backend" OFF)
2929
option(SD_METAL "sd: metal backend" OFF)
3030
option(SD_VULKAN "sd: vulkan backend" OFF)
3131
option(SD_SYCL "sd: sycl backend" OFF)
32-
option(SD_FLASH_ATTN "sd: use flash attention for x4 less memory usage" OFF)
3332
option(SD_FAST_SOFTMAX "sd: x1.5 faster softmax, indeterministic (sometimes, same seed don't generate same image), cuda only" OFF)
3433
option(SD_BUILD_SHARED_LIBS "sd: build shared libs" OFF)
3534
#option(SD_BUILD_SERVER "sd: build server example" ON)
@@ -61,11 +60,6 @@ if (SD_HIPBLAS)
6160
endif()
6261
endif ()
6362

64-
if(SD_FLASH_ATTN)
65-
message("-- Use Flash Attention for memory optimization")
66-
add_definitions(-DSD_USE_FLASH_ATTENTION)
67-
endif()
68-
6963
set(SD_LIB stable-diffusion)
7064

7165
file(GLOB SD_LIB_SOURCES

ggml_extend.hpp

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -667,32 +667,6 @@ __STATIC_INLINE__ std::vector<struct ggml_tensor*> split_qkv(struct ggml_context
667667
return {q, k, v};
668668
}
669669

670-
// q: [N * n_head, n_token, d_head]
671-
// k: [N * n_head, n_k, d_head]
672-
// v: [N * n_head, d_head, n_k]
673-
// return: [N * n_head, n_token, d_head]
674-
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention(struct ggml_context* ctx,
675-
struct ggml_tensor* q,
676-
struct ggml_tensor* k,
677-
struct ggml_tensor* v,
678-
bool mask = false) {
679-
#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) && !defined(SD_USE_METAL) && !defined(SD_USE_VULKAN) && !defined(SD_USE_SYCL)
680-
struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, n_token, d_head]
681-
#else
682-
float d_head = (float)q->ne[0];
683-
684-
struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, n_token, n_k]
685-
kq = ggml_scale_inplace(ctx, kq, 1.0f / sqrt(d_head));
686-
if (mask) {
687-
kq = ggml_diag_mask_inf_inplace(ctx, kq, 0);
688-
}
689-
kq = ggml_soft_max_inplace(ctx, kq);
690-
691-
struct ggml_tensor* kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, n_token, d_head]
692-
#endif
693-
return kqv;
694-
}
695-
696670
// q: [N, L_q, C] or [N*n_head, L_q, d_head]
697671
// k: [N, L_k, C] or [N*n_head, L_k, d_head]
698672
// v: [N, L_k, C] or [N, L_k, n_head, d_head]
@@ -747,6 +721,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
747721
can_use_flash_attn = can_use_flash_attn && L_k % 256 == 0;
748722
can_use_flash_attn = can_use_flash_attn && d_head % 64 == 0; // double check
749723

724+
// cuda max d_head seems to be 256, cpu does seem to work with 512
725+
can_use_flash_attn = can_use_flash_attn && d_head <= 256; // double check
726+
750727
if (mask != nullptr) {
751728
// TODO: figure out if we can bend t5 to work too
752729
can_use_flash_attn = can_use_flash_attn && mask->ne[2] == 1;

stable-diffusion.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -181,13 +181,7 @@ class StableDiffusionGGML {
181181
LOG_DEBUG("Using CPU backend");
182182
backend = ggml_backend_cpu_init();
183183
}
184-
#ifdef SD_USE_FLASH_ATTENTION
185-
#if defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) || defined (SD_USE_SYCL) || defined(SD_USE_VULKAN)
186-
LOG_WARN("Flash Attention not supported with GPU Backend");
187-
#else
188-
LOG_INFO("Flash Attention enabled");
189-
#endif
190-
#endif
184+
191185
ModelLoader model_loader;
192186

193187
vae_tiling = vae_tiling_;

vae.hpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,12 @@ class AttnBlock : public UnaryBlock {
9999
k = ggml_cont(ctx, ggml_permute(ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels]
100100
k = ggml_reshape_3d(ctx, k, c, h * w, n); // [N, h * w, in_channels]
101101

102-
auto v = v_proj->forward(ctx, h_); // [N, in_channels, h, w]
103-
v = ggml_reshape_3d(ctx, v, h * w, c, n); // [N, in_channels, h * w]
102+
auto v = v_proj->forward(ctx, h_); // [N, in_channels, h, w]
103+
v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, h, w, in_channels]
104+
v = ggml_reshape_3d(ctx, v, c, h * w, n); // [N, h * w, in_channels]
104105

105-
h_ = ggml_nn_attention(ctx, q, k, v, false); // [N, h * w, in_channels]
106+
//h_ = ggml_nn_attention(ctx, q, k, v, false); // [N, h * w, in_channels]
107+
h_ = ggml_nn_attention_ext(ctx, q, k, v, 1, nullptr, false, true, false);
106108

107109
h_ = ggml_cont(ctx, ggml_permute(ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w]
108110
h_ = ggml_reshape_4d(ctx, h_, w, h, c, n); // [N, in_channels, h, w]
@@ -612,4 +614,4 @@ struct AutoEncoderKL : public GGMLRunner {
612614
};
613615
};
614616

615-
#endif
617+
#endif

0 commit comments

Comments
 (0)