Skip to content

Commit d5be1e3

Browse files
committed
Support MulAccRecipe
1 parent bf92b2d commit d5be1e3

File tree

4 files changed

+237
-69
lines changed

4 files changed

+237
-69
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7664,8 +7664,6 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
76647664

76657665
// TODO: Rebase to fhahn's implementation.
76667666
VPlanTransforms::prepareExecute(BestVPlan);
7667-
dbgs() << "\n\n print plan\n";
7668-
BestVPlan.print(dbgs());
76697667
BestVPlan.execute(&State);
76707668

76717669
// 2.5 Collect reduction resume values.
@@ -9379,11 +9377,34 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
93799377
if (CM.blockNeedsPredicationForAnyReason(BB))
93809378
CondOp = RecipeBuilder.getBlockInMask(BB);
93819379

9382-
// VPWidenCastRecipes can folded into VPReductionRecipe
9383-
VPValue *A;
9380+
VPValue *A, *B;
93849381
VPSingleDefRecipe *RedRecipe;
9385-
if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
9386-
!VecOp->hasMoreThanOneUniqueUser()) {
9382+
// reduce.add(mul(ext, ext)) can folded into VPMulAccRecipe
9383+
if (RdxDesc.getOpcode() == Instruction::Add &&
9384+
match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
9385+
VPRecipeBase *RecipeA = A->getDefiningRecipe();
9386+
VPRecipeBase *RecipeB = B->getDefiningRecipe();
9387+
if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
9388+
match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
9389+
cast<VPWidenCastRecipe>(RecipeA)->getOpcode() ==
9390+
cast<VPWidenCastRecipe>(RecipeB)->getOpcode() &&
9391+
!A->hasMoreThanOneUniqueUser() && !B->hasMoreThanOneUniqueUser()) {
9392+
RedRecipe = new VPMulAccRecipe(
9393+
RdxDesc, CurrentLinkI, PreviousLink, CondOp,
9394+
CM.useOrderedReductions(RdxDesc),
9395+
cast<VPWidenRecipe>(VecOp->getDefiningRecipe()),
9396+
cast<VPWidenCastRecipe>(RecipeA),
9397+
cast<VPWidenCastRecipe>(RecipeB));
9398+
} else {
9399+
RedRecipe = new VPMulAccRecipe(
9400+
RdxDesc, CurrentLinkI, PreviousLink, CondOp,
9401+
CM.useOrderedReductions(RdxDesc),
9402+
cast<VPWidenRecipe>(VecOp->getDefiningRecipe()));
9403+
}
9404+
}
9405+
// VPWidenCastRecipes can folded into VPReductionRecipe
9406+
else if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
9407+
!VecOp->hasMoreThanOneUniqueUser()) {
93879408
RedRecipe = new VPExtendedReductionRecipe(
93889409
RdxDesc, CurrentLinkI,
93899410
cast<CastInst>(

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 57 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2770,60 +2770,64 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
27702770
/// Whether the reduction is conditional.
27712771
bool IsConditional = false;
27722772
/// Type after extend.
2773-
Type *ResultTy;
2774-
/// Type for mul.
2775-
Type *MulTy;
2776-
/// reduce.add(OuterExt(mul(InnerExt(), InnerExt())))
2777-
Instruction::CastOps OuterExtOp;
2778-
Instruction::CastOps InnerExtOp;
2773+
Type *ResultType;
2774+
/// reduce.add(mul(Ext(), Ext()))
2775+
Instruction::CastOps ExtOp;
2776+
2777+
Instruction *MulInstr;
2778+
CastInst *Ext0Instr;
2779+
CastInst *Ext1Instr;
27792780

2780-
Instruction *MulI;
2781-
Instruction *OuterExtI;
2782-
Instruction *InnerExt0I;
2783-
Instruction *InnerExt1I;
2781+
bool IsExtended;
27842782

27852783
protected:
27862784
VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
2787-
Instruction *RedI, Instruction::CastOps OuterExtOp,
2788-
Instruction *OuterExtI, Instruction *MulI,
2789-
Instruction::CastOps InnerExtOp, Instruction *InnerExt0I,
2790-
Instruction *InnerExt1I, ArrayRef<VPValue *> Operands,
2791-
VPValue *CondOp, bool IsOrdered, Type *ResultTy, Type *MulTy)
2785+
Instruction *RedI, Instruction *MulInstr,
2786+
Instruction::CastOps ExtOp, Instruction *Ext0Instr,
2787+
Instruction *Ext1Instr, ArrayRef<VPValue *> Operands,
2788+
VPValue *CondOp, bool IsOrdered, Type *ResultType)
2789+
: VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
2790+
ResultType(ResultType), ExtOp(ExtOp), MulInstr(MulInstr),
2791+
Ext0Instr(cast<CastInst>(Ext0Instr)),
2792+
Ext1Instr(cast<CastInst>(Ext1Instr)) {
2793+
if (CondOp) {
2794+
IsConditional = true;
2795+
addOperand(CondOp);
2796+
}
2797+
IsExtended = true;
2798+
}
2799+
2800+
VPMulAccRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
2801+
Instruction *RedI, Instruction *MulInstr,
2802+
ArrayRef<VPValue *> Operands, VPValue *CondOp, bool IsOrdered)
27922803
: VPSingleDefRecipe(SC, Operands, RedI), RdxDesc(R), IsOrdered(IsOrdered),
2793-
ResultTy(ResultTy), MulTy(MulTy), OuterExtOp(OuterExtOp),
2794-
InnerExtOp(InnerExtOp), MulI(MulI), OuterExtI(OuterExtI),
2795-
InnerExt0I(InnerExt0I), InnerExt1I(InnerExt1I) {
2804+
MulInstr(MulInstr) {
27962805
if (CondOp) {
27972806
IsConditional = true;
27982807
addOperand(CondOp);
27992808
}
2809+
IsExtended = false;
28002810
}
28012811

28022812
public:
28032813
VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
2804-
Instruction *OuterExt, Instruction *Mul,
2805-
Instruction *InnerExt0, Instruction *InnerExt1,
2806-
VPValue *ChainOp, VPValue *InnerExt0Op, VPValue *InnerExt1Op,
2807-
VPValue *CondOp, bool IsOrdered, Type *ResultTy, Type *MulTy)
2808-
: VPMulAccRecipe(
2809-
VPDef::VPMulAccSC, R, RedI, cast<CastInst>(OuterExt)->getOpcode(),
2810-
OuterExt, Mul, cast<CastInst>(InnerExt0)->getOpcode(), InnerExt0,
2811-
InnerExt1, ArrayRef<VPValue *>({ChainOp, InnerExt0Op, InnerExt1Op}),
2812-
CondOp, IsOrdered, ResultTy, MulTy) {}
2813-
2814-
VPMulAccRecipe(VPReductionRecipe *Red, VPWidenCastRecipe *OuterExt,
2815-
VPWidenRecipe *Mul, VPWidenCastRecipe *InnerExt0,
2816-
VPWidenCastRecipe *InnerExt1)
2817-
: VPMulAccRecipe(
2818-
VPDef::VPMulAccSC, Red->getRecurrenceDescriptor(),
2819-
Red->getUnderlyingInstr(), OuterExt->getOpcode(),
2820-
OuterExt->getUnderlyingInstr(), Mul->getUnderlyingInstr(),
2821-
InnerExt0->getOpcode(), InnerExt0->getUnderlyingInstr(),
2822-
InnerExt1->getUnderlyingInstr(),
2823-
ArrayRef<VPValue *>({Red->getChainOp(), InnerExt0->getOperand(0),
2824-
InnerExt1->getOperand(0)}),
2825-
Red->getCondOp(), Red->isOrdered(), OuterExt->getResultType(),
2826-
InnerExt0->getResultType()) {}
2814+
VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
2815+
VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
2816+
VPWidenCastRecipe *Ext1)
2817+
: VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
2818+
Ext0->getOpcode(), Ext0->getUnderlyingInstr(),
2819+
Ext1->getUnderlyingInstr(),
2820+
ArrayRef<VPValue *>(
2821+
{ChainOp, Ext0->getOperand(0), Ext1->getOperand(0)}),
2822+
CondOp, IsOrdered, Ext0->getResultType()) {}
2823+
2824+
VPMulAccRecipe(const RecurrenceDescriptor &R, Instruction *RedI,
2825+
VPValue *ChainOp, VPValue *CondOp, bool IsOrdered,
2826+
VPWidenRecipe *Mul)
2827+
: VPMulAccRecipe(VPDef::VPMulAccSC, R, RedI, Mul->getUnderlyingInstr(),
2828+
ArrayRef<VPValue *>(
2829+
{ChainOp, Mul->getOperand(0), Mul->getOperand(0)}),
2830+
CondOp, IsOrdered) {}
28272831

28282832
~VPMulAccRecipe() override = default;
28292833

@@ -2839,7 +2843,10 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
28392843
}
28402844

28412845
/// Generate the reduction in the loop
2842-
void execute(VPTransformState &State) override;
2846+
void execute(VPTransformState &State) override {
2847+
llvm_unreachable("VPMulAccRecipe should transform to VPWidenCastRecipe + "
2848+
"VPWidenRecipe + VPReductionRecipe before execution");
2849+
}
28432850

28442851
/// Return the cost of VPExtendedReductionRecipe.
28452852
InstructionCost computeCost(ElementCount VF,
@@ -2862,14 +2869,18 @@ class VPMulAccRecipe : public VPSingleDefRecipe {
28622869
/// The VPValue of the scalar Chain being accumulated.
28632870
VPValue *getChainOp() const { return getOperand(0); }
28642871
/// The VPValue of the vector value to be extended and reduced.
2865-
VPValue *getVecOp() const { return getOperand(1); }
2872+
VPValue *getVecOp0() const { return getOperand(1); }
2873+
VPValue *getVecOp1() const { return getOperand(2); }
28662874
/// The VPValue of the condition for the block.
28672875
VPValue *getCondOp() const {
28682876
return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
28692877
}
2870-
Type *getResultTy() const { return ResultTy; };
2871-
Instruction::CastOps getOuterExtOpcode() const { return OuterExtOp; };
2872-
Instruction::CastOps getInnerExtOpcode() const { return InnerExtOp; };
2878+
Type *getResultType() const { return ResultType; };
2879+
Instruction::CastOps getExtOpcode() const { return ExtOp; };
2880+
Instruction *getMulInstr() const { return MulInstr; };
2881+
CastInst *getExt0Instr() const { return Ext0Instr; };
2882+
CastInst *getExt1Instr() const { return Ext1Instr; };
2883+
bool isExtended() const { return IsExtended; };
28732884
};
28742885

28752886
/// VPReplicateRecipe replicates a given instruction producing multiple scalar

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 111 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -270,9 +270,7 @@ InstructionCost VPRecipeBase::cost(ElementCount VF, VPCostContext &Ctx) {
270270
UI = &WidenMem->getIngredient();
271271

272272
InstructionCost RecipeCost;
273-
if ((UI && Ctx.skipCostComputation(UI, VF.isVector())) ||
274-
(Ctx.FoldedRecipes.contains(VF) &&
275-
Ctx.FoldedRecipes.at(VF).contains(this))) {
273+
if ((UI && Ctx.skipCostComputation(UI, VF.isVector()))) {
276274
RecipeCost = 0;
277275
} else {
278276
RecipeCost = computeCost(VF, Ctx);
@@ -2376,6 +2374,85 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
23762374
return ExtendedCost + ReductionCost;
23772375
}
23782376

2377+
InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
2378+
VPCostContext &Ctx) const {
2379+
Type *ElementTy =
2380+
IsExtended ? getResultType() : Ctx.Types.inferScalarType(getVecOp0());
2381+
auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
2382+
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2383+
unsigned Opcode = RdxDesc.getOpcode();
2384+
2385+
assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
2386+
"Inferred type and recurrence type mismatch.");
2387+
2388+
// BaseCost = Reduction cost + BinOp cost
2389+
InstructionCost ReductionCost =
2390+
Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
2391+
ReductionCost += Ctx.TTI.getArithmeticReductionCost(
2392+
Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
2393+
2394+
// Extended cost
2395+
InstructionCost ExtendedCost = 0;
2396+
if (IsExtended) {
2397+
auto *SrcTy = cast<VectorType>(
2398+
ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
2399+
auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
2400+
TTI::CastContextHint CCH0 =
2401+
computeCCH(getVecOp0()->getDefiningRecipe(), VF);
2402+
// Arm TTI will use the underlying instruction to determine the cost.
2403+
ExtendedCost = Ctx.TTI.getCastInstrCost(
2404+
ExtOp, DestTy, SrcTy, CCH0, TTI::TCK_RecipThroughput,
2405+
dyn_cast_if_present<Instruction>(getExt0Instr()));
2406+
TTI::CastContextHint CCH1 =
2407+
computeCCH(getVecOp0()->getDefiningRecipe(), VF);
2408+
ExtendedCost += Ctx.TTI.getCastInstrCost(
2409+
ExtOp, DestTy, SrcTy, CCH1, TTI::TCK_RecipThroughput,
2410+
dyn_cast_if_present<Instruction>(getExt1Instr()));
2411+
}
2412+
2413+
// Mul cost
2414+
InstructionCost MulCost;
2415+
SmallVector<const Value *, 4> Operands;
2416+
Operands.append(MulInstr->value_op_begin(), MulInstr->value_op_end());
2417+
if (IsExtended)
2418+
MulCost = Ctx.TTI.getArithmeticInstrCost(
2419+
Instruction::Mul, VectorTy, CostKind,
2420+
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
2421+
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
2422+
Operands, MulInstr, &Ctx.TLI);
2423+
else {
2424+
VPValue *RHS = getVecOp1();
2425+
// Certain instructions can be cheaper to vectorize if they have a constant
2426+
// second vector operand. One example of this are shifts on x86.
2427+
TargetTransformInfo::OperandValueInfo RHSInfo = {
2428+
TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None};
2429+
if (RHS->isLiveIn())
2430+
RHSInfo = Ctx.TTI.getOperandInfo(RHS->getLiveInIRValue());
2431+
2432+
if (RHSInfo.Kind == TargetTransformInfo::OK_AnyValue &&
2433+
RHS->isDefinedOutsideLoopRegions())
2434+
RHSInfo.Kind = TargetTransformInfo::OK_UniformValue;
2435+
MulCost = Ctx.TTI.getArithmeticInstrCost(
2436+
Instruction::Mul, VectorTy, CostKind,
2437+
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
2438+
RHSInfo, Operands, MulInstr, &Ctx.TLI);
2439+
}
2440+
2441+
// ExtendedReduction Cost
2442+
VectorType *SrcVecTy =
2443+
cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
2444+
InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
2445+
getExtOpcode() == Instruction::CastOps::ZExt, ElementTy, SrcVecTy,
2446+
CostKind);
2447+
2448+
// Check if folding ext into ExtendedReduction is profitable.
2449+
if (MulAccCost.isValid() &&
2450+
MulAccCost < ExtendedCost + ReductionCost + MulCost) {
2451+
return MulAccCost;
2452+
}
2453+
return ExtendedCost + ReductionCost + MulCost;
2454+
}
2455+
23792456
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
23802457
void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
23812458
VPSlotTracker &SlotTracker) const {
@@ -2443,6 +2520,37 @@ void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
24432520
O << " (with final reduction value stored in invariant address sank "
24442521
"outside of loop)";
24452522
}
2523+
2524+
void VPMulAccRecipe::print(raw_ostream &O, const Twine &Indent,
2525+
VPSlotTracker &SlotTracker) const {
2526+
O << Indent << "MULACC-REDUCE ";
2527+
printAsOperand(O, SlotTracker);
2528+
O << " = ";
2529+
getChainOp()->printAsOperand(O, SlotTracker);
2530+
O << " +";
2531+
if (isa<FPMathOperator>(getUnderlyingInstr()))
2532+
O << getUnderlyingInstr()->getFastMathFlags();
2533+
O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
2534+
O << " mul ";
2535+
if (IsExtended)
2536+
O << "(";
2537+
getVecOp0()->printAsOperand(O, SlotTracker);
2538+
if (IsExtended)
2539+
O << " extended to " << *getResultType() << ")";
2540+
if (IsExtended)
2541+
O << "(";
2542+
getVecOp1()->printAsOperand(O, SlotTracker);
2543+
if (IsExtended)
2544+
O << " extended to " << *getResultType() << ")";
2545+
if (isConditional()) {
2546+
O << ", ";
2547+
getCondOp()->printAsOperand(O, SlotTracker);
2548+
}
2549+
O << ")";
2550+
if (RdxDesc.IntermediateStore)
2551+
O << " (with final reduction value stored in invariant address sank "
2552+
"outside of loop)";
2553+
}
24462554
#endif
24472555

24482556
bool VPReplicateRecipe::shouldPack() const {

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -520,25 +520,53 @@ void VPlanTransforms::removeDeadRecipes(VPlan &Plan) {
520520
}
521521

522522
void VPlanTransforms::prepareExecute(VPlan &Plan) {
523-
errs() << "\n\n\n!!Prepare to execute\n";
524523
ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<VPBlockBase *>> RPOT(
525524
Plan.getVectorLoopRegion());
526525
for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
527526
vp_depth_first_deep(Plan.getEntry()))) {
528527
for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
529-
if (!isa<VPExtendedReductionRecipe>(&R))
530-
continue;
531-
auto *ExtRed = cast<VPExtendedReductionRecipe>(&R);
532-
auto *Ext = new VPWidenCastRecipe(
533-
ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
534-
*ExtRed->getExtInstr());
535-
auto *Red = new VPReductionRecipe(
536-
ExtRed->getRecurrenceDescriptor(), ExtRed->getUnderlyingInstr(),
537-
ExtRed->getChainOp(), Ext, ExtRed->getCondOp(), ExtRed->isOrdered());
538-
Ext->insertBefore(ExtRed);
539-
Red->insertBefore(ExtRed);
540-
ExtRed->replaceAllUsesWith(Red);
541-
ExtRed->eraseFromParent();
528+
if (isa<VPExtendedReductionRecipe>(&R)) {
529+
auto *ExtRed = cast<VPExtendedReductionRecipe>(&R);
530+
auto *Ext = new VPWidenCastRecipe(
531+
ExtRed->getExtOpcode(), ExtRed->getVecOp(), ExtRed->getResultType(),
532+
*ExtRed->getExtInstr());
533+
auto *Red = new VPReductionRecipe(
534+
ExtRed->getRecurrenceDescriptor(), ExtRed->getUnderlyingInstr(),
535+
ExtRed->getChainOp(), Ext, ExtRed->getCondOp(),
536+
ExtRed->isOrdered());
537+
Ext->insertBefore(ExtRed);
538+
Red->insertBefore(ExtRed);
539+
ExtRed->replaceAllUsesWith(Red);
540+
ExtRed->eraseFromParent();
541+
} else if (isa<VPMulAccRecipe>(&R)) {
542+
auto *MulAcc = cast<VPMulAccRecipe>(&R);
543+
VPValue *Op0, *Op1;
544+
if (MulAcc->isExtended()) {
545+
Op0 = new VPWidenCastRecipe(
546+
MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
547+
MulAcc->getResultType(), *MulAcc->getExt0Instr());
548+
Op1 = new VPWidenCastRecipe(
549+
MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
550+
MulAcc->getResultType(), *MulAcc->getExt1Instr());
551+
Op0->getDefiningRecipe()->insertBefore(MulAcc);
552+
Op1->getDefiningRecipe()->insertBefore(MulAcc);
553+
} else {
554+
Op0 = MulAcc->getVecOp0();
555+
Op1 = MulAcc->getVecOp1();
556+
}
557+
Instruction *MulInstr = MulAcc->getMulInstr();
558+
SmallVector<VPValue *, 2> MulOps = {Op0, Op1};
559+
auto *Mul = new VPWidenRecipe(*MulInstr,
560+
make_range(MulOps.begin(), MulOps.end()));
561+
auto *Red = new VPReductionRecipe(
562+
MulAcc->getRecurrenceDescriptor(), MulAcc->getUnderlyingInstr(),
563+
MulAcc->getChainOp(), Mul, MulAcc->getCondOp(),
564+
MulAcc->isOrdered());
565+
Mul->insertBefore(MulAcc);
566+
Red->insertBefore(MulAcc);
567+
MulAcc->replaceAllUsesWith(Red);
568+
MulAcc->eraseFromParent();
569+
}
542570
}
543571
}
544572
}

0 commit comments

Comments
 (0)