Skip to content

Commit 3a04d8a

Browse files
committed
fix: Fix ssm_scan op to match mamba2 branch impl
@compilade mentioned needing to align to the CUDA optimizations done in ggml-org#10558 which I have not tackled here, so that will likely be follow-on work. Branch: BambaAbstractMemory Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent b62b1bb commit 3a04d8a

File tree

1 file changed

+130
-54
lines changed

1 file changed

+130
-54
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 130 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -7502,78 +7502,154 @@ void ggml_compute_forward_ssm_conv(
75027502
}
75037503

75047504
// ggml_compute_forward_ssm_scan
7505-
75067505
static void ggml_compute_forward_ssm_scan_f32(
7507-
const ggml_compute_params * params,
7508-
ggml_tensor * dst) {
7509-
const ggml_tensor * src0 = dst->src[0]; // s
7510-
const ggml_tensor * src1 = dst->src[1]; // x
7511-
const ggml_tensor * src2 = dst->src[2]; // dt
7512-
const ggml_tensor * src3 = dst->src[3]; // A
7513-
const ggml_tensor * src4 = dst->src[4]; // B
7514-
const ggml_tensor * src5 = dst->src[5]; // C
7506+
const struct ggml_compute_params * params,
7507+
struct ggml_tensor * dst) {
7508+
const struct ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
7509+
const struct ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
7510+
const struct ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
7511+
const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
7512+
const struct ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
7513+
const struct ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
7514+
const struct ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
75157515

75167516
const int ith = params->ith;
75177517
const int nth = params->nth;
75187518

7519-
const int64_t nc = src0->ne[0]; // d_state
7520-
const int64_t nr = src0->ne[1]; // d_inner
7521-
const int64_t n_t = src1->ne[1]; // number of tokens per sequence
7522-
const int64_t n_s = src0->ne[2]; // number of sequences in the batch
7519+
const int64_t nc = src0->ne[0]; // d_state
7520+
const int64_t nr = src0->ne[1]; // dim
7521+
const int64_t nh = src1->ne[1]; // n_head
7522+
const int64_t ng = src4->ne[1];
7523+
const int64_t nt = src1->ne[2]; // number of tokens per sequence
7524+
const int64_t ns = src1->ne[3]; // number of sequences in the batch
7525+
7526+
// can't use ggml_nbytes because src1 is not necessarily contiguous
7527+
const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
75237528

7524-
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
7529+
GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
75257530
GGML_ASSERT(src0->nb[0] == sizeof(float));
75267531
GGML_ASSERT(src1->nb[0] == sizeof(float));
75277532
GGML_ASSERT(src2->nb[0] == sizeof(float));
75287533
GGML_ASSERT(src3->nb[0] == sizeof(float));
75297534
GGML_ASSERT(src4->nb[0] == sizeof(float));
75307535
GGML_ASSERT(src5->nb[0] == sizeof(float));
7531-
// required for the dot product between s and C
7532-
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
7533-
// required for per-sequence offsets for states
7534-
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
7535-
// required to get correct offset for state destination (i.e. src1->nb[3])
7536-
GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
7536+
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
7537+
// allows optimizing the modulo since n_group should be a power of 2
7538+
GGML_ASSERT((ng & -ng) == ng);
7539+
7540+
// heads per thread
7541+
const int dh = (nh + nth - 1)/nth;
7542+
7543+
// head range for this thread
7544+
const int ih0 = dh*ith;
7545+
const int ih1 = MIN(ih0 + dh, nh);
7546+
7547+
const int32_t * ids = (const int32_t *) src6->data;
7548+
7549+
for (int i3 = 0; i3 < ns; ++i3) {
7550+
const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
7551+
float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
7552+
7553+
for (int i2 = 0; i2 < nt; ++i2) {
7554+
const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
7555+
const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
7556+
const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
7557+
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
7558+
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
7559+
float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
7560+
7561+
if (src3->ne[0] == 1) {
7562+
// Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
7563+
7564+
// n_head
7565+
for (int h = ih0; h < ih1; ++h) {
7566+
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
7567+
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
7568+
const float dA = expf(dt_soft_plus * A[h]);
7569+
7570+
// dim
7571+
for (int i1 = 0; i1 < nr; ++i1) {
7572+
const int ii = i1 + h*nr;
7573+
const float x_dt = x[ii] * dt_soft_plus;
7574+
float sumf = 0.0f;
7575+
#if defined(GGML_SIMD)
7576+
const int np = (nc & ~(GGML_F32_STEP - 1));
75377577

7538-
// rows per thread
7539-
const int dr = (nr + nth - 1)/nth;
7578+
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
75407579

7541-
// row range for this thread
7542-
const int ir0 = dr*ith;
7543-
const int ir1 = MIN(ir0 + dr, nr);
7544-
const int ir = ir1 - ir0;
7580+
GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
7581+
GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
75457582

7546-
for (int i3 = 0; i3 < n_s; ++i3) {
7547-
for (int i2 = 0; i2 < n_t; ++i2) {
7548-
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
7549-
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7550-
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
7551-
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
7552-
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
7553-
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
7554-
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7555-
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
7556-
7557-
// use the output as the source for the next token-wise iterations
7558-
if (i2 > 0) { s0 = s; }
7583+
GGML_F32_VEC ax[GGML_F32_ARR];
7584+
GGML_F32_VEC ay[GGML_F32_ARR];
7585+
GGML_F32_VEC az[GGML_F32_ARR];
75597586

7560-
// d_inner
7561-
for (int i1 = 0; i1 < ir; ++i1) {
7562-
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
7563-
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
7564-
float x_dt = x[i1] * dt_soft_plus;
7565-
float sumf = 0.0f;
7566-
// d_state
7567-
for (int i0 = 0; i0 < nc; ++i0) {
7568-
int i = i0 + i1*nc;
7569-
// state = prev_state * dA + dB * x
7570-
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
7571-
// y = rowwise_dotprod(state, C)
7572-
sumf += state * C[i0];
7573-
s[i] = state;
7587+
for (int i = 0; i < np; i += GGML_F32_STEP) {
7588+
for (int j = 0; j < GGML_F32_ARR; j++) {
7589+
ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
7590+
ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
7591+
az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
7592+
7593+
ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
7594+
ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
7595+
7596+
ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
7597+
7598+
sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
7599+
7600+
GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
7601+
}
7602+
}
7603+
7604+
// reduce sum0..sum3 to sum0
7605+
GGML_F32_VEC_REDUCE(sumf, sum);
7606+
#else
7607+
const int np = 0;
7608+
#endif
7609+
// d_state
7610+
for (int i0 = np; i0 < nc; ++i0) {
7611+
const int i = i0 + ii*nc;
7612+
const int ig = i0 + (h & (ng - 1))*nc;
7613+
// state = prev_state * dA + dB * x
7614+
const float state = (s0[i] * dA) + (B[ig] * x_dt);
7615+
// y = rowwise_dotprod(state, C)
7616+
sumf += state * C[ig];
7617+
s[i] = state;
7618+
}
7619+
y[ii] = sumf;
7620+
}
7621+
}
7622+
} else {
7623+
// Mamba-1 has an element-wise decay factor for the states
7624+
7625+
// n_head
7626+
for (int h = ih0; h < ih1; ++h) {
7627+
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
7628+
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
7629+
7630+
// dim
7631+
for (int i1 = 0; i1 < nr; ++i1) {
7632+
const int ii = i1 + h*nr;
7633+
const float x_dt = x[ii] * dt_soft_plus;
7634+
float sumf = 0.0f;
7635+
// NOTE: can't really use GGML_SIMD here because d_state is usually 16
7636+
// and also because expf is used within the loop.
7637+
// d_state
7638+
for (int i0 = 0; i0 < nc; ++i0) {
7639+
const int i = i0 + ii*nc;
7640+
const int ig = i0 + (h & (ng - 1))*nc;
7641+
// state = prev_state * dA + dB * x
7642+
const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
7643+
// y = rowwise_dotprod(state, C)
7644+
sumf += state * C[ig];
7645+
s[i] = state;
7646+
}
7647+
y[ii] = sumf;
7648+
}
75747649
}
7575-
y[i1] = sumf;
75767650
}
7651+
// use the output as the source when it's not the first token-wise iteration
7652+
s0 = s;
75777653
}
75787654
}
75797655
}

0 commit comments

Comments
 (0)