Skip to content

Commit d6779ee

Browse files
committed
add support for boolean dtypes for dpt.ceil, dpt.floor, and dpt.trunc
1 parent 63f5129 commit d6779ee

File tree

4 files changed

+11
-8
lines changed

4 files changed

+11
-8
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/ceil.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ using CeilStridedFunctor = elementwise_common::
9999
template <typename T> struct CeilOutputType
100100
{
101101
using value_type =
102-
typename std::disjunction<td_ns::TypeMapResultEntry<T, std::uint8_t>,
102+
typename std::disjunction<td_ns::TypeMapResultEntry<T, bool>,
103+
td_ns::TypeMapResultEntry<T, std::uint8_t>,
103104
td_ns::TypeMapResultEntry<T, std::uint16_t>,
104105
td_ns::TypeMapResultEntry<T, std::uint32_t>,
105106
td_ns::TypeMapResultEntry<T, std::uint64_t>,

dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ using FloorStridedFunctor = elementwise_common::
9999
template <typename T> struct FloorOutputType
100100
{
101101
using value_type =
102-
typename std::disjunction<td_ns::TypeMapResultEntry<T, std::uint8_t>,
102+
typename std::disjunction<td_ns::TypeMapResultEntry<T, bool>,
103+
td_ns::TypeMapResultEntry<T, std::uint8_t>,
103104
td_ns::TypeMapResultEntry<T, std::uint16_t>,
104105
td_ns::TypeMapResultEntry<T, std::uint32_t>,
105106
td_ns::TypeMapResultEntry<T, std::uint64_t>,

dpctl/tensor/libtensor/include/kernels/elementwise_functions/trunc.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ using TruncStridedFunctor = elementwise_common::
9696
template <typename T> struct TruncOutputType
9797
{
9898
using value_type =
99-
typename std::disjunction<td_ns::TypeMapResultEntry<T, std::uint8_t>,
99+
typename std::disjunction<td_ns::TypeMapResultEntry<T, bool>,
100+
td_ns::TypeMapResultEntry<T, std::uint8_t>,
100101
td_ns::TypeMapResultEntry<T, std::uint16_t>,
101102
td_ns::TypeMapResultEntry<T, std::uint32_t>,
102103
td_ns::TypeMapResultEntry<T, std::uint64_t>,

dpctl/tests/elementwise/test_floor_ceil_trunc.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@
2424
import dpctl.tensor as dpt
2525
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2626

27-
from .utils import _map_to_device_dtype, _real_value_dtypes
27+
from .utils import _map_to_device_dtype, _no_complex_dtypes, _real_value_dtypes
2828

2929
_all_funcs = [(np.floor, dpt.floor), (np.ceil, dpt.ceil), (np.trunc, dpt.trunc)]
3030

3131

3232
@pytest.mark.parametrize("dpt_call", [dpt.floor, dpt.ceil, dpt.trunc])
33-
@pytest.mark.parametrize("dtype", _real_value_dtypes)
33+
@pytest.mark.parametrize("dtype", _no_complex_dtypes)
3434
def test_floor_ceil_trunc_out_type(dpt_call, dtype):
3535
q = get_queue_or_skip()
3636
skip_if_dtype_not_supported(dtype, q)
@@ -69,7 +69,7 @@ def test_floor_ceil_trunc_usm_type(np_call, dpt_call, usm_type):
6969

7070

7171
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
72-
@pytest.mark.parametrize("dtype", _real_value_dtypes)
72+
@pytest.mark.parametrize("dtype", _no_complex_dtypes)
7373
def test_floor_ceil_trunc_order(np_call, dpt_call, dtype):
7474
q = get_queue_or_skip()
7575
skip_if_dtype_not_supported(dtype, q)
@@ -102,7 +102,7 @@ def test_floor_ceil_trunc_error_dtype(dpt_call, dtype):
102102

103103

104104
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
105-
@pytest.mark.parametrize("dtype", _real_value_dtypes)
105+
@pytest.mark.parametrize("dtype", _no_complex_dtypes)
106106
def test_floor_ceil_trunc_contig(np_call, dpt_call, dtype):
107107
q = get_queue_or_skip()
108108
skip_if_dtype_not_supported(dtype, q)
@@ -123,7 +123,7 @@ def test_floor_ceil_trunc_contig(np_call, dpt_call, dtype):
123123

124124

125125
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
126-
@pytest.mark.parametrize("dtype", _real_value_dtypes)
126+
@pytest.mark.parametrize("dtype", _no_complex_dtypes)
127127
def test_floor_ceil_trunc_strided(np_call, dpt_call, dtype):
128128
q = get_queue_or_skip()
129129
skip_if_dtype_not_supported(dtype, q)

0 commit comments

Comments
 (0)