@@ -8824,54 +8824,68 @@ static SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG,
8824
8824
8825
8825
static SDValue LowerCONCAT_VECTORS_i1(SDValue Op, SelectionDAG &DAG,
8826
8826
const ARMSubtarget *ST) {
8827
- SDValue V1 = Op.getOperand(0);
8828
- SDValue V2 = Op.getOperand(1);
8829
8827
SDLoc dl(Op);
8830
- EVT VT = Op.getValueType();
8831
- EVT Op1VT = V1.getValueType();
8832
- EVT Op2VT = V2.getValueType();
8833
- unsigned NumElts = VT.getVectorNumElements();
8834
-
8835
- assert(Op1VT == Op2VT && "Operand types don't match!");
8836
- assert(VT.getScalarSizeInBits() == 1 &&
8828
+ assert(Op.getValueType().getScalarSizeInBits() == 1 &&
8829
+ "Unexpected custom CONCAT_VECTORS lowering");
8830
+ assert(isPowerOf2_32(Op.getNumOperands()) &&
8837
8831
"Unexpected custom CONCAT_VECTORS lowering");
8838
8832
assert(ST->hasMVEIntegerOps() &&
8839
8833
"CONCAT_VECTORS lowering only supported for MVE");
8840
8834
8841
- SDValue NewV1 = PromoteMVEPredVector(dl, V1, Op1VT, DAG);
8842
- SDValue NewV2 = PromoteMVEPredVector(dl, V2, Op2VT, DAG);
8843
-
8844
- // We now have Op1 + Op2 promoted to vectors of integers, where v8i1 gets
8845
- // promoted to v8i16, etc.
8846
-
8847
- MVT ElType = getVectorTyFromPredicateVector(VT).getScalarType().getSimpleVT();
8848
-
8849
- // Extract the vector elements from Op1 and Op2 one by one and truncate them
8850
- // to be the right size for the destination. For example, if Op1 is v4i1 then
8851
- // the promoted vector is v4i32. The result of concatentation gives a v8i1,
8852
- // which when promoted is v8i16. That means each i32 element from Op1 needs
8853
- // truncating to i16 and inserting in the result.
8854
- EVT ConcatVT = MVT::getVectorVT(ElType, NumElts);
8855
- SDValue ConVec = DAG.getNode(ISD::UNDEF, dl, ConcatVT);
8856
- auto ExractInto = [&DAG, &dl](SDValue NewV, SDValue ConVec, unsigned &j) {
8857
- EVT NewVT = NewV.getValueType();
8858
- EVT ConcatVT = ConVec.getValueType();
8859
- for (unsigned i = 0, e = NewVT.getVectorNumElements(); i < e; i++, j++) {
8860
- SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i32, NewV,
8861
- DAG.getIntPtrConstant(i, dl));
8862
- ConVec = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, ConcatVT, ConVec, Elt,
8863
- DAG.getConstant(j, dl, MVT::i32));
8864
- }
8865
- return ConVec;
8835
+ auto ConcatPair = [&](SDValue V1, SDValue V2) {
8836
+ EVT Op1VT = V1.getValueType();
8837
+ EVT Op2VT = V2.getValueType();
8838
+ assert(Op1VT == Op2VT && "Operand types don't match!");
8839
+ EVT VT = Op1VT.getDoubleNumVectorElementsVT(*DAG.getContext());
8840
+
8841
+ SDValue NewV1 = PromoteMVEPredVector(dl, V1, Op1VT, DAG);
8842
+ SDValue NewV2 = PromoteMVEPredVector(dl, V2, Op2VT, DAG);
8843
+
8844
+ // We now have Op1 + Op2 promoted to vectors of integers, where v8i1 gets
8845
+ // promoted to v8i16, etc.
8846
+ MVT ElType =
8847
+ getVectorTyFromPredicateVector(VT).getScalarType().getSimpleVT();
8848
+ unsigned NumElts = 2 * Op1VT.getVectorNumElements();
8849
+
8850
+ // Extract the vector elements from Op1 and Op2 one by one and truncate them
8851
+ // to be the right size for the destination. For example, if Op1 is v4i1
8852
+ // then the promoted vector is v4i32. The result of concatentation gives a
8853
+ // v8i1, which when promoted is v8i16. That means each i32 element from Op1
8854
+ // needs truncating to i16 and inserting in the result.
8855
+ EVT ConcatVT = MVT::getVectorVT(ElType, NumElts);
8856
+ SDValue ConVec = DAG.getNode(ISD::UNDEF, dl, ConcatVT);
8857
+ auto ExtractInto = [&DAG, &dl](SDValue NewV, SDValue ConVec, unsigned &j) {
8858
+ EVT NewVT = NewV.getValueType();
8859
+ EVT ConcatVT = ConVec.getValueType();
8860
+ for (unsigned i = 0, e = NewVT.getVectorNumElements(); i < e; i++, j++) {
8861
+ SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i32, NewV,
8862
+ DAG.getIntPtrConstant(i, dl));
8863
+ ConVec = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, ConcatVT, ConVec, Elt,
8864
+ DAG.getConstant(j, dl, MVT::i32));
8865
+ }
8866
+ return ConVec;
8867
+ };
8868
+ unsigned j = 0;
8869
+ ConVec = ExtractInto(NewV1, ConVec, j);
8870
+ ConVec = ExtractInto(NewV2, ConVec, j);
8871
+
8872
+ // Now return the result of comparing the subvector with zero,
8873
+ // which will generate a real predicate, i.e. v4i1, v8i1 or v16i1.
8874
+ return DAG.getNode(ARMISD::VCMPZ, dl, VT, ConVec,
8875
+ DAG.getConstant(ARMCC::NE, dl, MVT::i32));
8866
8876
};
8867
- unsigned j = 0;
8868
- ConVec = ExractInto(NewV1, ConVec, j);
8869
- ConVec = ExractInto(NewV2, ConVec, j);
8870
8877
8871
- // Now return the result of comparing the subvector with zero,
8872
- // which will generate a real predicate, i.e. v4i1, v8i1 or v16i1.
8873
- return DAG.getNode(ARMISD::VCMPZ, dl, VT, ConVec,
8874
- DAG.getConstant(ARMCC::NE, dl, MVT::i32));
8878
+ // Concat each pair of subvectors and pack into the lower half of the array.
8879
+ SmallVector<SDValue> ConcatOps(Op->op_begin(), Op->op_end());
8880
+ while (ConcatOps.size() > 1) {
8881
+ for (unsigned I = 0, E = ConcatOps.size(); I != E; I += 2) {
8882
+ SDValue V1 = ConcatOps[I];
8883
+ SDValue V2 = ConcatOps[I + 1];
8884
+ ConcatOps[I / 2] = ConcatPair(V1, V2);
8885
+ }
8886
+ ConcatOps.resize(ConcatOps.size() / 2);
8887
+ }
8888
+ return ConcatOps[0];
8875
8889
}
8876
8890
8877
8891
static SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG,
0 commit comments