@@ -10207,7 +10207,37 @@ static void ggml_compute_forward_mul_f32(
10207
10207
GGML_ASSERT( nb0 == sizeof(float));
10208
10208
GGML_ASSERT(nb00 == sizeof(float));
10209
10209
10210
- if (nb10 == sizeof(float)) {
10210
+ if (ne00 > 1 && ne10 == 1) {
10211
+ // fast broadcast path
10212
+ for (int64_t ir = ith; ir < nr; ir += nth) {
10213
+ // src0 and dst are same shape => same indices
10214
+ const int64_t i03 = ir/(ne02*ne01);
10215
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
10216
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
10217
+
10218
+ const int64_t i13 = i03 % ne13;
10219
+ const int64_t i12 = i02 % ne12;
10220
+ const int64_t i11 = i01 % ne11;
10221
+
10222
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
10223
+
10224
+ const float scale = src1_ptr[0];
10225
+
10226
+ if (scale == 0.0f) {
10227
+ // NOTE: this also sets NANs to zero, which is not compliant with IEEE754,
10228
+ // but it is useful when resetting the state of recurrent models.
10229
+ memset((char *)dst->data + ir*nb1, 0, nb1);
10230
+ } else {
10231
+ if (dst->data != src0->data) {
10232
+ // src0 is same shape as dst => same indices
10233
+ memcpy((char *)dst->data + ir*nb1, (char *)src0->data + ir*nb01, ne0 * sizeof(float));
10234
+ }
10235
+ if (scale != 1.0f) {
10236
+ ggml_vec_scale_f32(ne0, (float *) ((char *) dst->data + ir*nb1), scale);
10237
+ }
10238
+ }
10239
+ }
10240
+ } else if (nb10 == sizeof(float)) {
10211
10241
for (int64_t ir = ith; ir < nr; ir += nth) {
10212
10242
// src0 and dst are same shape => same indices
10213
10243
const int64_t i03 = ir/(ne02*ne01);
@@ -15919,23 +15949,56 @@ static void ggml_compute_forward_ssm_scan_f32(
15919
15949
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
15920
15950
const float dA = expf(dt_soft_plus * A[h]);
15921
15951
15922
- // TODO: SIMD implementation
15923
15952
// dim
15924
15953
for (int i1 = 0; i1 < nr; ++i1) {
15925
- const int i = i1 + h*nr;
15926
- const float x_dt = x[i ] * dt_soft_plus;
15954
+ const int ii = i1 + h*nr;
15955
+ const float x_dt = x[ii ] * dt_soft_plus;
15927
15956
float sumf = 0.0f;
15957
+ #if defined(GGML_SIMD)
15958
+ const int np = (nc & ~(GGML_F32_STEP - 1));
15959
+
15960
+ GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
15961
+
15962
+ GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
15963
+ GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
15964
+
15965
+ GGML_F32_VEC ax[GGML_F32_ARR];
15966
+ GGML_F32_VEC ay[GGML_F32_ARR];
15967
+ GGML_F32_VEC az[GGML_F32_ARR];
15968
+
15969
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
15970
+ for (int j = 0; j < GGML_F32_ARR; j++) {
15971
+ ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
15972
+ ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
15973
+ az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
15974
+
15975
+ ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
15976
+ ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
15977
+
15978
+ ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
15979
+
15980
+ sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
15981
+
15982
+ GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
15983
+ }
15984
+ }
15985
+
15986
+ // reduce sum0..sum3 to sum0
15987
+ GGML_F32_VEC_REDUCE(sumf, sum);
15988
+ #else
15989
+ const int np = 0;
15990
+ #endif
15928
15991
// d_state
15929
- for (int i0 = 0 ; i0 < nc; ++i0) {
15930
- const int ii = i0 + i *nc;
15992
+ for (int i0 = np ; i0 < nc; ++i0) {
15993
+ const int i = i0 + ii *nc;
15931
15994
const int ig = i0 + (h & (ng - 1))*nc;
15932
15995
// state = prev_state * dA + dB * x
15933
- const float state = (s0[ii ] * dA) + (B[ig] * x_dt);
15996
+ const float state = (s0[i ] * dA) + (B[ig] * x_dt);
15934
15997
// y = rowwise_dotprod(state, C)
15935
15998
sumf += state * C[ig];
15936
- s[ii ] = state;
15999
+ s[i ] = state;
15937
16000
}
15938
- y[i ] = sumf + x[i ] * D[h];
16001
+ y[ii ] = sumf + x[ii ] * D[h];
15939
16002
}
15940
16003
}
15941
16004
} else {
@@ -15948,20 +16011,22 @@ static void ggml_compute_forward_ssm_scan_f32(
15948
16011
15949
16012
// dim
15950
16013
for (int i1 = 0; i1 < nr; ++i1) {
15951
- const int i = i1 + h*nr;
15952
- const float x_dt = x[i ] * dt_soft_plus;
16014
+ const int ii = i1 + h*nr;
16015
+ const float x_dt = x[ii ] * dt_soft_plus;
15953
16016
float sumf = 0.0f;
16017
+ // NOTE: can't really use GGML_SIMD here because d_state is usually 16
16018
+ // and also because expf is used within the loop.
15954
16019
// d_state
15955
16020
for (int i0 = 0; i0 < nc; ++i0) {
15956
- const int ii = i0 + i *nc;
16021
+ const int i = i0 + ii *nc;
15957
16022
const int ig = i0 + (h & (ng - 1))*nc;
15958
16023
// state = prev_state * dA + dB * x
15959
- const float state = (s0[ii ] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
16024
+ const float state = (s0[i ] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
15960
16025
// y = rowwise_dotprod(state, C)
15961
16026
sumf += state * C[ig];
15962
- s[ii ] = state;
16027
+ s[i ] = state;
15963
16028
}
15964
- y[i ] = sumf + x[i ] * D[h];
16029
+ y[ii ] = sumf + x[ii ] * D[h];
15965
16030
}
15966
16031
}
15967
16032
}
0 commit comments