Skip to content

Commit 29b214d

Browse files
committed
floor_divide now rounds signed integers toward negative infinity
- Resolves #1247
1 parent 179ce15 commit 29b214d

File tree

2 files changed

+39
-3
lines changed

2 files changed

+39
-3
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,25 @@ struct FloorDivideFunctor
5656
using supports_sg_loadstore =
5757
std::negation<std::disjunction<tu_ns::is_complex<argT1>,
5858
tu_ns::is_complex<argT2>>>; // TRUE
59-
using supports_vec = std::negation<
60-
std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
59+
using supports_vec = std::negation<std::disjunction<
60+
tu_ns::is_complex<argT1>,
61+
tu_ns::is_complex<argT2>,
62+
std::conjunction<std::is_integral<argT1>, std::is_signed<argT1>>,
63+
std::conjunction<std::is_integral<argT2>, std::is_signed<argT2>>>>;
64+
// no vec overload for signed integers to avoid loop
6165

6266
resT operator()(const argT1 &in1, const argT2 &in2)
6367
{
6468
auto tmp = in1 / in2;
6569
if constexpr (std::is_integral_v<decltype(tmp)>) {
66-
return tmp;
70+
if constexpr (std::is_unsigned_v<decltype(tmp)>) {
71+
return tmp;
72+
}
73+
else {
74+
auto rem = in1 % in2;
75+
auto corr = (rem != 0 && ((rem < 0) != (in2 < 0)));
76+
return (tmp - corr);
77+
}
6778
}
6879
else {
6980
return sycl::floor(tmp);

dpctl/tests/elementwise/test_floor_divide.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,28 @@ def __sycl_usm_array_interface__(self):
186186
c = Canary()
187187
with pytest.raises(ValueError):
188188
dpt.floor_divide(a, c)
189+
190+
191+
def test_floor_divide_gh_1247():
192+
get_queue_or_skip()
193+
194+
x = dpt.ones(1, dtype="i4")
195+
res = dpt.floor_divide(x, -2)
196+
np.testing.assert_array_equal(
197+
dpt.asnumpy(res), np.full(res.shape, -1, dtype=res.dtype)
198+
)
199+
200+
x = dpt.full(1, -1, dtype="i4")
201+
res = dpt.floor_divide(x, 2)
202+
np.testing.assert_array_equal(
203+
dpt.asnumpy(res), np.full(res.shape, -1, dtype=res.dtype)
204+
)
205+
206+
x = dpt.arange(-5, 6, 1, dtype="i4")
207+
np.testing.assert_array_equal(
208+
dpt.asnumpy(dpt.floor_divide(x, 3)), np.floor_divide(dpt.asnumpy(x), 3)
209+
)
210+
np.testing.assert_array_equal(
211+
dpt.asnumpy(dpt.floor_divide(x, -3)),
212+
np.floor_divide(dpt.asnumpy(x), -3),
213+
)

0 commit comments

Comments
 (0)