Skip to content

Commit 0c74b04

Browse files
authored
vulkan: fix NaN issue in flash attention shader (#12776)
Use -FLT_MAX/2 rather than -inf as the initial value for computing the maximum.
1 parent 80b717d commit 0c74b04

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,11 @@ void main() {
227227

228228
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
229229

230+
// Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
231+
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
232+
230233
L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
231-
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-1.0/0.0);
234+
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(NEG_FLT_MAX_OVER_2);
232235

233236
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
234237

@@ -278,7 +281,7 @@ void main() {
278281
uint R = ((i + 1) * Br > N) ? (N % Br) : Br;
279282
uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
280283

281-
coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(-1.0/0.0), R, C);
284+
coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(NEG_FLT_MAX_OVER_2), R, C);
282285
}
283286

284287
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> rowmax, P, rowsum, eM;

0 commit comments

Comments
 (0)