Skip to content

Commit cd16b07

Browse files
authored
IR: introduce CmpInst::isEquivalence (#111979)
Steal impliesEquivalanceIf{True,False} (sic) from GVN, and extend it for floating-point constant vectors, and accounting for denormal values. Since InstCombine also performs GVN-like replacements, introduce CmpInst::isEquivalence, and remove the corresponding code in GVN, with the intent of using it in more places. The code in GVN also has a bad FIXME saying that the optimization may be valid in the nsz case, but this is not the case. Alive2 proof: https://alive2.llvm.org/ce/z/vEaK8M
1 parent 3e8a8fc commit cd16b07

File tree

5 files changed

+77
-57
lines changed

5 files changed

+77
-57
lines changed

llvm/include/llvm/IR/InstrTypes.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -912,6 +912,11 @@ class CmpInst : public Instruction {
912912
/// Determine if this is an equals/not equals predicate.
913913
bool isEquality() const { return isEquality(getPredicate()); }
914914

915+
/// Determine if one operand of this compare can always be replaced by the
916+
/// other operand, ignoring provenance considerations. If \p Invert, check for
917+
/// equivalence with the inverse predicate.
918+
bool isEquivalence(bool Invert = false) const;
919+
915920
/// Return true if the predicate is relational (not EQ or NE).
916921
static bool isRelational(Predicate P) { return !isEquality(P); }
917922

llvm/include/llvm/IR/PatternMatch.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,16 @@ inline cstfp_pred_ty<is_non_zero_fp> m_NonZeroFP() {
792792
return cstfp_pred_ty<is_non_zero_fp>();
793793
}
794794

795+
struct is_non_zero_not_denormal_fp {
796+
bool isValue(const APFloat &C) { return !C.isDenormal() && C.isNonZero(); }
797+
};
798+
799+
/// Match a floating-point non-zero that is not a denormal.
800+
/// For vectors, this includes constants with undefined elements.
801+
inline cstfp_pred_ty<is_non_zero_not_denormal_fp> m_NonZeroNotDenormalFP() {
802+
return cstfp_pred_ty<is_non_zero_not_denormal_fp>();
803+
}
804+
795805
///////////////////////////////////////////////////////////////////////////////
796806

797807
template <typename Class> struct bind_ty {

llvm/lib/IR/Instructions.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "llvm/IR/Metadata.h"
3333
#include "llvm/IR/Module.h"
3434
#include "llvm/IR/Operator.h"
35+
#include "llvm/IR/PatternMatch.h"
3536
#include "llvm/IR/ProfDataUtils.h"
3637
#include "llvm/IR/Type.h"
3738
#include "llvm/IR/Value.h"
@@ -3471,6 +3472,36 @@ bool CmpInst::isEquality(Predicate P) {
34713472
llvm_unreachable("Unsupported predicate kind");
34723473
}
34733474

3475+
// Returns true if either operand of CmpInst is a provably non-zero
3476+
// floating-point constant.
3477+
static bool hasNonZeroFPOperands(const CmpInst *Cmp) {
3478+
auto *LHS = dyn_cast<Constant>(Cmp->getOperand(0));
3479+
auto *RHS = dyn_cast<Constant>(Cmp->getOperand(1));
3480+
if (auto *Const = LHS ? LHS : RHS) {
3481+
using namespace llvm::PatternMatch;
3482+
return match(Const, m_NonZeroNotDenormalFP());
3483+
}
3484+
return false;
3485+
}
3486+
3487+
// Floating-point equality is not an equivalence when comparing +0.0 with
3488+
// -0.0, when comparing NaN with another value, or when flushing
3489+
// denormals-to-zero.
3490+
bool CmpInst::isEquivalence(bool Invert) const {
3491+
switch (Invert ? getInversePredicate() : getPredicate()) {
3492+
case CmpInst::Predicate::ICMP_EQ:
3493+
return true;
3494+
case CmpInst::Predicate::FCMP_UEQ:
3495+
if (!hasNoNaNs())
3496+
return false;
3497+
[[fallthrough]];
3498+
case CmpInst::Predicate::FCMP_OEQ:
3499+
return hasNonZeroFPOperands(this);
3500+
default:
3501+
return false;
3502+
}
3503+
}
3504+
34743505
CmpInst::Predicate CmpInst::getInversePredicate(Predicate pred) {
34753506
switch (pred) {
34763507
default: llvm_unreachable("Unknown cmp predicate!");

llvm/lib/Transforms/Scalar/GVN.cpp

Lines changed: 2 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1989,59 +1989,6 @@ bool GVNPass::processNonLocalLoad(LoadInst *Load) {
19891989
return Changed;
19901990
}
19911991

1992-
static bool impliesEquivalanceIfTrue(CmpInst* Cmp) {
1993-
if (Cmp->getPredicate() == CmpInst::Predicate::ICMP_EQ)
1994-
return true;
1995-
1996-
// Floating point comparisons can be equal, but not equivalent. Cases:
1997-
// NaNs for unordered operators
1998-
// +0.0 vs 0.0 for all operators
1999-
if (Cmp->getPredicate() == CmpInst::Predicate::FCMP_OEQ ||
2000-
(Cmp->getPredicate() == CmpInst::Predicate::FCMP_UEQ &&
2001-
Cmp->getFastMathFlags().noNaNs())) {
2002-
Value *LHS = Cmp->getOperand(0);
2003-
Value *RHS = Cmp->getOperand(1);
2004-
// If we can prove either side non-zero, then equality must imply
2005-
// equivalence.
2006-
// FIXME: We should do this optimization if 'no signed zeros' is
2007-
// applicable via an instruction-level fast-math-flag or some other
2008-
// indicator that relaxed FP semantics are being used.
2009-
if (isa<ConstantFP>(LHS) && !cast<ConstantFP>(LHS)->isZero())
2010-
return true;
2011-
if (isa<ConstantFP>(RHS) && !cast<ConstantFP>(RHS)->isZero())
2012-
return true;
2013-
// TODO: Handle vector floating point constants
2014-
}
2015-
return false;
2016-
}
2017-
2018-
static bool impliesEquivalanceIfFalse(CmpInst* Cmp) {
2019-
if (Cmp->getPredicate() == CmpInst::Predicate::ICMP_NE)
2020-
return true;
2021-
2022-
// Floating point comparisons can be equal, but not equivelent. Cases:
2023-
// NaNs for unordered operators
2024-
// +0.0 vs 0.0 for all operators
2025-
if ((Cmp->getPredicate() == CmpInst::Predicate::FCMP_ONE &&
2026-
Cmp->getFastMathFlags().noNaNs()) ||
2027-
Cmp->getPredicate() == CmpInst::Predicate::FCMP_UNE) {
2028-
Value *LHS = Cmp->getOperand(0);
2029-
Value *RHS = Cmp->getOperand(1);
2030-
// If we can prove either side non-zero, then equality must imply
2031-
// equivalence.
2032-
// FIXME: We should do this optimization if 'no signed zeros' is
2033-
// applicable via an instruction-level fast-math-flag or some other
2034-
// indicator that relaxed FP semantics are being used.
2035-
if (isa<ConstantFP>(LHS) && !cast<ConstantFP>(LHS)->isZero())
2036-
return true;
2037-
if (isa<ConstantFP>(RHS) && !cast<ConstantFP>(RHS)->isZero())
2038-
return true;
2039-
// TODO: Handle vector floating point constants
2040-
}
2041-
return false;
2042-
}
2043-
2044-
20451992
static bool hasUsersIn(Value *V, BasicBlock *BB) {
20461993
return llvm::any_of(V->users(), [BB](User *U) {
20471994
auto *I = dyn_cast<Instruction>(U);
@@ -2143,7 +2090,7 @@ bool GVNPass::processAssumeIntrinsic(AssumeInst *IntrinsicI) {
21432090
// call void @llvm.assume(i1 %cmp)
21442091
// ret float %load ; will change it to ret float %0
21452092
if (auto *CmpI = dyn_cast<CmpInst>(V)) {
2146-
if (impliesEquivalanceIfTrue(CmpI)) {
2093+
if (CmpI->isEquivalence()) {
21472094
Value *CmpLHS = CmpI->getOperand(0);
21482095
Value *CmpRHS = CmpI->getOperand(1);
21492096
// Heuristically pick the better replacement -- the choice of heuristic
@@ -2577,8 +2524,7 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS,
25772524
// If "A == B" is known true, or "A != B" is known false, then replace
25782525
// A with B everywhere in the scope. For floating point operations, we
25792526
// have to be careful since equality does not always imply equivalance.
2580-
if ((isKnownTrue && impliesEquivalanceIfTrue(Cmp)) ||
2581-
(isKnownFalse && impliesEquivalanceIfFalse(Cmp)))
2527+
if (Cmp->isEquivalence(isKnownFalse))
25822528
Worklist.push_back(std::make_pair(Op0, Op1));
25832529

25842530
// If "A >= B" is known true, replace "A < B" with false everywhere.

llvm/test/Transforms/GVN/edge.ll

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,34 @@ return:
224224
ret double %retval
225225
}
226226

227+
; Denormals may be flushed to zero in some cases by the backend.
228+
; Hence, treat denormals as 0.
229+
define float @fcmp_oeq_denormal(float %x, float %y) {
230+
; CHECK-LABEL: define float @fcmp_oeq_denormal(
231+
; CHECK-SAME: float [[X:%.*]], float [[Y:%.*]]) {
232+
; CHECK-NEXT: [[ENTRY:.*]]:
233+
; CHECK-NEXT: [[CMP:%.*]] = fcmp oeq float [[Y]], 0x3800000000000000
234+
; CHECK-NEXT: br i1 [[CMP]], label %[[IF:.*]], label %[[RETURN:.*]]
235+
; CHECK: [[IF]]:
236+
; CHECK-NEXT: [[DIV:%.*]] = fdiv float [[X]], [[Y]]
237+
; CHECK-NEXT: br label %[[RETURN]]
238+
; CHECK: [[RETURN]]:
239+
; CHECK-NEXT: [[RETVAL:%.*]] = phi float [ [[DIV]], %[[IF]] ], [ [[X]], %[[ENTRY]] ]
240+
; CHECK-NEXT: ret float [[RETVAL]]
241+
;
242+
entry:
243+
%cmp = fcmp oeq float %y, 0x3800000000000000
244+
br i1 %cmp, label %if, label %return
245+
246+
if:
247+
%div = fdiv float %x, %y
248+
br label %return
249+
250+
return:
251+
%retval = phi float [ %div, %if ], [ %x, %entry ]
252+
ret float %retval
253+
}
254+
227255
define double @fcmp_une_zero(double %x, double %y) {
228256
; CHECK-LABEL: define double @fcmp_une_zero(
229257
; CHECK-SAME: double [[X:%.*]], double [[Y:%.*]]) {
@@ -251,7 +279,7 @@ return:
251279
}
252280

253281
; We also cannot propagate a value if it's not a constant.
254-
; This is because the value could be 0.0 or -0.0.
282+
; This is because the value could be 0.0, -0.0, or a denormal.
255283

256284
define double @fcmp_oeq_maybe_zero(double %x, double %y, double %z1, double %z2) {
257285
; CHECK-LABEL: define double @fcmp_oeq_maybe_zero(

0 commit comments

Comments
 (0)