@@ -9793,6 +9793,93 @@ static SDValue LowerCONCAT_VECTORS(SDValue Op,
9793
9793
// patterns.
9794
9794
//===----------------------------------------------------------------------===//
9795
9795
9796
+ /// Checks whether the vector elements referenced by two shuffle masks are
9797
+ /// equivalent.
9798
+ static bool IsElementEquivalent(int MaskSize, SDValue Op, SDValue ExpectedOp,
9799
+ int Idx, int ExpectedIdx) {
9800
+ assert(0 <= Idx && Idx < MaskSize && 0 <= ExpectedIdx &&
9801
+ ExpectedIdx < MaskSize && "Out of range element index");
9802
+ if (!Op || !ExpectedOp || Op.getOpcode() != ExpectedOp.getOpcode())
9803
+ return false;
9804
+
9805
+ EVT VT = Op.getValueType();
9806
+ switch (Op.getOpcode()) {
9807
+ case ISD::BUILD_VECTOR:
9808
+ // If the values are build vectors, we can look through them to find
9809
+ // equivalent inputs that make the shuffles equivalent.
9810
+ // TODO: Handle MaskSize != Op.getNumOperands()?
9811
+ if (MaskSize == (int)Op.getNumOperands() &&
9812
+ MaskSize == (int)ExpectedOp.getNumOperands())
9813
+ return Op.getOperand(Idx) == ExpectedOp.getOperand(ExpectedIdx);
9814
+ break;
9815
+ case ISD::BITCAST: {
9816
+ SDValue Src = peekThroughBitcasts(Op);
9817
+ EVT SrcVT = Src.getValueType();
9818
+ if (Op == ExpectedOp && SrcVT.isVector() &&
9819
+ (int)VT.getVectorNumElements() == MaskSize) {
9820
+ if ((SrcVT.getScalarSizeInBits() % VT.getScalarSizeInBits()) == 0) {
9821
+ unsigned Scale = SrcVT.getScalarSizeInBits() / VT.getScalarSizeInBits();
9822
+ return (Idx % Scale) == (ExpectedIdx % Scale) &&
9823
+ IsElementEquivalent(SrcVT.getVectorNumElements(), Src, Src,
9824
+ Idx / Scale, ExpectedIdx / Scale);
9825
+ }
9826
+ if ((VT.getScalarSizeInBits() % SrcVT.getScalarSizeInBits()) == 0) {
9827
+ unsigned Scale = VT.getScalarSizeInBits() / SrcVT.getScalarSizeInBits();
9828
+ for (unsigned I = 0; I != Scale; ++I)
9829
+ if (!IsElementEquivalent(SrcVT.getVectorNumElements(), Src, Src,
9830
+ (Idx * Scale) + I,
9831
+ (ExpectedIdx * Scale) + I))
9832
+ return false;
9833
+ return true;
9834
+ }
9835
+ }
9836
+ break;
9837
+ }
9838
+ case ISD::VECTOR_SHUFFLE: {
9839
+ auto *SVN = cast<ShuffleVectorSDNode>(Op);
9840
+ return Op == ExpectedOp && (int)VT.getVectorNumElements() == MaskSize &&
9841
+ SVN->getMaskElt(Idx) == SVN->getMaskElt(ExpectedIdx);
9842
+ }
9843
+ case X86ISD::VBROADCAST:
9844
+ case X86ISD::VBROADCAST_LOAD:
9845
+ // TODO: Handle MaskSize != VT.getVectorNumElements()?
9846
+ return (Op == ExpectedOp && (int)VT.getVectorNumElements() == MaskSize);
9847
+ case X86ISD::SUBV_BROADCAST_LOAD:
9848
+ // TODO: Handle MaskSize != VT.getVectorNumElements()?
9849
+ if (Op == ExpectedOp && (int)VT.getVectorNumElements() == MaskSize) {
9850
+ auto *MemOp = cast<MemSDNode>(Op);
9851
+ unsigned NumMemElts = MemOp->getMemoryVT().getVectorNumElements();
9852
+ return (Idx % NumMemElts) == (ExpectedIdx % NumMemElts);
9853
+ }
9854
+ break;
9855
+ case X86ISD::HADD:
9856
+ case X86ISD::HSUB:
9857
+ case X86ISD::FHADD:
9858
+ case X86ISD::FHSUB:
9859
+ case X86ISD::PACKSS:
9860
+ case X86ISD::PACKUS:
9861
+ // HOP(X,X) can refer to the elt from the lower/upper half of a lane.
9862
+ // TODO: Handle MaskSize != NumElts?
9863
+ // TODO: Handle HOP(X,Y) vs HOP(Y,X) equivalence cases.
9864
+ if (Op == ExpectedOp && Op.getOperand(0) == Op.getOperand(1)) {
9865
+ int NumElts = VT.getVectorNumElements();
9866
+ if (MaskSize == NumElts) {
9867
+ int NumLanes = VT.getSizeInBits() / 128;
9868
+ int NumEltsPerLane = NumElts / NumLanes;
9869
+ int NumHalfEltsPerLane = NumEltsPerLane / 2;
9870
+ bool SameLane =
9871
+ (Idx / NumEltsPerLane) == (ExpectedIdx / NumEltsPerLane);
9872
+ bool SameElt =
9873
+ (Idx % NumHalfEltsPerLane) == (ExpectedIdx % NumHalfEltsPerLane);
9874
+ return SameLane && SameElt;
9875
+ }
9876
+ }
9877
+ break;
9878
+ }
9879
+
9880
+ return false;
9881
+ }
9882
+
9796
9883
/// Tiny helper function to identify a no-op mask.
9797
9884
///
9798
9885
/// This is a somewhat boring predicate function. It checks whether the mask
@@ -9968,93 +10055,6 @@ static bool isRepeatedTargetShuffleMask(unsigned LaneSizeInBits, MVT VT,
9968
10055
Mask, RepeatedMask);
9969
10056
}
9970
10057
9971
- /// Checks whether the vector elements referenced by two shuffle masks are
9972
- /// equivalent.
9973
- static bool IsElementEquivalent(int MaskSize, SDValue Op, SDValue ExpectedOp,
9974
- int Idx, int ExpectedIdx) {
9975
- assert(0 <= Idx && Idx < MaskSize && 0 <= ExpectedIdx &&
9976
- ExpectedIdx < MaskSize && "Out of range element index");
9977
- if (!Op || !ExpectedOp || Op.getOpcode() != ExpectedOp.getOpcode())
9978
- return false;
9979
-
9980
- EVT VT = Op.getValueType();
9981
- switch (Op.getOpcode()) {
9982
- case ISD::BUILD_VECTOR:
9983
- // If the values are build vectors, we can look through them to find
9984
- // equivalent inputs that make the shuffles equivalent.
9985
- // TODO: Handle MaskSize != Op.getNumOperands()?
9986
- if (MaskSize == (int)Op.getNumOperands() &&
9987
- MaskSize == (int)ExpectedOp.getNumOperands())
9988
- return Op.getOperand(Idx) == ExpectedOp.getOperand(ExpectedIdx);
9989
- break;
9990
- case ISD::BITCAST: {
9991
- SDValue Src = peekThroughBitcasts(Op);
9992
- EVT SrcVT = Src.getValueType();
9993
- if (Op == ExpectedOp && SrcVT.isVector() &&
9994
- (int)VT.getVectorNumElements() == MaskSize) {
9995
- if ((SrcVT.getScalarSizeInBits() % VT.getScalarSizeInBits()) == 0) {
9996
- unsigned Scale = SrcVT.getScalarSizeInBits() / VT.getScalarSizeInBits();
9997
- return (Idx % Scale) == (ExpectedIdx % Scale) &&
9998
- IsElementEquivalent(SrcVT.getVectorNumElements(), Src, Src,
9999
- Idx / Scale, ExpectedIdx / Scale);
10000
- }
10001
- if ((VT.getScalarSizeInBits() % SrcVT.getScalarSizeInBits()) == 0) {
10002
- unsigned Scale = VT.getScalarSizeInBits() / SrcVT.getScalarSizeInBits();
10003
- for (unsigned I = 0; I != Scale; ++I)
10004
- if (!IsElementEquivalent(SrcVT.getVectorNumElements(), Src, Src,
10005
- (Idx * Scale) + I,
10006
- (ExpectedIdx * Scale) + I))
10007
- return false;
10008
- return true;
10009
- }
10010
- }
10011
- break;
10012
- }
10013
- case ISD::VECTOR_SHUFFLE: {
10014
- auto *SVN = cast<ShuffleVectorSDNode>(Op);
10015
- return Op == ExpectedOp && (int)VT.getVectorNumElements() == MaskSize &&
10016
- SVN->getMaskElt(Idx) == SVN->getMaskElt(ExpectedIdx);
10017
- }
10018
- case X86ISD::VBROADCAST:
10019
- case X86ISD::VBROADCAST_LOAD:
10020
- // TODO: Handle MaskSize != VT.getVectorNumElements()?
10021
- return (Op == ExpectedOp && (int)VT.getVectorNumElements() == MaskSize);
10022
- case X86ISD::SUBV_BROADCAST_LOAD:
10023
- // TODO: Handle MaskSize != VT.getVectorNumElements()?
10024
- if (Op == ExpectedOp && (int)VT.getVectorNumElements() == MaskSize) {
10025
- auto *MemOp = cast<MemSDNode>(Op);
10026
- unsigned NumMemElts = MemOp->getMemoryVT().getVectorNumElements();
10027
- return (Idx % NumMemElts) == (ExpectedIdx % NumMemElts);
10028
- }
10029
- break;
10030
- case X86ISD::HADD:
10031
- case X86ISD::HSUB:
10032
- case X86ISD::FHADD:
10033
- case X86ISD::FHSUB:
10034
- case X86ISD::PACKSS:
10035
- case X86ISD::PACKUS:
10036
- // HOP(X,X) can refer to the elt from the lower/upper half of a lane.
10037
- // TODO: Handle MaskSize != NumElts?
10038
- // TODO: Handle HOP(X,Y) vs HOP(Y,X) equivalence cases.
10039
- if (Op == ExpectedOp && Op.getOperand(0) == Op.getOperand(1)) {
10040
- int NumElts = VT.getVectorNumElements();
10041
- if (MaskSize == NumElts) {
10042
- int NumLanes = VT.getSizeInBits() / 128;
10043
- int NumEltsPerLane = NumElts / NumLanes;
10044
- int NumHalfEltsPerLane = NumEltsPerLane / 2;
10045
- bool SameLane =
10046
- (Idx / NumEltsPerLane) == (ExpectedIdx / NumEltsPerLane);
10047
- bool SameElt =
10048
- (Idx % NumHalfEltsPerLane) == (ExpectedIdx % NumHalfEltsPerLane);
10049
- return SameLane && SameElt;
10050
- }
10051
- }
10052
- break;
10053
- }
10054
-
10055
- return false;
10056
- }
10057
-
10058
10058
/// Checks whether a shuffle mask is equivalent to an explicit list of
10059
10059
/// arguments.
10060
10060
///
0 commit comments