Skip to content

Commit d6fe8d3

Browse files
committed
[DAG] Fold concat_vectors(concat_vectors(x,y),concat_vectors(a,b)) -> concat_vectors(x,y,a,b)
Follow-up to D107068, attempt to fold nested concat_vectors/undefs, as long as both the vector and inner subvector types are legal. This exposed the same issue in ARM's MVE LowerCONCAT_VECTORS_i1 (raised as PR51365) and AArch64's performConcatVectorsCombine which both assumed concat_vectors only took 2 subvector operands. Differential Revision: https://reviews.llvm.org/D107597
1 parent b4a1f44 commit d6fe8d3

11 files changed

+359
-420
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19865,6 +19865,44 @@ static SDValue combineConcatVectorOfScalars(SDNode *N, SelectionDAG &DAG) {
1986519865
return DAG.getBitcast(VT, DAG.getBuildVector(VecVT, DL, Ops));
1986619866
}
1986719867

19868+
// Attempt to merge nested concat_vectors/undefs.
19869+
// Fold concat_vectors(concat_vectors(x,y,z,w),u,u,concat_vectors(a,b,c,d))
19870+
// --> concat_vectors(x,y,z,w,u,u,u,u,u,u,u,u,a,b,c,d)
19871+
static SDValue combineConcatVectorOfConcatVectors(SDNode *N,
19872+
SelectionDAG &DAG) {
19873+
EVT VT = N->getValueType(0);
19874+
19875+
// Ensure we're concatenating UNDEF and CONCAT_VECTORS nodes of similar types.
19876+
EVT SubVT;
19877+
SDValue FirstConcat;
19878+
for (const SDValue &Op : N->ops()) {
19879+
if (Op.isUndef())
19880+
continue;
19881+
if (Op.getOpcode() != ISD::CONCAT_VECTORS)
19882+
return SDValue();
19883+
if (!FirstConcat) {
19884+
SubVT = Op.getOperand(0).getValueType();
19885+
if (!DAG.getTargetLoweringInfo().isTypeLegal(SubVT))
19886+
return SDValue();
19887+
FirstConcat = Op;
19888+
continue;
19889+
}
19890+
if (SubVT != Op.getOperand(0).getValueType())
19891+
return SDValue();
19892+
}
19893+
assert(FirstConcat && "Concat of all-undefs found");
19894+
19895+
SmallVector<SDValue> ConcatOps;
19896+
for (const SDValue &Op : N->ops()) {
19897+
if (Op.isUndef()) {
19898+
ConcatOps.append(FirstConcat->getNumOperands(), DAG.getUNDEF(SubVT));
19899+
continue;
19900+
}
19901+
ConcatOps.append(Op->op_begin(), Op->op_end());
19902+
}
19903+
return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, ConcatOps);
19904+
}
19905+
1986819906
// Check to see if this is a CONCAT_VECTORS of a bunch of EXTRACT_SUBVECTOR
1986919907
// operations. If so, and if the EXTRACT_SUBVECTOR vector inputs come from at
1987019908
// most two distinct vectors the same size as the result, attempt to turn this
@@ -20124,13 +20162,19 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
2012420162
}
2012520163

2012620164
// Fold CONCAT_VECTORS of only bitcast scalars (or undef) to BUILD_VECTOR.
20165+
// FIXME: Add support for concat_vectors(bitcast(vec0),bitcast(vec1),...).
2012720166
if (SDValue V = combineConcatVectorOfScalars(N, DAG))
2012820167
return V;
2012920168

20130-
// Fold CONCAT_VECTORS of EXTRACT_SUBVECTOR (or undef) to VECTOR_SHUFFLE.
20131-
if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT))
20169+
if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT)) {
20170+
// Fold CONCAT_VECTORS of CONCAT_VECTORS (or undef) to VECTOR_SHUFFLE.
20171+
if (SDValue V = combineConcatVectorOfConcatVectors(N, DAG))
20172+
return V;
20173+
20174+
// Fold CONCAT_VECTORS of EXTRACT_SUBVECTOR (or undef) to VECTOR_SHUFFLE.
2013220175
if (SDValue V = combineConcatVectorOfExtracts(N, DAG))
2013320176
return V;
20177+
}
2013420178

2013520179
if (SDValue V = combineConcatVectorOfCasts(N, DAG))
2013620180
return V;

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10459,8 +10459,29 @@ SDValue AArch64TargetLowering::LowerCONCAT_VECTORS(SDValue Op,
1045910459
isTypeLegal(Op.getValueType()) &&
1046010460
"Expected legal scalable vector type!");
1046110461

10462-
if (isTypeLegal(Op.getOperand(0).getValueType()) && Op.getNumOperands() == 2)
10463-
return Op;
10462+
if (isTypeLegal(Op.getOperand(0).getValueType())) {
10463+
unsigned NumOperands = Op->getNumOperands();
10464+
assert(NumOperands > 1 && isPowerOf2_32(NumOperands) &&
10465+
"Unexpected number of operands in CONCAT_VECTORS");
10466+
10467+
if (Op.getNumOperands() == 2)
10468+
return Op;
10469+
10470+
// Concat each pair of subvectors and pack into the lower half of the array.
10471+
SmallVector<SDValue> ConcatOps(Op->op_begin(), Op->op_end());
10472+
while (ConcatOps.size() > 1) {
10473+
for (unsigned I = 0, E = ConcatOps.size(); I != E; I += 2) {
10474+
SDValue V1 = ConcatOps[I];
10475+
SDValue V2 = ConcatOps[I + 1];
10476+
EVT SubVT = V1.getValueType();
10477+
EVT PairVT = SubVT.getDoubleNumVectorElementsVT(*DAG.getContext());
10478+
ConcatOps[I / 2] =
10479+
DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(Op), PairVT, V1, V2);
10480+
}
10481+
ConcatOps.resize(ConcatOps.size() / 2);
10482+
}
10483+
return ConcatOps[0];
10484+
}
1046410485

1046510486
return SDValue();
1046610487
}
@@ -13621,7 +13642,7 @@ static SDValue performConcatVectorsCombine(SDNode *N,
1362113642
// If we see a (concat_vectors (v1x64 A), (v1x64 A)) it's really a vector
1362213643
// splat. The indexed instructions are going to be expecting a DUPLANE64, so
1362313644
// canonicalise to that.
13624-
if (N0 == N1 && VT.getVectorNumElements() == 2) {
13645+
if (N->getNumOperands() == 2 && N0 == N1 && VT.getVectorNumElements() == 2) {
1362513646
assert(VT.getScalarSizeInBits() == 64);
1362613647
return DAG.getNode(AArch64ISD::DUPLANE64, dl, VT, WidenVector(N0, DAG),
1362713648
DAG.getConstant(0, dl, MVT::i64));
@@ -13636,7 +13657,7 @@ static SDValue performConcatVectorsCombine(SDNode *N,
1363613657
// becomes
1363713658
// (bitconvert (concat_vectors (v4i16 (bitconvert LHS)), RHS))
1363813659

13639-
if (N1Opc != ISD::BITCAST)
13660+
if (N->getNumOperands() != 2 || N1Opc != ISD::BITCAST)
1364013661
return SDValue();
1364113662
SDValue RHS = N1->getOperand(0);
1364213663
MVT RHSTy = RHS.getValueType().getSimpleVT();

llvm/lib/Target/ARM/ARMISelLowering.cpp

Lines changed: 55 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -8824,54 +8824,68 @@ static SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG,
88248824

88258825
static SDValue LowerCONCAT_VECTORS_i1(SDValue Op, SelectionDAG &DAG,
88268826
const ARMSubtarget *ST) {
8827-
SDValue V1 = Op.getOperand(0);
8828-
SDValue V2 = Op.getOperand(1);
88298827
SDLoc dl(Op);
8830-
EVT VT = Op.getValueType();
8831-
EVT Op1VT = V1.getValueType();
8832-
EVT Op2VT = V2.getValueType();
8833-
unsigned NumElts = VT.getVectorNumElements();
8834-
8835-
assert(Op1VT == Op2VT && "Operand types don't match!");
8836-
assert(VT.getScalarSizeInBits() == 1 &&
8828+
assert(Op.getValueType().getScalarSizeInBits() == 1 &&
8829+
"Unexpected custom CONCAT_VECTORS lowering");
8830+
assert(isPowerOf2_32(Op.getNumOperands()) &&
88378831
"Unexpected custom CONCAT_VECTORS lowering");
88388832
assert(ST->hasMVEIntegerOps() &&
88398833
"CONCAT_VECTORS lowering only supported for MVE");
88408834

8841-
SDValue NewV1 = PromoteMVEPredVector(dl, V1, Op1VT, DAG);
8842-
SDValue NewV2 = PromoteMVEPredVector(dl, V2, Op2VT, DAG);
8843-
8844-
// We now have Op1 + Op2 promoted to vectors of integers, where v8i1 gets
8845-
// promoted to v8i16, etc.
8846-
8847-
MVT ElType = getVectorTyFromPredicateVector(VT).getScalarType().getSimpleVT();
8848-
8849-
// Extract the vector elements from Op1 and Op2 one by one and truncate them
8850-
// to be the right size for the destination. For example, if Op1 is v4i1 then
8851-
// the promoted vector is v4i32. The result of concatentation gives a v8i1,
8852-
// which when promoted is v8i16. That means each i32 element from Op1 needs
8853-
// truncating to i16 and inserting in the result.
8854-
EVT ConcatVT = MVT::getVectorVT(ElType, NumElts);
8855-
SDValue ConVec = DAG.getNode(ISD::UNDEF, dl, ConcatVT);
8856-
auto ExractInto = [&DAG, &dl](SDValue NewV, SDValue ConVec, unsigned &j) {
8857-
EVT NewVT = NewV.getValueType();
8858-
EVT ConcatVT = ConVec.getValueType();
8859-
for (unsigned i = 0, e = NewVT.getVectorNumElements(); i < e; i++, j++) {
8860-
SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i32, NewV,
8861-
DAG.getIntPtrConstant(i, dl));
8862-
ConVec = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, ConcatVT, ConVec, Elt,
8863-
DAG.getConstant(j, dl, MVT::i32));
8864-
}
8865-
return ConVec;
8835+
auto ConcatPair = [&](SDValue V1, SDValue V2) {
8836+
EVT Op1VT = V1.getValueType();
8837+
EVT Op2VT = V2.getValueType();
8838+
assert(Op1VT == Op2VT && "Operand types don't match!");
8839+
EVT VT = Op1VT.getDoubleNumVectorElementsVT(*DAG.getContext());
8840+
8841+
SDValue NewV1 = PromoteMVEPredVector(dl, V1, Op1VT, DAG);
8842+
SDValue NewV2 = PromoteMVEPredVector(dl, V2, Op2VT, DAG);
8843+
8844+
// We now have Op1 + Op2 promoted to vectors of integers, where v8i1 gets
8845+
// promoted to v8i16, etc.
8846+
MVT ElType =
8847+
getVectorTyFromPredicateVector(VT).getScalarType().getSimpleVT();
8848+
unsigned NumElts = 2 * Op1VT.getVectorNumElements();
8849+
8850+
// Extract the vector elements from Op1 and Op2 one by one and truncate them
8851+
// to be the right size for the destination. For example, if Op1 is v4i1
8852+
// then the promoted vector is v4i32. The result of concatentation gives a
8853+
// v8i1, which when promoted is v8i16. That means each i32 element from Op1
8854+
// needs truncating to i16 and inserting in the result.
8855+
EVT ConcatVT = MVT::getVectorVT(ElType, NumElts);
8856+
SDValue ConVec = DAG.getNode(ISD::UNDEF, dl, ConcatVT);
8857+
auto ExtractInto = [&DAG, &dl](SDValue NewV, SDValue ConVec, unsigned &j) {
8858+
EVT NewVT = NewV.getValueType();
8859+
EVT ConcatVT = ConVec.getValueType();
8860+
for (unsigned i = 0, e = NewVT.getVectorNumElements(); i < e; i++, j++) {
8861+
SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i32, NewV,
8862+
DAG.getIntPtrConstant(i, dl));
8863+
ConVec = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, ConcatVT, ConVec, Elt,
8864+
DAG.getConstant(j, dl, MVT::i32));
8865+
}
8866+
return ConVec;
8867+
};
8868+
unsigned j = 0;
8869+
ConVec = ExtractInto(NewV1, ConVec, j);
8870+
ConVec = ExtractInto(NewV2, ConVec, j);
8871+
8872+
// Now return the result of comparing the subvector with zero,
8873+
// which will generate a real predicate, i.e. v4i1, v8i1 or v16i1.
8874+
return DAG.getNode(ARMISD::VCMPZ, dl, VT, ConVec,
8875+
DAG.getConstant(ARMCC::NE, dl, MVT::i32));
88668876
};
8867-
unsigned j = 0;
8868-
ConVec = ExractInto(NewV1, ConVec, j);
8869-
ConVec = ExractInto(NewV2, ConVec, j);
88708877

8871-
// Now return the result of comparing the subvector with zero,
8872-
// which will generate a real predicate, i.e. v4i1, v8i1 or v16i1.
8873-
return DAG.getNode(ARMISD::VCMPZ, dl, VT, ConVec,
8874-
DAG.getConstant(ARMCC::NE, dl, MVT::i32));
8878+
// Concat each pair of subvectors and pack into the lower half of the array.
8879+
SmallVector<SDValue> ConcatOps(Op->op_begin(), Op->op_end());
8880+
while (ConcatOps.size() > 1) {
8881+
for (unsigned I = 0, E = ConcatOps.size(); I != E; I += 2) {
8882+
SDValue V1 = ConcatOps[I];
8883+
SDValue V2 = ConcatOps[I + 1];
8884+
ConcatOps[I / 2] = ConcatPair(V1, V2);
8885+
}
8886+
ConcatOps.resize(ConcatOps.size() / 2);
8887+
}
8888+
return ConcatOps[0];
88758889
}
88768890

88778891
static SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG,

0 commit comments

Comments
 (0)