@@ -7270,32 +7270,48 @@ struct ggml_tensor * ggml_ssm_scan(
7270
7270
struct ggml_tensor * dt,
7271
7271
struct ggml_tensor * A,
7272
7272
struct ggml_tensor * B,
7273
- struct ggml_tensor * C) {
7273
+ struct ggml_tensor * C,
7274
+ struct ggml_tensor * D) {
7274
7275
GGML_ASSERT(ggml_is_contiguous(s));
7275
- GGML_ASSERT(ggml_is_contiguous(x));
7276
7276
GGML_ASSERT(ggml_is_contiguous(dt));
7277
7277
GGML_ASSERT(ggml_is_contiguous(A));
7278
- GGML_ASSERT(ggml_is_matrix(A));
7279
- GGML_ASSERT(ggml_is_3d(B));
7280
- GGML_ASSERT(ggml_is_3d(s));
7278
+ GGML_ASSERT(x->nb[0] == ggml_type_size(x->type));
7281
7279
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
7282
7280
GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
7283
- GGML_ASSERT(ggml_are_same_shape(x, dt));
7281
+ GGML_ASSERT(x->nb[1] == x->ne[0]*x->nb[0]);
7282
+ GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]);
7283
+ GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]);
7284
7284
GGML_ASSERT(ggml_are_same_shape(B, C));
7285
7285
7286
7286
{
7287
7287
const int64_t d_state = s->ne[0];
7288
- const int64_t d_inner = s->ne[1];
7289
- const int64_t n_seq_tokens = x->ne[1];
7290
- const int64_t n_seqs = x->ne[2];
7291
-
7292
- GGML_ASSERT(s->ne[2] == n_seqs);
7293
- GGML_ASSERT(x->ne[0] == d_inner);
7294
- GGML_ASSERT(A->ne[0] == d_state);
7295
- GGML_ASSERT(A->ne[1] == d_inner);
7288
+ const int64_t head_dim = x->ne[0];
7289
+ const int64_t n_head = x->ne[1];
7290
+ const int64_t n_seq_tokens = x->ne[2];
7291
+ const int64_t n_seqs = x->ne[3];
7292
+
7293
+ GGML_ASSERT(dt->ne[0] == n_head);
7294
+ GGML_ASSERT(dt->ne[1] == n_seq_tokens);
7295
+ GGML_ASSERT(dt->ne[2] == n_seqs);
7296
+ GGML_ASSERT(ggml_is_3d(dt));
7297
+ GGML_ASSERT(s->ne[1] == head_dim);
7298
+ GGML_ASSERT(s->ne[2] == n_head);
7299
+ GGML_ASSERT(s->ne[3] == n_seqs);
7296
7300
GGML_ASSERT(B->ne[0] == d_state);
7297
- GGML_ASSERT(B->ne[1] == n_seq_tokens);
7298
- GGML_ASSERT(B->ne[2] == n_seqs);
7301
+ GGML_ASSERT(B->ne[2] == n_seq_tokens);
7302
+ GGML_ASSERT(B->ne[3] == n_seqs);
7303
+ GGML_ASSERT(D->ne[0] == n_head);
7304
+ GGML_ASSERT(ggml_is_vector(D));
7305
+
7306
+ if (ggml_is_vector(A)) {
7307
+ // Mamba-2
7308
+ GGML_ASSERT(A->ne[0] == n_head);
7309
+ } else {
7310
+ // Mamba-1
7311
+ GGML_ASSERT(A->ne[0] == d_state);
7312
+ GGML_ASSERT(A->ne[1] == n_head);
7313
+ GGML_ASSERT(ggml_is_matrix(A));
7314
+ }
7299
7315
}
7300
7316
7301
7317
bool is_node = false;
@@ -7316,6 +7332,7 @@ struct ggml_tensor * ggml_ssm_scan(
7316
7332
result->src[3] = A;
7317
7333
result->src[4] = B;
7318
7334
result->src[5] = C;
7335
+ result->src[6] = D;
7319
7336
7320
7337
return result;
7321
7338
}
@@ -15840,20 +15857,25 @@ static void ggml_compute_forward_ssm_conv(
15840
15857
static void ggml_compute_forward_ssm_scan_f32(
15841
15858
const struct ggml_compute_params * params,
15842
15859
struct ggml_tensor * dst) {
15843
- const struct ggml_tensor * src0 = dst->src[0]; // s
15844
- const struct ggml_tensor * src1 = dst->src[1]; // x
15845
- const struct ggml_tensor * src2 = dst->src[2]; // dt
15846
- const struct ggml_tensor * src3 = dst->src[3]; // A
15847
- const struct ggml_tensor * src4 = dst->src[4]; // B
15848
- const struct ggml_tensor * src5 = dst->src[5]; // C
15860
+ const struct ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs}
15861
+ const struct ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
15862
+ const struct ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
15863
+ const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {n_head}
15864
+ const struct ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
15865
+ const struct ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
15866
+ const struct ggml_tensor * src6 = dst->src[6]; // D {n_head}
15849
15867
15850
15868
const int ith = params->ith;
15851
15869
const int nth = params->nth;
15852
15870
15853
- const int64_t nc = src0->ne[0]; // d_state
15854
- const int64_t nr = src0->ne[1]; // d_inner
15855
- const int64_t n_t = src1->ne[1]; // number of tokens per sequence
15856
- const int64_t n_s = src0->ne[2]; // number of sequences in the batch
15871
+ const int64_t nc = src0->ne[0]; // d_state
15872
+ const int64_t nr = src0->ne[1]; // dim
15873
+ const int64_t nh = src1->ne[1]; // n_head
15874
+ const int64_t ng = src4->ne[1];
15875
+ const int64_t nt = src1->ne[2]; // number of tokens per sequence
15876
+ const int64_t ns = src0->ne[3]; // number of sequences in the batch
15877
+
15878
+ const int64_t s_off = ggml_element_size(src1) * ggml_nelements(src1);
15857
15879
15858
15880
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
15859
15881
GGML_ASSERT(src0->nb[0] == sizeof(float));
@@ -15862,51 +15884,86 @@ static void ggml_compute_forward_ssm_scan_f32(
15862
15884
GGML_ASSERT(src3->nb[0] == sizeof(float));
15863
15885
GGML_ASSERT(src4->nb[0] == sizeof(float));
15864
15886
GGML_ASSERT(src5->nb[0] == sizeof(float));
15865
- // required for the dot product between s and C
15866
- GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
15867
- // required for per-sequence offsets for states
15868
- GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
15869
- // required to get correct offset for state destination (i.e. src1->nb[3])
15870
- GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
15871
-
15872
- // rows per thread
15873
- const int dr = (nr + nth - 1)/nth;
15874
-
15875
- // row range for this thread
15876
- const int ir0 = dr*ith;
15877
- const int ir1 = MIN(ir0 + dr, nr);
15878
- const int ir = ir1 - ir0;
15879
-
15880
- for (int i3 = 0; i3 < n_s; ++i3) {
15881
- for (int i2 = 0; i2 < n_t; ++i2) {
15882
- const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
15883
- const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
15884
- const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
15885
- const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
15886
- const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
15887
- const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
15888
- float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
15889
- float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
15890
-
15891
- // use the output as the source for the next token-wise iterations
15887
+ GGML_ASSERT(src6->nb[0] == sizeof(float));
15888
+ // allows optimizing the modulo since n_group should be a power of 2
15889
+ GGML_ASSERT((ng & -ng) == ng);
15890
+
15891
+ // heads per thread
15892
+ const int dh = (nh + nth - 1)/nth;
15893
+
15894
+ // head range for this thread
15895
+ const int ih0 = dh*ith;
15896
+ const int ih1 = MIN(ih0 + dh, nh);
15897
+
15898
+ for (int i3 = 0; i3 < ns; ++i3) {
15899
+ for (int i2 = 0; i2 < nt; ++i2) {
15900
+ const float * s0 = (const float *) ((const char *) src0->data + i3*(src0->nb[3])); // {d_state, dim, nh, ns}
15901
+ const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
15902
+ const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
15903
+ const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {nh}
15904
+ const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
15905
+ const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
15906
+ const float * D = (const float *) ((const char *) src6->data); // {nh}
15907
+ float * y = (float *) ((char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
15908
+ float * s = (float *) ((char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
15909
+
15910
+ // use the output as the source when it's not the first token-wise iteration
15892
15911
if (i2 > 0) { s0 = s; }
15893
15912
15894
- // d_inner
15895
- for (int i1 = 0; i1 < ir; ++i1) {
15896
- // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
15897
- float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
15898
- float x_dt = x[i1] * dt_soft_plus;
15899
- float sumf = 0.0f;
15900
- // d_state
15901
- for (int i0 = 0; i0 < nc; ++i0) {
15902
- int i = i0 + i1*nc;
15903
- // state = prev_state * dA + dB * x
15904
- float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
15905
- // y = rowwise_dotprod(state, C)
15906
- sumf += state * C[i0];
15907
- s[i] = state;
15913
+ if (ggml_is_vector(src3)) {
15914
+ // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
15915
+
15916
+ // n_head
15917
+ for (int h = ih0; h < ih1; ++h) {
15918
+ // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
15919
+ const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
15920
+ const float dA = expf(dt_soft_plus * A[h]);
15921
+
15922
+ // TODO: SIMD implementation
15923
+ // dim
15924
+ for (int i1 = 0; i1 < nr; ++i1) {
15925
+ const int i = i1 + h*nr;
15926
+ const float x_dt = x[i] * dt_soft_plus;
15927
+ float sumf = 0.0f;
15928
+ // d_state
15929
+ for (int i0 = 0; i0 < nc; ++i0) {
15930
+ const int ii = i0 + i*nc;
15931
+ const int ig = i0 + (h & (ng - 1))*nc;
15932
+ // state = prev_state * dA + dB * x
15933
+ const float state = (s0[ii] * dA) + (B[ig] * x_dt);
15934
+ // y = rowwise_dotprod(state, C)
15935
+ sumf += state * C[ig];
15936
+ s[ii] = state;
15937
+ }
15938
+ y[i] = sumf + x[i] * D[h];
15939
+ }
15940
+ }
15941
+ } else {
15942
+ // Mamba-1 has an element-wise decay factor for the states
15943
+
15944
+ // n_head
15945
+ for (int h = ih0; h < ih1; ++h) {
15946
+ // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
15947
+ const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
15948
+
15949
+ // dim
15950
+ for (int i1 = 0; i1 < nr; ++i1) {
15951
+ const int i = i1 + h*nr;
15952
+ const float x_dt = x[i] * dt_soft_plus;
15953
+ float sumf = 0.0f;
15954
+ // d_state
15955
+ for (int i0 = 0; i0 < nc; ++i0) {
15956
+ const int ii = i0 + i*nc;
15957
+ const int ig = i0 + (h & (ng - 1))*nc;
15958
+ // state = prev_state * dA + dB * x
15959
+ const float state = (s0[ii] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
15960
+ // y = rowwise_dotprod(state, C)
15961
+ sumf += state * C[ig];
15962
+ s[ii] = state;
15963
+ }
15964
+ y[i] = sumf + x[i] * D[h];
15965
+ }
15908
15966
}
15909
- y[i1] = sumf;
15910
15967
}
15911
15968
}
15912
15969
}
0 commit comments