forked from OSchip/llvm-project
[mlir][Inliner] Properly handle callgraph node deletion
We previously weren't properly updating the SCC iterator when nodes were removed, leading to asan failures in certain situations. This commit adds a CallGraphSCC class and defers operation deletion until inlining has finished. Differential Revision: https://reviews.llvm.org/D81984
This commit is contained in:
parent
55d53d4f54
commit
f4ef77cbb4
|
@ -242,19 +242,47 @@ void CGUseList::decrementDiscardableUses(CGUser &uses) {
|
||||||
// CallGraph traversal
|
// CallGraph traversal
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
/// This class represents a specific callgraph SCC.
|
||||||
|
class CallGraphSCC {
|
||||||
|
public:
|
||||||
|
CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
|
||||||
|
: parentIterator(parentIterator) {}
|
||||||
|
/// Return a range over the nodes within this SCC.
|
||||||
|
std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); }
|
||||||
|
std::vector<CallGraphNode *>::iterator end() { return nodes.end(); }
|
||||||
|
|
||||||
|
/// Reset the nodes of this SCC with those provided.
|
||||||
|
void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
|
||||||
|
|
||||||
|
/// Remove the given node from this SCC.
|
||||||
|
void remove(CallGraphNode *node) {
|
||||||
|
auto it = llvm::find(nodes, node);
|
||||||
|
if (it != nodes.end()) {
|
||||||
|
nodes.erase(it);
|
||||||
|
parentIterator.ReplaceNode(node, nullptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<CallGraphNode *> nodes;
|
||||||
|
llvm::scc_iterator<const CallGraph *> &parentIterator;
|
||||||
|
};
|
||||||
|
} // end anonymous namespace
|
||||||
|
|
||||||
/// Run a given transformation over the SCCs of the callgraph in a bottom up
|
/// Run a given transformation over the SCCs of the callgraph in a bottom up
|
||||||
/// traversal.
|
/// traversal.
|
||||||
static void runTransformOnCGSCCs(
|
static void
|
||||||
const CallGraph &cg,
|
runTransformOnCGSCCs(const CallGraph &cg,
|
||||||
function_ref<void(MutableArrayRef<CallGraphNode *>)> sccTransformer) {
|
function_ref<void(CallGraphSCC &)> sccTransformer) {
|
||||||
std::vector<CallGraphNode *> currentSCCVec;
|
llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
|
||||||
auto cgi = llvm::scc_begin(&cg);
|
CallGraphSCC currentSCC(cgi);
|
||||||
while (!cgi.isAtEnd()) {
|
while (!cgi.isAtEnd()) {
|
||||||
// Copy the current SCC and increment so that the transformer can modify the
|
// Copy the current SCC and increment so that the transformer can modify the
|
||||||
// SCC without invalidating our iterator.
|
// SCC without invalidating our iterator.
|
||||||
currentSCCVec = *cgi;
|
currentSCC.reset(*cgi);
|
||||||
++cgi;
|
++cgi;
|
||||||
sccTransformer(currentSCCVec);
|
sccTransformer(currentSCC);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -343,6 +371,19 @@ struct Inliner : public InlinerInterface {
|
||||||
/*traverseNestedCGNodes=*/true);
|
/*traverseNestedCGNodes=*/true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Mark the given callgraph node for deletion.
|
||||||
|
void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); }
|
||||||
|
|
||||||
|
/// This method properly disposes of callables that became dead during
|
||||||
|
/// inlining. This should not be called while iterating over the SCCs.
|
||||||
|
void eraseDeadCallables() {
|
||||||
|
for (CallGraphNode *node : deadNodes)
|
||||||
|
node->getCallableRegion()->getParentOp()->erase();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The set of callables known to be dead.
|
||||||
|
SmallPtrSet<CallGraphNode *, 8> deadNodes;
|
||||||
|
|
||||||
/// The current set of call instructions to consider for inlining.
|
/// The current set of call instructions to consider for inlining.
|
||||||
SmallVector<ResolvedCall, 8> calls;
|
SmallVector<ResolvedCall, 8> calls;
|
||||||
|
|
||||||
|
@ -368,27 +409,16 @@ static bool shouldInline(ResolvedCall &resolvedCall) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Delete the given node and remove it from the current scc and the callgraph.
|
|
||||||
static void deleteNode(CallGraphNode *node, CGUseList &useList, CallGraph &cg,
|
|
||||||
MutableArrayRef<CallGraphNode *> currentSCC) {
|
|
||||||
// Erase the parent operation and remove it from the various lists.
|
|
||||||
node->getCallableRegion()->getParentOp()->erase();
|
|
||||||
cg.eraseNode(node);
|
|
||||||
|
|
||||||
// Replace this node in the currentSCC with the external node.
|
|
||||||
auto it = llvm::find(currentSCC, node);
|
|
||||||
if (it != currentSCC.end())
|
|
||||||
*it = cg.getExternalNode();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Attempt to inline calls within the given scc. This function returns
|
/// Attempt to inline calls within the given scc. This function returns
|
||||||
/// success if any calls were inlined, failure otherwise.
|
/// success if any calls were inlined, failure otherwise.
|
||||||
static LogicalResult
|
static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
|
||||||
inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
|
CallGraphSCC ¤tSCC) {
|
||||||
MutableArrayRef<CallGraphNode *> currentSCC) {
|
|
||||||
CallGraph &cg = inliner.cg;
|
CallGraph &cg = inliner.cg;
|
||||||
auto &calls = inliner.calls;
|
auto &calls = inliner.calls;
|
||||||
|
|
||||||
|
// A set of dead nodes to remove after inlining.
|
||||||
|
SmallVector<CallGraphNode *, 1> deadNodes;
|
||||||
|
|
||||||
// Collect all of the direct calls within the nodes of the current SCC. We
|
// Collect all of the direct calls within the nodes of the current SCC. We
|
||||||
// don't traverse nested callgraph nodes, because they are handled separately
|
// don't traverse nested callgraph nodes, because they are handled separately
|
||||||
// likely within a different SCC.
|
// likely within a different SCC.
|
||||||
|
@ -396,18 +426,13 @@ inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
|
||||||
if (node->isExternal())
|
if (node->isExternal())
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
// If this node is dead, just delete it now.
|
// Don't collect calls if the node is already dead.
|
||||||
if (useList.isDead(node))
|
if (useList.isDead(node))
|
||||||
deleteNode(node, useList, cg, currentSCC);
|
deadNodes.push_back(node);
|
||||||
else
|
else
|
||||||
collectCallOps(*node->getCallableRegion(), node, cg, calls,
|
collectCallOps(*node->getCallableRegion(), node, cg, calls,
|
||||||
/*traverseNestedCGNodes=*/false);
|
/*traverseNestedCGNodes=*/false);
|
||||||
}
|
}
|
||||||
if (calls.empty())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// A set of dead nodes to remove after inlining.
|
|
||||||
SmallVector<CallGraphNode *, 1> deadNodes;
|
|
||||||
|
|
||||||
// Try to inline each of the call operations. Don't cache the end iterator
|
// Try to inline each of the call operations. Don't cache the end iterator
|
||||||
// here as more calls may be added during inlining.
|
// here as more calls may be added during inlining.
|
||||||
|
@ -453,8 +478,10 @@ inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (CallGraphNode *node : deadNodes)
|
for (CallGraphNode *node : deadNodes) {
|
||||||
deleteNode(node, useList, cg, currentSCC);
|
currentSCC.remove(node);
|
||||||
|
inliner.markForDeletion(node);
|
||||||
|
}
|
||||||
calls.clear();
|
calls.clear();
|
||||||
return success(inlinedAnyCalls);
|
return success(inlinedAnyCalls);
|
||||||
}
|
}
|
||||||
|
@ -462,8 +489,7 @@ inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
|
||||||
/// Canonicalize the nodes within the given SCC with the given set of
|
/// Canonicalize the nodes within the given SCC with the given set of
|
||||||
/// canonicalization patterns.
|
/// canonicalization patterns.
|
||||||
static void canonicalizeSCC(CallGraph &cg, CGUseList &useList,
|
static void canonicalizeSCC(CallGraph &cg, CGUseList &useList,
|
||||||
MutableArrayRef<CallGraphNode *> currentSCC,
|
CallGraphSCC ¤tSCC, MLIRContext *context,
|
||||||
MLIRContext *context,
|
|
||||||
const OwningRewritePatternList &canonPatterns) {
|
const OwningRewritePatternList &canonPatterns) {
|
||||||
// Collect the sets of nodes to canonicalize.
|
// Collect the sets of nodes to canonicalize.
|
||||||
SmallVector<CallGraphNode *, 4> nodesToCanonicalize;
|
SmallVector<CallGraphNode *, 4> nodesToCanonicalize;
|
||||||
|
@ -533,8 +559,7 @@ struct InlinerPass : public InlinerBase<InlinerPass> {
|
||||||
/// Attempt to inline calls within the given scc, and run canonicalizations
|
/// Attempt to inline calls within the given scc, and run canonicalizations
|
||||||
/// with the given patterns, until a fixed point is reached. This allows for
|
/// with the given patterns, until a fixed point is reached. This allows for
|
||||||
/// the inlining of newly devirtualized calls.
|
/// the inlining of newly devirtualized calls.
|
||||||
void inlineSCC(Inliner &inliner, CGUseList &useList,
|
void inlineSCC(Inliner &inliner, CGUseList &useList, CallGraphSCC ¤tSCC,
|
||||||
MutableArrayRef<CallGraphNode *> currentSCC,
|
|
||||||
MLIRContext *context,
|
MLIRContext *context,
|
||||||
const OwningRewritePatternList &canonPatterns);
|
const OwningRewritePatternList &canonPatterns);
|
||||||
};
|
};
|
||||||
|
@ -562,14 +587,16 @@ void InlinerPass::runOnOperation() {
|
||||||
// Run the inline transform in post-order over the SCCs in the callgraph.
|
// Run the inline transform in post-order over the SCCs in the callgraph.
|
||||||
Inliner inliner(context, cg);
|
Inliner inliner(context, cg);
|
||||||
CGUseList useList(getOperation(), cg);
|
CGUseList useList(getOperation(), cg);
|
||||||
runTransformOnCGSCCs(cg, [&](MutableArrayRef<CallGraphNode *> scc) {
|
runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
|
||||||
inlineSCC(inliner, useList, scc, context, canonPatterns);
|
inlineSCC(inliner, useList, scc, context, canonPatterns);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// After inlining, make sure to erase any callables proven to be dead.
|
||||||
|
inliner.eraseDeadCallables();
|
||||||
}
|
}
|
||||||
|
|
||||||
void InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList,
|
void InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList,
|
||||||
MutableArrayRef<CallGraphNode *> currentSCC,
|
CallGraphSCC ¤tSCC, MLIRContext *context,
|
||||||
MLIRContext *context,
|
|
||||||
const OwningRewritePatternList &canonPatterns) {
|
const OwningRewritePatternList &canonPatterns) {
|
||||||
// If we successfully inlined any calls, run some simplifications on the
|
// If we successfully inlined any calls, run some simplifications on the
|
||||||
// nodes of the scc. Continue attempting to inline until we reach a fixed
|
// nodes of the scc. Continue attempting to inline until we reach a fixed
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: mlir-opt -allow-unregistered-dialect %s -inline | FileCheck %s
|
// RUN: mlir-opt -allow-unregistered-dialect %s -inline -split-input-file | FileCheck %s
|
||||||
|
|
||||||
// This file tests the callgraph dead code elimination performed by the inliner.
|
// This file tests the callgraph dead code elimination performed by the inliner.
|
||||||
|
|
||||||
|
@ -51,3 +51,23 @@ func @live_function_d() attributes {sym_visibility = "private"} {
|
||||||
}
|
}
|
||||||
|
|
||||||
"live.user"() {use = @live_function_d} : () -> ()
|
"live.user"() {use = @live_function_d} : () -> ()
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// This test checks that the inliner can properly handle the deletion of
|
||||||
|
// functions in different SCCs that are referenced by calls materialized during
|
||||||
|
// canonicalization.
|
||||||
|
// CHECK: func @live_function_e
|
||||||
|
func @live_function_e() {
|
||||||
|
call @dead_function_e() : () -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK-NOT: func @dead_function_e
|
||||||
|
func @dead_function_e() -> () attributes {sym_visibility = "private"} {
|
||||||
|
"test.fold_to_call_op"() {callee=@dead_function_f} : () -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK-NOT: func @dead_function_f
|
||||||
|
func @dead_function_f() attributes {sym_visibility = "private"} {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -173,6 +173,28 @@ TestBranchOp::getMutableSuccessorOperands(unsigned index) {
|
||||||
return targetOperandsMutable();
|
return targetOperandsMutable();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TestFoldToCallOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
|
||||||
|
using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(FoldToCallOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
rewriter.replaceOpWithNewOp<CallOp>(op, ArrayRef<Type>(), op.calleeAttr(),
|
||||||
|
ValueRange());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
void FoldToCallOp::getCanonicalizationPatterns(
|
||||||
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
|
results.insert<FoldToCallOpPattern>(context);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Test IsolatedRegionOp - parse passthrough region arguments.
|
// Test IsolatedRegionOp - parse passthrough region arguments.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -321,6 +321,12 @@ def FunctionalRegionOp : TEST_Op<"functional_region_op",
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def FoldToCallOp : TEST_Op<"fold_to_call_op"> {
|
||||||
|
let arguments = (ins FlatSymbolRefAttr:$callee);
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Test Traits
|
// Test Traits
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
Loading…
Reference in New Issue