@@ -53,27 +53,27 @@ template <typename argT1, typename argT2, typename resT>
53
53
struct FloorDivideFunctor
54
54
{
55
55
56
- using supports_sg_loadstore =
57
- std::negation<std::disjunction<tu_ns::is_complex<argT1>,
58
- tu_ns::is_complex<argT2>>>; // TRUE
59
- using supports_vec = std::negation<std::disjunction<
60
- tu_ns::is_complex<argT1>,
61
- tu_ns::is_complex<argT2>,
62
- std::conjunction<std::is_integral<argT1>, std::is_signed<argT1>>,
63
- std::conjunction<std::is_integral<argT2>, std::is_signed<argT2>>>>;
64
- // no vec overload for signed integers to avoid loop
56
+ using supports_sg_loadstore = std::negation<
57
+ std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
58
+ using supports_vec = std::negation<
59
+ std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
65
60
66
61
resT operator ()(const argT1 &in1, const argT2 &in2)
67
62
{
68
63
auto tmp = in1 / in2;
69
64
if constexpr (std::is_integral_v<decltype (tmp)>) {
70
65
if constexpr (std::is_unsigned_v<decltype (tmp)>) {
71
- return tmp;
66
+ return (in2 == argT2 ( 0 )) ? resT ( 0 ) : tmp;
72
67
}
73
68
else {
74
- auto rem = in1 % in2;
75
- auto corr = (rem != 0 && ((rem < 0 ) != (in2 < 0 )));
76
- return (tmp - corr);
69
+ if (in2 == argT2 (0 )) {
70
+ return resT (0 );
71
+ }
72
+ else {
73
+ auto rem = in1 % in2;
74
+ auto corr = (rem != 0 && ((rem < 0 ) != (in2 < 0 )));
75
+ return (tmp - corr);
76
+ }
77
77
}
78
78
}
79
79
else {
@@ -86,17 +86,37 @@ struct FloorDivideFunctor
86
86
const sycl::vec<argT2, vec_sz> &in2)
87
87
{
88
88
auto tmp = in1 / in2;
89
- if constexpr (std::is_same_v<resT,
90
- typename decltype (tmp)::element_type> &&
91
- std::is_integral_v<resT>)
92
- {
93
- return tmp;
94
- }
95
- else if constexpr (std::is_integral_v<typename decltype (
96
- tmp)::element_type>) {
97
- using dpctl::tensor::type_utils::vec_cast;
98
- return vec_cast<resT, typename decltype (tmp)::element_type, vec_sz>(
99
- tmp);
89
+ using tmpT = typename decltype (tmp)::element_type;
90
+ if constexpr (std::is_integral_v<tmpT>) {
91
+ if constexpr (std::is_signed_v<tmpT>) {
92
+ auto rem_tmp = in1 % in2;
93
+ #pragma unroll
94
+ for (int i = 0 ; i < vec_sz; ++i) {
95
+ if (in2[i] == argT2 (0 )) {
96
+ tmp[i] = tmpT (0 );
97
+ }
98
+ else {
99
+ tmpT corr = (rem_tmp[i] != 0 &&
100
+ ((rem_tmp[i] < 0 ) != (in2[i] < 0 )));
101
+ tmp[i] -= corr;
102
+ }
103
+ }
104
+ }
105
+ else {
106
+ #pragma unroll
107
+ for (int i = 0 ; i < vec_sz; ++i) {
108
+ if (in2[i] == argT2 (0 )) {
109
+ tmp[i] = tmpT (0 );
110
+ }
111
+ }
112
+ }
113
+ if constexpr (std::is_same_v<resT, tmpT>) {
114
+ return tmp;
115
+ }
116
+ else {
117
+ using dpctl::tensor::type_utils::vec_cast;
118
+ return vec_cast<resT, tmpT, vec_sz>(tmp);
119
+ }
100
120
}
101
121
else {
102
122
sycl::vec<resT, vec_sz> res = sycl::floor (tmp);
0 commit comments