Skip to content

Commit 28da0cf

Browse files
committed
IR: introduce CmpInst::is{Eq,Ne}Equivalence
Steal impliesEquivalanceIf{True,False} (sic) from GVN, and extend it for floating-point constant vectors. Since InstCombine also performs GVN-like replacements, introduce CmpInst::is{Eq,Ne}Equivalence, and remove the corresponding code in GVN, with the intent of using it in more places.
1 parent a705838 commit 28da0cf

File tree

4 files changed

+72
-56
lines changed

4 files changed

+72
-56
lines changed

llvm/include/llvm/IR/InstrTypes.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -912,6 +912,16 @@ class CmpInst : public Instruction {
912912
/// Determine if this is an equals/not equals predicate.
913913
bool isEquality() const { return isEquality(getPredicate()); }
914914

915+
/// Determine if this is an equals predicate that is also an equivalence. This
916+
/// is useful in GVN-like transformations, where we can replace RHS by LHS in
917+
/// the true branch of the CmpInst.
918+
bool isEqEquivalence() const;
919+
920+
/// Determine if this is a not-equals predicate that is also an equivalence.
921+
/// This is useful in GVN-like transformations, where we can replace RHS by
922+
/// LHS in the false branch of the CmpInst.
923+
bool isNeEquivalence() const;
924+
915925
/// Return true if the predicate is relational (not EQ or NE).
916926
static bool isRelational(Predicate P) { return !isEquality(P); }
917927

llvm/include/llvm/IR/PatternMatch.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,16 @@ inline cstfp_pred_ty<is_non_zero_fp> m_NonZeroFP() {
792792
return cstfp_pred_ty<is_non_zero_fp>();
793793
}
794794

795+
struct is_non_zero_not_denormal_fp {
796+
bool isValue(const APFloat &C) { return !C.isDenormal() && C.isNonZero(); }
797+
};
798+
799+
/// Match a floating-point non-zero that is not a denormal.
800+
/// For vectors, this includes constants with undefined elements.
801+
inline cstfp_pred_ty<is_non_zero_not_denormal_fp> m_NonZeroNotDenormalFP() {
802+
return cstfp_pred_ty<is_non_zero_not_denormal_fp>();
803+
}
804+
795805
///////////////////////////////////////////////////////////////////////////////
796806

797807
template <typename Class> struct bind_ty {

llvm/lib/IR/Instructions.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "llvm/IR/Metadata.h"
3333
#include "llvm/IR/Module.h"
3434
#include "llvm/IR/Operator.h"
35+
#include "llvm/IR/PatternMatch.h"
3536
#include "llvm/IR/ProfDataUtils.h"
3637
#include "llvm/IR/Type.h"
3738
#include "llvm/IR/Value.h"
@@ -3471,6 +3472,54 @@ bool CmpInst::isEquality(Predicate P) {
34713472
llvm_unreachable("Unsupported predicate kind");
34723473
}
34733474

3475+
// Returns true if either operand of CmpInst is a provably non-zero
3476+
// floating-point constant.
3477+
static bool hasNonZeroFPOperands(const CmpInst *Cmp) {
3478+
auto *LHS = dyn_cast<Constant>(Cmp->getOperand(0));
3479+
auto *RHS = dyn_cast<Constant>(Cmp->getOperand(1));
3480+
if (auto *Const = LHS ? LHS : RHS) {
3481+
using namespace llvm::PatternMatch;
3482+
return match(Const, m_NonZeroNotDenormalFP());
3483+
}
3484+
return false;
3485+
}
3486+
3487+
// Floating-point equality is not an equivalence when comparing +0.0 with
3488+
// -0.0, when comparing NaN with another value, or when flushing
3489+
// denormals-to-zero.
3490+
bool CmpInst::isEqEquivalence() const {
3491+
switch (getPredicate()) {
3492+
case CmpInst::Predicate::ICMP_EQ:
3493+
return true;
3494+
case CmpInst::Predicate::FCMP_UEQ:
3495+
if (!hasNoNaNs())
3496+
return false;
3497+
[[fallthrough]];
3498+
case CmpInst::Predicate::FCMP_OEQ:
3499+
return hasNonZeroFPOperands(this);
3500+
default:
3501+
return false;
3502+
}
3503+
}
3504+
3505+
// Floating-point equality is not an equivalence when comparing +0.0 with
3506+
// -0.0, when comparing NaN with another value, or when flushing
3507+
// denormals-to-zero.
3508+
bool CmpInst::isNeEquivalence() const {
3509+
switch (getPredicate()) {
3510+
case CmpInst::Predicate::ICMP_NE:
3511+
return true;
3512+
case CmpInst::Predicate::FCMP_ONE:
3513+
if (!hasNoNaNs())
3514+
return false;
3515+
[[fallthrough]];
3516+
case CmpInst::Predicate::FCMP_UNE:
3517+
return hasNonZeroFPOperands(this);
3518+
default:
3519+
return false;
3520+
}
3521+
}
3522+
34743523
CmpInst::Predicate CmpInst::getInversePredicate(Predicate pred) {
34753524
switch (pred) {
34763525
default: llvm_unreachable("Unknown cmp predicate!");

llvm/lib/Transforms/Scalar/GVN.cpp

Lines changed: 3 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1989,59 +1989,6 @@ bool GVNPass::processNonLocalLoad(LoadInst *Load) {
19891989
return Changed;
19901990
}
19911991

1992-
static bool impliesEquivalanceIfTrue(CmpInst* Cmp) {
1993-
if (Cmp->getPredicate() == CmpInst::Predicate::ICMP_EQ)
1994-
return true;
1995-
1996-
// Floating point comparisons can be equal, but not equivalent. Cases:
1997-
// NaNs for unordered operators
1998-
// +0.0 vs 0.0 for all operators
1999-
if (Cmp->getPredicate() == CmpInst::Predicate::FCMP_OEQ ||
2000-
(Cmp->getPredicate() == CmpInst::Predicate::FCMP_UEQ &&
2001-
Cmp->getFastMathFlags().noNaNs())) {
2002-
Value *LHS = Cmp->getOperand(0);
2003-
Value *RHS = Cmp->getOperand(1);
2004-
// If we can prove either side non-zero, then equality must imply
2005-
// equivalence.
2006-
// FIXME: We should do this optimization if 'no signed zeros' is
2007-
// applicable via an instruction-level fast-math-flag or some other
2008-
// indicator that relaxed FP semantics are being used.
2009-
if (isa<ConstantFP>(LHS) && !cast<ConstantFP>(LHS)->isZero())
2010-
return true;
2011-
if (isa<ConstantFP>(RHS) && !cast<ConstantFP>(RHS)->isZero())
2012-
return true;
2013-
// TODO: Handle vector floating point constants
2014-
}
2015-
return false;
2016-
}
2017-
2018-
static bool impliesEquivalanceIfFalse(CmpInst* Cmp) {
2019-
if (Cmp->getPredicate() == CmpInst::Predicate::ICMP_NE)
2020-
return true;
2021-
2022-
// Floating point comparisons can be equal, but not equivelent. Cases:
2023-
// NaNs for unordered operators
2024-
// +0.0 vs 0.0 for all operators
2025-
if ((Cmp->getPredicate() == CmpInst::Predicate::FCMP_ONE &&
2026-
Cmp->getFastMathFlags().noNaNs()) ||
2027-
Cmp->getPredicate() == CmpInst::Predicate::FCMP_UNE) {
2028-
Value *LHS = Cmp->getOperand(0);
2029-
Value *RHS = Cmp->getOperand(1);
2030-
// If we can prove either side non-zero, then equality must imply
2031-
// equivalence.
2032-
// FIXME: We should do this optimization if 'no signed zeros' is
2033-
// applicable via an instruction-level fast-math-flag or some other
2034-
// indicator that relaxed FP semantics are being used.
2035-
if (isa<ConstantFP>(LHS) && !cast<ConstantFP>(LHS)->isZero())
2036-
return true;
2037-
if (isa<ConstantFP>(RHS) && !cast<ConstantFP>(RHS)->isZero())
2038-
return true;
2039-
// TODO: Handle vector floating point constants
2040-
}
2041-
return false;
2042-
}
2043-
2044-
20451992
static bool hasUsersIn(Value *V, BasicBlock *BB) {
20461993
return llvm::any_of(V->users(), [BB](User *U) {
20471994
auto *I = dyn_cast<Instruction>(U);
@@ -2143,7 +2090,7 @@ bool GVNPass::processAssumeIntrinsic(AssumeInst *IntrinsicI) {
21432090
// call void @llvm.assume(i1 %cmp)
21442091
// ret float %load ; will change it to ret float %0
21452092
if (auto *CmpI = dyn_cast<CmpInst>(V)) {
2146-
if (impliesEquivalanceIfTrue(CmpI)) {
2093+
if (CmpI->isEqEquivalence()) {
21472094
Value *CmpLHS = CmpI->getOperand(0);
21482095
Value *CmpRHS = CmpI->getOperand(1);
21492096
// Heuristically pick the better replacement -- the choice of heuristic
@@ -2567,8 +2514,8 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS,
25672514
// If "A == B" is known true, or "A != B" is known false, then replace
25682515
// A with B everywhere in the scope. For floating point operations, we
25692516
// have to be careful since equality does not always imply equivalance.
2570-
if ((isKnownTrue && impliesEquivalanceIfTrue(Cmp)) ||
2571-
(isKnownFalse && impliesEquivalanceIfFalse(Cmp)))
2517+
if ((isKnownTrue && Cmp->isEqEquivalence()) ||
2518+
(isKnownFalse && Cmp->isNeEquivalence()))
25722519
Worklist.push_back(std::make_pair(Op0, Op1));
25732520

25742521
// If "A >= B" is known true, replace "A < B" with false everywhere.

0 commit comments

Comments
 (0)