Skip to content

Commit 204d9fc

Browse files
committed
Arithmetic reductions no longer use atomics for inexact types
This change is intended to improve the numerical stability of sum and prod
1 parent 2fcde7f commit 204d9fc

File tree

1 file changed

+0
-48
lines changed

1 file changed

+0
-48
lines changed

dpctl/tensor/libtensor/include/kernels/reductions.hpp

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2524,57 +2524,33 @@ struct TypePairSupportDataForSumReductionAtomic
25242524
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::uint32_t>,
25252525
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int64_t>,
25262526
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::uint64_t>,
2527-
td_ns::TypePairDefinedEntry<argTy, bool, outTy, float>,
2528-
td_ns::TypePairDefinedEntry<argTy, bool, outTy, double>,
25292527
// input int8
25302528
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int32_t>,
25312529
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int64_t>,
2532-
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, float>,
2533-
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, double>,
25342530
// input uint8
25352531
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::int32_t>,
25362532
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::uint32_t>,
25372533
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::int64_t>,
25382534
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::uint64_t>,
2539-
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, float>,
2540-
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, double>,
25412535
// input int16
25422536
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, std::int32_t>,
25432537
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, std::int64_t>,
2544-
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, float>,
2545-
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, double>,
25462538
// input uint16
25472539
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::int32_t>,
25482540
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::uint32_t>,
25492541
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::int64_t>,
25502542
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::uint64_t>,
2551-
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, float>,
2552-
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, double>,
25532543
// input int32
25542544
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, std::int32_t>,
25552545
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, std::int64_t>,
2556-
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, float>,
2557-
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, double>,
25582546
// input uint32
25592547
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, std::uint32_t>,
25602548
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, std::int64_t>,
25612549
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, std::uint64_t>,
2562-
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, float>,
2563-
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, double>,
25642550
// input int64
25652551
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, std::int64_t>,
2566-
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, double>,
25672552
// input uint64
25682553
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, std::uint64_t>,
2569-
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, double>,
2570-
// input half
2571-
td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, float>,
2572-
td_ns::TypePairDefinedEntry<argTy, float, outTy, double>,
2573-
// input float
2574-
td_ns::TypePairDefinedEntry<argTy, float, outTy, float>,
2575-
td_ns::TypePairDefinedEntry<argTy, float, outTy, double>,
2576-
// input double
2577-
td_ns::TypePairDefinedEntry<argTy, double, outTy, double>,
25782554
// fall-through
25792555
td_ns::NotDefinedEntry>::is_defined;
25802556
};
@@ -2803,57 +2779,33 @@ struct TypePairSupportDataForProductReductionAtomic
28032779
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::uint32_t>,
28042780
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int64_t>,
28052781
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::uint64_t>,
2806-
td_ns::TypePairDefinedEntry<argTy, bool, outTy, float>,
2807-
td_ns::TypePairDefinedEntry<argTy, bool, outTy, double>,
28082782
// input int8
28092783
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int32_t>,
28102784
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int64_t>,
2811-
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, float>,
2812-
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, double>,
28132785
// input uint8
28142786
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::int32_t>,
28152787
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::uint32_t>,
28162788
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::int64_t>,
28172789
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::uint64_t>,
2818-
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, float>,
2819-
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, double>,
28202790
// input int16
28212791
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, std::int32_t>,
28222792
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, std::int64_t>,
2823-
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, float>,
2824-
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, double>,
28252793
// input uint16
28262794
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::int32_t>,
28272795
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::uint32_t>,
28282796
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::int64_t>,
28292797
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::uint64_t>,
2830-
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, float>,
2831-
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, double>,
28322798
// input int32
28332799
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, std::int32_t>,
28342800
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, std::int64_t>,
2835-
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, float>,
2836-
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, double>,
28372801
// input uint32
28382802
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, std::uint32_t>,
28392803
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, std::int64_t>,
28402804
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, std::uint64_t>,
2841-
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, float>,
2842-
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, double>,
28432805
// input int64
28442806
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, std::int64_t>,
2845-
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, double>,
28462807
// input uint64
28472808
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, std::uint64_t>,
2848-
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, double>,
2849-
// input half
2850-
td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, float>,
2851-
td_ns::TypePairDefinedEntry<argTy, float, outTy, double>,
2852-
// input float
2853-
td_ns::TypePairDefinedEntry<argTy, float, outTy, float>,
2854-
td_ns::TypePairDefinedEntry<argTy, float, outTy, double>,
2855-
// input double
2856-
td_ns::TypePairDefinedEntry<argTy, double, outTy, double>,
28572809
// fall-through
28582810
td_ns::NotDefinedEntry>::is_defined;
28592811
};

0 commit comments

Comments
 (0)