Skip to content

Commit 2599750

Browse files
dranger003NeoZhangJianyu
authored andcommitted
llama : add support for the cohere2 model architecture (ggml-org#10900)
1 parent a64d62e commit 2599750

File tree

6 files changed

+221
-0
lines changed

6 files changed

+221
-0
lines changed

convert_hf_to_gguf.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3373,6 +3373,24 @@ def set_gguf_parameters(self):
33733373
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
33743374

33753375

3376+
@Model.register("Cohere2ForCausalLM")
3377+
class Cohere2Model(Model):
3378+
model_arch = gguf.MODEL_ARCH.COHERE2
3379+
3380+
def set_gguf_parameters(self):
3381+
super().set_gguf_parameters()
3382+
3383+
self.gguf_writer.add_logit_scale(self.hparams["logit_scale"])
3384+
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
3385+
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
3386+
3387+
rotary_pct = self.hparams["rotary_pct"]
3388+
hidden_size = self.hparams["hidden_size"]
3389+
num_attention_heads = self.hparams["num_attention_heads"]
3390+
self.gguf_writer.add_rope_dimension_count(int(rotary_pct * (hidden_size // num_attention_heads)))
3391+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
3392+
3393+
33763394
@Model.register("OlmoForCausalLM")
33773395
@Model.register("OLMoForCausalLM")
33783396
class OlmoModel(Model):

gguf-py/gguf/constants.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ class MODEL_ARCH(IntEnum):
255255
MAMBA = auto()
256256
XVERSE = auto()
257257
COMMAND_R = auto()
258+
COHERE2 = auto()
258259
DBRX = auto()
259260
OLMO = auto()
260261
OLMO2 = auto()
@@ -437,6 +438,7 @@ class MODEL_TENSOR(IntEnum):
437438
MODEL_ARCH.MAMBA: "mamba",
438439
MODEL_ARCH.XVERSE: "xverse",
439440
MODEL_ARCH.COMMAND_R: "command-r",
441+
MODEL_ARCH.COHERE2: "cohere2",
440442
MODEL_ARCH.DBRX: "dbrx",
441443
MODEL_ARCH.OLMO: "olmo",
442444
MODEL_ARCH.OLMO2: "olmo2",
@@ -1136,6 +1138,18 @@ class MODEL_TENSOR(IntEnum):
11361138
MODEL_TENSOR.ATTN_K_NORM,
11371139
MODEL_TENSOR.ATTN_Q_NORM,
11381140
],
1141+
MODEL_ARCH.COHERE2: [
1142+
MODEL_TENSOR.TOKEN_EMBD,
1143+
MODEL_TENSOR.OUTPUT_NORM,
1144+
MODEL_TENSOR.ATTN_NORM,
1145+
MODEL_TENSOR.ATTN_Q,
1146+
MODEL_TENSOR.ATTN_K,
1147+
MODEL_TENSOR.ATTN_V,
1148+
MODEL_TENSOR.ATTN_OUT,
1149+
MODEL_TENSOR.FFN_GATE,
1150+
MODEL_TENSOR.FFN_DOWN,
1151+
MODEL_TENSOR.FFN_UP,
1152+
],
11391153
MODEL_ARCH.DBRX: [
11401154
MODEL_TENSOR.TOKEN_EMBD,
11411155
MODEL_TENSOR.OUTPUT_NORM,

src/llama-arch.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
3939
{ LLM_ARCH_MAMBA, "mamba" },
4040
{ LLM_ARCH_XVERSE, "xverse" },
4141
{ LLM_ARCH_COMMAND_R, "command-r" },
42+
{ LLM_ARCH_COHERE2, "cohere2" },
4243
{ LLM_ARCH_DBRX, "dbrx" },
4344
{ LLM_ARCH_OLMO, "olmo" },
4445
{ LLM_ARCH_OLMO2, "olmo2" },
@@ -807,6 +808,21 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
807808
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
808809
},
809810
},
811+
{
812+
LLM_ARCH_COHERE2,
813+
{
814+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
815+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
816+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
817+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
818+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
819+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
820+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
821+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
822+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
823+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
824+
},
825+
},
810826
{
811827
LLM_ARCH_DBRX,
812828
{

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ enum llm_arch {
4343
LLM_ARCH_MAMBA,
4444
LLM_ARCH_XVERSE,
4545
LLM_ARCH_COMMAND_R,
46+
LLM_ARCH_COHERE2,
4647
LLM_ARCH_DBRX,
4748
LLM_ARCH_OLMO,
4849
LLM_ARCH_OLMO2,

src/llama-model.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,16 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) {
786786
default: model.type = e_model::MODEL_UNKNOWN;
787787
}
788788
} break;
789+
case LLM_ARCH_COHERE2:
790+
{
791+
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
792+
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
793+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
794+
switch (hparams.n_layer) {
795+
case 32: model.type = e_model::MODEL_8B; break;
796+
default: model.type = e_model::MODEL_UNKNOWN;
797+
}
798+
} break;
789799
case LLM_ARCH_DBRX:
790800
{
791801
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -2031,6 +2041,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
20312041
case LLM_ARCH_MINICPM:
20322042
case LLM_ARCH_XVERSE:
20332043
case LLM_ARCH_COMMAND_R:
2044+
case LLM_ARCH_COHERE2:
20342045
case LLM_ARCH_OLMO:
20352046
case LLM_ARCH_ARCTIC:
20362047
case LLM_ARCH_DEEPSEEK:

src/llama.cpp

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1561,6 +1561,32 @@ static bool llm_load_tensors(
15611561
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
15621562
}
15631563
} break;
1564+
case LLM_ARCH_COHERE2:
1565+
{
1566+
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
1567+
1568+
// output
1569+
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
1570+
// init output from the input tok embed
1571+
model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab },
1572+
llama_model_loader::TENSOR_DUPLICATED);
1573+
1574+
for (int i = 0; i < n_layer; ++i) {
1575+
auto & layer = model.layers[i];
1576+
1577+
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
1578+
1579+
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd }, 0);
1580+
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0);
1581+
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0);
1582+
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0);
1583+
1584+
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0);
1585+
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
1586+
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0);
1587+
}
1588+
}
1589+
break;
15641590
case LLM_ARCH_OLMO: // adapted from LLM_ARCH_LLAMA with norm params removed
15651591
{
15661592
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -7642,6 +7668,137 @@ struct llm_build_context {
76427668

76437669
}
76447670

7671+
struct ggml_cgraph * build_cohere2() {
7672+
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
7673+
7674+
const int64_t n_embd_head = hparams.n_embd_head_v;
7675+
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
7676+
const float f_logit_scale = hparams.f_logit_scale;
7677+
7678+
struct ggml_tensor * cur;
7679+
struct ggml_tensor * inpL;
7680+
7681+
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
7682+
7683+
// inp_pos - contains the positions
7684+
struct ggml_tensor * inp_pos = build_inp_pos();
7685+
7686+
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
7687+
// cohere2 requires different mask for layers using sliding window (SWA)
7688+
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
7689+
struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa();
7690+
7691+
// sliding window switch pattern
7692+
const int32_t sliding_window_pattern = 4;
7693+
7694+
for (int il = 0; il < n_layer; ++il) {
7695+
// three layers sliding window attention (window size 4096) and ROPE
7696+
// fourth layer uses global attention without positional embeddings
7697+
const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
7698+
struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;
7699+
7700+
// norm
7701+
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM, cb, il);
7702+
cb(cur, "attn_norm", il);
7703+
struct ggml_tensor * ffn_inp = cur;
7704+
7705+
// self-attention
7706+
{
7707+
// rope freq factors for 128k context
7708+
struct ggml_tensor * rope_factors = build_rope_factors(il);
7709+
7710+
// compute Q and K and RoPE them
7711+
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
7712+
cb(Qcur, "Qcur", il);
7713+
if (model.layers[il].bq) {
7714+
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
7715+
cb(Qcur, "Qcur", il);
7716+
}
7717+
7718+
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
7719+
cb(Kcur, "Kcur", il);
7720+
if (model.layers[il].bk) {
7721+
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
7722+
cb(Kcur, "Kcur", il);
7723+
}
7724+
7725+
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
7726+
cb(Vcur, "Vcur", il);
7727+
if (model.layers[il].bv) {
7728+
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
7729+
cb(Vcur, "Vcur", il);
7730+
}
7731+
7732+
if (is_sliding) {
7733+
Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
7734+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor,
7735+
beta_fast, beta_slow);
7736+
cb(Qcur, "Qcur", il);
7737+
7738+
Kcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
7739+
rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
7740+
attn_factor, beta_fast, beta_slow);
7741+
cb(Kcur, "Kcur", il);
7742+
} else {
7743+
// For non-sliding layers, just reshape without applying RoPE
7744+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
7745+
cb(Qcur, "Qcur", il);
7746+
7747+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
7748+
cb(Kcur, "Kcur", il);
7749+
}
7750+
7751+
cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur,
7752+
KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il);
7753+
}
7754+
7755+
if (il == n_layer - 1) {
7756+
// skip computing output for unused tokens
7757+
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
7758+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7759+
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
7760+
ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
7761+
}
7762+
7763+
struct ggml_tensor * attn_out = cur;
7764+
7765+
// feed-forward network
7766+
{
7767+
cur = llm_build_ffn(ctx0, lctx, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate,
7768+
NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR,
7769+
cb, il);
7770+
cb(cur, "ffn_out", il);
7771+
}
7772+
7773+
// add together residual + FFN + self-attention
7774+
cur = ggml_add(ctx0, cur, inpL);
7775+
cur = ggml_add(ctx0, cur, attn_out);
7776+
cur = lctx.cvec.apply_to(ctx0, cur, il);
7777+
cb(cur, "l_out", il);
7778+
7779+
// input for next layer
7780+
inpL = cur;
7781+
}
7782+
7783+
cur = inpL;
7784+
7785+
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM, cb, -1);
7786+
cb(cur, "result_norm", -1);
7787+
7788+
// lm_head
7789+
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
7790+
7791+
if (f_logit_scale) {
7792+
cur = ggml_scale(ctx0, cur, f_logit_scale);
7793+
}
7794+
7795+
cb(cur, "result_output", -1);
7796+
7797+
ggml_build_forward_expand(gf, cur);
7798+
7799+
return gf;
7800+
}
7801+
76457802
// ref: https://allenai.org/olmo
76467803
// based on the original build_llama() function, changes:
76477804
// * non-parametric layer norm
@@ -10393,6 +10550,10 @@ static struct ggml_cgraph * llama_build_graph(
1039310550
{
1039410551
result = llm.build_command_r();
1039510552
} break;
10553+
case LLM_ARCH_COHERE2:
10554+
{
10555+
result = llm.build_cohere2();
10556+
} break;
1039610557
case LLM_ARCH_DBRX:
1039710558
{
1039810559
result = llm.build_dbrx();

0 commit comments

Comments
 (0)