Skip to content

Commit c1a3a9c

Browse files
llama: rwkv6: Use ggml_norm instead of ggml_group_norm
Co-authored-by: compilade <git@compilade.net>
1 parent e29b446 commit c1a3a9c

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/llama.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9496,10 +9496,10 @@ static struct ggml_tensor * llm_build_time_mix_rwkv6(
94969496
cur = ggml_view_1d(ctx, wkv_output, n_embed * n_tokens, 0);
94979497
*wkv_state = ggml_view_1d(ctx, wkv_output, n_embed * head_size * n_seqs, n_embed * n_tokens * sizeof(float));
94989498

9499-
// ggml_group_norm considers groups in the third dimension.
9500-
cur = ggml_reshape_4d(ctx, cur, n_embed / head_count, 1, head_count, n_tokens);
9501-
cur = ggml_group_norm(ctx, cur, head_count, 64e-5f);
9502-
// Convert back to a regular vector.
9499+
// group norm with head_count groups
9500+
cur = ggml_reshape_3d(ctx, cur, n_embed / head_count, head_count, n_tokens);
9501+
cur = ggml_norm(ctx, cur, 64e-5f);
9502+
// Convert back to regular vectors.
95039503
cur = ggml_reshape_2d(ctx, cur, n_embed, n_tokens);
95049504
cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);
95059505

0 commit comments

Comments
 (0)