From 1f0fea70fb761d10e2264cbdcf4852ed32706c89 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Thu, 1 Aug 2024 10:43:42 -0400 Subject: [PATCH 01/37] llama : initial Mamba-2 support --- convert_hf_to_gguf.py | 67 ++++++++ ggml/include/ggml.h | 3 +- ggml/src/ggml.c | 193 ++++++++++++++-------- gguf-py/gguf/constants.py | 19 +++ gguf-py/gguf/gguf_writer.py | 3 + gguf-py/gguf/tensor_mapping.py | 6 +- src/llama.cpp | 291 +++++++++++++++++++++++++++++++-- 7 files changed, 495 insertions(+), 87 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 108c822cff5d2..0ac64574a3043 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2788,6 +2788,73 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(new_name, data_torch)] +@Model.register("Mamba2ForCausalLM") +class Mamba2Model(Model): + model_arch = gguf.MODEL_ARCH.MAMBA2 + + def set_vocab(self): + vocab_size = self.hparams["vocab_size"] + # Round vocab size to next multiple of 16 + pad_vocab = self.hparams.get("pad_vocab_size_multiple", 16) + # pad using ceiling division + # ref: https://stackoverflow.com/a/17511341/22827863 + vocab_size = -(vocab_size // -pad_vocab) * pad_vocab + self.hparams["vocab_size"] = vocab_size + + if (self.dir_model / "tokenizer.json").is_file(): + self._set_vocab_gpt2() + elif (self.dir_model / "tokenizer.model").is_file(): + self._set_vocab_sentencepiece() + elif (self.dir_model / "tokenizer.model.v3").is_file(): + # mamba-codestral + raise NotImplementedError(f"Please rename {self.dir_model / 'tokenizer.model.v3'} to {self.dir_model / 'tokenizer.model'}") + else: + # Use the GPT-NeoX tokenizer when no tokenizer files are present + self._set_vocab_builtin("gpt-neox", vocab_size) + + def set_gguf_parameters(self): + d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) + d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4 + d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model + d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128 + head_dim = self.find_hparam(["head_dim"], optional=True) or 64 + n_group = self.find_hparam(["n_groups"], optional=True) or 1 + + rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5 + + # Fail early for models which don't have a block expansion factor of 2 + # TODO: does this really matter? + assert d_inner == 2 * d_model + assert d_inner % head_dim == 0 + + self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default + self.gguf_writer.add_embedding_length(d_model) + self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading + self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_ssm_conv_kernel(d_conv) + self.gguf_writer.add_ssm_inner_size(d_inner) + self.gguf_writer.add_ssm_state_size(d_state) + self.gguf_writer.add_ssm_time_step_rank(d_inner // head_dim) + self.gguf_writer.add_ssm_group_count(n_group) + self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) + self.gguf_writer.add_file_type(self.ftype) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + if name.endswith(".dt_bias"): + name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias" + + new_name = self.map_tensor_name(name) + + if name.endswith(".A_log"): + logger.debug("A_log --> A ==> " + new_name) + data_torch = -torch.exp(data_torch) + + yield (new_name, data_torch) + + @Model.register("CohereForCausalLM") class CommandR2Model(Model): model_arch = gguf.MODEL_ARCH.COMMAND_R diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index b8a21a2ccc3f0..59e0022dd4286 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1787,7 +1787,8 @@ extern "C" { struct ggml_tensor * dt, struct ggml_tensor * A, struct ggml_tensor * B, - struct ggml_tensor * C); + struct ggml_tensor * C, + struct ggml_tensor * D); // partition into non-overlapping windows with padding if needed // example: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index d63c917a5705a..6668209081b6c 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -7270,32 +7270,48 @@ struct ggml_tensor * ggml_ssm_scan( struct ggml_tensor * dt, struct ggml_tensor * A, struct ggml_tensor * B, - struct ggml_tensor * C) { + struct ggml_tensor * C, + struct ggml_tensor * D) { GGML_ASSERT(ggml_is_contiguous(s)); - GGML_ASSERT(ggml_is_contiguous(x)); GGML_ASSERT(ggml_is_contiguous(dt)); GGML_ASSERT(ggml_is_contiguous(A)); - GGML_ASSERT(ggml_is_matrix(A)); - GGML_ASSERT(ggml_is_3d(B)); - GGML_ASSERT(ggml_is_3d(s)); + GGML_ASSERT(x->nb[0] == ggml_type_size(x->type)); GGML_ASSERT(B->nb[0] == ggml_type_size(B->type)); GGML_ASSERT(C->nb[0] == ggml_type_size(C->type)); - GGML_ASSERT(ggml_are_same_shape(x, dt)); + GGML_ASSERT(x->nb[1] == x->ne[0]*x->nb[0]); + GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]); + GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]); GGML_ASSERT(ggml_are_same_shape(B, C)); { const int64_t d_state = s->ne[0]; - const int64_t d_inner = s->ne[1]; - const int64_t n_seq_tokens = x->ne[1]; - const int64_t n_seqs = x->ne[2]; - - GGML_ASSERT(s->ne[2] == n_seqs); - GGML_ASSERT(x->ne[0] == d_inner); - GGML_ASSERT(A->ne[0] == d_state); - GGML_ASSERT(A->ne[1] == d_inner); + const int64_t head_dim = x->ne[0]; + const int64_t n_head = x->ne[1]; + const int64_t n_seq_tokens = x->ne[2]; + const int64_t n_seqs = x->ne[3]; + + GGML_ASSERT(dt->ne[0] == n_head); + GGML_ASSERT(dt->ne[1] == n_seq_tokens); + GGML_ASSERT(dt->ne[2] == n_seqs); + GGML_ASSERT(ggml_is_3d(dt)); + GGML_ASSERT(s->ne[1] == head_dim); + GGML_ASSERT(s->ne[2] == n_head); + GGML_ASSERT(s->ne[3] == n_seqs); GGML_ASSERT(B->ne[0] == d_state); - GGML_ASSERT(B->ne[1] == n_seq_tokens); - GGML_ASSERT(B->ne[2] == n_seqs); + GGML_ASSERT(B->ne[2] == n_seq_tokens); + GGML_ASSERT(B->ne[3] == n_seqs); + GGML_ASSERT(D->ne[0] == n_head); + GGML_ASSERT(ggml_is_vector(D)); + + if (ggml_is_vector(A)) { + // Mamba-2 + GGML_ASSERT(A->ne[0] == n_head); + } else { + // Mamba-1 + GGML_ASSERT(A->ne[0] == d_state); + GGML_ASSERT(A->ne[1] == n_head); + GGML_ASSERT(ggml_is_matrix(A)); + } } bool is_node = false; @@ -7316,6 +7332,7 @@ struct ggml_tensor * ggml_ssm_scan( result->src[3] = A; result->src[4] = B; result->src[5] = C; + result->src[6] = D; return result; } @@ -15840,20 +15857,25 @@ static void ggml_compute_forward_ssm_conv( static void ggml_compute_forward_ssm_scan_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; // s - const struct ggml_tensor * src1 = dst->src[1]; // x - const struct ggml_tensor * src2 = dst->src[2]; // dt - const struct ggml_tensor * src3 = dst->src[3]; // A - const struct ggml_tensor * src4 = dst->src[4]; // B - const struct ggml_tensor * src5 = dst->src[5]; // C + const struct ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs} + const struct ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs} + const struct ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs} + const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {n_head} + const struct ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs} + const struct ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs} + const struct ggml_tensor * src6 = dst->src[6]; // D {n_head} const int ith = params->ith; const int nth = params->nth; - const int64_t nc = src0->ne[0]; // d_state - const int64_t nr = src0->ne[1]; // d_inner - const int64_t n_t = src1->ne[1]; // number of tokens per sequence - const int64_t n_s = src0->ne[2]; // number of sequences in the batch + const int64_t nc = src0->ne[0]; // d_state + const int64_t nr = src0->ne[1]; // dim + const int64_t nh = src1->ne[1]; // n_head + const int64_t ng = src4->ne[1]; + const int64_t nt = src1->ne[2]; // number of tokens per sequence + const int64_t ns = src0->ne[3]; // number of sequences in the batch + + const int64_t s_off = ggml_element_size(src1) * ggml_nelements(src1); GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); @@ -15862,51 +15884,86 @@ static void ggml_compute_forward_ssm_scan_f32( GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); - // required for the dot product between s and C - GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); - // required for per-sequence offsets for states - GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float)); - // required to get correct offset for state destination (i.e. src1->nb[3]) - GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - const int ir = ir1 - ir0; - - for (int i3 = 0; i3 < n_s; ++i3) { - for (int i2 = 0; i2 < n_t; ++i2) { - const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} - const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} - const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} - const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} - float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} - - // use the output as the source for the next token-wise iterations + GGML_ASSERT(src6->nb[0] == sizeof(float)); + // allows optimizing the modulo since n_group should be a power of 2 + GGML_ASSERT((ng & -ng) == ng); + + // heads per thread + const int dh = (nh + nth - 1)/nth; + + // head range for this thread + const int ih0 = dh*ith; + const int ih1 = MIN(ih0 + dh, nh); + + for (int i3 = 0; i3 < ns; ++i3) { + for (int i2 = 0; i2 < nt; ++i2) { + const float * s0 = (const float *) ((const char *) src0->data + i3*(src0->nb[3])); // {d_state, dim, nh, ns} + const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns} + const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns} + const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {nh} + const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns} + const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns} + const float * D = (const float *) ((const char *) src6->data); // {nh} + float * y = (float *) ((char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns} + float * s = (float *) ((char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns} + + // use the output as the source when it's not the first token-wise iteration if (i2 > 0) { s0 = s; } - // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 - float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; - float x_dt = x[i1] * dt_soft_plus; - float sumf = 0.0f; - // d_state - for (int i0 = 0; i0 < nc; ++i0) { - int i = i0 + i1*nc; - // state = prev_state * dA + dB * x - float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); - // y = rowwise_dotprod(state, C) - sumf += state * C[i0]; - s[i] = state; + if (ggml_is_vector(src3)) { + // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop + + // n_head + for (int h = ih0; h < ih1; ++h) { + // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 + const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; + const float dA = expf(dt_soft_plus * A[h]); + + // TODO: SIMD implementation + // dim + for (int i1 = 0; i1 < nr; ++i1) { + const int i = i1 + h*nr; + const float x_dt = x[i] * dt_soft_plus; + float sumf = 0.0f; + // d_state + for (int i0 = 0; i0 < nc; ++i0) { + const int ii = i0 + i*nc; + const int ig = i0 + (h & (ng - 1))*nc; + // state = prev_state * dA + dB * x + const float state = (s0[ii] * dA) + (B[ig] * x_dt); + // y = rowwise_dotprod(state, C) + sumf += state * C[ig]; + s[ii] = state; + } + y[i] = sumf + x[i] * D[h]; + } + } + } else { + // Mamba-1 has an element-wise decay factor for the states + + // n_head + for (int h = ih0; h < ih1; ++h) { + // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 + const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; + + // dim + for (int i1 = 0; i1 < nr; ++i1) { + const int i = i1 + h*nr; + const float x_dt = x[i] * dt_soft_plus; + float sumf = 0.0f; + // d_state + for (int i0 = 0; i0 < nc; ++i0) { + const int ii = i0 + i*nc; + const int ig = i0 + (h & (ng - 1))*nc; + // state = prev_state * dA + dB * x + const float state = (s0[ii] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt); + // y = rowwise_dotprod(state, C) + sumf += state * C[ig]; + s[ii] = state; + } + y[i] = sumf + x[i] * D[h]; + } } - y[i1] = sumf; } } } diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index b55effa9907b1..32a2fb20f84b9 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -130,6 +130,7 @@ class SSM: INNER_SIZE = "{arch}.ssm.inner_size" STATE_SIZE = "{arch}.ssm.state_size" TIME_STEP_RANK = "{arch}.ssm.time_step_rank" + GROUP_COUNT = "{arch}.ssm.group_count" DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms" class Tokenizer: @@ -208,6 +209,7 @@ class MODEL_ARCH(IntEnum): GEMMA2 = auto() STARCODER2 = auto() MAMBA = auto() + MAMBA2 = auto() XVERSE = auto() COMMAND_R = auto() DBRX = auto() @@ -269,6 +271,7 @@ class MODEL_TENSOR(IntEnum): SSM_DT = auto() SSM_A = auto() SSM_D = auto() + SSM_NORM = auto() SSM_OUT = auto() ATTN_Q_A = auto() ATTN_Q_B = auto() @@ -338,6 +341,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.GEMMA2: "gemma2", MODEL_ARCH.STARCODER2: "starcoder2", MODEL_ARCH.MAMBA: "mamba", + MODEL_ARCH.MAMBA2: "mamba2", MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.DBRX: "dbrx", @@ -399,6 +403,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt", MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", + MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", MODEL_TENSOR.ATTN_Q_A: "blk.{bid}.attn_q_a", MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b", @@ -869,6 +874,19 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_D, MODEL_TENSOR.SSM_OUT, ], + MODEL_ARCH.MAMBA2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.SSM_IN, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_D, + MODEL_TENSOR.SSM_NORM, + MODEL_TENSOR.SSM_OUT, + ], MODEL_ARCH.XVERSE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -1373,6 +1391,7 @@ def get_type(val: Any) -> GGUFValueType: KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK +KEY_SSM_GROUP_COUNT = Keys.SSM.GROUP_COUNT KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS # tokenization diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index af3b98c679b0b..ea788918dbf2c 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -730,6 +730,9 @@ def add_ssm_state_size(self, value: int) -> None: def add_ssm_time_step_rank(self, value: int) -> None: self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value) + def add_ssm_group_count(self, value: int) -> None: + self.add_uint32(Keys.SSM.GROUP_COUNT.format(arch=self.arch), value) + def add_ssm_dt_b_c_rms(self, value: bool) -> None: self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index a4f185c0658a3..8593a80a5ab8f 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -396,7 +396,7 @@ class TensorNameMap: "encoder.layers.{bid}.norm2", # nomic-bert "transformer.decoder_layer.{bid}.rms_norm_3", # Grok "encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2 - "encoder.layer.{bid}.layer_norm_2" # jina-v2-code + "encoder.layer.{bid}.layer_norm_2", # jina-v2-code ), MODEL_TENSOR.SSM_IN: ( @@ -429,6 +429,10 @@ class TensorNameMap: "backbone.layers.{bid}.mixer.D", ), + MODEL_TENSOR.SSM_NORM: ( + "backbone.layers.{bid}.mixer.norm", # mamba2 + ), + MODEL_TENSOR.SSM_OUT: ( "model.layers.{bid}.out_proj", "backbone.layers.{bid}.mixer.out_proj", diff --git a/src/llama.cpp b/src/llama.cpp index bd7f1508b2644..5be0ef7a2ac7a 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -198,6 +198,7 @@ enum llm_arch { LLM_ARCH_GEMMA2, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, + LLM_ARCH_MAMBA2, LLM_ARCH_XVERSE, LLM_ARCH_COMMAND_R, LLM_ARCH_DBRX, @@ -245,6 +246,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GEMMA2, "gemma2" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, + { LLM_ARCH_MAMBA2, "mamba2" }, { LLM_ARCH_XVERSE, "xverse" }, { LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_DBRX, "dbrx" }, @@ -328,6 +330,7 @@ enum llm_kv { LLM_KV_SSM_CONV_KERNEL, LLM_KV_SSM_STATE_SIZE, LLM_KV_SSM_TIME_STEP_RANK, + LLM_KV_SSM_GROUP_COUNT, LLM_KV_SSM_DT_B_C_RMS, LLM_KV_TOKENIZER_MODEL, @@ -427,7 +430,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" }, { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" }, { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, - { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, + { LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" }, + { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, @@ -517,6 +521,7 @@ enum llm_tensor { LLM_TENSOR_SSM_DT, LLM_TENSOR_SSM_A, LLM_TENSOR_SSM_D, + LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, LLM_TENSOR_ATTN_Q_A, LLM_TENSOR_ATTN_Q_B, @@ -1068,6 +1073,22 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, }, }, + { + LLM_ARCH_MAMBA2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + }, + }, { LLM_ARCH_XVERSE, { @@ -2239,6 +2260,7 @@ struct llama_hparams { uint32_t ssm_d_inner = 0; uint32_t ssm_d_state = 0; uint32_t ssm_dt_rank = 0; + uint32_t ssm_n_group = 0; bool ssm_dt_b_c_rms = false; float f_clamp_kqv = 0.0f; @@ -2289,6 +2311,7 @@ struct llama_hparams { if (this->ssm_d_inner != other.ssm_d_inner) return true; if (this->ssm_d_state != other.ssm_d_state) return true; if (this->ssm_dt_rank != other.ssm_dt_rank) return true; + if (this->ssm_n_group != other.ssm_n_group) return true; if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true; if (this->dec_start_token_id != other.dec_start_token_id) return true; @@ -2357,7 +2380,7 @@ struct llama_hparams { // corresponds to Mamba's conv_states size // TODO: maybe support other convolution strides than 1 // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed - return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; + return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state); } uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings @@ -2419,6 +2442,7 @@ struct llama_layer { struct ggml_tensor * ffn_sub_norm; struct ggml_tensor * attn_norm_cross; struct ggml_tensor * attn_norm_enc; + struct ggml_tensor * ssm_norm; // attention struct ggml_tensor * wq; @@ -5573,6 +5597,38 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_MAMBA2: + { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 24: + switch (hparams.n_embd) { + case 768: model.type = e_model::MODEL_SMALL; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 48: + switch (hparams.n_embd) { + case 1024: model.type = e_model::MODEL_MEDIUM; break; + case 1536: model.type = e_model::MODEL_LARGE; break; + case 2048: model.type = e_model::MODEL_XL; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 64: + switch (hparams.n_embd) { + case 2560: model.type = e_model::MODEL_3B; break; + case 4096: model.type = e_model::MODEL_7B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } case LLM_ARCH_XVERSE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -6404,6 +6460,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); + LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); } @@ -7639,7 +7696,7 @@ static bool llm_load_tensors( layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}); - layer.ssm_conv1d = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}); + layer.ssm_conv1d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}); layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}); layer.ssm_x = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}); @@ -7648,9 +7705,61 @@ static bool llm_load_tensors( layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}); // no "weight" suffix for these - layer.ssm_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}); + layer.ssm_a = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}); layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, i), {d_inner}); + // out_proj + layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}); + } + } break; + case LLM_ARCH_MAMBA2: + { + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_head = hparams.ssm_dt_rank; + const int64_t head_dim = n_embd / n_head; + const int64_t n_group = hparams.ssm_n_group; + const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_head; + + // only an expansion factor of 2 is supported for now + GGML_ASSERT(2 * n_embd == d_inner); + + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (model.output == NULL) { + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + // norm + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + + layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}); + + layer.ssm_conv1d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}); + layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}); + + layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {n_head}); + + // no "weight" suffix for these + layer.ssm_a = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_A, i), {n_head}); + layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, i), {n_head}); + + layer.ssm_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner}); + // out_proj layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}); } @@ -9041,6 +9150,8 @@ static struct ggml_tensor * llm_build_mamba( const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_state = hparams.ssm_d_state; const int64_t dt_rank = hparams.ssm_dt_rank; + const int64_t n_head = d_inner; + const int64_t head_dim = 1; const int64_t n_seqs = batch.n_seqs; // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers) const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms; @@ -9064,7 +9175,7 @@ static struct ggml_tensor * llm_build_mamba( struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx, graph, ssm_states_all, state_copy, state_mask, hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs); - ssm = ggml_reshape_3d(ctx, ssm, d_state, d_inner, n_seqs); + ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, n_seqs); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs); @@ -9113,8 +9224,8 @@ static struct ggml_tensor * llm_build_mamba( struct ggml_tensor * x_db = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_x, x); // split struct ggml_tensor * dt = ggml_view_3d(ctx, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0); - struct ggml_tensor * B = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank); - struct ggml_tensor * C = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state)); + struct ggml_tensor * B = ggml_view_4d(ctx, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank); + struct ggml_tensor * C = ggml_view_4d(ctx, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state)); // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers if (ssm_dt_b_c_rms) { @@ -9127,23 +9238,23 @@ static struct ggml_tensor * llm_build_mamba( dt = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_dt, dt); dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); + x = ggml_reshape_4d(ctx, x, head_dim, n_head, n_seq_tokens, n_seqs); + // Custom operator to optimize the parallel associative scan // as described in the Annex D of the Mamba paper. // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C); + struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d); // store last states ggml_build_forward_expand(graph, ggml_cpy(ctx, - ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, x->nb[3]), + ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, x->nb[3]*x->ne[3]), ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); - struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0); + struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[2], x->nb[3], 0); // TODO: skip computing output earlier for unused tokens - // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs} - y = ggml_add(ctx, y, ggml_mul(ctx, x, model.layers[il].ssm_d)); y = ggml_mul(ctx, y, ggml_silu(ctx, ggml_cont(ctx, z))); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} @@ -9157,6 +9268,136 @@ static struct ggml_tensor * llm_build_mamba( return cur; } +static struct ggml_tensor * llm_build_mamba2( + struct ggml_context * ctx, + struct llama_context & lctx, + const llama_ubatch & batch, + struct ggml_cgraph * graph, + struct ggml_tensor * cur, + struct ggml_tensor * state_copy, + struct ggml_tensor * state_mask, + int32_t kv_head, + int32_t n_kv, + const llm_build_cb & cb, + int il) { + const llama_model & model = lctx.model; + const llama_hparams & hparams = model.hparams; + const llama_kv_cache & kv = lctx.kv_self; + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_head = hparams.ssm_dt_rank; + const int64_t head_dim = d_inner / n_head; // FIXME + const int64_t n_group = hparams.ssm_n_group; + const int64_t n_seqs = batch.n_seqs; + + const int64_t n_seq_tokens = batch.n_seq_tokens; + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(batch.equal_seqs); + GGML_ASSERT(batch.n_tokens == n_seq_tokens * n_seqs); + + struct ggml_tensor * conv_states_all = kv.k_l[il]; + struct ggml_tensor * ssm_states_all = kv.v_l[il]; + + // (ab)using the KV cache to store the states + struct ggml_tensor * conv = llm_build_copy_mask_state(ctx, + graph, conv_states_all, state_copy, state_mask, + hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs); + conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); + struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx, + graph, ssm_states_all, state_copy, state_mask, + hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs); + ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, n_seqs); + + // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} + cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs); + + // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads + + // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs} + struct ggml_tensor * zxBCdt = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_in, cur); + + // split the above in three + struct ggml_tensor * z = ggml_view_3d(ctx, zxBCdt, d_inner, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], 0); + struct ggml_tensor * xBC = ggml_view_3d(ctx, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner*ggml_element_size(zxBCdt)); + struct ggml_tensor * dt = ggml_view_3d(ctx, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt)); + + // conv + { + // => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs} + struct ggml_tensor * conv_x = ggml_concat(ctx, conv, ggml_transpose(ctx, xBC), 0); + + // copy last (d_conv - 1) columns back into the state cache + struct ggml_tensor * last_conv = ggml_view_3d(ctx, conv_x, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0])); + + ggml_build_forward_expand(graph, + ggml_cpy(ctx, last_conv, + ggml_view_1d(ctx, conv_states_all, + (d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs), + kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all)))); + + // 1D convolution + // The equivalent is to make a self-overlapping view of conv_x + // over d_conv columns at each stride in the 3rd dimension, + // then element-wise multiply that with the conv1d weight, + // then sum the elements of each row, + // (the last two steps are a dot product over rows (also doable with mul_mat)) + // then permute away the ne[0] dimension, + // and then you're left with the resulting x tensor. + // For simultaneous sequences, all sequences need to have the same length. + xBC = ggml_ssm_conv(ctx, conv_x, model.layers[il].ssm_conv1d); + + // bias + xBC = ggml_add(ctx, xBC, model.layers[il].ssm_conv1d_b); + + xBC = ggml_silu(ctx, xBC); + } + + // ssm + { + // These correspond to V K Q in SSM/attention duality + struct ggml_tensor * x = ggml_view_4d(ctx, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*xBC->nb[0], xBC->nb[1], xBC->nb[2], 0); + struct ggml_tensor * B = ggml_view_4d(ctx, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_inner*ggml_element_size(xBC)); + struct ggml_tensor * C = ggml_view_4d(ctx, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC)); + + // {n_head, n_seq_tokens, n_seqs} + dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); + + // TODO: use semistructured matrices to implement state-space duality + // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d); + + // store last states + ggml_build_forward_expand(graph, + ggml_cpy(ctx, + ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]), + ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); + + struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0); + + // TODO: skip computing output earlier for unused tokens + + y = ggml_mul(ctx, y, ggml_silu(ctx, ggml_cont(ctx, z))); + + // grouped RMS norm + y = ggml_reshape_4d(ctx, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); + y = llm_build_norm(ctx, y, hparams, + model.layers[il].ssm_norm, NULL, + LLM_NORM_RMS, cb, il); + y = ggml_reshape_3d(ctx, y, d_inner, n_seq_tokens, n_seqs); + + // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} + cur = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_out, y); + } + + // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} + cur = ggml_reshape_2d(ctx, cur, cur->ne[0], n_seq_tokens * n_seqs); + cb(cur, "mamba_out", il); + + return cur; +} + struct llm_build_context { const llama_model & model; llama_context & lctx; @@ -12788,7 +13029,7 @@ struct llm_build_context { return gf; } - struct ggml_cgraph * build_mamba() { + struct ggml_cgraph * build_mamba(int32_t version = 1) { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); struct ggml_tensor * cur; @@ -12807,9 +13048,19 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "attn_norm", il); - cur = llm_build_mamba(ctx0, lctx, batch, gf, cur, - state_copy, state_mask, - kv_head, n_kv, cb, il); + switch (version) { + case 2: + cur = llm_build_mamba2(ctx0, lctx, batch, gf, cur, + state_copy, state_mask, + kv_head, n_kv, cb, il); + break; + case 1: + default: + cur = llm_build_mamba(ctx0, lctx, batch, gf, cur, + state_copy, state_mask, + kv_head, n_kv, cb, il); + break; + } if (il == n_layer - 1) { // skip computing output for unused tokens @@ -14858,7 +15109,11 @@ static struct ggml_cgraph * llama_build_graph( } break; case LLM_ARCH_MAMBA: { - result = llm.build_mamba(); + result = llm.build_mamba(/* version */ 1); + } break; + case LLM_ARCH_MAMBA2: + { + result = llm.build_mamba(/* version */ 2); } break; case LLM_ARCH_XVERSE: { @@ -17954,6 +18209,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_REFACT: case LLM_ARCH_BLOOM: case LLM_ARCH_MAMBA: + case LLM_ARCH_MAMBA2: case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_T5: case LLM_ARCH_T5ENCODER: @@ -18125,6 +18381,7 @@ llama_token llama_model_decoder_start_token(const struct llama_model * model) { bool llama_model_is_recurrent(const struct llama_model * model) { switch (model->arch) { + case LLM_ARCH_MAMBA2: case LLM_ARCH_MAMBA: return true; default: return false; } From dceff23faec99945d3161d24ea209a0c433546db Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 18 Aug 2024 21:49:39 -0400 Subject: [PATCH 02/37] ggml : SIMD ggml_ssm_scan for Mamba-2 * ggml : improve ggml_mul speed when masking recurrent states --- ggml/src/ggml.c | 95 +++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 80 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 6668209081b6c..f8e708088b357 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -10207,7 +10207,37 @@ static void ggml_compute_forward_mul_f32( GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); - if (nb10 == sizeof(float)) { + if (ne00 > 1 && ne10 == 1) { + // fast broadcast path + for (int64_t ir = ith; ir < nr; ir += nth) { + // src0 and dst are same shape => same indices + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + + const float scale = src1_ptr[0]; + + if (scale == 0.0f) { + // NOTE: this also sets NANs to zero, which is not compliant with IEEE754, + // but it is useful when resetting the state of recurrent models. + memset((char *)dst->data + ir*nb1, 0, nb1); + } else { + if (dst->data != src0->data) { + // src0 is same shape as dst => same indices + memcpy((char *)dst->data + ir*nb1, (char *)src0->data + ir*nb01, ne0 * sizeof(float)); + } + if (scale != 1.0f) { + ggml_vec_scale_f32(ne0, (float *) ((char *) dst->data + ir*nb1), scale); + } + } + } + } else if (nb10 == sizeof(float)) { for (int64_t ir = ith; ir < nr; ir += nth) { // src0 and dst are same shape => same indices const int64_t i03 = ir/(ne02*ne01); @@ -15919,23 +15949,56 @@ static void ggml_compute_forward_ssm_scan_f32( const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; const float dA = expf(dt_soft_plus * A[h]); - // TODO: SIMD implementation // dim for (int i1 = 0; i1 < nr; ++i1) { - const int i = i1 + h*nr; - const float x_dt = x[i] * dt_soft_plus; + const int ii = i1 + h*nr; + const float x_dt = x[ii] * dt_soft_plus; float sumf = 0.0f; +#if defined(GGML_SIMD) + const int np = (nc & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; + + GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA); + GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt); + + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + GGML_F32_VEC az[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc); + ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc); + az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc); + + ax[j] = GGML_F32_VEC_MUL(ax[j], adA); + ay[j] = GGML_F32_VEC_MUL(ay[j], axdt); + + ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]); + + sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]); + + GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]); + } + } + + // reduce sum0..sum3 to sum0 + GGML_F32_VEC_REDUCE(sumf, sum); +#else + const int np = 0; +#endif // d_state - for (int i0 = 0; i0 < nc; ++i0) { - const int ii = i0 + i*nc; + for (int i0 = np; i0 < nc; ++i0) { + const int i = i0 + ii*nc; const int ig = i0 + (h & (ng - 1))*nc; // state = prev_state * dA + dB * x - const float state = (s0[ii] * dA) + (B[ig] * x_dt); + const float state = (s0[i] * dA) + (B[ig] * x_dt); // y = rowwise_dotprod(state, C) sumf += state * C[ig]; - s[ii] = state; + s[i] = state; } - y[i] = sumf + x[i] * D[h]; + y[ii] = sumf + x[ii] * D[h]; } } } else { @@ -15948,20 +16011,22 @@ static void ggml_compute_forward_ssm_scan_f32( // dim for (int i1 = 0; i1 < nr; ++i1) { - const int i = i1 + h*nr; - const float x_dt = x[i] * dt_soft_plus; + const int ii = i1 + h*nr; + const float x_dt = x[ii] * dt_soft_plus; float sumf = 0.0f; + // NOTE: can't really use GGML_SIMD here because d_state is usually 16 + // and also because expf is used within the loop. // d_state for (int i0 = 0; i0 < nc; ++i0) { - const int ii = i0 + i*nc; + const int i = i0 + ii*nc; const int ig = i0 + (h & (ng - 1))*nc; // state = prev_state * dA + dB * x - const float state = (s0[ii] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt); + const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt); // y = rowwise_dotprod(state, C) sumf += state * C[ig]; - s[ii] = state; + s[i] = state; } - y[i] = sumf + x[i] * D[h]; + y[ii] = sumf + x[ii] * D[h]; } } } From 2bfe9de6d3a3598d4b778f9b144bb8ac33c2797b Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 18 Aug 2024 22:43:39 -0400 Subject: [PATCH 03/37] llama : support running Mamba-Codestral-7B-v0.1 --- convert_hf_to_gguf.py | 4 ++++ src/llama.cpp | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 0ac64574a3043..a5bdd5def2029 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2843,6 +2843,10 @@ def set_gguf_parameters(self): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unused + if name.startswith("model.backbone") or name.startswith("model.lm_head"): + # map Mamba-Codestral-7B-v0.1 tensor names to the names used by Mamba-2 + name = name.removeprefix("model.") + if name.endswith(".dt_bias"): name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias" diff --git a/src/llama.cpp b/src/llama.cpp index 5be0ef7a2ac7a..fd80361bd7605 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -9383,7 +9383,7 @@ static struct ggml_tensor * llm_build_mamba2( // grouped RMS norm y = ggml_reshape_4d(ctx, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); y = llm_build_norm(ctx, y, hparams, - model.layers[il].ssm_norm, NULL, + ggml_reshape_2d(ctx, model.layers[il].ssm_norm, d_inner / n_group, n_group), NULL, LLM_NORM_RMS, cb, il); y = ggml_reshape_3d(ctx, y, d_inner, n_seq_tokens, n_seqs); From aff96920f972d8e042dfdef6dc08644cd8df0234 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 21 Aug 2024 16:28:07 -0400 Subject: [PATCH 04/37] llama : fix Mamba-2 conv state saving * ggml : make the ggml_mul fast broadcast path more consistently formatted --- ggml/src/ggml.c | 4 ++-- src/llama.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index f8e708088b357..415fa6901304a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -10226,11 +10226,11 @@ static void ggml_compute_forward_mul_f32( if (scale == 0.0f) { // NOTE: this also sets NANs to zero, which is not compliant with IEEE754, // but it is useful when resetting the state of recurrent models. - memset((char *)dst->data + ir*nb1, 0, nb1); + memset((char *) dst->data + ir*nb1, 0, ne0 * sizeof(float)); } else { if (dst->data != src0->data) { // src0 is same shape as dst => same indices - memcpy((char *)dst->data + ir*nb1, (char *)src0->data + ir*nb01, ne0 * sizeof(float)); + memcpy((char *) dst->data + ir*nb1, (char *) src0->data + ir*nb01, ne0 * sizeof(float)); } if (scale != 1.0f) { ggml_vec_scale_f32(ne0, (float *) ((char *) dst->data + ir*nb1), scale); diff --git a/src/llama.cpp b/src/llama.cpp index fd80361bd7605..03f93164a89e8 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -9335,7 +9335,7 @@ static struct ggml_tensor * llm_build_mamba2( ggml_cpy(ctx, last_conv, ggml_view_1d(ctx, conv_states_all, (d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs), - kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all)))); + kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all)))); // 1D convolution // The equivalent is to make a self-overlapping view of conv_x From e04910dc48966f1cbc7309d12b8e1b55bdd33df2 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 21 Aug 2024 23:06:22 -0400 Subject: [PATCH 05/37] llama : remove unused variable --- src/llama.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 03f93164a89e8..dda3d51b017d6 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7718,7 +7718,6 @@ static bool llm_load_tensors( const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_state = hparams.ssm_d_state; const int64_t n_head = hparams.ssm_dt_rank; - const int64_t head_dim = n_embd / n_head; const int64_t n_group = hparams.ssm_n_group; const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_head; @@ -9287,7 +9286,7 @@ static struct ggml_tensor * llm_build_mamba2( const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_state = hparams.ssm_d_state; const int64_t n_head = hparams.ssm_dt_rank; - const int64_t head_dim = d_inner / n_head; // FIXME + const int64_t head_dim = d_inner / n_head; const int64_t n_group = hparams.ssm_n_group; const int64_t n_seqs = batch.n_seqs; From fa358e707132ace9012cb90880abe86fd32464a6 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Thu, 22 Aug 2024 01:13:43 -0400 Subject: [PATCH 06/37] llama : add missing break --- src/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama.cpp b/src/llama.cpp index dda3d51b017d6..5b6b6707a1c95 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -5628,7 +5628,7 @@ static void llm_load_hparams( } break; default: model.type = e_model::MODEL_UNKNOWN; } - } + } break; case LLM_ARCH_XVERSE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); From 38913dc8ddd1e119df0e0cfcacfb260b9b1f5c02 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Thu, 22 Aug 2024 14:31:12 -0400 Subject: [PATCH 07/37] convert_hf : prefer SentencePiece tokenizer for Mamba-2 when present The tokenzier.json of Mamba-Codestral-7B-v0.1 otherwise requires workarounds to work correctly. --- convert_hf_to_gguf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index a5bdd5def2029..4851926b7b98f 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2801,13 +2801,13 @@ def set_vocab(self): vocab_size = -(vocab_size // -pad_vocab) * pad_vocab self.hparams["vocab_size"] = vocab_size - if (self.dir_model / "tokenizer.json").is_file(): - self._set_vocab_gpt2() - elif (self.dir_model / "tokenizer.model").is_file(): + if (self.dir_model / "tokenizer.model").is_file(): self._set_vocab_sentencepiece() elif (self.dir_model / "tokenizer.model.v3").is_file(): # mamba-codestral raise NotImplementedError(f"Please rename {self.dir_model / 'tokenizer.model.v3'} to {self.dir_model / 'tokenizer.model'}") + elif (self.dir_model / "tokenizer.json").is_file(): + self._set_vocab_gpt2() else: # Use the GPT-NeoX tokenizer when no tokenizer files are present self._set_vocab_builtin("gpt-neox", vocab_size) From 273e7a495ad8c93bb9ba8123c1a3de3c68f93cf9 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 30 Sep 2024 15:52:42 -0400 Subject: [PATCH 08/37] llama : avoid redundant state copy for Mamba 1 and 2 --- ggml/include/ggml.h | 3 +- ggml/src/ggml.c | 50 ++++++------ src/llama.cpp | 154 +++++++++++++++++-------------------- tests/test-backend-ops.cpp | 54 ++++++++++--- 4 files changed, 142 insertions(+), 119 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index fec6798ff6d06..1fc53bebebf30 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1833,7 +1833,8 @@ extern "C" { struct ggml_tensor * A, struct ggml_tensor * B, struct ggml_tensor * C, - struct ggml_tensor * D); + struct ggml_tensor * D, + struct ggml_tensor * ids); // partition into non-overlapping windows with padding if needed // example: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 12e4f26942f86..1c4c393e55d06 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -7598,7 +7598,8 @@ struct ggml_tensor * ggml_ssm_scan( struct ggml_tensor * A, struct ggml_tensor * B, struct ggml_tensor * C, - struct ggml_tensor * D) { + struct ggml_tensor * D, + struct ggml_tensor * ids) { GGML_ASSERT(ggml_is_contiguous(s)); GGML_ASSERT(ggml_is_contiguous(dt)); GGML_ASSERT(ggml_is_contiguous(A)); @@ -7609,6 +7610,7 @@ struct ggml_tensor * ggml_ssm_scan( GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]); GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]); GGML_ASSERT(ggml_are_same_shape(B, C)); + GGML_ASSERT(ids->type == GGML_TYPE_I32); { const int64_t d_state = s->ne[0]; @@ -7623,21 +7625,19 @@ struct ggml_tensor * ggml_ssm_scan( GGML_ASSERT(ggml_is_3d(dt)); GGML_ASSERT(s->ne[1] == head_dim); GGML_ASSERT(s->ne[2] == n_head); - GGML_ASSERT(s->ne[3] == n_seqs); GGML_ASSERT(B->ne[0] == d_state); GGML_ASSERT(B->ne[2] == n_seq_tokens); GGML_ASSERT(B->ne[3] == n_seqs); GGML_ASSERT(D->ne[0] == n_head); GGML_ASSERT(ggml_is_vector(D)); + GGML_ASSERT(ids->ne[0] == n_seqs); + GGML_ASSERT(ggml_is_vector(ids)); + GGML_ASSERT(A->ne[1] == n_head); + GGML_ASSERT(ggml_is_matrix(A)); - if (ggml_is_vector(A)) { - // Mamba-2 - GGML_ASSERT(A->ne[0] == n_head); - } else { - // Mamba-1 + if (A->ne[0] != 1) { + // Mamba-1 has more granular decay factors GGML_ASSERT(A->ne[0] == d_state); - GGML_ASSERT(A->ne[1] == n_head); - GGML_ASSERT(ggml_is_matrix(A)); } } @@ -7649,7 +7649,7 @@ struct ggml_tensor * ggml_ssm_scan( } // concatenated y + ssm_states - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s)); + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + s->ne[0]*s->ne[1]*s->ne[2]*ids->ne[0]); result->op = GGML_OP_SSM_SCAN; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -7660,6 +7660,7 @@ struct ggml_tensor * ggml_ssm_scan( result->src[4] = B; result->src[5] = C; result->src[6] = D; + result->src[7] = ids; return result; } @@ -16635,13 +16636,14 @@ static void ggml_compute_forward_ssm_conv( static void ggml_compute_forward_ssm_scan_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs} + const struct ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+} const struct ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs} const struct ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs} - const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {n_head} + const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head} const struct ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs} const struct ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs} const struct ggml_tensor * src6 = dst->src[6]; // D {n_head} + const struct ggml_tensor * src7 = dst->src[7]; // ids {n_seqs} const int ith = params->ith; const int nth = params->nth; @@ -16651,11 +16653,12 @@ static void ggml_compute_forward_ssm_scan_f32( const int64_t nh = src1->ne[1]; // n_head const int64_t ng = src4->ne[1]; const int64_t nt = src1->ne[2]; // number of tokens per sequence - const int64_t ns = src0->ne[3]; // number of sequences in the batch + const int64_t ns = src1->ne[3]; // number of sequences in the batch - const int64_t s_off = ggml_element_size(src1) * ggml_nelements(src1); + // can't use ggml_nbytes because src1 is not necessarily contiguous + const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1); - GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); + GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float)); @@ -16663,6 +16666,7 @@ static void ggml_compute_forward_ssm_scan_f32( GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); GGML_ASSERT(src6->nb[0] == sizeof(float)); + GGML_ASSERT(src7->nb[0] == sizeof(int32_t)); // allows optimizing the modulo since n_group should be a power of 2 GGML_ASSERT((ng & -ng) == ng); @@ -16673,22 +16677,22 @@ static void ggml_compute_forward_ssm_scan_f32( const int ih0 = dh*ith; const int ih1 = MIN(ih0 + dh, nh); + const int32_t * ids = (const int32_t *) src7->data; + for (int i3 = 0; i3 < ns; ++i3) { + const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns} + float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns} + for (int i2 = 0; i2 < nt; ++i2) { - const float * s0 = (const float *) ((const char *) src0->data + i3*(src0->nb[3])); // {d_state, dim, nh, ns} const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns} const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns} - const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {nh} + const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh} const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns} const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns} const float * D = (const float *) ((const char *) src6->data); // {nh} float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns} - float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns} - - // use the output as the source when it's not the first token-wise iteration - if (i2 > 0) { s0 = s; } - if (ggml_is_vector(src3)) { + if (src3->ne[0] == 1) { // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop // n_head @@ -16778,6 +16782,8 @@ static void ggml_compute_forward_ssm_scan_f32( } } } + // use the output as the source when it's not the first token-wise iteration + s0 = s; } } } diff --git a/src/llama.cpp b/src/llama.cpp index c11472112f8fb..3e1f8755ffb85 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2801,6 +2801,10 @@ struct llama_kv_cache { // computed before each graph build uint32_t n = 0; + // first zero-ed state + // NOTE: only used by recurrent models + int32_t rs_z = -1; + ggml_type type_k = GGML_TYPE_F16; ggml_type type_v = GGML_TYPE_F16; @@ -3381,8 +3385,6 @@ struct llama_context { struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch] struct ggml_tensor * inp_cls; // I32 [n_batch] struct ggml_tensor * inp_s_copy; // I32 [kv_size] - struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] - struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] @@ -3813,6 +3815,15 @@ static bool llama_kv_cache_find_slot( } } + // Find first to-be-cleared cell + cache.rs_z = -1; + for (int i = min; i <= max; ++i) { + if (cache.cells[i].src == -1) { + cache.rs_z = i; + break; + } + } + // allow getting the range of used cells, from head to head + n cache.head = min; cache.n = max - min + 1; @@ -9569,36 +9580,42 @@ static struct ggml_tensor * llm_build_kv( return cur; } -static struct ggml_tensor * llm_build_copy_mask_state( +static struct ggml_tensor * llm_build_rs( struct ggml_context * ctx, struct ggml_cgraph * graph, struct ggml_tensor * s, struct ggml_tensor * state_copy, - struct ggml_tensor * state_mask, + int32_t rs_zero, int32_t n_state, int32_t kv_size, int32_t kv_head, int32_t n_kv, - int32_t n_seqs) { + int32_t n_seqs, + bool avoid_copies = false) { struct ggml_tensor * states = ggml_reshape_2d(ctx, s, n_state, kv_size); - // copy states - // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv - // this shrinks the tensors's ne[1] to n_kv - states = ggml_get_rows(ctx, states, state_copy); - - // clear states of sequences which are starting at the beginning of this batch - // FIXME: zero-out NANs? - states = ggml_mul(ctx, states, state_mask); + // Clear a single state which will then be copied to the other cleared states. + // Note that this is a no-op when the view is zero-sized. + struct ggml_tensor * state_zero = ggml_view_1d(ctx, states, n_state*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0)); + ggml_build_forward_expand(graph, ggml_scale_inplace(ctx, state_zero, 0)); // copy states which won't be changed further (between n_seqs and n_kv) + struct ggml_tensor * states_extra = ggml_get_rows(ctx, states, ggml_view_1d(ctx, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0])); ggml_build_forward_expand(graph, ggml_cpy(ctx, - ggml_view_1d(ctx, states, n_state*(n_kv - n_seqs), n_seqs*n_state*ggml_element_size(states)), + states_extra, ggml_view_1d(ctx, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s)))); - // the part of the states that will be used and modified - return ggml_view_2d(ctx, states, n_state, n_seqs, states->nb[1], 0); + if (!avoid_copies) { + // copy states + // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv + // this shrinks the tensors's ne[1] to n_kv + states = ggml_get_rows(ctx, states, ggml_view_1d(ctx, state_copy, n_seqs, 0)); + // the part of the states that will be used and modified + states = ggml_view_2d(ctx, states, n_state, n_seqs, states->nb[1], 0); + } + + return states; } // TODO: split @@ -9609,7 +9626,7 @@ static struct ggml_tensor * llm_build_mamba( struct ggml_cgraph * graph, struct ggml_tensor * cur, struct ggml_tensor * state_copy, - struct ggml_tensor * state_mask, + int32_t rs_zero, int32_t kv_head, int32_t n_kv, const llm_build_cb & cb, @@ -9639,14 +9656,14 @@ static struct ggml_tensor * llm_build_mamba( struct ggml_tensor * ssm_states_all = kv.v_l[il]; // (ab)using the KV cache to store the states - struct ggml_tensor * conv = llm_build_copy_mask_state(ctx, - graph, conv_states_all, state_copy, state_mask, + struct ggml_tensor * conv = llm_build_rs(ctx, + graph, conv_states_all, state_copy, rs_zero, hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs); conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner, n_seqs); - struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx, - graph, ssm_states_all, state_copy, state_mask, - hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs); - ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, n_seqs); + struct ggml_tensor * ssm = llm_build_rs(ctx, + graph, ssm_states_all, state_copy, rs_zero, + hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs, true); + ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, kv.size); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs); @@ -9711,10 +9728,11 @@ static struct ggml_tensor * llm_build_mamba( x = ggml_reshape_4d(ctx, x, head_dim, n_head, n_seq_tokens, n_seqs); + struct ggml_tensor * ssm_ids = ggml_view_1d(ctx, state_copy, n_seqs, 0); // Custom operator to optimize the parallel associative scan // as described in the Annex D of the Mamba paper. // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d); + struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d, ssm_ids); // store last states ggml_build_forward_expand(graph, @@ -9746,7 +9764,7 @@ static struct ggml_tensor * llm_build_mamba2( struct ggml_cgraph * graph, struct ggml_tensor * cur, struct ggml_tensor * state_copy, - struct ggml_tensor * state_mask, + int32_t rs_zero, int32_t kv_head, int32_t n_kv, const llm_build_cb & cb, @@ -9772,14 +9790,14 @@ static struct ggml_tensor * llm_build_mamba2( struct ggml_tensor * ssm_states_all = kv.v_l[il]; // (ab)using the KV cache to store the states - struct ggml_tensor * conv = llm_build_copy_mask_state(ctx, - graph, conv_states_all, state_copy, state_mask, + struct ggml_tensor * conv = llm_build_rs(ctx, + graph, conv_states_all, state_copy, rs_zero, hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs); conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); - struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx, - graph, ssm_states_all, state_copy, state_mask, - hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs); - ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, n_seqs); + struct ggml_tensor * ssm = llm_build_rs(ctx, + graph, ssm_states_all, state_copy, rs_zero, + hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs, true); + ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, kv.size); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs); @@ -9835,9 +9853,12 @@ static struct ggml_tensor * llm_build_mamba2( // {n_head, n_seq_tokens, n_seqs} dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); + struct ggml_tensor * ssm_ids = ggml_view_1d(ctx, state_copy, n_seqs, 0); + // Use the same shape semantics for A as Mamba-1 + struct ggml_tensor * A = ggml_reshape_2d(ctx, model.layers[il].ssm_a, 1, n_head); // TODO: use semistructured matrices to implement state-space duality // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d); + struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, model.layers[il].ssm_d, ssm_ids); // store last states ggml_build_forward_expand(graph, @@ -10069,6 +10090,7 @@ struct llm_build_context { const int32_t n_outputs; const int32_t n_outputs_enc; const int32_t kv_head; // index of where we store new KV data in the cache + const int32_t rs_zero; // the first zero-ed recurrent state const int32_t n_ctx_orig; const bool flash_attn; @@ -10119,6 +10141,7 @@ struct llm_build_context { n_outputs (worst_case ? n_tokens : lctx.n_outputs), n_outputs_enc (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd), kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), + rs_zero (kv_self.rs_z), n_ctx_orig (cparams.n_ctx_orig_yarn), flash_attn (cparams.flash_attn), pooling_type (cparams.pooling_type), @@ -10147,8 +10170,6 @@ struct llm_build_context { lctx.inp_mean = nullptr; lctx.inp_cls = nullptr; lctx.inp_s_copy = nullptr; - lctx.inp_s_mask = nullptr; - lctx.inp_s_seq = nullptr; lctx.inp_pos_bucket = nullptr; lctx.inp_embd_enc = nullptr; lctx.inp_KQ_mask_cross = nullptr; @@ -10332,13 +10353,6 @@ struct llm_build_context { return lctx.inp_s_copy; } - struct ggml_tensor * build_inp_s_mask() { - lctx.inp_s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv); - cb(lctx.inp_s_mask, "inp_s_mask", -1); - ggml_set_input(lctx.inp_s_mask); - return lctx.inp_s_mask; - } - struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) { // find result_norm tensor for input struct ggml_tensor * inp = nullptr; @@ -13901,7 +13915,6 @@ struct llm_build_context { inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); struct ggml_tensor * state_copy = build_inp_s_copy(); - struct ggml_tensor * state_mask = build_inp_s_mask(); for (int il = 0; il < n_layer; ++il) { // norm @@ -13912,15 +13925,13 @@ struct llm_build_context { switch (version) { case 2: - cur = llm_build_mamba2(ctx0, lctx, batch, gf, cur, - state_copy, state_mask, - kv_head, n_kv, cb, il); + cur = llm_build_mamba2(ctx0, lctx, batch, gf, cur, state_copy, + rs_zero, kv_head, n_kv, cb, il); break; case 1: default: - cur = llm_build_mamba(ctx0, lctx, batch, gf, cur, - state_copy, state_mask, - kv_head, n_kv, cb, il); + cur = llm_build_mamba(ctx0, lctx, batch, gf, cur, state_copy, + rs_zero, kv_head, n_kv, cb, il); break; } @@ -15946,7 +15957,6 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; struct ggml_tensor * state_copy = build_inp_s_copy(); - struct ggml_tensor * state_mask = build_inp_s_mask(); inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1); @@ -15955,11 +15965,11 @@ struct llm_build_context { const llama_layer * layer = &model.layers[il]; // (ab)using the KV cache to store the states - struct ggml_tensor * token_shift = llm_build_copy_mask_state(ctx0, - gf, kv_self.k_l[il], state_copy, state_mask, + struct ggml_tensor * token_shift = llm_build_rs(ctx0, + gf, kv_self.k_l[il], state_copy, rs_zero, hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs); - struct ggml_tensor * wkv_states = llm_build_copy_mask_state(ctx0, - gf, kv_self.v_l[il], state_copy, state_mask, + struct ggml_tensor * wkv_states = llm_build_rs(ctx0, + gf, kv_self.v_l[il], state_copy, rs_zero, hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs); cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); @@ -16329,18 +16339,6 @@ static void llama_set_k_shift(llama_context & lctx) { } } -static void llama_set_s_copy(llama_context & lctx) { - const int64_t kv_size = lctx.kv_self.size; - - assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); - - int32_t * data = (int32_t *) lctx.inp_s_copy->data; - - for (int i = 0; i < kv_size; ++i) { - data[i] = lctx.kv_self.cells[i].src; - } -} - static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { // TODO move to hparams if a T5 variant appears that uses a different value const int64_t max_distance = 128; @@ -16656,24 +16654,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { if (kv_self.recurrent) { const int64_t n_kv = kv_self.n; - if (lctx.inp_s_mask) { - GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer)); - float * data = (float *) lctx.inp_s_mask->data; - - // clear unused states - for (int i = 0; i < n_kv; ++i) { - const uint32_t cell_id = i + kv_self.head; - llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id]; - - data[i] = (float) (kv_cell.src >= 0); - - // only clear once - if (kv_cell.src < 0) { - kv_cell.src = cell_id; - } - } - } - if (lctx.inp_s_copy) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); int32_t * data = (int32_t *) lctx.inp_s_copy->data; @@ -16683,8 +16663,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { const uint32_t cell_id = i + kv_self.head; llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id]; - // prevent out-of-bound sources - if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) { + if (kv_cell.src < 0) { + GGML_ASSERT(kv_self.rs_z >= 0); // Need a valid zero-ed cell as a source + kv_cell.src = kv_self.rs_z; + } + if ((uint32_t) kv_cell.src >= kv_self.size) { + // ignore out-of-bound sources kv_cell.src = cell_id; } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index aa7896defdad0..092639eed42e1 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1530,27 +1530,58 @@ struct test_ssm_scan : public test_case { const int64_t d_state; const int64_t d_inner; + const int64_t n_head; + const int64_t n_group; const int64_t n_seq_tokens; const int64_t n_seqs; std::string vars() override { - return VARS_TO_STR5(type, d_state, d_inner, n_seq_tokens, n_seqs); + return VARS_TO_STR7(type, d_state, d_inner, n_head, n_group, n_seq_tokens, n_seqs); } test_ssm_scan(ggml_type type = GGML_TYPE_F32, - int64_t d_state = 32, int64_t d_inner = 32, int64_t n_seq_tokens = 32, int64_t n_seqs = 32) - : type(type), d_state(d_state), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} + int64_t d_state = 32, + int64_t d_inner = 1, // non-zero for Mamba-2 + int64_t n_head = 32, + int64_t n_group = 1, + int64_t n_seq_tokens = 32, + int64_t n_seqs = 32) + : type(type), d_state(d_state), d_inner(d_inner), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * s = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, d_inner, n_seqs, 1 }.data()); - ggml_tensor * x = ggml_new_tensor(ctx, type, 4, std::vector{ d_inner, n_seq_tokens, n_seqs, 1 }.data()); - ggml_tensor * dt = ggml_new_tensor(ctx, type, 4, std::vector{ d_inner, n_seq_tokens, n_seqs, 1 }.data()); - ggml_tensor * A = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, d_inner, 1 , 1 }.data()); - ggml_tensor * B = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, n_seq_tokens, n_seqs, 1 }.data()); - ggml_tensor * C = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, n_seq_tokens, n_seqs, 1 }.data()); - ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C); + ggml_tensor * s = ggml_new_tensor_4d(ctx, type, d_state, d_inner, n_head, n_seqs); + ggml_tensor * x = ggml_new_tensor_4d(ctx, type, d_inner, n_head, n_seq_tokens, n_seqs); + ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, n_head, n_seq_tokens, n_seqs); + ggml_tensor * A = ggml_new_tensor_2d(ctx, type, (d_inner > 1) ? 1 : d_state, n_head); + ggml_tensor * B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * D = ggml_new_tensor_1d(ctx, type, n_head); + ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); + ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, D, ids); return out; } + + // similar to test_mul_mat_id + void initialize_tensors(ggml_context * ctx) override { + std::random_device rd; + std::default_random_engine rng(rd()); + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->type == GGML_TYPE_I32) { + if (ggml_is_view_op(t->op)) { continue; } + // ids + for (int64_t r = 0; r < ggml_nrows(t); r++) { + std::vector data(t->ne[0]); + for (int i = 0; i < t->ne[0]; i++) { + data[i] = i; + } + std::shuffle(data.begin(), data.end(), rng); + ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t)); + } + } else { + init_tensor_uniform(t); + } + } + } }; // GGML_OP_MUL_MAT @@ -3255,7 +3286,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1})); test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1})); - test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4)); + test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1 + test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 32, 32, 2, 32, 4)); // Mamba-2 #if 1 for (ggml_type type_a : base_types) { From 2c77d799f9387f5971289139aaca23b4ce37c435 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 2 Oct 2024 10:36:22 -0400 Subject: [PATCH 09/37] metal : attempt to adapt SSM_SCAN for Mamba-2 --- ggml/src/ggml-metal.m | 107 ++++++++++++++++++++-------- ggml/src/ggml-metal.metal | 146 ++++++++++++++++++++++++++++++++------ 2 files changed, 202 insertions(+), 51 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 9da08fe2e9771..5d5b98307d264 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -95,6 +95,7 @@ GGML_METAL_KERNEL_TYPE_NORM, GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, + GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, @@ -591,6 +592,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); @@ -1629,47 +1631,74 @@ static void ggml_metal_encode_node( struct ggml_tensor * src3 = node->src[3]; struct ggml_tensor * src4 = node->src[4]; struct ggml_tensor * src5 = node->src[5]; + struct ggml_tensor * src6 = node->src[6]; + struct ggml_tensor * src7 = node->src[7]; GGML_ASSERT(src3); GGML_ASSERT(src4); GGML_ASSERT(src5); + GGML_ASSERT(src6); + GGML_ASSERT(src7); size_t offs_src3 = 0; size_t offs_src4 = 0; size_t offs_src5 = 0; + size_t offs_src6 = 0; + size_t offs_src7 = 0; id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; id id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil; id id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil; + id id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil; + id id_src7 = src7 ? ggml_metal_get_buffer(src7, &offs_src7) : nil; - const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30); + const int64_t ne30 = src3->ne[0]; const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31); const uint64_t nb30 = src3->nb[0]; const uint64_t nb31 = src3->nb[1]; const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40); - const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41); + const int64_t ne41 = src4->ne[1]; const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42); + const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43); const uint64_t nb40 = src4->nb[0]; const uint64_t nb41 = src4->nb[1]; const uint64_t nb42 = src4->nb[2]; + const uint64_t nb43 = src4->nb[3]; const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50); const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51); const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52); + const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53); const uint64_t nb50 = src5->nb[0]; const uint64_t nb51 = src5->nb[1]; const uint64_t nb52 = src5->nb[2]; + const uint64_t nb53 = src5->nb[3]; + + const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60); + + const uint64_t nb60 = src6->nb[0]; + + const int64_t ne70 = src7->ne[0]; GGML_UNUSED(ne70); + + const uint64_t nb70 = src7->nb[0]; const int64_t d_state = ne00; const int64_t d_inner = ne01; + const int64_t n_head = ne02; + const int64_t n_group = ne41; const int64_t n_seq_tokens = ne11; - const int64_t n_seqs = ne02; + const int64_t n_seqs = ne13; - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + if (ne30 == 1) { + // Mamba-2 + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline; + } else { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + } [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -1678,33 +1707,49 @@ static void ggml_metal_encode_node( [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; - - [encoder setBytes:&d_state length:sizeof(d_state) atIndex:7]; - [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8]; - [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9]; - [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10]; - - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; - [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19]; - [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20]; - [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21]; - [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22]; - [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23]; - [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24]; - [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25]; - [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26]; - [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27]; - [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28]; - - [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6]; + [encoder setBuffer:id_src7 offset:offs_src7 atIndex:7]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:8]; + + [encoder setBytes:&d_state length:sizeof(d_state) atIndex:9]; + [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:10]; + [encoder setBytes:&n_head length:sizeof(n_head) atIndex:11]; + [encoder setBytes:&n_group length:sizeof(n_group) atIndex:12]; + [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:13]; + [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:14]; + + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:15]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:16]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:17]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:18]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:19]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:20]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:21]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:22]; + [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:23]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:24]; + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:25]; + [encoder setBytes:&nb23 length:sizeof(nb23) atIndex:26]; + [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:27]; + [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:28]; + [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:29]; + [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:30]; + [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:31]; + [encoder setBytes:&nb43 length:sizeof(nb43) atIndex:32]; + [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:33]; + [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:34]; + [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:35]; + [encoder setBytes:&nb53 length:sizeof(nb53) atIndex:36]; + [encoder setBytes:&nb60 length:sizeof(nb60) atIndex:37]; + [encoder setBytes:&nb70 length:sizeof(nb70) atIndex:38]; + + if (ne30 == 1) { + // Mamba-2 + [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } else { + GGML_ASSERT(d_inner == 1); + [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } } break; case GGML_OP_MUL_MAT: { diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 2b200032394b1..c75fa25c34e7d 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -795,7 +795,7 @@ kernel void kernel_ssm_conv_f32( x[0] = sumf; } -// ref: ggml.c:ggml_compute_forward_ssm_scan_f32 +// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part // TODO: optimize kernel void kernel_ssm_scan_f32( device const void * src0, @@ -804,14 +804,19 @@ kernel void kernel_ssm_scan_f32( device const void * src3, device const void * src4, device const void * src5, + device const void * src6, + device const void * src7, device float * dst, constant int64_t & d_state, constant int64_t & d_inner, + constant int64_t & n_head, + constant int64_t & n_group, constant int64_t & n_seq_tokens, constant int64_t & n_seqs, constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, @@ -824,47 +829,148 @@ kernel void kernel_ssm_scan_f32( constant uint64_t & nb40, constant uint64_t & nb41, constant uint64_t & nb42, + constant uint64_t & nb43, constant uint64_t & nb50, constant uint64_t & nb51, constant uint64_t & nb52, + constant uint64_t & nb53, + constant uint64_t & nb60, + constant uint64_t & nb70, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { - const int64_t ir = tgpig.x; - const int64_t i3 = tgpig.y; + const int64_t i1 = 0; + const int64_t ir = tgpig.x; // current head + const int64_t i3 = tgpig.y; // current seq const int64_t nc = d_state; const int64_t nr = d_inner; + const int64_t nh = n_head; + const int64_t ng = n_group; const int64_t n_t = n_seq_tokens; const int64_t n_s = n_seqs; + const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float); + + device const int32_t * ids = (device const int32_t *) src7; + + device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + ids[i3]*nb03); + device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb03 + s_off); + for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02); - device const float * x = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12); - device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22); - device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); - device const float * B = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42); - device const float * C = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52); - device float * y = (device float *) ((device char *) dst + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides - device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb02 + nb13); - - if (i2 > 0) { - s0 = s; + device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns} + device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22); // {nh, nt, ns} + device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); // {d_state, nh} + device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns} + device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns} + device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} + device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns} + + const float dt_soft_plus = dt[0] <= 20.0f ? log1pf(expf(dt[0])) : dt[0]; + const float x_dt = x[0] * dt_soft_plus; + float sumf = 0.0f; + + for (int64_t i0 = 0; i0 < nc; ++i0) { + const int64_t i = i0 + i1*nc; + const float state = (s0[i] * expf(dt_soft_plus * A[i0])) + (B[i0] * x_dt); + sumf += state * C[i0]; + s[i] = state; } - // i1 == 0 - float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; - float x_dt = x[0] * dt_soft_plus; + y[0] = sumf + x[0] * D[0]; + + // recurse + s0 = s; + } +} + +// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part +// TODO: optimize (e.g. by parallelizing over d_state) +kernel void kernel_ssm_scan_f32_group( + device const void * src0, + device const void * src1, + device const void * src2, + device const void * src3, + device const void * src4, + device const void * src5, + device const void * src6, + device const void * src7, + device float * dst, + constant int64_t & d_state, + constant int64_t & d_inner, + constant int64_t & n_head, + constant int64_t & n_group, + constant int64_t & n_seq_tokens, + constant int64_t & n_seqs, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant uint64_t & nb20, + constant uint64_t & nb21, + constant uint64_t & nb22, + constant uint64_t & nb30, + constant uint64_t & nb31, + constant uint64_t & nb40, + constant uint64_t & nb41, + constant uint64_t & nb42, + constant uint64_t & nb43, + constant uint64_t & nb50, + constant uint64_t & nb51, + constant uint64_t & nb52, + constant uint64_t & nb53, + constant uint64_t & nb60, + constant uint64_t & nb70, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i1 = tgpig.x; + const int64_t ir = tgpig.y; // current head + const int64_t i3 = tgpig.z; // current seq + + const int64_t nc = d_state; + const int64_t nr = d_inner; + const int64_t nh = n_head; + const int64_t ng = n_group; + const int64_t n_t = n_seq_tokens; + const int64_t n_s = n_seqs; + + const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float); + + device const int32_t * ids = (device const int32_t *) src7; + + device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + ids[i3]*nb03); + device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb03 + s_off); + + for (int64_t i2 = 0; i2 < n_t; ++i2) { + device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns} + device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22); // {nh, nt, ns} + device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); // {1, nh} + device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns} + device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns} + device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} + device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns} + + const float dt_soft_plus = dt[0] <= 20.0f ? log1pf(expf(dt[0])) : dt[0]; + const float x_dt = x[0] * dt_soft_plus; + const float dA = expf(dt_soft_plus * A[0]); float sumf = 0.0f; for (int64_t i0 = 0; i0 < nc; ++i0) { - int64_t i = i0; - float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt); + const int64_t i = i0 + i1*nc; + const float state = (s0[i] * dA) + (B[i0] * x_dt); sumf += state * C[i0]; s[i] = state; } - y[0] = sumf; + y[0] = sumf + x[0] * D[0]; + + // recurse + s0 = s; } } From 87b97d08f43652c7a2e73929e34432ae5f9e8713 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 2 Oct 2024 10:41:10 -0400 Subject: [PATCH 10/37] metal : fix SSM_SCAN pipeline scope --- ggml/src/ggml-metal.m | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 5d5b98307d264..477f720a0e32f 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1693,11 +1693,13 @@ static void ggml_metal_encode_node( const int64_t n_seq_tokens = ne11; const int64_t n_seqs = ne13; + id pipeline = nil; + if (ne30 == 1) { // Mamba-2 - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline; } else { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; } [encoder setComputePipelineState:pipeline]; From 03d0e6eabe6172a56a7d470bfd844012f2c2b291 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 2 Oct 2024 10:58:41 -0400 Subject: [PATCH 11/37] metal : use log and exp instead of log1pf and expf in SSM_SCAN --- ggml/src/ggml-metal.metal | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index c75fa25c34e7d..cee9980a75619 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -866,13 +866,13 @@ kernel void kernel_ssm_scan_f32( device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns} - const float dt_soft_plus = dt[0] <= 20.0f ? log1pf(expf(dt[0])) : dt[0]; + const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; float sumf = 0.0f; for (int64_t i0 = 0; i0 < nc; ++i0) { const int64_t i = i0 + i1*nc; - const float state = (s0[i] * expf(dt_soft_plus * A[i0])) + (B[i0] * x_dt); + const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt); sumf += state * C[i0]; s[i] = state; } @@ -955,9 +955,9 @@ kernel void kernel_ssm_scan_f32_group( device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns} - const float dt_soft_plus = dt[0] <= 20.0f ? log1pf(expf(dt[0])) : dt[0]; + const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; - const float dA = expf(dt_soft_plus * A[0]); + const float dA = exp(dt_soft_plus * A[0]); float sumf = 0.0f; for (int64_t i0 = 0; i0 < nc; ++i0) { From 7a351abc28e36aeb73d1fd8ce172db56fbb3ebcb Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 2 Oct 2024 11:28:16 -0400 Subject: [PATCH 12/37] metal : remove unused arguments for SSM_SCAN The max index is 31, so trimming the arguments is necessary. --- ggml/src/ggml-metal.m | 53 ++++++++++++++++----------------------- ggml/src/ggml-metal.metal | 34 +++++++++---------------- 2 files changed, 34 insertions(+), 53 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 477f720a0e32f..5127b34f8edaa 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1655,7 +1655,7 @@ static void ggml_metal_encode_node( const int64_t ne30 = src3->ne[0]; const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31); - const uint64_t nb30 = src3->nb[0]; + const uint64_t nb30 = src3->nb[0]; GGML_UNUSED(nb30); const uint64_t nb31 = src3->nb[1]; const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40); @@ -1663,7 +1663,7 @@ static void ggml_metal_encode_node( const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42); const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43); - const uint64_t nb40 = src4->nb[0]; + const uint64_t nb40 = src4->nb[0]; GGML_UNUSED(nb40); const uint64_t nb41 = src4->nb[1]; const uint64_t nb42 = src4->nb[2]; const uint64_t nb43 = src4->nb[3]; @@ -1673,18 +1673,18 @@ static void ggml_metal_encode_node( const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52); const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53); - const uint64_t nb50 = src5->nb[0]; + const uint64_t nb50 = src5->nb[0]; GGML_UNUSED(nb50); const uint64_t nb51 = src5->nb[1]; const uint64_t nb52 = src5->nb[2]; const uint64_t nb53 = src5->nb[3]; const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60); - const uint64_t nb60 = src6->nb[0]; + const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60); const int64_t ne70 = src7->ne[0]; GGML_UNUSED(ne70); - const uint64_t nb70 = src7->nb[0]; + const uint64_t nb70 = src7->nb[0]; GGML_UNUSED(nb70); const int64_t d_state = ne00; const int64_t d_inner = ne01; @@ -1718,32 +1718,23 @@ static void ggml_metal_encode_node( [encoder setBytes:&n_head length:sizeof(n_head) atIndex:11]; [encoder setBytes:&n_group length:sizeof(n_group) atIndex:12]; [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:13]; - [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:14]; - - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:15]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:16]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:17]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:18]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:19]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:20]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:21]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:22]; - [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:23]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:24]; - [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:25]; - [encoder setBytes:&nb23 length:sizeof(nb23) atIndex:26]; - [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:27]; - [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:28]; - [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:29]; - [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:30]; - [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:31]; - [encoder setBytes:&nb43 length:sizeof(nb43) atIndex:32]; - [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:33]; - [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:34]; - [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:35]; - [encoder setBytes:&nb53 length:sizeof(nb53) atIndex:36]; - [encoder setBytes:&nb60 length:sizeof(nb60) atIndex:37]; - [encoder setBytes:&nb70 length:sizeof(nb70) atIndex:38]; + + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:14]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:15]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:16]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:17]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:18]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:19]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:20]; + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:21]; + [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22]; + [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:23]; + [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:24]; + [encoder setBytes:&nb43 length:sizeof(nb43) atIndex:25]; + [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:26]; + [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:27]; + [encoder setBytes:&nb53 length:sizeof(nb53) atIndex:28]; + // NOTE: max index is 31 if (ne30 == 1) { // Mamba-2 diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index cee9980a75619..3745f2f225512 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -812,30 +812,21 @@ kernel void kernel_ssm_scan_f32( constant int64_t & n_head, constant int64_t & n_group, constant int64_t & n_seq_tokens, - constant int64_t & n_seqs, - constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, constant uint64_t & nb03, - constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, constant uint64_t & nb13, - constant uint64_t & nb20, constant uint64_t & nb21, constant uint64_t & nb22, - constant uint64_t & nb30, constant uint64_t & nb31, - constant uint64_t & nb40, constant uint64_t & nb41, constant uint64_t & nb42, constant uint64_t & nb43, - constant uint64_t & nb50, constant uint64_t & nb51, constant uint64_t & nb52, constant uint64_t & nb53, - constant uint64_t & nb60, - constant uint64_t & nb70, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -843,12 +834,16 @@ kernel void kernel_ssm_scan_f32( const int64_t ir = tgpig.x; // current head const int64_t i3 = tgpig.y; // current seq + const uint64_t nb00 = sizeof(float); + const uint64_t nb10 = sizeof(float); + const uint64_t nb20 = sizeof(float); + const uint64_t nb60 = sizeof(float); + const int64_t nc = d_state; const int64_t nr = d_inner; const int64_t nh = n_head; const int64_t ng = n_group; const int64_t n_t = n_seq_tokens; - const int64_t n_s = n_seqs; const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float); @@ -864,7 +859,7 @@ kernel void kernel_ssm_scan_f32( device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns} device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns} device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} - device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns} + device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; @@ -901,30 +896,21 @@ kernel void kernel_ssm_scan_f32_group( constant int64_t & n_head, constant int64_t & n_group, constant int64_t & n_seq_tokens, - constant int64_t & n_seqs, - constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, constant uint64_t & nb03, - constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, constant uint64_t & nb13, - constant uint64_t & nb20, constant uint64_t & nb21, constant uint64_t & nb22, - constant uint64_t & nb30, constant uint64_t & nb31, - constant uint64_t & nb40, constant uint64_t & nb41, constant uint64_t & nb42, constant uint64_t & nb43, - constant uint64_t & nb50, constant uint64_t & nb51, constant uint64_t & nb52, constant uint64_t & nb53, - constant uint64_t & nb60, - constant uint64_t & nb70, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -932,12 +918,16 @@ kernel void kernel_ssm_scan_f32_group( const int64_t ir = tgpig.y; // current head const int64_t i3 = tgpig.z; // current seq + const uint64_t nb00 = sizeof(float); + const uint64_t nb10 = sizeof(float); + const uint64_t nb20 = sizeof(float); + const uint64_t nb60 = sizeof(float); + const int64_t nc = d_state; const int64_t nr = d_inner; const int64_t nh = n_head; const int64_t ng = n_group; const int64_t n_t = n_seq_tokens; - const int64_t n_s = n_seqs; const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float); @@ -953,7 +943,7 @@ kernel void kernel_ssm_scan_f32_group( device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns} device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns} device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} - device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns} + device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; From 8b15bc6fa0fbb7a0d831b90955430c0a9e281ac2 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 2 Oct 2024 11:47:56 -0400 Subject: [PATCH 13/37] metal : add back n_seqs to SSM_SCAN args Whoops, this is needed for the offset in the concatenated output. --- ggml/src/ggml-metal.m | 33 +++++++++++++++++---------------- ggml/src/ggml-metal.metal | 2 ++ 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 5127b34f8edaa..3f7183060d83d 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1718,22 +1718,23 @@ static void ggml_metal_encode_node( [encoder setBytes:&n_head length:sizeof(n_head) atIndex:11]; [encoder setBytes:&n_group length:sizeof(n_group) atIndex:12]; [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:13]; - - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:14]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:15]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:16]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:17]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:18]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:19]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:20]; - [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:21]; - [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22]; - [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:23]; - [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:24]; - [encoder setBytes:&nb43 length:sizeof(nb43) atIndex:25]; - [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:26]; - [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:27]; - [encoder setBytes:&nb53 length:sizeof(nb53) atIndex:28]; + [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:14]; + + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:15]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:16]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:17]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:20]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:21]; + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:22]; + [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:23]; + [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24]; + [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25]; + [encoder setBytes:&nb43 length:sizeof(nb43) atIndex:26]; + [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27]; + [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28]; + [encoder setBytes:&nb53 length:sizeof(nb53) atIndex:29]; // NOTE: max index is 31 if (ne30 == 1) { diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 3745f2f225512..c36eedb010de1 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -812,6 +812,7 @@ kernel void kernel_ssm_scan_f32( constant int64_t & n_head, constant int64_t & n_group, constant int64_t & n_seq_tokens, + constant int64_t & n_seqs, constant uint64_t & nb01, constant uint64_t & nb02, constant uint64_t & nb03, @@ -896,6 +897,7 @@ kernel void kernel_ssm_scan_f32_group( constant int64_t & n_head, constant int64_t & n_group, constant int64_t & n_seq_tokens, + constant int64_t & n_seqs, constant uint64_t & nb01, constant uint64_t & nb02, constant uint64_t & nb03, From 5b8ec2b978b84dfdb05e6fca4def928f72b1090c Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 2 Oct 2024 12:11:45 -0400 Subject: [PATCH 14/37] metal : fix SSM_SCAN state head offset --- ggml/src/ggml-metal.metal | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index c36eedb010de1..9e1d14ff5d8b5 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -850,8 +850,8 @@ kernel void kernel_ssm_scan_f32( device const int32_t * ids = (device const int32_t *) src7; - device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + ids[i3]*nb03); - device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb03 + s_off); + device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03); + device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off); for (int64_t i2 = 0; i2 < n_t; ++i2) { device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns} @@ -935,8 +935,8 @@ kernel void kernel_ssm_scan_f32_group( device const int32_t * ids = (device const int32_t *) src7; - device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + ids[i3]*nb03); - device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb03 + s_off); + device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03); + device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off); for (int64_t i2 = 0; i2 < n_t; ++i2) { device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns} From 62b09b343c6c4e35486368f1a7b653c9ae58574a Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 2 Oct 2024 21:35:50 -0400 Subject: [PATCH 15/37] metal : fix wrong number of tokens per sequence in SSM_SCAN --- ggml/src/ggml-metal.m | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 3f7183060d83d..a39770bd4ed1b 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1690,7 +1690,7 @@ static void ggml_metal_encode_node( const int64_t d_inner = ne01; const int64_t n_head = ne02; const int64_t n_group = ne41; - const int64_t n_seq_tokens = ne11; + const int64_t n_seq_tokens = ne12; const int64_t n_seqs = ne13; id pipeline = nil; From 805512a73b9876853f0e7d0cd612259806fa5d93 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 12 Oct 2024 16:20:26 -0400 Subject: [PATCH 16/37] ggml : remove unused fast broadcast path in GGML_MUL This was initially added because states were masked with ggml_mul, but this is no longer done and so this "optimisation" is no longer necessary, or at least not worth the additional code complexity. --- ggml/src/ggml.c | 32 +------------------------------- 1 file changed, 1 insertion(+), 31 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e8a5e3d153548..8fd335270dd5a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -10173,37 +10173,7 @@ static void ggml_compute_forward_mul_f32( GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); - if (ne00 > 1 && ne10 == 1) { - // fast broadcast path - for (int64_t ir = ith; ir < nr; ir += nth) { - // src0 and dst are same shape => same indices - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); - - const float scale = src1_ptr[0]; - - if (scale == 0.0f) { - // NOTE: this also sets NANs to zero, which is not compliant with IEEE754, - // but it is useful when resetting the state of recurrent models. - memset((char *) dst->data + ir*nb1, 0, ne0 * sizeof(float)); - } else { - if (dst->data != src0->data) { - // src0 is same shape as dst => same indices - memcpy((char *) dst->data + ir*nb1, (char *) src0->data + ir*nb01, ne0 * sizeof(float)); - } - if (scale != 1.0f) { - ggml_vec_scale_f32(ne0, (float *) ((char *) dst->data + ir*nb1), scale); - } - } - } - } else if (nb10 == sizeof(float)) { + if (nb10 == sizeof(float)) { for (int64_t ir = ith; ir < nr; ir += nth) { // src0 and dst are same shape => same indices const int64_t i03 = ir/(ne02*ne01); From 3bc7103d2ef1c41cd380a1ad8d918cf9c26694d8 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 4 Nov 2024 11:36:37 -0500 Subject: [PATCH 17/37] ggml : avoid multiply by D in GGML_OP_SSM_SCAN This makes the weight buft detection in src/llama.cpp simpler. * convert : transpose Mamba-2 A, D and reshape SSM_NORM This breaks existing conversions of Mamba-2 models to avoid some reshapes. Not sure if it's a good idea, but it makes the graph slightly cleaner. * llama : more appropriate SSM_SCAN and SSM_CONV buft support checks --- convert_hf_to_gguf.py | 26 ++++++++++++++++- ggml/include/ggml.h | 1 - ggml/src/ggml-metal.m | 57 ++++++++++++++++---------------------- ggml/src/ggml-metal.metal | 14 +++------- ggml/src/ggml.c | 20 ++++--------- src/llama.cpp | 54 +++++++++++++++++++----------------- tests/test-backend-ops.cpp | 25 ++++++++--------- 7 files changed, 100 insertions(+), 97 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index f307b1ac69202..f0a63d921d65f 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -264,6 +264,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(self.map_tensor_name(name), data_torch)] + # TODO: merge into modify_tensors? (need to check tensor shapes for all arches before doing that) + def reshape_tensors(self, data_torch: Tensor, new_name: str, bid: int | None) -> Tensor: + del new_name, bid # unused + + return data_torch.squeeze() + def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool: del name, new_name, bid, n_dims # unused @@ -295,7 +301,7 @@ def prepare_tensors(self): break for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)): - data = data_torch.squeeze().numpy() + data = self.reshape_tensors(data_torch, new_name, bid).numpy() # if data ends up empty, it means data_torch was a scalar tensor -> restore if len(data.shape) == 0: @@ -3063,6 +3069,24 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter yield (new_name, data_torch) + def reshape_tensors(self, data_torch: Tensor, new_name: str, bid: int | None) -> Tensor: + if any(self.match_model_tensor_name(new_name, t, bid, suffix="") for t in [ + gguf.MODEL_TENSOR.SSM_A, + gguf.MODEL_TENSOR.SSM_D, + ]): + # unsqueeze A to use similar shape semantics as Mamba-1 + # (D is also unsqueezed, but for more straightforward broadcast internally) + return data_torch.reshape((*data_torch.shape, 1)) + + elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid): + d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) + d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model + n_group = self.hparams.get("n_groups", 1) + return data_torch.reshape((n_group, d_inner // n_group)) + + return data_torch.squeeze() + + @Model.register("CohereForCausalLM") class CommandR2Model(Model): diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 0d2e5cb011a3b..735f56b005a28 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1828,7 +1828,6 @@ extern "C" { struct ggml_tensor * A, struct ggml_tensor * B, struct ggml_tensor * C, - struct ggml_tensor * D, struct ggml_tensor * ids); // partition into non-overlapping windows with padding if needed diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 73e2fedc36544..902728d8e6b55 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1649,25 +1649,21 @@ static void ggml_metal_encode_node( struct ggml_tensor * src4 = node->src[4]; struct ggml_tensor * src5 = node->src[5]; struct ggml_tensor * src6 = node->src[6]; - struct ggml_tensor * src7 = node->src[7]; GGML_ASSERT(src3); GGML_ASSERT(src4); GGML_ASSERT(src5); GGML_ASSERT(src6); - GGML_ASSERT(src7); size_t offs_src3 = 0; size_t offs_src4 = 0; size_t offs_src5 = 0; size_t offs_src6 = 0; - size_t offs_src7 = 0; id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; id id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil; id id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil; id id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil; - id id_src7 = src7 ? ggml_metal_get_buffer(src7, &offs_src7) : nil; const int64_t ne30 = src3->ne[0]; const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31); @@ -1699,10 +1695,6 @@ static void ggml_metal_encode_node( const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60); - const int64_t ne70 = src7->ne[0]; GGML_UNUSED(ne70); - - const uint64_t nb70 = src7->nb[0]; GGML_UNUSED(nb70); - const int64_t d_state = ne00; const int64_t d_inner = ne01; const int64_t n_head = ne02; @@ -1727,31 +1719,30 @@ static void ggml_metal_encode_node( [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6]; - [encoder setBuffer:id_src7 offset:offs_src7 atIndex:7]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:8]; - - [encoder setBytes:&d_state length:sizeof(d_state) atIndex:9]; - [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:10]; - [encoder setBytes:&n_head length:sizeof(n_head) atIndex:11]; - [encoder setBytes:&n_group length:sizeof(n_group) atIndex:12]; - [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:13]; - [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:14]; - - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:15]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:16]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:17]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:20]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:21]; - [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:22]; - [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:23]; - [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24]; - [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25]; - [encoder setBytes:&nb43 length:sizeof(nb43) atIndex:26]; - [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27]; - [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28]; - [encoder setBytes:&nb53 length:sizeof(nb53) atIndex:29]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:7]; + + [encoder setBytes:&d_state length:sizeof(d_state) atIndex:8]; + [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:9]; + [encoder setBytes:&n_head length:sizeof(n_head) atIndex:10]; + [encoder setBytes:&n_group length:sizeof(n_group) atIndex:11]; + [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:12]; + [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:13]; + + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:14]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:15]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:16]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:17]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:18]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:19]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:20]; + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:21]; + [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22]; + [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:23]; + [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:24]; + [encoder setBytes:&nb43 length:sizeof(nb43) atIndex:25]; + [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:26]; + [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:27]; + [encoder setBytes:&nb53 length:sizeof(nb53) atIndex:28]; // NOTE: max index is 31 if (ne30 == 1) { diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 2f5a4d12eeec3..05d04e8f3fdbf 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -805,7 +805,6 @@ kernel void kernel_ssm_scan_f32( device const void * src4, device const void * src5, device const void * src6, - device const void * src7, device float * dst, constant int64_t & d_state, constant int64_t & d_inner, @@ -838,7 +837,6 @@ kernel void kernel_ssm_scan_f32( const uint64_t nb00 = sizeof(float); const uint64_t nb10 = sizeof(float); const uint64_t nb20 = sizeof(float); - const uint64_t nb60 = sizeof(float); const int64_t nc = d_state; const int64_t nr = d_inner; @@ -848,7 +846,7 @@ kernel void kernel_ssm_scan_f32( const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float); - device const int32_t * ids = (device const int32_t *) src7; + device const int32_t * ids = (device const int32_t *) src6; device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03); device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off); @@ -859,7 +857,6 @@ kernel void kernel_ssm_scan_f32( device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); // {d_state, nh} device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns} device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns} - device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; @@ -873,7 +870,7 @@ kernel void kernel_ssm_scan_f32( s[i] = state; } - y[0] = sumf + x[0] * D[0]; + y[0] = sumf; // recurse s0 = s; @@ -890,7 +887,6 @@ kernel void kernel_ssm_scan_f32_group( device const void * src4, device const void * src5, device const void * src6, - device const void * src7, device float * dst, constant int64_t & d_state, constant int64_t & d_inner, @@ -923,7 +919,6 @@ kernel void kernel_ssm_scan_f32_group( const uint64_t nb00 = sizeof(float); const uint64_t nb10 = sizeof(float); const uint64_t nb20 = sizeof(float); - const uint64_t nb60 = sizeof(float); const int64_t nc = d_state; const int64_t nr = d_inner; @@ -933,7 +928,7 @@ kernel void kernel_ssm_scan_f32_group( const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float); - device const int32_t * ids = (device const int32_t *) src7; + device const int32_t * ids = (device const int32_t *) src6; device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03); device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off); @@ -944,7 +939,6 @@ kernel void kernel_ssm_scan_f32_group( device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); // {1, nh} device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns} device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns} - device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; @@ -959,7 +953,7 @@ kernel void kernel_ssm_scan_f32_group( s[i] = state; } - y[0] = sumf + x[0] * D[0]; + y[0] = sumf; // recurse s0 = s; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 91b256a4c25f0..9036fc0be9858 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -7181,7 +7181,6 @@ struct ggml_tensor * ggml_ssm_conv( const int64_t n_s = sx->ne[2]; // TODO: maybe support other strides than 1? - // FIXME: this is always true? GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t); GGML_ASSERT(sx->ne[1] == d_inner); GGML_ASSERT(n_t >= 0); @@ -7205,7 +7204,6 @@ struct ggml_tensor * ggml_ssm_scan( struct ggml_tensor * A, struct ggml_tensor * B, struct ggml_tensor * C, - struct ggml_tensor * D, struct ggml_tensor * ids) { GGML_ASSERT(ggml_is_contiguous(s)); GGML_ASSERT(ggml_is_contiguous(dt)); @@ -7235,8 +7233,6 @@ struct ggml_tensor * ggml_ssm_scan( GGML_ASSERT(B->ne[0] == d_state); GGML_ASSERT(B->ne[2] == n_seq_tokens); GGML_ASSERT(B->ne[3] == n_seqs); - GGML_ASSERT(D->ne[0] == n_head); - GGML_ASSERT(ggml_is_vector(D)); GGML_ASSERT(ids->ne[0] == n_seqs); GGML_ASSERT(ggml_is_vector(ids)); GGML_ASSERT(A->ne[1] == n_head); @@ -7258,8 +7254,7 @@ struct ggml_tensor * ggml_ssm_scan( result->src[3] = A; result->src[4] = B; result->src[5] = C; - result->src[6] = D; - result->src[7] = ids; + result->src[6] = ids; return result; } @@ -16217,8 +16212,7 @@ static void ggml_compute_forward_ssm_scan_f32( const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head} const struct ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs} const struct ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs} - const struct ggml_tensor * src6 = dst->src[6]; // D {n_head} - const struct ggml_tensor * src7 = dst->src[7]; // ids {n_seqs} + const struct ggml_tensor * src6 = dst->src[6]; // ids {n_seqs} const int ith = params->ith; const int nth = params->nth; @@ -16240,8 +16234,7 @@ static void ggml_compute_forward_ssm_scan_f32( GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); - GGML_ASSERT(src6->nb[0] == sizeof(float)); - GGML_ASSERT(src7->nb[0] == sizeof(int32_t)); + GGML_ASSERT(src6->nb[0] == sizeof(int32_t)); // allows optimizing the modulo since n_group should be a power of 2 GGML_ASSERT((ng & -ng) == ng); @@ -16252,7 +16245,7 @@ static void ggml_compute_forward_ssm_scan_f32( const int ih0 = dh*ith; const int ih1 = MIN(ih0 + dh, nh); - const int32_t * ids = (const int32_t *) src7->data; + const int32_t * ids = (const int32_t *) src6->data; for (int i3 = 0; i3 < ns; ++i3) { const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns} @@ -16264,7 +16257,6 @@ static void ggml_compute_forward_ssm_scan_f32( const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh} const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns} const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns} - const float * D = (const float *) ((const char *) src6->data); // {nh} float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns} if (src3->ne[0] == 1) { @@ -16325,7 +16317,7 @@ static void ggml_compute_forward_ssm_scan_f32( sumf += state * C[ig]; s[i] = state; } - y[ii] = sumf + x[ii] * D[h]; + y[ii] = sumf; } } } else { @@ -16353,7 +16345,7 @@ static void ggml_compute_forward_ssm_scan_f32( sumf += state * C[ig]; s[i] = state; } - y[ii] = sumf + x[ii] * D[h]; + y[ii] = sumf; } } } diff --git a/src/llama.cpp b/src/llama.cpp index e84510ce8ffd1..52052caf250b1 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7120,6 +7120,7 @@ static const std::map llm_tensor_info_mapping = { {LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}}, {LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}}, {LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -7227,23 +7228,27 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w } break; case GGML_OP_SSM_CONV: { - // FIXME - ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 12345, w->ne[1], 6789); + const int64_t n_seq_tokens = 512; + const int64_t n_seqs = 3; + ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0] - 1 + n_seq_tokens, w->ne[1], n_seqs); op_tensor = ggml_ssm_conv(ctx, conv_x, w); } break; case GGML_OP_SSM_SCAN: { - // FIXME - const int64_t d_state = w->ne[0]; - const int64_t d_inner = w->ne[1]; + // w is ssm_a + const int64_t d_state = w->ne[0] == 1 ? hparams.ssm_d_state : w->ne[0]; + const int64_t n_head = w->ne[1]; + const int64_t head_dim = hparams.ssm_d_inner / n_head; + const int64_t n_group = hparams.ssm_n_group; const int64_t n_seq_tokens = 512; - const int64_t n_seqs = 1; - ggml_tensor * s = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, d_inner, n_seqs); - ggml_tensor * x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); - ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); - ggml_tensor * B = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); - ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); - op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C); + const int64_t n_seqs = 3; + ggml_tensor * s = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, head_dim, n_head, n_seqs); + ggml_tensor * x = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_dim, n_head, n_seq_tokens, n_seqs); + ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_head, n_seq_tokens, n_seqs); + ggml_tensor * B = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * C = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); + op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C, ids); } break; case GGML_OP_RWKV_WKV: { @@ -8572,10 +8577,10 @@ static bool llm_load_tensors( layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_head}, 0); // no "weight" suffix for these - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {n_head}, 0); - layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {n_head}, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_head}, 0); - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner}, 0); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); // out_proj layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); @@ -9994,7 +9999,7 @@ static struct ggml_tensor * llm_build_rs( return states; } -// TODO: split +// TODO: split conv and ssm static struct ggml_tensor * llm_build_mamba( struct ggml_context * ctx, struct llama_context & lctx, @@ -10102,13 +10107,14 @@ static struct ggml_tensor * llm_build_mamba( dt = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_dt, dt); dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); + cur = x; x = ggml_reshape_4d(ctx, x, head_dim, n_head, n_seq_tokens, n_seqs); struct ggml_tensor * ssm_ids = ggml_view_1d(ctx, state_copy, n_seqs, 0); // Custom operator to optimize the parallel associative scan // as described in the Annex D of the Mamba paper. // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d, ssm_ids); + struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, ssm_ids); // store last states ggml_build_forward_expand(graph, @@ -10120,6 +10126,7 @@ static struct ggml_tensor * llm_build_mamba( // TODO: skip computing output earlier for unused tokens + y = ggml_add(ctx, y, ggml_mul(ctx, cur, model.layers[il].ssm_d)); y = ggml_mul(ctx, y, ggml_silu(ctx, ggml_cont(ctx, z))); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} @@ -10184,7 +10191,7 @@ static struct ggml_tensor * llm_build_mamba2( struct ggml_tensor * zxBCdt = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_in, cur); // split the above in three - struct ggml_tensor * z = ggml_view_3d(ctx, zxBCdt, d_inner, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], 0); + struct ggml_tensor * z = ggml_view_4d(ctx, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0); struct ggml_tensor * xBC = ggml_view_3d(ctx, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner*ggml_element_size(zxBCdt)); struct ggml_tensor * dt = ggml_view_3d(ctx, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt)); @@ -10230,11 +10237,9 @@ static struct ggml_tensor * llm_build_mamba2( dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); struct ggml_tensor * ssm_ids = ggml_view_1d(ctx, state_copy, n_seqs, 0); - // Use the same shape semantics for A as Mamba-1 - struct ggml_tensor * A = ggml_reshape_2d(ctx, model.layers[il].ssm_a, 1, n_head); // TODO: use semistructured matrices to implement state-space duality // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, model.layers[il].ssm_d, ssm_ids); + struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, ssm_ids); // store last states ggml_build_forward_expand(graph, @@ -10242,17 +10247,16 @@ static struct ggml_tensor * llm_build_mamba2( ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]), ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); - struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0); + struct ggml_tensor * y = ggml_view_4d(ctx, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0); // TODO: skip computing output earlier for unused tokens + y = ggml_add(ctx, y, ggml_mul(ctx, x, model.layers[il].ssm_d)); y = ggml_mul(ctx, y, ggml_silu(ctx, ggml_cont(ctx, z))); // grouped RMS norm y = ggml_reshape_4d(ctx, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); - y = llm_build_norm(ctx, y, hparams, - ggml_reshape_2d(ctx, model.layers[il].ssm_norm, d_inner / n_group, n_group), NULL, - LLM_NORM_RMS, cb, il); + y = llm_build_norm(ctx, y, hparams, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, cb, il); y = ggml_reshape_3d(ctx, y, d_inner, n_seq_tokens, n_seqs); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 6ca254a45f23f..95f8abbd80968 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1589,35 +1589,34 @@ struct test_ssm_scan : public test_case { const ggml_type type; const int64_t d_state; - const int64_t d_inner; + const int64_t head_dim; const int64_t n_head; const int64_t n_group; const int64_t n_seq_tokens; const int64_t n_seqs; std::string vars() override { - return VARS_TO_STR7(type, d_state, d_inner, n_head, n_group, n_seq_tokens, n_seqs); + return VARS_TO_STR7(type, d_state, head_dim, n_head, n_group, n_seq_tokens, n_seqs); } test_ssm_scan(ggml_type type = GGML_TYPE_F32, int64_t d_state = 32, - int64_t d_inner = 1, // non-zero for Mamba-2 + int64_t head_dim = 1, // non-zero for Mamba-2 int64_t n_head = 32, int64_t n_group = 1, int64_t n_seq_tokens = 32, int64_t n_seqs = 32) - : type(type), d_state(d_state), d_inner(d_inner), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} + : type(type), d_state(d_state), head_dim(head_dim), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * s = ggml_new_tensor_4d(ctx, type, d_state, d_inner, n_head, n_seqs); - ggml_tensor * x = ggml_new_tensor_4d(ctx, type, d_inner, n_head, n_seq_tokens, n_seqs); - ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, n_head, n_seq_tokens, n_seqs); - ggml_tensor * A = ggml_new_tensor_2d(ctx, type, (d_inner > 1) ? 1 : d_state, n_head); - ggml_tensor * B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); - ggml_tensor * C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); - ggml_tensor * D = ggml_new_tensor_1d(ctx, type, n_head); - ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); - ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, D, ids); + ggml_tensor * s = ggml_new_tensor_4d(ctx, type, d_state, head_dim, n_head, n_seqs); + ggml_tensor * x = ggml_new_tensor_4d(ctx, type, head_dim, n_head, n_seq_tokens, n_seqs); + ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, n_head, n_seq_tokens, n_seqs); + ggml_tensor * A = ggml_new_tensor_2d(ctx, type, (head_dim > 1) ? 1 : d_state, n_head); + ggml_tensor * B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); + ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, ids); return out; } From b4e9c5998dea2d657cfd22bc2e6fa0630fba2fa9 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 4 Nov 2024 15:26:15 -0500 Subject: [PATCH 18/37] convert : fix flake8 lint --- convert_hf_to_gguf.py | 1 - 1 file changed, 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index f0efe5d5b0c7c..019e7b7ef93b6 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3088,7 +3088,6 @@ def reshape_tensors(self, data_torch: Tensor, new_name: str, bid: int | None) -> return data_torch.squeeze() - @Model.register("CohereForCausalLM") class CommandR2Model(Model): model_arch = gguf.MODEL_ARCH.COMMAND_R From 9a68f7537b39541afd771c96389ba3740ad8be4b Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 26 Nov 2024 14:29:24 -0700 Subject: [PATCH 19/37] feat(jamba): First pass at GGUF conversion for Jamba models There are likely still some missing hparams, but the tensor mapping should be correct Branch: BambaArchitecture Signed-off-by: Gabe Goodhart --- convert_hf_to_gguf.py | 73 +++++++++++++++++++++++++++++++++- gguf-py/gguf/constants.py | 27 +++++++++++++ gguf-py/gguf/tensor_mapping.py | 13 +++++- 3 files changed, 110 insertions(+), 3 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index e14ad9b01c659..fa67350ab14d8 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -516,7 +516,10 @@ def get_vocab_base(self) -> tuple[list[str], list[int], str]: from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(self.dir_model) - vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab)) + vocab_size = max( + self.hparams.get("vocab_size", len(tokenizer.vocab)), + len(tokenizer.vocab) + ) assert max(tokenizer.vocab.values()) < vocab_size tokpre = self.get_vocab_base_pre(tokenizer) @@ -3036,7 +3039,7 @@ def set_gguf_parameters(self): # Fail early for models which don't have a block expansion factor of 2 # TODO: does this really matter? - assert d_inner == 2 * d_model + # assert d_inner == 2 * d_model assert d_inner % head_dim == 0 self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default @@ -3088,6 +3091,72 @@ def reshape_tensors(self, data_torch: Tensor, new_name: str, bid: int | None) -> return data_torch.squeeze() +@Model.register("JambaForCausalLM") +class JambaModel(Model): + """Jamba is a hybrid SSM + Attention model and can support either Mamba or + Mamba2 style SSMs + """ + model_arch = gguf.MODEL_ARCH.JAMBA + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Determine if this is using Mamba or Mamba2 + self._mamba_version = self.hparams.get("mamba_version", "v1") + self._mamba_model_class: type[Model] = { + "v1": MambaModel, + "v2": Mamba2Model, + }.get(self._mamba_version, Model) + assert ( + self._mamba_model_class is not Model + ), f"Unsupported mamba_version: {self._mamba_version}" + + # Use Llama conversion for attention / FF / MoE + self._transformer_model_class: type[Model] = LlamaModel + + # Lists of which layers use ssm vs attention + self._attn_layers = self.hparams.get("attn_layer_indices", []) + if not self._attn_layers: + attn_period = self.hparams.get("attn_layer_period") + assert attn_period, "Didn't find attn_layer_indices or attn_layer_period" + attn_offset = self.hparams.get("attn_layer_offset") + assert attn_offset is not None, "No attention layer offset set with attn_layer_period" + self._attn_layers = [ + i for i in range(self.block_count) + if i % attn_period == attn_offset + ] + self._ssm_layers = [ + i for i in range(self.block_count) + if i not in self._attn_layers + ] + + def set_vocab(self): + self._mamba_model_class.set_vocab(self) + + def set_gguf_parameters(self): + # Set the mamba-type parameters + self._mamba_model_class.set_gguf_parameters(self) + + # TODO: All the rest! + + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: + # Determine whether this is a mamaba layer or an attention layer + if bid in self._ssm_layers: + for mamba_new_name, data_torch in self._mamba_model_class.modify_tensors( + self, data_torch, name, bid + ): + yield mamba_new_name, data_torch + elif bid in self._attn_layers: + for llama_new_name, data_torch in self._transformer_model_class.modify_tensors( + self, data_torch, name, bid + ): + yield llama_new_name, data_torch + else: + yield name, data_torch + + @Model.register("CohereForCausalLM") class CommandR2Model(Model): model_arch = gguf.MODEL_ARCH.COMMAND_R diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c7bd9acd952a0..4739024f8495a 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -241,6 +241,7 @@ class MODEL_ARCH(IntEnum): RWKV6 = auto() MAMBA = auto() MAMBA2 = auto() + JAMBA = auto() XVERSE = auto() COMMAND_R = auto() DBRX = auto() @@ -405,6 +406,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.RWKV6: "rwkv6", MODEL_ARCH.MAMBA: "mamba", MODEL_ARCH.MAMBA2: "mamba2", + MODEL_ARCH.JAMBA: "jamba", MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.DBRX: "dbrx", @@ -1035,6 +1037,31 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_NORM, MODEL_TENSOR.SSM_OUT, ], + MODEL_ARCH.JAMBA: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.SSM_IN, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_D, + MODEL_TENSOR.SSM_NORM, + MODEL_TENSOR.SSM_OUT, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], MODEL_ARCH.XVERSE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index fe4bfa3d00175..7e126d9bfeff5 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -101,7 +101,7 @@ class TensorNameMap: "transformer.h.{bid}.input_layernorm", # falcon7b "h.{bid}.input_layernorm", # bloom "transformer.h.{bid}.ln_mlp", # falcon40b - "model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe + "model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe jamba "layers.{bid}.attention_norm", # llama-pth "language_model.encoder.layers.{bid}.input_layernorm", # persimmon "model.layers.{bid}.ln1", # yi @@ -241,6 +241,7 @@ class TensorNameMap: "transformer.decoder_layer.{bid}.rms_norm_2", # Grok "encoder.layers.{bid}.post_attention_layernorm", # chatglm "transformer.layers.{bid}.ffn_norm", # openelm + "model.layers.{bid}.pre_ff_layernorm.weight", # jamba ), # Post feed-forward norm @@ -293,6 +294,7 @@ class TensorNameMap: "model.layers.{bid}.residual_mlp.w3", # arctic "encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm "transformer.h.{bid}.mlp.c_fc_1", # exaone + "model.layers.{bid}.feed_forward.up_proj", # jamba ), MODEL_TENSOR.FFN_UP_EXP: ( @@ -325,6 +327,7 @@ class TensorNameMap: "transformer.h.{bid}.mlp.linear_1", # refact "model.layers.{bid}.residual_mlp.w1", # arctic "transformer.h.{bid}.mlp.c_fc_0", # exaone + "model.layers.{bid}.feed_forward.gate_proj", # jamba ), MODEL_TENSOR.FFN_GATE_EXP: ( @@ -365,6 +368,7 @@ class TensorNameMap: "encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2 "encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm "model.layers.h.{bid}.mlp.c_proj", # exaone + "model.layers.{bid}.feed_forward.down_proj", # jamba ), MODEL_TENSOR.FFN_DOWN_EXP: ( @@ -413,11 +417,13 @@ class TensorNameMap: MODEL_TENSOR.SSM_IN: ( "model.layers.{bid}.in_proj", "backbone.layers.{bid}.mixer.in_proj", + "model.layers.{bid}.mamba.in_proj", # jamba ), MODEL_TENSOR.SSM_CONV1D: ( "model.layers.{bid}.conv1d", "backbone.layers.{bid}.mixer.conv1d", + "model.layers.{bid}.mamba.conv1d", # jamba ), MODEL_TENSOR.SSM_X: ( @@ -428,25 +434,30 @@ class TensorNameMap: MODEL_TENSOR.SSM_DT: ( "model.layers.{bid}.dt_proj", "backbone.layers.{bid}.mixer.dt_proj", + "model.layers.{bid}.mamba.dt_proj", # jamba ), MODEL_TENSOR.SSM_A: ( "model.layers.{bid}.A_log", "backbone.layers.{bid}.mixer.A_log", + "model.layers.{bid}.mamba.A_log", # jamba ), MODEL_TENSOR.SSM_D: ( "model.layers.{bid}.D", "backbone.layers.{bid}.mixer.D", + "model.layers.{bid}.mamba.D", # jamba ), MODEL_TENSOR.SSM_NORM: ( "backbone.layers.{bid}.mixer.norm", # mamba2 + "model.layers.{bid}.mamba.norm", # jamba ), MODEL_TENSOR.SSM_OUT: ( "model.layers.{bid}.out_proj", "backbone.layers.{bid}.mixer.out_proj", + "model.layers.{bid}.mamba.out_proj", # jamba ), MODEL_TENSOR.TIME_MIX_W1: ( From 246dfdba65fc5708256028d472beb33240d5639b Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 26 Nov 2024 14:30:12 -0700 Subject: [PATCH 20/37] feat(jamba): Add jamba architecture to llama.cpp enums Branch: BambaArchitecture Signed-off-by: Gabe Goodhart --- src/llama.cpp | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/src/llama.cpp b/src/llama.cpp index 4168a392c797d..803d2136367d6 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -176,6 +176,7 @@ enum llm_arch { LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, LLM_ARCH_MAMBA2, + LLM_ARCH_JAMBA, LLM_ARCH_XVERSE, LLM_ARCH_COMMAND_R, LLM_ARCH_DBRX, @@ -231,6 +232,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_MAMBA2, "mamba2" }, + { LLM_ARCH_JAMBA, "jamba" }, { LLM_ARCH_XVERSE, "xverse" }, { LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_DBRX, "dbrx" }, @@ -1164,6 +1166,38 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, }, }, + { + LLM_ARCH_JAMBA, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + // mamba(2) ssm layers + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + // attention layers + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + // non-moe FFN + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + // moe FFN + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, { LLM_ARCH_XVERSE, { From e3525e9e5046449e78c444d05bb26f6212c835e9 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Mon, 2 Dec 2024 16:27:19 -0700 Subject: [PATCH 21/37] feat(convert): Full pass at hparam conversion Branch: BambaArchitecture Signed-off-by: Gabe Goodhart --- convert_hf_to_gguf.py | 120 ++++++++++++++++++++++++++++++------ gguf-py/gguf/constants.py | 9 +++ gguf-py/gguf/gguf_writer.py | 21 +++++++ 3 files changed, 131 insertions(+), 19 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index fa67350ab14d8..da4526ee6d173 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3007,6 +3007,14 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter class Mamba2Model(Model): model_arch = gguf.MODEL_ARCH.MAMBA2 + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # n_groups and d_inner are used during reshaping + self.d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) + self.n_group = self.find_hparam(["n_groups"], optional=True) or 1 + self.d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * self.d_model + def set_vocab(self): vocab_size = self.hparams["vocab_size"] # Round vocab size to next multiple of 16 @@ -3028,30 +3036,27 @@ def set_vocab(self): self._set_vocab_builtin("gpt-neox", vocab_size) def set_gguf_parameters(self): - d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) - d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4 - d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model - d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128 - head_dim = self.find_hparam(["head_dim"], optional=True) or 64 - n_group = self.find_hparam(["n_groups"], optional=True) or 1 + d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4 + d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128 + head_dim = self.find_hparam(["head_dim"], optional=True) or 64 rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5 # Fail early for models which don't have a block expansion factor of 2 # TODO: does this really matter? - # assert d_inner == 2 * d_model - assert d_inner % head_dim == 0 + assert self.d_inner == 2 * self.d_model + assert self.d_inner % head_dim == 0 self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default - self.gguf_writer.add_embedding_length(d_model) + self.gguf_writer.add_embedding_length(self.d_model) self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_ssm_conv_kernel(d_conv) - self.gguf_writer.add_ssm_inner_size(d_inner) + self.gguf_writer.add_ssm_inner_size(self.d_inner) self.gguf_writer.add_ssm_state_size(d_state) - self.gguf_writer.add_ssm_time_step_rank(d_inner // head_dim) - self.gguf_writer.add_ssm_group_count(n_group) + self.gguf_writer.add_ssm_time_step_rank(self.d_inner // head_dim) + self.gguf_writer.add_ssm_group_count(self.n_group) self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) self.gguf_writer.add_file_type(self.ftype) @@ -3083,10 +3088,7 @@ def reshape_tensors(self, data_torch: Tensor, new_name: str, bid: int | None) -> return data_torch.reshape((*data_torch.shape, 1)) elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid): - d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) - d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model - n_group = self.hparams.get("n_groups", 1) - return data_torch.reshape((n_group, d_inner // n_group)) + return data_torch.reshape((self.n_group, self.d_inner // self.n_group)) return data_torch.squeeze() @@ -3099,6 +3101,11 @@ class JambaModel(Model): model_arch = gguf.MODEL_ARCH.JAMBA def __init__(self, *args, **kwargs): + + # Hybrid mamba models use a prefix for the mamba-specific params. + # TODO: Extend this if the prefix(es) need to be configurable + self.hparam_prefixes = ["mamba"] + super().__init__(*args, **kwargs) # Determine if this is using Mamba or Mamba2 @@ -3130,14 +3137,73 @@ def __init__(self, *args, **kwargs): if i not in self._attn_layers ] + # n_group and d_inner are used during reshape_tensors for mamaba2 + self.d_model = self.find_hparam(["hidden_size", "d_model"]) + self.n_group = self.find_hparam(["n_groups"]) + self.d_inner = self.find_hparam(["expand"]) * self.d_model + + def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any: + prefixed = [] + for pfx in self.hparam_prefixes: + prefixed.extend( + "_".join([pfx, k]) + for k in keys + ) + keys = list(keys) + prefixed + return super().find_hparam(keys, *args, **kwargs) + def set_vocab(self): self._mamba_model_class.set_vocab(self) def set_gguf_parameters(self): - # Set the mamba-type parameters - self._mamba_model_class.set_gguf_parameters(self) - # TODO: All the rest! + ## General Params ## + self.gguf_writer.add_embedding_length(self.d_model) + self.gguf_writer.add_mamba_version(self._mamba_version) + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0)) + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) + + ## Mamba mixer params ## + self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"])) + self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"])) + self.gguf_writer.add_ssm_group_count(self.n_group) + self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["time_step_rank", "dt_rank"])) + self.gguf_writer.add_ssm_inner_size(self.d_inner) + self.gguf_writer.add_ssm_head_count(self.find_hparam(["n_heads"])) + self.gguf_writer.add_ssm_head_dim(d_head := self.find_hparam(["d_head"])) + self.gguf_writer.add_ssm_conv_bias(self.find_hparam(["conv_bias"], optional=True) or False) + self.gguf_writer.add_ssm_proj_bias(self.find_hparam(["proj_bias"], optional=True) or False) + self.gguf_writer.add_ssm_chunk_size(self.find_hparam(["chunk_size"])) + # TODO: I think this will always be true if available? + # "use_mamba_kernels": true, + + ## Attention params ## + self.gguf_writer.add_attn_layer_indices(self._attn_layers) + self.gguf_writer.add_rope_dimension_count(self.hparams["attn_rotary_emb"]) + self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) + self.gguf_writer.add_head_count_kv(self.find_hparam(["num_key_value_heads", "n_head_kv"])) + + ## Feed Forward Params ## + rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5 + + ## Validation ## + assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported" + assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}" + # TODO: Support MoE FFN configurations + # "num_experts" + # "num_experts_per_tok" + # "expert_layer_offset" + # "expert_layer_period" + assert self.hparams.get("num_experts") in [None, 1], "MoE not currently supported" + + ## UNUSED?? ## + # "tie_word_embeddings" <-- Implied by presence of output weights + # "router_aux_loss_coef" <-- Only used if outputting router logits + # "num_logits_to_keep" <-- Always only keep final token logits + # "output_router_logits" <-- Never output router logits since only doing generate + # "use_cache" <-- KV Cache always enabled + # "sliding_window" <-- Used for flash attention in transformers def modify_tensors( self, data_torch: Tensor, name: str, bid: int | None @@ -3157,6 +3223,22 @@ def modify_tensors( yield name, data_torch + def reshape_tensors( + self, + data_torch: Tensor, + new_name: str, bid: int | None, + ) -> Tensor: + if bid in self._ssm_layers: + return self._mamba_model_class.reshape_tensors( + self, data_torch, new_name, bid + ) + elif bid in self._attn_layers: + return self._transformer_model_class.reshape_tensors( + self, data_torch, new_name, bid + ) + return data_torch + + @Model.register("CohereForCausalLM") class CommandR2Model(Model): model_arch = gguf.MODEL_ARCH.COMMAND_R diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 4739024f8495a..3e5c928635f00 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -151,6 +151,15 @@ class SSM: TIME_STEP_RANK = "{arch}.ssm.time_step_rank" GROUP_COUNT = "{arch}.ssm.group_count" DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms" + HEAD_COUNT = "{arch}.ssm.head_count" + HEAD_DIM = "{arch}.ssm.head_dim" + CHUNK_SIZE = "{arch}.ssm.chunk_size" + CONV_BIAS = "{arch}.ssm.conv_bias" + PROJ_BIAS = "{arch}.ssm.proj_bias" + + class HybridMamba: + MAMBA_VERSION = "{arch}.mamba.version" + ATTN_LAYER_INDICES = "{arch}.attn.layers" class WKV: HEAD_SIZE = "{arch}.wkv.head_size" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 084dd25e2d5c1..7b0126ce1857b 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -790,6 +790,27 @@ def add_ssm_group_count(self, value: int) -> None: def add_ssm_dt_b_c_rms(self, value: bool) -> None: self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value) + def add_ssm_head_count(self, value: int) -> None: + self.add_uint32(Keys.SSM.HEAD_COUNT.format(arch=self.arch), value) + + def add_ssm_head_dim(self, value: int) -> None: + self.add_uint32(Keys.SSM.HEAD_DIM.format(arch=self.arch), value) + + def add_ssm_chunk_size(self, value: int) -> None: + self.add_uint32(Keys.SSM.CHUNK_SIZE.format(arch=self.arch), value) + + def add_ssm_conv_bias(self, value: bool) -> None: + self.add_bool(Keys.SSM.CONV_BIAS.format(arch=self.arch), value) + + def add_ssm_proj_bias(self, value: bool) -> None: + self.add_bool(Keys.SSM.PROJ_BIAS.format(arch=self.arch), value) + + def add_mamba_version(self, value: str) -> None: + self.add_string(Keys.HybridMamba.MAMBA_VERSION.format(arch=self.arch), value) + + def add_attn_layer_indices(self, values: list[int]) -> None: + self.add_array(Keys.HybridMamba.ATTN_LAYER_INDICES.format(arch=self.arch), values) + def add_tokenizer_model(self, model: str) -> None: self.add_string(Keys.Tokenizer.MODEL, model) From fd98682ec387b4bd8b50e20085028b9e4fb8168e Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 3 Dec 2024 16:27:29 -0700 Subject: [PATCH 22/37] fix(bamba conv): Jamba -> Bamba Branch: BambaArchitecture Signed-off-by: Gabe Goodhart --- convert_hf_to_gguf.py | 57 +++++++++++-------------------------- gguf-py/gguf/constants.py | 9 +++--- gguf-py/gguf/gguf_writer.py | 3 -- 3 files changed, 21 insertions(+), 48 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index da4526ee6d173..13f3b570a4fba 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3093,12 +3093,12 @@ def reshape_tensors(self, data_torch: Tensor, new_name: str, bid: int | None) -> return data_torch.squeeze() +# TODO: Switch to BambaForCausalLM once ready in transformers +# @Model.register("BambaForCausalLM") @Model.register("JambaForCausalLM") -class JambaModel(Model): - """Jamba is a hybrid SSM + Attention model and can support either Mamba or - Mamba2 style SSMs - """ - model_arch = gguf.MODEL_ARCH.JAMBA +class BambaModel(Mamba2Model): + """Bamba is a hybrid SSM + Attention model that uses Mamba2 SSM layers""" + model_arch = gguf.MODEL_ARCH.BAMBA def __init__(self, *args, **kwargs): @@ -3108,17 +3108,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # Determine if this is using Mamba or Mamba2 - self._mamba_version = self.hparams.get("mamba_version", "v1") - self._mamba_model_class: type[Model] = { - "v1": MambaModel, - "v2": Mamba2Model, - }.get(self._mamba_version, Model) - assert ( - self._mamba_model_class is not Model - ), f"Unsupported mamba_version: {self._mamba_version}" - - # Use Llama conversion for attention / FF / MoE + # Use Llama conversion for attention self._transformer_model_class: type[Model] = LlamaModel # Lists of which layers use ssm vs attention @@ -3152,17 +3142,14 @@ def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any: keys = list(keys) + prefixed return super().find_hparam(keys, *args, **kwargs) - def set_vocab(self): - self._mamba_model_class.set_vocab(self) - def set_gguf_parameters(self): ## General Params ## self.gguf_writer.add_embedding_length(self.d_model) - self.gguf_writer.add_mamba_version(self._mamba_version) self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0)) self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) ## Mamba mixer params ## self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"])) @@ -3175,8 +3162,6 @@ def set_gguf_parameters(self): self.gguf_writer.add_ssm_conv_bias(self.find_hparam(["conv_bias"], optional=True) or False) self.gguf_writer.add_ssm_proj_bias(self.find_hparam(["proj_bias"], optional=True) or False) self.gguf_writer.add_ssm_chunk_size(self.find_hparam(["chunk_size"])) - # TODO: I think this will always be true if available? - # "use_mamba_kernels": true, ## Attention params ## self.gguf_writer.add_attn_layer_indices(self._attn_layers) @@ -3185,33 +3170,27 @@ def set_gguf_parameters(self): self.gguf_writer.add_head_count_kv(self.find_hparam(["num_key_value_heads", "n_head_kv"])) ## Feed Forward Params ## - rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5 + self.gguf_writer.add_layer_norm_rms_eps( + self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5 + ) ## Validation ## assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported" assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}" - # TODO: Support MoE FFN configurations - # "num_experts" - # "num_experts_per_tok" - # "expert_layer_offset" - # "expert_layer_period" - assert self.hparams.get("num_experts") in [None, 1], "MoE not currently supported" ## UNUSED?? ## - # "tie_word_embeddings" <-- Implied by presence of output weights - # "router_aux_loss_coef" <-- Only used if outputting router logits - # "num_logits_to_keep" <-- Always only keep final token logits - # "output_router_logits" <-- Never output router logits since only doing generate - # "use_cache" <-- KV Cache always enabled - # "sliding_window" <-- Used for flash attention in transformers + # "tie_word_embeddings" <-- Implied by presence of output weights + # "num_logits_to_keep" <-- Always only keep final token logits + # "use_cache" <-- KV Cache always enabled + # "use_mamba_kernels" <-- I think this will always be true if available? def modify_tensors( self, data_torch: Tensor, name: str, bid: int | None ) -> Iterable[tuple[str, Tensor]]: # Determine whether this is a mamaba layer or an attention layer if bid in self._ssm_layers: - for mamba_new_name, data_torch in self._mamba_model_class.modify_tensors( - self, data_torch, name, bid + for mamba_new_name, data_torch in super().modify_tensors( + data_torch, name, bid ): yield mamba_new_name, data_torch elif bid in self._attn_layers: @@ -3229,9 +3208,7 @@ def reshape_tensors( new_name: str, bid: int | None, ) -> Tensor: if bid in self._ssm_layers: - return self._mamba_model_class.reshape_tensors( - self, data_torch, new_name, bid - ) + return super().reshape_tensors(data_torch, new_name, bid) elif bid in self._attn_layers: return self._transformer_model_class.reshape_tensors( self, data_torch, new_name, bid diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 3e5c928635f00..c2d363309c3d5 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -158,8 +158,7 @@ class SSM: PROJ_BIAS = "{arch}.ssm.proj_bias" class HybridMamba: - MAMBA_VERSION = "{arch}.mamba.version" - ATTN_LAYER_INDICES = "{arch}.attn.layers" + ATTN_LAYER_INDICES = "{arch}.attention.layer_indices" class WKV: HEAD_SIZE = "{arch}.wkv.head_size" @@ -250,7 +249,7 @@ class MODEL_ARCH(IntEnum): RWKV6 = auto() MAMBA = auto() MAMBA2 = auto() - JAMBA = auto() + BAMBA = auto() XVERSE = auto() COMMAND_R = auto() DBRX = auto() @@ -415,7 +414,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.RWKV6: "rwkv6", MODEL_ARCH.MAMBA: "mamba", MODEL_ARCH.MAMBA2: "mamba2", - MODEL_ARCH.JAMBA: "jamba", + MODEL_ARCH.BAMBA: "bamba", MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.DBRX: "dbrx", @@ -1046,7 +1045,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_NORM, MODEL_TENSOR.SSM_OUT, ], - MODEL_ARCH.JAMBA: [ + MODEL_ARCH.BAMBA: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 7b0126ce1857b..399887b2e5817 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -805,9 +805,6 @@ def add_ssm_conv_bias(self, value: bool) -> None: def add_ssm_proj_bias(self, value: bool) -> None: self.add_bool(Keys.SSM.PROJ_BIAS.format(arch=self.arch), value) - def add_mamba_version(self, value: str) -> None: - self.add_string(Keys.HybridMamba.MAMBA_VERSION.format(arch=self.arch), value) - def add_attn_layer_indices(self, values: list[int]) -> None: self.add_array(Keys.HybridMamba.ATTN_LAYER_INDICES.format(arch=self.arch), values) From 1c1e0080ed8a728150912fdfbdf2e3d8ef5bec6a Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 3 Dec 2024 16:29:13 -0700 Subject: [PATCH 23/37] fix(bamba): Jamba->Bamba in llama.cpp Branch: BambaArchitecture Signed-off-by: Gabe Goodhart --- src/llama.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 803d2136367d6..ba3fa8360505d 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -176,7 +176,7 @@ enum llm_arch { LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, LLM_ARCH_MAMBA2, - LLM_ARCH_JAMBA, + LLM_ARCH_BAMBA, LLM_ARCH_XVERSE, LLM_ARCH_COMMAND_R, LLM_ARCH_DBRX, @@ -232,7 +232,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_MAMBA2, "mamba2" }, - { LLM_ARCH_JAMBA, "jamba" }, + { LLM_ARCH_BAMBA, "bamba" }, { LLM_ARCH_XVERSE, "xverse" }, { LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_DBRX, "dbrx" }, @@ -1167,7 +1167,7 @@ static const std::map> LLM_TENSOR_N }, }, { - LLM_ARCH_JAMBA, + LLM_ARCH_BAMBA, { { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, From e0af809b05bb365f8c080c9dc99645f44751706d Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 3 Dec 2024 16:29:32 -0700 Subject: [PATCH 24/37] feat(bamba): hparam parsing in llama.cpp Branch: BambaArchitecture Signed-off-by: Gabe Goodhart --- src/llama.cpp | 80 +++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 74 insertions(+), 6 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index ba3fa8360505d..001e3aea91dc8 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -310,6 +310,8 @@ enum llm_kv { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, LLM_KV_ATTENTION_SLIDING_WINDOW, LLM_KV_ATTENTION_SCALE, + LLM_KV_ATTENTION_LAYER_COUNT, + LLM_KV_ATTENTION_LAYER_INDICES, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_FREQ_BASE, @@ -331,6 +333,11 @@ enum llm_kv { LLM_KV_SSM_TIME_STEP_RANK, LLM_KV_SSM_GROUP_COUNT, LLM_KV_SSM_DT_B_C_RMS, + LLM_KV_SSM_HEAD_COUNT, + LLM_KV_SSM_HEAD_DIM, + LLM_KV_SSM_CHUNK_SIZE, + LLM_KV_SSM_CONV_BIAS, + LLM_KV_SSM_PROJ_BIAS, LLM_KV_WKV_HEAD_SIZE, @@ -427,6 +434,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, + { LLM_KV_ATTENTION_LAYER_INDICES, "%s.attention.layer_indices" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, @@ -448,6 +456,11 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, { LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" }, { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, + { LLM_KV_SSM_HEAD_COUNT, "%s.ssm.head_count" }, + { LLM_KV_SSM_HEAD_DIM, "%s.ssm.head_dim" }, + { LLM_KV_SSM_CHUNK_SIZE, "%s.ssm.chunk_size" }, + { LLM_KV_SSM_CONV_BIAS, "%s.ssm.conv_bias" }, + { LLM_KV_SSM_PROJ_BIAS, "%s.ssm.proj_bias" }, { LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" }, @@ -2471,12 +2484,20 @@ struct llama_hparams { float rope_yarn_log_mul; // for State Space Models - uint32_t ssm_d_conv = 0; - uint32_t ssm_d_inner = 0; - uint32_t ssm_d_state = 0; - uint32_t ssm_dt_rank = 0; - uint32_t ssm_n_group = 0; - bool ssm_dt_b_c_rms = false; + uint32_t ssm_d_conv = 0; + uint32_t ssm_d_inner = 0; + uint32_t ssm_d_state = 0; + uint32_t ssm_dt_rank = 0; + uint32_t ssm_n_group = 0; + bool ssm_dt_b_c_rms = false; + uint32_t ssm_head_count = 0; + uint32_t ssm_head_dim = 0; + uint32_t ssm_chunk_size = 0; + bool ssm_conv_bias = false; + bool ssm_proj_bias = false; + + // for hybrid state space models + std::array ssm_layer_arr; float f_clamp_kqv = 0.0f; float f_max_alibi_bias = 0.0f; @@ -2533,6 +2554,13 @@ struct llama_hparams { if (this->ssm_dt_rank != other.ssm_dt_rank) return true; if (this->ssm_n_group != other.ssm_n_group) return true; if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true; + if (this->ssm_head_count != other.ssm_head_count) return true; + if (this->ssm_head_dim != other.ssm_head_dim) return true; + if (this->ssm_chunk_size != other.ssm_chunk_size) return true; + if (this->ssm_conv_bias != other.ssm_conv_bias) return true; + if (this->ssm_proj_bias != other.ssm_proj_bias) return true; + + if (this->ssm_layer_arr != other.ssm_layer_arr) return true; if (this->rescale_every_n_layers != other.rescale_every_n_layers) return true; if (this->time_mix_extra_dim != other.time_mix_extra_dim) return true; @@ -2625,6 +2653,10 @@ struct llama_hparams { return ssm_d_state * ssm_d_inner; } } + + bool ssm_layer(uint32_t il) const { + return ssm_layer_arr[il]; + } }; static_assert(std::is_trivially_copyable::value, "llama_hparams must be trivially copyable"); @@ -5492,6 +5524,7 @@ static void llm_load_hparams( std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); + std::fill(hparams.ssm_layer_arr.begin(), hparams.ssm_layer_arr.end(), false); ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer); ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer); @@ -5960,6 +5993,32 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_BAMBA: + { + // Mamba2 parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + ml.get_key(LLM_KV_SSM_HEAD_COUNT, hparams.ssm_head_count); + ml.get_key(LLM_KV_SSM_HEAD_DIM, hparams.ssm_head_dim); + ml.get_key(LLM_KV_SSM_CHUNK_SIZE, hparams.ssm_chunk_size); + ml.get_key(LLM_KV_SSM_CONV_BIAS, hparams.ssm_conv_bias); + ml.get_key(LLM_KV_SSM_PROJ_BIAS, hparams.ssm_proj_bias); + + // Attention params + std::fill(hparams.ssm_layer_arr.begin(), hparams.ssm_layer_arr.end(), true); + std::vector attn_layer_indices; + ml.get_arr(LLM_KV_ATTENTION_LAYER_INDICES, attn_layer_indices); + for (const auto attn_idx : attn_layer_indices) { + GGML_ASSERT(attn_idx < hparams.n_layer); + hparams.ssm_layer_arr[attn_idx] = false; + } + + // Feed forward params + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + } break; case LLM_ARCH_XVERSE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -7038,6 +7097,11 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); + LLAMA_LOG_INFO("%s: ssm_head_count = %d\n", __func__, hparams.ssm_head_count); + LLAMA_LOG_INFO("%s: ssm_head_dim = %d\n", __func__, hparams.ssm_head_dim); + LLAMA_LOG_INFO("%s: ssm_chunk_size = %d\n", __func__, hparams.ssm_chunk_size); + LLAMA_LOG_INFO("%s: ssm_conv_bias = %d\n", __func__, hparams.ssm_conv_bias); + LLAMA_LOG_INFO("%s: ssm_proj_bias = %d\n", __func__, hparams.ssm_proj_bias); } LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); @@ -7106,6 +7170,10 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); } + + if (model.arch == LLM_ARCH_BAMBA) { + LLAMA_LOG_INFO("%s: ssm_layer_arr = %s\n", __func__, print_f([&](uint32_t il) { return uint32_t(hparams.ssm_layer(il)); }, hparams.n_layer).c_str()); + } } enum llm_tensor_layer { From fd3bb30118a8c192f89c2faf998b41353f781e60 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 4 Dec 2024 12:00:46 -0700 Subject: [PATCH 25/37] fix(bamba conv): Fizes in tensor name and hparam conversion for llama.cpp parsing Branch: BambaArchitecture Signed-off-by: Gabe Goodhart --- convert_hf_to_gguf.py | 12 ++++-------- gguf-py/gguf/constants.py | 2 -- gguf-py/gguf/gguf_writer.py | 6 ------ gguf-py/gguf/tensor_mapping.py | 26 +++++++++++++------------- 4 files changed, 17 insertions(+), 29 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 13f3b570a4fba..916686838bb88 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3094,8 +3094,7 @@ def reshape_tensors(self, data_torch: Tensor, new_name: str, bid: int | None) -> # TODO: Switch to BambaForCausalLM once ready in transformers -# @Model.register("BambaForCausalLM") -@Model.register("JambaForCausalLM") +@Model.register("BambaForCausalLM") class BambaModel(Mamba2Model): """Bamba is a hybrid SSM + Attention model that uses Mamba2 SSM layers""" model_arch = gguf.MODEL_ARCH.BAMBA @@ -3159,8 +3158,6 @@ def set_gguf_parameters(self): self.gguf_writer.add_ssm_inner_size(self.d_inner) self.gguf_writer.add_ssm_head_count(self.find_hparam(["n_heads"])) self.gguf_writer.add_ssm_head_dim(d_head := self.find_hparam(["d_head"])) - self.gguf_writer.add_ssm_conv_bias(self.find_hparam(["conv_bias"], optional=True) or False) - self.gguf_writer.add_ssm_proj_bias(self.find_hparam(["proj_bias"], optional=True) or False) self.gguf_writer.add_ssm_chunk_size(self.find_hparam(["chunk_size"])) ## Attention params ## @@ -3187,6 +3184,7 @@ def set_gguf_parameters(self): def modify_tensors( self, data_torch: Tensor, name: str, bid: int | None ) -> Iterable[tuple[str, Tensor]]: + # Determine whether this is a mamaba layer or an attention layer if bid in self._ssm_layers: for mamba_new_name, data_torch in super().modify_tensors( @@ -3199,13 +3197,11 @@ def modify_tensors( ): yield llama_new_name, data_torch else: - yield name, data_torch + yield self.map_tensor_name(name), data_torch def reshape_tensors( - self, - data_torch: Tensor, - new_name: str, bid: int | None, + self, data_torch: Tensor, new_name: str, bid: int | None, ) -> Tensor: if bid in self._ssm_layers: return super().reshape_tensors(data_torch, new_name, bid) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c2d363309c3d5..1db6e5a4dfa7f 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -154,8 +154,6 @@ class SSM: HEAD_COUNT = "{arch}.ssm.head_count" HEAD_DIM = "{arch}.ssm.head_dim" CHUNK_SIZE = "{arch}.ssm.chunk_size" - CONV_BIAS = "{arch}.ssm.conv_bias" - PROJ_BIAS = "{arch}.ssm.proj_bias" class HybridMamba: ATTN_LAYER_INDICES = "{arch}.attention.layer_indices" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 399887b2e5817..6e9c61d9dff77 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -799,12 +799,6 @@ def add_ssm_head_dim(self, value: int) -> None: def add_ssm_chunk_size(self, value: int) -> None: self.add_uint32(Keys.SSM.CHUNK_SIZE.format(arch=self.arch), value) - def add_ssm_conv_bias(self, value: bool) -> None: - self.add_bool(Keys.SSM.CONV_BIAS.format(arch=self.arch), value) - - def add_ssm_proj_bias(self, value: bool) -> None: - self.add_bool(Keys.SSM.PROJ_BIAS.format(arch=self.arch), value) - def add_attn_layer_indices(self, values: list[int]) -> None: self.add_array(Keys.HybridMamba.ATTN_LAYER_INDICES.format(arch=self.arch), values) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 7e126d9bfeff5..4bd5af362ec2b 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -13,7 +13,7 @@ class TensorNameMap: "transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone "transformer.word_embeddings", # falcon "word_embeddings", # bloom - "model.embed_tokens", # llama-hf nemotron olmoe olmo_1124 + "model.embed_tokens", # llama-hf nemotron olmoe olmo_1124 bamba "tok_embeddings", # llama-pth "embeddings.word_embeddings", # bert nomic-bert "language_model.embedding.word_embeddings", # persimmon @@ -101,7 +101,7 @@ class TensorNameMap: "transformer.h.{bid}.input_layernorm", # falcon7b "h.{bid}.input_layernorm", # bloom "transformer.h.{bid}.ln_mlp", # falcon40b - "model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe jamba + "model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe bamba "layers.{bid}.attention_norm", # llama-pth "language_model.encoder.layers.{bid}.input_layernorm", # persimmon "model.layers.{bid}.ln1", # yi @@ -241,7 +241,7 @@ class TensorNameMap: "transformer.decoder_layer.{bid}.rms_norm_2", # Grok "encoder.layers.{bid}.post_attention_layernorm", # chatglm "transformer.layers.{bid}.ffn_norm", # openelm - "model.layers.{bid}.pre_ff_layernorm.weight", # jamba + "model.layers.{bid}.pre_ff_layernorm", # bamba ), # Post feed-forward norm @@ -294,7 +294,7 @@ class TensorNameMap: "model.layers.{bid}.residual_mlp.w3", # arctic "encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm "transformer.h.{bid}.mlp.c_fc_1", # exaone - "model.layers.{bid}.feed_forward.up_proj", # jamba + "model.layers.{bid}.feed_forward.up_proj", # bamba ), MODEL_TENSOR.FFN_UP_EXP: ( @@ -327,7 +327,7 @@ class TensorNameMap: "transformer.h.{bid}.mlp.linear_1", # refact "model.layers.{bid}.residual_mlp.w1", # arctic "transformer.h.{bid}.mlp.c_fc_0", # exaone - "model.layers.{bid}.feed_forward.gate_proj", # jamba + "model.layers.{bid}.feed_forward.gate_proj", # bamba ), MODEL_TENSOR.FFN_GATE_EXP: ( @@ -368,7 +368,7 @@ class TensorNameMap: "encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2 "encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm "model.layers.h.{bid}.mlp.c_proj", # exaone - "model.layers.{bid}.feed_forward.down_proj", # jamba + "model.layers.{bid}.feed_forward.down_proj", # bamba ), MODEL_TENSOR.FFN_DOWN_EXP: ( @@ -417,13 +417,13 @@ class TensorNameMap: MODEL_TENSOR.SSM_IN: ( "model.layers.{bid}.in_proj", "backbone.layers.{bid}.mixer.in_proj", - "model.layers.{bid}.mamba.in_proj", # jamba + "model.layers.{bid}.mamba.in_proj", # bamba ), MODEL_TENSOR.SSM_CONV1D: ( "model.layers.{bid}.conv1d", "backbone.layers.{bid}.mixer.conv1d", - "model.layers.{bid}.mamba.conv1d", # jamba + "model.layers.{bid}.mamba.conv1d", # bamba ), MODEL_TENSOR.SSM_X: ( @@ -434,30 +434,30 @@ class TensorNameMap: MODEL_TENSOR.SSM_DT: ( "model.layers.{bid}.dt_proj", "backbone.layers.{bid}.mixer.dt_proj", - "model.layers.{bid}.mamba.dt_proj", # jamba + "model.layers.{bid}.mamba.dt_proj", # bamba ), MODEL_TENSOR.SSM_A: ( "model.layers.{bid}.A_log", "backbone.layers.{bid}.mixer.A_log", - "model.layers.{bid}.mamba.A_log", # jamba + "model.layers.{bid}.mamba.A_log", # bamba ), MODEL_TENSOR.SSM_D: ( "model.layers.{bid}.D", "backbone.layers.{bid}.mixer.D", - "model.layers.{bid}.mamba.D", # jamba + "model.layers.{bid}.mamba.D", # bamba ), MODEL_TENSOR.SSM_NORM: ( "backbone.layers.{bid}.mixer.norm", # mamba2 - "model.layers.{bid}.mamba.norm", # jamba + "model.layers.{bid}.mamba.norm", # bamba ), MODEL_TENSOR.SSM_OUT: ( "model.layers.{bid}.out_proj", "backbone.layers.{bid}.mixer.out_proj", - "model.layers.{bid}.mamba.out_proj", # jamba + "model.layers.{bid}.mamba.out_proj", # bamba ), MODEL_TENSOR.TIME_MIX_W1: ( From 3ee0ae3b9062726cf17362601dfc45622659175a Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 4 Dec 2024 12:01:45 -0700 Subject: [PATCH 26/37] feat(bamba): Full tensor parsing for bamba Branch: BambaArchitecture Signed-off-by: Gabe Goodhart --- src/llama.cpp | 104 +++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 91 insertions(+), 13 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 001e3aea91dc8..7a91fc3bbe647 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -336,8 +336,6 @@ enum llm_kv { LLM_KV_SSM_HEAD_COUNT, LLM_KV_SSM_HEAD_DIM, LLM_KV_SSM_CHUNK_SIZE, - LLM_KV_SSM_CONV_BIAS, - LLM_KV_SSM_PROJ_BIAS, LLM_KV_WKV_HEAD_SIZE, @@ -459,8 +457,6 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_SSM_HEAD_COUNT, "%s.ssm.head_count" }, { LLM_KV_SSM_HEAD_DIM, "%s.ssm.head_dim" }, { LLM_KV_SSM_CHUNK_SIZE, "%s.ssm.chunk_size" }, - { LLM_KV_SSM_CONV_BIAS, "%s.ssm.conv_bias" }, - { LLM_KV_SSM_PROJ_BIAS, "%s.ssm.proj_bias" }, { LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" }, @@ -2493,8 +2489,6 @@ struct llama_hparams { uint32_t ssm_head_count = 0; uint32_t ssm_head_dim = 0; uint32_t ssm_chunk_size = 0; - bool ssm_conv_bias = false; - bool ssm_proj_bias = false; // for hybrid state space models std::array ssm_layer_arr; @@ -2557,8 +2551,6 @@ struct llama_hparams { if (this->ssm_head_count != other.ssm_head_count) return true; if (this->ssm_head_dim != other.ssm_head_dim) return true; if (this->ssm_chunk_size != other.ssm_chunk_size) return true; - if (this->ssm_conv_bias != other.ssm_conv_bias) return true; - if (this->ssm_proj_bias != other.ssm_proj_bias) return true; if (this->ssm_layer_arr != other.ssm_layer_arr) return true; @@ -2800,6 +2792,7 @@ struct llama_layer { // mamba bias struct ggml_tensor * ssm_conv1d_b; struct ggml_tensor * ssm_dt_b; + struct ggml_tensor * ssm_in_b; // rwkv struct ggml_tensor * time_mix_w1; @@ -6004,8 +5997,6 @@ static void llm_load_hparams( ml.get_key(LLM_KV_SSM_HEAD_COUNT, hparams.ssm_head_count); ml.get_key(LLM_KV_SSM_HEAD_DIM, hparams.ssm_head_dim); ml.get_key(LLM_KV_SSM_CHUNK_SIZE, hparams.ssm_chunk_size); - ml.get_key(LLM_KV_SSM_CONV_BIAS, hparams.ssm_conv_bias); - ml.get_key(LLM_KV_SSM_PROJ_BIAS, hparams.ssm_proj_bias); // Attention params std::fill(hparams.ssm_layer_arr.begin(), hparams.ssm_layer_arr.end(), true); @@ -7100,8 +7091,6 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: ssm_head_count = %d\n", __func__, hparams.ssm_head_count); LLAMA_LOG_INFO("%s: ssm_head_dim = %d\n", __func__, hparams.ssm_head_dim); LLAMA_LOG_INFO("%s: ssm_chunk_size = %d\n", __func__, hparams.ssm_chunk_size); - LLAMA_LOG_INFO("%s: ssm_conv_bias = %d\n", __func__, hparams.ssm_conv_bias); - LLAMA_LOG_INFO("%s: ssm_proj_bias = %d\n", __func__, hparams.ssm_proj_bias); } LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); @@ -7761,6 +7750,12 @@ static bool llm_load_tensors( model.layers.resize(n_layer); + // Log out tensor names for verbose debugging + LLAMA_LOG_DEBUG("%s: TENSORS\n", __func__); + for (const auto& entry : ml.weights_map) { + LLAMA_LOG_DEBUG("%s: %s\n", __func__, entry.first.c_str()); + } + // TODO: move to a separate function const auto tn = LLM_TN(model.arch); switch (model.arch) { @@ -8740,6 +8735,83 @@ static bool llm_load_tensors( layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); } } break; + case LLM_ARCH_BAMBA: + { + // mamba2 Mixer SSM params + // TODO: Why are these int64_t and not uint32_t? + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_group = hparams.ssm_n_group; + const int64_t head_count = hparams.ssm_head_count; + const int64_t head_dim = hparams.ssm_head_dim; + const int64_t chunk_size = hparams.ssm_chunk_size; + const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + head_count; + + // only an expansion factor of 2 is supported for now + GGML_ASSERT(2 * n_embd == d_inner); + + // embeddings + model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + { + model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (model.output == NULL) { + model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = model.layers[i]; + + // norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.ssm_layer(i)) { + // ssm layers + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); + layer.ssm_in_b = create_tensor(tn(LLM_TENSOR_SSM_IN, "bias", i), {n_embd, d_in_proj}, llama_model_loader::TENSOR_NOT_REQUIRED); + + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, llama_model_loader::TENSOR_NOT_REQUIRED); + + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {head_count}, 0); + + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, head_count}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, head_count}, 0); + + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); + + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); + } else { + // attention layers (with optional bias) + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + + } + + // feed forward (w/ optional biases) + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + } + } break; case LLM_ARCH_XVERSE: { model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -9546,7 +9618,12 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE && model.hparams.n_vocab != model.vocab.id_to_token.size()) { - throw std::runtime_error("vocab size mismatch"); + std::stringstream ss; + ss << "vocab size mismatch. " + << model.hparams.n_vocab + << " != " + << model.vocab.id_to_token.size(); + throw std::runtime_error(ss.str()); } if (params.vocab_only) { @@ -20407,6 +20484,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: case LLM_ARCH_CHAMELEON: + case LLM_ARCH_BAMBA: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 From dfe8d3ddb887316bd979d6c68c4f1d6f24140064 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 5 Dec 2024 10:59:21 -0700 Subject: [PATCH 27/37] fix(bamba conv): Remove chunk size and consolidate head count w/ time step rank head count and time step rank are used for the same purpose in the model, so we stick with the existing key. Chunk size is not used in this impl because of the way the mixer is implemented without chunking. Branch: BambaArchitecture Signed-off-by: Gabe Goodhart --- convert_hf_to_gguf.py | 9 +++++---- gguf-py/gguf/constants.py | 2 -- gguf-py/gguf/gguf_writer.py | 6 ------ 3 files changed, 5 insertions(+), 12 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 916686838bb88..0bf9b0fbc4b6f 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3154,11 +3154,11 @@ def set_gguf_parameters(self): self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"])) self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"])) self.gguf_writer.add_ssm_group_count(self.n_group) - self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["time_step_rank", "dt_rank"])) self.gguf_writer.add_ssm_inner_size(self.d_inner) - self.gguf_writer.add_ssm_head_count(self.find_hparam(["n_heads"])) self.gguf_writer.add_ssm_head_dim(d_head := self.find_hparam(["d_head"])) - self.gguf_writer.add_ssm_chunk_size(self.find_hparam(["chunk_size"])) + # NOTE: The mamba_dt_rank is _not_ the right field for how this is used + # in llama.cpp + self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"])) ## Attention params ## self.gguf_writer.add_attn_layer_indices(self._attn_layers) @@ -3175,11 +3175,12 @@ def set_gguf_parameters(self): assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported" assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}" - ## UNUSED?? ## + ## UNUSED ## # "tie_word_embeddings" <-- Implied by presence of output weights # "num_logits_to_keep" <-- Always only keep final token logits # "use_cache" <-- KV Cache always enabled # "use_mamba_kernels" <-- I think this will always be true if available? + # "chunk_size" <-- This is used in the mixer implementation in transformers, but not here def modify_tensors( self, data_torch: Tensor, name: str, bid: int | None diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 1db6e5a4dfa7f..166694a1f0bb0 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -151,9 +151,7 @@ class SSM: TIME_STEP_RANK = "{arch}.ssm.time_step_rank" GROUP_COUNT = "{arch}.ssm.group_count" DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms" - HEAD_COUNT = "{arch}.ssm.head_count" HEAD_DIM = "{arch}.ssm.head_dim" - CHUNK_SIZE = "{arch}.ssm.chunk_size" class HybridMamba: ATTN_LAYER_INDICES = "{arch}.attention.layer_indices" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 6e9c61d9dff77..d2cd1d531b4f5 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -790,15 +790,9 @@ def add_ssm_group_count(self, value: int) -> None: def add_ssm_dt_b_c_rms(self, value: bool) -> None: self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value) - def add_ssm_head_count(self, value: int) -> None: - self.add_uint32(Keys.SSM.HEAD_COUNT.format(arch=self.arch), value) - def add_ssm_head_dim(self, value: int) -> None: self.add_uint32(Keys.SSM.HEAD_DIM.format(arch=self.arch), value) - def add_ssm_chunk_size(self, value: int) -> None: - self.add_uint32(Keys.SSM.CHUNK_SIZE.format(arch=self.arch), value) - def add_attn_layer_indices(self, values: list[int]) -> None: self.add_array(Keys.HybridMamba.ATTN_LAYER_INDICES.format(arch=self.arch), values) From 41fc019057f0e8986e3336b64e2a298fc9b56285 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 5 Dec 2024 11:01:02 -0700 Subject: [PATCH 28/37] fix(bamba): Remove ssm_head_count and ssm_chunk_size in llama.cpp Not necessary despite their presence in the model config. Branch: BambaArchitecture Signed-off-by: Gabe Goodhart --- src/llama.cpp | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 7a91fc3bbe647..0e568779b9120 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -333,9 +333,7 @@ enum llm_kv { LLM_KV_SSM_TIME_STEP_RANK, LLM_KV_SSM_GROUP_COUNT, LLM_KV_SSM_DT_B_C_RMS, - LLM_KV_SSM_HEAD_COUNT, LLM_KV_SSM_HEAD_DIM, - LLM_KV_SSM_CHUNK_SIZE, LLM_KV_WKV_HEAD_SIZE, @@ -454,9 +452,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, { LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" }, { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, - { LLM_KV_SSM_HEAD_COUNT, "%s.ssm.head_count" }, { LLM_KV_SSM_HEAD_DIM, "%s.ssm.head_dim" }, - { LLM_KV_SSM_CHUNK_SIZE, "%s.ssm.chunk_size" }, { LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" }, @@ -2486,9 +2482,7 @@ struct llama_hparams { uint32_t ssm_dt_rank = 0; uint32_t ssm_n_group = 0; bool ssm_dt_b_c_rms = false; - uint32_t ssm_head_count = 0; uint32_t ssm_head_dim = 0; - uint32_t ssm_chunk_size = 0; // for hybrid state space models std::array ssm_layer_arr; @@ -2548,9 +2542,7 @@ struct llama_hparams { if (this->ssm_dt_rank != other.ssm_dt_rank) return true; if (this->ssm_n_group != other.ssm_n_group) return true; if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true; - if (this->ssm_head_count != other.ssm_head_count) return true; if (this->ssm_head_dim != other.ssm_head_dim) return true; - if (this->ssm_chunk_size != other.ssm_chunk_size) return true; if (this->ssm_layer_arr != other.ssm_layer_arr) return true; @@ -5994,9 +5986,7 @@ static void llm_load_hparams( ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - ml.get_key(LLM_KV_SSM_HEAD_COUNT, hparams.ssm_head_count); ml.get_key(LLM_KV_SSM_HEAD_DIM, hparams.ssm_head_dim); - ml.get_key(LLM_KV_SSM_CHUNK_SIZE, hparams.ssm_chunk_size); // Attention params std::fill(hparams.ssm_layer_arr.begin(), hparams.ssm_layer_arr.end(), true); @@ -7088,9 +7078,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); - LLAMA_LOG_INFO("%s: ssm_head_count = %d\n", __func__, hparams.ssm_head_count); LLAMA_LOG_INFO("%s: ssm_head_dim = %d\n", __func__, hparams.ssm_head_dim); - LLAMA_LOG_INFO("%s: ssm_chunk_size = %d\n", __func__, hparams.ssm_chunk_size); } LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); From e7b1abbc0a0e6aaf34ecbc3545cbaabd8f7e3592 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 5 Dec 2024 11:04:54 -0700 Subject: [PATCH 29/37] feat(bamba): Partially complete work on constructing the forward graph There are still problems at inference around matrix dimensions not lining up, so there are likely still places where the per-layer sizes are not being used correctly. Branch: BambaArchitecture Signed-off-by: Gabe Goodhart --- src/llama.cpp | 181 +++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 164 insertions(+), 17 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 0e568779b9120..ade7e52f3eab9 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -5988,6 +5988,16 @@ static void llm_load_hparams( ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); ml.get_key(LLM_KV_SSM_HEAD_DIM, hparams.ssm_head_dim); + // Zero-out n_head_arr and n_head_kv_arr since SSM layers don't + // have attention heads. We'll set them correctly below once we + // know which layers are attention layers + // NOTE: It's important that this happens after n_embd_head_[kv] + // are set above! + const auto n_head_attn = hparams.n_head(); + const auto n_head_kv_attn = hparams.n_head_kv(); + std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); + std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); + // Attention params std::fill(hparams.ssm_layer_arr.begin(), hparams.ssm_layer_arr.end(), true); std::vector attn_layer_indices; @@ -5995,6 +6005,9 @@ static void llm_load_hparams( for (const auto attn_idx : attn_layer_indices) { GGML_ASSERT(attn_idx < hparams.n_layer); hparams.ssm_layer_arr[attn_idx] = false; + // Correctly set n_head and n_head_kv for attention layers + hparams.n_head_arr[attn_idx] = n_head_attn; + hparams.n_head_kv_arr[attn_idx] = n_head_kv_attn; } // Feed forward params @@ -8726,15 +8739,13 @@ static bool llm_load_tensors( case LLM_ARCH_BAMBA: { // mamba2 Mixer SSM params - // TODO: Why are these int64_t and not uint32_t? + // NOTE: int64_t for tensor dimensions const int64_t d_conv = hparams.ssm_d_conv; const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_state = hparams.ssm_d_state; + const int64_t n_ssm_head = hparams.ssm_dt_rank; const int64_t n_group = hparams.ssm_n_group; - const int64_t head_count = hparams.ssm_head_count; - const int64_t head_dim = hparams.ssm_head_dim; - const int64_t chunk_size = hparams.ssm_chunk_size; - const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + head_count; + const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head; // only an expansion factor of 2 is supported for now GGML_ASSERT(2 * n_embd == d_inner); @@ -8766,11 +8777,11 @@ static bool llm_load_tensors( layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {head_count}, 0); + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_ssm_head}, 0); // no "weight" suffix for these - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, head_count}, 0); - layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, head_count}, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_ssm_head}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_ssm_head}, 0); layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); @@ -8778,14 +8789,17 @@ static bool llm_load_tensors( layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); } else { // attention layers (with optional bias) - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + const int64_t n_head_i = hparams.n_head(i); + const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i); + const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_i}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa_i}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa_i}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); } @@ -10408,7 +10422,7 @@ static struct ggml_tensor * llm_build_mamba2( const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_state = hparams.ssm_d_state; const int64_t n_head = hparams.ssm_dt_rank; - const int64_t head_dim = d_inner / n_head; + const int64_t head_dim = hparams.ssm_head_dim == 0 ? d_inner / n_head : hparams.ssm_head_dim; const int64_t n_group = hparams.ssm_n_group; const int64_t n_seqs = batch.n_seqs; @@ -14633,6 +14647,134 @@ struct llm_build_context { return gf; } + struct ggml_cgraph * build_bamba() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); + + struct ggml_tensor * state_copy = build_inp_s_copy(); + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + if (hparams.ssm_layer(il)) { + // ssm layer + cur = llm_build_mamba2(ctx0, lctx, ubatch, gf, cur, state_copy, + rs_zero, kv_head, n_kv, cb, il); + } else { + // attention layer // + + // rope freq factors + struct ggml_tensor * rope_factors = build_rope_factors(il); + + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed forward + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + // residual + cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + // final rmsnorm + cur = llm_build_norm(ctx0, inpL, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + struct ggml_cgraph * build_command_r() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); @@ -17215,6 +17357,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_mamba(/* version */ 2); } break; + case LLM_ARCH_BAMBA: + { + result = llm.build_bamba(); + } break; case LLM_ARCH_XVERSE: { result = llm.build_xverse(); @@ -20601,6 +20747,7 @@ bool llama_model_is_recurrent(const struct llama_model * model) { switch (model->arch) { case LLM_ARCH_MAMBA: case LLM_ARCH_MAMBA2: + case LLM_ARCH_BAMBA: case LLM_ARCH_RWKV6: return true; default: From f2478bcab58d08fd528d9b579e57fd85570ee1cf Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Mon, 9 Dec 2024 13:43:27 -0700 Subject: [PATCH 30/37] fix: Get n_head_kv per-layer in build_bamba Branch: BambaArchitecture Signed-off-by: Gabe Goodhart --- src/llama.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index ade7e52f3eab9..0d97e54c37ec6 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -14710,14 +14710,15 @@ struct llm_build_context { } Qcur = ggml_rope_ext( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors, + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens), inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); + LLAMA_LOG_DEBUG("%s[%d]: 9. ggml_rope_ext\n", __func__, il); Kcur = ggml_rope_ext( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors, + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens), inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); From d3a34e0282579ba08d773bb7760f4a6a2060cb8a Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Mon, 9 Dec 2024 15:51:32 -0700 Subject: [PATCH 31/37] fix: per-layer recurrent embd_[kv]_s For hybrid models, this value should be 0 for the non-recurrent layers Branch: BambaArchitecture Signed-off-by: Gabe Goodhart --- src/llama.cpp | 49 ++++++++++++++++++++++++++++--------------------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 0d97e54c37ec6..80f76728252a3 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2485,7 +2485,7 @@ struct llama_hparams { uint32_t ssm_head_dim = 0; // for hybrid state space models - std::array ssm_layer_arr; + std::array recurrent_layer_arr; float f_clamp_kqv = 0.0f; float f_max_alibi_bias = 0.0f; @@ -2544,7 +2544,7 @@ struct llama_hparams { if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true; if (this->ssm_head_dim != other.ssm_head_dim) return true; - if (this->ssm_layer_arr != other.ssm_layer_arr) return true; + if (this->recurrent_layer_arr != other.recurrent_layer_arr) return true; if (this->rescale_every_n_layers != other.rescale_every_n_layers) return true; if (this->time_mix_extra_dim != other.time_mix_extra_dim) return true; @@ -2616,30 +2616,34 @@ struct llama_hparams { return n_embd_head_v * n_head_kv; } - uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings + uint32_t n_embd_k_s(uint32_t il = 0) const { // dimension of the rolling state embeddings // corresponds to Mamba's conv_states size or RWKV's token_shift states size + if (!recurrent_layer(il)) { + return 0; + } if (wkv_head_size != 0) { // for RWKV models return 2 * n_embd; - } else { - // TODO: maybe support other convolution strides than 1 - // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed - return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state); } + // TODO: maybe support other convolution strides than 1 + // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed + return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state); } - uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings + uint32_t n_embd_v_s(uint32_t il = 0) const { // dimension of the recurrent state embeddings + if (!recurrent_layer(il)) { + return 0; + } if (wkv_head_size != 0) { // corresponds to RWKV's wkv_states size return n_embd * wkv_head_size; - } else { - // corresponds to Mamba's ssm_states size - return ssm_d_state * ssm_d_inner; } + // corresponds to Mamba's ssm_states size + return ssm_d_state * ssm_d_inner; } - bool ssm_layer(uint32_t il) const { - return ssm_layer_arr[il]; + bool recurrent_layer(uint32_t il) const { + return recurrent_layer_arr[il]; } }; @@ -3555,8 +3559,8 @@ static bool llama_kv_cache_init( cache.v_l.reserve(n_layer); for (int i = 0; i < (int) n_layer; i++) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(i); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(i); ggml_backend_buffer_type_t buft; if (offload) { @@ -5509,7 +5513,10 @@ static void llm_load_hparams( std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); - std::fill(hparams.ssm_layer_arr.begin(), hparams.ssm_layer_arr.end(), false); + std::fill( + hparams.recurrent_layer_arr.begin(), + hparams.recurrent_layer_arr.end(), + llama_model_is_recurrent(&model)); ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer); ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer); @@ -5999,12 +6006,12 @@ static void llm_load_hparams( std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); // Attention params - std::fill(hparams.ssm_layer_arr.begin(), hparams.ssm_layer_arr.end(), true); + std::fill(hparams.recurrent_layer_arr.begin(), hparams.recurrent_layer_arr.end(), true); std::vector attn_layer_indices; ml.get_arr(LLM_KV_ATTENTION_LAYER_INDICES, attn_layer_indices); for (const auto attn_idx : attn_layer_indices) { GGML_ASSERT(attn_idx < hparams.n_layer); - hparams.ssm_layer_arr[attn_idx] = false; + hparams.recurrent_layer_arr[attn_idx] = false; // Correctly set n_head and n_head_kv for attention layers hparams.n_head_arr[attn_idx] = n_head_attn; hparams.n_head_kv_arr[attn_idx] = n_head_kv_attn; @@ -7162,7 +7169,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { } if (model.arch == LLM_ARCH_BAMBA) { - LLAMA_LOG_INFO("%s: ssm_layer_arr = %s\n", __func__, print_f([&](uint32_t il) { return uint32_t(hparams.ssm_layer(il)); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: recurrent_layer_arr = %s\n", __func__, print_f([&](uint32_t il) { return uint32_t(hparams.recurrent_layer(il)); }, hparams.n_layer).c_str()); } } @@ -8769,7 +8776,7 @@ static bool llm_load_tensors( // norm layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - if (hparams.ssm_layer(i)) { + if (hparams.recurrent_layer(i)) { // ssm layers layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); layer.ssm_in_b = create_tensor(tn(LLM_TENSOR_SSM_IN, "bias", i), {n_embd, d_in_proj}, llama_model_loader::TENSOR_NOT_REQUIRED); @@ -14677,7 +14684,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "attn_norm", il); - if (hparams.ssm_layer(il)) { + if (hparams.recurrent_layer(il)) { // ssm layer cur = llm_build_mamba2(ctx0, lctx, ubatch, gf, cur, state_copy, rs_zero, kv_head, n_kv, cb, il); From 92653d05fdf8039439810380f323f32d7425ba96 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Mon, 9 Dec 2024 13:42:54 -0700 Subject: [PATCH 32/37] WIP: Partial work towards separate hybrid cache This also seems like not _quite_ the right direction Branch: BambaArchitecture Signed-off-by: Gabe Goodhart --- include/llama.h | 2 ++ src/llama.cpp | 83 ++++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 70 insertions(+), 15 deletions(-) diff --git a/include/llama.h b/include/llama.h index 90791d5f5ea12..3f9e72f3dfd1f 100644 --- a/include/llama.h +++ b/include/llama.h @@ -489,6 +489,8 @@ extern "C" { // Returns true if the model is recurrent (like Mamba, RWKV, etc.) LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model); + // Returns true if the model is hybrid (like Bamba, etc.) + LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model); // Returns 0 on success LLAMA_API uint32_t llama_model_quantize( diff --git a/src/llama.cpp b/src/llama.cpp index 80f76728252a3..2d0fe1e4d76d5 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3348,6 +3348,10 @@ struct llama_context { struct llama_kv_cache kv_self; struct llama_control_vector cvec; + // Hybrid attention/ssm models use kv cache differently for attention/ssm + // layers with different kv_size values + struct llama_kv_cache kv_hybrid; + std::unordered_map lora_adapters; std::vector backends; @@ -3511,7 +3515,8 @@ static bool llama_kv_cache_init( ggml_type type_k, ggml_type type_v, uint32_t kv_size, - bool offload) { + bool offload, + bool recurrent) { const llama_model & model = ctx->model; const llama_cparams & cparams = ctx->cparams; @@ -3521,7 +3526,7 @@ static bool llama_kv_cache_init( cache.has_shift = false; - cache.recurrent = llama_model_is_recurrent(&model); + cache.recurrent = recurrent; cache.v_trans = !cache.recurrent && !cparams.flash_attn; cache.head = 0; @@ -9749,7 +9754,7 @@ static void llm_build_kv_store( } else { // note: the V cache is transposed when not using flash attention v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, - ( n_ctx)*ggml_element_size(kv.v_l[il]), + (kv.size)*ggml_element_size(kv.v_l[il]), (kv_head)*ggml_element_size(kv.v_l[il])); v_cur = ggml_transpose(ctx, v_cur); @@ -10421,10 +10426,11 @@ static struct ggml_tensor * llm_build_mamba2( int32_t kv_head, int32_t n_kv, const llm_build_cb & cb, - int il) { + int il, + bool hybrid = false) { const llama_model & model = lctx.model; const llama_hparams & hparams = model.hparams; - const llama_kv_cache & kv = lctx.kv_self; + const llama_kv_cache & kv = hybrid ? lctx.kv_hybrid : lctx.kv_self; const int64_t d_conv = hparams.ssm_d_conv; const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_state = hparams.ssm_d_state; @@ -10712,6 +10718,7 @@ struct llm_build_context { const llama_cparams & cparams; const llama_ubatch & ubatch; const llama_kv_cache & kv_self; + const llama_kv_cache & kv_hybrid; const int64_t n_embd; const int64_t n_layer; @@ -10736,11 +10743,14 @@ struct llm_build_context { const float norm_rms_eps; const int32_t n_tokens; - const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size) + const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size) + const int32_t n_kv_hybrid; // size of KV cache to consider (n_kv_hybrid <= kv_hybrid.size) const int32_t n_outputs; const int32_t n_outputs_enc; - const int32_t kv_head; // index of where we store new KV data in the cache - const int32_t rs_zero; // the first zero-ed recurrent state + const int32_t kv_head; // index of where we store new KV data in the cache + const int32_t kv_head_hybrid; // index of where we store new KV data in the hybrid cache + const int32_t rs_zero; // the first zero-ed recurrent state + const int32_t rs_zero_hybrid; // the first zero-ed recurrent state const int32_t n_ctx_orig; const bool flash_attn; @@ -10766,6 +10776,7 @@ struct llm_build_context { cparams (lctx.cparams), ubatch (ubatch), kv_self (lctx.kv_self), + kv_hybrid (lctx.kv_hybrid), n_embd (hparams.n_embd), n_layer (hparams.n_layer), n_rot (hparams.n_rot), @@ -10788,10 +10799,13 @@ struct llm_build_context { norm_rms_eps (hparams.f_norm_rms_eps), n_tokens (ubatch.n_tokens), n_kv (worst_case ? kv_self.size : kv_self.n), + n_kv_hybrid (worst_case ? kv_hybrid.size : kv_self.n), n_outputs (worst_case ? n_tokens : lctx.n_outputs), n_outputs_enc (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd), kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), + kv_head_hybrid (worst_case ? (kv_hybrid.recurrent ? 0 : kv_hybrid.size - n_tokens) : kv_hybrid.head), rs_zero (kv_self.rs_z), + rs_zero_hybrid (kv_hybrid.rs_z), n_ctx_orig (cparams.n_ctx_orig_yarn), flash_attn (cparams.flash_attn), pooling_type (cparams.pooling_type), @@ -14687,7 +14701,8 @@ struct llm_build_context { if (hparams.recurrent_layer(il)) { // ssm layer cur = llm_build_mamba2(ctx0, lctx, ubatch, gf, cur, state_copy, - rs_zero, kv_head, n_kv, cb, il); + rs_zero_hybrid, kv_head_hybrid, n_kv_hybrid, cb, il, true); + cb(cur, "mamba_out", il); } else { // attention layer // @@ -20325,14 +20340,23 @@ struct llama_context * llama_new_context_with_model( ctx->is_encoding = llama_model_has_encoder(model); uint32_t kv_size = cparams.n_ctx; + uint32_t kv_size_hybrid = 0; ggml_type type_k = params.type_k; ggml_type type_v = params.type_v; + const bool recurrent = llama_model_is_recurrent(model); + const bool hybrid = llama_model_is_hybrid(model); // Mamba only needs a constant number of KV cache cells per sequence - if (llama_model_is_recurrent(model)) { + if (recurrent) { // Mamba needs at least as many KV cells as there are sequences kept at any time - kv_size = std::max((uint32_t) 1, params.n_seq_max); + // NOTE: Hybrid models will use the hybrid cache for the SSM layers + if (hybrid) { + kv_size_hybrid = std::max((uint32_t) 1, params.n_seq_max); + } else { + kv_size = std::max((uint32_t) 1, params.n_seq_max); + } // it's probably best to keep as much precision as possible for the states + // TODO: should types be different for the two caches? type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states } @@ -20389,24 +20413,44 @@ struct llama_context * llama_new_context_with_model( llama_set_abort_callback(ctx, params.abort_callback, params.abort_callback_data); - if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) { + // the self cache is recurrent IFF the model is recurrent, but not hybrid + if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv, recurrent && !hybrid)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; } - { + // Log cache memory usage size_t memory_size_k = 0; size_t memory_size_v = 0; - for (auto & k : ctx->kv_self.k_l) { memory_size_k += ggml_nbytes(k); } - for (auto & v : ctx->kv_self.v_l) { memory_size_v += ggml_nbytes(v); } + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), + ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); + } + // For hybrid models, initialize the hybrid kv cache + if (kv_size_hybrid > 0 && !llama_kv_cache_init(ctx->kv_hybrid, ctx, type_k, type_v, kv_size_hybrid, cparams.offload_kqv, true)) { + LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); + llama_free(ctx); + return nullptr; + } + { + // Log hybrid cache memory usage + size_t memory_size_k = 0; + size_t memory_size_v = 0; + for (auto & k : ctx->kv_hybrid.k_l) { + memory_size_k += ggml_nbytes(k); + } + for (auto & v : ctx->kv_hybrid.v_l) { + memory_size_v += ggml_nbytes(v); + } LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), @@ -20763,6 +20807,15 @@ bool llama_model_is_recurrent(const struct llama_model * model) { } } +bool llama_model_is_hybrid(const struct llama_model * model) { + switch (model->arch) { + case LLM_ARCH_BAMBA: + return true; + default: + return false; + } +} + uint32_t llama_model_quantize( const char * fname_inp, const char * fname_out, From 44bf431ab42e35e878e6537272d2c88d56a44ec2 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 10 Dec 2024 10:48:48 -0700 Subject: [PATCH 33/37] fix: Only allocate kv cache tensors for the appropriate layers in hybrid models Branch: BambaArchitecture Signed-off-by: Gabe Goodhart --- src/llama.cpp | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 2d0fe1e4d76d5..e248646140a3e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3581,8 +3581,17 @@ static bool llama_kv_cache_init( return false; } - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + // If this is a hybrid model, there will be two caches, one for + // recurrent layers and one for attention layers. The tensors in the + // cache only need to be fully allocated for the correct layers. + const uint32_t tensor_dim = ( + (cache.recurrent && hparams.recurrent_layer(i)) || + (!cache.recurrent && !hparams.recurrent_layer(i)) + ? kv_size : 0 + ); + + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*tensor_dim); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*tensor_dim); ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); cache.k_l.push_back(k); @@ -20451,7 +20460,7 @@ struct llama_context * llama_new_context_with_model( for (auto & v : ctx->kv_hybrid.v_l) { memory_size_v += ggml_nbytes(v); } - LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + LLAMA_LOG_INFO("%s: KV hybrid size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); From 4543ed56402cb4e3e6f60aa655422e1a7cafd9e1 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 10 Dec 2024 11:00:01 -0700 Subject: [PATCH 34/37] feat: Update the logic in llama_decode_internal for kv_hybrid cache Branch: BambaArchitecture Signed-off-by: Gabe Goodhart --- src/llama.cpp | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index e248646140a3e..c09471aaf696f 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -18104,6 +18104,11 @@ static int llama_decode_internal( auto & kv_self = lctx.kv_self; llama_kv_slot_restorer kv_slot_restorer(kv_self); + // Only used for hybrid-recurrent models (e.g. Bamba) + const bool hybrid = llama_model_is_hybrid(&model); + auto & kv_hybrid = lctx.kv_hybrid; + llama_kv_slot_restorer kv_slot_restorer_hybrid(kv_hybrid); + const int64_t n_embd = hparams.n_embd; const int64_t n_vocab = hparams.n_vocab; @@ -18192,7 +18197,15 @@ static int llama_decode_internal( return 1; } kv_slot_restorer.save(slot); + if (hybrid) { + const auto slot_hybrid = llama_kv_cache_find_slot(kv_hybrid, ubatch); + if (!slot_hybrid) { + return 1; + } + kv_slot_restorer_hybrid.save(slot_hybrid); + } + // TODO: Update this clause for hybrid recurrent models if (!kv_self.recurrent) { // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears @@ -18241,6 +18254,9 @@ static int llama_decode_internal( const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool); if (compute_status != GGML_STATUS_SUCCESS) { kv_slot_restorer.restore(kv_self); + if (hybrid) { + kv_slot_restorer_hybrid.restore(kv_hybrid); + } switch (compute_status) { case GGML_STATUS_ABORTED: return 2; @@ -18252,7 +18268,7 @@ static int llama_decode_internal( } } - // update the kv ring buffer + // update the kv ring buffer(s) { kv_self.head += n_tokens; @@ -18260,6 +18276,13 @@ static int llama_decode_internal( if (kv_self.head >= kv_self.size) { kv_self.head = 0; } + + if (hybrid) { + kv_hybrid.head += n_tokens; + if (kv_hybrid.head >= kv_hybrid.size) { + kv_hybrid.head = 0; + } + } } // plot the computation graph in dot format (for debugging purposes) @@ -18366,7 +18389,7 @@ static int llama_decode_internal( // wait for the computation to finish (automatically done when obtaining the model output) //llama_synchronize(&lctx); - // decide if we need to defrag the kv cache + // decide if we need to defrag the kv cache(s) if (cparams.causal_attn && cparams.defrag_thold >= 0.0f) { const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used)/float(kv_self.n) : 0.0f; @@ -18376,6 +18399,13 @@ static int llama_decode_internal( llama_kv_cache_defrag(kv_self); } + + if (hybrid) { + const float fragmentation = kv_hybrid.n >= 128 ? 1.0f - float(kv_hybrid.used)/float(kv_hybrid.n) : 0.0f; + if (fragmentation > cparams.defrag_thold) { + llama_kv_cache_defrag(kv_hybrid); + } + } } // Reset state for the next token before backend sync, to allow the CPU activities in the reset to From 204e78fba132dc9d74c0f16ec0f0b925649d22fe Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 10 Dec 2024 15:34:53 -0700 Subject: [PATCH 35/37] fix: A number of places where hybrid needs to be handled Still not fully working, but worth committing these: * per-layer n_embd_[kv]_s (probably a no-op since first layer is ssm) * fix setting n_kv_hybrid when not worst_case * Use the right n_kv for build_inp_s_copy when hybrid * Use the right n_kv for recurrent section of llama_set_inputs * Use the right logic to determine batch splitting for hybrid Branch: BambaArchitecture Signed-off-by: Gabe Goodhart --- src/llama.cpp | 36 ++++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index c09471aaf696f..04a01b253058e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -10460,11 +10460,11 @@ static struct ggml_tensor * llm_build_mamba2( // (ab)using the KV cache to store the states struct ggml_tensor * conv = llm_build_rs(ctx, graph, conv_states_all, state_copy, rs_zero, - hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs); + hparams.n_embd_k_s(il), kv.size, kv_head, n_kv, n_seqs); conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); struct ggml_tensor * ssm = llm_build_rs(ctx, graph, ssm_states_all, state_copy, rs_zero, - hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs, true); + hparams.n_embd_v_s(il), kv.size, kv_head, n_kv, n_seqs, true); ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, kv.size); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} @@ -10808,7 +10808,7 @@ struct llm_build_context { norm_rms_eps (hparams.f_norm_rms_eps), n_tokens (ubatch.n_tokens), n_kv (worst_case ? kv_self.size : kv_self.n), - n_kv_hybrid (worst_case ? kv_hybrid.size : kv_self.n), + n_kv_hybrid (worst_case ? kv_hybrid.size : kv_hybrid.n), n_outputs (worst_case ? n_tokens : lctx.n_outputs), n_outputs_enc (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd), kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), @@ -11036,8 +11036,8 @@ struct llm_build_context { return lctx.inp_cls; } - struct ggml_tensor * build_inp_s_copy() { - lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv); + struct ggml_tensor * build_inp_s_copy(bool hybrid = false) { + lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, hybrid ? n_kv_hybrid : n_kv); cb(lctx.inp_s_copy, "inp_s_copy", -1); ggml_set_input(lctx.inp_s_copy); return lctx.inp_s_copy; @@ -14686,7 +14686,7 @@ struct llm_build_context { // {n_embd, n_tokens} inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); - struct ggml_tensor * state_copy = build_inp_s_copy(); + struct ggml_tensor * state_copy = build_inp_s_copy(/* hybrid */true); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -14710,7 +14710,8 @@ struct llm_build_context { if (hparams.recurrent_layer(il)) { // ssm layer cur = llm_build_mamba2(ctx0, lctx, ubatch, gf, cur, state_copy, - rs_zero_hybrid, kv_head_hybrid, n_kv_hybrid, cb, il, true); + rs_zero_hybrid, kv_head_hybrid, n_kv_hybrid, cb, il, + /* hybrid */ true); cb(cur, "mamba_out", il); } else { // attention layer // @@ -17813,8 +17814,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) } } - if (kv_self.recurrent) { - const int64_t n_kv = kv_self.n; + const bool hybrid = llama_model_is_hybrid(&lctx.model); + auto& kv_hybrid = lctx.kv_hybrid; + if (kv_self.recurrent || (hybrid && kv_hybrid.recurrent)) { + auto& kv_recurrent = hybrid ? kv_hybrid : lctx.kv_self; + const int64_t n_kv = kv_recurrent.n; if (lctx.inp_s_copy) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); @@ -17822,14 +17826,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n for (uint32_t i = 0; i < n_kv; ++i) { - const uint32_t cell_id = i + kv_self.head; - llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id]; + const uint32_t cell_id = i + kv_recurrent.head; + llama_kv_cell & kv_cell = kv_recurrent.cells[cell_id]; if (kv_cell.src < 0) { - GGML_ASSERT(kv_self.rs_z >= 0); // Need a valid zero-ed cell as a source - kv_cell.src = kv_self.rs_z; + GGML_ASSERT(kv_recurrent.rs_z >= 0); // Need a valid zero-ed cell as a source + kv_cell.src = kv_recurrent.rs_z; } - if ((uint32_t) kv_cell.src >= kv_self.size) { + if ((uint32_t) kv_cell.src >= kv_recurrent.size) { // ignore out-of-bound sources kv_cell.src = cell_id; } @@ -18135,7 +18139,7 @@ static int llama_decode_internal( } lctx.sbatch.from_batch(batch, n_embd, - /* simple_split */ !kv_self.recurrent, + /* simple_split */ !(kv_self.recurrent || (hybrid && kv_hybrid.recurrent)), /* logits_all */ n_outputs == n_tokens_all); // reserve output buffer @@ -18146,7 +18150,7 @@ static int llama_decode_internal( while (lctx.sbatch.n_tokens > 0) { llama_ubatch ubatch; - if (kv_self.recurrent) { + if (kv_self.recurrent || (hybrid && kv_hybrid.recurrent)) { if (embd_pooled) { // Pooled embeddings cannot be split across ubatches (yet) ubatch = lctx.sbatch.split_seq(n_ubatch); From 97e6ba8d99c89389b6b769bedea7304e28793a2c Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 12 Dec 2024 15:02:05 -0700 Subject: [PATCH 36/37] fix: Remove outdated TODO in convrsion script Branch: BambaArchitecture Signed-off-by: Gabe Goodhart --- convert_hf_to_gguf.py | 1 - 1 file changed, 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 0bf9b0fbc4b6f..ece57527f0dd1 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3093,7 +3093,6 @@ def reshape_tensors(self, data_torch: Tensor, new_name: str, bid: int | None) -> return data_torch.squeeze() -# TODO: Switch to BambaForCausalLM once ready in transformers @Model.register("BambaForCausalLM") class BambaModel(Mamba2Model): """Bamba is a hybrid SSM + Attention model that uses Mamba2 SSM layers""" From b83e9a6cd2d0ba3e8e2eaf7465e2e36e3c373a5e Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 12 Dec 2024 15:02:38 -0700 Subject: [PATCH 37/37] fix: Remove unused LLM_KV_ATTENTION_LAYER_COUNT I'd added this at one point, but it's not actually needed Branch: BambaArchitecture Signed-off-by: Gabe Goodhart --- src/llama.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/llama.cpp b/src/llama.cpp index 04a01b253058e..8fd054c0e29c6 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -310,7 +310,6 @@ enum llm_kv { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, LLM_KV_ATTENTION_SLIDING_WINDOW, LLM_KV_ATTENTION_SCALE, - LLM_KV_ATTENTION_LAYER_COUNT, LLM_KV_ATTENTION_LAYER_INDICES, LLM_KV_ROPE_DIMENSION_COUNT,