Skip to content

Commit 441cb7b

Browse files
ndgrigorianoleksandr-pavlyk
authored andcommitted
Change per PR review by @oleksandr-pavlyk
1 parent c2b3701 commit 441cb7b

File tree

1 file changed

+123
-125
lines changed

1 file changed

+123
-125
lines changed

dpctl/tensor/_elementwise_common.py

Lines changed: 123 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -930,140 +930,138 @@ def __call__(self, o1, o2, /, *, out=None, order="K"):
930930
return out
931931

932932
def _inplace_op(self, o1, o2):
933-
if self.binary_inplace_fn_ is not None:
934-
if not isinstance(o1, dpt.usm_ndarray):
935-
raise TypeError(
936-
"Expected first argument to be "
937-
f"dpctl.tensor.usm_ndarray, got {type(o1)}"
938-
)
939-
if not o1.flags.writable:
940-
raise ValueError("provided left-hand side array is read-only")
941-
q1, o1_usm_type = o1.sycl_queue, o1.usm_type
942-
q2, o2_usm_type = _get_queue_usm_type(o2)
943-
if q2 is None:
944-
exec_q = q1
945-
res_usm_type = o1_usm_type
946-
else:
947-
exec_q = dpctl.utils.get_execution_queue((q1, q2))
948-
if exec_q is None:
949-
raise ExecutionPlacementError(
950-
"Execution placement can not be unambiguously inferred "
951-
"from input arguments."
952-
)
953-
res_usm_type = dpctl.utils.get_coerced_usm_type(
954-
(
955-
o1_usm_type,
956-
o2_usm_type,
957-
)
958-
)
959-
dpctl.utils.validate_usm_type(res_usm_type, allow_none=False)
960-
o1_shape = o1.shape
961-
o2_shape = _get_shape(o2)
962-
if not isinstance(o2_shape, (tuple, list)):
963-
raise TypeError(
964-
"Shape of second argument can not be inferred. "
965-
"Expected list or tuple."
966-
)
967-
try:
968-
res_shape = _broadcast_shape_impl(
969-
[
970-
o1_shape,
971-
o2_shape,
972-
]
973-
)
974-
except ValueError:
975-
raise ValueError(
976-
"operands could not be broadcast together with shapes "
977-
f"{o1_shape} and {o2_shape}"
933+
if self.binary_inplace_fn_ is None:
934+
raise ValueError(
935+
"binary function does not have a dedicated in-place "
936+
"implementation"
937+
)
938+
if not isinstance(o1, dpt.usm_ndarray):
939+
raise TypeError(
940+
"Expected first argument to be "
941+
f"dpctl.tensor.usm_ndarray, got {type(o1)}"
942+
)
943+
if not o1.flags.writable:
944+
raise ValueError("provided left-hand side array is read-only")
945+
q1, o1_usm_type = o1.sycl_queue, o1.usm_type
946+
q2, o2_usm_type = _get_queue_usm_type(o2)
947+
if q2 is None:
948+
exec_q = q1
949+
res_usm_type = o1_usm_type
950+
else:
951+
exec_q = dpctl.utils.get_execution_queue((q1, q2))
952+
if exec_q is None:
953+
raise ExecutionPlacementError(
954+
"Execution placement can not be unambiguously inferred "
955+
"from input arguments."
978956
)
979-
980-
if res_shape != o1_shape:
981-
raise ValueError(
982-
"The shape of the non-broadcastable left-hand "
983-
f"side {o1_shape} is inconsistent with the "
984-
f"broadcast shape {res_shape}."
957+
res_usm_type = dpctl.utils.get_coerced_usm_type(
958+
(
959+
o1_usm_type,
960+
o2_usm_type,
985961
)
986-
987-
sycl_dev = exec_q.sycl_device
988-
o1_dtype = o1.dtype
989-
o2_dtype = _get_dtype(o2, sycl_dev)
990-
if not _validate_dtype(o2_dtype):
991-
raise ValueError("Operand has an unsupported data type")
992-
993-
o1_dtype, o2_dtype = self.weak_type_resolver_(
994-
o1_dtype, o2_dtype, sycl_dev
962+
)
963+
dpctl.utils.validate_usm_type(res_usm_type, allow_none=False)
964+
o1_shape = o1.shape
965+
o2_shape = _get_shape(o2)
966+
if not isinstance(o2_shape, (tuple, list)):
967+
raise TypeError(
968+
"Shape of second argument can not be inferred. "
969+
"Expected list or tuple."
970+
)
971+
try:
972+
res_shape = _broadcast_shape_impl(
973+
[
974+
o1_shape,
975+
o2_shape,
976+
]
977+
)
978+
except ValueError:
979+
raise ValueError(
980+
"operands could not be broadcast together with shapes "
981+
f"{o1_shape} and {o2_shape}"
995982
)
996983

997-
buf_dt, res_dt = _find_buf_dtype_in_place_op(
998-
o1_dtype,
999-
o2_dtype,
1000-
self.result_type_resolver_fn_,
1001-
sycl_dev,
984+
if res_shape != o1_shape:
985+
raise ValueError(
986+
"The shape of the non-broadcastable left-hand "
987+
f"side {o1_shape} is inconsistent with the "
988+
f"broadcast shape {res_shape}."
1002989
)
1003990

1004-
if res_dt is None:
1005-
raise ValueError(
1006-
f"function '{self.name_}' does not support input types "
1007-
f"({o1_dtype}, {o2_dtype}), "
1008-
"and the inputs could not be safely coerced to any "
1009-
"supported types according to the casting rule "
1010-
"''same_kind''."
1011-
)
991+
sycl_dev = exec_q.sycl_device
992+
o1_dtype = o1.dtype
993+
o2_dtype = _get_dtype(o2, sycl_dev)
994+
if not _validate_dtype(o2_dtype):
995+
raise ValueError("Operand has an unsupported data type")
1012996

1013-
if res_dt != o1_dtype:
1014-
raise ValueError(
1015-
f"Output array of type {res_dt} is needed, "
1016-
f"got {o1_dtype}"
1017-
)
997+
o1_dtype, o2_dtype = self.weak_type_resolver_(
998+
o1_dtype, o2_dtype, sycl_dev
999+
)
10181000

1019-
_manager = SequentialOrderManager[exec_q]
1020-
if isinstance(o2, dpt.usm_ndarray):
1021-
src2 = o2
1022-
if (
1023-
ti._array_overlap(o2, o1)
1024-
and not ti._same_logical_tensors(o2, o1)
1025-
and buf_dt is None
1026-
):
1027-
buf_dt = o2_dtype
1028-
else:
1029-
src2 = dpt.asarray(o2, dtype=o2_dtype, sycl_queue=exec_q)
1030-
if buf_dt is None:
1031-
if src2.shape != res_shape:
1032-
src2 = dpt.broadcast_to(src2, res_shape)
1033-
dep_evs = _manager.submitted_events
1034-
ht_, comp_ev = self.binary_inplace_fn_(
1035-
lhs=o1,
1036-
rhs=src2,
1037-
sycl_queue=exec_q,
1038-
depends=dep_evs,
1039-
)
1040-
_manager.add_event_pair(ht_, comp_ev)
1041-
else:
1042-
buf = dpt.empty_like(src2, dtype=buf_dt)
1043-
dep_evs = _manager.submitted_events
1044-
(
1045-
ht_copy_ev,
1046-
copy_ev,
1047-
) = ti._copy_usm_ndarray_into_usm_ndarray(
1048-
src=src2,
1049-
dst=buf,
1050-
sycl_queue=exec_q,
1051-
depends=dep_evs,
1052-
)
1053-
_manager.add_event_pair(ht_copy_ev, copy_ev)
1001+
buf_dt, res_dt = _find_buf_dtype_in_place_op(
1002+
o1_dtype,
1003+
o2_dtype,
1004+
self.result_type_resolver_fn_,
1005+
sycl_dev,
1006+
)
10541007

1055-
buf = dpt.broadcast_to(buf, res_shape)
1056-
ht_, bf_ev = self.binary_inplace_fn_(
1057-
lhs=o1,
1058-
rhs=buf,
1059-
sycl_queue=exec_q,
1060-
depends=[copy_ev],
1061-
)
1062-
_manager.add_event_pair(ht_, bf_ev)
1008+
if res_dt is None:
1009+
raise ValueError(
1010+
f"function '{self.name_}' does not support input types "
1011+
f"({o1_dtype}, {o2_dtype}), "
1012+
"and the inputs could not be safely coerced to any "
1013+
"supported types according to the casting rule "
1014+
"''same_kind''."
1015+
)
10631016

1064-
return o1
1065-
else:
1017+
if res_dt != o1_dtype:
10661018
raise ValueError(
1067-
"binary function does not have a dedicated in-place "
1068-
"implementation"
1019+
f"Output array of type {res_dt} is needed, " f"got {o1_dtype}"
10691020
)
1021+
1022+
_manager = SequentialOrderManager[exec_q]
1023+
if isinstance(o2, dpt.usm_ndarray):
1024+
src2 = o2
1025+
if (
1026+
ti._array_overlap(o2, o1)
1027+
and not ti._same_logical_tensors(o2, o1)
1028+
and buf_dt is None
1029+
):
1030+
buf_dt = o2_dtype
1031+
else:
1032+
src2 = dpt.asarray(o2, dtype=o2_dtype, sycl_queue=exec_q)
1033+
if buf_dt is None:
1034+
if src2.shape != res_shape:
1035+
src2 = dpt.broadcast_to(src2, res_shape)
1036+
dep_evs = _manager.submitted_events
1037+
ht_, comp_ev = self.binary_inplace_fn_(
1038+
lhs=o1,
1039+
rhs=src2,
1040+
sycl_queue=exec_q,
1041+
depends=dep_evs,
1042+
)
1043+
_manager.add_event_pair(ht_, comp_ev)
1044+
else:
1045+
buf = dpt.empty_like(src2, dtype=buf_dt)
1046+
dep_evs = _manager.submitted_events
1047+
(
1048+
ht_copy_ev,
1049+
copy_ev,
1050+
) = ti._copy_usm_ndarray_into_usm_ndarray(
1051+
src=src2,
1052+
dst=buf,
1053+
sycl_queue=exec_q,
1054+
depends=dep_evs,
1055+
)
1056+
_manager.add_event_pair(ht_copy_ev, copy_ev)
1057+
1058+
buf = dpt.broadcast_to(buf, res_shape)
1059+
ht_, bf_ev = self.binary_inplace_fn_(
1060+
lhs=o1,
1061+
rhs=buf,
1062+
sycl_queue=exec_q,
1063+
depends=[copy_ev],
1064+
)
1065+
_manager.add_event_pair(ht_, bf_ev)
1066+
1067+
return o1

0 commit comments

Comments
 (0)