[mlir][Inliner] Properly handle callgraph node deletion

We previously weren't properly updating the SCC iterator when nodes were removed, leading to asan failures in certain situations. This commit adds a CallGraphSCC class and defers operation deletion until inlining has finished.

Differential Revision: https://reviews.llvm.org/D81984
This commit is contained in:
River Riddle 2020-06-17 13:13:48 -07:00
parent 55d53d4f54
commit f4ef77cbb4
4 changed files with 115 additions and 40 deletions

View File

@ -242,19 +242,47 @@ void CGUseList::decrementDiscardableUses(CGUser &uses) {
// CallGraph traversal
//===----------------------------------------------------------------------===//
namespace {
/// This class represents a specific callgraph SCC.
class CallGraphSCC {
public:
CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
: parentIterator(parentIterator) {}
/// Return a range over the nodes within this SCC.
std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); }
std::vector<CallGraphNode *>::iterator end() { return nodes.end(); }
/// Reset the nodes of this SCC with those provided.
void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
/// Remove the given node from this SCC.
void remove(CallGraphNode *node) {
auto it = llvm::find(nodes, node);
if (it != nodes.end()) {
nodes.erase(it);
parentIterator.ReplaceNode(node, nullptr);
}
}
private:
std::vector<CallGraphNode *> nodes;
llvm::scc_iterator<const CallGraph *> &parentIterator;
};
} // end anonymous namespace
/// Run a given transformation over the SCCs of the callgraph in a bottom up
/// traversal.
static void runTransformOnCGSCCs(
const CallGraph &cg,
function_ref<void(MutableArrayRef<CallGraphNode *>)> sccTransformer) {
std::vector<CallGraphNode *> currentSCCVec;
auto cgi = llvm::scc_begin(&cg);
static void
runTransformOnCGSCCs(const CallGraph &cg,
function_ref<void(CallGraphSCC &)> sccTransformer) {
llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
CallGraphSCC currentSCC(cgi);
while (!cgi.isAtEnd()) {
// Copy the current SCC and increment so that the transformer can modify the
// SCC without invalidating our iterator.
currentSCCVec = *cgi;
currentSCC.reset(*cgi);
++cgi;
sccTransformer(currentSCCVec);
sccTransformer(currentSCC);
}
}
@ -343,6 +371,19 @@ struct Inliner : public InlinerInterface {
/*traverseNestedCGNodes=*/true);
}
/// Mark the given callgraph node for deletion.
void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); }
/// This method properly disposes of callables that became dead during
/// inlining. This should not be called while iterating over the SCCs.
void eraseDeadCallables() {
for (CallGraphNode *node : deadNodes)
node->getCallableRegion()->getParentOp()->erase();
}
/// The set of callables known to be dead.
SmallPtrSet<CallGraphNode *, 8> deadNodes;
/// The current set of call instructions to consider for inlining.
SmallVector<ResolvedCall, 8> calls;
@ -368,27 +409,16 @@ static bool shouldInline(ResolvedCall &resolvedCall) {
return true;
}
/// Delete the given node and remove it from the current scc and the callgraph.
static void deleteNode(CallGraphNode *node, CGUseList &useList, CallGraph &cg,
MutableArrayRef<CallGraphNode *> currentSCC) {
// Erase the parent operation and remove it from the various lists.
node->getCallableRegion()->getParentOp()->erase();
cg.eraseNode(node);
// Replace this node in the currentSCC with the external node.
auto it = llvm::find(currentSCC, node);
if (it != currentSCC.end())
*it = cg.getExternalNode();
}
/// Attempt to inline calls within the given scc. This function returns
/// success if any calls were inlined, failure otherwise.
static LogicalResult
inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
MutableArrayRef<CallGraphNode *> currentSCC) {
static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
CallGraphSCC &currentSCC) {
CallGraph &cg = inliner.cg;
auto &calls = inliner.calls;
// A set of dead nodes to remove after inlining.
SmallVector<CallGraphNode *, 1> deadNodes;
// Collect all of the direct calls within the nodes of the current SCC. We
// don't traverse nested callgraph nodes, because they are handled separately
// likely within a different SCC.
@ -396,18 +426,13 @@ inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
if (node->isExternal())
continue;
// If this node is dead, just delete it now.
// Don't collect calls if the node is already dead.
if (useList.isDead(node))
deleteNode(node, useList, cg, currentSCC);
deadNodes.push_back(node);
else
collectCallOps(*node->getCallableRegion(), node, cg, calls,
/*traverseNestedCGNodes=*/false);
}
if (calls.empty())
return failure();
// A set of dead nodes to remove after inlining.
SmallVector<CallGraphNode *, 1> deadNodes;
// Try to inline each of the call operations. Don't cache the end iterator
// here as more calls may be added during inlining.
@ -453,8 +478,10 @@ inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
}
}
for (CallGraphNode *node : deadNodes)
deleteNode(node, useList, cg, currentSCC);
for (CallGraphNode *node : deadNodes) {
currentSCC.remove(node);
inliner.markForDeletion(node);
}
calls.clear();
return success(inlinedAnyCalls);
}
@ -462,8 +489,7 @@ inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
/// Canonicalize the nodes within the given SCC with the given set of
/// canonicalization patterns.
static void canonicalizeSCC(CallGraph &cg, CGUseList &useList,
MutableArrayRef<CallGraphNode *> currentSCC,
MLIRContext *context,
CallGraphSCC &currentSCC, MLIRContext *context,
const OwningRewritePatternList &canonPatterns) {
// Collect the sets of nodes to canonicalize.
SmallVector<CallGraphNode *, 4> nodesToCanonicalize;
@ -533,8 +559,7 @@ struct InlinerPass : public InlinerBase<InlinerPass> {
/// Attempt to inline calls within the given scc, and run canonicalizations
/// with the given patterns, until a fixed point is reached. This allows for
/// the inlining of newly devirtualized calls.
void inlineSCC(Inliner &inliner, CGUseList &useList,
MutableArrayRef<CallGraphNode *> currentSCC,
void inlineSCC(Inliner &inliner, CGUseList &useList, CallGraphSCC &currentSCC,
MLIRContext *context,
const OwningRewritePatternList &canonPatterns);
};
@ -562,14 +587,16 @@ void InlinerPass::runOnOperation() {
// Run the inline transform in post-order over the SCCs in the callgraph.
Inliner inliner(context, cg);
CGUseList useList(getOperation(), cg);
runTransformOnCGSCCs(cg, [&](MutableArrayRef<CallGraphNode *> scc) {
runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
inlineSCC(inliner, useList, scc, context, canonPatterns);
});
// After inlining, make sure to erase any callables proven to be dead.
inliner.eraseDeadCallables();
}
void InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList,
MutableArrayRef<CallGraphNode *> currentSCC,
MLIRContext *context,
CallGraphSCC &currentSCC, MLIRContext *context,
const OwningRewritePatternList &canonPatterns) {
// If we successfully inlined any calls, run some simplifications on the
// nodes of the scc. Continue attempting to inline until we reach a fixed

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt -allow-unregistered-dialect %s -inline | FileCheck %s
// RUN: mlir-opt -allow-unregistered-dialect %s -inline -split-input-file | FileCheck %s
// This file tests the callgraph dead code elimination performed by the inliner.
@ -51,3 +51,23 @@ func @live_function_d() attributes {sym_visibility = "private"} {
}
"live.user"() {use = @live_function_d} : () -> ()
// -----
// This test checks that the inliner can properly handle the deletion of
// functions in different SCCs that are referenced by calls materialized during
// canonicalization.
// CHECK: func @live_function_e
func @live_function_e() {
call @dead_function_e() : () -> ()
return
}
// CHECK-NOT: func @dead_function_e
func @dead_function_e() -> () attributes {sym_visibility = "private"} {
"test.fold_to_call_op"() {callee=@dead_function_f} : () -> ()
return
}
// CHECK-NOT: func @dead_function_f
func @dead_function_f() attributes {sym_visibility = "private"} {
return
}

View File

@ -173,6 +173,28 @@ TestBranchOp::getMutableSuccessorOperands(unsigned index) {
return targetOperandsMutable();
}
//===----------------------------------------------------------------------===//
// TestFoldToCallOp
//===----------------------------------------------------------------------===//
namespace {
struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
LogicalResult matchAndRewrite(FoldToCallOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<CallOp>(op, ArrayRef<Type>(), op.calleeAttr(),
ValueRange());
return success();
}
};
} // end anonymous namespace
void FoldToCallOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<FoldToCallOpPattern>(context);
}
//===----------------------------------------------------------------------===//
// Test IsolatedRegionOp - parse passthrough region arguments.
//===----------------------------------------------------------------------===//

View File

@ -321,6 +321,12 @@ def FunctionalRegionOp : TEST_Op<"functional_region_op",
}];
}
def FoldToCallOp : TEST_Op<"fold_to_call_op"> {
let arguments = (ins FlatSymbolRefAttr:$callee);
let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
// Test Traits
//===----------------------------------------------------------------------===//