Skip to content

Commit 79002b0

Browse files
committed
Create VecOperandInfo
1 parent e15a937 commit 79002b0

File tree

3 files changed

+58
-60
lines changed

3 files changed

+58
-60
lines changed

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 30 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2689,13 +2689,6 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
26892689
/// and needs to be lowered to concrete recipes before codegen. The operands are
26902690
/// {ChainOp, VecOp1, VecOp2, [Condition]}.
26912691
class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
2692-
/// Opcodes of the extend recipes.
2693-
Instruction::CastOps ExtOp0;
2694-
Instruction::CastOps ExtOp1;
2695-
2696-
/// Non-neg flags of the extend recipe.
2697-
bool IsNonNeg0 = false;
2698-
bool IsNonNeg1 = false;
26992692

27002693
/// The scalar type after extending.
27012694
Type *ResultTy = nullptr;
@@ -2712,12 +2705,12 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
27122705
MulAcc->getCondOp(), MulAcc->isOrdered(),
27132706
WrapFlagsTy(MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap()),
27142707
MulAcc->getDebugLoc()),
2715-
ExtOp0(MulAcc->getExt0Opcode()), ExtOp1(MulAcc->getExt1Opcode()),
2716-
IsNonNeg0(MulAcc->isNonNeg0()), IsNonNeg1(MulAcc->isNonNeg1()),
27172708
ResultTy(MulAcc->getResultType()),
27182709
VFScaleFactor(MulAcc->getVFScaleFactor()) {
27192710
transferFlags(*MulAcc);
27202711
setUnderlyingValue(MulAcc->getUnderlyingValue());
2712+
VecOpInfo[0] = MulAcc->getVecOp0Info();
2713+
VecOpInfo[1] = MulAcc->getVecOp1Info();
27212714
}
27222715

27232716
public:
@@ -2731,23 +2724,22 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
27312724
R->getCondOp(), R->isOrdered(),
27322725
WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
27332726
R->getDebugLoc()),
2734-
ExtOp0(Ext0->getOpcode()), ExtOp1(Ext1->getOpcode()),
2735-
IsNonNeg0(Ext0->hasNonNegFlag() && Ext0->isNonNeg()), IsNonNeg1(Ext1->hasNonNegFlag() && Ext1->isNonNeg()),
2736-
ResultTy(ResultTy),
2737-
VFScaleFactor(ScaleFactor) {
2727+
ResultTy(ResultTy), VFScaleFactor(ScaleFactor) {
27382728
assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
27392729
Instruction::Add &&
27402730
"The reduction instruction in MulAccumulateteReductionRecipe must "
27412731
"be Add");
2742-
assert(((ExtOp0 == Instruction::CastOps::ZExt ||
2743-
ExtOp0 == Instruction::CastOps::SExt) && (ExtOp1 == Instruction::CastOps::ZExt || ExtOp1 == Instruction::CastOps::SExt)) &&
2744-
"VPMulAccumulateReductionRecipe only supports zext and sext.");
27452732
setUnderlyingValue(R->getUnderlyingValue());
2746-
// Only set the non-negative flag if the original recipe contains.
2747-
if (Ext0->hasNonNegFlag())
2748-
IsNonNeg0 = Ext0->isNonNeg();
2749-
if (Ext1->hasNonNegFlag())
2750-
IsNonNeg1 = Ext1->isNonNeg();
2733+
// Only set the non-negative flag if the original recipe contains one.
2734+
VecOpInfo[0] = {Ext0->getOpcode(),
2735+
Ext0->hasNonNegFlag() && Ext0->isNonNeg()};
2736+
VecOpInfo[1] = {Ext1->getOpcode(),
2737+
Ext1->hasNonNegFlag() && Ext1->isNonNeg()};
2738+
assert(((Ext0->getOpcode() == Instruction::CastOps::ZExt ||
2739+
Ext0->getOpcode() == Instruction::CastOps::SExt) &&
2740+
(Ext1->getOpcode() == Instruction::CastOps::ZExt ||
2741+
Ext1->getOpcode() == Instruction::CastOps::SExt)) &&
2742+
"VPMulAccumulateReductionRecipe only supports zext and sext.");
27512743
}
27522744

27532745
VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul,
@@ -2758,15 +2750,21 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
27582750
R->getCondOp(), R->isOrdered(),
27592751
WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
27602752
R->getDebugLoc()),
2761-
ExtOp0(Instruction::CastOps::CastOpsEnd),
2762-
ExtOp1(Instruction::CastOps::CastOpsEnd), ResultTy(ResultTy) {
2753+
ResultTy(ResultTy) {
27632754
assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
27642755
Instruction::Add &&
27652756
"The reduction instruction in MulAccumulateReductionRecipe must be "
27662757
"Add");
27672758
setUnderlyingValue(R->getUnderlyingValue());
27682759
}
27692760

2761+
struct VecOperandInfo {
2762+
/// The operand's extend opcode.
2763+
Instruction::CastOps ExtOp{Instruction::CastOps::CastOpsEnd};
2764+
/// Non-neg portion of the operand's flags.
2765+
bool IsNonNeg = false;
2766+
};
2767+
27702768
~VPMulAccumulateReductionRecipe() override = default;
27712769

27722770
VPMulAccumulateReductionRecipe *clone() override {
@@ -2800,30 +2798,22 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
28002798
VPValue *getVecOp1() const { return getOperand(2); }
28012799

28022800
/// Return true if this recipe contains extended operands.
2803-
bool isExtended() const { return ExtOp0 != Instruction::CastOps::CastOpsEnd; }
2801+
bool isExtended() const {
2802+
return getVecOp0Info().ExtOp != Instruction::CastOps::CastOpsEnd;
2803+
}
28042804

28052805
/// Return if the operands of mul instruction come from same extend.
28062806
bool isSameExtendVal() const { return getVecOp0() == getVecOp1(); }
28072807

2808-
/// Return the opcode of the extends for the operands.
2809-
Instruction::CastOps getExt0Opcode() const { return ExtOp0; }
2810-
Instruction::CastOps getExt1Opcode() const { return ExtOp1; }
2811-
2812-
/// Return if the first extend's opcode is ZExt.
2813-
bool isZExt0() const { return ExtOp0 == Instruction::CastOps::ZExt; }
2814-
2815-
/// Return if the second extend's opcode is ZExt.
2816-
bool isZExt1() const { return ExtOp1 == Instruction::CastOps::ZExt; }
2817-
2818-
/// Return true if the first operand extend has the non-negative flag.
2819-
bool isNonNeg0() const { return IsNonNeg0; }
2820-
2821-
/// Return true if the second operand extend has the non-negative flag.
2822-
bool isNonNeg1() const { return IsNonNeg1; }
2823-
28242808
/// Return the scaling factor that the VF is divided by to form the recipe's
28252809
/// output
28262810
unsigned getVFScaleFactor() const { return VFScaleFactor; }
2811+
2812+
VecOperandInfo getVecOp0Info() const { return VecOpInfo[0]; }
2813+
VecOperandInfo getVecOp1Info() const { return VecOpInfo[1]; }
2814+
2815+
protected:
2816+
VecOperandInfo VecOpInfo[2];
28272817
};
28282818

28292819
/// VPReplicateRecipe replicates a given instruction producing multiple scalar

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2572,19 +2572,22 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
25722572
InstructionCost
25732573
VPMulAccumulateReductionRecipe::computeCost(ElementCount VF,
25742574
VPCostContext &Ctx) const {
2575+
VecOperandInfo Op0Info = getVecOp0Info();
2576+
VecOperandInfo Op1Info = getVecOp1Info();
25752577
if (getVFScaleFactor() > 1) {
25762578
return Ctx.TTI.getPartialReductionCost(
25772579
Instruction::Add, Ctx.Types.inferScalarType(getVecOp0()),
25782580
Ctx.Types.inferScalarType(getVecOp1()), getResultType(), VF,
2579-
TTI::getPartialReductionExtendKind(getExt0Opcode()),
2580-
TTI::getPartialReductionExtendKind(getExt1Opcode()), Instruction::Mul);
2581+
TTI::getPartialReductionExtendKind(Op0Info.ExtOp),
2582+
TTI::getPartialReductionExtendKind(Op1Info.ExtOp), Instruction::Mul);
25812583
}
25822584

25832585
Type *RedTy = Ctx.Types.inferScalarType(this);
25842586
auto *SrcVecTy =
25852587
cast<VectorType>(toVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
2586-
return Ctx.TTI.getMulAccReductionCost(isZExt0(), RedTy, SrcVecTy,
2587-
Ctx.CostKind);
2588+
return Ctx.TTI.getMulAccReductionCost(Op0Info.ExtOp ==
2589+
Instruction::CastOps::ZExt,
2590+
RedTy, SrcVecTy, Ctx.CostKind);
25882591
}
25892592

25902593
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
@@ -2653,6 +2656,8 @@ void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
26532656

26542657
void VPMulAccumulateReductionRecipe::print(raw_ostream &O, const Twine &Indent,
26552658
VPSlotTracker &SlotTracker) const {
2659+
VecOperandInfo Op0Info = getVecOp0Info();
2660+
VecOperandInfo Op1Info = getVecOp1Info();
26562661
O << Indent << "MULACC-REDUCE ";
26572662
printAsOperand(O, SlotTracker);
26582663
O << " = ";
@@ -2670,14 +2675,14 @@ void VPMulAccumulateReductionRecipe::print(raw_ostream &O, const Twine &Indent,
26702675
O << "(";
26712676
getVecOp0()->printAsOperand(O, SlotTracker);
26722677
if (isExtended()) {
2673-
O << " " << Instruction::getOpcodeName(ExtOp0) << " to " << *getResultType()
2674-
<< "), (";
2678+
O << " " << Instruction::getOpcodeName(Op0Info.ExtOp) << " to "
2679+
<< *getResultType() << "), (";
26752680
} else
26762681
O << ", ";
26772682
getVecOp1()->printAsOperand(O, SlotTracker);
26782683
if (isExtended()) {
2679-
O << " " << Instruction::getOpcodeName(ExtOp1) << " to " << *getResultType()
2680-
<< ")";
2684+
O << " " << Instruction::getOpcodeName(Op1Info.ExtOp) << " to "
2685+
<< *getResultType() << ")";
26812686
}
26822687
if (isConditional()) {
26832688
O << ", ";

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "llvm/Analysis/InstSimplifyFolder.h"
3131
#include "llvm/Analysis/LoopInfo.h"
3232
#include "llvm/Analysis/VectorUtils.h"
33+
#include "llvm/IR/Instruction.h"
3334
#include "llvm/IR/Intrinsics.h"
3435
#include "llvm/IR/PatternMatch.h"
3536
#include "llvm/Support/Casting.h"
@@ -2545,29 +2546,31 @@ expandVPMulAccumulateReduction(VPMulAccumulateReductionRecipe *MulAcc) {
25452546
// reduce.add(ext(mul(ext, ext))) to reduce.add(mul(ext, ext)).
25462547
VPValue *Op0, *Op1;
25472548
if (MulAcc->isExtended()) {
2549+
VPMulAccumulateReductionRecipe::VecOperandInfo Op0Info =
2550+
MulAcc->getVecOp0Info();
2551+
VPMulAccumulateReductionRecipe::VecOperandInfo Op1Info =
2552+
MulAcc->getVecOp1Info();
25482553
Type *RedTy = MulAcc->getResultType();
2549-
if (MulAcc->isZExt0())
2550-
Op0 = new VPWidenCastRecipe(MulAcc->getExt0Opcode(), MulAcc->getVecOp0(),
2551-
RedTy, VPIRFlags::NonNegFlagsTy(MulAcc->isNonNeg0()),
2554+
if (Op0Info.ExtOp == Instruction::CastOps::ZExt)
2555+
Op0 = new VPWidenCastRecipe(Op0Info.ExtOp, MulAcc->getVecOp0(), RedTy,
2556+
VPIRFlags::NonNegFlagsTy(Op0Info.IsNonNeg),
25522557
MulAcc->getDebugLoc());
25532558
else
2554-
Op0 = new VPWidenCastRecipe(MulAcc->getExt0Opcode(), MulAcc->getVecOp0(),
2555-
RedTy, {}, MulAcc->getDebugLoc());
2559+
Op0 = new VPWidenCastRecipe(Op0Info.ExtOp, MulAcc->getVecOp0(), RedTy, {},
2560+
MulAcc->getDebugLoc());
25562561
Op0->getDefiningRecipe()->insertBefore(MulAcc);
25572562
// Prevent reduce.add(mul(ext(A), ext(A))) generate duplicate
25582563
// VPWidenCastRecipe.
25592564
if (MulAcc->getVecOp0() == MulAcc->getVecOp1()) {
25602565
Op1 = Op0;
25612566
} else {
2562-
if (MulAcc->isZExt1())
2563-
Op1 = new VPWidenCastRecipe(MulAcc->getExt1Opcode(),
2564-
MulAcc->getVecOp1(), RedTy,
2565-
VPIRFlags::NonNegFlagsTy(MulAcc->isNonNeg1()),
2567+
if (Op1Info.ExtOp == Instruction::CastOps::ZExt)
2568+
Op1 = new VPWidenCastRecipe(Op1Info.ExtOp, MulAcc->getVecOp1(), RedTy,
2569+
VPIRFlags::NonNegFlagsTy(Op1Info.IsNonNeg),
25662570
MulAcc->getDebugLoc());
25672571
else
2568-
Op1 =
2569-
new VPWidenCastRecipe(MulAcc->getExt1Opcode(), MulAcc->getVecOp1(),
2570-
RedTy, {}, MulAcc->getDebugLoc());
2572+
Op1 = new VPWidenCastRecipe(Op1Info.ExtOp, MulAcc->getVecOp1(), RedTy,
2573+
{}, MulAcc->getDebugLoc());
25712574
Op1->getDefiningRecipe()->insertBefore(MulAcc);
25722575
}
25732576
} else {

0 commit comments

Comments
 (0)