@@ -178,18 +178,24 @@ ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
178
178
// ===----------------------------------------------------------------------===//
179
179
180
180
ConstantIntRanges
181
- mlir::intrange::inferAdd (ArrayRef<ConstantIntRanges> argRanges) {
181
+ mlir::intrange::inferAdd (ArrayRef<ConstantIntRanges> argRanges,
182
+ OverflowFlags ovfFlags) {
182
183
const ConstantIntRanges &lhs = argRanges[0 ], &rhs = argRanges[1 ];
183
- ConstArithFn uadd = [](const APInt &a,
184
- const APInt &b) -> std::optional<APInt> {
184
+
185
+ std::function uadd = [=](const APInt &a,
186
+ const APInt &b) -> std::optional<APInt> {
185
187
bool overflowed = false ;
186
- APInt result = a.uadd_ov (b, overflowed);
188
+ APInt result = any (ovfFlags & OverflowFlags::Nuw)
189
+ ? a.uadd_sat (b)
190
+ : a.uadd_ov (b, overflowed);
187
191
return overflowed ? std::optional<APInt>() : result;
188
192
};
189
- ConstArithFn sadd = [](const APInt &a,
190
- const APInt &b) -> std::optional<APInt> {
193
+ std::function sadd = [= ](const APInt &a,
194
+ const APInt &b) -> std::optional<APInt> {
191
195
bool overflowed = false ;
192
- APInt result = a.sadd_ov (b, overflowed);
196
+ APInt result = any (ovfFlags & OverflowFlags::Nsw)
197
+ ? a.sadd_sat (b)
198
+ : a.sadd_ov (b, overflowed);
193
199
return overflowed ? std::optional<APInt>() : result;
194
200
};
195
201
@@ -205,19 +211,24 @@ mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges) {
205
211
// ===----------------------------------------------------------------------===//
206
212
207
213
ConstantIntRanges
208
- mlir::intrange::inferSub (ArrayRef<ConstantIntRanges> argRanges) {
214
+ mlir::intrange::inferSub (ArrayRef<ConstantIntRanges> argRanges,
215
+ OverflowFlags ovfFlags) {
209
216
const ConstantIntRanges &lhs = argRanges[0 ], &rhs = argRanges[1 ];
210
217
211
- ConstArithFn usub = [](const APInt &a,
212
- const APInt &b) -> std::optional<APInt> {
218
+ std::function usub = [= ](const APInt &a,
219
+ const APInt &b) -> std::optional<APInt> {
213
220
bool overflowed = false ;
214
- APInt result = a.usub_ov (b, overflowed);
221
+ APInt result = any (ovfFlags & OverflowFlags::Nuw)
222
+ ? a.usub_sat (b)
223
+ : a.usub_ov (b, overflowed);
215
224
return overflowed ? std::optional<APInt>() : result;
216
225
};
217
- ConstArithFn ssub = [](const APInt &a,
218
- const APInt &b) -> std::optional<APInt> {
226
+ std::function ssub = [= ](const APInt &a,
227
+ const APInt &b) -> std::optional<APInt> {
219
228
bool overflowed = false ;
220
- APInt result = a.ssub_ov (b, overflowed);
229
+ APInt result = any (ovfFlags & OverflowFlags::Nsw)
230
+ ? a.ssub_sat (b)
231
+ : a.ssub_ov (b, overflowed);
221
232
return overflowed ? std::optional<APInt>() : result;
222
233
};
223
234
ConstantIntRanges urange = computeBoundsBy (
@@ -232,19 +243,24 @@ mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges) {
232
243
// ===----------------------------------------------------------------------===//
233
244
234
245
ConstantIntRanges
235
- mlir::intrange::inferMul (ArrayRef<ConstantIntRanges> argRanges) {
246
+ mlir::intrange::inferMul (ArrayRef<ConstantIntRanges> argRanges,
247
+ OverflowFlags ovfFlags) {
236
248
const ConstantIntRanges &lhs = argRanges[0 ], &rhs = argRanges[1 ];
237
249
238
- ConstArithFn umul = [](const APInt &a,
239
- const APInt &b) -> std::optional<APInt> {
250
+ std::function umul = [= ](const APInt &a,
251
+ const APInt &b) -> std::optional<APInt> {
240
252
bool overflowed = false ;
241
- APInt result = a.umul_ov (b, overflowed);
253
+ APInt result = any (ovfFlags & OverflowFlags::Nuw)
254
+ ? a.umul_sat (b)
255
+ : a.umul_ov (b, overflowed);
242
256
return overflowed ? std::optional<APInt>() : result;
243
257
};
244
- ConstArithFn smul = [](const APInt &a,
245
- const APInt &b) -> std::optional<APInt> {
258
+ std::function smul = [= ](const APInt &a,
259
+ const APInt &b) -> std::optional<APInt> {
246
260
bool overflowed = false ;
247
- APInt result = a.smul_ov (b, overflowed);
261
+ APInt result = any (ovfFlags & OverflowFlags::Nsw)
262
+ ? a.smul_sat (b)
263
+ : a.smul_ov (b, overflowed);
248
264
return overflowed ? std::optional<APInt>() : result;
249
265
};
250
266
@@ -542,32 +558,35 @@ mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) {
542
558
// ===----------------------------------------------------------------------===//
543
559
544
560
ConstantIntRanges
545
- mlir::intrange::inferShl (ArrayRef<ConstantIntRanges> argRanges) {
561
+ mlir::intrange::inferShl (ArrayRef<ConstantIntRanges> argRanges,
562
+ OverflowFlags ovfFlags) {
546
563
const ConstantIntRanges &lhs = argRanges[0 ], &rhs = argRanges[1 ];
547
- const APInt &lhsSMin = lhs.smin (), &lhsSMax = lhs.smax (),
548
- &lhsUMax = lhs.umax (), &rhsUMin = rhs.umin (),
549
- &rhsUMax = rhs.umax ();
564
+ const APInt &rhsUMin = rhs.umin (), &rhsUMax = rhs.umax ();
550
565
551
- ConstArithFn shl = [](const APInt &l,
552
- const APInt &r) -> std::optional<APInt> {
553
- return r.uge (r.getBitWidth ()) ? std::optional<APInt>() : l.shl (r);
566
+ // The signed/unsigned overflow behavior of shl by `rhs` matches a mul with
567
+ // 2^rhs.
568
+ std::function ushl = [=](const APInt &l,
569
+ const APInt &r) -> std::optional<APInt> {
570
+ bool overflowed = false ;
571
+ APInt result = any (ovfFlags & OverflowFlags::Nuw)
572
+ ? l.ushl_sat (r)
573
+ : l.ushl_ov (r, overflowed);
574
+ return overflowed ? std::optional<APInt>() : result;
575
+ };
576
+ std::function sshl = [=](const APInt &l,
577
+ const APInt &r) -> std::optional<APInt> {
578
+ bool overflowed = false ;
579
+ APInt result = any (ovfFlags & OverflowFlags::Nsw)
580
+ ? l.sshl_sat (r)
581
+ : l.sshl_ov (r, overflowed);
582
+ return overflowed ? std::optional<APInt>() : result;
554
583
};
555
-
556
- // The minMax inference does not work when there is danger of overflow. In the
557
- // signed case, this leads to the obvious problem that the sign bit might
558
- // change. In the unsigned case, it also leads to problems because the largest
559
- // LHS shifted by the largest RHS does not necessarily result in the largest
560
- // result anymore.
561
- assert (rhsUMax.isNonNegative () && " Unexpected negative shift count" );
562
- if (rhsUMax.uge (lhsSMin.getNumSignBits ()) ||
563
- rhsUMax.uge (lhsSMax.getNumSignBits ()))
564
- return ConstantIntRanges::maxRange (lhsUMax.getBitWidth ());
565
584
566
585
ConstantIntRanges urange =
567
- minMaxBy (shl , {lhs.umin (), lhsUMax }, {rhsUMin, rhsUMax},
586
+ minMaxBy (ushl , {lhs.umin (), lhs. umax () }, {rhsUMin, rhsUMax},
568
587
/* isSigned=*/ false );
569
588
ConstantIntRanges srange =
570
- minMaxBy (shl , {lhsSMin, lhsSMax }, {rhsUMin, rhsUMax},
589
+ minMaxBy (sshl , {lhs. smin (), lhs. smax () }, {rhsUMin, rhsUMax},
571
590
/* isSigned=*/ true );
572
591
return urange.intersection (srange);
573
592
}
0 commit comments