-
Notifications
You must be signed in to change notification settings - Fork 12k
llama : support Jamba hybrid Transformer-Mamba models #7531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
compilade
wants to merge
41
commits into
master
Choose a base branch
from
compilade/refactor-kv-cache
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 14 commits
Commits
Show all changes
41 commits
Select commit
Hold shift + click to select a range
271104c
wip: llama : separate recurrent states from the KV cache
compilade 8db1e4d
llama : use std::find for seq_nodes in llama_rs_cache
compilade 0028010
llama : state checkpoints for recurrent models
compilade 0c8b3b2
llama : correctly handle more edge cases for the rs cache
compilade d66849f
Merge branch 'master' into compilade/refactor-kv-cache
compilade a09db95
llama : rename many llama_kv_cache_* functions
compilade c460ff1
Merge branch 'master' into compilade/refactor-kv-cache
compilade b6fafd1
llama : remove useless return value for some llama_cache_* functions
compilade b7ec12e
Merge branch 'master' into compilade/refactor-kv-cache
compilade 3b57b55
Merge branch 'master' into compilade/refactor-kv-cache
compilade 7e13f19
llama : rethink recurrent state cell counts
compilade cbc743e
llama : support Jamba
compilade 0fd13e9
Merge branch 'master' into compilade/refactor-kv-cache
compilade 61a88a1
llama : fix BERT inference without KV cache
compilade ea2e63e
convert-hf : check for unprocessed Jamba experts
compilade fc59407
convert-hf : support Mini-Jamba conversion
compilade 181dadf
llama : fix Jamba quantization sanity checks
compilade 3a414b0
llama : sequence-length-aware batch splitting
compilade 4e4c41e
Merge branch 'master' into compilade/refactor-kv-cache
compilade 3587a94
llama : use equal-sequence-length sub-batches for recurrent models
compilade 5d3c7b9
Merge branch 'master' into compilade/refactor-kv-cache
compilade 72eea49
llama : fix batch split output count for embeddings
compilade 18d1c14
llama : minimize swaps when reordering logits
compilade 61200ef
llama : fix edge case finding batch seq_id of split recurrent cell
compilade eb589d5
llama : avoid copies for simple batch splits
compilade 8fb57ac
llama : use im2col and mul_mat to perform convolution for Mamba
compilade 17f6c1e
llama : fix .base() compilation error on Windows
compilade fee3c1d
llama : allow doing the equivalent of SSM_CONV with SUM_ROWS and MUL
compilade 6840ac0
Merge branch 'master' into compilade/refactor-kv-cache
compilade 372482d
llama : rename llama_cache to llama_past
compilade 43d8d4b
examples : replace llama_kv_cache_seq_* with llama_past_seq_*
compilade ff794f5
Merge branch 'master' into compilade/refactor-kv-cache
compilade 33425a7
mamba : fix non-contiguous usage of ggml_silu
compilade 10c3c41
Merge branch 'master' into compilade/refactor-kv-cache
compilade 9b38f8b
Merge branch 'master' into compilade/refactor-kv-cache
compilade bc320ef
Merge branch 'master' into compilade/refactor-kv-cache
compilade fcb889c
llama : session saving and reloading for hybrid models
compilade a03e32a
Merge branch 'master' into compilade/refactor-kv-cache
compilade 9d3f44d
convert_hf : fix Jamba conversion
compilade 5f62db7
llama : fix mixed signedness comparison
compilade 375de5b
llama : use unused n_embd_k_gqa in k_shift
compilade File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm looking at adding the missing Metal kernels for
SSM_CONV
andSSM_SCAN
. I'm wondering if this part of the kernels where we copysrc0
->dst
could be extracted outside of the operation viaggml_cpy
+ggml_view
orggml_acc
? Would simplify the implementationAlso, I still haven't understood the details of the computation, but if we find a way to express these ops via existing ops all together (e.g. using
ggml_conv
,ggml_mul_mat
, ...), it would be preferred to do so, in order to reduce the amount of kernels that we have to write.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is definitely possible. I'll find a way to extract the copies outside.
For
SSM_SCAN
, I think there's a way to fully express it in terms of other ops, though it will use much more memory because of the big intermediate tensors, and new operators likeSOFT_PLUS
andEXP
would be needed instead. But different lengths of simultaneous sequences might make a custom operator still necessary. I'll think about ways to make it simpler, especially since other recurrent architectures (like RWKV) will also need to work on multiple sequences per batch.For simplifying
SSM_CONV
, I don't thinkggml_conv
supports working on independent 1D rolling windows with varying sequence lengths.When working on a single sequence, though, it's quite simple to do the equivalent of
ggml_ssm_conv
with a self-overlapping view, as I did in my original implementation which I described in more detail in #5328 (comment):https://github.com/ggerganov/llama.cpp/blob/64fbce052373faf07a36b599528f8fe1cb1d62fb/llama.cpp#L6973-L6982
Setting
nb[2]
to the element size makes the view self-overlapping.But this would create too many nodes in the compute graph when done with multiple sequences (unless they're always all the same length in which case the 4th dimension could be used), so a custom operator is necessary.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One idea that we might consider is to unfuse the
n_rs
dimension from the SSM ops and make them work per 1 recurrent state. Then, during inference and right before the SSM operations, we split the batch into same-sequence chunks and SSM them individually. After that we concat back the results into the full hidden state for the batchThe main goal would be to simplify the SSM operators, and potentially express them as other existing ops if possible. But additionally, I'm considering a similar processing mode for the standard transformer KV cache in which we don't rely on a "unified" buffer for all the sequences, but instead each sequence has it's own separate KV cache buffer. In that mode, we would do a similar same-sequence batch split before the attention. The main purpose of supporting this mode would be to achieve reproducible results during parallel decoding (currently, decoding the same sequence in parallel can yield slightly different results due to the unified KV cache).
Just throwing some thoughts that I have so far - will continue looking at the PR in the next days
Edit: I was writing this comment before I saw you posted - will take a look tomorrow
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this would be doable,
but would make the number of compute graph nodes scale with the number of sequences.(EDIT: if it's split when making ubatches, then the number of compute graph nodes can stay constant)Another way would be to make all sequences have the same number of new tokens in a ubatch, to allow using another dimension instead of having to loop when building the compute graphs. This would still allow batching multiple sequences with recurrent models, but without the need for new custom operators for each architecture, and still with a constant number of compute graph nodes.
The recurrent steps are simpler for ubatches with sequence lengths of
1
, but prompt processing performance would be much slower than with a per-recurrent-architecture operator for longer sequences. Still thinking about ways to generalize this while keeping good performance.For the transformer KV cache, if there's logic to make all sequences within a ubatch to have the same number of new tokens, I think a mode to split batches sequence-wise will be simpler and could re-use much of the same code.
I also think there's a way to keep the unified KV cache (one buffer) and chunk it to make each sequence have their own independent contiguous reserved cells. Batching sequences together might still be possible though, if the KQ mask gets another dimension (the number of sequences in the ubatch, and the number of new tokens per sequence instead of the batch size) so that these equal-sized "chunks" get processed independently in parallel. But this might not work (because the newly-calculated KV cells have to be copied in a bunch of not-regularly-spaced places), unless... unless maybe with some kind of
ggml_set_rows
? Not sure about the transposed V cache, though.A sequence-wise processing mode is likely simpler, although it's not really parallel processing then (the model weights are all read at each ubatch).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it has to be split only for the attention so that the rest of the ops are still batched. Otherwise we will sacrifice a lot of performance
Not sure how that would work. Adding dummy tokens sounds too much overhead (at least
in the case of the regular transformer). Any other ideas?
From a broad PoV, if we have an implementation that works with a single-sequence and any batch size, then to extend it to multi-sequence batches we can split the batch into same-sequence tokens right before the attention and merge it back after the attention. Each split will do what we already do for the single-sequence solution, using separate cache for each sequence. I didn't consider the number of nodes until you noted - so that might be a problem indeed.
Looking forward to this!
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will sacrifice some performance, but only in the cases where a batch contains an unequal number of tokens for each affected sequence. So this should not affect large prompt processing or parallel text generation, if both are not done in the same batch.
This is not about adding dummy tokens, but about making the number of new tokens in each ubatch the same per sequence. I think the overhead will be minmal, though there is still some.
Let me illustrate.
Let's say there's a batch with new tokens for 4 sequences of length 16, 7, 1, 1, respectively.
Splitting that into equal-length sequences would make 3 ubatches, like so:
Each of these shapes are nice and rectangular, which is good for recurrent architectures because their operations can be more easily batched across sequences this way.
But I'm not yet sure if it would also benefit Transformers, which is why I'm thinking of initially only enabling the equal-length splitting for recurrent (or hybrid) model architectures.
Doing this with a constant number of graph nodes is pretty much what using same-length sequences (as illustrated above) allows, because the split into same-sequence tokens can then simply become another tensor dimension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aha, got it. Good idea. I'm also not sure if this can help Transformers, but it's something to think about 👍