Skip to content

Commit 1b02f59

Browse files
authored
[AMDGPU] Rework dot4 signedness checks (#68757)
Using the known/unknown value of the sign bit, reason about the signedness version of the dot4 instruction.
1 parent 83aa725 commit 1b02f59

File tree

3 files changed

+516
-99
lines changed

3 files changed

+516
-99
lines changed

llvm/include/llvm/CodeGen/ByteProvider.h

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,6 @@ template <typename ISelOp> class ByteProvider {
3232
ByteProvider(std::optional<ISelOp> Src, int64_t DestOffset, int64_t SrcOffset)
3333
: Src(Src), DestOffset(DestOffset), SrcOffset(SrcOffset) {}
3434

35-
ByteProvider(std::optional<ISelOp> Src, int64_t DestOffset, int64_t SrcOffset,
36-
std::optional<bool> IsSigned)
37-
: Src(Src), DestOffset(DestOffset), SrcOffset(SrcOffset),
38-
IsSigned(IsSigned) {}
39-
4035
// TODO -- use constraint in c++20
4136
// Does this type correspond with an operation in selection DAG
4237
template <typename T> class is_op {
@@ -66,9 +61,6 @@ template <typename ISelOp> class ByteProvider {
6661
// DestOffset
6762
int64_t SrcOffset = 0;
6863

69-
// Whether or not the path to this Src involved signed extensions
70-
std::optional<bool> IsSigned;
71-
7264
ByteProvider() = default;
7365

7466
static ByteProvider getSrc(std::optional<ISelOp> Val, int64_t ByteOffset,
@@ -78,14 +70,6 @@ template <typename ISelOp> class ByteProvider {
7870
return ByteProvider(Val, ByteOffset, VectorOffset);
7971
}
8072

81-
static ByteProvider getSrc(std::optional<ISelOp> Val, int64_t ByteOffset,
82-
int64_t VectorOffset,
83-
std::optional<bool> IsSigned) {
84-
static_assert(is_op<ISelOp>().value,
85-
"ByteProviders must contain an operation in selection DAG.");
86-
return ByteProvider(Val, ByteOffset, VectorOffset, IsSigned);
87-
}
88-
8973
static ByteProvider getConstantZero() {
9074
return ByteProvider<ISelOp>(std::nullopt, 0, 0);
9175
}

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 88 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -10940,7 +10940,6 @@ SDValue SITargetLowering::performAndCombine(SDNode *N,
1094010940
// performed.
1094110941
static const std::optional<ByteProvider<SDValue>>
1094210942
calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
10943-
std::optional<bool> IsSigned = std::nullopt,
1094410943
unsigned Depth = 0) {
1094510944
// We may need to recursively traverse a series of SRLs
1094610945
if (Depth >= 6)
@@ -10952,16 +10951,12 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
1095210951

1095310952
switch (Op->getOpcode()) {
1095410953
case ISD::TRUNCATE: {
10955-
return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, IsSigned,
10956-
Depth + 1);
10954+
return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, Depth + 1);
1095710955
}
1095810956

1095910957
case ISD::SIGN_EXTEND:
1096010958
case ISD::ZERO_EXTEND:
1096110959
case ISD::SIGN_EXTEND_INREG: {
10962-
IsSigned = IsSigned.value_or(false) ||
10963-
Op->getOpcode() == ISD::SIGN_EXTEND ||
10964-
Op->getOpcode() == ISD::SIGN_EXTEND_INREG;
1096510960
SDValue NarrowOp = Op->getOperand(0);
1096610961
auto NarrowVT = NarrowOp.getValueType();
1096710962
if (Op->getOpcode() == ISD::SIGN_EXTEND_INREG) {
@@ -10974,8 +10969,7 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
1097410969

1097510970
if (SrcIndex >= NarrowByteWidth)
1097610971
return std::nullopt;
10977-
return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, IsSigned,
10978-
Depth + 1);
10972+
return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, Depth + 1);
1097910973
}
1098010974

1098110975
case ISD::SRA:
@@ -10991,24 +10985,11 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
1099110985

1099210986
SrcIndex += BitShift / 8;
1099310987

10994-
return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, IsSigned,
10995-
Depth + 1);
10988+
return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, Depth + 1);
1099610989
}
1099710990

1099810991
default: {
10999-
if (isa<AtomicSDNode>(Op) || Op->isMemIntrinsic()) {
11000-
// If this causes us to throw away signedness info, then fail.
11001-
if (IsSigned)
11002-
return std::nullopt;
11003-
return ByteProvider<SDValue>::getSrc(Op, DestByte, SrcIndex);
11004-
}
11005-
11006-
if (auto L = dyn_cast<LoadSDNode>(Op))
11007-
if (L->getExtensionType() != ISD::NON_EXTLOAD)
11008-
IsSigned =
11009-
IsSigned.value_or(false) || L->getExtensionType() == ISD::SEXTLOAD;
11010-
11011-
return ByteProvider<SDValue>::getSrc(Op, DestByte, SrcIndex, IsSigned);
10992+
return ByteProvider<SDValue>::getSrc(Op, DestByte, SrcIndex);
1101210993
}
1101310994
}
1101410995
llvm_unreachable("fully handled switch");
@@ -11022,8 +11003,7 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
1102211003
// performed. \p StartingIndex is the originally requested byte of the Or
1102311004
static const std::optional<ByteProvider<SDValue>>
1102411005
calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
11025-
unsigned StartingIndex = 0,
11026-
std::optional<bool> IsSigned = std::nullopt) {
11006+
unsigned StartingIndex = 0) {
1102711007
// Finding Src tree of RHS of or typically requires at least 1 additional
1102811008
// depth
1102911009
if (Depth > 6)
@@ -11038,11 +11018,11 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
1103811018
switch (Op.getOpcode()) {
1103911019
case ISD::OR: {
1104011020
auto RHS = calculateByteProvider(Op.getOperand(1), Index, Depth + 1,
11041-
StartingIndex, IsSigned);
11021+
StartingIndex);
1104211022
if (!RHS)
1104311023
return std::nullopt;
1104411024
auto LHS = calculateByteProvider(Op.getOperand(0), Index, Depth + 1,
11045-
StartingIndex, IsSigned);
11025+
StartingIndex);
1104611026
if (!LHS)
1104711027
return std::nullopt;
1104811028
// A well formed Or will have two ByteProviders for each byte, one of which
@@ -11073,7 +11053,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
1107311053
return ByteProvider<SDValue>::getConstantZero();
1107411054
}
1107511055

11076-
return calculateSrcByte(Op->getOperand(0), StartingIndex, Index, IsSigned);
11056+
return calculateSrcByte(Op->getOperand(0), StartingIndex, Index);
1107711057
}
1107811058

1107911059
case ISD::FSHR: {
@@ -11122,7 +11102,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
1112211102
// the SRL is Index + ByteShift
1112311103
return BytesProvided - ByteShift > Index
1112411104
? calculateSrcByte(Op->getOperand(0), StartingIndex,
11125-
Index + ByteShift, IsSigned)
11105+
Index + ByteShift)
1112611106
: ByteProvider<SDValue>::getConstantZero();
1112711107
}
1112811108

@@ -11143,7 +11123,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
1114311123
return Index < ByteShift
1114411124
? ByteProvider<SDValue>::getConstantZero()
1114511125
: calculateByteProvider(Op.getOperand(0), Index - ByteShift,
11146-
Depth + 1, StartingIndex, IsSigned);
11126+
Depth + 1, StartingIndex);
1114711127
}
1114811128
case ISD::ANY_EXTEND:
1114911129
case ISD::SIGN_EXTEND:
@@ -11163,48 +11143,35 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
1116311143
return std::nullopt;
1116411144
uint64_t NarrowByteWidth = NarrowBitWidth / 8;
1116511145

11166-
IsSigned =
11167-
Op->getOpcode() != ISD::ANY_EXTEND
11168-
? std::optional<bool>(IsSigned.value_or(false) ||
11169-
Op->getOpcode() == ISD::SIGN_EXTEND ||
11170-
Op->getOpcode() == ISD::SIGN_EXTEND_INREG ||
11171-
Op->getOpcode() == ISD::AssertSext)
11172-
: IsSigned;
11173-
1117411146
if (Index >= NarrowByteWidth)
1117511147
return Op.getOpcode() == ISD::ZERO_EXTEND
1117611148
? std::optional<ByteProvider<SDValue>>(
1117711149
ByteProvider<SDValue>::getConstantZero())
1117811150
: std::nullopt;
11179-
return calculateByteProvider(NarrowOp, Index, Depth + 1, StartingIndex,
11180-
IsSigned);
11151+
return calculateByteProvider(NarrowOp, Index, Depth + 1, StartingIndex);
1118111152
}
1118211153

1118311154
case ISD::TRUNCATE: {
1118411155
uint64_t NarrowByteWidth = BitWidth / 8;
1118511156

1118611157
if (NarrowByteWidth >= Index) {
1118711158
return calculateByteProvider(Op.getOperand(0), Index, Depth + 1,
11188-
StartingIndex, IsSigned);
11159+
StartingIndex);
1118911160
}
1119011161

1119111162
return std::nullopt;
1119211163
}
1119311164

1119411165
case ISD::CopyFromReg: {
1119511166
if (BitWidth / 8 > Index)
11196-
return calculateSrcByte(Op, StartingIndex, Index, IsSigned);
11167+
return calculateSrcByte(Op, StartingIndex, Index);
1119711168

1119811169
return std::nullopt;
1119911170
}
1120011171

1120111172
case ISD::LOAD: {
1120211173
auto L = cast<LoadSDNode>(Op.getNode());
1120311174

11204-
// Only set IsSigned if the load is extended.
11205-
if (L->getExtensionType() != ISD::NON_EXTLOAD)
11206-
IsSigned =
11207-
IsSigned.value_or(false) || L->getExtensionType() == ISD::SEXTLOAD;
1120811175
unsigned NarrowBitWidth = L->getMemoryVT().getSizeInBits();
1120911176
if (NarrowBitWidth % 8 != 0)
1121011177
return std::nullopt;
@@ -11221,15 +11188,15 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
1122111188
}
1122211189

1122311190
if (NarrowByteWidth > Index) {
11224-
return calculateSrcByte(Op, StartingIndex, Index, IsSigned);
11191+
return calculateSrcByte(Op, StartingIndex, Index);
1122511192
}
1122611193

1122711194
return std::nullopt;
1122811195
}
1122911196

1123011197
case ISD::BSWAP:
1123111198
return calculateByteProvider(Op->getOperand(0), BitWidth / 8 - Index - 1,
11232-
Depth + 1, StartingIndex, IsSigned);
11199+
Depth + 1, StartingIndex);
1123311200

1123411201
case ISD::EXTRACT_VECTOR_ELT: {
1123511202
auto IdxOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
@@ -11244,7 +11211,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
1124411211
}
1124511212

1124611213
return calculateSrcByte(ScalarSize == 32 ? Op : Op.getOperand(0),
11247-
StartingIndex, Index, IsSigned);
11214+
StartingIndex, Index);
1124811215
}
1124911216

1125011217
case AMDGPUISD::PERM: {
@@ -11260,10 +11227,9 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
1126011227
auto NextOp = Op.getOperand(IdxMask > 0x03 ? 0 : 1);
1126111228
auto NextIndex = IdxMask > 0x03 ? IdxMask % 4 : IdxMask;
1126211229

11263-
return IdxMask != 0x0c
11264-
? calculateSrcByte(NextOp, StartingIndex, NextIndex, IsSigned)
11265-
: ByteProvider<SDValue>(
11266-
ByteProvider<SDValue>::getConstantZero());
11230+
return IdxMask != 0x0c ? calculateSrcByte(NextOp, StartingIndex, NextIndex)
11231+
: ByteProvider<SDValue>(
11232+
ByteProvider<SDValue>::getConstantZero());
1126711233
}
1126811234

1126911235
default: {
@@ -13064,32 +13030,67 @@ static bool isMul(const SDValue Op) {
1306413030
Opcode == AMDGPUISD::MUL_I24);
1306513031
}
1306613032

13067-
static std::optional<bool> checkSignedness(const SDValue &N,
13068-
ByteProvider<SDValue> &Src0,
13069-
ByteProvider<SDValue> &Src1) {
13070-
auto MulOpcode = N.getOpcode();
13071-
std::optional<bool> IterIsSigned;
13072-
// Both sides of the tree must have the same signedness semantics.
13073-
if ((Src0.IsSigned != Src1.IsSigned) ||
13074-
(Src0.IsSigned.value_or(false) != Src1.IsSigned.value_or(false)))
13075-
return IterIsSigned;
13076-
// If we have a MUL_U24 op with signed semantics, then fail.
13077-
if (Src0.IsSigned.value_or(false) && MulOpcode == AMDGPUISD::MUL_U24)
13078-
return IterIsSigned;
13079-
// If we have a MUL_I24 op with unsigned semantics, then fail.
13080-
if (!Src0.IsSigned.value_or(true) && MulOpcode == AMDGPUISD::MUL_I24)
13081-
return IterIsSigned;
13082-
13083-
bool TopLevelSignedness =
13084-
MulOpcode == AMDGPUISD::MUL_I24 ||
13085-
(MulOpcode == ISD::MUL && N.getNode()->getFlags().hasNoSignedWrap() &&
13086-
!N.getNode()->getFlags().hasNoUnsignedWrap());
13087-
13088-
// In cases where we are accumulating into an i8 (for v_dot4), the
13089-
// ByteProvider will not have signedness info since the MSBs are dont-cares.
13090-
// In this case, we simply use the TopLevelSignedness of the instruction.
13091-
IterIsSigned = Src0.IsSigned.value_or(TopLevelSignedness);
13092-
return IterIsSigned;
13033+
static std::optional<bool>
13034+
checkDot4MulSignedness(const SDValue &N, ByteProvider<SDValue> &Src0,
13035+
ByteProvider<SDValue> &Src1, const SDValue &S0Op,
13036+
const SDValue &S1Op, const SelectionDAG &DAG) {
13037+
// If we both ops are i8s (pre legalize-dag), then the signedness semantics
13038+
// of the dot4 is irrelevant.
13039+
if (S0Op.getValueSizeInBits() == 8 && S1Op.getValueSizeInBits() == 8)
13040+
return false;
13041+
13042+
auto Known0 = DAG.computeKnownBits(S0Op, 0);
13043+
bool S0IsUnsigned = Known0.countMinLeadingZeros() > 0;
13044+
bool S0IsSigned = Known0.countMinLeadingOnes() > 0;
13045+
auto Known1 = DAG.computeKnownBits(S1Op, 0);
13046+
bool S1IsUnsigned = Known1.countMinLeadingZeros() > 0;
13047+
bool S1IsSigned = Known1.countMinLeadingOnes() > 0;
13048+
13049+
assert(!(S0IsUnsigned && S0IsSigned));
13050+
assert(!(S1IsUnsigned && S1IsSigned));
13051+
13052+
// There are 9 possible permutations of
13053+
// {S0IsUnsigned, S0IsSigned, S1IsUnsigned, S1IsSigned}
13054+
13055+
// In two permutations, the sign bits are known to be the same for both Ops,
13056+
// so simply return Signed / Unsigned corresponding to the MSB
13057+
13058+
if ((S0IsUnsigned && S1IsUnsigned) || (S0IsSigned && S1IsSigned))
13059+
return S0IsSigned;
13060+
13061+
// In another two permutations, the sign bits are known to be opposite. In
13062+
// this case return std::nullopt to indicate a bad match.
13063+
13064+
if ((S0IsUnsigned && S1IsSigned) || (S0IsSigned && S1IsUnsigned))
13065+
return std::nullopt;
13066+
13067+
// In the remaining five permutations, we don't know the value of the sign
13068+
// bit for at least one Op. Since we have a valid ByteProvider, we know that
13069+
// the upper bits must be extension bits. Thus, the only ways for the sign
13070+
// bit to be unknown is if it was sign extended from unknown value, or if it
13071+
// was any extended. In either case, it is correct to use the signed
13072+
// version of the signedness semantics of dot4
13073+
13074+
// In two of such permutations, we known the sign bit is set for
13075+
// one op, and the other is unknown. It is okay to used signed version of
13076+
// dot4.
13077+
if ((S0IsSigned && !(S1IsSigned || S1IsUnsigned)) ||
13078+
((S1IsSigned && !(S0IsSigned || S0IsUnsigned))))
13079+
return true;
13080+
13081+
// In one such permutation, we don't know either of the sign bits. It is okay
13082+
// to used the signed version of dot4.
13083+
if ((!(S1IsSigned || S1IsUnsigned) && !(S0IsSigned || S0IsUnsigned)))
13084+
return true;
13085+
13086+
// In two of such permutations, we known the sign bit is unset for
13087+
// one op, and the other is unknown. Return std::nullopt to indicate a
13088+
// bad match.
13089+
if ((S0IsUnsigned && !(S1IsSigned || S1IsUnsigned)) ||
13090+
((S1IsUnsigned && !(S0IsSigned || S0IsUnsigned))))
13091+
return std::nullopt;
13092+
13093+
llvm_unreachable("Fully covered condition");
1309313094
}
1309413095

1309513096
SDValue SITargetLowering::performAddCombine(SDNode *N,
@@ -13132,8 +13133,10 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
1313213133
if (!Src1)
1313313134
break;
1313413135

13135-
auto IterIsSigned =
13136-
checkSignedness(TempNode->getOperand(MulIdx), *Src0, *Src1);
13136+
auto IterIsSigned = checkDot4MulSignedness(
13137+
TempNode->getOperand(MulIdx), *Src0, *Src1,
13138+
TempNode->getOperand(MulIdx)->getOperand(0),
13139+
TempNode->getOperand(MulIdx)->getOperand(1), DAG);
1313713140
if (!IterIsSigned)
1313813141
break;
1313913142
if (!IsSigned)
@@ -13154,8 +13157,10 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
1315413157
handleMulOperand(TempNode->getOperand(AddIdx)->getOperand(1));
1315513158
if (!Src1)
1315613159
break;
13157-
auto IterIsSigned =
13158-
checkSignedness(TempNode->getOperand(AddIdx), *Src0, *Src1);
13160+
auto IterIsSigned = checkDot4MulSignedness(
13161+
TempNode->getOperand(AddIdx), *Src0, *Src1,
13162+
TempNode->getOperand(AddIdx)->getOperand(0),
13163+
TempNode->getOperand(AddIdx)->getOperand(1), DAG);
1315913164
if (!IterIsSigned)
1316013165
break;
1316113166
assert(IsSigned);

0 commit comments

Comments
 (0)