-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[SDAG] Add partial_reduce_sumla node #141267
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
We have recently added the partial_reduce_smla and partial_reduce_umla nodes to represent Acc += ext(b) * ext(b) where the two extends have to have the same source type, and have the same extend kind. For riscv64 w/zvqdotq, we have the vqdot and vqdotu instructions which correspond to the existing nodes, but we also have vqdotsu which represents the case where the two extends are sign and zero respective (i.e. not the same type of extend). This patch adds a partial_reduce_sumla node which has sign extension for A, and zero extension for B. The addition is somewhat mechanical, except that it exposes an implementaion challenge because AArch64 doesn't have an analogous instruction (that I've found). The current legalization table assumes that all of the partial_reduce*mla variants have the same handling for a given type pair. Questions to the AArch64 folks: * Does aarch64 have a good implementation for this that I missed? * If not, are you okay with my somewhat hacky custom legalization approach (in this patch)? It does look like there are some small regressions here, but I haven't dug into why. * If not, any suggestions on how to structure splitting the legalization table? I could add the opcode to the table key; that's probably the easiest.
@llvm/pr-subscribers-backend-risc-v @llvm/pr-subscribers-backend-aarch64 Author: Philip Reames (preames) ChangesWe have recently added the partial_reduce_smla and partial_reduce_umla nodes to represent Acc += ext(b) * ext(b) where the two extends have to have the same source type, and have the same extend kind. For riscv64 w/zvqdotq, we have the vqdot and vqdotu instructions which correspond to the existing nodes, but we also have vqdotsu which represents the case where the two extends are sign and zero respective (i.e. not the same type of extend). This patch adds a partial_reduce_sumla node which has sign extension for A, and zero extension for B. The addition is somewhat mechanical, except that it exposes an implementaion challenge because AArch64 doesn't have an analogous instruction (that I've found). The current legalization table assumes that all of the partial_reduce*mla variants have the same handling for a given type pair. Questions to the AArch64 folks:
Patch is 36.70 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141267.diff 12 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 9f66402e4c820..848631c7ffb03 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1484,8 +1484,9 @@ enum NodeType {
VECREDUCE_UMIN,
// PARTIAL_REDUCE_[U|S]MLA(Accumulator, Input1, Input2)
- // The partial reduction nodes sign or zero extend Input1 and Input2 to the
- // element type of Accumulator before multiplying their results.
+ // The partial reduction nodes sign or zero extend Input1 and Input2
+ // (with the extension kind noted below) to the element type of
+ // Accumulator before multiplying their results.
// This result is concatenated to the Accumulator, and this is then reduced,
// using addition, to the result type.
// The output is only expected to either be given to another partial reduction
@@ -1497,8 +1498,9 @@ enum NodeType {
// multiple of the number of elements in the Accumulator / output type.
// Input1 and Input2 must have an element type which is the same as or smaller
// than the element type of the Accumulator and output.
- PARTIAL_REDUCE_SMLA,
- PARTIAL_REDUCE_UMLA,
+ PARTIAL_REDUCE_SMLA, // sext, sext
+ PARTIAL_REDUCE_UMLA, // zext, zext
+ PARTIAL_REDUCE_SUMLA, // sext, zext
// The `llvm.experimental.stackmap` intrinsic.
// Operands: input chain, glue, <id>, <numShadowBytes>, [live0[, live1...]]
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index efaa8bd4a7950..df6702c390fc7 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -1991,6 +1991,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM: return visitMHISTOGRAM(N);
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_UMLA:
+ case ISD::PARTIAL_REDUCE_SUMLA:
return visitPARTIAL_REDUCE_MLA(N);
case ISD::VECTOR_COMPRESS: return visitVECTOR_COMPRESS(N);
case ISD::LIFETIME_END: return visitLIFETIME_END(N);
@@ -12675,19 +12676,19 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
return SDValue();
- bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
- unsigned NewOpcode =
- ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
-
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
// -> partial_reduce_*mla(acc, x, C)
if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
+ // TODO: Make use of partial_reduce_sumla here
APInt CTrunc = C.trunc(LHSExtOpVT.getScalarSizeInBits());
unsigned LHSBits = LHS.getValueType().getScalarSizeInBits();
if ((LHSOpcode != ISD::ZERO_EXTEND || CTrunc.zext(LHSBits) != C) &&
(LHSOpcode != ISD::SIGN_EXTEND || CTrunc.sext(LHSBits) != C))
return SDValue();
+ unsigned NewOpcode = LHSOpcode == ISD::SIGN_EXTEND
+ ? ISD::PARTIAL_REDUCE_SMLA
+ : ISD::PARTIAL_REDUCE_UMLA;
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
DAG.getConstant(CTrunc, DL, LHSExtOpVT));
}
@@ -12697,26 +12698,46 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
return SDValue();
SDValue RHSExtOp = RHS->getOperand(0);
- if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
+ if (LHSExtOpVT != RHSExtOp.getValueType())
return SDValue();
- // For a 2-stage extend the signedness of both of the extends must be the
- // same. This is so the node can be folded into only a signed or unsigned
- // node.
- bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
+ unsigned NewOpc = ISD::PARTIAL_REDUCE_SMLA;
+ // For a 2-stage extend the signedness of both of the extends must match
+ // If the mul has the same type, there is no outer extend, and thus we
+ // can simply use the inner extends to pick the result node.
EVT AccElemVT = Acc.getValueType().getVectorElementType();
- if (ExtIsSigned != NodeIsSigned &&
- Op1.getValueType().getVectorElementType() != AccElemVT)
- return SDValue();
-
- return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
- RHSExtOp);
+ if (Op1.getValueType().getVectorElementType() != AccElemVT) {
+ // TODO: Split this into canonicalization rules
+ if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::SIGN_EXTEND &&
+ (N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA ||
+ N->getOpcode() == ISD::PARTIAL_REDUCE_SUMLA))
+ NewOpc = ISD::PARTIAL_REDUCE_SMLA;
+ else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::ZERO_EXTEND &&
+ N->getOpcode() == ISD::PARTIAL_REDUCE_UMLA)
+ NewOpc = ISD::PARTIAL_REDUCE_UMLA;
+ else
+ return SDValue();
+ } else {
+ // TODO: Add canonicalization rule
+ if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::SIGN_EXTEND)
+ NewOpc = ISD::PARTIAL_REDUCE_SMLA;
+ else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::ZERO_EXTEND)
+ NewOpc = ISD::PARTIAL_REDUCE_UMLA;
+ else if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::ZERO_EXTEND)
+ NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
+ else
+ // TODO: Handle the swapped sumla case here
+ return SDValue();
+ }
+ return DAG.getNode(NewOpc, DL, N->getValueType(0), Acc, LHSExtOp, RHSExtOp);
}
// partial.reduce.umla(acc, zext(op), splat(1))
// -> partial.reduce.umla(acc, op, splat(trunc(1)))
// partial.reduce.smla(acc, sext(op), splat(1))
// -> partial.reduce.smla(acc, op, splat(trunc(1)))
+// partial.reduce.sumla(acc, sext(op), splat(1))
+// -> partial.reduce.smla(acc, op, splat(trunc(1)))
SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
SDLoc DL(N);
SDValue Acc = N->getOperand(0);
@@ -12738,7 +12759,7 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
return SDValue();
bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
- bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
+ bool NodeIsSigned = N->getOpcode() != ISD::PARTIAL_REDUCE_UMLA;
EVT AccElemVT = Acc.getValueType().getVectorElementType();
if (Op1IsSigned != NodeIsSigned &&
Op1.getValueType().getVectorElementType() != AccElemVT)
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 90af5f2cd8e70..5eb2f8c9150e9 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -166,6 +166,7 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
+ case ISD::PARTIAL_REDUCE_SUMLA:
Res = PromoteIntRes_PARTIAL_REDUCE_MLA(N);
break;
@@ -2090,6 +2091,7 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
break;
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
+ case ISD::PARTIAL_REDUCE_SUMLA:
Res = PromoteIntOp_PARTIAL_REDUCE_MLA(N);
break;
}
@@ -2876,12 +2878,21 @@ SDValue DAGTypeLegalizer::PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(SDNode *N,
SDValue DAGTypeLegalizer::PromoteIntOp_PARTIAL_REDUCE_MLA(SDNode *N) {
SmallVector<SDValue, 1> NewOps(N->ops());
- if (N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA) {
+ switch (N->getOpcode()) {
+ case ISD::PARTIAL_REDUCE_SMLA:
NewOps[1] = SExtPromotedInteger(N->getOperand(1));
NewOps[2] = SExtPromotedInteger(N->getOperand(2));
- } else {
+ break;
+ case ISD::PARTIAL_REDUCE_UMLA:
NewOps[1] = ZExtPromotedInteger(N->getOperand(1));
NewOps[2] = ZExtPromotedInteger(N->getOperand(2));
+ break;
+ case ISD::PARTIAL_REDUCE_SUMLA:
+ NewOps[1] = SExtPromotedInteger(N->getOperand(1));
+ NewOps[2] = ZExtPromotedInteger(N->getOperand(2));
+ break;
+ default:
+ llvm_unreachable("unexpected opcode");
}
return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index affcd78ea61b0..4a12b76851966 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -530,6 +530,7 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
}
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
+ case ISD::PARTIAL_REDUCE_SUMLA:
Action = TLI.getPartialReduceMLAAction(Node->getValueType(0),
Node->getOperand(1).getValueType());
break;
@@ -1210,6 +1211,7 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
return;
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
+ case ISD::PARTIAL_REDUCE_SUMLA:
Results.push_back(TLI.expandPartialReduceMLA(Node, DAG));
return;
case ISD::VECREDUCE_SEQ_FADD:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index c011a0a61d698..d3200b38c350e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -1387,6 +1387,7 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
break;
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
+ case ISD::PARTIAL_REDUCE_SUMLA:
SplitVecRes_PARTIAL_REDUCE_MLA(N, Lo, Hi);
break;
}
@@ -3454,6 +3455,7 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
break;
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
+ case ISD::PARTIAL_REDUCE_SUMLA:
Res = SplitVecOp_PARTIAL_REDUCE_MLA(N);
break;
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 5400f3eaf373d..b8288af53de1e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -7967,7 +7967,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
break;
}
case ISD::PARTIAL_REDUCE_UMLA:
- case ISD::PARTIAL_REDUCE_SMLA: {
+ case ISD::PARTIAL_REDUCE_SMLA:
+ case ISD::PARTIAL_REDUCE_SUMLA: {
[[maybe_unused]] EVT AccVT = N1.getValueType();
[[maybe_unused]] EVT Input1VT = N2.getValueType();
[[maybe_unused]] EVT Input2VT = N3.getValueType();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index 803894e298dd5..4fd0b7fd873e6 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -584,6 +584,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
return "partial_reduce_umla";
case ISD::PARTIAL_REDUCE_SMLA:
return "partial_reduce_smla";
+ case ISD::PARTIAL_REDUCE_SUMLA:
+ return "partial_reduce_sumla";
// Vector Predication
#define BEGIN_REGISTER_VP_SDNODE(SDID, LEGALARG, NAME, ...) \
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 75c9bbaec7603..4d627af0e9bca 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -11887,13 +11887,23 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
EVT ExtMulOpVT =
EVT::getVectorVT(*DAG.getContext(), AccVT.getVectorElementType(),
MulOpVT.getVectorElementCount());
- unsigned ExtOpc = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA
- ? ISD::SIGN_EXTEND
- : ISD::ZERO_EXTEND;
-
if (ExtMulOpVT != MulOpVT) {
- MulLHS = DAG.getNode(ExtOpc, DL, ExtMulOpVT, MulLHS);
- MulRHS = DAG.getNode(ExtOpc, DL, ExtMulOpVT, MulRHS);
+ switch (N->getOpcode()) {
+ case ISD::PARTIAL_REDUCE_SMLA:
+ MulLHS = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtMulOpVT, MulLHS);
+ MulRHS = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtMulOpVT, MulRHS);
+ break;
+ case ISD::PARTIAL_REDUCE_UMLA:
+ MulLHS = DAG.getNode(ISD::ZERO_EXTEND, DL, ExtMulOpVT, MulLHS);
+ MulRHS = DAG.getNode(ISD::ZERO_EXTEND, DL, ExtMulOpVT, MulRHS);
+ break;
+ case ISD::PARTIAL_REDUCE_SUMLA:
+ MulLHS = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtMulOpVT, MulLHS);
+ MulRHS = DAG.getNode(ISD::ZERO_EXTEND, DL, ExtMulOpVT, MulRHS);
+ break;
+ default:
+ llvm_unreachable("unexpected opcode");
+ }
}
SDValue Input = MulLHS;
APInt ConstantOne;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index b7f0bcfd015bc..f3e8a6974c25f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1874,8 +1874,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
if (EnablePartialReduceNodes && Subtarget->isSVEorStreamingSVEAvailable()) {
// Mark known legal pairs as 'Legal' (these will expand to UDOT or SDOT).
// Other pairs will default to 'Expand'.
- setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Legal);
- setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal);
+ setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Custom);
+ setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Custom);
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom);
}
@@ -7745,6 +7745,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
return LowerVECTOR_HISTOGRAM(Op, DAG);
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_UMLA:
+ case ISD::PARTIAL_REDUCE_SUMLA:
return LowerPARTIAL_REDUCE_MLA(Op, DAG);
}
}
@@ -29532,13 +29533,24 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
SDValue
AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
SelectionDAG &DAG) const {
+ // No support for sumla forms, let generic legalization handle them
+ if (Op->getOpcode() == ISD::PARTIAL_REDUCE_SUMLA)
+ return SDValue();
+
SDLoc DL(Op);
SDValue Acc = Op.getOperand(0);
SDValue LHS = Op.getOperand(1);
SDValue RHS = Op.getOperand(2);
EVT ResultVT = Op.getValueType();
- assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8);
+ EVT OpVT = LHS.getValueType();
+
+ // These two are legal...
+ if ((ResultVT == MVT::nxv2i64 && OpVT == MVT::nxv8i16) ||
+ (ResultVT == MVT::nxv4i32 && OpVT == MVT::nxv16i8))
+ return Op;
+
+ assert(ResultVT == MVT::nxv2i64 && OpVT == MVT::nxv16i8);
SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32,
DAG.getConstant(0, DL, MVT::nxv4i32), LHS, RHS);
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 476596e4e0104..5622b68475305 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -8240,6 +8240,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return lowerADJUST_TRAMPOLINE(Op, DAG);
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
+ case ISD::PARTIAL_REDUCE_SUMLA:
return lowerPARTIAL_REDUCE_MLA(Op, DAG);
}
}
@@ -8391,8 +8392,20 @@ SDValue RISCVTargetLowering::lowerPARTIAL_REDUCE_MLA(SDValue Op,
SDValue B = Op.getOperand(2);
assert(A.getSimpleValueType() == B.getSimpleValueType() &&
A.getSimpleValueType().getVectorElementType() == MVT::i8);
- bool IsSigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
- unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
+ unsigned Opc;
+ switch (Op.getOpcode()) {
+ case ISD::PARTIAL_REDUCE_SMLA:
+ Opc = RISCVISD::VQDOT_VL;
+ break;
+ case ISD::PARTIAL_REDUCE_UMLA:
+ Opc = RISCVISD::VQDOTU_VL;
+ break;
+ case ISD::PARTIAL_REDUCE_SUMLA:
+ Opc = RISCVISD::VQDOTSU_VL;
+ break;
+ default:
+ llvm_unreachable("Unexpected opcode");
+ }
auto [Mask, VL] = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
return DAG.getNode(Opc, DL, VT, {A, B, Accum, Mask, VL});
}
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 5bc9a101b1e44..35b19eee5d983 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -159,26 +159,71 @@ define <vscale x 4 x i32> @sudot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a,
; CHECK-NOI8MM-NEXT: mla z0.s, p0/m, z1.s, z2.s
; CHECK-NOI8MM-NEXT: ret
;
-; CHECK-NEWLOWERING-LABEL: sudot:
-; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: sunpklo z3.h, z1.b
-; CHECK-NEWLOWERING-NEXT: uunpklo z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT: ptrue p0.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z1.h, z1.b
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT: sunpklo z5.s, z3.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z5.s, z1.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z6.s, z2.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z3.s, z4.s
-; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s
-; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z1.s, z2.s
-; CHECK-NEWLOWERING-NEXT: ret
+; CHECK-NEWLOWERING-SVE-LABEL: sudot:
+; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z3.h, z2.b
+; CHECK-NEWLOWERING-SVE-NEXT: sunpklo z4.h, z1.b
+; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-SVE-NEXT: sunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-SVE-NEXT: ptrue p0.s
+; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z5.s, z3.h
+; CHECK-NEWLOWERING-SVE-NEXT: sunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-SVE-NEXT: sunpkhi z4.s, z4.h
+; CHECK-NEWLOWERING-SVE-NEXT: mla z0.s, p0/m, z6.s, z5.s
+; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z5.s, z2.h
+; CHECK-NEWLOWERING-SVE-NEXT: sunpkhi z6.s, z1.h
+; CHECK-NEWLOWERING-SVE-NEXT: mul z3.s, p0/m, z3.s, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z2.s, z2.h
+; CHECK-NEWLOWERING-SVE-NEXT: sunpklo z1.s, z1.h
+; CHECK-NEWLOWERING-SVE-NEXT: mla z0.s, p0/m, z6.s, z5.s
+; CHECK-NEWLOWERING-SVE-NEXT: mad z1.s, p0/m, z2.s, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: sudot:
+; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT: uunpklo z3.h, z2.b
+; CHECK-NEWLOWERING-SVE2-NEXT: sunpklo z4.h, z1.b
+; CHECK-NEWLOWERING-SVE2-NEXT: uunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-SVE2-NEXT: sunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-SVE2-NEXT: ptrue p0.s
+; CHECK-NEWLOWERING-SVE2-NEXT: uunpklo z5.s, z3.h
+; CHECK-NEWLOWERING-SVE2-NEXT: sunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-SVE2-NEXT: uunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-SVE2-NEXT: sunpkhi z4.s, z4.h
+; CHECK-NEWLOWERING-SVE2-NEXT: mla z0.s, p0/m, z6.s, z5.s
+; CHECK-NEWLOWERING-SVE2-NEXT: uunpkhi z5.s, z2.h
+; CHECK-NEWLOWERING-SVE2-NEXT: sunpkhi z6.s, z1.h
+; CHECK-NEWLOWERING-SVE2-NEXT: mul z3.s, z4.s, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT: uunpklo z2.s, z2.h
+; CHECK-NEWLOWERING-SVE2-NEXT: sunpklo z1.s, z1.h
+; CHECK-NEWLOWERING-SVE2-NEXT: mla z0.s, p0/m, z6.s, z5.s
+; CHECK-NEWLOWERING-SVE2-NEXT: mad z1.s, p0/m, z2.s, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT: add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-SVE2-NEXT: ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: sudot:
+; CHECK-NEWLOWERING-SME: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SME-NEXT: uunpklo z3.h, z2.b
+; CHECK-NEWLOWERING-SME-NEXT: sunpklo z4.h, z1.b
+; CHECK-NEWLOWERING-SME-NEXT: uunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-SME-NEXT: sunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-SME-NEXT: ptrue p0.s
+; CHECK-NEWLOWERING-SME-NEXT: uunpklo z5.s, z3.h
+; CHECK-NEWLOWERING-SME-NEXT: su...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Questions to the AArch64 folks:
- Does aarch64 have a good implementation for this that I missed?
There is the USDOT instruction, which does a dot-product by zero-extending the LHS and sign-extending the RHS.
- If not, any suggestions on how to structure splitting the legalization table? I could add the opcode to the table key; that's probably the easiest.
For AArch64 there is (zext, zext), (sext, sext), (zext, sext) (and because the operation is commutative (sext, zext) as well). We need to encode this in the table somehow, because a different set of types are supported for the mixed extends.
@@ -1874,8 +1874,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, | |||
if (EnablePartialReduceNodes && Subtarget->isSVEorStreamingSVEAvailable()) { | |||
// Mark known legal pairs as 'Legal' (these will expand to UDOT or SDOT). | |||
// Other pairs will default to 'Expand'. | |||
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Legal); | |||
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal); | |||
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Custom); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we change this interface (and internal table) to include the kind of extension being done? (signed/unsigned or mixed)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just throwing this out there but perhaps change it to setReduceAction(Opcode, ResultType, OperandType)
? to make it easier to add more reductions in the future. This might also be a way to relax the result type requirements of the current VECREDUCE_***
nodes.
On it's own, this change should be non-functional. This is a preparatory change for llvm#141267 which adds a new form of PARTIAL_REDUCE_*MLA. As noted in the discussion on that review, AArch64 needs a different set of legal and custom types for the PARTIAL_REDUCE_SUMLA variant than the currently existing PARTIAL_REDUCE_UMLA/SMLA.
I've posted #141970 which will split the legalize table by opcode. Once this lands, I'll rebase. My plan is to leave the AArch64 cases in the default (non-legal) state, and focus only on the RISCV support. AArch64 folks can follow up with the USDOT support if desired. |
On it's own, this change should be non-functional. This is a preparatory change for #141267 which adds a new form of PARTIAL_REDUCE_*MLA. As noted in the discussion on that review, AArch64 needs a different set of legal and custom types for the PARTIAL_REDUCE_SUMLA variant than the currently existing PARTIAL_REDUCE_UMLA/SMLA.
…nfc] (#141970) On it's own, this change should be non-functional. This is a preparatory change for llvm/llvm-project#141267 which adds a new form of PARTIAL_REDUCE_*MLA. As noted in the discussion on that review, AArch64 needs a different set of legal and custom types for the PARTIAL_REDUCE_SUMLA variant than the currently existing PARTIAL_REDUCE_UMLA/SMLA.
…141970) On it's own, this change should be non-functional. This is a preparatory change for llvm#141267 which adds a new form of PARTIAL_REDUCE_*MLA. As noted in the discussion on that review, AArch64 needs a different set of legal and custom types for the PARTIAL_REDUCE_SUMLA variant than the currently existing PARTIAL_REDUCE_UMLA/SMLA.
#141970 has landed, and I've merged main into the dev branch. Still need to improve test coverage of the new node, but this should leave WIP-status in a day or so. |
✅ With the latest revision this PR passed the C/C++ code formatter. |
I've landed tests in main, and merged them into the dev branch. I've gone through and done some cleanup/generalization based on the additional tests. This should be ready for regularly review. |
We have recently added the partial_reduce_smla and partial_reduce_umla nodes to represent Acc += ext(b) * ext(b) where the two extends have to have the same source type, and have the same extend kind.
For riscv64 w/zvqdotq, we have the vqdot and vqdotu instructions which correspond to the existing nodes, but we also have vqdotsu which represents the case where the two extends are sign and zero respective (i.e. not the same type of extend).
This patch adds a partial_reduce_sumla node which has sign extension for A, and zero extension for B. The addition is somewhat mechanical.