Skip to content

Commit 5f9cd09

Browse files
committed
[MLIR] Fix affine LICM pass for unknown region holding ops
Fix affine LICM pass for unknown region-holding ops. The logic was completely ignoring regions of unknown ops leading to generation of invalid IR on hoisting. Handle affine.parallel op among those with regions that are supported. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D140738
1 parent aa7aac9 commit 5f9cd09

File tree

2 files changed

+65
-0
lines changed

2 files changed

+65
-0
lines changed

mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,17 @@ bool isOpLoopInvariant(Operation &op, Value indVar, ValueRange iterArgs,
8484
if (!areAllOpsInTheBlockListInvariant(forOp.getLoopBody(), indVar, iterArgs,
8585
opsWithUsers, opsToHoist))
8686
return false;
87+
} else if (auto parOp = dyn_cast<AffineParallelOp>(op)) {
88+
if (!areAllOpsInTheBlockListInvariant(parOp.getLoopBody(), indVar, iterArgs,
89+
opsWithUsers, opsToHoist))
90+
return false;
8791
} else if (isa<AffineDmaStartOp, AffineDmaWaitOp>(op)) {
8892
// TODO: Support DMA ops.
93+
// FIXME: This should be fixed to not special-case these affine DMA ops but
94+
// instead rely on side effects.
95+
return false;
96+
} else if (op.getNumRegions() > 0) {
97+
// We can't handle region-holding ops we don't know about.
8998
return false;
9099
} else if (!matchPattern(&op, m_Constant())) {
91100
// Register op in the set of ops that have users.

mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -752,3 +752,59 @@ func.func @use_of_iter_args_not_invariant(%m : memref<10xindex>) {
752752
// CHECK-NEXT: affine.for
753753
// CHECK-NEXT: arith.addi
754754
// CHECK-NEXT: affine.yield
755+
756+
#map = affine_map<(d0) -> (64, d0 * -64 + 1020)>
757+
// CHECK-LABEL: func.func @affine_parallel
758+
func.func @affine_parallel(%memref_8: memref<4090x2040xf32>, %x: index) {
759+
%cst = arith.constant 0.000000e+00 : f32
760+
affine.parallel (%arg3) = (0) to (32) {
761+
affine.for %arg4 = 0 to 16 {
762+
affine.parallel (%arg5, %arg6) = (0, 0) to (min(128, 122), min(64, %arg3 * -64 + 2040)) {
763+
affine.for %arg7 = 0 to min #map(%arg4) {
764+
affine.store %cst, %memref_8[%arg5 + 3968, %arg6 + %arg3 * 64] : memref<4090x2040xf32>
765+
}
766+
}
767+
}
768+
}
769+
// CHECK: affine.parallel
770+
// CHECK-NEXT: affine.for
771+
// CHECK-NEXT: affine.parallel
772+
// CHECK-NEXT: affine.store
773+
// CHECK-NEXT: affine.for
774+
775+
%c0 = arith.constant 0 : index
776+
%c1 = arith.constant 1 : index
777+
%c32 = arith.constant 32 : index
778+
scf.parallel (%arg3) = (%c0) to (%c32) step (%c1) {
779+
affine.for %arg4 = 0 to 16 {
780+
affine.parallel (%arg5, %arg6) = (0, 0) to (min(128, 122), min(64, %x * -64 + 2040)) {
781+
affine.for %arg7 = 0 to min #map(%arg4) {
782+
affine.store %cst, %memref_8[%arg5 + 3968, %arg6] : memref<4090x2040xf32>
783+
}
784+
}
785+
}
786+
}
787+
// CHECK: scf.parallel
788+
// CHECK-NEXT: affine.for
789+
// CHECK-NEXT: affine.parallel
790+
// CHECK-NEXT: affine.store
791+
// CHECK-NEXT: affine.for
792+
793+
affine.for %arg3 = 0 to 32 {
794+
affine.for %arg4 = 0 to 16 {
795+
affine.parallel (%arg5, %arg6) = (0, 0) to (min(128, 122), min(64, %arg3 * -64 + 2040)) {
796+
// Unknown region-holding op for this pass.
797+
scf.for %arg7 = %c0 to %x step %c1 {
798+
affine.store %cst, %memref_8[%arg5 + 3968, %arg6 + %arg3 * 64] : memref<4090x2040xf32>
799+
}
800+
}
801+
}
802+
}
803+
// CHECK: affine.for
804+
// CHECK-NEXT: affine.for
805+
// CHECK-NEXT: affine.parallel
806+
// CHECK-NEXT: scf.for
807+
// CHECK-NEXT: affine.store
808+
809+
return
810+
}

0 commit comments

Comments
 (0)