diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d2fdd39a8..193ecae948 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +* Support for Boolean data-type is added to `dpctl.tensor.ceil`, `dpctl.tensor.floor`, and `dpctl.tensor.trunc` [gh-2033](https://github.com/IntelPython/dpctl/pull/2033) + ### Fixed ## [0.19.0] - Feb. 26, 2025 diff --git a/dpctl/tensor/_elementwise_funcs.py b/dpctl/tensor/_elementwise_funcs.py index 001a53ab35..4731ec8631 100644 --- a/dpctl/tensor/_elementwise_funcs.py +++ b/dpctl/tensor/_elementwise_funcs.py @@ -528,7 +528,7 @@ Args: x (usm_ndarray): - Input array, expected to have a real-valued data type. + Input array, expected to have a boolean or real-valued data type. out (Union[usm_ndarray, None], optional): Output array to populate. Array must have the correct shape and the expected data type. @@ -767,7 +767,7 @@ Args: x (usm_ndarray): - Input array, expected to have a real-valued data type. + Input array, expected to have a boolean or real-valued data type. out (Union[usm_ndarray, None], optional): Output array to populate. Array must have the correct shape and the expected data type. @@ -2017,7 +2017,7 @@ Args: x (usm_ndarray): - Input array, expected to have a real-valued data type. + Input array, expected to have a boolean or real-valued data type. out (Union[usm_ndarray, None], optional): Output array to populate. Array must have the correct shape and the expected data type. diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/ceil.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/ceil.hpp index 1328df3f4b..c587ba6767 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/ceil.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/ceil.hpp @@ -99,7 +99,8 @@ using CeilStridedFunctor = elementwise_common:: template struct CeilOutputType { using value_type = - typename std::disjunction, + typename std::disjunction, + td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor.hpp index aaa81b77b9..5bc6b888ef 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor.hpp @@ -99,7 +99,8 @@ using FloorStridedFunctor = elementwise_common:: template struct FloorOutputType { using value_type = - typename std::disjunction, + typename std::disjunction, + td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/trunc.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/trunc.hpp index 008c5f59b1..4c776e3560 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/trunc.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/trunc.hpp @@ -96,7 +96,8 @@ using TruncStridedFunctor = elementwise_common:: template struct TruncOutputType { using value_type = - typename std::disjunction, + typename std::disjunction, + td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, diff --git a/dpctl/tests/elementwise/test_floor_ceil_trunc.py b/dpctl/tests/elementwise/test_floor_ceil_trunc.py index a6bf956a78..20bb739b2c 100644 --- a/dpctl/tests/elementwise/test_floor_ceil_trunc.py +++ b/dpctl/tests/elementwise/test_floor_ceil_trunc.py @@ -24,13 +24,13 @@ import dpctl.tensor as dpt from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported -from .utils import _map_to_device_dtype, _real_value_dtypes +from .utils import _map_to_device_dtype, _no_complex_dtypes, _real_value_dtypes _all_funcs = [(np.floor, dpt.floor), (np.ceil, dpt.ceil), (np.trunc, dpt.trunc)] @pytest.mark.parametrize("dpt_call", [dpt.floor, dpt.ceil, dpt.trunc]) -@pytest.mark.parametrize("dtype", _real_value_dtypes) +@pytest.mark.parametrize("dtype", _no_complex_dtypes) def test_floor_ceil_trunc_out_type(dpt_call, dtype): q = get_queue_or_skip() skip_if_dtype_not_supported(dtype, q) @@ -69,7 +69,7 @@ def test_floor_ceil_trunc_usm_type(np_call, dpt_call, usm_type): @pytest.mark.parametrize("np_call, dpt_call", _all_funcs) -@pytest.mark.parametrize("dtype", _real_value_dtypes) +@pytest.mark.parametrize("dtype", _no_complex_dtypes) def test_floor_ceil_trunc_order(np_call, dpt_call, dtype): q = get_queue_or_skip() skip_if_dtype_not_supported(dtype, q) @@ -102,7 +102,7 @@ def test_floor_ceil_trunc_error_dtype(dpt_call, dtype): @pytest.mark.parametrize("np_call, dpt_call", _all_funcs) -@pytest.mark.parametrize("dtype", _real_value_dtypes) +@pytest.mark.parametrize("dtype", _no_complex_dtypes) def test_floor_ceil_trunc_contig(np_call, dpt_call, dtype): q = get_queue_or_skip() skip_if_dtype_not_supported(dtype, q) @@ -123,7 +123,7 @@ def test_floor_ceil_trunc_contig(np_call, dpt_call, dtype): @pytest.mark.parametrize("np_call, dpt_call", _all_funcs) -@pytest.mark.parametrize("dtype", _real_value_dtypes) +@pytest.mark.parametrize("dtype", _no_complex_dtypes) def test_floor_ceil_trunc_strided(np_call, dpt_call, dtype): q = get_queue_or_skip() skip_if_dtype_not_supported(dtype, q)