@@ -270,7 +270,9 @@ InstructionCost VPRecipeBase::cost(ElementCount VF, VPCostContext &Ctx) {
270
270
UI = &WidenMem->getIngredient ();
271
271
272
272
InstructionCost RecipeCost;
273
- if (UI && Ctx.skipCostComputation (UI, VF.isVector ())) {
273
+ if ((UI && Ctx.skipCostComputation (UI, VF.isVector ())) ||
274
+ (Ctx.FoldedRecipes .contains (VF) &&
275
+ Ctx.FoldedRecipes .at (VF).contains (this ))) {
274
276
RecipeCost = 0 ;
275
277
} else {
276
278
RecipeCost = computeCost (VF, Ctx);
@@ -2187,30 +2189,143 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
2187
2189
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2188
2190
unsigned Opcode = RdxDesc.getOpcode ();
2189
2191
2190
- // TODO: Support any-of and in-loop reductions.
2192
+ // TODO: Support any-of reductions.
2191
2193
assert (
2192
2194
(!RecurrenceDescriptor::isAnyOfRecurrenceKind (RdxKind) ||
2193
2195
ForceTargetInstructionCost.getNumOccurrences () > 0 ) &&
2194
2196
" Any-of reduction not implemented in VPlan-based cost model currently." );
2195
- assert (
2196
- (!cast<VPReductionPHIRecipe>(getOperand (0 ))->isInLoop () ||
2197
- ForceTargetInstructionCost.getNumOccurrences () > 0 ) &&
2198
- " In-loop reduction not implemented in VPlan-based cost model currently." );
2199
2197
2200
2198
assert (ElementTy->getTypeID () == RdxDesc.getRecurrenceType ()->getTypeID () &&
2201
2199
" Inferred type and recurrence type mismatch." );
2202
2200
2203
- // Cost = Reduction cost + BinOp cost
2204
- InstructionCost Cost =
2201
+ // BaseCost = Reduction cost + BinOp cost
2202
+ InstructionCost BaseCost =
2205
2203
Ctx.TTI .getArithmeticInstrCost (Opcode, ElementTy, CostKind);
2206
2204
if (RecurrenceDescriptor::isMinMaxRecurrenceKind (RdxKind)) {
2207
2205
Intrinsic::ID Id = getMinMaxReductionIntrinsicOp (RdxKind);
2208
- return Cost + Ctx.TTI .getMinMaxReductionCost (
2209
- Id, VectorTy, RdxDesc.getFastMathFlags (), CostKind);
2206
+ BaseCost += Ctx.TTI .getMinMaxReductionCost (
2207
+ Id, VectorTy, RdxDesc.getFastMathFlags (), CostKind);
2208
+ } else {
2209
+ BaseCost += Ctx.TTI .getArithmeticReductionCost (
2210
+ Opcode, VectorTy, RdxDesc.getFastMathFlags (), CostKind);
2210
2211
}
2211
2212
2212
- return Cost + Ctx.TTI .getArithmeticReductionCost (
2213
- Opcode, VectorTy, RdxDesc.getFastMathFlags (), CostKind);
2213
+ using namespace llvm ::VPlanPatternMatch;
2214
+ auto GetMulAccReductionCost =
2215
+ [&](const VPReductionRecipe *Red) -> InstructionCost {
2216
+ VPValue *A, *B;
2217
+ InstructionCost InnerExt0Cost = 0 ;
2218
+ InstructionCost InnerExt1Cost = 0 ;
2219
+ InstructionCost ExtCost = 0 ;
2220
+ InstructionCost MulCost = 0 ;
2221
+
2222
+ VectorType *SrcVecTy = VectorTy;
2223
+ Type *InnerExt0Ty;
2224
+ Type *InnerExt1Ty;
2225
+ Type *MaxInnerExtTy;
2226
+ bool IsUnsigned = true ;
2227
+ bool HasOuterExt = false ;
2228
+
2229
+ auto *Ext = dyn_cast_if_present<VPWidenCastRecipe>(
2230
+ Red->getVecOp ()->getDefiningRecipe ());
2231
+ VPRecipeBase *Mul;
2232
+ // Try to match outer extend reduce.add(ext(...))
2233
+ if (Ext && match (Ext, m_ZExtOrSExt (m_VPValue ())) &&
2234
+ cast<VPWidenCastRecipe>(Ext)->getNumUsers () == 1 ) {
2235
+ IsUnsigned =
2236
+ Ext->getOpcode () == Instruction::CastOps::ZExt ? true : false ;
2237
+ ExtCost = Ext->computeCost (VF, Ctx);
2238
+ Mul = Ext->getOperand (0 )->getDefiningRecipe ();
2239
+ HasOuterExt = true ;
2240
+ } else {
2241
+ Mul = Red->getVecOp ()->getDefiningRecipe ();
2242
+ }
2243
+
2244
+ // Match reduce.add(mul())
2245
+ if (Mul && match (Mul, m_Mul (m_VPValue (A), m_VPValue (B))) &&
2246
+ cast<VPWidenRecipe>(Mul)->getNumUsers () == 1 ) {
2247
+ MulCost = cast<VPWidenRecipe>(Mul)->computeCost (VF, Ctx);
2248
+ auto *InnerExt0 =
2249
+ dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe ());
2250
+ auto *InnerExt1 =
2251
+ dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe ());
2252
+ bool HasInnerExt = false ;
2253
+ // Try to match inner extends.
2254
+ if (InnerExt0 && InnerExt1 &&
2255
+ match (InnerExt0, m_ZExtOrSExt (m_VPValue ())) &&
2256
+ match (InnerExt1, m_ZExtOrSExt (m_VPValue ())) &&
2257
+ InnerExt0->getOpcode () == InnerExt1->getOpcode () &&
2258
+ (InnerExt0->getNumUsers () > 0 &&
2259
+ !InnerExt0->hasMoreThanOneUniqueUser ()) &&
2260
+ (InnerExt1->getNumUsers () > 0 &&
2261
+ !InnerExt1->hasMoreThanOneUniqueUser ())) {
2262
+ InnerExt0Cost = InnerExt0->computeCost (VF, Ctx);
2263
+ InnerExt1Cost = InnerExt1->computeCost (VF, Ctx);
2264
+ Type *InnerExt0Ty = Ctx.Types .inferScalarType (InnerExt0->getOperand (0 ));
2265
+ Type *InnerExt1Ty = Ctx.Types .inferScalarType (InnerExt1->getOperand (0 ));
2266
+ Type *MaxInnerExtTy = InnerExt0Ty->getIntegerBitWidth () >
2267
+ InnerExt1Ty->getIntegerBitWidth ()
2268
+ ? InnerExt0Ty
2269
+ : InnerExt1Ty;
2270
+ SrcVecTy = cast<VectorType>(ToVectorTy (MaxInnerExtTy, VF));
2271
+ IsUnsigned = true ;
2272
+ HasInnerExt = true ;
2273
+ }
2274
+ InstructionCost MulAccRedCost = Ctx.TTI .getMulAccReductionCost (
2275
+ IsUnsigned, ElementTy, SrcVecTy, CostKind);
2276
+ // Check if folding ext/mul into MulAccReduction is profitable.
2277
+ if (MulAccRedCost.isValid () &&
2278
+ MulAccRedCost <
2279
+ ExtCost + MulCost + InnerExt0Cost + InnerExt1Cost + BaseCost) {
2280
+ if (HasInnerExt) {
2281
+ Ctx.FoldedRecipes [VF].insert (InnerExt0);
2282
+ Ctx.FoldedRecipes [VF].insert (InnerExt1);
2283
+ }
2284
+ Ctx.FoldedRecipes [VF].insert (Mul);
2285
+ if (HasOuterExt)
2286
+ Ctx.FoldedRecipes [VF].insert (Ext);
2287
+ return MulAccRedCost;
2288
+ }
2289
+ }
2290
+ return InstructionCost::getInvalid ();
2291
+ };
2292
+
2293
+ // Match reduce(ext(...))
2294
+ auto GetExtendedReductionCost =
2295
+ [&](const VPReductionRecipe *Red) -> InstructionCost {
2296
+ VPValue *VecOp = Red->getVecOp ();
2297
+ VPValue *A;
2298
+ if (match (VecOp, m_ZExtOrSExt (m_VPValue (A))) && VecOp->getNumUsers () == 1 ) {
2299
+ VPWidenCastRecipe *Ext =
2300
+ cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe ());
2301
+ bool IsUnsigned = Ext->getOpcode () == Instruction::CastOps::ZExt;
2302
+ InstructionCost ExtCost = Ext->computeCost (VF, Ctx);
2303
+ auto *ExtVecTy =
2304
+ cast<VectorType>(ToVectorTy (Ctx.Types .inferScalarType (A), VF));
2305
+ InstructionCost ExtendedRedCost = Ctx.TTI .getExtendedReductionCost (
2306
+ Opcode, IsUnsigned, ElementTy, ExtVecTy, RdxDesc.getFastMathFlags (),
2307
+ CostKind);
2308
+ // Check if folding ext into ExtendedReduction is profitable.
2309
+ if (ExtendedRedCost.isValid () && ExtendedRedCost < ExtCost + BaseCost) {
2310
+ Ctx.FoldedRecipes [VF].insert (Ext);
2311
+ return ExtendedRedCost;
2312
+ }
2313
+ }
2314
+ return InstructionCost::getInvalid ();
2315
+ };
2316
+
2317
+ // Match MulAccReduction patterns.
2318
+ InstructionCost MulAccCost = GetMulAccReductionCost (this );
2319
+ if (MulAccCost.isValid ())
2320
+ return MulAccCost;
2321
+
2322
+ // Match ExtendedReduction patterns.
2323
+ InstructionCost ExtendedCost = GetExtendedReductionCost (this );
2324
+ if (ExtendedCost.isValid ())
2325
+ return ExtendedCost;
2326
+
2327
+ // Default cost.
2328
+ return BaseCost;
2214
2329
}
2215
2330
2216
2331
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
0 commit comments