diff --git a/dpctl/tensor/_clip.py b/dpctl/tensor/_clip.py index eeed87b404..f2bc326e82 100644 --- a/dpctl/tensor/_clip.py +++ b/dpctl/tensor/_clip.py @@ -24,21 +24,24 @@ _empty_like_triple_orderK, ) from dpctl.tensor._elementwise_common import ( - WeakBooleanType, - WeakComplexType, - WeakFloatingType, - WeakIntegralType, _get_dtype, _get_queue_usm_type, _get_shape, - _strong_dtype_num_kind, _validate_dtype, - _weak_type_num_kind, ) from dpctl.tensor._manipulation_functions import _broadcast_shape_impl from dpctl.tensor._type_utils import _can_cast, _to_device_supported_dtype from dpctl.utils import ExecutionPlacementError +from ._type_utils import ( + WeakBooleanType, + WeakComplexType, + WeakFloatingType, + WeakIntegralType, + _strong_dtype_num_kind, + _weak_type_num_kind, +) + def _resolve_one_strong_two_weak_types(st_dtype, dtype1, dtype2, dev): "Resolves weak data types per NEP-0050," diff --git a/dpctl/tensor/_elementwise_common.py b/dpctl/tensor/_elementwise_common.py index 8369c8af3e..80abc9baa4 100644 --- a/dpctl/tensor/_elementwise_common.py +++ b/dpctl/tensor/_elementwise_common.py @@ -28,11 +28,16 @@ from ._copy_utils import _empty_like_orderK, _empty_like_pair_orderK from ._type_utils import ( + WeakBooleanType, + WeakComplexType, + WeakFloatingType, + WeakIntegralType, _acceptance_fn_default_binary, _acceptance_fn_default_unary, _all_data_types, _find_buf_dtype, _find_buf_dtype2, + _resolve_weak_types, _to_device_supported_dtype, ) @@ -286,46 +291,6 @@ def _get_queue_usm_type(o): return None, None -class WeakBooleanType: - "Python type representing type of Python boolean objects" - - def __init__(self, o): - self.o_ = o - - def get(self): - return self.o_ - - -class WeakIntegralType: - "Python type representing type of Python integral objects" - - def __init__(self, o): - self.o_ = o - - def get(self): - return self.o_ - - -class WeakFloatingType: - """Python type representing type of Python floating point objects""" - - def __init__(self, o): - self.o_ = o - - def get(self): - return self.o_ - - -class WeakComplexType: - """Python type representing type of Python complex floating point objects""" - - def __init__(self, o): - self.o_ = o - - def get(self): - return self.o_ - - def _get_dtype(o, dev): if isinstance(o, dpt.usm_ndarray): return o.dtype @@ -375,87 +340,6 @@ def _validate_dtype(dt) -> bool: ) -def _weak_type_num_kind(o): - _map = {"?": 0, "i": 1, "f": 2, "c": 3} - if isinstance(o, WeakBooleanType): - return _map["?"] - if isinstance(o, WeakIntegralType): - return _map["i"] - if isinstance(o, WeakFloatingType): - return _map["f"] - if isinstance(o, WeakComplexType): - return _map["c"] - raise TypeError( - f"Unexpected type {o} while expecting " - "`WeakBooleanType`, `WeakIntegralType`," - "`WeakFloatingType`, or `WeakComplexType`." - ) - - -def _strong_dtype_num_kind(o): - _map = {"b": 0, "i": 1, "u": 1, "f": 2, "c": 3} - if not isinstance(o, dpt.dtype): - raise TypeError - k = o.kind - if k in _map: - return _map[k] - raise ValueError(f"Unrecognized kind {k} for dtype {o}") - - -def _resolve_weak_types(o1_dtype, o2_dtype, dev): - "Resolves weak data type per NEP-0050" - if isinstance( - o1_dtype, - (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType), - ): - if isinstance( - o2_dtype, - ( - WeakBooleanType, - WeakIntegralType, - WeakFloatingType, - WeakComplexType, - ), - ): - raise ValueError - o1_kind_num = _weak_type_num_kind(o1_dtype) - o2_kind_num = _strong_dtype_num_kind(o2_dtype) - if o1_kind_num > o2_kind_num: - if isinstance(o1_dtype, WeakIntegralType): - return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype - if isinstance(o1_dtype, WeakComplexType): - if o2_dtype is dpt.float16 or o2_dtype is dpt.float32: - return dpt.complex64, o2_dtype - return ( - _to_device_supported_dtype(dpt.complex128, dev), - o2_dtype, - ) - return _to_device_supported_dtype(dpt.float64, dev), o2_dtype - else: - return o2_dtype, o2_dtype - elif isinstance( - o2_dtype, - (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType), - ): - o1_kind_num = _strong_dtype_num_kind(o1_dtype) - o2_kind_num = _weak_type_num_kind(o2_dtype) - if o2_kind_num > o1_kind_num: - if isinstance(o2_dtype, WeakIntegralType): - return o1_dtype, dpt.dtype(ti.default_device_int_type(dev)) - if isinstance(o2_dtype, WeakComplexType): - if o1_dtype is dpt.float16 or o1_dtype is dpt.float32: - return o1_dtype, dpt.complex64 - return o1_dtype, _to_device_supported_dtype(dpt.complex128, dev) - return ( - o1_dtype, - _to_device_supported_dtype(dpt.float64, dev), - ) - else: - return o1_dtype, o1_dtype - else: - return o1_dtype, o2_dtype - - def _get_shape(o): if isinstance(o, dpt.usm_ndarray): return o.shape diff --git a/dpctl/tensor/_type_utils.py b/dpctl/tensor/_type_utils.py index c1f6027ccf..3021db1841 100644 --- a/dpctl/tensor/_type_utils.py +++ b/dpctl/tensor/_type_utils.py @@ -252,6 +252,127 @@ def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev, acceptance_fn): return None, None, None +class WeakBooleanType: + "Python type representing type of Python boolean objects" + + def __init__(self, o): + self.o_ = o + + def get(self): + return self.o_ + + +class WeakIntegralType: + "Python type representing type of Python integral objects" + + def __init__(self, o): + self.o_ = o + + def get(self): + return self.o_ + + +class WeakFloatingType: + """Python type representing type of Python floating point objects""" + + def __init__(self, o): + self.o_ = o + + def get(self): + return self.o_ + + +class WeakComplexType: + """Python type representing type of Python complex floating point objects""" + + def __init__(self, o): + self.o_ = o + + def get(self): + return self.o_ + + +def _weak_type_num_kind(o): + _map = {"?": 0, "i": 1, "f": 2, "c": 3} + if isinstance(o, WeakBooleanType): + return _map["?"] + if isinstance(o, WeakIntegralType): + return _map["i"] + if isinstance(o, WeakFloatingType): + return _map["f"] + if isinstance(o, WeakComplexType): + return _map["c"] + raise TypeError( + f"Unexpected type {o} while expecting " + "`WeakBooleanType`, `WeakIntegralType`," + "`WeakFloatingType`, or `WeakComplexType`." + ) + + +def _strong_dtype_num_kind(o): + _map = {"b": 0, "i": 1, "u": 1, "f": 2, "c": 3} + if not isinstance(o, dpt.dtype): + raise TypeError + k = o.kind + if k in _map: + return _map[k] + raise ValueError(f"Unrecognized kind {k} for dtype {o}") + + +def _resolve_weak_types(o1_dtype, o2_dtype, dev): + "Resolves weak data type per NEP-0050" + if isinstance( + o1_dtype, + (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType), + ): + if isinstance( + o2_dtype, + ( + WeakBooleanType, + WeakIntegralType, + WeakFloatingType, + WeakComplexType, + ), + ): + raise ValueError + o1_kind_num = _weak_type_num_kind(o1_dtype) + o2_kind_num = _strong_dtype_num_kind(o2_dtype) + if o1_kind_num > o2_kind_num: + if isinstance(o1_dtype, WeakIntegralType): + return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype + if isinstance(o1_dtype, WeakComplexType): + if o2_dtype is dpt.float16 or o2_dtype is dpt.float32: + return dpt.complex64, o2_dtype + return ( + _to_device_supported_dtype(dpt.complex128, dev), + o2_dtype, + ) + return _to_device_supported_dtype(dpt.float64, dev), o2_dtype + else: + return o2_dtype, o2_dtype + elif isinstance( + o2_dtype, + (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType), + ): + o1_kind_num = _strong_dtype_num_kind(o1_dtype) + o2_kind_num = _weak_type_num_kind(o2_dtype) + if o2_kind_num > o1_kind_num: + if isinstance(o2_dtype, WeakIntegralType): + return o1_dtype, dpt.dtype(ti.default_device_int_type(dev)) + if isinstance(o2_dtype, WeakComplexType): + if o1_dtype is dpt.float16 or o1_dtype is dpt.float32: + return o1_dtype, dpt.complex64 + return o1_dtype, _to_device_supported_dtype(dpt.complex128, dev) + return ( + o1_dtype, + _to_device_supported_dtype(dpt.float64, dev), + ) + else: + return o1_dtype, o1_dtype + else: + return o1_dtype, o2_dtype + + class finfo_object: """ `numpy.finfo` subclass which returns Python floating-point scalars for @@ -407,10 +528,19 @@ def result_type(*arrays_and_dtypes): """ dtypes = [] devices = [] + weak_dtypes = [] for arg_i in arrays_and_dtypes: if isinstance(arg_i, dpt.usm_ndarray): devices.append(arg_i.sycl_device) dtypes.append(arg_i.dtype) + elif isinstance(arg_i, int): + weak_dtypes.append(WeakIntegralType(arg_i)) + elif isinstance(arg_i, float): + weak_dtypes.append(WeakFloatingType(arg_i)) + elif isinstance(arg_i, complex): + weak_dtypes.append(WeakComplexType(arg_i)) + elif isinstance(arg_i, bool): + weak_dtypes.append(WeakBooleanType(arg_i)) else: dt = dpt.dtype(arg_i) _supported_dtype([dt]) @@ -418,6 +548,7 @@ def result_type(*arrays_and_dtypes): has_fp16 = True has_fp64 = True + target_dev = None if devices: inspected = False for d in devices: @@ -435,17 +566,28 @@ def result_type(*arrays_and_dtypes): else: has_fp16 = d.has_aspect_fp16 has_fp64 = d.has_aspect_fp64 + target_dev = d inspected = True if not (has_fp16 and has_fp64): for dt in dtypes: if not _dtype_supported_by_device_impl(dt, has_fp16, has_fp64): - raise ValueError(f"Argument {dt} is not supported by ") + raise ValueError( + f"Argument {dt} is not supported by the device" + ) res_dt = np.result_type(*dtypes) res_dt = _to_device_supported_dtype_impl(res_dt, has_fp16, has_fp64) - return res_dt + for wdt in weak_dtypes: + pair = _resolve_weak_types(wdt, res_dt, target_dev) + res_dt = np.result_type(*pair) + res_dt = _to_device_supported_dtype_impl(res_dt, has_fp16, has_fp64) + else: + res_dt = np.result_type(*dtypes) + if weak_dtypes: + weak_dt_obj = [wdt.get() for wdt in weak_dtypes] + res_dt = np.result_type(res_dt, *weak_dt_obj) - return np.result_type(*dtypes) + return res_dt def iinfo(dtype): @@ -528,8 +670,15 @@ def _supported_dtype(dtypes): "_acceptance_fn_reciprocal", "_acceptance_fn_default_binary", "_acceptance_fn_divide", + "_resolve_weak_types", + "_weak_type_num_kind", + "_strong_dtype_num_kind", "can_cast", "finfo", "iinfo", "result_type", + "WeakBooleanType", + "WeakIntegralType", + "WeakFloatingType", + "WeakComplexType", ] diff --git a/dpctl/tests/elementwise/test_type_utils.py b/dpctl/tests/elementwise/test_type_utils.py index 9407195bd9..25f986806c 100644 --- a/dpctl/tests/elementwise/test_type_utils.py +++ b/dpctl/tests/elementwise/test_type_utils.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np import pytest import dpctl @@ -236,3 +237,44 @@ def test_can_cast_device(): # can't safely cast inexact type to inexact type of lesser precision assert not tu._can_cast(dpt.float32, dpt.float16, True, False) assert not tu._can_cast(dpt.float64, dpt.float32, False, True) + + +def test_acceptance_fns(): + """Check type promotion acceptance functions""" + dev = dpctl.SyclDevice() + assert tu._acceptance_fn_reciprocal( + dpt.float32, dpt.float32, dpt.float32, dev + ) + + +def test_weak_types(): + wbt = tu.WeakBooleanType(True) + assert wbt.get() + assert tu._weak_type_num_kind(wbt) == 0 + + wit = tu.WeakIntegralType(7) + assert wit.get() == 7 + assert tu._weak_type_num_kind(wit) == 1 + + wft = tu.WeakFloatingType(3.1415926) + assert wft.get() == 3.1415926 + assert tu._weak_type_num_kind(wft) == 2 + + wct = tu.WeakComplexType(2.0 + 3.0j) + assert wct.get() == 2 + 3j + assert tu._weak_type_num_kind(wct) == 3 + + +def test_arg_validation(): + with pytest.raises(TypeError): + tu._weak_type_num_kind(dict()) + + with pytest.raises(TypeError): + tu._strong_dtype_num_kind(Ellipsis) + + with pytest.raises(ValueError): + tu._strong_dtype_num_kind(np.dtype("O")) + + wt = tu.WeakFloatingType(2.0) + with pytest.raises(ValueError): + tu._resolve_weak_types(wt, wt, None) diff --git a/dpctl/tests/test_usm_ndarray_manipulation.py b/dpctl/tests/test_usm_ndarray_manipulation.py index 54c2c2380a..5a7799994f 100644 --- a/dpctl/tests/test_usm_ndarray_manipulation.py +++ b/dpctl/tests/test_usm_ndarray_manipulation.py @@ -935,11 +935,39 @@ def test_can_cast(): def test_result_type(): q = get_queue_or_skip() - X = [dpt.ones((2), dtype=dpt.int16, sycl_queue=q), dpt.int32, "int64"] - X_np = [np.ones((2), dtype=np.int16), np.int32, "int64"] + usm_ar = dpt.ones((2), dtype=dpt.int16, sycl_queue=q) + np_ar = dpt.asnumpy(usm_ar) + + X = [usm_ar, dpt.int32, "int64", usm_ar] + X_np = [np_ar, np.int32, "int64", np_ar] + + assert dpt.result_type(*X) == np.result_type(*X_np) + + X = [usm_ar, dpt.int32, "int64", True] + X_np = [np_ar, np.int32, "int64", True] + + assert dpt.result_type(*X) == np.result_type(*X_np) + + X = [usm_ar, dpt.int32, "int64", 2] + X_np = [np_ar, np.int32, "int64", 2] assert dpt.result_type(*X) == np.result_type(*X_np) + X = [dpt.int32, "int64", 2] + X_np = [np.int32, "int64", 2] + + assert dpt.result_type(*X) == np.result_type(*X_np) + + X = [usm_ar, dpt.int32, "int64", 2.0] + X_np = [np_ar, np.int32, "int64", 2.0] + + assert dpt.result_type(*X).kind == np.result_type(*X_np).kind + + X = [usm_ar, dpt.int32, "int64", 2.0 + 1j] + X_np = [np_ar, np.int32, "int64", 2.0 + 1j] + + assert dpt.result_type(*X).kind == np.result_type(*X_np).kind + def test_swapaxes_1d(): get_queue_or_skip()