@@ -599,11 +599,16 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
599
599
matrices on which to perform matrix multiplication.
600
600
out (Optional[usm_ndarray]):
601
601
the array into which the result of the matrix product is written.
602
- If `None` then a new array is returned.
602
+ The data type of `out` must match the expected data type of the
603
+ result or (if provided) `dtype`.
604
+ If `None` then a new array is returned. Default: `None`.
605
+ dtype (Optional[dtype]):
606
+ data type of the returned array. If `None`, the data type of the
607
+ returned array is determined by the Type Promotion Rules.
608
+ Default: `None`.
603
609
order (["K", "C", "F", "A"]):
604
610
memory layout of the output array, if `out` is `None`, otherwise
605
- the `order` parameter value is not used.
606
-
611
+ the `order` parameter value is not used. Default: `K`.
607
612
Returns:
608
613
usm_ndarray:
609
614
* if both `x1` and `x2` are one-dimensional arrays with shape
@@ -613,8 +618,8 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
613
618
a two-dimensional array with shape `(K, N)`, returned array is a
614
619
two-dimensional array with shape `(M, N)` and contains the
615
620
conventional matrix product.
616
- * if `x1` is a one-dimensinal array with shape `(K,)` and `x2` is an
617
- array with shape `(..., K, N)`, returned array contains the
621
+ * if `x1` is a one-dimensional array with shape `(K,)` and `x2` is
622
+ an array with shape `(..., K, N)`, returned array contains the
618
623
conventional matrix product and has shape `(..., N)`.
619
624
* if `x1` is an array with shape `(..., M, K)` and `x2` is a
620
625
one-dimensional array with shape `(K,)`, returned array has shape
@@ -741,12 +746,21 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
741
746
if not out .flags .writable :
742
747
raise ValueError ("provided `out` array is read-only" )
743
748
744
- if out .shape != res_shape :
749
+ final_res_shape = tuple (
750
+ res_shape [i ]
751
+ for i in range (- len (res_shape ), 0 )
752
+ if i not in appended_axes
753
+ )
754
+ if out .shape != final_res_shape :
745
755
raise ValueError (
746
756
"The shape of input and output arrays are inconsistent. "
747
- f"Expected output shape is { res_shape } , got { out .shape } "
757
+ f"Expected output shape is { final_res_shape } , got { out .shape } "
748
758
)
749
759
760
+ if appended_axes :
761
+ out = dpt .expand_dims (out , appended_axes )
762
+ orig_out = out
763
+
750
764
if res_dt != out .dtype :
751
765
raise ValueError (
752
766
f"Output array of type { res_dt } is needed," f"got { out .dtype } "
0 commit comments