From 28da0cf548ef21058f12838ec0429ecb377ea5ec Mon Sep 17 00:00:00 2001 From: Ramkumar Ramachandra Date: Fri, 11 Oct 2024 11:38:06 +0100 Subject: [PATCH 1/5] IR: introduce CmpInst::is{Eq,Ne}Equivalence Steal impliesEquivalanceIf{True,False} (sic) from GVN, and extend it for floating-point constant vectors. Since InstCombine also performs GVN-like replacements, introduce CmpInst::is{Eq,Ne}Equivalence, and remove the corresponding code in GVN, with the intent of using it in more places. --- llvm/include/llvm/IR/InstrTypes.h | 10 +++++ llvm/include/llvm/IR/PatternMatch.h | 10 +++++ llvm/lib/IR/Instructions.cpp | 49 ++++++++++++++++++++++++ llvm/lib/Transforms/Scalar/GVN.cpp | 59 ++--------------------------- 4 files changed, 72 insertions(+), 56 deletions(-) diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h index 99f72792ce402..85e84afda738c 100644 --- a/llvm/include/llvm/IR/InstrTypes.h +++ b/llvm/include/llvm/IR/InstrTypes.h @@ -912,6 +912,16 @@ class CmpInst : public Instruction { /// Determine if this is an equals/not equals predicate. bool isEquality() const { return isEquality(getPredicate()); } + /// Determine if this is an equals predicate that is also an equivalence. This + /// is useful in GVN-like transformations, where we can replace RHS by LHS in + /// the true branch of the CmpInst. + bool isEqEquivalence() const; + + /// Determine if this is a not-equals predicate that is also an equivalence. + /// This is useful in GVN-like transformations, where we can replace RHS by + /// LHS in the false branch of the CmpInst. + bool isNeEquivalence() const; + /// Return true if the predicate is relational (not EQ or NE). static bool isRelational(Predicate P) { return !isEquality(P); } diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h index c3349c9772c7a..0d6df72790632 100644 --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -792,6 +792,16 @@ inline cstfp_pred_ty m_NonZeroFP() { return cstfp_pred_ty(); } +struct is_non_zero_not_denormal_fp { + bool isValue(const APFloat &C) { return !C.isDenormal() && C.isNonZero(); } +}; + +/// Match a floating-point non-zero that is not a denormal. +/// For vectors, this includes constants with undefined elements. +inline cstfp_pred_ty m_NonZeroNotDenormalFP() { + return cstfp_pred_ty(); +} + /////////////////////////////////////////////////////////////////////////////// template struct bind_ty { diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp index 009e0c03957c9..98b474f5bbc36 100644 --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -32,6 +32,7 @@ #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" @@ -3471,6 +3472,54 @@ bool CmpInst::isEquality(Predicate P) { llvm_unreachable("Unsupported predicate kind"); } +// Returns true if either operand of CmpInst is a provably non-zero +// floating-point constant. +static bool hasNonZeroFPOperands(const CmpInst *Cmp) { + auto *LHS = dyn_cast(Cmp->getOperand(0)); + auto *RHS = dyn_cast(Cmp->getOperand(1)); + if (auto *Const = LHS ? LHS : RHS) { + using namespace llvm::PatternMatch; + return match(Const, m_NonZeroNotDenormalFP()); + } + return false; +} + +// Floating-point equality is not an equivalence when comparing +0.0 with +// -0.0, when comparing NaN with another value, or when flushing +// denormals-to-zero. +bool CmpInst::isEqEquivalence() const { + switch (getPredicate()) { + case CmpInst::Predicate::ICMP_EQ: + return true; + case CmpInst::Predicate::FCMP_UEQ: + if (!hasNoNaNs()) + return false; + [[fallthrough]]; + case CmpInst::Predicate::FCMP_OEQ: + return hasNonZeroFPOperands(this); + default: + return false; + } +} + +// Floating-point equality is not an equivalence when comparing +0.0 with +// -0.0, when comparing NaN with another value, or when flushing +// denormals-to-zero. +bool CmpInst::isNeEquivalence() const { + switch (getPredicate()) { + case CmpInst::Predicate::ICMP_NE: + return true; + case CmpInst::Predicate::FCMP_ONE: + if (!hasNoNaNs()) + return false; + [[fallthrough]]; + case CmpInst::Predicate::FCMP_UNE: + return hasNonZeroFPOperands(this); + default: + return false; + } +} + CmpInst::Predicate CmpInst::getInversePredicate(Predicate pred) { switch (pred) { default: llvm_unreachable("Unknown cmp predicate!"); diff --git a/llvm/lib/Transforms/Scalar/GVN.cpp b/llvm/lib/Transforms/Scalar/GVN.cpp index 2ba600497e00d..cdd2a9dc06af6 100644 --- a/llvm/lib/Transforms/Scalar/GVN.cpp +++ b/llvm/lib/Transforms/Scalar/GVN.cpp @@ -1989,59 +1989,6 @@ bool GVNPass::processNonLocalLoad(LoadInst *Load) { return Changed; } -static bool impliesEquivalanceIfTrue(CmpInst* Cmp) { - if (Cmp->getPredicate() == CmpInst::Predicate::ICMP_EQ) - return true; - - // Floating point comparisons can be equal, but not equivalent. Cases: - // NaNs for unordered operators - // +0.0 vs 0.0 for all operators - if (Cmp->getPredicate() == CmpInst::Predicate::FCMP_OEQ || - (Cmp->getPredicate() == CmpInst::Predicate::FCMP_UEQ && - Cmp->getFastMathFlags().noNaNs())) { - Value *LHS = Cmp->getOperand(0); - Value *RHS = Cmp->getOperand(1); - // If we can prove either side non-zero, then equality must imply - // equivalence. - // FIXME: We should do this optimization if 'no signed zeros' is - // applicable via an instruction-level fast-math-flag or some other - // indicator that relaxed FP semantics are being used. - if (isa(LHS) && !cast(LHS)->isZero()) - return true; - if (isa(RHS) && !cast(RHS)->isZero()) - return true; - // TODO: Handle vector floating point constants - } - return false; -} - -static bool impliesEquivalanceIfFalse(CmpInst* Cmp) { - if (Cmp->getPredicate() == CmpInst::Predicate::ICMP_NE) - return true; - - // Floating point comparisons can be equal, but not equivelent. Cases: - // NaNs for unordered operators - // +0.0 vs 0.0 for all operators - if ((Cmp->getPredicate() == CmpInst::Predicate::FCMP_ONE && - Cmp->getFastMathFlags().noNaNs()) || - Cmp->getPredicate() == CmpInst::Predicate::FCMP_UNE) { - Value *LHS = Cmp->getOperand(0); - Value *RHS = Cmp->getOperand(1); - // If we can prove either side non-zero, then equality must imply - // equivalence. - // FIXME: We should do this optimization if 'no signed zeros' is - // applicable via an instruction-level fast-math-flag or some other - // indicator that relaxed FP semantics are being used. - if (isa(LHS) && !cast(LHS)->isZero()) - return true; - if (isa(RHS) && !cast(RHS)->isZero()) - return true; - // TODO: Handle vector floating point constants - } - return false; -} - - static bool hasUsersIn(Value *V, BasicBlock *BB) { return llvm::any_of(V->users(), [BB](User *U) { auto *I = dyn_cast(U); @@ -2143,7 +2090,7 @@ bool GVNPass::processAssumeIntrinsic(AssumeInst *IntrinsicI) { // call void @llvm.assume(i1 %cmp) // ret float %load ; will change it to ret float %0 if (auto *CmpI = dyn_cast(V)) { - if (impliesEquivalanceIfTrue(CmpI)) { + if (CmpI->isEqEquivalence()) { Value *CmpLHS = CmpI->getOperand(0); Value *CmpRHS = CmpI->getOperand(1); // Heuristically pick the better replacement -- the choice of heuristic @@ -2567,8 +2514,8 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS, // If "A == B" is known true, or "A != B" is known false, then replace // A with B everywhere in the scope. For floating point operations, we // have to be careful since equality does not always imply equivalance. - if ((isKnownTrue && impliesEquivalanceIfTrue(Cmp)) || - (isKnownFalse && impliesEquivalanceIfFalse(Cmp))) + if ((isKnownTrue && Cmp->isEqEquivalence()) || + (isKnownFalse && Cmp->isNeEquivalence())) Worklist.push_back(std::make_pair(Op0, Op1)); // If "A >= B" is known true, replace "A < B" with false everywhere. From 7bcf6c98c876b2c21f5e18b43e7ee158562d2572 Mon Sep 17 00:00:00 2001 From: Ramkumar Ramachandra Date: Tue, 15 Oct 2024 12:29:06 +0100 Subject: [PATCH 2/5] GVN: add test for denormals --- llvm/test/Transforms/GVN/edge.ll | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/llvm/test/Transforms/GVN/edge.ll b/llvm/test/Transforms/GVN/edge.ll index 9703195d3b642..83c4c336f6474 100644 --- a/llvm/test/Transforms/GVN/edge.ll +++ b/llvm/test/Transforms/GVN/edge.ll @@ -224,6 +224,34 @@ return: ret double %retval } +; Denormals may be flushed to zero in some cases by the backend. +; Hence, treat denormals as 0. +define float @fcmp_oeq_denormal(float %x, float %y) { +; CHECK-LABEL: define float @fcmp_oeq_denormal( +; CHECK-SAME: float [[X:%.*]], float [[Y:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*]]: +; CHECK-NEXT: [[CMP:%.*]] = fcmp oeq float [[Y]], 0x3800000000000000 +; CHECK-NEXT: br i1 [[CMP]], label %[[IF:.*]], label %[[RETURN:.*]] +; CHECK: [[IF]]: +; CHECK-NEXT: [[DIV:%.*]] = fdiv float [[X]], [[Y]] +; CHECK-NEXT: br label %[[RETURN]] +; CHECK: [[RETURN]]: +; CHECK-NEXT: [[RETVAL:%.*]] = phi float [ [[DIV]], %[[IF]] ], [ [[X]], %[[ENTRY]] ] +; CHECK-NEXT: ret float [[RETVAL]] +; +entry: + %cmp = fcmp oeq float %y, 0x3800000000000000 + br i1 %cmp, label %if, label %return + +if: + %div = fdiv float %x, %y + br label %return + +return: + %retval = phi float [ %div, %if ], [ %x, %entry ] + ret float %retval +} + define double @fcmp_une_zero(double %x, double %y) { ; CHECK-LABEL: define double @fcmp_une_zero( ; CHECK-SAME: double [[X:%.*]], double [[Y:%.*]]) { @@ -251,7 +279,7 @@ return: } ; We also cannot propagate a value if it's not a constant. -; This is because the value could be 0.0 or -0.0. +; This is because the value could be 0.0, -0.0, or a denormal. define double @fcmp_oeq_maybe_zero(double %x, double %y, double %z1, double %z2) { ; CHECK-LABEL: define double @fcmp_oeq_maybe_zero( From e57f9eb52340f66ff564da6512a0dd16873cf9da Mon Sep 17 00:00:00 2001 From: Ramkumar Ramachandra Date: Mon, 21 Oct 2024 13:01:13 +0100 Subject: [PATCH 3/5] CmpInst: merge functions; address review --- llvm/include/llvm/IR/InstrTypes.h | 14 +++++--------- llvm/lib/IR/Instructions.cpp | 22 ++++++---------------- llvm/lib/Transforms/Scalar/GVN.cpp | 6 +++--- 3 files changed, 14 insertions(+), 28 deletions(-) diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h index 85e84afda738c..d93cd5958bc07 100644 --- a/llvm/include/llvm/IR/InstrTypes.h +++ b/llvm/include/llvm/IR/InstrTypes.h @@ -912,15 +912,11 @@ class CmpInst : public Instruction { /// Determine if this is an equals/not equals predicate. bool isEquality() const { return isEquality(getPredicate()); } - /// Determine if this is an equals predicate that is also an equivalence. This - /// is useful in GVN-like transformations, where we can replace RHS by LHS in - /// the true branch of the CmpInst. - bool isEqEquivalence() const; - - /// Determine if this is a not-equals predicate that is also an equivalence. - /// This is useful in GVN-like transformations, where we can replace RHS by - /// LHS in the false branch of the CmpInst. - bool isNeEquivalence() const; + /// Determine if one operand of this compare can always be replaced by the + /// other operand, ignoring provenance considerations. If \p Invert is false, + /// check for equivalence with an equals predicate; otherwise, check for + /// equivalence with a not-equals predicate. + bool isEquivalence(bool Invert = false) const; /// Return true if the predicate is relational (not EQ or NE). static bool isRelational(Predicate P) { return !isEquality(P); } diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp index 98b474f5bbc36..63f3568f359b8 100644 --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -3487,34 +3487,24 @@ static bool hasNonZeroFPOperands(const CmpInst *Cmp) { // Floating-point equality is not an equivalence when comparing +0.0 with // -0.0, when comparing NaN with another value, or when flushing // denormals-to-zero. -bool CmpInst::isEqEquivalence() const { +bool CmpInst::isEquivalence(bool Invert) const { switch (getPredicate()) { case CmpInst::Predicate::ICMP_EQ: - return true; + return !Invert; + case CmpInst::Predicate::ICMP_NE: + return Invert; case CmpInst::Predicate::FCMP_UEQ: if (!hasNoNaNs()) return false; [[fallthrough]]; case CmpInst::Predicate::FCMP_OEQ: - return hasNonZeroFPOperands(this); - default: - return false; - } -} - -// Floating-point equality is not an equivalence when comparing +0.0 with -// -0.0, when comparing NaN with another value, or when flushing -// denormals-to-zero. -bool CmpInst::isNeEquivalence() const { - switch (getPredicate()) { - case CmpInst::Predicate::ICMP_NE: - return true; + return !Invert && hasNonZeroFPOperands(this); case CmpInst::Predicate::FCMP_ONE: if (!hasNoNaNs()) return false; [[fallthrough]]; case CmpInst::Predicate::FCMP_UNE: - return hasNonZeroFPOperands(this); + return Invert && hasNonZeroFPOperands(this); default: return false; } diff --git a/llvm/lib/Transforms/Scalar/GVN.cpp b/llvm/lib/Transforms/Scalar/GVN.cpp index cdd2a9dc06af6..adfac2b5914e8 100644 --- a/llvm/lib/Transforms/Scalar/GVN.cpp +++ b/llvm/lib/Transforms/Scalar/GVN.cpp @@ -2090,7 +2090,7 @@ bool GVNPass::processAssumeIntrinsic(AssumeInst *IntrinsicI) { // call void @llvm.assume(i1 %cmp) // ret float %load ; will change it to ret float %0 if (auto *CmpI = dyn_cast(V)) { - if (CmpI->isEqEquivalence()) { + if (CmpI->isEquivalence()) { Value *CmpLHS = CmpI->getOperand(0); Value *CmpRHS = CmpI->getOperand(1); // Heuristically pick the better replacement -- the choice of heuristic @@ -2514,8 +2514,8 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS, // If "A == B" is known true, or "A != B" is known false, then replace // A with B everywhere in the scope. For floating point operations, we // have to be careful since equality does not always imply equivalance. - if ((isKnownTrue && Cmp->isEqEquivalence()) || - (isKnownFalse && Cmp->isNeEquivalence())) + if ((isKnownTrue && Cmp->isEquivalence()) || + (isKnownFalse && Cmp->isEquivalence(/* Invert = */ true))) Worklist.push_back(std::make_pair(Op0, Op1)); // If "A >= B" is known true, replace "A < B" with false everywhere. From 7fe374186cd4e52e828efce894c58afd34b7c4b4 Mon Sep 17 00:00:00 2001 From: Ramkumar Ramachandra Date: Mon, 21 Oct 2024 13:49:25 +0100 Subject: [PATCH 4/5] CmpInst: de-duplicate code; address review --- llvm/include/llvm/IR/InstrTypes.h | 5 ++--- llvm/lib/IR/Instructions.cpp | 14 +++----------- 2 files changed, 5 insertions(+), 14 deletions(-) diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h index d93cd5958bc07..1c60eae7f2f85 100644 --- a/llvm/include/llvm/IR/InstrTypes.h +++ b/llvm/include/llvm/IR/InstrTypes.h @@ -913,9 +913,8 @@ class CmpInst : public Instruction { bool isEquality() const { return isEquality(getPredicate()); } /// Determine if one operand of this compare can always be replaced by the - /// other operand, ignoring provenance considerations. If \p Invert is false, - /// check for equivalence with an equals predicate; otherwise, check for - /// equivalence with a not-equals predicate. + /// other operand, ignoring provenance considerations. If \p Invert, check for + /// equivalence with the inverse predicate. bool isEquivalence(bool Invert = false) const; /// Return true if the predicate is relational (not EQ or NE). diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp index 63f3568f359b8..05e340ffa20a0 100644 --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -3488,23 +3488,15 @@ static bool hasNonZeroFPOperands(const CmpInst *Cmp) { // -0.0, when comparing NaN with another value, or when flushing // denormals-to-zero. bool CmpInst::isEquivalence(bool Invert) const { - switch (getPredicate()) { + switch (Invert ? getInversePredicate() : getPredicate()) { case CmpInst::Predicate::ICMP_EQ: - return !Invert; - case CmpInst::Predicate::ICMP_NE: - return Invert; + return true; case CmpInst::Predicate::FCMP_UEQ: if (!hasNoNaNs()) return false; [[fallthrough]]; case CmpInst::Predicate::FCMP_OEQ: - return !Invert && hasNonZeroFPOperands(this); - case CmpInst::Predicate::FCMP_ONE: - if (!hasNoNaNs()) - return false; - [[fallthrough]]; - case CmpInst::Predicate::FCMP_UNE: - return Invert && hasNonZeroFPOperands(this); + return hasNonZeroFPOperands(this); default: return false; } From 476594f9809a7af3ee6361a2b165a671d37c7e6c Mon Sep 17 00:00:00 2001 From: Ramkumar Ramachandra Date: Mon, 4 Nov 2024 13:24:01 +0000 Subject: [PATCH 5/5] GVN: fix nit --- llvm/lib/Transforms/Scalar/GVN.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/llvm/lib/Transforms/Scalar/GVN.cpp b/llvm/lib/Transforms/Scalar/GVN.cpp index adfac2b5914e8..56b7e374ad69c 100644 --- a/llvm/lib/Transforms/Scalar/GVN.cpp +++ b/llvm/lib/Transforms/Scalar/GVN.cpp @@ -2514,8 +2514,7 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS, // If "A == B" is known true, or "A != B" is known false, then replace // A with B everywhere in the scope. For floating point operations, we // have to be careful since equality does not always imply equivalance. - if ((isKnownTrue && Cmp->isEquivalence()) || - (isKnownFalse && Cmp->isEquivalence(/* Invert = */ true))) + if (Cmp->isEquivalence(isKnownFalse)) Worklist.push_back(std::make_pair(Op0, Op1)); // If "A >= B" is known true, replace "A < B" with false everywhere.