Skip to content

Commit 7b8d29b

Browse files
committed
address comments - first round
1 parent d45cb4a commit 7b8d29b

File tree

7 files changed

+78
-35
lines changed

7 files changed

+78
-35
lines changed

dpnp/backend/extensions/blas/blas_py.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ namespace py = pybind11;
3838
// populate dispatch tables
3939
void init_dispatch_tables(void)
4040
{
41-
blas_ext::init_gemm_dispatch_table();
4241
blas_ext::init_gemm_batch_dispatch_table();
42+
blas_ext::init_gemm_dispatch_table();
4343
}
4444

4545
PYBIND11_MODULE(_blas_impl, m)
@@ -51,12 +51,18 @@ PYBIND11_MODULE(_blas_impl, m)
5151
"Call `gemm` from OneMKL LAPACK library to return "
5252
"the matrix-matrix product with 2-D matrices.",
5353
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
54-
py::arg("matrixC"), py::arg("depends") = py::list());
54+
py::arg("result"), py::arg("depends") = py::list());
5555
}
5656

5757
{
5858
m.def("_gemm_batch", &blas_ext::gemm_batch,
5959
"Call `gemm_batch` from OneMKL LAPACK library to return "
60-
"the matrix-matrix product with general matrices.");
60+
"the matrix-matrix product with general matrices.",
61+
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
62+
py::arg("result"), py::arg("m"), py::arg("n"), py::arg("k"),
63+
py::arg("batch_size"), py::arg("ld_array_1"),
64+
py::arg("ld_array_2"), py::arg("ld_result"), py::arg("stridea"),
65+
py::arg("strideb"), py::arg("stridec"), py::arg("transA_int"),
66+
py::arg("transB_int"), py::arg("depends") = py::list());
6167
}
6268
}

dpnp/backend/extensions/blas/gemm.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@ static sycl::event gemm_impl(sycl::queue exec_q,
9393
try {
9494
gemm_event = mkl_blas::row_major::gemm(
9595
exec_q,
96-
transA, // Parameter indicating whether matrix A is not
97-
// transposed ('N'), transposed ('T'),
98-
// or conjugate transposed ('C').
96+
transA, // Defines the transpose operation for matrix A:
97+
// 'N' indicates no transpose, 'T' for transpose,
98+
// or 'C' for a conjugate transpose.
9999
transB, // Same as transA but for matrix B.
100100
m, // Number of rows in matrices A and C.
101101
n, // Number of columns in matrices B and C.
@@ -106,7 +106,7 @@ static sycl::event gemm_impl(sycl::queue exec_q,
106106
// stride between successive rows (for row major
107107
// layout).
108108
b, // Pointer to matrix B.
109-
ldb, // Leading dimension of matrix B, similar to lda
109+
ldb, // Leading dimension of matrix B, similar to lda.
110110
Tab(0), // Scaling factor for matrix C.
111111
res, // Pointer to matrix C, where the result is stored.
112112
ldc, // Leading dimension of matrix C.
@@ -198,7 +198,8 @@ std::pair<sycl::event, sycl::event>
198198
gemm_impl_fn_ptr_t gemm_fn =
199199
gemm_dispatch_table[matrixAB_type_id][resultC_type_id];
200200
if (gemm_fn == nullptr) {
201-
throw py::value_error("Type dispatch ran into trouble.");
201+
throw py::value_error(
202+
"Types of input matrices and result matrix are mismatched.");
202203
}
203204

204205
char *a_typeless_ptr = matrixA.get_data();

dpnp/backend/extensions/blas/gemm.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,8 @@ extern std::pair<sycl::event, sycl::event>
4545
dpctl::tensor::usm_ndarray resultC,
4646
const std::vector<sycl::event> &depends);
4747

48-
// extern sycl::event
4948
extern std::pair<sycl::event, sycl::event>
50-
gemm_batch(sycl::queue q,
49+
gemm_batch(sycl::queue exec_q,
5150
dpctl::tensor::usm_ndarray matrixA,
5251
dpctl::tensor::usm_ndarray matrixB,
5352
dpctl::tensor::usm_ndarray resultC,

dpnp/backend/extensions/blas/gemm_batch.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ static sycl::event gemm_batch_impl(sycl::queue exec_q,
100100

101101
sycl::event gemm_batch_event;
102102
try {
103-
gemm_batch_event = oneapi::mkl::blas::row_major::gemm_batch(
103+
gemm_batch_event = mkl_blas::row_major::gemm_batch(
104104
exec_q, transA, transB, m, n, k, Tab(1), a, ld_array_1, stridea, b,
105105
ld_array_2, strideb, Tab(0), res, ld_result, stridec, batch_size,
106106
depends);
@@ -171,7 +171,8 @@ std::pair<sycl::event, sycl::event>
171171
gemm_batch_impl_fn_ptr_t gemm_batch_fn =
172172
gemm_batch_dispatch_table[matrixAB_type_id][resultC_type_id];
173173
if (gemm_batch_fn == nullptr) {
174-
throw py::value_error("Type dispatch ran into trouble.");
174+
throw py::value_error(
175+
"Types of input matrices and result matrix are mismatched.");
175176
}
176177

177178
char *a_typeless_ptr = matrixA.get_data();

dpnp/backend/extensions/blas/types_matrix.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,14 @@ struct GemmTypePairSupportFactory
7373
dpctl_td_ns::NotDefinedEntry>::is_defined;
7474
};
7575

76+
/**
77+
* @brief A factory to define pairs of supported types for which
78+
* MKL BLAS library provides support in
79+
* oneapi::mkl::blas::gemm_batch<Tab, Tc> function.
80+
*
81+
* @tparam Tab Type of arrays containing input matrices A and B.
82+
* @tparam Tc Type of array containing output matrix C.
83+
*/
7684
template <typename Tab, typename Tc>
7785
struct GemmBatchTypePairSupportFactory
7886
{

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 51 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -285,25 +285,26 @@ def matmul(
285285
286286
Examples
287287
--------
288-
>>> import dpnp as np
289-
>>> a = np.ones([9, 5, 7, 4])
290-
>>> c = np.ones([9, 5, 4, 3])
291-
>>> np.matmul(a, c).shape
292-
(9, 5, 7, 3)
288+
For 2-D arrays it is the matrix product:
293289
290+
>>> import dpnp as np
294291
>>> a = np.array([[1, 0], [0, 1]])
295292
>>> b = np.array([[4, 1], [2, 2]])
296293
>>> np.matmul(a, b)
297294
array([[4, 1],
298295
[2, 2]])
299296
297+
For 2-D mixed with 1-D, the result is the usual.
298+
300299
>>> a = np.array([[1, 0], [0, 1]])
301300
>>> b = np.array([1, 2])
302301
>>> np.matmul(a, b)
303302
array([1, 2])
304303
>>> np.matmul(b, a)
305304
array([1, 2])
306305
306+
Broadcasting is conventional for stacks of arrays
307+
307308
>>> a = np.arange(2 * 2 * 4).reshape((2, 2, 4))
308309
>>> b = np.arange(2 * 2 * 4).reshape((2, 4, 2))
309310
>>> np.matmul(a,b).shape
@@ -313,11 +314,16 @@ def matmul(
313314
>>> np.sum(a[0, 1, :] * b[0 , :, 1])
314315
array(98)
315316
316-
The ``@`` operator can be used as a shorthand for ``matmul`` on
317-
:class:`dpnp.ndarray`.
317+
Vector, vector returns the scalar inner product, but neither argument is complex-conjugated:
318318
319319
>>> x1 = np.array([2j, 3j])
320320
>>> x2 = np.array([2j, 3j])
321+
>>> np.matmul(x1, x2)
322+
array(-13+0j)
323+
324+
The ``@`` operator can be used as a shorthand for ``matmul`` on
325+
:class:`dpnp.ndarray`.
326+
321327
>>> x1 @ x2
322328
array(-13+0j)
323329
@@ -590,27 +596,52 @@ def dpnp_matmul_batch(
590596

591597

592598
def _gemm_res_dtype(*arrays, casting):
593-
dtype = dpnp.result_type(*arrays)
594-
default = dpnp.default_float_type(device=arrays[0].device)
595-
if dpnp.issubdtype(dtype, dpnp.complexfloating):
596-
default = dpnp.complex64 if default == dpnp.float32 else dpnp.complex128
599+
"""
600+
Determines the data types for matmul operation and the output array of matmul operation.
601+
602+
The output array data type is determined based on the Promotion Type Rule
603+
and device capibilities. The data type used in matmul operation is an 'inexact' data type
604+
determined based on the output data type and device capabilities.
605+
Both data types are determined based on the fact that the output array data type can be cast
606+
to the other data type according to casting rule specified, otherwise a ``TypeError`` is raised.
607+
608+
Parameters
609+
----------
610+
arrays : {dpnp_array, usm_ndarray}
611+
Input arrays.
612+
casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
613+
Controls what kind of data casting may occur.
614+
615+
Returns
616+
-------
617+
gemm_dtype, res_dtype :
618+
The appropriate data types for performing matmul operation and presenting output array.
619+
620+
"""
621+
622+
res_dtype = dpnp.result_type(*arrays)
623+
gemm_dtype = dpnp.default_float_type(device=arrays[0].device)
624+
if dpnp.issubdtype(res_dtype, dpnp.complexfloating):
625+
gemm_dtype = (
626+
dpnp.complex64 if gemm_dtype == dpnp.float32 else dpnp.complex128
627+
)
597628

598-
if dpnp.can_cast(dtype, default, casting):
599-
if dtype in [
629+
if dpnp.can_cast(res_dtype, gemm_dtype, casting):
630+
if res_dtype in [
600631
dpnp.float64,
601632
dpnp.complex128,
602-
]: # in case device does not support fp64 (default)
603-
return default, default
604-
elif dtype in [
633+
]: # in case device does not support fp64
634+
return gemm_dtype, gemm_dtype
635+
elif res_dtype in [
605636
dpnp.float32,
606637
dpnp.complex64,
607-
]: # needed dtype is fp32 but device supports fp64 (default)
608-
return dtype, dtype
638+
]: # needed dtype is fp32 but device supports fp64
639+
return res_dtype, res_dtype
609640
else:
610-
return default, dtype
641+
return gemm_dtype, res_dtype
611642
else:
612643
raise TypeError(
613-
f"Cannot cast ufunc 'matmul' output from dtype({dtype}) to dtype({default}) with casting rule {casting}"
644+
f"Cannot cast ufunc 'matmul' output from dtype({res_dtype}) to dtype({gemm_dtype}) with casting rule {casting}"
614645
)
615646

616647

tests/third_party/cupy/math_tests/test_matmul.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757
}
5858
)
5959
)
60-
@testing.gpu
6160
class TestMatmul(unittest.TestCase):
6261
@testing.for_all_dtypes(name="dtype1")
6362
@testing.numpy_cupy_allclose(rtol=1e-3, atol=1e-3) # required for uint8
@@ -93,7 +92,6 @@ def test_cupy_matmul(self, xp, dtype1):
9392
}
9493
)
9594
)
96-
@testing.gpu
9795
class TestMatmulLarge(unittest.TestCase):
9896
# Avoid overflow
9997
skip_dtypes = {
@@ -149,7 +147,6 @@ def test_cupy_matmul(self, xp, dtype1):
149147
}
150148
)
151149
)
152-
@testing.gpu
153150
class TestMatmulInvalidShape(unittest.TestCase):
154151
def test_invalid_shape(self):
155152
for xp in (numpy, dpnp):

0 commit comments

Comments
 (0)