Skip to content

Commit 039bdcc

Browse files
committed
[MLIR] Canonicalize sub/add of a constant and another sub/add of a constant
Differential Revision: https://reviews.llvm.org/D101705
1 parent 3ed6a6f commit 039bdcc

File tree

3 files changed

+315
-0
lines changed

3 files changed

+315
-0
lines changed

mlir/include/mlir/Dialect/StandardOps/IR/Ops.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ def AddIOp : IntBinaryOp<"addi", [Commutative]> {
277277
```
278278
}];
279279
let hasFolder = 1;
280+
let hasCanonicalizer = 1;
280281
}
281282

282283
//===----------------------------------------------------------------------===//
@@ -1792,6 +1793,7 @@ def SubFOp : FloatBinaryOp<"subf"> {
17921793
def SubIOp : IntBinaryOp<"subi"> {
17931794
let summary = "integer subtraction operation";
17941795
let hasFolder = 1;
1796+
let hasCanonicalizer = 1;
17951797
}
17961798

17971799
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/StandardOps/IR/Ops.cpp

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,62 @@ static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
283283
}));
284284
}
285285

286+
/// Canonicalize a sum of a constant and (constant - something) to simply be
287+
/// a sum of constants minus something. This transformation does similar
288+
/// transformations for additions of a constant with a subtract/add of
289+
/// a constant. This may result in some operations being reordered (but should
290+
/// remain equivalent).
291+
struct AddConstantReorder : public OpRewritePattern<AddIOp> {
292+
using OpRewritePattern<AddIOp>::OpRewritePattern;
293+
294+
LogicalResult matchAndRewrite(AddIOp addop,
295+
PatternRewriter &rewriter) const override {
296+
for (int i = 0; i < 2; i++) {
297+
APInt origConst;
298+
APInt midConst;
299+
if (matchPattern(addop.getOperand(i), m_ConstantInt(&origConst))) {
300+
if (auto midAddOp = addop.getOperand(1 - i).getDefiningOp<AddIOp>()) {
301+
for (int j = 0; j < 2; j++) {
302+
if (matchPattern(midAddOp.getOperand(j),
303+
m_ConstantInt(&midConst))) {
304+
auto nextConstant = rewriter.create<ConstantOp>(
305+
addop.getLoc(), rewriter.getIntegerAttr(
306+
addop.getType(), origConst + midConst));
307+
rewriter.replaceOpWithNewOp<AddIOp>(addop, nextConstant,
308+
midAddOp.getOperand(1 - j));
309+
return success();
310+
}
311+
}
312+
}
313+
if (auto midSubOp = addop.getOperand(1 - i).getDefiningOp<SubIOp>()) {
314+
if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
315+
auto nextConstant = rewriter.create<ConstantOp>(
316+
addop.getLoc(),
317+
rewriter.getIntegerAttr(addop.getType(), origConst + midConst));
318+
rewriter.replaceOpWithNewOp<SubIOp>(addop, nextConstant,
319+
midSubOp.getOperand(1));
320+
return success();
321+
}
322+
if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
323+
auto nextConstant = rewriter.create<ConstantOp>(
324+
addop.getLoc(),
325+
rewriter.getIntegerAttr(addop.getType(), origConst - midConst));
326+
rewriter.replaceOpWithNewOp<AddIOp>(addop, nextConstant,
327+
midSubOp.getOperand(0));
328+
return success();
329+
}
330+
}
331+
}
332+
}
333+
return failure();
334+
}
335+
};
336+
337+
void AddIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
338+
MLIRContext *context) {
339+
results.insert<AddConstantReorder>(context);
340+
}
341+
286342
//===----------------------------------------------------------------------===//
287343
// AndOp
288344
//===----------------------------------------------------------------------===//
@@ -1706,6 +1762,153 @@ OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {
17061762
[](APInt a, APInt b) { return a - b; });
17071763
}
17081764

1765+
/// Canonicalize a sub of a constant and (constant +/- something) to simply be
1766+
/// a single operation that merges the two constants.
1767+
struct SubConstantReorder : public OpRewritePattern<SubIOp> {
1768+
using OpRewritePattern<SubIOp>::OpRewritePattern;
1769+
1770+
LogicalResult matchAndRewrite(SubIOp subOp,
1771+
PatternRewriter &rewriter) const override {
1772+
APInt origConst;
1773+
APInt midConst;
1774+
1775+
if (matchPattern(subOp.getOperand(0), m_ConstantInt(&origConst))) {
1776+
if (auto midAddOp = subOp.getOperand(1).getDefiningOp<AddIOp>()) {
1777+
// origConst - (midConst + something) == (origConst - midConst) -
1778+
// something
1779+
for (int j = 0; j < 2; j++) {
1780+
if (matchPattern(midAddOp.getOperand(j), m_ConstantInt(&midConst))) {
1781+
auto nextConstant = rewriter.create<ConstantOp>(
1782+
subOp.getLoc(),
1783+
rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
1784+
rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
1785+
midAddOp.getOperand(1 - j));
1786+
return success();
1787+
}
1788+
}
1789+
}
1790+
1791+
if (auto midSubOp = subOp.getOperand(0).getDefiningOp<SubIOp>()) {
1792+
if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
1793+
// (midConst - something) - origConst == (midConst - origConst) -
1794+
// something
1795+
auto nextConstant = rewriter.create<ConstantOp>(
1796+
subOp.getLoc(),
1797+
rewriter.getIntegerAttr(subOp.getType(), midConst - origConst));
1798+
rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
1799+
midSubOp.getOperand(1));
1800+
return success();
1801+
}
1802+
1803+
if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
1804+
// (something - midConst) - origConst == something - (origConst +
1805+
// midConst)
1806+
auto nextConstant = rewriter.create<ConstantOp>(
1807+
subOp.getLoc(),
1808+
rewriter.getIntegerAttr(subOp.getType(), origConst + midConst));
1809+
rewriter.replaceOpWithNewOp<SubIOp>(subOp, midSubOp.getOperand(0),
1810+
nextConstant);
1811+
return success();
1812+
}
1813+
}
1814+
1815+
if (auto midSubOp = subOp.getOperand(1).getDefiningOp<SubIOp>()) {
1816+
if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
1817+
// origConst - (midConst - something) == (origConst - midConst) +
1818+
// something
1819+
auto nextConstant = rewriter.create<ConstantOp>(
1820+
subOp.getLoc(),
1821+
rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
1822+
rewriter.replaceOpWithNewOp<AddIOp>(subOp, nextConstant,
1823+
midSubOp.getOperand(1));
1824+
return success();
1825+
}
1826+
1827+
if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
1828+
// origConst - (something - midConst) == (origConst + midConst) -
1829+
// something
1830+
auto nextConstant = rewriter.create<ConstantOp>(
1831+
subOp.getLoc(),
1832+
rewriter.getIntegerAttr(subOp.getType(), origConst + midConst));
1833+
rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
1834+
midSubOp.getOperand(0));
1835+
return success();
1836+
}
1837+
}
1838+
}
1839+
1840+
if (matchPattern(subOp.getOperand(1), m_ConstantInt(&origConst))) {
1841+
if (auto midAddOp = subOp.getOperand(0).getDefiningOp<AddIOp>()) {
1842+
// (midConst + something) - origConst == (midConst - origConst) +
1843+
// something
1844+
for (int j = 0; j < 2; j++) {
1845+
if (matchPattern(midAddOp.getOperand(j), m_ConstantInt(&midConst))) {
1846+
auto nextConstant = rewriter.create<ConstantOp>(
1847+
subOp.getLoc(),
1848+
rewriter.getIntegerAttr(subOp.getType(), midConst - origConst));
1849+
rewriter.replaceOpWithNewOp<AddIOp>(subOp, nextConstant,
1850+
midAddOp.getOperand(1 - j));
1851+
return success();
1852+
}
1853+
}
1854+
}
1855+
1856+
if (auto midSubOp = subOp.getOperand(0).getDefiningOp<SubIOp>()) {
1857+
if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
1858+
// (midConst - something) - origConst == (midConst - origConst) -
1859+
// something
1860+
auto nextConstant = rewriter.create<ConstantOp>(
1861+
subOp.getLoc(),
1862+
rewriter.getIntegerAttr(subOp.getType(), midConst - origConst));
1863+
rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
1864+
midSubOp.getOperand(1));
1865+
return success();
1866+
}
1867+
1868+
if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
1869+
// (something - midConst) - origConst == something - (midConst +
1870+
// origConst)
1871+
auto nextConstant = rewriter.create<ConstantOp>(
1872+
subOp.getLoc(),
1873+
rewriter.getIntegerAttr(subOp.getType(), midConst + origConst));
1874+
rewriter.replaceOpWithNewOp<SubIOp>(subOp, midSubOp.getOperand(0),
1875+
nextConstant);
1876+
return success();
1877+
}
1878+
}
1879+
1880+
if (auto midSubOp = subOp.getOperand(1).getDefiningOp<SubIOp>()) {
1881+
if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
1882+
// origConst - (midConst - something) == (origConst - midConst) +
1883+
// something
1884+
auto nextConstant = rewriter.create<ConstantOp>(
1885+
subOp.getLoc(),
1886+
rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
1887+
rewriter.replaceOpWithNewOp<AddIOp>(subOp, nextConstant,
1888+
midSubOp.getOperand(1));
1889+
return success();
1890+
}
1891+
if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
1892+
// origConst - (something - midConst) == (origConst - midConst) -
1893+
// something
1894+
auto nextConstant = rewriter.create<ConstantOp>(
1895+
subOp.getLoc(),
1896+
rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
1897+
rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
1898+
midSubOp.getOperand(0));
1899+
return success();
1900+
}
1901+
}
1902+
}
1903+
return failure();
1904+
}
1905+
};
1906+
1907+
void SubIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1908+
MLIRContext *context) {
1909+
results.insert<SubConstantReorder>(context);
1910+
}
1911+
17091912
//===----------------------------------------------------------------------===//
17101913
// UIToFPOp
17111914
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Standard/canonicalize.mlir

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,3 +428,113 @@ func @truncConstant(%arg0: i8) -> i16 {
428428
%tr = trunci %c-2 : i32 to i16
429429
return %tr : i16
430430
}
431+
432+
// -----
433+
434+
// CHECK-LABEL: @tripleAddAdd
435+
// CHECK: %[[cres:.+]] = constant 59 : index
436+
// CHECK: %[[add:.+]] = addi %arg0, %[[cres]] : index
437+
// CHECK: return %[[add]]
438+
func @tripleAddAdd(%arg0: index) -> index {
439+
%c17 = constant 17 : index
440+
%c42 = constant 42 : index
441+
%add1 = addi %c17, %arg0 : index
442+
%add2 = addi %c42, %add1 : index
443+
return %add2 : index
444+
}
445+
446+
// CHECK-LABEL: @tripleAddSub0
447+
// CHECK: %[[cres:.+]] = constant 59 : index
448+
// CHECK: %[[add:.+]] = subi %[[cres]], %arg0 : index
449+
// CHECK: return %[[add]]
450+
func @tripleAddSub0(%arg0: index) -> index {
451+
%c17 = constant 17 : index
452+
%c42 = constant 42 : index
453+
%add1 = subi %c17, %arg0 : index
454+
%add2 = addi %c42, %add1 : index
455+
return %add2 : index
456+
}
457+
458+
// CHECK-LABEL: @tripleAddSub1
459+
// CHECK: %[[cres:.+]] = constant 25 : index
460+
// CHECK: %[[add:.+]] = addi %arg0, %[[cres]] : index
461+
// CHECK: return %[[add]]
462+
func @tripleAddSub1(%arg0: index) -> index {
463+
%c17 = constant 17 : index
464+
%c42 = constant 42 : index
465+
%add1 = subi %arg0, %c17 : index
466+
%add2 = addi %c42, %add1 : index
467+
return %add2 : index
468+
}
469+
470+
// CHECK-LABEL: @tripleSubAdd0
471+
// CHECK: %[[cres:.+]] = constant 25 : index
472+
// CHECK: %[[add:.+]] = subi %[[cres]], %arg0 : index
473+
// CHECK: return %[[add]]
474+
func @tripleSubAdd0(%arg0: index) -> index {
475+
%c17 = constant 17 : index
476+
%c42 = constant 42 : index
477+
%add1 = addi %c17, %arg0 : index
478+
%add2 = subi %c42, %add1 : index
479+
return %add2 : index
480+
}
481+
482+
// CHECK-LABEL: @tripleSubAdd1
483+
// CHECK: %[[cres:.+]] = constant -25 : index
484+
// CHECK: %[[add:.+]] = addi %arg0, %[[cres]] : index
485+
// CHECK: return %[[add]]
486+
func @tripleSubAdd1(%arg0: index) -> index {
487+
%c17 = constant 17 : index
488+
%c42 = constant 42 : index
489+
%add1 = addi %c17, %arg0 : index
490+
%add2 = subi %add1, %c42 : index
491+
return %add2 : index
492+
}
493+
494+
// CHECK-LABEL: @tripleSubSub0
495+
// CHECK: %[[cres:.+]] = constant 25 : index
496+
// CHECK: %[[add:.+]] = addi %arg0, %[[cres]] : index
497+
// CHECK: return %[[add]]
498+
func @tripleSubSub0(%arg0: index) -> index {
499+
%c17 = constant 17 : index
500+
%c42 = constant 42 : index
501+
%add1 = subi %c17, %arg0 : index
502+
%add2 = subi %c42, %add1 : index
503+
return %add2 : index
504+
}
505+
506+
// CHECK-LABEL: @tripleSubSub1
507+
// CHECK: %[[cres:.+]] = constant -25 : index
508+
// CHECK: %[[add:.+]] = subi %[[cres]], %arg0 : index
509+
// CHECK: return %[[add]]
510+
func @tripleSubSub1(%arg0: index) -> index {
511+
%c17 = constant 17 : index
512+
%c42 = constant 42 : index
513+
%add1 = subi %c17, %arg0 : index
514+
%add2 = subi %add1, %c42 : index
515+
return %add2 : index
516+
}
517+
518+
// CHECK-LABEL: @tripleSubSub2
519+
// CHECK: %[[cres:.+]] = constant 59 : index
520+
// CHECK: %[[add:.+]] = subi %[[cres]], %arg0 : index
521+
// CHECK: return %[[add]]
522+
func @tripleSubSub2(%arg0: index) -> index {
523+
%c17 = constant 17 : index
524+
%c42 = constant 42 : index
525+
%add1 = subi %arg0, %c17 : index
526+
%add2 = subi %c42, %add1 : index
527+
return %add2 : index
528+
}
529+
530+
// CHECK-LABEL: @tripleSubSub3
531+
// CHECK: %[[cres:.+]] = constant 59 : index
532+
// CHECK: %[[add:.+]] = subi %arg0, %[[cres]] : index
533+
// CHECK: return %[[add]]
534+
func @tripleSubSub3(%arg0: index) -> index {
535+
%c17 = constant 17 : index
536+
%c42 = constant 42 : index
537+
%add1 = subi %arg0, %c17 : index
538+
%add2 = subi %add1, %c42 : index
539+
return %add2 : index
540+
}

0 commit comments

Comments
 (0)