Skip to content

Commit 0ba20ed

Browse files
committed
llama : compute BERT graph with F16 K, V
ggml-ci
1 parent 6cdabe6 commit 0ba20ed

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

llama.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6175,15 +6175,15 @@ struct llm_build_context {
61756175
}
61766176

61776177
struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
6178-
struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
6178+
struct ggml_tensor * k = ggml_cast(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3), GGML_TYPE_F16);
61796179

61806180
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
61816181
cb(kq, "kq", il);
61826182

61836183
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, nullptr, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
61846184
cb(kq, "kq_soft_max_ext", il);
61856185

6186-
struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens)));
6186+
struct ggml_tensor * v = ggml_cast(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens)), GGML_TYPE_F16);
61876187
cb(v, "v", il);
61886188

61896189
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_tokens, n_embd_head, n_head_kv), kq);

0 commit comments

Comments
 (0)