@@ -1561,6 +1561,32 @@ static bool llm_load_tensors(
1561
1561
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
1562
1562
}
1563
1563
} break;
1564
+ case LLM_ARCH_COHERE2:
1565
+ {
1566
+ model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
1567
+
1568
+ // output
1569
+ model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
1570
+ // init output from the input tok embed
1571
+ model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab },
1572
+ llama_model_loader::TENSOR_DUPLICATED);
1573
+
1574
+ for (int i = 0; i < n_layer; ++i) {
1575
+ auto & layer = model.layers[i];
1576
+
1577
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
1578
+
1579
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd }, 0);
1580
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0);
1581
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0);
1582
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0);
1583
+
1584
+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0);
1585
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
1586
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0);
1587
+ }
1588
+ }
1589
+ break;
1564
1590
case LLM_ARCH_OLMO: // adapted from LLM_ARCH_LLAMA with norm params removed
1565
1591
{
1566
1592
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -7642,6 +7668,137 @@ struct llm_build_context {
7642
7668
7643
7669
}
7644
7670
7671
+ struct ggml_cgraph * build_cohere2() {
7672
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
7673
+
7674
+ const int64_t n_embd_head = hparams.n_embd_head_v;
7675
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
7676
+ const float f_logit_scale = hparams.f_logit_scale;
7677
+
7678
+ struct ggml_tensor * cur;
7679
+ struct ggml_tensor * inpL;
7680
+
7681
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
7682
+
7683
+ // inp_pos - contains the positions
7684
+ struct ggml_tensor * inp_pos = build_inp_pos();
7685
+
7686
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
7687
+ // cohere2 requires different mask for layers using sliding window (SWA)
7688
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
7689
+ struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa();
7690
+
7691
+ // sliding window switch pattern
7692
+ const int32_t sliding_window_pattern = 4;
7693
+
7694
+ for (int il = 0; il < n_layer; ++il) {
7695
+ // three layers sliding window attention (window size 4096) and ROPE
7696
+ // fourth layer uses global attention without positional embeddings
7697
+ const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
7698
+ struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;
7699
+
7700
+ // norm
7701
+ cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM, cb, il);
7702
+ cb(cur, "attn_norm", il);
7703
+ struct ggml_tensor * ffn_inp = cur;
7704
+
7705
+ // self-attention
7706
+ {
7707
+ // rope freq factors for 128k context
7708
+ struct ggml_tensor * rope_factors = build_rope_factors(il);
7709
+
7710
+ // compute Q and K and RoPE them
7711
+ struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
7712
+ cb(Qcur, "Qcur", il);
7713
+ if (model.layers[il].bq) {
7714
+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
7715
+ cb(Qcur, "Qcur", il);
7716
+ }
7717
+
7718
+ struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
7719
+ cb(Kcur, "Kcur", il);
7720
+ if (model.layers[il].bk) {
7721
+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
7722
+ cb(Kcur, "Kcur", il);
7723
+ }
7724
+
7725
+ struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
7726
+ cb(Vcur, "Vcur", il);
7727
+ if (model.layers[il].bv) {
7728
+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
7729
+ cb(Vcur, "Vcur", il);
7730
+ }
7731
+
7732
+ if (is_sliding) {
7733
+ Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
7734
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor,
7735
+ beta_fast, beta_slow);
7736
+ cb(Qcur, "Qcur", il);
7737
+
7738
+ Kcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
7739
+ rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
7740
+ attn_factor, beta_fast, beta_slow);
7741
+ cb(Kcur, "Kcur", il);
7742
+ } else {
7743
+ // For non-sliding layers, just reshape without applying RoPE
7744
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
7745
+ cb(Qcur, "Qcur", il);
7746
+
7747
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
7748
+ cb(Kcur, "Kcur", il);
7749
+ }
7750
+
7751
+ cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur,
7752
+ KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il);
7753
+ }
7754
+
7755
+ if (il == n_layer - 1) {
7756
+ // skip computing output for unused tokens
7757
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
7758
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7759
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
7760
+ ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
7761
+ }
7762
+
7763
+ struct ggml_tensor * attn_out = cur;
7764
+
7765
+ // feed-forward network
7766
+ {
7767
+ cur = llm_build_ffn(ctx0, lctx, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate,
7768
+ NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR,
7769
+ cb, il);
7770
+ cb(cur, "ffn_out", il);
7771
+ }
7772
+
7773
+ // add together residual + FFN + self-attention
7774
+ cur = ggml_add(ctx0, cur, inpL);
7775
+ cur = ggml_add(ctx0, cur, attn_out);
7776
+ cur = lctx.cvec.apply_to(ctx0, cur, il);
7777
+ cb(cur, "l_out", il);
7778
+
7779
+ // input for next layer
7780
+ inpL = cur;
7781
+ }
7782
+
7783
+ cur = inpL;
7784
+
7785
+ cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM, cb, -1);
7786
+ cb(cur, "result_norm", -1);
7787
+
7788
+ // lm_head
7789
+ cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
7790
+
7791
+ if (f_logit_scale) {
7792
+ cur = ggml_scale(ctx0, cur, f_logit_scale);
7793
+ }
7794
+
7795
+ cb(cur, "result_output", -1);
7796
+
7797
+ ggml_build_forward_expand(gf, cur);
7798
+
7799
+ return gf;
7800
+ }
7801
+
7645
7802
// ref: https://allenai.org/olmo
7646
7803
// based on the original build_llama() function, changes:
7647
7804
// * non-parametric layer norm
@@ -10393,6 +10550,10 @@ static struct ggml_cgraph * llama_build_graph(
10393
10550
{
10394
10551
result = llm.build_command_r();
10395
10552
} break;
10553
+ case LLM_ARCH_COHERE2:
10554
+ {
10555
+ result = llm.build_cohere2();
10556
+ } break;
10396
10557
case LLM_ARCH_DBRX:
10397
10558
{
10398
10559
result = llm.build_dbrx();
0 commit comments