Skip to content

Commit 36a7cd7

Browse files
authored
Merge pull request #1260 from IntelPython/floor-divide-negative-fix
Round signed integers toward negative infinity in dpctl.tensor.floor_divide
2 parents 7c1d147 + 3f436b2 commit 36a7cd7

File tree

2 files changed

+93
-15
lines changed

2 files changed

+93
-15
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -53,17 +53,28 @@ template <typename argT1, typename argT2, typename resT>
5353
struct FloorDivideFunctor
5454
{
5555

56-
using supports_sg_loadstore =
57-
std::negation<std::disjunction<tu_ns::is_complex<argT1>,
58-
tu_ns::is_complex<argT2>>>; // TRUE
56+
using supports_sg_loadstore = std::negation<
57+
std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
5958
using supports_vec = std::negation<
6059
std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
6160

6261
resT operator()(const argT1 &in1, const argT2 &in2)
6362
{
6463
auto tmp = in1 / in2;
6564
if constexpr (std::is_integral_v<decltype(tmp)>) {
66-
return tmp;
65+
if constexpr (std::is_unsigned_v<decltype(tmp)>) {
66+
return (in2 == argT2(0)) ? resT(0) : tmp;
67+
}
68+
else {
69+
if (in2 == argT2(0)) {
70+
return resT(0);
71+
}
72+
else {
73+
auto rem = in1 % in2;
74+
auto corr = (rem != 0 && ((rem < 0) != (in2 < 0)));
75+
return (tmp - corr);
76+
}
77+
}
6778
}
6879
else {
6980
return sycl::floor(tmp);
@@ -75,17 +86,37 @@ struct FloorDivideFunctor
7586
const sycl::vec<argT2, vec_sz> &in2)
7687
{
7788
auto tmp = in1 / in2;
78-
if constexpr (std::is_same_v<resT,
79-
typename decltype(tmp)::element_type> &&
80-
std::is_integral_v<resT>)
81-
{
82-
return tmp;
83-
}
84-
else if constexpr (std::is_integral_v<typename decltype(
85-
tmp)::element_type>) {
86-
using dpctl::tensor::type_utils::vec_cast;
87-
return vec_cast<resT, typename decltype(tmp)::element_type, vec_sz>(
88-
tmp);
89+
using tmpT = typename decltype(tmp)::element_type;
90+
if constexpr (std::is_integral_v<tmpT>) {
91+
if constexpr (std::is_signed_v<tmpT>) {
92+
auto rem_tmp = in1 % in2;
93+
#pragma unroll
94+
for (int i = 0; i < vec_sz; ++i) {
95+
if (in2[i] == argT2(0)) {
96+
tmp[i] = tmpT(0);
97+
}
98+
else {
99+
tmpT corr = (rem_tmp[i] != 0 &&
100+
((rem_tmp[i] < 0) != (in2[i] < 0)));
101+
tmp[i] -= corr;
102+
}
103+
}
104+
}
105+
else {
106+
#pragma unroll
107+
for (int i = 0; i < vec_sz; ++i) {
108+
if (in2[i] == argT2(0)) {
109+
tmp[i] = tmpT(0);
110+
}
111+
}
112+
}
113+
if constexpr (std::is_same_v<resT, tmpT>) {
114+
return tmp;
115+
}
116+
else {
117+
using dpctl::tensor::type_utils::vec_cast;
118+
return vec_cast<resT, tmpT, vec_sz>(tmp);
119+
}
89120
}
90121
else {
91122
sycl::vec<resT, vec_sz> res = sycl::floor(tmp);

dpctl/tests/elementwise/test_floor_divide.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,50 @@ def __sycl_usm_array_interface__(self):
186186
c = Canary()
187187
with pytest.raises(ValueError):
188188
dpt.floor_divide(a, c)
189+
190+
191+
def test_floor_divide_gh_1247():
192+
get_queue_or_skip()
193+
194+
x = dpt.ones(1, dtype="i4")
195+
res = dpt.floor_divide(x, -2)
196+
np.testing.assert_array_equal(
197+
dpt.asnumpy(res), np.full(res.shape, -1, dtype=res.dtype)
198+
)
199+
200+
x = dpt.full(1, -1, dtype="i4")
201+
res = dpt.floor_divide(x, 2)
202+
np.testing.assert_array_equal(
203+
dpt.asnumpy(res), np.full(res.shape, -1, dtype=res.dtype)
204+
)
205+
206+
# attempt to invoke sycl::vec overload using a larger array
207+
x = dpt.arange(-64, 65, 1, dtype="i4")
208+
np.testing.assert_array_equal(
209+
dpt.asnumpy(dpt.floor_divide(x, 3)), np.floor_divide(dpt.asnumpy(x), 3)
210+
)
211+
np.testing.assert_array_equal(
212+
dpt.asnumpy(dpt.floor_divide(x, -3)),
213+
np.floor_divide(dpt.asnumpy(x), -3),
214+
)
215+
216+
217+
@pytest.mark.parametrize("dtype", _no_complex_dtypes[1:9])
218+
def test_floor_divide_integer_zero(dtype):
219+
q = get_queue_or_skip()
220+
skip_if_dtype_not_supported(dtype, q)
221+
222+
x = dpt.arange(10, dtype=dtype, sycl_queue=q)
223+
y = dpt.zeros_like(x, sycl_queue=q)
224+
res = dpt.floor_divide(x, y)
225+
np.testing.assert_array_equal(
226+
dpt.asnumpy(res), np.zeros(x.shape, dtype=res.dtype)
227+
)
228+
229+
# attempt to invoke sycl::vec overload using a larger array
230+
x = dpt.arange(129, dtype=dtype, sycl_queue=q)
231+
y = dpt.zeros_like(x, sycl_queue=q)
232+
res = dpt.floor_divide(x, y)
233+
np.testing.assert_array_equal(
234+
dpt.asnumpy(res), np.zeros(x.shape, dtype=res.dtype)
235+
)

0 commit comments

Comments
 (0)