@@ -283,6 +283,62 @@ static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
283
283
}));
284
284
}
285
285
286
+ // / Canonicalize a sum of a constant and (constant - something) to simply be
287
+ // / a sum of constants minus something. This transformation does similar
288
+ // / transformations for additions of a constant with a subtract/add of
289
+ // / a constant. This may result in some operations being reordered (but should
290
+ // / remain equivalent).
291
+ struct AddConstantReorder : public OpRewritePattern <AddIOp> {
292
+ using OpRewritePattern<AddIOp>::OpRewritePattern;
293
+
294
+ LogicalResult matchAndRewrite (AddIOp addop,
295
+ PatternRewriter &rewriter) const override {
296
+ for (int i = 0 ; i < 2 ; i++) {
297
+ APInt origConst;
298
+ APInt midConst;
299
+ if (matchPattern (addop.getOperand (i), m_ConstantInt (&origConst))) {
300
+ if (auto midAddOp = addop.getOperand (1 - i).getDefiningOp <AddIOp>()) {
301
+ for (int j = 0 ; j < 2 ; j++) {
302
+ if (matchPattern (midAddOp.getOperand (j),
303
+ m_ConstantInt (&midConst))) {
304
+ auto nextConstant = rewriter.create <ConstantOp>(
305
+ addop.getLoc (), rewriter.getIntegerAttr (
306
+ addop.getType (), origConst + midConst));
307
+ rewriter.replaceOpWithNewOp <AddIOp>(addop, nextConstant,
308
+ midAddOp.getOperand (1 - j));
309
+ return success ();
310
+ }
311
+ }
312
+ }
313
+ if (auto midSubOp = addop.getOperand (1 - i).getDefiningOp <SubIOp>()) {
314
+ if (matchPattern (midSubOp.getOperand (0 ), m_ConstantInt (&midConst))) {
315
+ auto nextConstant = rewriter.create <ConstantOp>(
316
+ addop.getLoc (),
317
+ rewriter.getIntegerAttr (addop.getType (), origConst + midConst));
318
+ rewriter.replaceOpWithNewOp <SubIOp>(addop, nextConstant,
319
+ midSubOp.getOperand (1 ));
320
+ return success ();
321
+ }
322
+ if (matchPattern (midSubOp.getOperand (1 ), m_ConstantInt (&midConst))) {
323
+ auto nextConstant = rewriter.create <ConstantOp>(
324
+ addop.getLoc (),
325
+ rewriter.getIntegerAttr (addop.getType (), origConst - midConst));
326
+ rewriter.replaceOpWithNewOp <AddIOp>(addop, nextConstant,
327
+ midSubOp.getOperand (0 ));
328
+ return success ();
329
+ }
330
+ }
331
+ }
332
+ }
333
+ return failure ();
334
+ }
335
+ };
336
+
337
+ void AddIOp::getCanonicalizationPatterns (OwningRewritePatternList &results,
338
+ MLIRContext *context) {
339
+ results.insert <AddConstantReorder>(context);
340
+ }
341
+
286
342
// ===----------------------------------------------------------------------===//
287
343
// AndOp
288
344
// ===----------------------------------------------------------------------===//
@@ -1706,6 +1762,153 @@ OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {
1706
1762
[](APInt a, APInt b) { return a - b; });
1707
1763
}
1708
1764
1765
+ // / Canonicalize a sub of a constant and (constant +/- something) to simply be
1766
+ // / a single operation that merges the two constants.
1767
+ struct SubConstantReorder : public OpRewritePattern <SubIOp> {
1768
+ using OpRewritePattern<SubIOp>::OpRewritePattern;
1769
+
1770
+ LogicalResult matchAndRewrite (SubIOp subOp,
1771
+ PatternRewriter &rewriter) const override {
1772
+ APInt origConst;
1773
+ APInt midConst;
1774
+
1775
+ if (matchPattern (subOp.getOperand (0 ), m_ConstantInt (&origConst))) {
1776
+ if (auto midAddOp = subOp.getOperand (1 ).getDefiningOp <AddIOp>()) {
1777
+ // origConst - (midConst + something) == (origConst - midConst) -
1778
+ // something
1779
+ for (int j = 0 ; j < 2 ; j++) {
1780
+ if (matchPattern (midAddOp.getOperand (j), m_ConstantInt (&midConst))) {
1781
+ auto nextConstant = rewriter.create <ConstantOp>(
1782
+ subOp.getLoc (),
1783
+ rewriter.getIntegerAttr (subOp.getType (), origConst - midConst));
1784
+ rewriter.replaceOpWithNewOp <SubIOp>(subOp, nextConstant,
1785
+ midAddOp.getOperand (1 - j));
1786
+ return success ();
1787
+ }
1788
+ }
1789
+ }
1790
+
1791
+ if (auto midSubOp = subOp.getOperand (0 ).getDefiningOp <SubIOp>()) {
1792
+ if (matchPattern (midSubOp.getOperand (0 ), m_ConstantInt (&midConst))) {
1793
+ // (midConst - something) - origConst == (midConst - origConst) -
1794
+ // something
1795
+ auto nextConstant = rewriter.create <ConstantOp>(
1796
+ subOp.getLoc (),
1797
+ rewriter.getIntegerAttr (subOp.getType (), midConst - origConst));
1798
+ rewriter.replaceOpWithNewOp <SubIOp>(subOp, nextConstant,
1799
+ midSubOp.getOperand (1 ));
1800
+ return success ();
1801
+ }
1802
+
1803
+ if (matchPattern (midSubOp.getOperand (1 ), m_ConstantInt (&midConst))) {
1804
+ // (something - midConst) - origConst == something - (origConst +
1805
+ // midConst)
1806
+ auto nextConstant = rewriter.create <ConstantOp>(
1807
+ subOp.getLoc (),
1808
+ rewriter.getIntegerAttr (subOp.getType (), origConst + midConst));
1809
+ rewriter.replaceOpWithNewOp <SubIOp>(subOp, midSubOp.getOperand (0 ),
1810
+ nextConstant);
1811
+ return success ();
1812
+ }
1813
+ }
1814
+
1815
+ if (auto midSubOp = subOp.getOperand (1 ).getDefiningOp <SubIOp>()) {
1816
+ if (matchPattern (midSubOp.getOperand (0 ), m_ConstantInt (&midConst))) {
1817
+ // origConst - (midConst - something) == (origConst - midConst) +
1818
+ // something
1819
+ auto nextConstant = rewriter.create <ConstantOp>(
1820
+ subOp.getLoc (),
1821
+ rewriter.getIntegerAttr (subOp.getType (), origConst - midConst));
1822
+ rewriter.replaceOpWithNewOp <AddIOp>(subOp, nextConstant,
1823
+ midSubOp.getOperand (1 ));
1824
+ return success ();
1825
+ }
1826
+
1827
+ if (matchPattern (midSubOp.getOperand (1 ), m_ConstantInt (&midConst))) {
1828
+ // origConst - (something - midConst) == (origConst + midConst) -
1829
+ // something
1830
+ auto nextConstant = rewriter.create <ConstantOp>(
1831
+ subOp.getLoc (),
1832
+ rewriter.getIntegerAttr (subOp.getType (), origConst + midConst));
1833
+ rewriter.replaceOpWithNewOp <SubIOp>(subOp, nextConstant,
1834
+ midSubOp.getOperand (0 ));
1835
+ return success ();
1836
+ }
1837
+ }
1838
+ }
1839
+
1840
+ if (matchPattern (subOp.getOperand (1 ), m_ConstantInt (&origConst))) {
1841
+ if (auto midAddOp = subOp.getOperand (0 ).getDefiningOp <AddIOp>()) {
1842
+ // (midConst + something) - origConst == (midConst - origConst) +
1843
+ // something
1844
+ for (int j = 0 ; j < 2 ; j++) {
1845
+ if (matchPattern (midAddOp.getOperand (j), m_ConstantInt (&midConst))) {
1846
+ auto nextConstant = rewriter.create <ConstantOp>(
1847
+ subOp.getLoc (),
1848
+ rewriter.getIntegerAttr (subOp.getType (), midConst - origConst));
1849
+ rewriter.replaceOpWithNewOp <AddIOp>(subOp, nextConstant,
1850
+ midAddOp.getOperand (1 - j));
1851
+ return success ();
1852
+ }
1853
+ }
1854
+ }
1855
+
1856
+ if (auto midSubOp = subOp.getOperand (0 ).getDefiningOp <SubIOp>()) {
1857
+ if (matchPattern (midSubOp.getOperand (0 ), m_ConstantInt (&midConst))) {
1858
+ // (midConst - something) - origConst == (midConst - origConst) -
1859
+ // something
1860
+ auto nextConstant = rewriter.create <ConstantOp>(
1861
+ subOp.getLoc (),
1862
+ rewriter.getIntegerAttr (subOp.getType (), midConst - origConst));
1863
+ rewriter.replaceOpWithNewOp <SubIOp>(subOp, nextConstant,
1864
+ midSubOp.getOperand (1 ));
1865
+ return success ();
1866
+ }
1867
+
1868
+ if (matchPattern (midSubOp.getOperand (1 ), m_ConstantInt (&midConst))) {
1869
+ // (something - midConst) - origConst == something - (midConst +
1870
+ // origConst)
1871
+ auto nextConstant = rewriter.create <ConstantOp>(
1872
+ subOp.getLoc (),
1873
+ rewriter.getIntegerAttr (subOp.getType (), midConst + origConst));
1874
+ rewriter.replaceOpWithNewOp <SubIOp>(subOp, midSubOp.getOperand (0 ),
1875
+ nextConstant);
1876
+ return success ();
1877
+ }
1878
+ }
1879
+
1880
+ if (auto midSubOp = subOp.getOperand (1 ).getDefiningOp <SubIOp>()) {
1881
+ if (matchPattern (midSubOp.getOperand (0 ), m_ConstantInt (&midConst))) {
1882
+ // origConst - (midConst - something) == (origConst - midConst) +
1883
+ // something
1884
+ auto nextConstant = rewriter.create <ConstantOp>(
1885
+ subOp.getLoc (),
1886
+ rewriter.getIntegerAttr (subOp.getType (), origConst - midConst));
1887
+ rewriter.replaceOpWithNewOp <AddIOp>(subOp, nextConstant,
1888
+ midSubOp.getOperand (1 ));
1889
+ return success ();
1890
+ }
1891
+ if (matchPattern (midSubOp.getOperand (1 ), m_ConstantInt (&midConst))) {
1892
+ // origConst - (something - midConst) == (origConst - midConst) -
1893
+ // something
1894
+ auto nextConstant = rewriter.create <ConstantOp>(
1895
+ subOp.getLoc (),
1896
+ rewriter.getIntegerAttr (subOp.getType (), origConst - midConst));
1897
+ rewriter.replaceOpWithNewOp <SubIOp>(subOp, nextConstant,
1898
+ midSubOp.getOperand (0 ));
1899
+ return success ();
1900
+ }
1901
+ }
1902
+ }
1903
+ return failure ();
1904
+ }
1905
+ };
1906
+
1907
+ void SubIOp::getCanonicalizationPatterns (OwningRewritePatternList &results,
1908
+ MLIRContext *context) {
1909
+ results.insert <SubConstantReorder>(context);
1910
+ }
1911
+
1709
1912
// ===----------------------------------------------------------------------===//
1710
1913
// UIToFPOp
1711
1914
// ===----------------------------------------------------------------------===//
0 commit comments