Skip to content

Commit 3e916a0

Browse files
authored
finetune : speed-up ggml_compute_forward_out_prod_f32 via BLAS (#4079)
* Remove logically superfluous assertions and order by dimension * Use cblas_sgemm() to implement ggml_compute_forward_out_prod() * Remove ggml_compute_forward_out_prod_use_blas(), fix compiling errors on cmake/zig, remove trailing whitespace * Add openBLAS support for sgemm() in compute_forward_out_prod()
1 parent 947f64f commit 3e916a0

File tree

1 file changed

+61
-8
lines changed

1 file changed

+61
-8
lines changed

ggml.c

Lines changed: 61 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9611,10 +9611,12 @@ static void ggml_compute_forward_out_prod_f32(
96119611
const int ith = params->ith;
96129612
const int nth = params->nth;
96139613

9614+
GGML_ASSERT(ne0 == ne00);
9615+
GGML_ASSERT(ne1 == ne10);
9616+
GGML_ASSERT(ne2 == ne02);
96149617
GGML_ASSERT(ne02 == ne12);
9615-
GGML_ASSERT(ne03 == ne13);
9616-
GGML_ASSERT(ne2 == ne12);
96179618
GGML_ASSERT(ne3 == ne13);
9619+
GGML_ASSERT(ne03 == ne13);
96189620

96199621
// we don't support permuted src0 or src1
96209622
GGML_ASSERT(nb00 == sizeof(float));
@@ -9625,18 +9627,25 @@ static void ggml_compute_forward_out_prod_f32(
96259627
// GGML_ASSERT(nb1 <= nb2);
96269628
// GGML_ASSERT(nb2 <= nb3);
96279629

9628-
GGML_ASSERT(ne0 == ne00);
9629-
GGML_ASSERT(ne1 == ne10);
9630-
GGML_ASSERT(ne2 == ne02);
9631-
GGML_ASSERT(ne3 == ne03);
9632-
96339630
// nb01 >= nb00 - src0 is not transposed
96349631
// compute by src0 rows
96359632

96369633
// TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
9637-
// TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
9634+
// TODO: #if defined(GGML_USE_CLBLAST)
9635+
9636+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
9637+
bool use_blas = ggml_is_matrix(src0) &&
9638+
ggml_is_matrix(src1) &&
9639+
ggml_is_contiguous(src0) &&
9640+
(ggml_is_contiguous(src1) || ggml_is_transposed(src1));
9641+
#endif
96389642

96399643
if (params->type == GGML_TASK_INIT) {
9644+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) // gemm beta will zero dst
9645+
if (use_blas) {
9646+
return;
9647+
}
9648+
#endif
96409649
ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
96419650
return;
96429651
}
@@ -9645,6 +9654,50 @@ static void ggml_compute_forward_out_prod_f32(
96459654
return;
96469655
}
96479656

9657+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
9658+
if (use_blas) {
9659+
if (params->ith != 0) { // All threads other than the first do no work.
9660+
return;
9661+
}
9662+
// Arguments to ggml_compute_forward_out_prod (expressed as major,minor)
9663+
// src0: (k,n)
9664+
// src1: (k,m)
9665+
// dst: (m,n)
9666+
//
9667+
// Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f)
9668+
// Also expressed as (major,minor)
9669+
// a: (m,k): so src1 transposed
9670+
// b: (k,n): so src0
9671+
// c: (m,n)
9672+
//
9673+
// However, if ggml_is_transposed(src1) is true, then
9674+
// src1->data already contains a transposed version, so sgemm mustn't
9675+
// transpose it further.
9676+
9677+
int n = src0->ne[0];
9678+
int k = src0->ne[1];
9679+
int m = src1->ne[0];
9680+
9681+
int transposeA, lda;
9682+
9683+
if (!ggml_is_transposed(src1)) {
9684+
transposeA = CblasTrans;
9685+
lda = m;
9686+
} else {
9687+
transposeA = CblasNoTrans;
9688+
lda = k;
9689+
}
9690+
9691+
float * a = (float *) ((char *) src1->data);
9692+
float * b = (float *) ((char *) src0->data);
9693+
float * c = (float *) ((char *) dst->data);
9694+
9695+
cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n);
9696+
9697+
return;
9698+
}
9699+
#endif
9700+
96489701
// dst[:,:,:,:] = 0
96499702
// for i2,i3:
96509703
// for i1:

0 commit comments

Comments
 (0)