[mlir][Inliner] Properly handle callgraph node deletion

We previously weren't properly updating the SCC iterator when nodes were removed, leading to asan failures in certain situations. This commit adds a CallGraphSCC class and defers operation deletion until inlining has finished.

Differential Revision: https://reviews.llvm.org/D81984
This commit is contained in:
River Riddle 2020-06-17 13:13:48 -07:00
parent 55d53d4f54
commit f4ef77cbb4
4 changed files with 115 additions and 40 deletions

View File

@ -242,19 +242,47 @@ void CGUseList::decrementDiscardableUses(CGUser &uses) {
// CallGraph traversal // CallGraph traversal
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
namespace {
/// This class represents a specific callgraph SCC.
class CallGraphSCC {
public:
CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
: parentIterator(parentIterator) {}
/// Return a range over the nodes within this SCC.
std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); }
std::vector<CallGraphNode *>::iterator end() { return nodes.end(); }
/// Reset the nodes of this SCC with those provided.
void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
/// Remove the given node from this SCC.
void remove(CallGraphNode *node) {
auto it = llvm::find(nodes, node);
if (it != nodes.end()) {
nodes.erase(it);
parentIterator.ReplaceNode(node, nullptr);
}
}
private:
std::vector<CallGraphNode *> nodes;
llvm::scc_iterator<const CallGraph *> &parentIterator;
};
} // end anonymous namespace
/// Run a given transformation over the SCCs of the callgraph in a bottom up /// Run a given transformation over the SCCs of the callgraph in a bottom up
/// traversal. /// traversal.
static void runTransformOnCGSCCs( static void
const CallGraph &cg, runTransformOnCGSCCs(const CallGraph &cg,
function_ref<void(MutableArrayRef<CallGraphNode *>)> sccTransformer) { function_ref<void(CallGraphSCC &)> sccTransformer) {
std::vector<CallGraphNode *> currentSCCVec; llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
auto cgi = llvm::scc_begin(&cg); CallGraphSCC currentSCC(cgi);
while (!cgi.isAtEnd()) { while (!cgi.isAtEnd()) {
// Copy the current SCC and increment so that the transformer can modify the // Copy the current SCC and increment so that the transformer can modify the
// SCC without invalidating our iterator. // SCC without invalidating our iterator.
currentSCCVec = *cgi; currentSCC.reset(*cgi);
++cgi; ++cgi;
sccTransformer(currentSCCVec); sccTransformer(currentSCC);
} }
} }
@ -343,6 +371,19 @@ struct Inliner : public InlinerInterface {
/*traverseNestedCGNodes=*/true); /*traverseNestedCGNodes=*/true);
} }
/// Mark the given callgraph node for deletion.
void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); }
/// This method properly disposes of callables that became dead during
/// inlining. This should not be called while iterating over the SCCs.
void eraseDeadCallables() {
for (CallGraphNode *node : deadNodes)
node->getCallableRegion()->getParentOp()->erase();
}
/// The set of callables known to be dead.
SmallPtrSet<CallGraphNode *, 8> deadNodes;
/// The current set of call instructions to consider for inlining. /// The current set of call instructions to consider for inlining.
SmallVector<ResolvedCall, 8> calls; SmallVector<ResolvedCall, 8> calls;
@ -368,27 +409,16 @@ static bool shouldInline(ResolvedCall &resolvedCall) {
return true; return true;
} }
/// Delete the given node and remove it from the current scc and the callgraph.
static void deleteNode(CallGraphNode *node, CGUseList &useList, CallGraph &cg,
MutableArrayRef<CallGraphNode *> currentSCC) {
// Erase the parent operation and remove it from the various lists.
node->getCallableRegion()->getParentOp()->erase();
cg.eraseNode(node);
// Replace this node in the currentSCC with the external node.
auto it = llvm::find(currentSCC, node);
if (it != currentSCC.end())
*it = cg.getExternalNode();
}
/// Attempt to inline calls within the given scc. This function returns /// Attempt to inline calls within the given scc. This function returns
/// success if any calls were inlined, failure otherwise. /// success if any calls were inlined, failure otherwise.
static LogicalResult static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
inlineCallsInSCC(Inliner &inliner, CGUseList &useList, CallGraphSCC &currentSCC) {
MutableArrayRef<CallGraphNode *> currentSCC) {
CallGraph &cg = inliner.cg; CallGraph &cg = inliner.cg;
auto &calls = inliner.calls; auto &calls = inliner.calls;
// A set of dead nodes to remove after inlining.
SmallVector<CallGraphNode *, 1> deadNodes;
// Collect all of the direct calls within the nodes of the current SCC. We // Collect all of the direct calls within the nodes of the current SCC. We
// don't traverse nested callgraph nodes, because they are handled separately // don't traverse nested callgraph nodes, because they are handled separately
// likely within a different SCC. // likely within a different SCC.
@ -396,18 +426,13 @@ inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
if (node->isExternal()) if (node->isExternal())
continue; continue;
// If this node is dead, just delete it now. // Don't collect calls if the node is already dead.
if (useList.isDead(node)) if (useList.isDead(node))
deleteNode(node, useList, cg, currentSCC); deadNodes.push_back(node);
else else
collectCallOps(*node->getCallableRegion(), node, cg, calls, collectCallOps(*node->getCallableRegion(), node, cg, calls,
/*traverseNestedCGNodes=*/false); /*traverseNestedCGNodes=*/false);
} }
if (calls.empty())
return failure();
// A set of dead nodes to remove after inlining.
SmallVector<CallGraphNode *, 1> deadNodes;
// Try to inline each of the call operations. Don't cache the end iterator // Try to inline each of the call operations. Don't cache the end iterator
// here as more calls may be added during inlining. // here as more calls may be added during inlining.
@ -453,8 +478,10 @@ inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
} }
} }
for (CallGraphNode *node : deadNodes) for (CallGraphNode *node : deadNodes) {
deleteNode(node, useList, cg, currentSCC); currentSCC.remove(node);
inliner.markForDeletion(node);
}
calls.clear(); calls.clear();
return success(inlinedAnyCalls); return success(inlinedAnyCalls);
} }
@ -462,8 +489,7 @@ inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
/// Canonicalize the nodes within the given SCC with the given set of /// Canonicalize the nodes within the given SCC with the given set of
/// canonicalization patterns. /// canonicalization patterns.
static void canonicalizeSCC(CallGraph &cg, CGUseList &useList, static void canonicalizeSCC(CallGraph &cg, CGUseList &useList,
MutableArrayRef<CallGraphNode *> currentSCC, CallGraphSCC &currentSCC, MLIRContext *context,
MLIRContext *context,
const OwningRewritePatternList &canonPatterns) { const OwningRewritePatternList &canonPatterns) {
// Collect the sets of nodes to canonicalize. // Collect the sets of nodes to canonicalize.
SmallVector<CallGraphNode *, 4> nodesToCanonicalize; SmallVector<CallGraphNode *, 4> nodesToCanonicalize;
@ -533,8 +559,7 @@ struct InlinerPass : public InlinerBase<InlinerPass> {
/// Attempt to inline calls within the given scc, and run canonicalizations /// Attempt to inline calls within the given scc, and run canonicalizations
/// with the given patterns, until a fixed point is reached. This allows for /// with the given patterns, until a fixed point is reached. This allows for
/// the inlining of newly devirtualized calls. /// the inlining of newly devirtualized calls.
void inlineSCC(Inliner &inliner, CGUseList &useList, void inlineSCC(Inliner &inliner, CGUseList &useList, CallGraphSCC &currentSCC,
MutableArrayRef<CallGraphNode *> currentSCC,
MLIRContext *context, MLIRContext *context,
const OwningRewritePatternList &canonPatterns); const OwningRewritePatternList &canonPatterns);
}; };
@ -562,14 +587,16 @@ void InlinerPass::runOnOperation() {
// Run the inline transform in post-order over the SCCs in the callgraph. // Run the inline transform in post-order over the SCCs in the callgraph.
Inliner inliner(context, cg); Inliner inliner(context, cg);
CGUseList useList(getOperation(), cg); CGUseList useList(getOperation(), cg);
runTransformOnCGSCCs(cg, [&](MutableArrayRef<CallGraphNode *> scc) { runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
inlineSCC(inliner, useList, scc, context, canonPatterns); inlineSCC(inliner, useList, scc, context, canonPatterns);
}); });
// After inlining, make sure to erase any callables proven to be dead.
inliner.eraseDeadCallables();
} }
void InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList, void InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList,
MutableArrayRef<CallGraphNode *> currentSCC, CallGraphSCC &currentSCC, MLIRContext *context,
MLIRContext *context,
const OwningRewritePatternList &canonPatterns) { const OwningRewritePatternList &canonPatterns) {
// If we successfully inlined any calls, run some simplifications on the // If we successfully inlined any calls, run some simplifications on the
// nodes of the scc. Continue attempting to inline until we reach a fixed // nodes of the scc. Continue attempting to inline until we reach a fixed

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt -allow-unregistered-dialect %s -inline | FileCheck %s // RUN: mlir-opt -allow-unregistered-dialect %s -inline -split-input-file | FileCheck %s
// This file tests the callgraph dead code elimination performed by the inliner. // This file tests the callgraph dead code elimination performed by the inliner.
@ -51,3 +51,23 @@ func @live_function_d() attributes {sym_visibility = "private"} {
} }
"live.user"() {use = @live_function_d} : () -> () "live.user"() {use = @live_function_d} : () -> ()
// -----
// This test checks that the inliner can properly handle the deletion of
// functions in different SCCs that are referenced by calls materialized during
// canonicalization.
// CHECK: func @live_function_e
func @live_function_e() {
call @dead_function_e() : () -> ()
return
}
// CHECK-NOT: func @dead_function_e
func @dead_function_e() -> () attributes {sym_visibility = "private"} {
"test.fold_to_call_op"() {callee=@dead_function_f} : () -> ()
return
}
// CHECK-NOT: func @dead_function_f
func @dead_function_f() attributes {sym_visibility = "private"} {
return
}

View File

@ -173,6 +173,28 @@ TestBranchOp::getMutableSuccessorOperands(unsigned index) {
return targetOperandsMutable(); return targetOperandsMutable();
} }
//===----------------------------------------------------------------------===//
// TestFoldToCallOp
//===----------------------------------------------------------------------===//
namespace {
struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
LogicalResult matchAndRewrite(FoldToCallOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<CallOp>(op, ArrayRef<Type>(), op.calleeAttr(),
ValueRange());
return success();
}
};
} // end anonymous namespace
void FoldToCallOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<FoldToCallOpPattern>(context);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Test IsolatedRegionOp - parse passthrough region arguments. // Test IsolatedRegionOp - parse passthrough region arguments.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -321,6 +321,12 @@ def FunctionalRegionOp : TEST_Op<"functional_region_op",
}]; }];
} }
def FoldToCallOp : TEST_Op<"fold_to_call_op"> {
let arguments = (ins FlatSymbolRefAttr:$callee);
let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Test Traits // Test Traits
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//