@@ -3069,53 +3069,6 @@ struct ProductOverAxis0TempsContigFactory
3069
3069
}
3070
3070
};
3071
3071
3072
- /* @brief Types supported by hypot-reduction code based on atomic_ref */
3073
- template <typename argTy, typename outTy>
3074
- struct TypePairSupportDataForHypotReductionAtomic
3075
- {
3076
-
3077
- /* value if true a kernel for <argTy, outTy> must be instantiated, false
3078
- * otherwise */
3079
- static constexpr bool is_defined = std::disjunction< // disjunction is C++17
3080
- // feature, supported
3081
- // by DPC++ input bool
3082
- // input bool
3083
- td_ns::TypePairDefinedEntry<argTy, bool , outTy, float >,
3084
- td_ns::TypePairDefinedEntry<argTy, bool , outTy, double >,
3085
- // input int8
3086
- td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, float >,
3087
- td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, double >,
3088
- // input uint8
3089
- td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, float >,
3090
- td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, double >,
3091
- // input int16
3092
- td_ns::TypePairDefinedEntry<argTy, std::int16_t , outTy, float >,
3093
- td_ns::TypePairDefinedEntry<argTy, std::int16_t , outTy, double >,
3094
- // input uint16
3095
- td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, float >,
3096
- td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, double >,
3097
- // input int32
3098
- td_ns::TypePairDefinedEntry<argTy, std::int32_t , outTy, float >,
3099
- td_ns::TypePairDefinedEntry<argTy, std::int32_t , outTy, double >,
3100
- // input uint32
3101
- td_ns::TypePairDefinedEntry<argTy, std::uint32_t , outTy, float >,
3102
- td_ns::TypePairDefinedEntry<argTy, std::uint32_t , outTy, double >,
3103
- // input int64
3104
- td_ns::TypePairDefinedEntry<argTy, std::int64_t , outTy, double >,
3105
- // input uint64
3106
- td_ns::TypePairDefinedEntry<argTy, std::uint64_t , outTy, double >,
3107
- // input half
3108
- td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, float >,
3109
- td_ns::TypePairDefinedEntry<argTy, float , outTy, double >,
3110
- // input float
3111
- td_ns::TypePairDefinedEntry<argTy, float , outTy, float >,
3112
- td_ns::TypePairDefinedEntry<argTy, float , outTy, double >,
3113
- // input double
3114
- td_ns::TypePairDefinedEntry<argTy, double , outTy, double >,
3115
- // fall-through
3116
- td_ns::NotDefinedEntry>::is_defined;
3117
- };
3118
-
3119
3072
template <typename argTy, typename outTy>
3120
3073
struct TypePairSupportDataForHypotReductionTemps
3121
3074
{
@@ -3177,25 +3130,6 @@ struct TypePairSupportDataForHypotReductionTemps
3177
3130
td_ns::NotDefinedEntry>::is_defined;
3178
3131
};
3179
3132
3180
- template <typename fnT, typename srcTy, typename dstTy>
3181
- struct HypotOverAxisAtomicStridedFactory
3182
- {
3183
- fnT get () const
3184
- {
3185
- if constexpr (TypePairSupportDataForHypotReductionAtomic<
3186
- srcTy, dstTy>::is_defined)
3187
- {
3188
- using ReductionOpT = su_ns::Hypot<dstTy>;
3189
- return dpctl::tensor::kernels::
3190
- reduction_over_group_with_atomics_strided_impl<srcTy, dstTy,
3191
- ReductionOpT>;
3192
- }
3193
- else {
3194
- return nullptr ;
3195
- }
3196
- }
3197
- };
3198
-
3199
3133
template <typename fnT, typename srcTy, typename dstTy>
3200
3134
struct HypotOverAxisTempsStridedFactory
3201
3135
{
@@ -3215,44 +3149,6 @@ struct HypotOverAxisTempsStridedFactory
3215
3149
}
3216
3150
};
3217
3151
3218
- template <typename fnT, typename srcTy, typename dstTy>
3219
- struct HypotOverAxis1AtomicContigFactory
3220
- {
3221
- fnT get () const
3222
- {
3223
- if constexpr (TypePairSupportDataForHypotReductionAtomic<
3224
- srcTy, dstTy>::is_defined)
3225
- {
3226
- using ReductionOpT = su_ns::Hypot<dstTy>;
3227
- return dpctl::tensor::kernels::
3228
- reduction_axis1_over_group_with_atomics_contig_impl<
3229
- srcTy, dstTy, ReductionOpT>;
3230
- }
3231
- else {
3232
- return nullptr ;
3233
- }
3234
- }
3235
- };
3236
-
3237
- template <typename fnT, typename srcTy, typename dstTy>
3238
- struct HypotOverAxis0AtomicContigFactory
3239
- {
3240
- fnT get () const
3241
- {
3242
- if constexpr (TypePairSupportDataForHypotReductionAtomic<
3243
- srcTy, dstTy>::is_defined)
3244
- {
3245
- using ReductionOpT = su_ns::Hypot<dstTy>;
3246
- return dpctl::tensor::kernels::
3247
- reduction_axis0_over_group_with_atomics_contig_impl<
3248
- srcTy, dstTy, ReductionOpT>;
3249
- }
3250
- else {
3251
- return nullptr ;
3252
- }
3253
- }
3254
- };
3255
-
3256
3152
template <typename fnT, typename srcTy, typename dstTy>
3257
3153
struct HypotOverAxis1TempsContigFactory
3258
3154
{
@@ -3291,53 +3187,6 @@ struct HypotOverAxis0TempsContigFactory
3291
3187
}
3292
3188
};
3293
3189
3294
- /* @brief Types supported by logsumexp-reduction code based on atomic_ref */
3295
- template <typename argTy, typename outTy>
3296
- struct TypePairSupportDataForLogSumExpReductionAtomic
3297
- {
3298
-
3299
- /* value if true a kernel for <argTy, outTy> must be instantiated, false
3300
- * otherwise */
3301
- static constexpr bool is_defined = std::disjunction< // disjunction is C++17
3302
- // feature, supported
3303
- // by DPC++ input bool
3304
- // input bool
3305
- td_ns::TypePairDefinedEntry<argTy, bool , outTy, float >,
3306
- td_ns::TypePairDefinedEntry<argTy, bool , outTy, double >,
3307
- // input int8
3308
- td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, float >,
3309
- td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, double >,
3310
- // input uint8
3311
- td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, float >,
3312
- td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, double >,
3313
- // input int16
3314
- td_ns::TypePairDefinedEntry<argTy, std::int16_t , outTy, float >,
3315
- td_ns::TypePairDefinedEntry<argTy, std::int16_t , outTy, double >,
3316
- // input uint16
3317
- td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, float >,
3318
- td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, double >,
3319
- // input int32
3320
- td_ns::TypePairDefinedEntry<argTy, std::int32_t , outTy, float >,
3321
- td_ns::TypePairDefinedEntry<argTy, std::int32_t , outTy, double >,
3322
- // input uint32
3323
- td_ns::TypePairDefinedEntry<argTy, std::uint32_t , outTy, float >,
3324
- td_ns::TypePairDefinedEntry<argTy, std::uint32_t , outTy, double >,
3325
- // input int64
3326
- td_ns::TypePairDefinedEntry<argTy, std::int64_t , outTy, double >,
3327
- // input uint64
3328
- td_ns::TypePairDefinedEntry<argTy, std::uint64_t , outTy, double >,
3329
- // input half
3330
- td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, float >,
3331
- td_ns::TypePairDefinedEntry<argTy, float , outTy, double >,
3332
- // input float
3333
- td_ns::TypePairDefinedEntry<argTy, float , outTy, float >,
3334
- td_ns::TypePairDefinedEntry<argTy, float , outTy, double >,
3335
- // input double
3336
- td_ns::TypePairDefinedEntry<argTy, double , outTy, double >,
3337
- // fall-through
3338
- td_ns::NotDefinedEntry>::is_defined;
3339
- };
3340
-
3341
3190
template <typename argTy, typename outTy>
3342
3191
struct TypePairSupportDataForLogSumExpReductionTemps
3343
3192
{
@@ -3399,25 +3248,6 @@ struct TypePairSupportDataForLogSumExpReductionTemps
3399
3248
td_ns::NotDefinedEntry>::is_defined;
3400
3249
};
3401
3250
3402
- template <typename fnT, typename srcTy, typename dstTy>
3403
- struct LogSumExpOverAxisAtomicStridedFactory
3404
- {
3405
- fnT get () const
3406
- {
3407
- if constexpr (TypePairSupportDataForLogSumExpReductionAtomic<
3408
- srcTy, dstTy>::is_defined)
3409
- {
3410
- using ReductionOpT = su_ns::LogSumExp<dstTy>;
3411
- return dpctl::tensor::kernels::
3412
- reduction_over_group_with_atomics_strided_impl<srcTy, dstTy,
3413
- ReductionOpT>;
3414
- }
3415
- else {
3416
- return nullptr ;
3417
- }
3418
- }
3419
- };
3420
-
3421
3251
template <typename fnT, typename srcTy, typename dstTy>
3422
3252
struct LogSumExpOverAxisTempsStridedFactory
3423
3253
{
@@ -3437,44 +3267,6 @@ struct LogSumExpOverAxisTempsStridedFactory
3437
3267
}
3438
3268
};
3439
3269
3440
- template <typename fnT, typename srcTy, typename dstTy>
3441
- struct LogSumExpOverAxis1AtomicContigFactory
3442
- {
3443
- fnT get () const
3444
- {
3445
- if constexpr (TypePairSupportDataForLogSumExpReductionAtomic<
3446
- srcTy, dstTy>::is_defined)
3447
- {
3448
- using ReductionOpT = su_ns::LogSumExp<dstTy>;
3449
- return dpctl::tensor::kernels::
3450
- reduction_axis1_over_group_with_atomics_contig_impl<
3451
- srcTy, dstTy, ReductionOpT>;
3452
- }
3453
- else {
3454
- return nullptr ;
3455
- }
3456
- }
3457
- };
3458
-
3459
- template <typename fnT, typename srcTy, typename dstTy>
3460
- struct LogSumExpOverAxis0AtomicContigFactory
3461
- {
3462
- fnT get () const
3463
- {
3464
- if constexpr (TypePairSupportDataForLogSumExpReductionAtomic<
3465
- srcTy, dstTy>::is_defined)
3466
- {
3467
- using ReductionOpT = su_ns::LogSumExp<dstTy>;
3468
- return dpctl::tensor::kernels::
3469
- reduction_axis0_over_group_with_atomics_contig_impl<
3470
- srcTy, dstTy, ReductionOpT>;
3471
- }
3472
- else {
3473
- return nullptr ;
3474
- }
3475
- }
3476
- };
3477
-
3478
3270
template <typename fnT, typename srcTy, typename dstTy>
3479
3271
struct LogSumExpOverAxis1TempsContigFactory
3480
3272
{
0 commit comments