forked from OSchip/llvm-project
[mlir][Inliner] Support recursion in Inliner
This fixes Bug https://github.com/llvm/llvm-project/issues/53492 and uses InlineHistory to track recursive inlining. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D127072
This commit is contained in:
parent
030b36a44c
commit
c2ecf16224
|
@ -19,6 +19,7 @@
|
|||
#include "mlir/Interfaces/CallInterfaces.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Support/DebugStringHelper.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "llvm/ADT/SCCIterator.h"
|
||||
|
@ -364,6 +365,31 @@ static void collectCallOps(iterator_range<Region::iterator> blocks,
|
|||
//===----------------------------------------------------------------------===//
|
||||
// Inliner
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NDEBUG
|
||||
static std::string getNodeName(CallOpInterface op) {
|
||||
if (auto sym = op.getCallableForCallee().dyn_cast<SymbolRefAttr>())
|
||||
return debugString(op);
|
||||
return "_unnamed_callee_";
|
||||
}
|
||||
#endif
|
||||
|
||||
/// Return true if the specified `inlineHistoryID` indicates an inline history
|
||||
/// that already includes `node`.
|
||||
static bool inlineHistoryIncludes(
|
||||
CallGraphNode *node, Optional<size_t> inlineHistoryID,
|
||||
MutableArrayRef<std::pair<CallGraphNode *, Optional<size_t>>>
|
||||
inlineHistory) {
|
||||
while (inlineHistoryID.has_value()) {
|
||||
assert(inlineHistoryID.value() < inlineHistory.size() &&
|
||||
"Invalid inline history ID");
|
||||
if (inlineHistory[inlineHistoryID.value()].first == node)
|
||||
return true;
|
||||
inlineHistoryID = inlineHistory[inlineHistoryID.value()].second;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// This class provides a specialization of the main inlining interface.
|
||||
struct Inliner : public InlinerInterface {
|
||||
|
@ -454,23 +480,43 @@ static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
|
|||
}
|
||||
}
|
||||
|
||||
// When inlining a callee produces new call sites, we want to keep track of
|
||||
// the fact that they were inlined from the callee. This allows us to avoid
|
||||
// infinite inlining.
|
||||
using InlineHistoryT = Optional<size_t>;
|
||||
SmallVector<std::pair<CallGraphNode *, InlineHistoryT>, 8> inlineHistory;
|
||||
std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{});
|
||||
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n";
|
||||
for (unsigned i = 0, e = calls.size(); i < e; ++i)
|
||||
llvm::dbgs() << " " << i << ". " << calls[i].call << ",\n";
|
||||
llvm::dbgs() << "}\n";
|
||||
});
|
||||
|
||||
// Try to inline each of the call operations. Don't cache the end iterator
|
||||
// here as more calls may be added during inlining.
|
||||
bool inlinedAnyCalls = false;
|
||||
for (unsigned i = 0; i != calls.size(); ++i) {
|
||||
for (unsigned i = 0; i < calls.size(); ++i) {
|
||||
if (deadNodes.contains(calls[i].sourceNode))
|
||||
continue;
|
||||
ResolvedCall it = calls[i];
|
||||
bool doInline = shouldInline(it);
|
||||
|
||||
InlineHistoryT inlineHistoryID = callHistory[i];
|
||||
bool inHistory =
|
||||
inlineHistoryIncludes(it.targetNode, inlineHistoryID, inlineHistory);
|
||||
bool doInline = !inHistory && shouldInline(it);
|
||||
CallOpInterface call = it.call;
|
||||
LLVM_DEBUG({
|
||||
if (doInline)
|
||||
llvm::dbgs() << "* Inlining call: " << call << "\n";
|
||||
llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n";
|
||||
else
|
||||
llvm::dbgs() << "* Not inlining call: " << call << "\n";
|
||||
llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n";
|
||||
});
|
||||
if (!doInline)
|
||||
continue;
|
||||
|
||||
unsigned prevSize = calls.size();
|
||||
Region *targetRegion = it.targetNode->getCallableRegion();
|
||||
|
||||
// If this is the last call to the target node and the node is discardable,
|
||||
|
@ -486,6 +532,29 @@ static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
|
|||
}
|
||||
inlinedAnyCalls = true;
|
||||
|
||||
// Create a inline history entry for this inlined call, so that we remember
|
||||
// that new callsites came about due to inlining Callee.
|
||||
InlineHistoryT newInlineHistoryID{inlineHistory.size()};
|
||||
inlineHistory.push_back(std::make_pair(it.targetNode, inlineHistoryID));
|
||||
|
||||
auto historyToString = [](InlineHistoryT h) {
|
||||
return h.has_value() ? std::to_string(h.value()) : "root";
|
||||
};
|
||||
(void)historyToString;
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "* new inlineHistory entry: " << newInlineHistoryID << ". ["
|
||||
<< getNodeName(call) << ", " << historyToString(inlineHistoryID)
|
||||
<< "]\n");
|
||||
|
||||
for (unsigned k = prevSize; k != calls.size(); ++k) {
|
||||
callHistory.push_back(newInlineHistoryID);
|
||||
LLVM_DEBUG(llvm::dbgs() << "* new call " << k << " {" << calls[i].call
|
||||
<< "}\n with historyID = " << newInlineHistoryID
|
||||
<< ", added due to inlining of\n call {" << call
|
||||
<< "}\n with historyID = "
|
||||
<< historyToString(inlineHistoryID) << "\n");
|
||||
}
|
||||
|
||||
// If the inlining was successful, Merge the new uses into the source node.
|
||||
useList.dropCallUses(it.sourceNode, call.getOperation(), cg);
|
||||
useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode);
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
// RUN: mlir-opt %s -inline='default-pipeline=''' | FileCheck %s
|
||||
// RUN: mlir-opt %s --mlir-disable-threading -inline='default-pipeline=''' | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @foo0
|
||||
func.func @foo0(%arg0 : i32) -> i32 {
|
||||
// CHECK: call @foo1
|
||||
// CHECK: }
|
||||
%0 = arith.constant 0 : i32
|
||||
%1 = arith.cmpi eq, %arg0, %0 : i32
|
||||
cf.cond_br %1, ^exit, ^tail
|
||||
^exit:
|
||||
return %0 : i32
|
||||
^tail:
|
||||
%3 = call @foo1(%arg0) : (i32) -> i32
|
||||
return %3 : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @foo1
|
||||
func.func @foo1(%arg0 : i32) -> i32 {
|
||||
// CHECK: call @foo1
|
||||
%0 = arith.constant 1 : i32
|
||||
%1 = arith.subi %arg0, %0 : i32
|
||||
%2 = call @foo0(%1) : (i32) -> i32
|
||||
return %2 : i32
|
||||
}
|
Loading…
Reference in New Issue