|
37 | 37 | _all_data_types,
|
38 | 38 | _find_buf_dtype,
|
39 | 39 | _find_buf_dtype2,
|
| 40 | + _find_buf_dtype_in_place_op, |
40 | 41 | _resolve_weak_types,
|
41 | 42 | _to_device_supported_dtype,
|
42 | 43 | )
|
@@ -213,7 +214,7 @@ def __call__(self, x, /, *, out=None, order="K"):
|
213 | 214 |
|
214 | 215 | if res_dt != out.dtype:
|
215 | 216 | raise ValueError(
|
216 |
| - f"Output array of type {res_dt} is needed," |
| 217 | + f"Output array of type {res_dt} is needed, " |
217 | 218 | f" got {out.dtype}"
|
218 | 219 | )
|
219 | 220 |
|
@@ -650,7 +651,7 @@ def __call__(self, o1, o2, /, *, out=None, order="K"):
|
650 | 651 |
|
651 | 652 | if res_dt != out.dtype:
|
652 | 653 | raise ValueError(
|
653 |
| - f"Output array of type {res_dt} is needed," |
| 654 | + f"Output array of type {res_dt} is needed, " |
654 | 655 | f"got {out.dtype}"
|
655 | 656 | )
|
656 | 657 |
|
@@ -927,3 +928,125 @@ def __call__(self, o1, o2, /, *, out=None, order="K"):
|
927 | 928 | )
|
928 | 929 | _manager.add_event_pair(ht_, bf_ev)
|
929 | 930 | return out
|
| 931 | + |
| 932 | + def _inplace_op(self, o1, o2): |
| 933 | + if not isinstance(o1, dpt.usm_ndarray): |
| 934 | + raise TypeError( |
| 935 | + "Expected first argument to be " |
| 936 | + f"dpctl.tensor.usm_ndarray, got {type(o1)}" |
| 937 | + ) |
| 938 | + if not o1.flags.writable: |
| 939 | + raise ValueError("provided left-hand side array is read-only") |
| 940 | + q1, o1_usm_type = o1.sycl_queue, o1.usm_type |
| 941 | + q2, o2_usm_type = _get_queue_usm_type(o2) |
| 942 | + if q2 is None: |
| 943 | + exec_q = q1 |
| 944 | + res_usm_type = o1_usm_type |
| 945 | + else: |
| 946 | + exec_q = dpctl.utils.get_execution_queue((q1, q2)) |
| 947 | + if exec_q is None: |
| 948 | + raise ExecutionPlacementError( |
| 949 | + "Execution placement can not be unambiguously inferred " |
| 950 | + "from input arguments." |
| 951 | + ) |
| 952 | + res_usm_type = dpctl.utils.get_coerced_usm_type( |
| 953 | + ( |
| 954 | + o1_usm_type, |
| 955 | + o2_usm_type, |
| 956 | + ) |
| 957 | + ) |
| 958 | + dpctl.utils.validate_usm_type(res_usm_type, allow_none=False) |
| 959 | + o1_shape = o1.shape |
| 960 | + o2_shape = _get_shape(o2) |
| 961 | + if not isinstance(o2_shape, (tuple, list)): |
| 962 | + raise TypeError( |
| 963 | + "Shape of second argument can not be inferred. " |
| 964 | + "Expected list or tuple." |
| 965 | + ) |
| 966 | + try: |
| 967 | + res_shape = _broadcast_shape_impl( |
| 968 | + [ |
| 969 | + o1_shape, |
| 970 | + o2_shape, |
| 971 | + ] |
| 972 | + ) |
| 973 | + except ValueError: |
| 974 | + raise ValueError( |
| 975 | + "operands could not be broadcast together with shapes " |
| 976 | + f"{o1_shape} and {o2_shape}" |
| 977 | + ) |
| 978 | + if res_shape != o1_shape: |
| 979 | + raise ValueError("") |
| 980 | + sycl_dev = exec_q.sycl_device |
| 981 | + o1_dtype = o1.dtype |
| 982 | + o2_dtype = _get_dtype(o2, sycl_dev) |
| 983 | + if not _validate_dtype(o2_dtype): |
| 984 | + raise ValueError("Operand has an unsupported data type") |
| 985 | + |
| 986 | + o1_dtype, o2_dtype = self.weak_type_resolver_( |
| 987 | + o1_dtype, o2_dtype, sycl_dev |
| 988 | + ) |
| 989 | + |
| 990 | + buf_dt, res_dt = _find_buf_dtype_in_place_op( |
| 991 | + o1_dtype, |
| 992 | + o2_dtype, |
| 993 | + self.result_type_resolver_fn_, |
| 994 | + sycl_dev, |
| 995 | + ) |
| 996 | + |
| 997 | + if res_dt is None: |
| 998 | + raise ValueError( |
| 999 | + f"function '{self.name_}' does not support input types " |
| 1000 | + f"({o1_dtype}, {o2_dtype}), " |
| 1001 | + "and the inputs could not be safely coerced to any " |
| 1002 | + "supported types according to the casting rule ''same_kind''." |
| 1003 | + ) |
| 1004 | + |
| 1005 | + if res_dt != o1_dtype: |
| 1006 | + raise ValueError( |
| 1007 | + f"Output array of type {res_dt} is needed, " f"got {o1_dtype}" |
| 1008 | + ) |
| 1009 | + |
| 1010 | + _manager = SequentialOrderManager[exec_q] |
| 1011 | + if isinstance(o2, dpt.usm_ndarray): |
| 1012 | + src2 = o2 |
| 1013 | + if ( |
| 1014 | + ti._array_overlap(o2, o1) |
| 1015 | + and not ti._same_logical_tensors(o2, o1) |
| 1016 | + and buf_dt is None |
| 1017 | + ): |
| 1018 | + buf_dt = o2_dtype |
| 1019 | + else: |
| 1020 | + src2 = dpt.asarray(o2, dtype=o2_dtype, sycl_queue=exec_q) |
| 1021 | + if buf_dt is None: |
| 1022 | + if src2.shape != res_shape: |
| 1023 | + src2 = dpt.broadcast_to(src2, res_shape) |
| 1024 | + dep_evs = _manager.submitted_events |
| 1025 | + ht_, comp_ev = self.binary_inplace_fn_( |
| 1026 | + lhs=o1, |
| 1027 | + rhs=src2, |
| 1028 | + sycl_queue=exec_q, |
| 1029 | + depends=dep_evs, |
| 1030 | + ) |
| 1031 | + _manager.add_event_pair(ht_, comp_ev) |
| 1032 | + else: |
| 1033 | + buf = dpt.empty_like(src2, dtype=buf_dt) |
| 1034 | + dep_evs = _manager.submitted_events |
| 1035 | + (ht_copy_ev, copy_ev,) = ti._copy_usm_ndarray_into_usm_ndarray( |
| 1036 | + src=src2, |
| 1037 | + dst=buf, |
| 1038 | + sycl_queue=exec_q, |
| 1039 | + depends=dep_evs, |
| 1040 | + ) |
| 1041 | + _manager.add_event_pair(ht_copy_ev, copy_ev) |
| 1042 | + |
| 1043 | + buf = dpt.broadcast_to(buf, res_shape) |
| 1044 | + ht_, bf_ev = self.binary_inplace_fn_( |
| 1045 | + lhs=o1, |
| 1046 | + rhs=buf, |
| 1047 | + sycl_queue=exec_q, |
| 1048 | + depends=[copy_ev], |
| 1049 | + ) |
| 1050 | + _manager.add_event_pair(ht_, bf_ev) |
| 1051 | + |
| 1052 | + return o1 |
0 commit comments