Skip to content

Commit 9c4c257

Browse files
committed
mamba : multiple sequences, but one at a time
This is a step towards making this Mamba implementation usable with the server example (the way the system prompt is kept when clearing the client slots will need to be changed before this can work, though). The KV cache size for this kind of model is tied to the maximum number of sequences kept at any single time. For now, this number is obtained from n_parallel (plus one, to have an extra sequence to dedicate to the system prompt), but there might be a better way to do this which won't also make the main example use 2 cells even if only 1 is really used. (for this specific case, --parallel 0 helps) Simultaneous sequence processing will probably require changes to ggml_ssm_scan, and possibly a new operator for the conv step. * mamba : support llama_kv_cache_seq_cp This (mis)uses the logic around K shifts, because tokens in a state can't be shifted anyway, and because inp_K_shift has the right shape and type. Using ggml_get_rows is a nice way to do copies, but copy chains can't work. Fortunately, copy chains don't really seem to be used in the examples. Each KV cell is dedicated to the sequence ID corresponding to its own index. * mamba : use a state mask It's cleaner than the previous heuristic of checking for the pos of the first token in the batch. inp_KQ_mask could not be re-used for this, because it has the wrong shape and because it seems more suited to the next step of simultaneous sequence processing (helping with the problem of remembering which token belongs to which sequence(s)/state(s)). * llama : replace the usage of n_ctx with kv_self.size in many places * mamba : use n_tokens directly instead of n_tok
1 parent 98e6328 commit 9c4c257

File tree

4 files changed

+253
-88
lines changed

4 files changed

+253
-88
lines changed

common/common.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,6 +1176,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
11761176

11771177
cparams.n_ctx = params.n_ctx;
11781178
cparams.n_batch = params.n_batch;
1179+
cparams.n_parallel = params.n_parallel;
11791180
cparams.n_threads = params.n_threads;
11801181
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
11811182
cparams.mul_mat_q = params.mul_mat_q;

ggml.c

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5905,15 +5905,15 @@ struct ggml_tensor * ggml_ssm_scan(
59055905
GGML_ASSERT(ggml_is_matrix(s)); // the ssm_state should be 2D
59065906

59075907
{
5908-
const int64_t d_state = s->ne[0];
5909-
const int64_t d_inner = s->ne[1];
5910-
const int64_t n_tok = x->ne[1];
5908+
const int64_t d_state = s->ne[0];
5909+
const int64_t d_inner = s->ne[1];
5910+
const int64_t n_tokens = x->ne[1];
59115911

59125912
GGML_ASSERT(x->ne[0] == d_inner);
59135913
GGML_ASSERT(A->ne[0] == d_state);
59145914
GGML_ASSERT(A->ne[1] == d_inner);
59155915
GGML_ASSERT(B->ne[0] == d_state);
5916-
GGML_ASSERT(B->ne[1] == n_tok);
5916+
GGML_ASSERT(B->ne[1] == n_tokens);
59175917
}
59185918

59195919
bool is_node = false;
@@ -14178,12 +14178,12 @@ static void ggml_compute_forward_ssm_scan_f32(
1417814178

1417914179
// first batch
1418014180
{
14181-
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tok}
14181+
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tokens}
1418214182
float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner}
14183-
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tok}
14184-
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tok}
14183+
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tokens}
14184+
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tokens}
1418514185
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
14186-
float * B = (float *) ((char *) src4->data); // {d_state, n_tok}
14186+
float * B = (float *) ((char *) src4->data); // {d_state, n_tokens}
1418714187
// d_inner
1418814188
for (int i1 = 0; i1 < ir; ++i1) {
1418914189
float dt_soft_plus = log1pf(expf(dt[i1]));
@@ -14199,12 +14199,12 @@ static void ggml_compute_forward_ssm_scan_f32(
1419914199

1420014200
// compute state for rest of tokens, previous state comes from dest
1420114201
for (int i2 = 1; i2 < n_t; ++i2) {
14202-
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tok}
14203-
float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); // {d_state, d_inner, n_tok}
14204-
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2 *(src1->nb[1])); // {d_inner, n_tok}
14205-
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2 *(src2->nb[1])); // {d_inner, n_tok}
14202+
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tokens}
14203+
float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); // {d_state, d_inner, n_tokens}
14204+
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2 *(src1->nb[1])); // {d_inner, n_tokens}
14205+
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2 *(src2->nb[1])); // {d_inner, n_tokens}
1420614206
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
14207-
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tok}
14207+
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
1420814208
// d_inner
1420914209
for (int i1 = 0; i1 < ir; ++i1) {
1421014210
float dt_soft_plus = log1pf(expf(dt[i1]));

0 commit comments

Comments
 (0)