[mlir][Inliner] Support recursion in Inliner

This fixes  Bug https://github.com/llvm/llvm-project/issues/53492
 and uses InlineHistory to track recursive inlining.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D127072
This commit is contained in:
Javed Absar 2022-06-05 13:42:13 +01:00
parent 030b36a44c
commit c2ecf16224
2 changed files with 98 additions and 4 deletions

View File

@ -19,6 +19,7 @@
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/DebugStringHelper.h"
#include "mlir/Transforms/InliningUtils.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/SCCIterator.h"
@ -364,6 +365,31 @@ static void collectCallOps(iterator_range<Region::iterator> blocks,
//===----------------------------------------------------------------------===//
// Inliner
//===----------------------------------------------------------------------===//
#ifndef NDEBUG
static std::string getNodeName(CallOpInterface op) {
if (auto sym = op.getCallableForCallee().dyn_cast<SymbolRefAttr>())
return debugString(op);
return "_unnamed_callee_";
}
#endif
/// Return true if the specified `inlineHistoryID` indicates an inline history
/// that already includes `node`.
static bool inlineHistoryIncludes(
CallGraphNode *node, Optional<size_t> inlineHistoryID,
MutableArrayRef<std::pair<CallGraphNode *, Optional<size_t>>>
inlineHistory) {
while (inlineHistoryID.has_value()) {
assert(inlineHistoryID.value() < inlineHistory.size() &&
"Invalid inline history ID");
if (inlineHistory[inlineHistoryID.value()].first == node)
return true;
inlineHistoryID = inlineHistory[inlineHistoryID.value()].second;
}
return false;
}
namespace {
/// This class provides a specialization of the main inlining interface.
struct Inliner : public InlinerInterface {
@ -454,23 +480,43 @@ static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
}
}
// When inlining a callee produces new call sites, we want to keep track of
// the fact that they were inlined from the callee. This allows us to avoid
// infinite inlining.
using InlineHistoryT = Optional<size_t>;
SmallVector<std::pair<CallGraphNode *, InlineHistoryT>, 8> inlineHistory;
std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{});
LLVM_DEBUG({
llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n";
for (unsigned i = 0, e = calls.size(); i < e; ++i)
llvm::dbgs() << " " << i << ". " << calls[i].call << ",\n";
llvm::dbgs() << "}\n";
});
// Try to inline each of the call operations. Don't cache the end iterator
// here as more calls may be added during inlining.
bool inlinedAnyCalls = false;
for (unsigned i = 0; i != calls.size(); ++i) {
for (unsigned i = 0; i < calls.size(); ++i) {
if (deadNodes.contains(calls[i].sourceNode))
continue;
ResolvedCall it = calls[i];
bool doInline = shouldInline(it);
InlineHistoryT inlineHistoryID = callHistory[i];
bool inHistory =
inlineHistoryIncludes(it.targetNode, inlineHistoryID, inlineHistory);
bool doInline = !inHistory && shouldInline(it);
CallOpInterface call = it.call;
LLVM_DEBUG({
if (doInline)
llvm::dbgs() << "* Inlining call: " << call << "\n";
llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n";
else
llvm::dbgs() << "* Not inlining call: " << call << "\n";
llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n";
});
if (!doInline)
continue;
unsigned prevSize = calls.size();
Region *targetRegion = it.targetNode->getCallableRegion();
// If this is the last call to the target node and the node is discardable,
@ -486,6 +532,29 @@ static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
}
inlinedAnyCalls = true;
// Create a inline history entry for this inlined call, so that we remember
// that new callsites came about due to inlining Callee.
InlineHistoryT newInlineHistoryID{inlineHistory.size()};
inlineHistory.push_back(std::make_pair(it.targetNode, inlineHistoryID));
auto historyToString = [](InlineHistoryT h) {
return h.has_value() ? std::to_string(h.value()) : "root";
};
(void)historyToString;
LLVM_DEBUG(llvm::dbgs()
<< "* new inlineHistory entry: " << newInlineHistoryID << ". ["
<< getNodeName(call) << ", " << historyToString(inlineHistoryID)
<< "]\n");
for (unsigned k = prevSize; k != calls.size(); ++k) {
callHistory.push_back(newInlineHistoryID);
LLVM_DEBUG(llvm::dbgs() << "* new call " << k << " {" << calls[i].call
<< "}\n with historyID = " << newInlineHistoryID
<< ", added due to inlining of\n call {" << call
<< "}\n with historyID = "
<< historyToString(inlineHistoryID) << "\n");
}
// If the inlining was successful, Merge the new uses into the source node.
useList.dropCallUses(it.sourceNode, call.getOperation(), cg);
useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode);

View File

@ -0,0 +1,25 @@
// RUN: mlir-opt %s -inline='default-pipeline=''' | FileCheck %s
// RUN: mlir-opt %s --mlir-disable-threading -inline='default-pipeline=''' | FileCheck %s
// CHECK-LABEL: func.func @foo0
func.func @foo0(%arg0 : i32) -> i32 {
// CHECK: call @foo1
// CHECK: }
%0 = arith.constant 0 : i32
%1 = arith.cmpi eq, %arg0, %0 : i32
cf.cond_br %1, ^exit, ^tail
^exit:
return %0 : i32
^tail:
%3 = call @foo1(%arg0) : (i32) -> i32
return %3 : i32
}
// CHECK-LABEL: func.func @foo1
func.func @foo1(%arg0 : i32) -> i32 {
// CHECK: call @foo1
%0 = arith.constant 1 : i32
%1 = arith.subi %arg0, %0 : i32
%2 = call @foo0(%1) : (i32) -> i32
return %2 : i32
}