Skip to content

Commit af1968c

Browse files
committed
build_inp_attn_scale()
1 parent e6a2809 commit af1968c

File tree

3 files changed

+15
-5
lines changed

3 files changed

+15
-5
lines changed

src/llama-graph.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,6 +1024,19 @@ ggml_tensor * llm_graph_context::build_inp_pos() const {
10241024
return cur;
10251025
}
10261026

1027+
ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
1028+
auto inp = std::make_unique<llm_graph_input_attn_temp>(n_pos_per_token());
1029+
1030+
auto & cur = inp->attn_scale;
1031+
1032+
cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens*n_pos_per_token());
1033+
ggml_set_input(cur);
1034+
1035+
res->add_input(std::move(inp));
1036+
1037+
return cur;
1038+
}
1039+
10271040
ggml_tensor * llm_graph_context::build_inp_out_ids() const {
10281041
auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
10291042

src/llama-graph.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,7 @@ struct llm_graph_context {
487487

488488
ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const;
489489
ggml_tensor * build_inp_pos() const;
490+
ggml_tensor * build_inp_attn_scale() const;
490491
ggml_tensor * build_inp_out_ids() const;
491492
ggml_tensor * build_inp_mean() const;
492493
ggml_tensor * build_inp_cls() const;

src/llama-model.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4279,11 +4279,7 @@ struct llm_build_llama : public llm_graph_context {
42794279
// temperature tuning
42804280
ggml_tensor * inp_attn_scale = nullptr;
42814281
if (arch == LLM_ARCH_LLAMA4) {
4282-
auto inp = std::make_unique<llm_graph_input_attn_temp>(n_pos_per_token(), hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
4283-
inp_attn_scale = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens*n_pos_per_token());
4284-
ggml_set_input(inp_attn_scale);
4285-
inp->attn_scale = inp_attn_scale;
4286-
res->add_input(std::move(inp));
4282+
inp_attn_scale = build_inp_attn_scale();
42874283
}
42884284

42894285
auto * inp_attn = build_attn_inp_kv_unified();

0 commit comments

Comments
 (0)