Skip to content

Commit f81119f

Browse files
committed
Implements BinaryElementwiseFunc._inplace_op method
This method permits casting behavior equivalent to `"same_kind"` when using in-place operators by introducing the `_inplace_op` method Expands this to `__imatmul__` as well through use of the already-implemented `dtype` keyword
1 parent 5d70795 commit f81119f

File tree

3 files changed

+154
-15
lines changed

3 files changed

+154
-15
lines changed

dpctl/tensor/_elementwise_common.py

Lines changed: 125 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
_all_data_types,
3838
_find_buf_dtype,
3939
_find_buf_dtype2,
40+
_find_buf_dtype_in_place_op,
4041
_resolve_weak_types,
4142
_to_device_supported_dtype,
4243
)
@@ -213,7 +214,7 @@ def __call__(self, x, /, *, out=None, order="K"):
213214

214215
if res_dt != out.dtype:
215216
raise ValueError(
216-
f"Output array of type {res_dt} is needed,"
217+
f"Output array of type {res_dt} is needed, "
217218
f" got {out.dtype}"
218219
)
219220

@@ -650,7 +651,7 @@ def __call__(self, o1, o2, /, *, out=None, order="K"):
650651

651652
if res_dt != out.dtype:
652653
raise ValueError(
653-
f"Output array of type {res_dt} is needed,"
654+
f"Output array of type {res_dt} is needed, "
654655
f"got {out.dtype}"
655656
)
656657

@@ -927,3 +928,125 @@ def __call__(self, o1, o2, /, *, out=None, order="K"):
927928
)
928929
_manager.add_event_pair(ht_, bf_ev)
929930
return out
931+
932+
def _inplace_op(self, o1, o2):
933+
if not isinstance(o1, dpt.usm_ndarray):
934+
raise TypeError(
935+
"Expected first argument to be "
936+
f"dpctl.tensor.usm_ndarray, got {type(o1)}"
937+
)
938+
if not o1.flags.writable:
939+
raise ValueError("provided left-hand side array is read-only")
940+
q1, o1_usm_type = o1.sycl_queue, o1.usm_type
941+
q2, o2_usm_type = _get_queue_usm_type(o2)
942+
if q2 is None:
943+
exec_q = q1
944+
res_usm_type = o1_usm_type
945+
else:
946+
exec_q = dpctl.utils.get_execution_queue((q1, q2))
947+
if exec_q is None:
948+
raise ExecutionPlacementError(
949+
"Execution placement can not be unambiguously inferred "
950+
"from input arguments."
951+
)
952+
res_usm_type = dpctl.utils.get_coerced_usm_type(
953+
(
954+
o1_usm_type,
955+
o2_usm_type,
956+
)
957+
)
958+
dpctl.utils.validate_usm_type(res_usm_type, allow_none=False)
959+
o1_shape = o1.shape
960+
o2_shape = _get_shape(o2)
961+
if not isinstance(o2_shape, (tuple, list)):
962+
raise TypeError(
963+
"Shape of second argument can not be inferred. "
964+
"Expected list or tuple."
965+
)
966+
try:
967+
res_shape = _broadcast_shape_impl(
968+
[
969+
o1_shape,
970+
o2_shape,
971+
]
972+
)
973+
except ValueError:
974+
raise ValueError(
975+
"operands could not be broadcast together with shapes "
976+
f"{o1_shape} and {o2_shape}"
977+
)
978+
if res_shape != o1_shape:
979+
raise ValueError("")
980+
sycl_dev = exec_q.sycl_device
981+
o1_dtype = o1.dtype
982+
o2_dtype = _get_dtype(o2, sycl_dev)
983+
if not _validate_dtype(o2_dtype):
984+
raise ValueError("Operand has an unsupported data type")
985+
986+
o1_dtype, o2_dtype = self.weak_type_resolver_(
987+
o1_dtype, o2_dtype, sycl_dev
988+
)
989+
990+
buf_dt, res_dt = _find_buf_dtype_in_place_op(
991+
o1_dtype,
992+
o2_dtype,
993+
self.result_type_resolver_fn_,
994+
sycl_dev,
995+
)
996+
997+
if res_dt is None:
998+
raise ValueError(
999+
f"function '{self.name_}' does not support input types "
1000+
f"({o1_dtype}, {o2_dtype}), "
1001+
"and the inputs could not be safely coerced to any "
1002+
"supported types according to the casting rule ''same_kind''."
1003+
)
1004+
1005+
if res_dt != o1_dtype:
1006+
raise ValueError(
1007+
f"Output array of type {res_dt} is needed, " f"got {o1_dtype}"
1008+
)
1009+
1010+
_manager = SequentialOrderManager[exec_q]
1011+
if isinstance(o2, dpt.usm_ndarray):
1012+
src2 = o2
1013+
if (
1014+
ti._array_overlap(o2, o1)
1015+
and not ti._same_logical_tensors(o2, o1)
1016+
and buf_dt is None
1017+
):
1018+
buf_dt = o2_dtype
1019+
else:
1020+
src2 = dpt.asarray(o2, dtype=o2_dtype, sycl_queue=exec_q)
1021+
if buf_dt is None:
1022+
if src2.shape != res_shape:
1023+
src2 = dpt.broadcast_to(src2, res_shape)
1024+
dep_evs = _manager.submitted_events
1025+
ht_, comp_ev = self.binary_inplace_fn_(
1026+
lhs=o1,
1027+
rhs=src2,
1028+
sycl_queue=exec_q,
1029+
depends=dep_evs,
1030+
)
1031+
_manager.add_event_pair(ht_, comp_ev)
1032+
else:
1033+
buf = dpt.empty_like(src2, dtype=buf_dt)
1034+
dep_evs = _manager.submitted_events
1035+
(ht_copy_ev, copy_ev,) = ti._copy_usm_ndarray_into_usm_ndarray(
1036+
src=src2,
1037+
dst=buf,
1038+
sycl_queue=exec_q,
1039+
depends=dep_evs,
1040+
)
1041+
_manager.add_event_pair(ht_copy_ev, copy_ev)
1042+
1043+
buf = dpt.broadcast_to(buf, res_shape)
1044+
ht_, bf_ev = self.binary_inplace_fn_(
1045+
lhs=o1,
1046+
rhs=buf,
1047+
sycl_queue=exec_q,
1048+
depends=[copy_ev],
1049+
)
1050+
_manager.add_event_pair(ht_, bf_ev)
1051+
1052+
return o1

dpctl/tensor/_type_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,21 @@ def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev, acceptance_fn):
277277
return None, None, None
278278

279279

280+
def _find_buf_dtype_in_place_op(arg1_dtype, arg2_dtype, query_fn, sycl_dev):
281+
res_dt = query_fn(arg1_dtype, arg2_dtype)
282+
if res_dt:
283+
return None, res_dt
284+
285+
_fp16 = sycl_dev.has_aspect_fp16
286+
_fp64 = sycl_dev.has_aspect_fp64
287+
if _can_cast(arg2_dtype, arg1_dtype, _fp16, _fp64, casting="same_kind"):
288+
res_dt = query_fn(arg1_dtype, arg1_dtype)
289+
if res_dt:
290+
return arg1_dtype, res_dt
291+
292+
return None, None
293+
294+
280295
class WeakBooleanType:
281296
"Python type representing type of Python boolean objects"
282297

@@ -959,4 +974,5 @@ def _default_accumulation_dtype_fp_types(inp_dt, q):
959974
"WeakComplexType",
960975
"_default_accumulation_dtype",
961976
"_default_accumulation_dtype_fp_types",
977+
"_find_buf_dtype_in_place_op",
962978
]

dpctl/tensor/_usmarray.pyx

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1508,43 +1508,43 @@ cdef class usm_ndarray:
15081508
return dpctl.tensor.bitwise_xor(other, self)
15091509

15101510
def __iadd__(self, other):
1511-
return dpctl.tensor.add(self, other, out=self)
1511+
return dpctl.tensor.add._inplace_op(self, other)
15121512

15131513
def __iand__(self, other):
1514-
return dpctl.tensor.bitwise_and(self, other, out=self)
1514+
return dpctl.tensor.bitwise_and._inplace_op(self, other)
15151515

15161516
def __ifloordiv__(self, other):
1517-
return dpctl.tensor.floor_divide(self, other, out=self)
1517+
return dpctl.tensor.floor_divide._inplace_op(self, other)
15181518

15191519
def __ilshift__(self, other):
1520-
return dpctl.tensor.bitwise_left_shift(self, other, out=self)
1520+
return dpctl.tensor.bitwise_left_shift._inplace_op(self, other)
15211521

15221522
def __imatmul__(self, other):
1523-
return dpctl.tensor.matmul(self, other, out=self)
1523+
return dpctl.tensor.matmul(self, other, out=self, dtype=self.dtype)
15241524

15251525
def __imod__(self, other):
1526-
return dpctl.tensor.remainder(self, other, out=self)
1526+
return dpctl.tensor.remainder._inplace_op(self, other)
15271527

15281528
def __imul__(self, other):
1529-
return dpctl.tensor.multiply(self, other, out=self)
1529+
return dpctl.tensor.multiply._inplace_op(self, other)
15301530

15311531
def __ior__(self, other):
1532-
return dpctl.tensor.bitwise_or(self, other, out=self)
1532+
return dpctl.tensor.bitwise_or._inplace_op(self, other)
15331533

15341534
def __ipow__(self, other):
1535-
return dpctl.tensor.pow(self, other, out=self)
1535+
return dpctl.tensor.pow._inplace_op(self, other)
15361536

15371537
def __irshift__(self, other):
1538-
return dpctl.tensor.bitwise_right_shift(self, other, out=self)
1538+
return dpctl.tensor.bitwise_right_shift._inplace_op(self, other)
15391539

15401540
def __isub__(self, other):
1541-
return dpctl.tensor.subtract(self, other, out=self)
1541+
return dpctl.tensor.subtract._inplace_op(self, other)
15421542

15431543
def __itruediv__(self, other):
1544-
return dpctl.tensor.divide(self, other, out=self)
1544+
return dpctl.tensor.divide._inplace_op(self, other)
15451545

15461546
def __ixor__(self, other):
1547-
return dpctl.tensor.bitwise_xor(self, other, out=self)
1547+
return dpctl.tensor.bitwise_xor._inplace_op(self, other)
15481548

15491549
def __str__(self):
15501550
return usm_ndarray_str(self)

0 commit comments

Comments
 (0)