Skip to content

Commit 3f436b2

Browse files
committed
Integer division by 0 in floor_divide now handled properly
- Also fully enables sycl::vec overload for floor_divide - Added a test for integer division by 0 behavior
1 parent 7b702bc commit 3f436b2

File tree

2 files changed

+67
-25
lines changed

2 files changed

+67
-25
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -53,27 +53,27 @@ template <typename argT1, typename argT2, typename resT>
5353
struct FloorDivideFunctor
5454
{
5555

56-
using supports_sg_loadstore =
57-
std::negation<std::disjunction<tu_ns::is_complex<argT1>,
58-
tu_ns::is_complex<argT2>>>; // TRUE
59-
using supports_vec = std::negation<std::disjunction<
60-
tu_ns::is_complex<argT1>,
61-
tu_ns::is_complex<argT2>,
62-
std::conjunction<std::is_integral<argT1>, std::is_signed<argT1>>,
63-
std::conjunction<std::is_integral<argT2>, std::is_signed<argT2>>>>;
64-
// no vec overload for signed integers to avoid loop
56+
using supports_sg_loadstore = std::negation<
57+
std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
58+
using supports_vec = std::negation<
59+
std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
6560

6661
resT operator()(const argT1 &in1, const argT2 &in2)
6762
{
6863
auto tmp = in1 / in2;
6964
if constexpr (std::is_integral_v<decltype(tmp)>) {
7065
if constexpr (std::is_unsigned_v<decltype(tmp)>) {
71-
return tmp;
66+
return (in2 == argT2(0)) ? resT(0) : tmp;
7267
}
7368
else {
74-
auto rem = in1 % in2;
75-
auto corr = (rem != 0 && ((rem < 0) != (in2 < 0)));
76-
return (tmp - corr);
69+
if (in2 == argT2(0)) {
70+
return resT(0);
71+
}
72+
else {
73+
auto rem = in1 % in2;
74+
auto corr = (rem != 0 && ((rem < 0) != (in2 < 0)));
75+
return (tmp - corr);
76+
}
7777
}
7878
}
7979
else {
@@ -86,17 +86,37 @@ struct FloorDivideFunctor
8686
const sycl::vec<argT2, vec_sz> &in2)
8787
{
8888
auto tmp = in1 / in2;
89-
if constexpr (std::is_same_v<resT,
90-
typename decltype(tmp)::element_type> &&
91-
std::is_integral_v<resT>)
92-
{
93-
return tmp;
94-
}
95-
else if constexpr (std::is_integral_v<typename decltype(
96-
tmp)::element_type>) {
97-
using dpctl::tensor::type_utils::vec_cast;
98-
return vec_cast<resT, typename decltype(tmp)::element_type, vec_sz>(
99-
tmp);
89+
using tmpT = typename decltype(tmp)::element_type;
90+
if constexpr (std::is_integral_v<tmpT>) {
91+
if constexpr (std::is_signed_v<tmpT>) {
92+
auto rem_tmp = in1 % in2;
93+
#pragma unroll
94+
for (int i = 0; i < vec_sz; ++i) {
95+
if (in2[i] == argT2(0)) {
96+
tmp[i] = tmpT(0);
97+
}
98+
else {
99+
tmpT corr = (rem_tmp[i] != 0 &&
100+
((rem_tmp[i] < 0) != (in2[i] < 0)));
101+
tmp[i] -= corr;
102+
}
103+
}
104+
}
105+
else {
106+
#pragma unroll
107+
for (int i = 0; i < vec_sz; ++i) {
108+
if (in2[i] == argT2(0)) {
109+
tmp[i] = tmpT(0);
110+
}
111+
}
112+
}
113+
if constexpr (std::is_same_v<resT, tmpT>) {
114+
return tmp;
115+
}
116+
else {
117+
using dpctl::tensor::type_utils::vec_cast;
118+
return vec_cast<resT, tmpT, vec_sz>(tmp);
119+
}
100120
}
101121
else {
102122
sycl::vec<resT, vec_sz> res = sycl::floor(tmp);

dpctl/tests/elementwise/test_floor_divide.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,11 +203,33 @@ def test_floor_divide_gh_1247():
203203
dpt.asnumpy(res), np.full(res.shape, -1, dtype=res.dtype)
204204
)
205205

206-
x = dpt.arange(-5, 6, 1, dtype="i4")
206+
# attempt to invoke sycl::vec overload using a larger array
207+
x = dpt.arange(-64, 65, 1, dtype="i4")
207208
np.testing.assert_array_equal(
208209
dpt.asnumpy(dpt.floor_divide(x, 3)), np.floor_divide(dpt.asnumpy(x), 3)
209210
)
210211
np.testing.assert_array_equal(
211212
dpt.asnumpy(dpt.floor_divide(x, -3)),
212213
np.floor_divide(dpt.asnumpy(x), -3),
213214
)
215+
216+
217+
@pytest.mark.parametrize("dtype", _no_complex_dtypes[1:9])
218+
def test_floor_divide_integer_zero(dtype):
219+
q = get_queue_or_skip()
220+
skip_if_dtype_not_supported(dtype, q)
221+
222+
x = dpt.arange(10, dtype=dtype, sycl_queue=q)
223+
y = dpt.zeros_like(x, sycl_queue=q)
224+
res = dpt.floor_divide(x, y)
225+
np.testing.assert_array_equal(
226+
dpt.asnumpy(res), np.zeros(x.shape, dtype=res.dtype)
227+
)
228+
229+
# attempt to invoke sycl::vec overload using a larger array
230+
x = dpt.arange(129, dtype=dtype, sycl_queue=q)
231+
y = dpt.zeros_like(x, sycl_queue=q)
232+
res = dpt.floor_divide(x, y)
233+
np.testing.assert_array_equal(
234+
dpt.asnumpy(res), np.zeros(x.shape, dtype=res.dtype)
235+
)

0 commit comments

Comments
 (0)