Skip to content

Commit 40c0685

Browse files
committed
ggml : SIMD ggml_ssm_scan for Mamba-2
* ggml : improve ggml_mul speed when masking recurrent states
1 parent a79e4e8 commit 40c0685

File tree

1 file changed

+80
-15
lines changed

1 file changed

+80
-15
lines changed

ggml/src/ggml.c

Lines changed: 80 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10207,7 +10207,37 @@ static void ggml_compute_forward_mul_f32(
1020710207
GGML_ASSERT( nb0 == sizeof(float));
1020810208
GGML_ASSERT(nb00 == sizeof(float));
1020910209

10210-
if (nb10 == sizeof(float)) {
10210+
if (ne00 > 1 && ne10 == 1) {
10211+
// fast broadcast path
10212+
for (int64_t ir = ith; ir < nr; ir += nth) {
10213+
// src0 and dst are same shape => same indices
10214+
const int64_t i03 = ir/(ne02*ne01);
10215+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
10216+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
10217+
10218+
const int64_t i13 = i03 % ne13;
10219+
const int64_t i12 = i02 % ne12;
10220+
const int64_t i11 = i01 % ne11;
10221+
10222+
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
10223+
10224+
const float scale = src1_ptr[0];
10225+
10226+
if (scale == 0.0f) {
10227+
// NOTE: this also sets NANs to zero, which is not compliant with IEEE754,
10228+
// but it is useful when resetting the state of recurrent models.
10229+
memset((char *)dst->data + ir*nb1, 0, nb1);
10230+
} else {
10231+
if (dst->data != src0->data) {
10232+
// src0 is same shape as dst => same indices
10233+
memcpy((char *)dst->data + ir*nb1, (char *)src0->data + ir*nb01, ne0 * sizeof(float));
10234+
}
10235+
if (scale != 1.0f) {
10236+
ggml_vec_scale_f32(ne0, (float *) ((char *) dst->data + ir*nb1), scale);
10237+
}
10238+
}
10239+
}
10240+
} else if (nb10 == sizeof(float)) {
1021110241
for (int64_t ir = ith; ir < nr; ir += nth) {
1021210242
// src0 and dst are same shape => same indices
1021310243
const int64_t i03 = ir/(ne02*ne01);
@@ -15919,23 +15949,56 @@ static void ggml_compute_forward_ssm_scan_f32(
1591915949
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
1592015950
const float dA = expf(dt_soft_plus * A[h]);
1592115951

15922-
// TODO: SIMD implementation
1592315952
// dim
1592415953
for (int i1 = 0; i1 < nr; ++i1) {
15925-
const int i = i1 + h*nr;
15926-
const float x_dt = x[i] * dt_soft_plus;
15954+
const int ii = i1 + h*nr;
15955+
const float x_dt = x[ii] * dt_soft_plus;
1592715956
float sumf = 0.0f;
15957+
#if defined(GGML_SIMD)
15958+
const int np = (nc & ~(GGML_F32_STEP - 1));
15959+
15960+
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
15961+
15962+
GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
15963+
GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
15964+
15965+
GGML_F32_VEC ax[GGML_F32_ARR];
15966+
GGML_F32_VEC ay[GGML_F32_ARR];
15967+
GGML_F32_VEC az[GGML_F32_ARR];
15968+
15969+
for (int i = 0; i < np; i += GGML_F32_STEP) {
15970+
for (int j = 0; j < GGML_F32_ARR; j++) {
15971+
ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
15972+
ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
15973+
az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
15974+
15975+
ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
15976+
ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
15977+
15978+
ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
15979+
15980+
sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
15981+
15982+
GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
15983+
}
15984+
}
15985+
15986+
// reduce sum0..sum3 to sum0
15987+
GGML_F32_VEC_REDUCE(sumf, sum);
15988+
#else
15989+
const int np = 0;
15990+
#endif
1592815991
// d_state
15929-
for (int i0 = 0; i0 < nc; ++i0) {
15930-
const int ii = i0 + i*nc;
15992+
for (int i0 = np; i0 < nc; ++i0) {
15993+
const int i = i0 + ii*nc;
1593115994
const int ig = i0 + (h & (ng - 1))*nc;
1593215995
// state = prev_state * dA + dB * x
15933-
const float state = (s0[ii] * dA) + (B[ig] * x_dt);
15996+
const float state = (s0[i] * dA) + (B[ig] * x_dt);
1593415997
// y = rowwise_dotprod(state, C)
1593515998
sumf += state * C[ig];
15936-
s[ii] = state;
15999+
s[i] = state;
1593716000
}
15938-
y[i] = sumf + x[i] * D[h];
16001+
y[ii] = sumf + x[ii] * D[h];
1593916002
}
1594016003
}
1594116004
} else {
@@ -15948,20 +16011,22 @@ static void ggml_compute_forward_ssm_scan_f32(
1594816011

1594916012
// dim
1595016013
for (int i1 = 0; i1 < nr; ++i1) {
15951-
const int i = i1 + h*nr;
15952-
const float x_dt = x[i] * dt_soft_plus;
16014+
const int ii = i1 + h*nr;
16015+
const float x_dt = x[ii] * dt_soft_plus;
1595316016
float sumf = 0.0f;
16017+
// NOTE: can't really use GGML_SIMD here because d_state is usually 16
16018+
// and also because expf is used within the loop.
1595416019
// d_state
1595516020
for (int i0 = 0; i0 < nc; ++i0) {
15956-
const int ii = i0 + i*nc;
16021+
const int i = i0 + ii*nc;
1595716022
const int ig = i0 + (h & (ng - 1))*nc;
1595816023
// state = prev_state * dA + dB * x
15959-
const float state = (s0[ii] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
16024+
const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
1596016025
// y = rowwise_dotprod(state, C)
1596116026
sumf += state * C[ig];
15962-
s[ii] = state;
16027+
s[i] = state;
1596316028
}
15964-
y[i] = sumf + x[i] * D[h];
16029+
y[ii] = sumf + x[ii] * D[h];
1596516030
}
1596616031
}
1596716032
}

0 commit comments

Comments
 (0)