Skip to content

Commit 62fe67f

Browse files
author
Jeff Niu
committed
[mlir][DCA] Fix visiting call ops when run at function scopes
When dead-code analysis is run at the scope of a function, call ops to other functions at the same level were being marked as unreachable, since the analysis optimistically assumes the call op to have no known predecessors and that all predecessors are known, but the callee would never get visited. This patch fixes the bug by checking if a referenced function is above the top-level op of the analysis, and is thus considered an external callable. Fixes #56830 Reviewed By: zero9178 Differential Revision: https://reviews.llvm.org/D130829
1 parent 68b0aaa commit 62fe67f

File tree

3 files changed

+48
-6
lines changed

3 files changed

+48
-6
lines changed

mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,11 @@ class DeadCodeAnalysis : public DataFlowAnalysis {
234234
/// any of the operand lattices are uninitialized.
235235
Optional<SmallVector<Attribute>> getOperandValues(Operation *op);
236236

237+
/// The top-level operation the analysis is running on. This is used to detect
238+
/// if a callable is outside the scope of the analysis and thus must be
239+
/// considered an external callable.
240+
Operation *analysisScope;
241+
237242
/// A symbol table used for O(1) symbol lookups during simplification.
238243
SymbolTableCollection symbolTable;
239244
};

mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ LogicalResult DeadCodeAnalysis::initialize(Operation *top) {
118118
}
119119

120120
void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
121+
analysisScope = top;
121122
auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
122123
Region &symbolTableRegion = symTable->getRegion(0);
123124
Block *symbolTableBlock = &symbolTableRegion.front();
@@ -278,14 +279,14 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
278279
}
279280

280281
void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
281-
Operation *callableOp = nullptr;
282-
if (Value callableValue = call.getCallableForCallee().dyn_cast<Value>())
283-
callableOp = callableValue.getDefiningOp();
284-
else
285-
callableOp = call.resolveCallable(&symbolTable);
282+
Operation *callableOp = call.resolveCallable(&symbolTable);
286283

287284
// A call to a externally-defined callable has unknown predecessors.
288-
const auto isExternalCallable = [](Operation *op) {
285+
const auto isExternalCallable = [this](Operation *op) {
286+
// A callable outside the analysis scope is an external callable.
287+
if (!analysisScope->isAncestor(op))
288+
return true;
289+
// Otherwise, check if the callable region is defined.
289290
if (auto callable = dyn_cast<CallableOpInterface>(op))
290291
return !callable.getCallableRegion();
291292
return false;

mlir/test/Transforms/sccp-callgraph.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// RUN: mlir-opt -allow-unregistered-dialect %s -sccp -split-input-file | FileCheck %s
22
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="builtin.module(sccp)" -split-input-file | FileCheck %s --check-prefix=NESTED
3+
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="func.func(sccp)" -split-input-file | FileCheck %s --check-prefix=FUNC
34

45
/// Check that a constant is properly propagated through the arguments and
56
/// results of a private function.
@@ -270,3 +271,38 @@ func.func private @unreferenced_private_function() -> i32 {
270271
%result = arith.select %true, %cst0, %cst1 : i32
271272
return %result : i32
272273
}
274+
275+
// -----
276+
277+
/// Check that callables outside the analysis scope are marked as external.
278+
279+
func.func private @foo() -> index {
280+
%0 = arith.constant 10 : index
281+
return %0 : index
282+
}
283+
284+
// CHECK-LABEL: func @bar
285+
// FUNC-LABEL: func @bar
286+
func.func @bar(%arg0: index) -> index {
287+
// CHECK: %[[C10:.*]] = arith.constant 10
288+
%c0 = arith.constant 0 : index
289+
%1 = arith.constant 420 : index
290+
%7 = arith.cmpi eq, %arg0, %c0 : index
291+
cf.cond_br %7, ^bb1(%1 : index), ^bb2
292+
293+
// CHECK: ^bb1(%[[ARG:.*]]: index):
294+
// FUNC: ^bb1(%[[ARG:.*]]: index):
295+
^bb1(%8: index): // 2 preds: ^bb0, ^bb4
296+
// CHECK-NEXT: return %[[ARG]]
297+
// FUNC-NEXT: return %[[ARG]]
298+
return %8 : index
299+
300+
// CHECK: ^bb2
301+
// FUNC: ^bb2
302+
^bb2:
303+
// FUNC-NEXT: %[[FOO:.*]] = call @foo
304+
%13 = call @foo() : () -> index
305+
// CHECK: cf.br ^bb1(%[[C10]]
306+
// FUNC: cf.br ^bb1(%[[FOO]]
307+
cf.br ^bb1(%13 : index)
308+
}

0 commit comments

Comments
 (0)