Skip to content

Commit 9899b78

Browse files
authored
Merge pull request #21120 from BvB93/matmul
ENH: Add support for inplace matrix multiplication Original NumPy Commit: a37978a106073eaec5cb9e0cb54785fafb639650
2 parents 68122b4 + 12baa24 commit 9899b78

File tree

1 file changed

+2
-12
lines changed

1 file changed

+2
-12
lines changed

array_api_strict/_array_object.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -850,23 +850,13 @@ def __imatmul__(self: Array, other: Array, /) -> Array:
850850
"""
851851
Performs the operation __imatmul__.
852852
"""
853-
# Note: NumPy does not implement __imatmul__.
854-
855853
# matmul is not defined for scalars, but without this, we may get
856854
# the wrong error message from asarray.
857855
other = self._check_allowed_dtypes(other, "numeric", "__imatmul__")
858856
if other is NotImplemented:
859857
return other
860-
861-
# __imatmul__ can only be allowed when it would not change the shape
862-
# of self.
863-
other_shape = other.shape
864-
if self.shape == () or other_shape == ():
865-
raise ValueError("@= requires at least one dimension")
866-
if len(other_shape) == 1 or other_shape[-1] != other_shape[-2]:
867-
raise ValueError("@= cannot change the shape of the input array")
868-
self._array[:] = self._array.__matmul__(other._array)
869-
return self
858+
res = self._array.__imatmul__(other._array)
859+
return self.__class__._new(res)
870860

871861
def __rmatmul__(self: Array, other: Array, /) -> Array:
872862
"""

0 commit comments

Comments
 (0)