@@ -930,140 +930,138 @@ def __call__(self, o1, o2, /, *, out=None, order="K"):
930
930
return out
931
931
932
932
def _inplace_op (self , o1 , o2 ):
933
- if self .binary_inplace_fn_ is not None :
934
- if not isinstance (o1 , dpt .usm_ndarray ):
935
- raise TypeError (
936
- "Expected first argument to be "
937
- f"dpctl.tensor.usm_ndarray, got { type (o1 )} "
938
- )
939
- if not o1 .flags .writable :
940
- raise ValueError ("provided left-hand side array is read-only" )
941
- q1 , o1_usm_type = o1 .sycl_queue , o1 .usm_type
942
- q2 , o2_usm_type = _get_queue_usm_type (o2 )
943
- if q2 is None :
944
- exec_q = q1
945
- res_usm_type = o1_usm_type
946
- else :
947
- exec_q = dpctl .utils .get_execution_queue ((q1 , q2 ))
948
- if exec_q is None :
949
- raise ExecutionPlacementError (
950
- "Execution placement can not be unambiguously inferred "
951
- "from input arguments."
952
- )
953
- res_usm_type = dpctl .utils .get_coerced_usm_type (
954
- (
955
- o1_usm_type ,
956
- o2_usm_type ,
957
- )
958
- )
959
- dpctl .utils .validate_usm_type (res_usm_type , allow_none = False )
960
- o1_shape = o1 .shape
961
- o2_shape = _get_shape (o2 )
962
- if not isinstance (o2_shape , (tuple , list )):
963
- raise TypeError (
964
- "Shape of second argument can not be inferred. "
965
- "Expected list or tuple."
966
- )
967
- try :
968
- res_shape = _broadcast_shape_impl (
969
- [
970
- o1_shape ,
971
- o2_shape ,
972
- ]
973
- )
974
- except ValueError :
975
- raise ValueError (
976
- "operands could not be broadcast together with shapes "
977
- f"{ o1_shape } and { o2_shape } "
933
+ if self .binary_inplace_fn_ is None :
934
+ raise ValueError (
935
+ "binary function does not have a dedicated in-place "
936
+ "implementation"
937
+ )
938
+ if not isinstance (o1 , dpt .usm_ndarray ):
939
+ raise TypeError (
940
+ "Expected first argument to be "
941
+ f"dpctl.tensor.usm_ndarray, got { type (o1 )} "
942
+ )
943
+ if not o1 .flags .writable :
944
+ raise ValueError ("provided left-hand side array is read-only" )
945
+ q1 , o1_usm_type = o1 .sycl_queue , o1 .usm_type
946
+ q2 , o2_usm_type = _get_queue_usm_type (o2 )
947
+ if q2 is None :
948
+ exec_q = q1
949
+ res_usm_type = o1_usm_type
950
+ else :
951
+ exec_q = dpctl .utils .get_execution_queue ((q1 , q2 ))
952
+ if exec_q is None :
953
+ raise ExecutionPlacementError (
954
+ "Execution placement can not be unambiguously inferred "
955
+ "from input arguments."
978
956
)
979
-
980
- if res_shape != o1_shape :
981
- raise ValueError (
982
- "The shape of the non-broadcastable left-hand "
983
- f"side { o1_shape } is inconsistent with the "
984
- f"broadcast shape { res_shape } ."
957
+ res_usm_type = dpctl .utils .get_coerced_usm_type (
958
+ (
959
+ o1_usm_type ,
960
+ o2_usm_type ,
985
961
)
986
-
987
- sycl_dev = exec_q .sycl_device
988
- o1_dtype = o1 .dtype
989
- o2_dtype = _get_dtype (o2 , sycl_dev )
990
- if not _validate_dtype (o2_dtype ):
991
- raise ValueError ("Operand has an unsupported data type" )
992
-
993
- o1_dtype , o2_dtype = self .weak_type_resolver_ (
994
- o1_dtype , o2_dtype , sycl_dev
962
+ )
963
+ dpctl .utils .validate_usm_type (res_usm_type , allow_none = False )
964
+ o1_shape = o1 .shape
965
+ o2_shape = _get_shape (o2 )
966
+ if not isinstance (o2_shape , (tuple , list )):
967
+ raise TypeError (
968
+ "Shape of second argument can not be inferred. "
969
+ "Expected list or tuple."
970
+ )
971
+ try :
972
+ res_shape = _broadcast_shape_impl (
973
+ [
974
+ o1_shape ,
975
+ o2_shape ,
976
+ ]
977
+ )
978
+ except ValueError :
979
+ raise ValueError (
980
+ "operands could not be broadcast together with shapes "
981
+ f"{ o1_shape } and { o2_shape } "
995
982
)
996
983
997
- buf_dt , res_dt = _find_buf_dtype_in_place_op (
998
- o1_dtype ,
999
- o2_dtype ,
1000
- self . result_type_resolver_fn_ ,
1001
- sycl_dev ,
984
+ if res_shape != o1_shape :
985
+ raise ValueError (
986
+ "The shape of the non-broadcastable left-hand "
987
+ f"side { o1_shape } is inconsistent with the "
988
+ f"broadcast shape { res_shape } ."
1002
989
)
1003
990
1004
- if res_dt is None :
1005
- raise ValueError (
1006
- f"function '{ self .name_ } ' does not support input types "
1007
- f"({ o1_dtype } , { o2_dtype } ), "
1008
- "and the inputs could not be safely coerced to any "
1009
- "supported types according to the casting rule "
1010
- "''same_kind''."
1011
- )
991
+ sycl_dev = exec_q .sycl_device
992
+ o1_dtype = o1 .dtype
993
+ o2_dtype = _get_dtype (o2 , sycl_dev )
994
+ if not _validate_dtype (o2_dtype ):
995
+ raise ValueError ("Operand has an unsupported data type" )
1012
996
1013
- if res_dt != o1_dtype :
1014
- raise ValueError (
1015
- f"Output array of type { res_dt } is needed, "
1016
- f"got { o1_dtype } "
1017
- )
997
+ o1_dtype , o2_dtype = self .weak_type_resolver_ (
998
+ o1_dtype , o2_dtype , sycl_dev
999
+ )
1018
1000
1019
- _manager = SequentialOrderManager [exec_q ]
1020
- if isinstance (o2 , dpt .usm_ndarray ):
1021
- src2 = o2
1022
- if (
1023
- ti ._array_overlap (o2 , o1 )
1024
- and not ti ._same_logical_tensors (o2 , o1 )
1025
- and buf_dt is None
1026
- ):
1027
- buf_dt = o2_dtype
1028
- else :
1029
- src2 = dpt .asarray (o2 , dtype = o2_dtype , sycl_queue = exec_q )
1030
- if buf_dt is None :
1031
- if src2 .shape != res_shape :
1032
- src2 = dpt .broadcast_to (src2 , res_shape )
1033
- dep_evs = _manager .submitted_events
1034
- ht_ , comp_ev = self .binary_inplace_fn_ (
1035
- lhs = o1 ,
1036
- rhs = src2 ,
1037
- sycl_queue = exec_q ,
1038
- depends = dep_evs ,
1039
- )
1040
- _manager .add_event_pair (ht_ , comp_ev )
1041
- else :
1042
- buf = dpt .empty_like (src2 , dtype = buf_dt )
1043
- dep_evs = _manager .submitted_events
1044
- (
1045
- ht_copy_ev ,
1046
- copy_ev ,
1047
- ) = ti ._copy_usm_ndarray_into_usm_ndarray (
1048
- src = src2 ,
1049
- dst = buf ,
1050
- sycl_queue = exec_q ,
1051
- depends = dep_evs ,
1052
- )
1053
- _manager .add_event_pair (ht_copy_ev , copy_ev )
1001
+ buf_dt , res_dt = _find_buf_dtype_in_place_op (
1002
+ o1_dtype ,
1003
+ o2_dtype ,
1004
+ self .result_type_resolver_fn_ ,
1005
+ sycl_dev ,
1006
+ )
1054
1007
1055
- buf = dpt . broadcast_to ( buf , res_shape )
1056
- ht_ , bf_ev = self . binary_inplace_fn_ (
1057
- lhs = o1 ,
1058
- rhs = buf ,
1059
- sycl_queue = exec_q ,
1060
- depends = [ copy_ev ],
1061
- )
1062
- _manager . add_event_pair ( ht_ , bf_ev )
1008
+ if res_dt is None :
1009
+ raise ValueError (
1010
+ f"function ' { self . name_ } ' does not support input types "
1011
+ f"( { o1_dtype } , { o2_dtype } ), "
1012
+ "and the inputs could not be safely coerced to any "
1013
+ "supported types according to the casting rule "
1014
+ "''same_kind''."
1015
+ )
1063
1016
1064
- return o1
1065
- else :
1017
+ if res_dt != o1_dtype :
1066
1018
raise ValueError (
1067
- "binary function does not have a dedicated in-place "
1068
- "implementation"
1019
+ f"Output array of type { res_dt } is needed, " f"got { o1_dtype } "
1069
1020
)
1021
+
1022
+ _manager = SequentialOrderManager [exec_q ]
1023
+ if isinstance (o2 , dpt .usm_ndarray ):
1024
+ src2 = o2
1025
+ if (
1026
+ ti ._array_overlap (o2 , o1 )
1027
+ and not ti ._same_logical_tensors (o2 , o1 )
1028
+ and buf_dt is None
1029
+ ):
1030
+ buf_dt = o2_dtype
1031
+ else :
1032
+ src2 = dpt .asarray (o2 , dtype = o2_dtype , sycl_queue = exec_q )
1033
+ if buf_dt is None :
1034
+ if src2 .shape != res_shape :
1035
+ src2 = dpt .broadcast_to (src2 , res_shape )
1036
+ dep_evs = _manager .submitted_events
1037
+ ht_ , comp_ev = self .binary_inplace_fn_ (
1038
+ lhs = o1 ,
1039
+ rhs = src2 ,
1040
+ sycl_queue = exec_q ,
1041
+ depends = dep_evs ,
1042
+ )
1043
+ _manager .add_event_pair (ht_ , comp_ev )
1044
+ else :
1045
+ buf = dpt .empty_like (src2 , dtype = buf_dt )
1046
+ dep_evs = _manager .submitted_events
1047
+ (
1048
+ ht_copy_ev ,
1049
+ copy_ev ,
1050
+ ) = ti ._copy_usm_ndarray_into_usm_ndarray (
1051
+ src = src2 ,
1052
+ dst = buf ,
1053
+ sycl_queue = exec_q ,
1054
+ depends = dep_evs ,
1055
+ )
1056
+ _manager .add_event_pair (ht_copy_ev , copy_ev )
1057
+
1058
+ buf = dpt .broadcast_to (buf , res_shape )
1059
+ ht_ , bf_ev = self .binary_inplace_fn_ (
1060
+ lhs = o1 ,
1061
+ rhs = buf ,
1062
+ sycl_queue = exec_q ,
1063
+ depends = [copy_ev ],
1064
+ )
1065
+ _manager .add_event_pair (ht_ , bf_ev )
1066
+
1067
+ return o1
0 commit comments