Skip to content

Commit e413b86

Browse files
committed
[MLIR][Shape] Combine cstr_eq only if they share shape operands
Differential Revision: https://reviews.llvm.org/D100198
1 parent 2450369 commit e413b86

File tree

2 files changed

+25
-5
lines changed

2 files changed

+25
-5
lines changed

mlir/lib/Dialect/Shape/IR/Shape.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -389,12 +389,16 @@ struct AssumingAllToCstrEqCanonicalization
389389
LogicalResult matchAndRewrite(AssumingAllOp op,
390390
PatternRewriter &rewriter) const override {
391391
SmallVector<Value, 8> shapes;
392-
for (Value v : op.inputs()) {
393-
auto cstrEqOp = v.getDefiningOp<CstrEqOp>();
392+
for (Value w : op.inputs()) {
393+
auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
394394
if (!cstrEqOp)
395395
return failure();
396-
auto range = cstrEqOp.shapes();
397-
shapes.append(range.begin(), range.end());
396+
bool disjointShapes = llvm::none_of(cstrEqOp.shapes(), [&](Value s) {
397+
return llvm::is_contained(shapes, s);
398+
});
399+
if (!shapes.empty() && !cstrEqOp.shapes().empty() && disjointShapes)
400+
return failure();
401+
shapes.append(cstrEqOp.shapes().begin(), cstrEqOp.shapes().end());
398402
}
399403
rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
400404
return success();

mlir/test/Dialect/Shape/canonicalize.mlir

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ func @cstr_require_no_fold(%arg0: i1) {
434434
}
435435

436436
// -----
437-
// `assuming_all` with all `cstr_eq` can be collapsed.
437+
// `assuming_all` with all `cstr_eq` and shared operands can be collapsed.
438438
// CHECK-LABEL: func @assuming_all_to_cstr_eq
439439
// CHECK-SAME: (%[[A:.*]]: !shape.shape, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<3xindex>)
440440
func @assuming_all_to_cstr_eq(%a : !shape.shape, %b : tensor<?xindex>,
@@ -447,6 +447,22 @@ func @assuming_all_to_cstr_eq(%a : !shape.shape, %b : tensor<?xindex>,
447447
return %2 : !shape.witness
448448
}
449449

450+
// -----
451+
// `assuming_all` with all `cstr_eq` but disjoint operands cannot be collapsed.
452+
// CHECK-LABEL: func @assuming_all_to_cstr_eq
453+
// CHECK-SAME: (%[[A:.*]]: !shape.shape, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<3xindex>, %[[D:.*]]: tensor<3xindex>)
454+
func @assuming_all_to_cstr_eq(%a : !shape.shape, %b : tensor<?xindex>,
455+
%c : tensor<3xindex>, %d : tensor<3xindex>) -> !shape.witness {
456+
// CHECK: %[[EQ0:.*]] = shape.cstr_eq %[[A]], %[[B]]
457+
// CHECK: %[[EQ1:.*]] = shape.cstr_eq %[[C]], %[[D]]
458+
// CHECK: %[[RESULT:.*]] = shape.assuming_all %[[EQ0]], %[[EQ1]]
459+
// CHECK: return %[[RESULT]]
460+
%0 = shape.cstr_eq %a, %b : !shape.shape, tensor<?xindex>
461+
%1 = shape.cstr_eq %c, %d : tensor<3xindex>, tensor<3xindex>
462+
%2 = shape.assuming_all %0, %1
463+
return %2 : !shape.witness
464+
}
465+
450466
// -----
451467
// assuming_all with known passing witnesses can be folded
452468
// CHECK-LABEL: func @f

0 commit comments

Comments
 (0)