@@ -9611,10 +9611,12 @@ static void ggml_compute_forward_out_prod_f32(
9611
9611
const int ith = params->ith;
9612
9612
const int nth = params->nth;
9613
9613
9614
+ GGML_ASSERT(ne0 == ne00);
9615
+ GGML_ASSERT(ne1 == ne10);
9616
+ GGML_ASSERT(ne2 == ne02);
9614
9617
GGML_ASSERT(ne02 == ne12);
9615
- GGML_ASSERT(ne03 == ne13);
9616
- GGML_ASSERT(ne2 == ne12);
9617
9618
GGML_ASSERT(ne3 == ne13);
9619
+ GGML_ASSERT(ne03 == ne13);
9618
9620
9619
9621
// we don't support permuted src0 or src1
9620
9622
GGML_ASSERT(nb00 == sizeof(float));
@@ -9625,18 +9627,25 @@ static void ggml_compute_forward_out_prod_f32(
9625
9627
// GGML_ASSERT(nb1 <= nb2);
9626
9628
// GGML_ASSERT(nb2 <= nb3);
9627
9629
9628
- GGML_ASSERT(ne0 == ne00);
9629
- GGML_ASSERT(ne1 == ne10);
9630
- GGML_ASSERT(ne2 == ne02);
9631
- GGML_ASSERT(ne3 == ne03);
9632
-
9633
9630
// nb01 >= nb00 - src0 is not transposed
9634
9631
// compute by src0 rows
9635
9632
9636
9633
// TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
9637
- // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
9634
+ // TODO: #if defined(GGML_USE_CLBLAST)
9635
+
9636
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
9637
+ bool use_blas = ggml_is_matrix(src0) &&
9638
+ ggml_is_matrix(src1) &&
9639
+ ggml_is_contiguous(src0) &&
9640
+ (ggml_is_contiguous(src1) || ggml_is_transposed(src1));
9641
+ #endif
9638
9642
9639
9643
if (params->type == GGML_TASK_INIT) {
9644
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) // gemm beta will zero dst
9645
+ if (use_blas) {
9646
+ return;
9647
+ }
9648
+ #endif
9640
9649
ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
9641
9650
return;
9642
9651
}
@@ -9645,6 +9654,50 @@ static void ggml_compute_forward_out_prod_f32(
9645
9654
return;
9646
9655
}
9647
9656
9657
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
9658
+ if (use_blas) {
9659
+ if (params->ith != 0) { // All threads other than the first do no work.
9660
+ return;
9661
+ }
9662
+ // Arguments to ggml_compute_forward_out_prod (expressed as major,minor)
9663
+ // src0: (k,n)
9664
+ // src1: (k,m)
9665
+ // dst: (m,n)
9666
+ //
9667
+ // Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f)
9668
+ // Also expressed as (major,minor)
9669
+ // a: (m,k): so src1 transposed
9670
+ // b: (k,n): so src0
9671
+ // c: (m,n)
9672
+ //
9673
+ // However, if ggml_is_transposed(src1) is true, then
9674
+ // src1->data already contains a transposed version, so sgemm mustn't
9675
+ // transpose it further.
9676
+
9677
+ int n = src0->ne[0];
9678
+ int k = src0->ne[1];
9679
+ int m = src1->ne[0];
9680
+
9681
+ int transposeA, lda;
9682
+
9683
+ if (!ggml_is_transposed(src1)) {
9684
+ transposeA = CblasTrans;
9685
+ lda = m;
9686
+ } else {
9687
+ transposeA = CblasNoTrans;
9688
+ lda = k;
9689
+ }
9690
+
9691
+ float * a = (float *) ((char *) src1->data);
9692
+ float * b = (float *) ((char *) src0->data);
9693
+ float * c = (float *) ((char *) dst->data);
9694
+
9695
+ cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n);
9696
+
9697
+ return;
9698
+ }
9699
+ #endif
9700
+
9648
9701
// dst[:,:,:,:] = 0
9649
9702
// for i2,i3:
9650
9703
// for i1:
0 commit comments