Skip to content

Commit 3146183

Browse files
authored
Merge pull request #1958 from IntelPython/resolve-gh-1944
Add dedicated reduction kernels for sums and products of boolean arrays
2 parents a7ca491 + f9f4d6c commit 3146183

File tree

3 files changed

+36
-8
lines changed

3 files changed

+36
-8
lines changed

dpctl/tensor/libtensor/source/reductions/prod.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ struct TypePairSupportDataForProductReductionTemps
120120
{
121121

122122
static constexpr bool is_defined = std::disjunction<
123+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, bool>,
123124
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int8_t>,
124125
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::uint8_t>,
125126
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int16_t>,
@@ -224,7 +225,7 @@ struct TypePairSupportDataForProductReductionTemps
224225
outTy,
225226
std::complex<double>>,
226227

227-
// fall-throug
228+
// fall-through
228229
td_ns::NotDefinedEntry>::is_defined;
229230
};
230231

@@ -255,7 +256,9 @@ struct ProductOverAxisTempsStridedFactory
255256
if constexpr (TypePairSupportDataForProductReductionTemps<
256257
srcTy, dstTy>::is_defined)
257258
{
258-
using ReductionOpT = sycl::multiplies<dstTy>;
259+
using ReductionOpT = std::conditional_t<std::is_same_v<dstTy, bool>,
260+
sycl::logical_and<dstTy>,
261+
sycl::multiplies<dstTy>>;
259262
return dpctl::tensor::kernels::
260263
reduction_over_group_temps_strided_impl<srcTy, dstTy,
261264
ReductionOpT>;
@@ -312,7 +315,9 @@ struct ProductOverAxis1TempsContigFactory
312315
if constexpr (TypePairSupportDataForProductReductionTemps<
313316
srcTy, dstTy>::is_defined)
314317
{
315-
using ReductionOpT = sycl::multiplies<dstTy>;
318+
using ReductionOpT = std::conditional_t<std::is_same_v<dstTy, bool>,
319+
sycl::logical_and<dstTy>,
320+
sycl::multiplies<dstTy>>;
316321
return dpctl::tensor::kernels::
317322
reduction_axis1_over_group_temps_contig_impl<srcTy, dstTy,
318323
ReductionOpT>;
@@ -331,7 +336,9 @@ struct ProductOverAxis0TempsContigFactory
331336
if constexpr (TypePairSupportDataForProductReductionTemps<
332337
srcTy, dstTy>::is_defined)
333338
{
334-
using ReductionOpT = sycl::multiplies<dstTy>;
339+
using ReductionOpT = std::conditional_t<std::is_same_v<dstTy, bool>,
340+
sycl::logical_and<dstTy>,
341+
sycl::multiplies<dstTy>>;
335342
return dpctl::tensor::kernels::
336343
reduction_axis0_over_group_temps_contig_impl<srcTy, dstTy,
337344
ReductionOpT>;

dpctl/tensor/libtensor/source/reductions/sum.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ struct TypePairSupportDataForSumReductionTemps
120120
{
121121

122122
static constexpr bool is_defined = std::disjunction<
123+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, bool>,
123124
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int8_t>,
124125
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::uint8_t>,
125126
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int16_t>,
@@ -224,7 +225,7 @@ struct TypePairSupportDataForSumReductionTemps
224225
outTy,
225226
std::complex<double>>,
226227

227-
// fall-throug
228+
// fall-through
228229
td_ns::NotDefinedEntry>::is_defined;
229230
};
230231

@@ -255,7 +256,9 @@ struct SumOverAxisTempsStridedFactory
255256
if constexpr (TypePairSupportDataForSumReductionTemps<
256257
srcTy, dstTy>::is_defined)
257258
{
258-
using ReductionOpT = sycl::plus<dstTy>;
259+
using ReductionOpT =
260+
std::conditional_t<std::is_same_v<dstTy, bool>,
261+
sycl::logical_or<dstTy>, sycl::plus<dstTy>>;
259262
return dpctl::tensor::kernels::
260263
reduction_over_group_temps_strided_impl<srcTy, dstTy,
261264
ReductionOpT>;
@@ -312,7 +315,9 @@ struct SumOverAxis1TempsContigFactory
312315
if constexpr (TypePairSupportDataForSumReductionTemps<
313316
srcTy, dstTy>::is_defined)
314317
{
315-
using ReductionOpT = sycl::plus<dstTy>;
318+
using ReductionOpT =
319+
std::conditional_t<std::is_same_v<dstTy, bool>,
320+
sycl::logical_or<dstTy>, sycl::plus<dstTy>>;
316321
return dpctl::tensor::kernels::
317322
reduction_axis1_over_group_temps_contig_impl<srcTy, dstTy,
318323
ReductionOpT>;
@@ -331,7 +336,9 @@ struct SumOverAxis0TempsContigFactory
331336
if constexpr (TypePairSupportDataForSumReductionTemps<
332337
srcTy, dstTy>::is_defined)
333338
{
334-
using ReductionOpT = sycl::plus<dstTy>;
339+
using ReductionOpT =
340+
std::conditional_t<std::is_same_v<dstTy, bool>,
341+
sycl::logical_or<dstTy>, sycl::plus<dstTy>>;
335342
return dpctl::tensor::kernels::
336343
reduction_axis0_over_group_temps_contig_impl<srcTy, dstTy,
337344
ReductionOpT>;

dpctl/tests/test_tensor_sum.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,3 +316,17 @@ def test_gh_1468():
316316
a = dpt.full((2, 3, 4), 123456789, dtype=dpt.int32)
317317
t = dpt.sum(a, dtype="f4")
318318
assert t > 0
319+
320+
321+
@pytest.mark.parametrize(
322+
"dt", ["i1", "i2", "i4", "i8", "f2", "f4", "f8", "c8", "c16"]
323+
)
324+
def test_gh_1944(dt):
325+
"See https://github.com/IntelPython/dpctl/issues/1944"
326+
q = get_queue_or_skip()
327+
skip_if_dtype_not_supported(dt, q)
328+
x = dpt.asarray([-1, 1], dtype=dpt.dtype(dt), sycl_queue=q)
329+
r = dpt.sum(x, dtype="?")
330+
# reduction must be performed in the requested dtype
331+
# if performed in the input type, result is False
332+
assert r

0 commit comments

Comments
 (0)