Skip to content

Commit 3d02b6b

Browse files
authored
Permit "same_kind" casting for element-wise in-place operators (#2170)
The PR proposes to permit `"same_kind"` casting for element-wise in-place operators. The implementation leverages on dpctl changes added in scope of [PR#1827](IntelPython/dpctl#1827). It also adds callbacks to support in-place bit-wise operators (leverages on dpctl changes from [RR#1447](IntelPython/dpctl#1447)). The PR removes a temporary workaround from `dpnp.wrap` which depends on the implemented changes.
1 parent f7c0938 commit 3d02b6b

File tree

7 files changed

+1200
-889
lines changed

7 files changed

+1200
-889
lines changed

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,20 @@ def __call__(
335335
"as an argument, but both were provided."
336336
)
337337

338+
x1_usm = dpnp.get_usm_ndarray_or_scalar(x1)
339+
x2_usm = dpnp.get_usm_ndarray_or_scalar(x2)
340+
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
341+
342+
if (
343+
isinstance(x1, dpnp_array)
344+
and x1 is out
345+
and order == "K"
346+
and dtype is None
347+
):
348+
# in-place operation
349+
super()._inplace_op(x1_usm, x2_usm)
350+
return x1
351+
338352
if order is None:
339353
order = "K"
340354
elif order in "afkcAFKC":
@@ -344,9 +358,6 @@ def __call__(
344358
"order must be one of 'C', 'F', 'A', or 'K' (got '{order}')"
345359
)
346360

347-
x1_usm = dpnp.get_usm_ndarray_or_scalar(x1)
348-
x2_usm = dpnp.get_usm_ndarray_or_scalar(x2)
349-
350361
if dtype is not None:
351362
if dpnp.isscalar(x1):
352363
x1_usm = dpt.asarray(
@@ -368,7 +379,6 @@ def __call__(
368379
x1_usm = dpt.astype(x1_usm, dtype, copy=False)
369380
x2_usm = dpt.astype(x2_usm, dtype, copy=False)
370381

371-
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
372382
res_usm = super().__call__(x1_usm, x2_usm, out=out_usm, order=order)
373383

374384
if out is not None and isinstance(out, dpnp_array):

dpnp/dpnp_array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def __imatmul__(self, other):
399399
axes = [(-2, -1), (-2, -1), (-2, -1)]
400400

401401
try:
402-
dpnp.matmul(self, other, out=self, axes=axes)
402+
dpnp.matmul(self, other, out=self, dtype=self.dtype, axes=axes)
403403
except AxisError:
404404
# AxisError should indicate that the axes argument didn't work out
405405
# which should mean the second operand not being 2 dimensional.

dpnp/dpnp_iface_bitwise.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ def binary_repr(num, width=None):
208208
ti._bitwise_and_result_type,
209209
ti._bitwise_and,
210210
_BITWISE_AND_DOCSTRING,
211+
binary_inplace_fn=ti._bitwise_and_inplace,
211212
)
212213

213214

@@ -285,6 +286,7 @@ def binary_repr(num, width=None):
285286
ti._bitwise_or_result_type,
286287
ti._bitwise_or,
287288
_BITWISE_OR_DOCSTRING,
289+
binary_inplace_fn=ti._bitwise_or_inplace,
288290
)
289291

290292

@@ -366,6 +368,7 @@ def binary_repr(num, width=None):
366368
ti._bitwise_xor_result_type,
367369
ti._bitwise_xor,
368370
_BITWISE_XOR_DOCSTRING,
371+
binary_inplace_fn=ti._bitwise_xor_inplace,
369372
)
370373

371374

@@ -518,6 +521,7 @@ def binary_repr(num, width=None):
518521
ti._bitwise_left_shift_result_type,
519522
ti._bitwise_left_shift,
520523
_LEFT_SHIFT_DOCSTRING,
524+
binary_inplace_fn=ti._bitwise_left_shift_inplace,
521525
)
522526

523527
bitwise_left_shift = left_shift # bitwise_left_shift is an alias for left_shift
@@ -595,6 +599,7 @@ def binary_repr(num, width=None):
595599
ti._bitwise_right_shift_result_type,
596600
ti._bitwise_right_shift,
597601
_RIGHT_SHIFT_DOCSTRING,
602+
binary_inplace_fn=ti._bitwise_right_shift_inplace,
598603
)
599604

600605
# bitwise_right_shift is an alias for right_shift

dpnp/dpnp_iface_trigonometric.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2450,7 +2450,5 @@ def unwrap(p, discont=None, axis=-1, *, period=2 * dpnp.pi):
24502450

24512451
up = dpnp.astype(p, dtype=dt, copy=True)
24522452
up[slice1] = p[slice1]
2453-
# TODO: replace, once dpctl-1757 resolved
2454-
# up[slice1] += ph_correct.cumsum(axis=axis)
2455-
up[slice1] += ph_correct.cumsum(axis=axis, dtype=dt)
2453+
up[slice1] += ph_correct.cumsum(axis=axis)
24562454
return up

0 commit comments

Comments
 (0)