@@ -9580,16 +9580,11 @@ static bool ggml_compute_forward_mul_mat_use_blas(
9580
9580
}
9581
9581
#endif
9582
9582
9583
- // off1 = offset in i11 and i1
9584
- // cne1 = ne11 and ne1
9585
- // in a normal matrix multiplication, off1 = 0 and cne1 = ne1
9586
- // during GGML_TASK_INIT, the full src1 is converted regardless of off1 and cne1
9587
9583
static void ggml_compute_forward_mul_mat(
9588
9584
const struct ggml_compute_params * params,
9589
9585
const struct ggml_tensor * src0,
9590
9586
const struct ggml_tensor * src1,
9591
- struct ggml_tensor * dst,
9592
- int64_t off1, int64_t cne1) {
9587
+ struct ggml_tensor * dst) {
9593
9588
int64_t t0 = ggml_perf_time_us();
9594
9589
UNUSED(t0);
9595
9590
@@ -9657,9 +9652,9 @@ static void ggml_compute_forward_mul_mat(
9657
9652
const int64_t i03 = i13/r3;
9658
9653
const int64_t i02 = i12/r2;
9659
9654
9660
- const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
9661
- const float * y = (float *) ((char *) src1->data + off1*nb11 + i12*nb12 + i13*nb13);
9662
- float * d = (float *) ((char *) dst->data + off1*nb1 + i12*nb2 + i13*nb3);
9655
+ const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
9656
+ const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
9657
+ float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
9663
9658
9664
9659
if (type != GGML_TYPE_F32) {
9665
9660
float * const wdata = params->wdata;
@@ -9676,7 +9671,7 @@ static void ggml_compute_forward_mul_mat(
9676
9671
}
9677
9672
9678
9673
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
9679
- cne1 , ne01, ne10,
9674
+ ne1 , ne01, ne10,
9680
9675
1.0f, y, ne10,
9681
9676
x, ne00,
9682
9677
0.0f, d, ne01);
@@ -9717,8 +9712,8 @@ static void ggml_compute_forward_mul_mat(
9717
9712
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
9718
9713
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
9719
9714
9720
- const int64_t nr0 = ne01; // src0 rows
9721
- const int64_t nr1 = cne1 *ne12*ne13; // src1 rows
9715
+ const int64_t nr0 = ne01; // src0 rows
9716
+ const int64_t nr1 = ne1 *ne12*ne13; // src1 rows
9722
9717
9723
9718
//printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
9724
9719
@@ -9760,9 +9755,9 @@ static void ggml_compute_forward_mul_mat(
9760
9755
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
9761
9756
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
9762
9757
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
9763
- const int64_t i13 = (ir1/(ne12*cne1 ));
9764
- const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1 ;
9765
- const int64_t i11 = (ir1 - i13*ne12*cne1 - i12*cne1) + off1 ;
9758
+ const int64_t i13 = (ir1/(ne12*ne1 ));
9759
+ const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1 ;
9760
+ const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1) ;
9766
9761
9767
9762
// broadcast src0 into src1
9768
9763
const int64_t i03 = i13/r3;
@@ -9802,28 +9797,191 @@ static void ggml_compute_forward_mul_mat(
9802
9797
9803
9798
static void ggml_compute_forward_mul_mat_id(
9804
9799
const struct ggml_compute_params * params,
9805
- const struct ggml_tensor * src0 ,
9800
+ const struct ggml_tensor * ids ,
9806
9801
const struct ggml_tensor * src1,
9807
9802
struct ggml_tensor * dst) {
9808
9803
9809
- if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
9810
- // during GGML_TASK_INIT the entire src1 is converted to vec_dot_type
9811
- ggml_compute_forward_mul_mat(params, dst->src[2], src1, dst, 0, dst->ne[1]);
9812
- return;
9813
- }
9804
+ const struct ggml_tensor * src0 = dst->src[2]; // only for GGML_TENSOR_BINARY_OP_LOCALS
9805
+
9806
+ GGML_TENSOR_BINARY_OP_LOCALS
9807
+
9808
+ const int ith = params->ith;
9809
+ const int nth = params->nth;
9810
+
9811
+ const enum ggml_type type = src0->type;
9812
+
9813
+ const bool src1_cont = ggml_is_contiguous(src1);
9814
+
9815
+ ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
9816
+ enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
9817
+ ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
9818
+
9819
+ GGML_ASSERT(ne0 == ne01);
9820
+ GGML_ASSERT(ne1 == ne11);
9821
+ GGML_ASSERT(ne2 == ne12);
9822
+ GGML_ASSERT(ne3 == ne13);
9823
+
9824
+ // we don't support permuted src0 or src1
9825
+ GGML_ASSERT(nb00 == ggml_type_size(type));
9826
+ GGML_ASSERT(nb10 == ggml_type_size(src1->type));
9827
+
9828
+ // dst cannot be transposed or permuted
9829
+ GGML_ASSERT(nb0 == sizeof(float));
9830
+ GGML_ASSERT(nb0 <= nb1);
9831
+ GGML_ASSERT(nb1 <= nb2);
9832
+ GGML_ASSERT(nb2 <= nb3);
9814
9833
9815
- const struct ggml_tensor * ids = src0;
9834
+ // broadcast factors
9835
+ const int64_t r2 = ne12/ne02;
9836
+ const int64_t r3 = ne13/ne03;
9837
+
9838
+ // row groups
9816
9839
const int id = ggml_get_op_params_i32(dst, 0);
9817
9840
const int n_as = ggml_get_op_params_i32(dst, 1);
9818
9841
9819
- for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
9820
- const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]);
9842
+ char * wdata_src1_end = (src1->type == vec_dot_type) ?
9843
+ (char *) params->wdata :
9844
+ (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
9845
+
9846
+ int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
9847
+ int64_t * matrix_rows = matrix_row_counts + n_as; // [n_as][ne11]
9848
+
9849
+ #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne11 + (i1)]
9821
9850
9822
- GGML_ASSERT(row_id >= 0 && row_id < n_as);
9851
+ if (params->type == GGML_TASK_INIT) {
9852
+ char * wdata = params->wdata;
9853
+ if (src1->type != vec_dot_type) {
9854
+ const size_t row_size = ggml_row_size(vec_dot_type, ne10);
9823
9855
9824
- const struct ggml_tensor * src0_row = dst->src[row_id + 2];
9825
- ggml_compute_forward_mul_mat(params, src0_row, src1, dst, i01, 1);
9856
+ assert(params->wsize >= ne11*ne12*ne13*row_size);
9857
+ assert(src1->type == GGML_TYPE_F32);
9858
+
9859
+ for (int64_t i13 = 0; i13 < ne13; ++i13) {
9860
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
9861
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
9862
+ from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
9863
+ wdata += row_size;
9864
+ }
9865
+ }
9866
+ }
9867
+ }
9868
+
9869
+ // initialize matrix_row_counts
9870
+ GGML_ASSERT(wdata == wdata_src1_end);
9871
+ memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
9872
+
9873
+ // group rows by src0 matrix
9874
+ for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
9875
+ const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]);
9876
+
9877
+ GGML_ASSERT(row_id >= 0 && row_id < n_as);
9878
+ MMID_MATRIX_ROW(row_id, matrix_row_counts[row_id]) = i01;
9879
+ matrix_row_counts[row_id] += 1;
9880
+ }
9881
+
9882
+ return;
9826
9883
}
9884
+
9885
+ if (params->type == GGML_TASK_FINALIZE) {
9886
+ return;
9887
+ }
9888
+
9889
+ // compute each matrix multiplication in sequence
9890
+ for (int cur_a = 0; cur_a < n_as; ++cur_a) {
9891
+ const int64_t cne1 = matrix_row_counts[cur_a];
9892
+
9893
+ if (cne1 == 0) {
9894
+ continue;
9895
+ }
9896
+
9897
+ const struct ggml_tensor * src0_cur = dst->src[cur_a + 2];
9898
+
9899
+ const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
9900
+ const size_t row_size = ggml_row_size(vec_dot_type, ne10);
9901
+
9902
+ const int64_t nr0 = ne01; // src0 rows
9903
+ const int64_t nr1 = cne1*ne12*ne13; // src1 rows
9904
+
9905
+ //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
9906
+
9907
+ // distribute the thread work across the inner or outer loop based on which one is larger
9908
+
9909
+ const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
9910
+ const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
9911
+
9912
+ const int64_t ith0 = ith % nth0;
9913
+ const int64_t ith1 = ith / nth0;
9914
+
9915
+ const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
9916
+ const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
9917
+
9918
+ const int64_t ir010 = dr0*ith0;
9919
+ const int64_t ir011 = MIN(ir010 + dr0, nr0);
9920
+
9921
+ const int64_t ir110 = dr1*ith1;
9922
+ const int64_t ir111 = MIN(ir110 + dr1, nr1);
9923
+
9924
+ //printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111);
9925
+
9926
+ // threads with no work simply yield (not sure if it helps)
9927
+ if (ir010 >= ir011 || ir110 >= ir111) {
9928
+ sched_yield();
9929
+ continue;
9930
+ }
9931
+
9932
+ assert(ne12 % ne02 == 0);
9933
+ assert(ne13 % ne03 == 0);
9934
+
9935
+ // block-tiling attempt
9936
+ const int64_t blck_0 = 16;
9937
+ const int64_t blck_1 = 16;
9938
+
9939
+ // attempt to reduce false-sharing (does not seem to make a difference)
9940
+ float tmp[16];
9941
+
9942
+ for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
9943
+ for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
9944
+ for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
9945
+ const int64_t i13 = (ir1/(ne12*cne1)); // Note: currently, src1 is always a matrix
9946
+ const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1;
9947
+ const int64_t _i11 = (ir1 - i13*ne12*cne1 - i12*cne1);
9948
+ const int64_t i11 = MMID_MATRIX_ROW(cur_a, _i11);
9949
+
9950
+ // broadcast src0 into src1
9951
+ const int64_t i03 = i13/r3;
9952
+ const int64_t i02 = i12/r2;
9953
+
9954
+ const int64_t i1 = i11;
9955
+ const int64_t i2 = i12;
9956
+ const int64_t i3 = i13;
9957
+
9958
+ const char * src0_row = (const char *) src0_cur->data + (0 + i02*nb02 + i03*nb03);
9959
+
9960
+ // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
9961
+ // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
9962
+ // the original src1 data pointer, so we should index using the indices directly
9963
+ // TODO: this is a bit of a hack, we should probably have a better way to handle this
9964
+ const char * src1_col = (const char *) wdata +
9965
+ (src1_cont || src1->type != vec_dot_type
9966
+ ? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
9967
+ : (i11*nb11 + i12*nb12 + i13*nb13));
9968
+
9969
+ float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
9970
+
9971
+ //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
9972
+ // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
9973
+ //}
9974
+
9975
+ for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
9976
+ vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col);
9977
+ }
9978
+ memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
9979
+ }
9980
+ }
9981
+ }
9982
+ }
9983
+
9984
+ #undef MMID_MATRIX_ROW
9827
9985
}
9828
9986
9829
9987
// ggml_compute_forward_out_prod
@@ -14191,7 +14349,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
14191
14349
} break;
14192
14350
case GGML_OP_MUL_MAT:
14193
14351
{
14194
- ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor, 0, tensor->ne[1] );
14352
+ ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor);
14195
14353
} break;
14196
14354
case GGML_OP_MUL_MAT_ID:
14197
14355
{
@@ -15991,7 +16149,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
15991
16149
} break;
15992
16150
case GGML_OP_MUL_MAT_ID:
15993
16151
{
15994
- // FIXME: blas
15995
16152
n_tasks = n_threads;
15996
16153
} break;
15997
16154
case GGML_OP_OUT_PROD:
@@ -16325,20 +16482,16 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
16325
16482
} break;
16326
16483
case GGML_OP_MUL_MAT_ID:
16327
16484
{
16328
- const struct ggml_tensor * a = node->src[2];
16329
- const struct ggml_tensor * b = node->src[1];
16330
- const enum ggml_type vec_dot_type = type_traits[a->type].vec_dot_type;
16331
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
16332
- if (ggml_compute_forward_mul_mat_use_blas(a, b, node)) {
16333
- if (a->type != GGML_TYPE_F32) {
16334
- // here we need memory just for single 2D matrix from src0
16335
- cur = ggml_type_size(GGML_TYPE_F32)*(a->ne[0]*a->ne[1]);
16336
- }
16337
- } else
16338
- #endif
16339
- if (b->type != vec_dot_type) {
16340
- cur = ggml_row_size(vec_dot_type, ggml_nelements(b));
16485
+ const struct ggml_tensor * src0 = node->src[2];
16486
+ const struct ggml_tensor * src1 = node->src[1];
16487
+ const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type;
16488
+ if (src1->type != vec_dot_type) {
16489
+ cur = ggml_row_size(vec_dot_type, ggml_nelements(src1));
16341
16490
}
16491
+ const int n_as = ggml_get_op_params_i32(node, 1);
16492
+ cur = GGML_PAD(cur, sizeof(int64_t)); // align
16493
+ cur += n_as * sizeof(int64_t); // matrix_row_counts
16494
+ cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows
16342
16495
} break;
16343
16496
case GGML_OP_OUT_PROD:
16344
16497
{
0 commit comments