Skip to content

Commit 2034f2f

Browse files
authored
[mlir][intrange] Use nsw,nuw flags in inference (#92642)
This patch includes the "no signed wrap" and "no unsigned wrap" flags, which can be used to annotate some Ops in the `arith` dialect and also in LLVMIR, in the integer range inference. The general approach is to use saturating arithmetic operations to infer bounds which are assumed to not wrap and use overflowing arithmetic operations in the normal case. If overflow is detected in the normal case, special handling makes sure that we don't underestimate the result range.
1 parent 1015f51 commit 2034f2f

File tree

6 files changed

+250
-53
lines changed

6 files changed

+250
-53
lines changed

mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "mlir/Interfaces/InferIntRangeInterface.h"
1818
#include "llvm/ADT/ArrayRef.h"
19+
#include "llvm/ADT/BitmaskEnum.h"
1920
#include <optional>
2021

2122
namespace mlir {
@@ -31,6 +32,18 @@ static constexpr unsigned indexMaxWidth = 64;
3132

3233
enum class CmpMode : uint32_t { Both, Signed, Unsigned };
3334

35+
enum class OverflowFlags : uint32_t {
36+
None = 0,
37+
Nsw = 1,
38+
Nuw = 2,
39+
LLVM_MARK_AS_BITMASK_ENUM(Nuw)
40+
};
41+
42+
/// Function that performs inference on an array of `ConstantIntRanges` while
43+
/// taking special overflow behavior into account.
44+
using InferRangeWithOvfFlagsFn =
45+
function_ref<ConstantIntRanges(ArrayRef<ConstantIntRanges>, OverflowFlags)>;
46+
3447
/// Compute `inferFn` on `ranges`, whose size should be the index storage
3548
/// bitwidth. Then, compute the function on `argRanges` again after truncating
3649
/// the ranges to 32 bits. Finally, if the truncation of the 64-bit result is
@@ -60,11 +73,14 @@ ConstantIntRanges extSIRange(const ConstantIntRanges &range,
6073
ConstantIntRanges truncRange(const ConstantIntRanges &range,
6174
unsigned destWidth);
6275

63-
ConstantIntRanges inferAdd(ArrayRef<ConstantIntRanges> argRanges);
76+
ConstantIntRanges inferAdd(ArrayRef<ConstantIntRanges> argRanges,
77+
OverflowFlags ovfFlags = OverflowFlags::None);
6478

65-
ConstantIntRanges inferSub(ArrayRef<ConstantIntRanges> argRanges);
79+
ConstantIntRanges inferSub(ArrayRef<ConstantIntRanges> argRanges,
80+
OverflowFlags ovfFlags = OverflowFlags::None);
6681

67-
ConstantIntRanges inferMul(ArrayRef<ConstantIntRanges> argRanges);
82+
ConstantIntRanges inferMul(ArrayRef<ConstantIntRanges> argRanges,
83+
OverflowFlags ovfFlags = OverflowFlags::None);
6884

6985
ConstantIntRanges inferDivS(ArrayRef<ConstantIntRanges> argRanges);
7086

@@ -94,7 +110,8 @@ ConstantIntRanges inferOr(ArrayRef<ConstantIntRanges> argRanges);
94110

95111
ConstantIntRanges inferXor(ArrayRef<ConstantIntRanges> argRanges);
96112

97-
ConstantIntRanges inferShl(ArrayRef<ConstantIntRanges> argRanges);
113+
ConstantIntRanges inferShl(ArrayRef<ConstantIntRanges> argRanges,
114+
OverflowFlags ovfFlags = OverflowFlags::None);
98115

99116
ConstantIntRanges inferShrS(ArrayRef<ConstantIntRanges> argRanges);
100117

mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,16 @@ using namespace mlir;
1919
using namespace mlir::arith;
2020
using namespace mlir::intrange;
2121

22+
static intrange::OverflowFlags
23+
convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {
24+
intrange::OverflowFlags retFlags = intrange::OverflowFlags::None;
25+
if (bitEnumContainsAny(flags, arith::IntegerOverflowFlags::nsw))
26+
retFlags |= intrange::OverflowFlags::Nsw;
27+
if (bitEnumContainsAny(flags, arith::IntegerOverflowFlags::nuw))
28+
retFlags |= intrange::OverflowFlags::Nuw;
29+
return retFlags;
30+
}
31+
2232
//===----------------------------------------------------------------------===//
2333
// ConstantOp
2434
//===----------------------------------------------------------------------===//
@@ -38,7 +48,8 @@ void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
3848

3949
void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
4050
SetIntRangeFn setResultRange) {
41-
setResultRange(getResult(), inferAdd(argRanges));
51+
setResultRange(getResult(), inferAdd(argRanges, convertArithOverflowFlags(
52+
getOverflowFlags())));
4253
}
4354

4455
//===----------------------------------------------------------------------===//
@@ -47,7 +58,8 @@ void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
4758

4859
void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
4960
SetIntRangeFn setResultRange) {
50-
setResultRange(getResult(), inferSub(argRanges));
61+
setResultRange(getResult(), inferSub(argRanges, convertArithOverflowFlags(
62+
getOverflowFlags())));
5163
}
5264

5365
//===----------------------------------------------------------------------===//
@@ -56,7 +68,8 @@ void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
5668

5769
void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
5870
SetIntRangeFn setResultRange) {
59-
setResultRange(getResult(), inferMul(argRanges));
71+
setResultRange(getResult(), inferMul(argRanges, convertArithOverflowFlags(
72+
getOverflowFlags())));
6073
}
6174

6275
//===----------------------------------------------------------------------===//
@@ -302,7 +315,8 @@ void arith::SelectOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
302315

303316
void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
304317
SetIntRangeFn setResultRange) {
305-
setResultRange(getResult(), inferShl(argRanges));
318+
setResultRange(getResult(), inferShl(argRanges, convertArithOverflowFlags(
319+
getOverflowFlags())));
306320
}
307321

308322
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,19 +44,32 @@ void BoolConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
4444
// we take the 64-bit result).
4545
//===----------------------------------------------------------------------===//
4646

47+
// Some arithmetic inference functions allow specifying special overflow / wrap
48+
// behavior. We do not require this for the IndexOps and use this helper to call
49+
// the inference function without any `OverflowFlags`.
50+
static std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>
51+
inferWithoutOverflowFlags(InferRangeWithOvfFlagsFn inferWithOvfFn) {
52+
return [inferWithOvfFn](ArrayRef<ConstantIntRanges> argRanges) {
53+
return inferWithOvfFn(argRanges, OverflowFlags::None);
54+
};
55+
}
56+
4757
void AddOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
4858
SetIntRangeFn setResultRange) {
49-
setResultRange(getResult(), inferIndexOp(inferAdd, argRanges, CmpMode::Both));
59+
setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferAdd),
60+
argRanges, CmpMode::Both));
5061
}
5162

5263
void SubOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
5364
SetIntRangeFn setResultRange) {
54-
setResultRange(getResult(), inferIndexOp(inferSub, argRanges, CmpMode::Both));
65+
setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferSub),
66+
argRanges, CmpMode::Both));
5567
}
5668

5769
void MulOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
5870
SetIntRangeFn setResultRange) {
59-
setResultRange(getResult(), inferIndexOp(inferMul, argRanges, CmpMode::Both));
71+
setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferMul),
72+
argRanges, CmpMode::Both));
6073
}
6174

6275
void DivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
@@ -127,7 +140,8 @@ void MinUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
127140

128141
void ShlOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
129142
SetIntRangeFn setResultRange) {
130-
setResultRange(getResult(), inferIndexOp(inferShl, argRanges, CmpMode::Both));
143+
setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferShl),
144+
argRanges, CmpMode::Both));
131145
}
132146

133147
void ShrSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,

mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp

Lines changed: 59 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -178,18 +178,24 @@ ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
178178
//===----------------------------------------------------------------------===//
179179

180180
ConstantIntRanges
181-
mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges) {
181+
mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges,
182+
OverflowFlags ovfFlags) {
182183
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
183-
ConstArithFn uadd = [](const APInt &a,
184-
const APInt &b) -> std::optional<APInt> {
184+
185+
std::function uadd = [=](const APInt &a,
186+
const APInt &b) -> std::optional<APInt> {
185187
bool overflowed = false;
186-
APInt result = a.uadd_ov(b, overflowed);
188+
APInt result = any(ovfFlags & OverflowFlags::Nuw)
189+
? a.uadd_sat(b)
190+
: a.uadd_ov(b, overflowed);
187191
return overflowed ? std::optional<APInt>() : result;
188192
};
189-
ConstArithFn sadd = [](const APInt &a,
190-
const APInt &b) -> std::optional<APInt> {
193+
std::function sadd = [=](const APInt &a,
194+
const APInt &b) -> std::optional<APInt> {
191195
bool overflowed = false;
192-
APInt result = a.sadd_ov(b, overflowed);
196+
APInt result = any(ovfFlags & OverflowFlags::Nsw)
197+
? a.sadd_sat(b)
198+
: a.sadd_ov(b, overflowed);
193199
return overflowed ? std::optional<APInt>() : result;
194200
};
195201

@@ -205,19 +211,24 @@ mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges) {
205211
//===----------------------------------------------------------------------===//
206212

207213
ConstantIntRanges
208-
mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges) {
214+
mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges,
215+
OverflowFlags ovfFlags) {
209216
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
210217

211-
ConstArithFn usub = [](const APInt &a,
212-
const APInt &b) -> std::optional<APInt> {
218+
std::function usub = [=](const APInt &a,
219+
const APInt &b) -> std::optional<APInt> {
213220
bool overflowed = false;
214-
APInt result = a.usub_ov(b, overflowed);
221+
APInt result = any(ovfFlags & OverflowFlags::Nuw)
222+
? a.usub_sat(b)
223+
: a.usub_ov(b, overflowed);
215224
return overflowed ? std::optional<APInt>() : result;
216225
};
217-
ConstArithFn ssub = [](const APInt &a,
218-
const APInt &b) -> std::optional<APInt> {
226+
std::function ssub = [=](const APInt &a,
227+
const APInt &b) -> std::optional<APInt> {
219228
bool overflowed = false;
220-
APInt result = a.ssub_ov(b, overflowed);
229+
APInt result = any(ovfFlags & OverflowFlags::Nsw)
230+
? a.ssub_sat(b)
231+
: a.ssub_ov(b, overflowed);
221232
return overflowed ? std::optional<APInt>() : result;
222233
};
223234
ConstantIntRanges urange = computeBoundsBy(
@@ -232,19 +243,24 @@ mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges) {
232243
//===----------------------------------------------------------------------===//
233244

234245
ConstantIntRanges
235-
mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges) {
246+
mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges,
247+
OverflowFlags ovfFlags) {
236248
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
237249

238-
ConstArithFn umul = [](const APInt &a,
239-
const APInt &b) -> std::optional<APInt> {
250+
std::function umul = [=](const APInt &a,
251+
const APInt &b) -> std::optional<APInt> {
240252
bool overflowed = false;
241-
APInt result = a.umul_ov(b, overflowed);
253+
APInt result = any(ovfFlags & OverflowFlags::Nuw)
254+
? a.umul_sat(b)
255+
: a.umul_ov(b, overflowed);
242256
return overflowed ? std::optional<APInt>() : result;
243257
};
244-
ConstArithFn smul = [](const APInt &a,
245-
const APInt &b) -> std::optional<APInt> {
258+
std::function smul = [=](const APInt &a,
259+
const APInt &b) -> std::optional<APInt> {
246260
bool overflowed = false;
247-
APInt result = a.smul_ov(b, overflowed);
261+
APInt result = any(ovfFlags & OverflowFlags::Nsw)
262+
? a.smul_sat(b)
263+
: a.smul_ov(b, overflowed);
248264
return overflowed ? std::optional<APInt>() : result;
249265
};
250266

@@ -542,32 +558,35 @@ mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) {
542558
//===----------------------------------------------------------------------===//
543559

544560
ConstantIntRanges
545-
mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges) {
561+
mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges,
562+
OverflowFlags ovfFlags) {
546563
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
547-
const APInt &lhsSMin = lhs.smin(), &lhsSMax = lhs.smax(),
548-
&lhsUMax = lhs.umax(), &rhsUMin = rhs.umin(),
549-
&rhsUMax = rhs.umax();
564+
const APInt &rhsUMin = rhs.umin(), &rhsUMax = rhs.umax();
550565

551-
ConstArithFn shl = [](const APInt &l,
552-
const APInt &r) -> std::optional<APInt> {
553-
return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.shl(r);
566+
// The signed/unsigned overflow behavior of shl by `rhs` matches a mul with
567+
// 2^rhs.
568+
std::function ushl = [=](const APInt &l,
569+
const APInt &r) -> std::optional<APInt> {
570+
bool overflowed = false;
571+
APInt result = any(ovfFlags & OverflowFlags::Nuw)
572+
? l.ushl_sat(r)
573+
: l.ushl_ov(r, overflowed);
574+
return overflowed ? std::optional<APInt>() : result;
575+
};
576+
std::function sshl = [=](const APInt &l,
577+
const APInt &r) -> std::optional<APInt> {
578+
bool overflowed = false;
579+
APInt result = any(ovfFlags & OverflowFlags::Nsw)
580+
? l.sshl_sat(r)
581+
: l.sshl_ov(r, overflowed);
582+
return overflowed ? std::optional<APInt>() : result;
554583
};
555-
556-
// The minMax inference does not work when there is danger of overflow. In the
557-
// signed case, this leads to the obvious problem that the sign bit might
558-
// change. In the unsigned case, it also leads to problems because the largest
559-
// LHS shifted by the largest RHS does not necessarily result in the largest
560-
// result anymore.
561-
assert(rhsUMax.isNonNegative() && "Unexpected negative shift count");
562-
if (rhsUMax.uge(lhsSMin.getNumSignBits()) ||
563-
rhsUMax.uge(lhsSMax.getNumSignBits()))
564-
return ConstantIntRanges::maxRange(lhsUMax.getBitWidth());
565584

566585
ConstantIntRanges urange =
567-
minMaxBy(shl, {lhs.umin(), lhsUMax}, {rhsUMin, rhsUMax},
586+
minMaxBy(ushl, {lhs.umin(), lhs.umax()}, {rhsUMin, rhsUMax},
568587
/*isSigned=*/false);
569588
ConstantIntRanges srange =
570-
minMaxBy(shl, {lhsSMin, lhsSMax}, {rhsUMin, rhsUMax},
589+
minMaxBy(sshl, {lhs.smin(), lhs.smax()}, {rhsUMin, rhsUMax},
571590
/*isSigned=*/true);
572591
return urange.intersection(srange);
573592
}

0 commit comments

Comments
 (0)