forked from OSchip/llvm-project
[mlir][Inliner] Properly handle callgraph node deletion
We previously weren't properly updating the SCC iterator when nodes were removed, leading to asan failures in certain situations. This commit adds a CallGraphSCC class and defers operation deletion until inlining has finished. Differential Revision: https://reviews.llvm.org/D81984
This commit is contained in:
parent
55d53d4f54
commit
f4ef77cbb4
|
@ -242,19 +242,47 @@ void CGUseList::decrementDiscardableUses(CGUser &uses) {
|
|||
// CallGraph traversal
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// This class represents a specific callgraph SCC.
|
||||
class CallGraphSCC {
|
||||
public:
|
||||
CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
|
||||
: parentIterator(parentIterator) {}
|
||||
/// Return a range over the nodes within this SCC.
|
||||
std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); }
|
||||
std::vector<CallGraphNode *>::iterator end() { return nodes.end(); }
|
||||
|
||||
/// Reset the nodes of this SCC with those provided.
|
||||
void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
|
||||
|
||||
/// Remove the given node from this SCC.
|
||||
void remove(CallGraphNode *node) {
|
||||
auto it = llvm::find(nodes, node);
|
||||
if (it != nodes.end()) {
|
||||
nodes.erase(it);
|
||||
parentIterator.ReplaceNode(node, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<CallGraphNode *> nodes;
|
||||
llvm::scc_iterator<const CallGraph *> &parentIterator;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
/// Run a given transformation over the SCCs of the callgraph in a bottom up
|
||||
/// traversal.
|
||||
static void runTransformOnCGSCCs(
|
||||
const CallGraph &cg,
|
||||
function_ref<void(MutableArrayRef<CallGraphNode *>)> sccTransformer) {
|
||||
std::vector<CallGraphNode *> currentSCCVec;
|
||||
auto cgi = llvm::scc_begin(&cg);
|
||||
static void
|
||||
runTransformOnCGSCCs(const CallGraph &cg,
|
||||
function_ref<void(CallGraphSCC &)> sccTransformer) {
|
||||
llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
|
||||
CallGraphSCC currentSCC(cgi);
|
||||
while (!cgi.isAtEnd()) {
|
||||
// Copy the current SCC and increment so that the transformer can modify the
|
||||
// SCC without invalidating our iterator.
|
||||
currentSCCVec = *cgi;
|
||||
currentSCC.reset(*cgi);
|
||||
++cgi;
|
||||
sccTransformer(currentSCCVec);
|
||||
sccTransformer(currentSCC);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -343,6 +371,19 @@ struct Inliner : public InlinerInterface {
|
|||
/*traverseNestedCGNodes=*/true);
|
||||
}
|
||||
|
||||
/// Mark the given callgraph node for deletion.
|
||||
void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); }
|
||||
|
||||
/// This method properly disposes of callables that became dead during
|
||||
/// inlining. This should not be called while iterating over the SCCs.
|
||||
void eraseDeadCallables() {
|
||||
for (CallGraphNode *node : deadNodes)
|
||||
node->getCallableRegion()->getParentOp()->erase();
|
||||
}
|
||||
|
||||
/// The set of callables known to be dead.
|
||||
SmallPtrSet<CallGraphNode *, 8> deadNodes;
|
||||
|
||||
/// The current set of call instructions to consider for inlining.
|
||||
SmallVector<ResolvedCall, 8> calls;
|
||||
|
||||
|
@ -368,27 +409,16 @@ static bool shouldInline(ResolvedCall &resolvedCall) {
|
|||
return true;
|
||||
}
|
||||
|
||||
/// Delete the given node and remove it from the current scc and the callgraph.
|
||||
static void deleteNode(CallGraphNode *node, CGUseList &useList, CallGraph &cg,
|
||||
MutableArrayRef<CallGraphNode *> currentSCC) {
|
||||
// Erase the parent operation and remove it from the various lists.
|
||||
node->getCallableRegion()->getParentOp()->erase();
|
||||
cg.eraseNode(node);
|
||||
|
||||
// Replace this node in the currentSCC with the external node.
|
||||
auto it = llvm::find(currentSCC, node);
|
||||
if (it != currentSCC.end())
|
||||
*it = cg.getExternalNode();
|
||||
}
|
||||
|
||||
/// Attempt to inline calls within the given scc. This function returns
|
||||
/// success if any calls were inlined, failure otherwise.
|
||||
static LogicalResult
|
||||
inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
|
||||
MutableArrayRef<CallGraphNode *> currentSCC) {
|
||||
static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
|
||||
CallGraphSCC ¤tSCC) {
|
||||
CallGraph &cg = inliner.cg;
|
||||
auto &calls = inliner.calls;
|
||||
|
||||
// A set of dead nodes to remove after inlining.
|
||||
SmallVector<CallGraphNode *, 1> deadNodes;
|
||||
|
||||
// Collect all of the direct calls within the nodes of the current SCC. We
|
||||
// don't traverse nested callgraph nodes, because they are handled separately
|
||||
// likely within a different SCC.
|
||||
|
@ -396,18 +426,13 @@ inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
|
|||
if (node->isExternal())
|
||||
continue;
|
||||
|
||||
// If this node is dead, just delete it now.
|
||||
// Don't collect calls if the node is already dead.
|
||||
if (useList.isDead(node))
|
||||
deleteNode(node, useList, cg, currentSCC);
|
||||
deadNodes.push_back(node);
|
||||
else
|
||||
collectCallOps(*node->getCallableRegion(), node, cg, calls,
|
||||
/*traverseNestedCGNodes=*/false);
|
||||
}
|
||||
if (calls.empty())
|
||||
return failure();
|
||||
|
||||
// A set of dead nodes to remove after inlining.
|
||||
SmallVector<CallGraphNode *, 1> deadNodes;
|
||||
|
||||
// Try to inline each of the call operations. Don't cache the end iterator
|
||||
// here as more calls may be added during inlining.
|
||||
|
@ -453,8 +478,10 @@ inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
|
|||
}
|
||||
}
|
||||
|
||||
for (CallGraphNode *node : deadNodes)
|
||||
deleteNode(node, useList, cg, currentSCC);
|
||||
for (CallGraphNode *node : deadNodes) {
|
||||
currentSCC.remove(node);
|
||||
inliner.markForDeletion(node);
|
||||
}
|
||||
calls.clear();
|
||||
return success(inlinedAnyCalls);
|
||||
}
|
||||
|
@ -462,8 +489,7 @@ inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
|
|||
/// Canonicalize the nodes within the given SCC with the given set of
|
||||
/// canonicalization patterns.
|
||||
static void canonicalizeSCC(CallGraph &cg, CGUseList &useList,
|
||||
MutableArrayRef<CallGraphNode *> currentSCC,
|
||||
MLIRContext *context,
|
||||
CallGraphSCC ¤tSCC, MLIRContext *context,
|
||||
const OwningRewritePatternList &canonPatterns) {
|
||||
// Collect the sets of nodes to canonicalize.
|
||||
SmallVector<CallGraphNode *, 4> nodesToCanonicalize;
|
||||
|
@ -533,8 +559,7 @@ struct InlinerPass : public InlinerBase<InlinerPass> {
|
|||
/// Attempt to inline calls within the given scc, and run canonicalizations
|
||||
/// with the given patterns, until a fixed point is reached. This allows for
|
||||
/// the inlining of newly devirtualized calls.
|
||||
void inlineSCC(Inliner &inliner, CGUseList &useList,
|
||||
MutableArrayRef<CallGraphNode *> currentSCC,
|
||||
void inlineSCC(Inliner &inliner, CGUseList &useList, CallGraphSCC ¤tSCC,
|
||||
MLIRContext *context,
|
||||
const OwningRewritePatternList &canonPatterns);
|
||||
};
|
||||
|
@ -562,14 +587,16 @@ void InlinerPass::runOnOperation() {
|
|||
// Run the inline transform in post-order over the SCCs in the callgraph.
|
||||
Inliner inliner(context, cg);
|
||||
CGUseList useList(getOperation(), cg);
|
||||
runTransformOnCGSCCs(cg, [&](MutableArrayRef<CallGraphNode *> scc) {
|
||||
runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
|
||||
inlineSCC(inliner, useList, scc, context, canonPatterns);
|
||||
});
|
||||
|
||||
// After inlining, make sure to erase any callables proven to be dead.
|
||||
inliner.eraseDeadCallables();
|
||||
}
|
||||
|
||||
void InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList,
|
||||
MutableArrayRef<CallGraphNode *> currentSCC,
|
||||
MLIRContext *context,
|
||||
CallGraphSCC ¤tSCC, MLIRContext *context,
|
||||
const OwningRewritePatternList &canonPatterns) {
|
||||
// If we successfully inlined any calls, run some simplifications on the
|
||||
// nodes of the scc. Continue attempting to inline until we reach a fixed
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-opt -allow-unregistered-dialect %s -inline | FileCheck %s
|
||||
// RUN: mlir-opt -allow-unregistered-dialect %s -inline -split-input-file | FileCheck %s
|
||||
|
||||
// This file tests the callgraph dead code elimination performed by the inliner.
|
||||
|
||||
|
@ -51,3 +51,23 @@ func @live_function_d() attributes {sym_visibility = "private"} {
|
|||
}
|
||||
|
||||
"live.user"() {use = @live_function_d} : () -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// This test checks that the inliner can properly handle the deletion of
|
||||
// functions in different SCCs that are referenced by calls materialized during
|
||||
// canonicalization.
|
||||
// CHECK: func @live_function_e
|
||||
func @live_function_e() {
|
||||
call @dead_function_e() : () -> ()
|
||||
return
|
||||
}
|
||||
// CHECK-NOT: func @dead_function_e
|
||||
func @dead_function_e() -> () attributes {sym_visibility = "private"} {
|
||||
"test.fold_to_call_op"() {callee=@dead_function_f} : () -> ()
|
||||
return
|
||||
}
|
||||
// CHECK-NOT: func @dead_function_f
|
||||
func @dead_function_f() attributes {sym_visibility = "private"} {
|
||||
return
|
||||
}
|
||||
|
|
|
@ -173,6 +173,28 @@ TestBranchOp::getMutableSuccessorOperands(unsigned index) {
|
|||
return targetOperandsMutable();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TestFoldToCallOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
|
||||
using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(FoldToCallOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<CallOp>(op, ArrayRef<Type>(), op.calleeAttr(),
|
||||
ValueRange());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
void FoldToCallOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<FoldToCallOpPattern>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test IsolatedRegionOp - parse passthrough region arguments.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -321,6 +321,12 @@ def FunctionalRegionOp : TEST_Op<"functional_region_op",
|
|||
}];
|
||||
}
|
||||
|
||||
|
||||
def FoldToCallOp : TEST_Op<"fold_to_call_op"> {
|
||||
let arguments = (ins FlatSymbolRefAttr:$callee);
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test Traits
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
Loading…
Reference in New Issue