Skip to content

Commit cc098b1

Browse files
Merge pull request #1160 from IntelPython/dot-crash
GEMM to use trans only if matrix if not C-contig
2 parents 51ee8c8 + fa647a9 commit cc098b1

File tree

1 file changed

+64
-38
lines changed

1 file changed

+64
-38
lines changed

dpnp/backend/kernels/dpnp_krnl_common.cpp

Lines changed: 64 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -323,49 +323,75 @@ DPCTLSyclEventRef dpnp_dot_c(DPCTLSyclQueueRef q_ref,
323323
// check if GEMM can be executed (strides)
324324
// TODO: rewrite the condition in general case for ndims > 2
325325
// (looks like there are such another cases)
326-
if ((ext_input1_ndim == 2 && ext_input2_ndim == 2) &&
327-
(ext_input1_strides[0] == 1 || ext_input1_strides[1] == 1) &&
328-
(ext_input2_strides[0] == 1 || ext_input2_strides[1] == 1))
326+
327+
if (ext_input1_ndim == 2 && ext_input2_ndim == 2)
329328
{
330329
// there is a difference of behavior with trans and sizes params in previous version of GEMM
331330
// only new version is supported, in case of old version computation goes in common way
332331
#if INTEL_MKL_VERSION >= 20210004
333-
oneapi::mkl::transpose trans1 =
334-
ext_input1_strides[0] == 1 ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans;
335-
oneapi::mkl::transpose trans2 =
336-
ext_input2_strides[0] == 1 ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans;
337-
338-
const size_t size_m = ext_input1_shape[0];
339-
const size_t size_n = ext_input2_shape[1];
340-
const size_t size_k = ext_input1_shape[1];
341-
342-
const std::int64_t lda =
343-
trans1 == oneapi::mkl::transpose::nontrans ? ext_input1_strides[0] : ext_input1_strides[1];
344-
const std::int64_t ldb =
345-
trans2 == oneapi::mkl::transpose::nontrans ? ext_input2_strides[0] : ext_input2_strides[1];
346-
;
347-
// defenition of ldc will be another for result with non-standard (c-contiguous) strides
348-
// const std::int64_t ldc = result_strides[0] == 1 ? result_strides[1] : result_strides[0];
349-
const std::int64_t ldc = size_n;
350-
351-
sycl::event event = mkl_blas_rm::gemm(q,
352-
trans1,
353-
trans2,
354-
size_m,
355-
size_n,
356-
size_k,
357-
_DataType_output(1), // alpha
358-
input1,
359-
lda,
360-
input2,
361-
ldb,
362-
_DataType_output(0), // beta
363-
result,
364-
ldc);
365-
event.wait();
366-
return event_ref;
332+
// is mat1 F-contiguous, C-contiguous
333+
bool mat1_f_contig = (
334+
((ext_input1_shape[0] == 1) || (ext_input1_strides[0] == 1)) &&
335+
((ext_input1_shape[1] == 1) || (ext_input1_strides[1] == ext_input1_shape[0])));
336+
bool mat1_c_contig = (
337+
((ext_input1_shape[1] == 1) || (ext_input1_strides[1] == 1)) &&
338+
((ext_input1_shape[0] == 1) || (ext_input1_strides[0] == ext_input1_shape[1])));
339+
// is mat2 F-contiguous, C-contiguous
340+
bool mat2_f_contig = (
341+
((ext_input2_shape[0] == 1) || (ext_input2_strides[0] == 1)) &&
342+
((ext_input2_shape[1] == 1) || (ext_input2_strides[1] == ext_input2_shape[0])));
343+
bool mat2_c_contig = (
344+
((ext_input2_shape[1] == 1) || (ext_input2_strides[1] == 1)) &&
345+
((ext_input2_shape[0] == 1) || (ext_input2_strides[0] == ext_input2_shape[1])));
346+
347+
if ((mat1_f_contig || mat1_c_contig) && (mat2_f_contig || mat2_c_contig)) {
348+
oneapi::mkl::transpose trans1 =
349+
(mat1_f_contig && !mat1_c_contig) ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans;
350+
oneapi::mkl::transpose trans2 =
351+
(mat2_f_contig && !mat2_c_contig) ? oneapi::mkl::transpose::trans : oneapi::mkl::transpose::nontrans;
352+
353+
const size_t size_m = ext_input1_shape[0];
354+
const size_t size_n = ext_input2_shape[1];
355+
const size_t size_k = ext_input1_shape[1];
356+
357+
const std::int64_t lda =
358+
trans1 == oneapi::mkl::transpose::nontrans ? ext_input1_strides[0] : ext_input1_strides[1];
359+
const std::int64_t ldb =
360+
trans2 == oneapi::mkl::transpose::nontrans ? ext_input2_strides[0] : ext_input2_strides[1];
361+
362+
// definition of ldc will be another for result with non-standard (c-contiguous) strides
363+
// const std::int64_t ldc = result_strides[0] == 1 ? result_strides[1] : result_strides[0];
364+
const std::int64_t ldc = size_n;
365+
366+
try {
367+
sycl::event event = mkl_blas_rm::gemm(q,
368+
trans1,
369+
trans2,
370+
size_m,
371+
size_n,
372+
size_k,
373+
_DataType_output(1), // alpha
374+
input1,
375+
lda,
376+
input2,
377+
ldb,
378+
_DataType_output(0), // beta
379+
result,
380+
ldc);
381+
event.wait();
382+
delete[] ext_input1_shape;
383+
delete[] ext_input1_strides;
384+
delete[] ext_input2_shape;
385+
delete[] ext_input2_strides;
386+
delete[] ext_result_shape;
387+
388+
return event_ref;
389+
} catch (const std::exception &e) {
390+
// do nothing, proceed to general case
391+
}
367392
#endif
368-
}
393+
}
394+
}
369395
}
370396

371397
std::vector<sycl::event> dot_events;

0 commit comments

Comments
 (0)