Skip to content

Commit ea25db2

Browse files
committed
[VPlan] Impl VPlan-based pattern match for ExtendedRed and MulAccRed. NFCI
This patch implement the VPlan-based pattern match for extendedReduction and MulAccReduction. In above reduction patterns, extened instructions and mul instruction can fold into reduction instruction and the cost is free. We add `FoldedRecipes` in the `VPCostContext` to put recipes that can be folded into other recipes. ExtendedReductionPatterns: reduce(ext(...)) MulAccReductionPatterns: reduce.add(mul(...)) reduce.add(mul(ext(...), ext(...))) reduce.add(ext(mul(...))) reduce.add(ext(mul(ext(...), ext(...)))) Ref: Original instruction based implementation: https://reviews.llvm.org/D93476
1 parent 7111d03 commit ea25db2

File tree

3 files changed

+129
-57
lines changed

3 files changed

+129
-57
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -7306,51 +7306,6 @@ LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF,
73067306
Cost += ReductionCost;
73077307
continue;
73087308
}
7309-
7310-
const auto &ChainOps = RdxDesc.getReductionOpChain(RedPhi, OrigLoop);
7311-
SetVector<Instruction *> ChainOpsAndOperands(ChainOps.begin(),
7312-
ChainOps.end());
7313-
auto IsZExtOrSExt = [](const unsigned Opcode) -> bool {
7314-
return Opcode == Instruction::ZExt || Opcode == Instruction::SExt;
7315-
};
7316-
// Also include the operands of instructions in the chain, as the cost-model
7317-
// may mark extends as free.
7318-
//
7319-
// For ARM, some of the instruction can folded into the reducion
7320-
// instruction. So we need to mark all folded instructions free.
7321-
// For example: We can fold reduce(mul(ext(A), ext(B))) into one
7322-
// instruction.
7323-
for (auto *ChainOp : ChainOps) {
7324-
for (Value *Op : ChainOp->operands()) {
7325-
if (auto *I = dyn_cast<Instruction>(Op)) {
7326-
ChainOpsAndOperands.insert(I);
7327-
if (I->getOpcode() == Instruction::Mul) {
7328-
auto *Ext0 = dyn_cast<Instruction>(I->getOperand(0));
7329-
auto *Ext1 = dyn_cast<Instruction>(I->getOperand(1));
7330-
if (Ext0 && IsZExtOrSExt(Ext0->getOpcode()) && Ext1 &&
7331-
Ext0->getOpcode() == Ext1->getOpcode()) {
7332-
ChainOpsAndOperands.insert(Ext0);
7333-
ChainOpsAndOperands.insert(Ext1);
7334-
}
7335-
}
7336-
}
7337-
}
7338-
}
7339-
7340-
// Pre-compute the cost for I, if it has a reduction pattern cost.
7341-
for (Instruction *I : ChainOpsAndOperands) {
7342-
auto ReductionCost = CM.getReductionPatternCost(
7343-
I, VF, ToVectorTy(I->getType(), VF), TTI::TCK_RecipThroughput);
7344-
if (!ReductionCost)
7345-
continue;
7346-
7347-
assert(!CostCtx.SkipCostComputation.contains(I) &&
7348-
"reduction op visited multiple times");
7349-
CostCtx.SkipCostComputation.insert(I);
7350-
LLVM_DEBUG(dbgs() << "Cost of " << ReductionCost << " for VF " << VF
7351-
<< ":\n in-loop reduction " << *I << "\n");
7352-
Cost += *ReductionCost;
7353-
}
73547309
}
73557310

73567311
// Pre-compute the costs for branches except for the backedge, as the number

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,8 @@ struct VPCostContext {
682682
LLVMContext &LLVMCtx;
683683
LoopVectorizationCostModel &CM;
684684
SmallPtrSet<Instruction *, 8> SkipCostComputation;
685+
/// Contains recipes that are folded into other recipes.
686+
SmallDenseMap<ElementCount, SmallPtrSet<VPRecipeBase *, 4>, 4> FoldedRecipes;
685687

686688
VPCostContext(const TargetTransformInfo &TTI, const TargetLibraryInfo &TLI,
687689
Type *CanIVTy, LoopVectorizationCostModel &CM)

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 127 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,9 @@ InstructionCost VPRecipeBase::cost(ElementCount VF, VPCostContext &Ctx) {
270270
UI = &WidenMem->getIngredient();
271271

272272
InstructionCost RecipeCost;
273-
if (UI && Ctx.skipCostComputation(UI, VF.isVector())) {
273+
if ((UI && Ctx.skipCostComputation(UI, VF.isVector())) ||
274+
(Ctx.FoldedRecipes.contains(VF) &&
275+
Ctx.FoldedRecipes.at(VF).contains(this))) {
274276
RecipeCost = 0;
275277
} else {
276278
RecipeCost = computeCost(VF, Ctx);
@@ -2187,30 +2189,143 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
21872189
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
21882190
unsigned Opcode = RdxDesc.getOpcode();
21892191

2190-
// TODO: Support any-of and in-loop reductions.
2192+
// TODO: Support any-of reductions.
21912193
assert(
21922194
(!RecurrenceDescriptor::isAnyOfRecurrenceKind(RdxKind) ||
21932195
ForceTargetInstructionCost.getNumOccurrences() > 0) &&
21942196
"Any-of reduction not implemented in VPlan-based cost model currently.");
2195-
assert(
2196-
(!cast<VPReductionPHIRecipe>(getOperand(0))->isInLoop() ||
2197-
ForceTargetInstructionCost.getNumOccurrences() > 0) &&
2198-
"In-loop reduction not implemented in VPlan-based cost model currently.");
21992197

22002198
assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
22012199
"Inferred type and recurrence type mismatch.");
22022200

2203-
// Cost = Reduction cost + BinOp cost
2204-
InstructionCost Cost =
2201+
// BaseCost = Reduction cost + BinOp cost
2202+
InstructionCost BaseCost =
22052203
Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
22062204
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
22072205
Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
2208-
return Cost + Ctx.TTI.getMinMaxReductionCost(
2209-
Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
2206+
BaseCost += Ctx.TTI.getMinMaxReductionCost(
2207+
Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
2208+
} else {
2209+
BaseCost += Ctx.TTI.getArithmeticReductionCost(
2210+
Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
22102211
}
22112212

2212-
return Cost + Ctx.TTI.getArithmeticReductionCost(
2213-
Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
2213+
using namespace llvm::VPlanPatternMatch;
2214+
auto GetMulAccReductionCost =
2215+
[&](const VPReductionRecipe *Red) -> InstructionCost {
2216+
VPValue *A, *B;
2217+
InstructionCost InnerExt0Cost = 0;
2218+
InstructionCost InnerExt1Cost = 0;
2219+
InstructionCost ExtCost = 0;
2220+
InstructionCost MulCost = 0;
2221+
2222+
VectorType *SrcVecTy = VectorTy;
2223+
Type *InnerExt0Ty;
2224+
Type *InnerExt1Ty;
2225+
Type *MaxInnerExtTy;
2226+
bool IsUnsigned = true;
2227+
bool HasOuterExt = false;
2228+
2229+
auto *Ext = dyn_cast_if_present<VPWidenCastRecipe>(
2230+
Red->getVecOp()->getDefiningRecipe());
2231+
VPRecipeBase *Mul;
2232+
// Try to match outer extend reduce.add(ext(...))
2233+
if (Ext && match(Ext, m_ZExtOrSExt(m_VPValue())) &&
2234+
cast<VPWidenCastRecipe>(Ext)->getNumUsers() == 1) {
2235+
IsUnsigned =
2236+
Ext->getOpcode() == Instruction::CastOps::ZExt ? true : false;
2237+
ExtCost = Ext->computeCost(VF, Ctx);
2238+
Mul = Ext->getOperand(0)->getDefiningRecipe();
2239+
HasOuterExt = true;
2240+
} else {
2241+
Mul = Red->getVecOp()->getDefiningRecipe();
2242+
}
2243+
2244+
// Match reduce.add(mul())
2245+
if (Mul && match(Mul, m_Mul(m_VPValue(A), m_VPValue(B))) &&
2246+
cast<VPWidenRecipe>(Mul)->getNumUsers() == 1) {
2247+
MulCost = cast<VPWidenRecipe>(Mul)->computeCost(VF, Ctx);
2248+
auto *InnerExt0 =
2249+
dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
2250+
auto *InnerExt1 =
2251+
dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
2252+
bool HasInnerExt = false;
2253+
// Try to match inner extends.
2254+
if (InnerExt0 && InnerExt1 &&
2255+
match(InnerExt0, m_ZExtOrSExt(m_VPValue())) &&
2256+
match(InnerExt1, m_ZExtOrSExt(m_VPValue())) &&
2257+
InnerExt0->getOpcode() == InnerExt1->getOpcode() &&
2258+
(InnerExt0->getNumUsers() > 0 &&
2259+
!InnerExt0->hasMoreThanOneUniqueUser()) &&
2260+
(InnerExt1->getNumUsers() > 0 &&
2261+
!InnerExt1->hasMoreThanOneUniqueUser())) {
2262+
InnerExt0Cost = InnerExt0->computeCost(VF, Ctx);
2263+
InnerExt1Cost = InnerExt1->computeCost(VF, Ctx);
2264+
Type *InnerExt0Ty = Ctx.Types.inferScalarType(InnerExt0->getOperand(0));
2265+
Type *InnerExt1Ty = Ctx.Types.inferScalarType(InnerExt1->getOperand(0));
2266+
Type *MaxInnerExtTy = InnerExt0Ty->getIntegerBitWidth() >
2267+
InnerExt1Ty->getIntegerBitWidth()
2268+
? InnerExt0Ty
2269+
: InnerExt1Ty;
2270+
SrcVecTy = cast<VectorType>(ToVectorTy(MaxInnerExtTy, VF));
2271+
IsUnsigned = true;
2272+
HasInnerExt = true;
2273+
}
2274+
InstructionCost MulAccRedCost = Ctx.TTI.getMulAccReductionCost(
2275+
IsUnsigned, ElementTy, SrcVecTy, CostKind);
2276+
// Check if folding ext/mul into MulAccReduction is profitable.
2277+
if (MulAccRedCost.isValid() &&
2278+
MulAccRedCost <
2279+
ExtCost + MulCost + InnerExt0Cost + InnerExt1Cost + BaseCost) {
2280+
if (HasInnerExt) {
2281+
Ctx.FoldedRecipes[VF].insert(InnerExt0);
2282+
Ctx.FoldedRecipes[VF].insert(InnerExt1);
2283+
}
2284+
Ctx.FoldedRecipes[VF].insert(Mul);
2285+
if (HasOuterExt)
2286+
Ctx.FoldedRecipes[VF].insert(Ext);
2287+
return MulAccRedCost;
2288+
}
2289+
}
2290+
return InstructionCost::getInvalid();
2291+
};
2292+
2293+
// Match reduce(ext(...))
2294+
auto GetExtendedReductionCost =
2295+
[&](const VPReductionRecipe *Red) -> InstructionCost {
2296+
VPValue *VecOp = Red->getVecOp();
2297+
VPValue *A;
2298+
if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) && VecOp->getNumUsers() == 1) {
2299+
VPWidenCastRecipe *Ext =
2300+
cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
2301+
bool IsUnsigned = Ext->getOpcode() == Instruction::CastOps::ZExt;
2302+
InstructionCost ExtCost = Ext->computeCost(VF, Ctx);
2303+
auto *ExtVecTy =
2304+
cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(A), VF));
2305+
InstructionCost ExtendedRedCost = Ctx.TTI.getExtendedReductionCost(
2306+
Opcode, IsUnsigned, ElementTy, ExtVecTy, RdxDesc.getFastMathFlags(),
2307+
CostKind);
2308+
// Check if folding ext into ExtendedReduction is profitable.
2309+
if (ExtendedRedCost.isValid() && ExtendedRedCost < ExtCost + BaseCost) {
2310+
Ctx.FoldedRecipes[VF].insert(Ext);
2311+
return ExtendedRedCost;
2312+
}
2313+
}
2314+
return InstructionCost::getInvalid();
2315+
};
2316+
2317+
// Match MulAccReduction patterns.
2318+
InstructionCost MulAccCost = GetMulAccReductionCost(this);
2319+
if (MulAccCost.isValid())
2320+
return MulAccCost;
2321+
2322+
// Match ExtendedReduction patterns.
2323+
InstructionCost ExtendedCost = GetExtendedReductionCost(this);
2324+
if (ExtendedCost.isValid())
2325+
return ExtendedCost;
2326+
2327+
// Default cost.
2328+
return BaseCost;
22142329
}
22152330

22162331
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)

0 commit comments

Comments
 (0)