@@ -323,49 +323,75 @@ DPCTLSyclEventRef dpnp_dot_c(DPCTLSyclQueueRef q_ref,
323
323
// check if GEMM can be executed (strides)
324
324
// TODO: rewrite the condition in general case for ndims > 2
325
325
// (looks like there are such another cases)
326
- if ((ext_input1_ndim == 2 && ext_input2_ndim == 2 ) &&
327
- (ext_input1_strides[0 ] == 1 || ext_input1_strides[1 ] == 1 ) &&
328
- (ext_input2_strides[0 ] == 1 || ext_input2_strides[1 ] == 1 ))
326
+
327
+ if (ext_input1_ndim == 2 && ext_input2_ndim == 2 )
329
328
{
330
329
// there is a difference of behavior with trans and sizes params in previous version of GEMM
331
330
// only new version is supported, in case of old version computation goes in common way
332
331
#if INTEL_MKL_VERSION >= 20210004
333
- oneapi::mkl::transpose trans1 =
334
- ext_input1_strides[0 ] == 1 ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans;
335
- oneapi::mkl::transpose trans2 =
336
- ext_input2_strides[0 ] == 1 ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans;
337
-
338
- const size_t size_m = ext_input1_shape[0 ];
339
- const size_t size_n = ext_input2_shape[1 ];
340
- const size_t size_k = ext_input1_shape[1 ];
341
-
342
- const std::int64_t lda =
343
- trans1 == oneapi::mkl::transpose::nontrans ? ext_input1_strides[0 ] : ext_input1_strides[1 ];
344
- const std::int64_t ldb =
345
- trans2 == oneapi::mkl::transpose::nontrans ? ext_input2_strides[0 ] : ext_input2_strides[1 ];
346
- ;
347
- // defenition of ldc will be another for result with non-standard (c-contiguous) strides
348
- // const std::int64_t ldc = result_strides[0] == 1 ? result_strides[1] : result_strides[0];
349
- const std::int64_t ldc = size_n;
350
-
351
- sycl::event event = mkl_blas_rm::gemm (q,
352
- trans1,
353
- trans2,
354
- size_m,
355
- size_n,
356
- size_k,
357
- _DataType_output (1 ), // alpha
358
- input1,
359
- lda,
360
- input2,
361
- ldb,
362
- _DataType_output (0 ), // beta
363
- result,
364
- ldc);
365
- event.wait ();
366
- return event_ref;
332
+ // is mat1 F-contiguous, C-contiguous
333
+ bool mat1_f_contig = (
334
+ ((ext_input1_shape[0 ] == 1 ) || (ext_input1_strides[0 ] == 1 )) &&
335
+ ((ext_input1_shape[1 ] == 1 ) || (ext_input1_strides[1 ] == ext_input1_shape[0 ])));
336
+ bool mat1_c_contig = (
337
+ ((ext_input1_shape[1 ] == 1 ) || (ext_input1_strides[1 ] == 1 )) &&
338
+ ((ext_input1_shape[0 ] == 1 ) || (ext_input1_strides[0 ] == ext_input1_shape[1 ])));
339
+ // is mat2 F-contiguous, C-contiguous
340
+ bool mat2_f_contig = (
341
+ ((ext_input2_shape[0 ] == 1 ) || (ext_input2_strides[0 ] == 1 )) &&
342
+ ((ext_input2_shape[1 ] == 1 ) || (ext_input2_strides[1 ] == ext_input2_shape[0 ])));
343
+ bool mat2_c_contig = (
344
+ ((ext_input2_shape[1 ] == 1 ) || (ext_input2_strides[1 ] == 1 )) &&
345
+ ((ext_input2_shape[0 ] == 1 ) || (ext_input2_strides[0 ] == ext_input2_shape[1 ])));
346
+
347
+ if ((mat1_f_contig || mat1_c_contig) && (mat2_f_contig || mat2_c_contig)) {
348
+ oneapi::mkl::transpose trans1 =
349
+ (mat1_f_contig && !mat1_c_contig) ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans;
350
+ oneapi::mkl::transpose trans2 =
351
+ (mat2_f_contig && !mat2_c_contig) ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans;
352
+
353
+ const size_t size_m = ext_input1_shape[0 ];
354
+ const size_t size_n = ext_input2_shape[1 ];
355
+ const size_t size_k = ext_input1_shape[1 ];
356
+
357
+ const std::int64_t lda =
358
+ trans1 == oneapi::mkl::transpose::nontrans ? ext_input1_strides[0 ] : ext_input1_strides[1 ];
359
+ const std::int64_t ldb =
360
+ trans2 == oneapi::mkl::transpose::nontrans ? ext_input2_strides[0 ] : ext_input2_strides[1 ];
361
+
362
+ // definition of ldc will be another for result with non-standard (c-contiguous) strides
363
+ // const std::int64_t ldc = result_strides[0] == 1 ? result_strides[1] : result_strides[0];
364
+ const std::int64_t ldc = size_n;
365
+
366
+ try {
367
+ sycl::event event = mkl_blas_rm::gemm (q,
368
+ trans1,
369
+ trans2,
370
+ size_m,
371
+ size_n,
372
+ size_k,
373
+ _DataType_output (1 ), // alpha
374
+ input1,
375
+ lda,
376
+ input2,
377
+ ldb,
378
+ _DataType_output (0 ), // beta
379
+ result,
380
+ ldc);
381
+ event.wait ();
382
+ delete[] ext_input1_shape;
383
+ delete[] ext_input1_strides;
384
+ delete[] ext_input2_shape;
385
+ delete[] ext_input2_strides;
386
+ delete[] ext_result_shape;
387
+
388
+ return event_ref;
389
+ } catch (const std::exception &e) {
390
+ // do nothing, proceed to general case
391
+ }
367
392
#endif
368
- }
393
+ }
394
+ }
369
395
}
370
396
371
397
std::vector<sycl::event> dot_events;
0 commit comments