@@ -10940,7 +10940,6 @@ SDValue SITargetLowering::performAndCombine(SDNode *N,
10940
10940
// performed.
10941
10941
static const std::optional<ByteProvider<SDValue>>
10942
10942
calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
10943
- std::optional<bool> IsSigned = std::nullopt,
10944
10943
unsigned Depth = 0) {
10945
10944
// We may need to recursively traverse a series of SRLs
10946
10945
if (Depth >= 6)
@@ -10952,16 +10951,12 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
10952
10951
10953
10952
switch (Op->getOpcode()) {
10954
10953
case ISD::TRUNCATE: {
10955
- return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, IsSigned,
10956
- Depth + 1);
10954
+ return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, Depth + 1);
10957
10955
}
10958
10956
10959
10957
case ISD::SIGN_EXTEND:
10960
10958
case ISD::ZERO_EXTEND:
10961
10959
case ISD::SIGN_EXTEND_INREG: {
10962
- IsSigned = IsSigned.value_or(false) ||
10963
- Op->getOpcode() == ISD::SIGN_EXTEND ||
10964
- Op->getOpcode() == ISD::SIGN_EXTEND_INREG;
10965
10960
SDValue NarrowOp = Op->getOperand(0);
10966
10961
auto NarrowVT = NarrowOp.getValueType();
10967
10962
if (Op->getOpcode() == ISD::SIGN_EXTEND_INREG) {
@@ -10974,8 +10969,7 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
10974
10969
10975
10970
if (SrcIndex >= NarrowByteWidth)
10976
10971
return std::nullopt;
10977
- return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, IsSigned,
10978
- Depth + 1);
10972
+ return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, Depth + 1);
10979
10973
}
10980
10974
10981
10975
case ISD::SRA:
@@ -10991,24 +10985,11 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
10991
10985
10992
10986
SrcIndex += BitShift / 8;
10993
10987
10994
- return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, IsSigned,
10995
- Depth + 1);
10988
+ return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, Depth + 1);
10996
10989
}
10997
10990
10998
10991
default: {
10999
- if (isa<AtomicSDNode>(Op) || Op->isMemIntrinsic()) {
11000
- // If this causes us to throw away signedness info, then fail.
11001
- if (IsSigned)
11002
- return std::nullopt;
11003
- return ByteProvider<SDValue>::getSrc(Op, DestByte, SrcIndex);
11004
- }
11005
-
11006
- if (auto L = dyn_cast<LoadSDNode>(Op))
11007
- if (L->getExtensionType() != ISD::NON_EXTLOAD)
11008
- IsSigned =
11009
- IsSigned.value_or(false) || L->getExtensionType() == ISD::SEXTLOAD;
11010
-
11011
- return ByteProvider<SDValue>::getSrc(Op, DestByte, SrcIndex, IsSigned);
10992
+ return ByteProvider<SDValue>::getSrc(Op, DestByte, SrcIndex);
11012
10993
}
11013
10994
}
11014
10995
llvm_unreachable("fully handled switch");
@@ -11022,8 +11003,7 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
11022
11003
// performed. \p StartingIndex is the originally requested byte of the Or
11023
11004
static const std::optional<ByteProvider<SDValue>>
11024
11005
calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
11025
- unsigned StartingIndex = 0,
11026
- std::optional<bool> IsSigned = std::nullopt) {
11006
+ unsigned StartingIndex = 0) {
11027
11007
// Finding Src tree of RHS of or typically requires at least 1 additional
11028
11008
// depth
11029
11009
if (Depth > 6)
@@ -11038,11 +11018,11 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
11038
11018
switch (Op.getOpcode()) {
11039
11019
case ISD::OR: {
11040
11020
auto RHS = calculateByteProvider(Op.getOperand(1), Index, Depth + 1,
11041
- StartingIndex, IsSigned );
11021
+ StartingIndex);
11042
11022
if (!RHS)
11043
11023
return std::nullopt;
11044
11024
auto LHS = calculateByteProvider(Op.getOperand(0), Index, Depth + 1,
11045
- StartingIndex, IsSigned );
11025
+ StartingIndex);
11046
11026
if (!LHS)
11047
11027
return std::nullopt;
11048
11028
// A well formed Or will have two ByteProviders for each byte, one of which
@@ -11073,7 +11053,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
11073
11053
return ByteProvider<SDValue>::getConstantZero();
11074
11054
}
11075
11055
11076
- return calculateSrcByte(Op->getOperand(0), StartingIndex, Index, IsSigned );
11056
+ return calculateSrcByte(Op->getOperand(0), StartingIndex, Index);
11077
11057
}
11078
11058
11079
11059
case ISD::FSHR: {
@@ -11122,7 +11102,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
11122
11102
// the SRL is Index + ByteShift
11123
11103
return BytesProvided - ByteShift > Index
11124
11104
? calculateSrcByte(Op->getOperand(0), StartingIndex,
11125
- Index + ByteShift, IsSigned )
11105
+ Index + ByteShift)
11126
11106
: ByteProvider<SDValue>::getConstantZero();
11127
11107
}
11128
11108
@@ -11143,7 +11123,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
11143
11123
return Index < ByteShift
11144
11124
? ByteProvider<SDValue>::getConstantZero()
11145
11125
: calculateByteProvider(Op.getOperand(0), Index - ByteShift,
11146
- Depth + 1, StartingIndex, IsSigned );
11126
+ Depth + 1, StartingIndex);
11147
11127
}
11148
11128
case ISD::ANY_EXTEND:
11149
11129
case ISD::SIGN_EXTEND:
@@ -11163,48 +11143,35 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
11163
11143
return std::nullopt;
11164
11144
uint64_t NarrowByteWidth = NarrowBitWidth / 8;
11165
11145
11166
- IsSigned =
11167
- Op->getOpcode() != ISD::ANY_EXTEND
11168
- ? std::optional<bool>(IsSigned.value_or(false) ||
11169
- Op->getOpcode() == ISD::SIGN_EXTEND ||
11170
- Op->getOpcode() == ISD::SIGN_EXTEND_INREG ||
11171
- Op->getOpcode() == ISD::AssertSext)
11172
- : IsSigned;
11173
-
11174
11146
if (Index >= NarrowByteWidth)
11175
11147
return Op.getOpcode() == ISD::ZERO_EXTEND
11176
11148
? std::optional<ByteProvider<SDValue>>(
11177
11149
ByteProvider<SDValue>::getConstantZero())
11178
11150
: std::nullopt;
11179
- return calculateByteProvider(NarrowOp, Index, Depth + 1, StartingIndex,
11180
- IsSigned);
11151
+ return calculateByteProvider(NarrowOp, Index, Depth + 1, StartingIndex);
11181
11152
}
11182
11153
11183
11154
case ISD::TRUNCATE: {
11184
11155
uint64_t NarrowByteWidth = BitWidth / 8;
11185
11156
11186
11157
if (NarrowByteWidth >= Index) {
11187
11158
return calculateByteProvider(Op.getOperand(0), Index, Depth + 1,
11188
- StartingIndex, IsSigned );
11159
+ StartingIndex);
11189
11160
}
11190
11161
11191
11162
return std::nullopt;
11192
11163
}
11193
11164
11194
11165
case ISD::CopyFromReg: {
11195
11166
if (BitWidth / 8 > Index)
11196
- return calculateSrcByte(Op, StartingIndex, Index, IsSigned );
11167
+ return calculateSrcByte(Op, StartingIndex, Index);
11197
11168
11198
11169
return std::nullopt;
11199
11170
}
11200
11171
11201
11172
case ISD::LOAD: {
11202
11173
auto L = cast<LoadSDNode>(Op.getNode());
11203
11174
11204
- // Only set IsSigned if the load is extended.
11205
- if (L->getExtensionType() != ISD::NON_EXTLOAD)
11206
- IsSigned =
11207
- IsSigned.value_or(false) || L->getExtensionType() == ISD::SEXTLOAD;
11208
11175
unsigned NarrowBitWidth = L->getMemoryVT().getSizeInBits();
11209
11176
if (NarrowBitWidth % 8 != 0)
11210
11177
return std::nullopt;
@@ -11221,15 +11188,15 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
11221
11188
}
11222
11189
11223
11190
if (NarrowByteWidth > Index) {
11224
- return calculateSrcByte(Op, StartingIndex, Index, IsSigned );
11191
+ return calculateSrcByte(Op, StartingIndex, Index);
11225
11192
}
11226
11193
11227
11194
return std::nullopt;
11228
11195
}
11229
11196
11230
11197
case ISD::BSWAP:
11231
11198
return calculateByteProvider(Op->getOperand(0), BitWidth / 8 - Index - 1,
11232
- Depth + 1, StartingIndex, IsSigned );
11199
+ Depth + 1, StartingIndex);
11233
11200
11234
11201
case ISD::EXTRACT_VECTOR_ELT: {
11235
11202
auto IdxOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
@@ -11244,7 +11211,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
11244
11211
}
11245
11212
11246
11213
return calculateSrcByte(ScalarSize == 32 ? Op : Op.getOperand(0),
11247
- StartingIndex, Index, IsSigned );
11214
+ StartingIndex, Index);
11248
11215
}
11249
11216
11250
11217
case AMDGPUISD::PERM: {
@@ -11260,10 +11227,9 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
11260
11227
auto NextOp = Op.getOperand(IdxMask > 0x03 ? 0 : 1);
11261
11228
auto NextIndex = IdxMask > 0x03 ? IdxMask % 4 : IdxMask;
11262
11229
11263
- return IdxMask != 0x0c
11264
- ? calculateSrcByte(NextOp, StartingIndex, NextIndex, IsSigned)
11265
- : ByteProvider<SDValue>(
11266
- ByteProvider<SDValue>::getConstantZero());
11230
+ return IdxMask != 0x0c ? calculateSrcByte(NextOp, StartingIndex, NextIndex)
11231
+ : ByteProvider<SDValue>(
11232
+ ByteProvider<SDValue>::getConstantZero());
11267
11233
}
11268
11234
11269
11235
default: {
@@ -13064,32 +13030,67 @@ static bool isMul(const SDValue Op) {
13064
13030
Opcode == AMDGPUISD::MUL_I24);
13065
13031
}
13066
13032
13067
- static std::optional<bool> checkSignedness(const SDValue &N,
13068
- ByteProvider<SDValue> &Src0,
13069
- ByteProvider<SDValue> &Src1) {
13070
- auto MulOpcode = N.getOpcode();
13071
- std::optional<bool> IterIsSigned;
13072
- // Both sides of the tree must have the same signedness semantics.
13073
- if ((Src0.IsSigned != Src1.IsSigned) ||
13074
- (Src0.IsSigned.value_or(false) != Src1.IsSigned.value_or(false)))
13075
- return IterIsSigned;
13076
- // If we have a MUL_U24 op with signed semantics, then fail.
13077
- if (Src0.IsSigned.value_or(false) && MulOpcode == AMDGPUISD::MUL_U24)
13078
- return IterIsSigned;
13079
- // If we have a MUL_I24 op with unsigned semantics, then fail.
13080
- if (!Src0.IsSigned.value_or(true) && MulOpcode == AMDGPUISD::MUL_I24)
13081
- return IterIsSigned;
13082
-
13083
- bool TopLevelSignedness =
13084
- MulOpcode == AMDGPUISD::MUL_I24 ||
13085
- (MulOpcode == ISD::MUL && N.getNode()->getFlags().hasNoSignedWrap() &&
13086
- !N.getNode()->getFlags().hasNoUnsignedWrap());
13087
-
13088
- // In cases where we are accumulating into an i8 (for v_dot4), the
13089
- // ByteProvider will not have signedness info since the MSBs are dont-cares.
13090
- // In this case, we simply use the TopLevelSignedness of the instruction.
13091
- IterIsSigned = Src0.IsSigned.value_or(TopLevelSignedness);
13092
- return IterIsSigned;
13033
+ static std::optional<bool>
13034
+ checkDot4MulSignedness(const SDValue &N, ByteProvider<SDValue> &Src0,
13035
+ ByteProvider<SDValue> &Src1, const SDValue &S0Op,
13036
+ const SDValue &S1Op, const SelectionDAG &DAG) {
13037
+ // If we both ops are i8s (pre legalize-dag), then the signedness semantics
13038
+ // of the dot4 is irrelevant.
13039
+ if (S0Op.getValueSizeInBits() == 8 && S1Op.getValueSizeInBits() == 8)
13040
+ return false;
13041
+
13042
+ auto Known0 = DAG.computeKnownBits(S0Op, 0);
13043
+ bool S0IsUnsigned = Known0.countMinLeadingZeros() > 0;
13044
+ bool S0IsSigned = Known0.countMinLeadingOnes() > 0;
13045
+ auto Known1 = DAG.computeKnownBits(S1Op, 0);
13046
+ bool S1IsUnsigned = Known1.countMinLeadingZeros() > 0;
13047
+ bool S1IsSigned = Known1.countMinLeadingOnes() > 0;
13048
+
13049
+ assert(!(S0IsUnsigned && S0IsSigned));
13050
+ assert(!(S1IsUnsigned && S1IsSigned));
13051
+
13052
+ // There are 9 possible permutations of
13053
+ // {S0IsUnsigned, S0IsSigned, S1IsUnsigned, S1IsSigned}
13054
+
13055
+ // In two permutations, the sign bits are known to be the same for both Ops,
13056
+ // so simply return Signed / Unsigned corresponding to the MSB
13057
+
13058
+ if ((S0IsUnsigned && S1IsUnsigned) || (S0IsSigned && S1IsSigned))
13059
+ return S0IsSigned;
13060
+
13061
+ // In another two permutations, the sign bits are known to be opposite. In
13062
+ // this case return std::nullopt to indicate a bad match.
13063
+
13064
+ if ((S0IsUnsigned && S1IsSigned) || (S0IsSigned && S1IsUnsigned))
13065
+ return std::nullopt;
13066
+
13067
+ // In the remaining five permutations, we don't know the value of the sign
13068
+ // bit for at least one Op. Since we have a valid ByteProvider, we know that
13069
+ // the upper bits must be extension bits. Thus, the only ways for the sign
13070
+ // bit to be unknown is if it was sign extended from unknown value, or if it
13071
+ // was any extended. In either case, it is correct to use the signed
13072
+ // version of the signedness semantics of dot4
13073
+
13074
+ // In two of such permutations, we known the sign bit is set for
13075
+ // one op, and the other is unknown. It is okay to used signed version of
13076
+ // dot4.
13077
+ if ((S0IsSigned && !(S1IsSigned || S1IsUnsigned)) ||
13078
+ ((S1IsSigned && !(S0IsSigned || S0IsUnsigned))))
13079
+ return true;
13080
+
13081
+ // In one such permutation, we don't know either of the sign bits. It is okay
13082
+ // to used the signed version of dot4.
13083
+ if ((!(S1IsSigned || S1IsUnsigned) && !(S0IsSigned || S0IsUnsigned)))
13084
+ return true;
13085
+
13086
+ // In two of such permutations, we known the sign bit is unset for
13087
+ // one op, and the other is unknown. Return std::nullopt to indicate a
13088
+ // bad match.
13089
+ if ((S0IsUnsigned && !(S1IsSigned || S1IsUnsigned)) ||
13090
+ ((S1IsUnsigned && !(S0IsSigned || S0IsUnsigned))))
13091
+ return std::nullopt;
13092
+
13093
+ llvm_unreachable("Fully covered condition");
13093
13094
}
13094
13095
13095
13096
SDValue SITargetLowering::performAddCombine(SDNode *N,
@@ -13132,8 +13133,10 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
13132
13133
if (!Src1)
13133
13134
break;
13134
13135
13135
- auto IterIsSigned =
13136
- checkSignedness(TempNode->getOperand(MulIdx), *Src0, *Src1);
13136
+ auto IterIsSigned = checkDot4MulSignedness(
13137
+ TempNode->getOperand(MulIdx), *Src0, *Src1,
13138
+ TempNode->getOperand(MulIdx)->getOperand(0),
13139
+ TempNode->getOperand(MulIdx)->getOperand(1), DAG);
13137
13140
if (!IterIsSigned)
13138
13141
break;
13139
13142
if (!IsSigned)
@@ -13154,8 +13157,10 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
13154
13157
handleMulOperand(TempNode->getOperand(AddIdx)->getOperand(1));
13155
13158
if (!Src1)
13156
13159
break;
13157
- auto IterIsSigned =
13158
- checkSignedness(TempNode->getOperand(AddIdx), *Src0, *Src1);
13160
+ auto IterIsSigned = checkDot4MulSignedness(
13161
+ TempNode->getOperand(AddIdx), *Src0, *Src1,
13162
+ TempNode->getOperand(AddIdx)->getOperand(0),
13163
+ TempNode->getOperand(AddIdx)->getOperand(1), DAG);
13159
13164
if (!IterIsSigned)
13160
13165
break;
13161
13166
assert(IsSigned);
0 commit comments