Skip to content

Commit e22412f

Browse files
authored
Fixes out kwarg in matmul when axes are appended to inputs (#1610)
* Fixes `out` keyword in `matmul` for cases where axes are appended to inputs * Adds test for fixed matmul `out` kwarg * Fix typo in matmul docstring and adds documentation for dtype kwarg
1 parent d7c54e4 commit e22412f

File tree

2 files changed

+46
-7
lines changed

2 files changed

+46
-7
lines changed

dpctl/tensor/_linear_algebra_functions.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -599,11 +599,16 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
599599
matrices on which to perform matrix multiplication.
600600
out (Optional[usm_ndarray]):
601601
the array into which the result of the matrix product is written.
602-
If `None` then a new array is returned.
602+
The data type of `out` must match the expected data type of the
603+
result or (if provided) `dtype`.
604+
If `None` then a new array is returned. Default: `None`.
605+
dtype (Optional[dtype]):
606+
data type of the returned array. If `None`, the data type of the
607+
returned array is determined by the Type Promotion Rules.
608+
Default: `None`.
603609
order (["K", "C", "F", "A"]):
604610
memory layout of the output array, if `out` is `None`, otherwise
605-
the `order` parameter value is not used.
606-
611+
the `order` parameter value is not used. Default: `K`.
607612
Returns:
608613
usm_ndarray:
609614
* if both `x1` and `x2` are one-dimensional arrays with shape
@@ -613,8 +618,8 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
613618
a two-dimensional array with shape `(K, N)`, returned array is a
614619
two-dimensional array with shape `(M, N)` and contains the
615620
conventional matrix product.
616-
* if `x1` is a one-dimensinal array with shape `(K,)` and `x2` is an
617-
array with shape `(..., K, N)`, returned array contains the
621+
* if `x1` is a one-dimensional array with shape `(K,)` and `x2` is
622+
an array with shape `(..., K, N)`, returned array contains the
618623
conventional matrix product and has shape `(..., N)`.
619624
* if `x1` is an array with shape `(..., M, K)` and `x2` is a
620625
one-dimensional array with shape `(K,)`, returned array has shape
@@ -741,12 +746,21 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
741746
if not out.flags.writable:
742747
raise ValueError("provided `out` array is read-only")
743748

744-
if out.shape != res_shape:
749+
final_res_shape = tuple(
750+
res_shape[i]
751+
for i in range(-len(res_shape), 0)
752+
if i not in appended_axes
753+
)
754+
if out.shape != final_res_shape:
745755
raise ValueError(
746756
"The shape of input and output arrays are inconsistent. "
747-
f"Expected output shape is {res_shape}, got {out.shape}"
757+
f"Expected output shape is {final_res_shape}, got {out.shape}"
748758
)
749759

760+
if appended_axes:
761+
out = dpt.expand_dims(out, appended_axes)
762+
orig_out = out
763+
750764
if res_dt != out.dtype:
751765
raise ValueError(
752766
f"Output array of type {res_dt} is needed," f"got {out.dtype}"

dpctl/tests/test_usm_ndarray_linalg.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -980,3 +980,28 @@ def test_vecdot_contig_small():
980980
res = dpt.vecdot(v1, v2)
981981
assert dpt.all(res[:-1] == 0)
982982
assert res[-1] == n
983+
984+
985+
def test_matmul_out_appended_axes():
986+
get_queue_or_skip()
987+
988+
n0, n1, n2 = 4, 10, 5
989+
# vm
990+
x1 = dpt.ones(n1, dtype="i4")
991+
x2 = dpt.ones((n0, n1, n2), dtype="i4")
992+
out = dpt.empty((n0, n2), dtype="i4")
993+
994+
dpt.matmul(x1, x2, out=out)
995+
assert dpt.all(out == n1)
996+
997+
# mv
998+
x2 = x2.mT
999+
x1, x2 = x2, x1
1000+
dpt.matmul(x1, x2, out=out)
1001+
assert dpt.all(out == n1)
1002+
1003+
# vv
1004+
x1 = dpt.ones(n1, dtype="i4")
1005+
out = dpt.empty((), dtype="i4")
1006+
dpt.matmul(x1, x2, out=out)
1007+
assert out == n1

0 commit comments

Comments
 (0)