Skip to content

llama : support Jamba hybrid Transformer-Mamba models #7531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 41 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
271104c
wip: llama : separate recurrent states from the KV cache
compilade Apr 3, 2024
8db1e4d
llama : use std::find for seq_nodes in llama_rs_cache
compilade Apr 4, 2024
0028010
llama : state checkpoints for recurrent models
compilade Apr 8, 2024
0c8b3b2
llama : correctly handle more edge cases for the rs cache
compilade Apr 9, 2024
d66849f
Merge branch 'master' into compilade/refactor-kv-cache
compilade Apr 10, 2024
a09db95
llama : rename many llama_kv_cache_* functions
compilade Apr 29, 2024
c460ff1
Merge branch 'master' into compilade/refactor-kv-cache
compilade Apr 29, 2024
b6fafd1
llama : remove useless return value for some llama_cache_* functions
compilade Apr 29, 2024
b7ec12e
Merge branch 'master' into compilade/refactor-kv-cache
compilade May 12, 2024
3b57b55
Merge branch 'master' into compilade/refactor-kv-cache
compilade May 22, 2024
7e13f19
llama : rethink recurrent state cell counts
compilade May 24, 2024
cbc743e
llama : support Jamba
compilade May 24, 2024
0fd13e9
Merge branch 'master' into compilade/refactor-kv-cache
compilade May 24, 2024
61a88a1
llama : fix BERT inference without KV cache
compilade May 25, 2024
ea2e63e
convert-hf : check for unprocessed Jamba experts
compilade May 25, 2024
fc59407
convert-hf : support Mini-Jamba conversion
compilade May 25, 2024
181dadf
llama : fix Jamba quantization sanity checks
compilade May 28, 2024
3a414b0
llama : sequence-length-aware batch splitting
compilade May 28, 2024
4e4c41e
Merge branch 'master' into compilade/refactor-kv-cache
compilade May 28, 2024
3587a94
llama : use equal-sequence-length sub-batches for recurrent models
compilade Jun 1, 2024
5d3c7b9
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jun 1, 2024
72eea49
llama : fix batch split output count for embeddings
compilade Jun 1, 2024
18d1c14
llama : minimize swaps when reordering logits
compilade Jun 1, 2024
61200ef
llama : fix edge case finding batch seq_id of split recurrent cell
compilade Jun 1, 2024
eb589d5
llama : avoid copies for simple batch splits
compilade Jun 2, 2024
8fb57ac
llama : use im2col and mul_mat to perform convolution for Mamba
compilade Jun 3, 2024
17f6c1e
llama : fix .base() compilation error on Windows
compilade Jun 3, 2024
fee3c1d
llama : allow doing the equivalent of SSM_CONV with SUM_ROWS and MUL
compilade Jun 3, 2024
6840ac0
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jun 8, 2024
372482d
llama : rename llama_cache to llama_past
compilade Jun 8, 2024
43d8d4b
examples : replace llama_kv_cache_seq_* with llama_past_seq_*
compilade Jun 10, 2024
ff794f5
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jun 12, 2024
33425a7
mamba : fix non-contiguous usage of ggml_silu
compilade Jun 12, 2024
10c3c41
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jun 30, 2024
9b38f8b
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jul 4, 2024
bc320ef
Merge branch 'master' into compilade/refactor-kv-cache
compilade Sep 1, 2024
fcb889c
llama : session saving and reloading for hybrid models
compilade Sep 2, 2024
a03e32a
Merge branch 'master' into compilade/refactor-kv-cache
compilade Sep 2, 2024
9d3f44d
convert_hf : fix Jamba conversion
compilade Sep 2, 2024
5f62db7
llama : fix mixed signedness comparison
compilade Sep 2, 2024
375de5b
llama : use unused n_embd_k_gqa in k_shift
compilade Sep 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 102 additions & 1 deletion convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2338,7 +2338,7 @@ def set_gguf_parameters(self):
self.gguf_writer.add_embedding_length(d_model)
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
self.gguf_writer.add_block_count(self.hparams["n_layer"])
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_ssm_conv_kernel(d_conv)
self.gguf_writer.add_ssm_inner_size(d_inner)
self.gguf_writer.add_ssm_state_size(d_state)
Expand Down Expand Up @@ -2384,6 +2384,107 @@ def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: i
)


@Model.register("JambaForCausalLM")
class JambaModel(Model):
model_arch = gguf.MODEL_ARCH.JAMBA

def get_vocab_base_pre(self, tokenizer) -> str:
del tokenizer # unused

return "gpt-2"

def set_gguf_parameters(self):
d_model = self.find_hparam(["hidden_size", "mamba_d_model"])
d_conv = self.find_hparam(["mamba_d_conv"], optional=True) or 4
d_inner = self.hparams["mamba_expand"] * d_model
d_state = self.find_hparam(["mamba_d_state"], optional=True) or 16
# ceiling division
# ref: https://stackoverflow.com/a/17511341/22827863
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
dt_rank = self.find_hparam(["mamba_dt_rank"], optional=True) or -(d_model // -16)
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-6
n_kv_head = self.hparams["num_key_value_heads"]
attn_offset = self.hparams["attn_layer_offset"]
attn_period = self.hparams["attn_layer_period"]
n_kv_vec = [0 for _ in range(attn_offset)] + [
n_kv_head if (i - attn_offset) % attn_period == 0 else 0 for i in range(attn_offset, self.block_count)
]

self.gguf_writer.add_name(self.dir_model.name)
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
self.gguf_writer.add_embedding_length(d_model)
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(n_kv_vec)
self.gguf_writer.add_ssm_conv_kernel(d_conv)
self.gguf_writer.add_ssm_inner_size(d_inner)
self.gguf_writer.add_ssm_state_size(d_state)
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"])
self.gguf_writer.add_file_type(self.ftype)

_experts: list[dict[str, Tensor]] | None = None

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:

# process the experts separately
if ".feed_forward.experts." in name:
n_experts = self.hparams["num_experts"]

assert bid is not None

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch

if len(self._experts[bid]) >= n_experts * 3:

# merge the experts into a single 3d tensor
for wid in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"model.layers.{bid}.feed_forward.experts.{xid}.{wid}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]

data_torch = torch.stack(datas, dim=0)

# using the same merged name as qwen2moe
merged_name = f"model.layers.{bid}.mlp.experts.{wid}.weight"

new_name = self.map_tensor_name(merged_name)

yield new_name, data_torch
return

new_name = self.map_tensor_name(name)

if name.endswith(".A_log"):
logger.debug("A_log --> A ==> " + new_name)
data_torch = -torch.exp(data_torch)

yield new_name, data_torch

# same as Mamba
def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
del n_dims # unused

return bid is not None and new_name in (
self.format_tensor_name(n, bid, ".weight" if name.endswith(".weight") else "") for n in [
gguf.MODEL_TENSOR.SSM_CONV1D,
gguf.MODEL_TENSOR.SSM_X,
gguf.MODEL_TENSOR.SSM_DT,
gguf.MODEL_TENSOR.SSM_A,
gguf.MODEL_TENSOR.SSM_D,
]
)


@Model.register("CohereForCausalLM")
class CommandR2Model(Model):
model_arch = gguf.MODEL_ARCH.COMMAND_R
Expand Down
96 changes: 38 additions & 58 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -7094,19 +7094,18 @@ struct ggml_tensor * ggml_ssm_conv(
GGML_ASSERT(ggml_is_3d(s));
GGML_ASSERT(ggml_is_matrix(x));
GGML_ASSERT(ggml_is_matrix(c));
GGML_ASSERT(ggml_is_matrix(sq));
GGML_ASSERT(ggml_is_vector(sq));
GGML_ASSERT(sq->type == GGML_TYPE_I32);

const int64_t d_conv = c->ne[0];
const int64_t d_inner = c->ne[1];
const int64_t n_tokens = x->ne[1];
const int64_t n_kv = s->ne[2];
const int64_t n_rs = s->ne[2];

GGML_ASSERT( s->ne[0] == d_conv - 1);
GGML_ASSERT( s->ne[1] == d_inner);
GGML_ASSERT( x->ne[0] == d_inner);
GGML_ASSERT(sq->ne[0] == n_kv);
GGML_ASSERT(sq->ne[1] == n_tokens);
GGML_ASSERT(sq->ne[0] == n_tokens);

bool is_node = false;

Expand All @@ -7115,8 +7114,8 @@ struct ggml_tensor * ggml_ssm_conv(
is_node = true;
}

// 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_kv}
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_kv));
// 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_rs}
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_rs));

result->op = GGML_OP_SSM_CONV;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
Expand Down Expand Up @@ -7169,7 +7168,7 @@ struct ggml_tensor * ggml_ssm_scan(
is_node = true;
}

// 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_kv}
// 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_rs}
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));

result->op = GGML_OP_SSM_SCAN;
Expand Down Expand Up @@ -16241,9 +16240,9 @@ static void ggml_compute_forward_ssm_conv_f32(
const int nc = src2->ne[0]; // d_conv
const int nr = src0->ne[1]; // d_inner
const int n_t = src1->ne[1]; // n_tokens
const int n_kv = src0->ne[2]; // max number of sequences in the batch
const int n_rs = src0->ne[2]; // max number of sequences in the batch

GGML_ASSERT((nr*n_t) + (nc*nr*n_kv) == ggml_nelements(dst));
GGML_ASSERT((nr*n_t) + (nc*nr*n_rs) == ggml_nelements(dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float));
Expand All @@ -16260,10 +16259,12 @@ static void ggml_compute_forward_ssm_conv_f32(
const int ir1 = MIN(ir0 + dr, nr);
const int ir = ir1 - ir0;

if (n_kv > 1) {
const int32_t * sq = src3->data; // {n_tokens}

if (n_rs > 1) {
// multiple sequences means it's hard to know when it's the first time a state is read,
// so copy them all over to the destination, just to be sure.
for (int i3 = 0; i3 < n_kv; ++i3) {
for (int i3 = 0; i3 < n_rs; ++i3) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm looking at adding the missing Metal kernels for SSM_CONV and SSM_SCAN. I'm wondering if this part of the kernels where we copy src0 -> dst could be extracted outside of the operation via ggml_cpy + ggml_view or ggml_acc? Would simplify the implementation

Also, I still haven't understood the details of the computation, but if we find a way to express these ops via existing ops all together (e.g. using ggml_conv, ggml_mul_mat, ...), it would be preferred to do so, in order to reduce the amount of kernels that we have to write.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if this part of the kernels where we copy src0 -> dst could be extracted outside of the operation via ggml_cpy + ggml_view or ggml_acc? Would simplify the implementation

Yes, this is definitely possible. I'll find a way to extract the copies outside.

if we find a way to express these ops via existing ops all together (e.g. using ggml_conv, ggml_mul_mat, ...), it would be preferred to do so, in order to reduce the amount of kernels that we have to write.

For SSM_SCAN, I think there's a way to fully express it in terms of other ops, though it will use much more memory because of the big intermediate tensors, and new operators like SOFT_PLUS and EXP would be needed instead. But different lengths of simultaneous sequences might make a custom operator still necessary. I'll think about ways to make it simpler, especially since other recurrent architectures (like RWKV) will also need to work on multiple sequences per batch.

For simplifying SSM_CONV, I don't think ggml_conv supports working on independent 1D rolling windows with varying sequence lengths.

When working on a single sequence, though, it's quite simple to do the equivalent of ggml_ssm_conv with a self-overlapping view, as I did in my original implementation which I described in more detail in #5328 (comment):

https://github.com/ggerganov/llama.cpp/blob/64fbce052373faf07a36b599528f8fe1cb1d62fb/llama.cpp#L6973-L6982

Setting nb[2] to the element size makes the view self-overlapping.

But this would create too many nodes in the compute graph when done with multiple sequences (unless they're always all the same length in which case the 4th dimension could be used), so a custom operator is necessary.

Copy link
Member

@ggerganov ggerganov May 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One idea that we might consider is to unfuse the n_rs dimension from the SSM ops and make them work per 1 recurrent state. Then, during inference and right before the SSM operations, we split the batch into same-sequence chunks and SSM them individually. After that we concat back the results into the full hidden state for the batch

The main goal would be to simplify the SSM operators, and potentially express them as other existing ops if possible. But additionally, I'm considering a similar processing mode for the standard transformer KV cache in which we don't rely on a "unified" buffer for all the sequences, but instead each sequence has it's own separate KV cache buffer. In that mode, we would do a similar same-sequence batch split before the attention. The main purpose of supporting this mode would be to achieve reproducible results during parallel decoding (currently, decoding the same sequence in parallel can yield slightly different results due to the unified KV cache).

Just throwing some thoughts that I have so far - will continue looking at the PR in the next days

Edit: I was writing this comment before I saw you posted - will take a look tomorrow

Copy link
Collaborator Author

@compilade compilade May 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One idea that we might consider is to unfuse the n_rs dimension from the SSM ops and make them work per 1 recurrent state. Then, during inference and right before the SSM operations, we split the batch into same-sequence chunks and SSM them individually. After that we concat back the results into the full hidden state for the batch

Yes, this would be doable, but would make the number of compute graph nodes scale with the number of sequences. (EDIT: if it's split when making ubatches, then the number of compute graph nodes can stay constant)

Another way would be to make all sequences have the same number of new tokens in a ubatch, to allow using another dimension instead of having to loop when building the compute graphs. This would still allow batching multiple sequences with recurrent models, but without the need for new custom operators for each architecture, and still with a constant number of compute graph nodes.

The recurrent steps are simpler for ubatches with sequence lengths of 1, but prompt processing performance would be much slower than with a per-recurrent-architecture operator for longer sequences. Still thinking about ways to generalize this while keeping good performance.

But additionally, I'm considering a similar processing mode for the standard transformer KV cache in which we don't rely on a "unified" buffer for all the sequences, but instead each sequence has it's own separate KV cache buffer. In that mode, we would do a similar same-sequence batch split before the attention.

For the transformer KV cache, if there's logic to make all sequences within a ubatch to have the same number of new tokens, I think a mode to split batches sequence-wise will be simpler and could re-use much of the same code.

I also think there's a way to keep the unified KV cache (one buffer) and chunk it to make each sequence have their own independent contiguous reserved cells. Batching sequences together might still be possible though, if the KQ mask gets another dimension (the number of sequences in the ubatch, and the number of new tokens per sequence instead of the batch size) so that these equal-sized "chunks" get processed independently in parallel. But this might not work (because the newly-calculated KV cells have to be copied in a bunch of not-regularly-spaced places), unless... unless maybe with some kind of ggml_set_rows? Not sure about the transposed V cache, though.

A sequence-wise processing mode is likely simpler, although it's not really parallel processing then (the model weights are all read at each ubatch).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it's split when making ubatches, then the number of compute graph nodes can stay constant

No, it has to be split only for the attention so that the rest of the ops are still batched. Otherwise we will sacrifice a lot of performance

Another way would be to make all sequences have the same number of new tokens in a ubatch, to allow using another dimension instead of having to loop when building the compute graphs. This would still allow batching multiple sequences with recurrent models, but without the need for new custom operators for each architecture, and still with a constant number of compute graph nodes.

For the transformer KV cache, if there's logic to make all sequences within a ubatch to have the same number of new tokens, I think a mode to split batches sequence-wise will be simpler and could re-use much of the same code.

Not sure how that would work. Adding dummy tokens sounds too much overhead (at least
in the case of the regular transformer). Any other ideas?

A sequence-wise processing mode is likely simpler, although it's not really parallel processing then (the model weights are all read at each ubatch).

From a broad PoV, if we have an implementation that works with a single-sequence and any batch size, then to extend it to multi-sequence batches we can split the batch into same-sequence tokens right before the attention and merge it back after the attention. Each split will do what we already do for the single-sequence solution, using separate cache for each sequence. I didn't consider the number of nodes until you noted - so that might be a problem indeed.

I'm currently working on a big refactor of how Mamba (and Jamba) works to make all sequences of a sub-batch be of the same length (initially only for models with recurrent states), and to make recurrent state slots contiguous, with the goal of simplifying the SSM operations (and removing GGML_OP_SSM_CONV), so that GPU support will be much easier to implement after that.

Looking forward to this!

Copy link
Collaborator Author

@compilade compilade May 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it has to be split only for the attention so that the rest of the ops are still batched. Otherwise we will sacrifice a lot of performance

It will sacrifice some performance, but only in the cases where a batch contains an unequal number of tokens for each affected sequence. So this should not affect large prompt processing or parallel text generation, if both are not done in the same batch.

Not sure how that would work. Adding dummy tokens sounds too much overhead (at least
in the case of the regular transformer). Any other ideas?

This is not about adding dummy tokens, but about making the number of new tokens in each ubatch the same per sequence. I think the overhead will be minmal, though there is still some.

Let me illustrate.

Let's say there's a batch with new tokens for 4 sequences of length 16, 7, 1, 1, respectively.

0: ################
1: #######
2: #
3: #

Splitting that into equal-length sequences would make 3 ubatches, like so:

0: #
1: #
2: #
3: #
0: ######
1: ######
0: #########

Each of these shapes are nice and rectangular, which is good for recurrent architectures because their operations can be more easily batched across sequences this way.

But I'm not yet sure if it would also benefit Transformers, which is why I'm thinking of initially only enabling the equal-length splitting for recurrent (or hybrid) model architectures.

From a broad PoV, if we have an implementation that works with a single-sequence and any batch size, then to extend it to multi-sequence batches we can split the batch into same-sequence tokens right before the attention and merge it back after the attention. Each split will do what we already do for the single-sequence solution, using separate cache for each sequence. I didn't consider the number of nodes until you noted - so that might be a problem indeed.

Doing this with a constant number of graph nodes is pretty much what using same-length sequences (as illustrated above) allows, because the split into same-sequence tokens can then simply become another tensor dimension.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha, got it. Good idea. I'm also not sure if this can help Transformers, but it's something to think about 👍

float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float));
// can't use memcpy because of d_conv vs d_conv - 1
Expand All @@ -16277,19 +16278,19 @@ static void ggml_compute_forward_ssm_conv_f32(
}

for (int i2 = 0; i2 < n_t; ++i2) {
int32_t * sq = (int32_t *) ((char *) src3->data + i2*(src3->nb[1])); // {n_kv, n_tokens}
float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens}
float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq[0]*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_kv}
float * s0; // {d_conv - 1, d_inner, n_kv}
float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner}
int32_t sq_i = sq[i2];
float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens}
float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq_i*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_rs}
float * s0; // {d_conv - 1, d_inner, n_rs}
float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner}
int ne0s0;

GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
GGML_ASSERT(0 <= sq_i && sq_i < n_rs);

// avoid needing to copy the state for the first token
if (i2 == 0) {
s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_conv - 1, d_inner, n_kv}
s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2])); // {d_conv - 1, d_inner, n_rs}
ne0s0 = src0->ne[0];
} else {
// the source is the last (d_conv - 1) columns of the destination
Expand All @@ -16307,18 +16308,6 @@ static void ggml_compute_forward_ssm_conv_f32(
s[(nc - 1) + i1*nc] = x0[i1];
}

// handle copies when there are multiple output states
for (int i3 = 1; i3 < n_kv; ++i3) {
int32_t seq = sq[i3];
if (0 <= seq && seq < n_kv) {
float * s1 = s + (seq - sq[0])*nc*nr;
memcpy(s1, s, nc*ir*sizeof(float));
} else {
// stop at negative or too big seq_ids
break;
}
}

// it seems a little faster when this is separate from the state shift
for (int i1 = 0; i1 < ir; ++i1) {
// rowwise dot product
Expand Down Expand Up @@ -16370,7 +16359,7 @@ static void ggml_compute_forward_ssm_scan_f32(
const int64_t nc = src0->ne[0]; // d_state
const int64_t nr = src0->ne[1]; // d_inner
const int64_t n_t = src1->ne[1]; // number of tokens in the batch
const int64_t n_kv = src0->ne[2]; // max number of sequences in the batch
const int64_t n_rs = src0->ne[2]; // max number of sequences in the batch

GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
Expand All @@ -16379,6 +16368,7 @@ static void ggml_compute_forward_ssm_scan_f32(
GGML_ASSERT(src3->nb[0] == sizeof(float));
GGML_ASSERT(src4->nb[0] == sizeof(float));
GGML_ASSERT(src5->nb[0] == sizeof(float));
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
// required for the dot product between s and C, and when copying the states
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
// required for per-sequence offsets for states
Expand All @@ -16394,32 +16384,34 @@ static void ggml_compute_forward_ssm_scan_f32(
const int ir1 = MIN(ir0 + dr, nr);
const int ir = ir1 - ir0;

if (n_kv > 1) {
const int32_t * sq = src6->data; // {n_tokens}

if (n_rs > 1) {
// it's hard to know if the source states have already been copied
// when there are multiple, so copy them already.
for (int i3 = 0; i3 < n_kv; ++i3) {
for (int i3 = 0; i3 < n_rs; ++i3) {
float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]);
memcpy(s, s0, nc*ir*sizeof(float));
}
}

for (int i2 = 0; i2 < n_t; ++i2) {
int32_t * sq = (int32_t *) ((char *) src6->data + i2*(src6->nb[1])); // {n_kv, n_tokens}
float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_kv}
float * s0;
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens}
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens}

GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
int32_t sq_i = sq[i2];
float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_rs}
float * s0;
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens}
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens}

GGML_ASSERT(0 <= sq_i && sq_i < n_rs);

// avoid needing to copy the state for the first token
if (i2 == 0) {
s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_state, d_inner, n_kv}
s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2])); // {d_state, d_inner, n_rs}
} else {
// otherwise the source is the same as the destination
s0 = s;
Expand All @@ -16442,18 +16434,6 @@ static void ggml_compute_forward_ssm_scan_f32(
}
y[i1] = sumf;
}

// handle copies when there are multiple output states
for (int i3 = 1; i3 < n_kv; ++i3) {
int32_t seq = sq[i3];
if (0 <= seq && seq < n_kv) {
float * s1 = s + (seq - sq[0])*nc*nr;
memcpy(s1, s, nc*ir*sizeof(float));
} else {
// stop at negative or too big seq_ids
break;
}
}
}
}

Expand Down
36 changes: 36 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ class MODEL_ARCH(IntEnum):
GEMMA = auto()
STARCODER2 = auto()
MAMBA = auto()
JAMBA = auto()
XVERSE = auto()
COMMAND_R = auto()
DBRX = auto()
Expand Down Expand Up @@ -182,7 +183,10 @@ class MODEL_TENSOR(IntEnum):
SSM_CONV1D = auto()
SSM_X = auto()
SSM_DT = auto()
SSM_DT_NORM = auto()
SSM_A = auto()
SSM_B_NORM = auto()
SSM_C_NORM = auto()
SSM_D = auto()
SSM_OUT = auto()

Expand Down Expand Up @@ -216,6 +220,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.GEMMA: "gemma",
MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.JAMBA: "jamba",
MODEL_ARCH.XVERSE: "xverse",
MODEL_ARCH.COMMAND_R: "command-r",
MODEL_ARCH.DBRX: "dbrx",
Expand Down Expand Up @@ -263,7 +268,10 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",
MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x",
MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt",
MODEL_TENSOR.SSM_DT_NORM: "blk.{bid}.ssm_dt_norm",
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
MODEL_TENSOR.SSM_B_NORM: "blk.{bid}.ssm_b_norm",
MODEL_TENSOR.SSM_C_NORM: "blk.{bid}.ssm_c_norm",
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
}
Expand Down Expand Up @@ -682,6 +690,34 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.SSM_D,
MODEL_TENSOR.SSM_OUT,
],
MODEL_ARCH.JAMBA: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.SSM_IN,
MODEL_TENSOR.SSM_CONV1D,
MODEL_TENSOR.SSM_X,
MODEL_TENSOR.SSM_DT,
MODEL_TENSOR.SSM_DT_NORM,
MODEL_TENSOR.SSM_A,
MODEL_TENSOR.SSM_B_NORM,
MODEL_TENSOR.SSM_C_NORM,
MODEL_TENSOR.SSM_D,
MODEL_TENSOR.SSM_OUT,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
MODEL_ARCH.XVERSE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down
Loading
Loading