@@ -710,11 +710,10 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
710
710
711
711
float scale = (1 .0f / sqrt ((float )d_head));
712
712
713
- if (flash_attn) {
714
- // TODO: remove before merge
715
- LOG_DEBUG (" attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d" , L_q, L_k, n_head, C, d_head, N);
716
- }
717
- // is there anything oddly shaped??
713
+ // if (flash_attn) {
714
+ // LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
715
+ // }
716
+ // is there anything oddly shaped?? ping Green-Sky if you can trip this assert
718
717
GGML_ASSERT (((L_k % 256 == 0 ) && L_q == L_k) || !(L_k % 256 == 0 ));
719
718
720
719
bool can_use_flash_attn = true ;
@@ -725,17 +724,17 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
725
724
can_use_flash_attn = can_use_flash_attn && d_head <= 256 ; // double check
726
725
727
726
if (mask != nullptr ) {
728
- // TODO: figure out if we can bend t5 to work too
727
+ // TODO(Green-Sky) : figure out if we can bend t5 to work too
729
728
can_use_flash_attn = can_use_flash_attn && mask->ne [2 ] == 1 ;
730
729
can_use_flash_attn = can_use_flash_attn && mask->ne [3 ] == 1 ;
731
730
}
732
731
733
- // TODO: more pad or disable for funny tensor shapes
732
+ // TODO(Green-Sky) : more pad or disable for funny tensor shapes
734
733
735
734
ggml_tensor* kqv = nullptr ;
736
735
// GGML_ASSERT((flash_attn && can_use_flash_attn) || !flash_attn);
737
736
if (can_use_flash_attn && flash_attn) {
738
- LOG_DEBUG (" using flash attention" );
737
+ // LOG_DEBUG("using flash attention");
739
738
k = ggml_cast (ctx, k, GGML_TYPE_F16);
740
739
741
740
v = ggml_cont (ctx, ggml_permute (ctx, v, 0 , 2 , 1 , 3 )); // [N, n_head, L_k, d_head]
0 commit comments