[mlir][Linalg] Add a ComprehensiveModuleBufferizePass and support for CallOp analysis(9/n)

This revision adds the minimal plumbing to create a simple ComprehensiveModuleBufferizePass that can behave conservatively in the presence of CallOps.

A topological sort of caller/callee is performed and, if the call-graph is cycle-free, analysis can proceed.

Differential revision: https://reviews.llvm.org/D104859
This commit is contained in:
Nicolas Vasilache 2021-06-29 15:39:14 +00:00
parent 90dfd05919
commit a77524cd2c
6 changed files with 336 additions and 38 deletions

View File

@ -62,6 +62,14 @@ std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgToAffineLoopsPass();
/// b) whose buffer uses would be free of memory hazards.
std::unique_ptr<Pass> createLinalgComprehensiveFuncBufferizePass();
/// This pass implements a cross-dialect bufferization approach and performs an
/// analysis to determine which op operands and results may be bufferized in the
/// same buffers. The analysis is performed on topologically sorted CallOp and
/// FuncOp within a module. It provides analyses and bufferization across
/// function boundaries. Within a single function body, the bufferization used
/// is that provided by `LinalgComprehensiveFuncBufferizePass`.
std::unique_ptr<Pass> createLinalgComprehensiveModuleBufferizePass();
/// Create a pass to convert Linalg operations which work on tensors to use
/// buffers instead.
std::unique_ptr<OperationPass<FuncOp>> createLinalgBufferizePass();

View File

@ -32,7 +32,7 @@ def LinalgComprehensiveFuncBufferize :
This pass implements a cross-dialect bufferization approach and performs an
analysis to determine which op operands and results may be bufferized in the
same buffers. The analysis is performed on SSA use-def chains starting from
function operands that are annotated with the 'inplaceable' attribute
function operands that are annotated with the 'inplaceable' attribute.
}];
let options = [
Option<"testAnalysisOnly", "test-analysis-only", "bool",
@ -42,6 +42,25 @@ def LinalgComprehensiveFuncBufferize :
let constructor = "mlir::createLinalgComprehensiveFuncBufferizePass()";
}
def LinalgComprehensiveModuleBufferize :
Pass<"linalg-comprehensive-module-bufferize", "ModuleOp"> {
let summary = "Bufferize (tensor into memref) for a Module.";
let description = [{
This pass implements a cross-dialect bufferization approach and performs an
analysis to determine which op operands and results may be bufferized in the
same buffers. The analysis is performed on topologically sorted CallOp and
FuncOp within a module. It provides analyses and bufferization across
function boundaries. Within a single function body, the bufferization used
is that provided by `-linalg-comprehensive-func-bufferize`.
}];
let options = [
Option<"testAnalysisOnly", "test-analysis-only", "bool",
/*default=*/"false",
"Only runs inplaceability analysis (for testing purposes only)">
];
let constructor = "mlir::createLinalgComprehensiveModuleBufferizePass()";
}
def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> {
let summary = "Remove unit-extent dimension in Linalg ops on tensors";
let constructor = "mlir::createLinalgFoldUnitExtentDimsPass()";

View File

@ -375,6 +375,10 @@ public:
/// attribute that was erased, or nullptr if there was no attribute with such
/// name.
Attribute removeArgAttr(unsigned index, Identifier name);
Attribute removeArgAttr(unsigned index, StringRef name) {
return removeArgAttr(
index, Identifier::get(name, this->getOperation()->getContext()));
}
//===--------------------------------------------------------------------===//
// Result Attributes

View File

@ -16,7 +16,7 @@
// Composability with extensible set of ops is not a first-class concern.
//
// Bufferization occurs by:
// a. performing an inPlace analysis `inPlaceAnalysisFuncOpInternals`
// a. performing an inPlace analysis `inPlaceAnalysisFuncOpBody`
// which marks each operation within the function with the
// `kInPlaceResultsAttrName` attribute.
// b. traversing each operation in the function and rewriting it in
@ -132,6 +132,19 @@ using namespace tensor;
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X)
//===----------------------------------------------------------------------===//
// Generic helpers.
//===----------------------------------------------------------------------===//
/// Return the FuncOp called by `callOp`.
static FuncOp getCalledFunction(CallOpInterface callOp) {
SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
if (!sym)
return nullptr;
return dyn_cast_or_null<FuncOp>(
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
}
//===----------------------------------------------------------------------===//
// Bufferization-specific BlockAndValueMapping support with debugging.
//===----------------------------------------------------------------------===//
@ -167,6 +180,7 @@ static Value lookup(BlockAndValueMapping &bvm, Value key) {
parentOp = key.getDefiningOp()->getParentOfType<FuncOp>();
}
LDBG("In func:\n" << *parentOp << "NO VALUE FOR KEY: " << key << '\n');
(void)parentOp;
return Value();
}
@ -276,6 +290,25 @@ static InPlaceSpec getInPlace(BlockArgument bbArg) {
return InPlaceSpec::None;
}
/// Set the attribute that triggers inplace bufferization on a FuncOp argument
/// `bbArg`.
static void
setInPlaceFuncArgument(BlockArgument bbArg,
InPlaceSpec inPlaceSpec = InPlaceSpec::True) {
auto funcOp = cast<FuncOp>(bbArg.getOwner()->getParentOp());
funcOp.setArgAttr(
bbArg.getArgNumber(), LinalgDialect::kInplaceableAttrName,
BoolAttr::get(bbArg.getContext(), inPlaceSpec == InPlaceSpec::True));
}
/// Remove the attribute that triggers inplace bufferization on a FuncOp
/// argument `bbArg`.
static void removeInPlaceFuncArgument(BlockArgument bbArg) {
auto funcOp = cast<FuncOp>(bbArg.getOwner()->getParentOp());
funcOp.removeArgAttr(bbArg.getArgNumber(),
LinalgDialect::kInplaceableAttrName);
}
LLVM_ATTRIBUTE_UNUSED static InPlaceSpec getInPlace(Value v) {
if (auto bbArg = v.dyn_cast<BlockArgument>())
return getInPlace(bbArg);
@ -305,7 +338,8 @@ LLVM_ATTRIBUTE_UNUSED static InPlaceSpec getInPlace(Value v) {
static bool hasKnownBufferizationAliasingBehavior(Operation *op) {
return
// clang-format off
isa<scf::ForOp,
isa<CallOpInterface,
scf::ForOp,
LinalgOp,
ReturnOp,
ExtractSliceOp,
@ -386,6 +420,10 @@ static OpResult getInplaceableOpResult(OpOperand &opOperand) {
// ExtractSliceOp is special, when bufferized inplace it just returns an
// alias to its operand. Its result is never inplaceable on its operand.
.Case([&](ExtractSliceOp op) { return OpResult(); })
// CallOpInterface is special, it needs to wait for the callee to be
// bufferized and needs to inspect the BufferAliasInfo object. It can't
// make a proper determination by itself and needs to be conservative.
.Case([&](CallOpInterface op) { return OpResult(); })
// Other ops.
.Default([&](Operation *op) { return OpResult(); });
// clang-format on
@ -458,6 +496,12 @@ static bool bufferizesToMemoryRead(OpOperand &opOperand) {
// matching bbArg may.
if (isa<scf::ForOp>(opOperand.getOwner()))
return false;
// CallOpInterface alone doesn't bufferize to a memory read, one of the uses
// of the matching bbArg may. It is the responsibility of the caller to
// inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be
// conservative.
if (auto callOp = dyn_cast<CallOpInterface>(opOperand.getOwner()))
return true;
if (auto linalgOp = dyn_cast<LinalgOp>(opOperand.getOwner()))
return linalgOp.isInputTensor(&opOperand) ||
linalgOp.isInitTensor(&opOperand);
@ -473,6 +517,19 @@ static bool bufferizesToMemoryRead(OpOperand &opOperand) {
static bool
bufferizesToMemoryWrite(OpOperand &opOperand,
InPlaceSpec inPlaceSpec = InPlaceSpec::None) {
// These terminators are not writes.
if (isa<ReturnOp, linalg::YieldOp, scf::YieldOp>(opOperand.getOwner()))
return false;
// ExtractSliceOp alone doesn't bufferize to a memory write, one of its uses
// may.
if (isa<ExtractSliceOp>(opOperand.getOwner()))
return false;
// CallOpInterface alone doesn't bufferize to a memory write, one of the uses
// of the matching bbArg may. It is the responsibility of the caller to
// inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be
// conservative.
if (auto callOp = dyn_cast<CallOpInterface>(opOperand.getOwner()))
return true;
Optional<OpResult> maybeOpResult = getAliasingOpResult(opOperand);
// Unknown op that returns a tensor. The inplace analysis does not support
// it. Conservatively return true.
@ -482,13 +539,6 @@ bufferizesToMemoryWrite(OpOperand &opOperand,
// This does not bufferize to a write.
if (!*maybeOpResult)
return false;
// These terminators are not writes.
if (isa<ReturnOp, linalg::YieldOp, scf::YieldOp>(opOperand.getOwner()))
return false;
// ExtractSliceOp alone doesn't bufferize to a memory write, one of its uses
// may.
if (maybeOpResult->getDefiningOp<ExtractSliceOp>())
return false;
// If we have a matching OpResult, this is a write.
// Additionally allow to restrict to only inPlace write, if so specified.
return inPlaceSpec == InPlaceSpec::None ||
@ -521,7 +571,11 @@ public:
Equivalent
};
explicit BufferizationAliasInfo(FuncOp funcOp);
explicit BufferizationAliasInfo(Operation *rootOp);
/// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
/// beginning the alias and equivalence sets only contain `v` itself.
void createAliasInfoEntry(Value v);
/// Return true if the buffer to which `operand` would bufferize aliases a
/// buffer that is known to not be writeable. This implies that the matching
@ -664,33 +718,28 @@ private:
};
} // namespace
BufferizationAliasInfo::BufferizationAliasInfo(FuncOp funcOp) {
funcOp.walk([&](Operation *op) {
for (Value v : op->getResults()) {
if (!v.getType().isa<TensorType>())
continue;
assert(getInPlace(v) == InPlaceSpec::None &&
"unexpected inplace in analysis.");
DenseSet<Value> selfSet;
selfSet.insert(v);
aliasInfo.try_emplace(v, selfSet);
equivalentInfo.insert(v);
}
for (Region &r : op->getRegions()) {
for (Block &b : r.getBlocks()) {
for (auto bbArg : b.getArguments()) {
if (!bbArg.getType().isa<TensorType>())
continue;
DenseSet<Value> selfSet;
selfSet.insert(bbArg);
aliasInfo.try_emplace(bbArg, selfSet);
equivalentInfo.insert(bbArg);
}
}
}
BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
rootOp->walk([&](Operation *op) {
for (Value v : op->getResults())
if (v.getType().isa<TensorType>())
createAliasInfoEntry(v);
for (Region &r : op->getRegions())
for (Block &b : r.getBlocks())
for (auto bbArg : b.getArguments())
if (bbArg.getType().isa<TensorType>())
createAliasInfoEntry(bbArg);
});
}
/// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
/// beginning the alias and equivalence sets only contain `v` itself.
void BufferizationAliasInfo::createAliasInfoEntry(Value v) {
DenseSet<Value> selfSet;
selfSet.insert(v);
aliasInfo.try_emplace(v, selfSet);
equivalentInfo.insert(v);
}
/// Return true if the buffer to which `operand` would bufferize aliases a
/// buffer that is known to not be writeable. This implies that the matching
/// OpResult cannot be bufferized inplace.
@ -1679,8 +1728,8 @@ bufferizationSanityCheck(scf::YieldOp yieldOp,
/// Analyze the `funcOp` body to determine which OpResults are inplaceable.
static LogicalResult
inPlaceAnalysisFuncOpInternals(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
const DominanceInfo &domInfo) {
inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
const DominanceInfo &domInfo) {
LLVM_DEBUG(llvm::dbgs() << "\n\n");
LDBG("Begin InPlaceAnalysisFuncOpInternals:\n" << funcOp << '\n');
assert(funcOp && funcOp->getNumRegions() > 0 && !funcOp.body().empty() &&
@ -1816,7 +1865,7 @@ void LinalgComprehensiveFuncBufferize::runOnFunction() {
BufferizationAliasInfo aliasInfo(funcOp);
// If the analysis fails, just return. This is expected to reset the IR and no
// single OpResult should be marked inPlace.
if (failed(inPlaceAnalysisFuncOpInternals(funcOp, aliasInfo, domInfo))) {
if (failed(inPlaceAnalysisFuncOpBody(funcOp, aliasInfo, domInfo))) {
signalPassFailure();
return;
}
@ -1836,3 +1885,122 @@ void LinalgComprehensiveFuncBufferize::runOnFunction() {
std::unique_ptr<Pass> mlir::createLinalgComprehensiveFuncBufferizePass() {
return std::make_unique<LinalgComprehensiveFuncBufferize>();
}
//===----------------------------------------------------------------------===//
// Bufferization entry-point for modules.
//===----------------------------------------------------------------------===//
/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
/// callee-caller order (i.e. callees without callers first).
/// Store the map of FuncOp to all its callers in `callerMap`.
/// Return `failure()` if a cycle of calls is detected or if we are unable to
/// retrieve the called FuncOp from any CallOpInterface.
static LogicalResult
getFuncOpsOrderedByCalls(ModuleOp moduleOp,
SmallVectorImpl<FuncOp> &orderedFuncOps,
DenseMap<FuncOp, DenseSet<Operation *>> &callerMap) {
// For each FuncOp, the set of functions called by it (i.e. the union of
// symbols of all nested CallOpInterfaceOp).
DenseMap<FuncOp, DenseSet<FuncOp>> calledBy;
// For each FuncOp, the number of CallOpInterface it contains.
DenseMap<FuncOp, unsigned> numberCallOpsContainedInFuncOp;
WalkResult res = moduleOp.walk([&](FuncOp funcOp) {
numberCallOpsContainedInFuncOp[funcOp] = 0;
return funcOp.walk([&](CallOpInterface callOp) {
FuncOp calledFunction = getCalledFunction(callOp);
if (!calledFunction)
return WalkResult::interrupt();
auto it = callerMap.try_emplace(calledFunction, DenseSet<Operation *>{});
it.first->getSecond().insert(callOp);
if (calledBy[calledFunction].count(funcOp) == 0) {
calledBy[calledFunction].insert(funcOp);
numberCallOpsContainedInFuncOp[funcOp]++;
}
return WalkResult::advance();
});
});
if (res.wasInterrupted())
return failure();
// Iteratively remove function operation that do not call any of the
// functions remaining in the callCounter map and add them to the worklist.
while (!numberCallOpsContainedInFuncOp.empty()) {
auto it = llvm::find_if(numberCallOpsContainedInFuncOp,
[](auto entry) { return entry.getSecond() == 0; });
if (it == numberCallOpsContainedInFuncOp.end())
return moduleOp.emitOpError(
"expected callgraph to be free of circular dependencies.");
orderedFuncOps.push_back(it->getFirst());
for (auto callee : calledBy[it->getFirst()])
numberCallOpsContainedInFuncOp[callee]--;
numberCallOpsContainedInFuncOp.erase(it);
}
return success();
}
namespace {
struct LinalgComprehensiveModuleBufferize
: public LinalgComprehensiveModuleBufferizeBase<
LinalgComprehensiveModuleBufferize> {
void runOnOperation() override;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect, memref::MemRefDialect>();
}
};
} // end namespace
void LinalgComprehensiveModuleBufferize::runOnOperation() {
ModuleOp moduleOp = getOperation();
SmallVector<FuncOp> orderedFuncOps;
DenseMap<FuncOp, DenseSet<Operation *>> callerMap;
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
return signalPassFailure();
DominanceInfo domInfo(moduleOp);
BufferizationAliasInfo aliasInfo(moduleOp);
// Interestingly, all function args that are not visible outside of a module
// can be fully bufferized inplace by guaranteeing the CallOp is bufferized
// inplace. Therefore, we just bufferize funcOp as if none of its results were
// inplaceable, detect which operands are cloned internally and decide what to
// do at call sites.
for (FuncOp funcOp : orderedFuncOps) {
// No body => no analysis.
if (funcOp.body().empty())
continue;
// In a first approximation:
// =========================
// If the function is called, we can allocate on the caller side which lets
// us force inplace arguments at function boundaries.
// TODO: do not rely on this behavior.
if (callerMap.find(funcOp) != callerMap.end())
for (BlockArgument bbArg : funcOp.getArguments())
if (bbArg.getType().isa<TensorType>())
setInPlaceFuncArgument(bbArg);
// If the analysis fails, just return.
if (failed(inPlaceAnalysisFuncOpBody(funcOp, aliasInfo, domInfo))) {
signalPassFailure();
return;
}
// TODO: Bufferization phase.
}
// Don't drop the attributes if we only want to report the analysis.
if (testAnalysisOnly)
return;
// Post-pass cleanup of inplaceable attributes.
moduleOp.walk(
[&](Operation *op) { op->removeAttr(kInPlaceResultsAttrName); });
moduleOp.walk([&](FuncOp op) {
for (BlockArgument bbArg : op.getArguments())
removeInPlaceFuncArgument(bbArg);
});
}
std::unique_ptr<Pass> mlir::createLinalgComprehensiveModuleBufferizePass() {
return std::make_unique<LinalgComprehensiveModuleBufferize>();
}

View File

@ -0,0 +1,84 @@
// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize=test-analysis-only -split-input-file | FileCheck %s
func private @foo(tensor<64xf32>)
// CHECK-LABEL: dependence_through_call
func @dependence_through_call(%I : tensor<64xf32> {linalg.inplaceable = true}) {
%f1 = constant 1.000000e+00 : f32
%f2 = constant 2.000000e+00 : f32
// 2. %B already bufferizes inplace, %A would alias and have a different
// value. The calls to `foo` are determined to read conservatively, so %A
// cannot bufferize inplace.
// CHECK: fill
// CHECK-SAME: {__inplace_results_attr__ = ["false"]}
%A = linalg.fill(%f1, %I) : f32, tensor<64xf32> -> tensor<64xf32>
// 1. Bufferizes inplace: no alias to %A is yet possible.
// CHECK: fill
// CHECK-SAME: {__inplace_results_attr__ = ["true"]}
%B = linalg.fill(%f2, %I) : f32, tensor<64xf32> -> tensor<64xf32>
call @foo(%A) : (tensor<64xf32>) -> ()
call @foo(%B) : (tensor<64xf32>) -> ()
return
}
// -----
func private @foo(tensor<64xf32>)
func private @bar(%A : tensor<64xf32>) {
call @foo(%A) : (tensor<64xf32>) -> ()
return
}
func @read_dependence_through_scf_and_call(
%I : tensor<64xf32> {linalg.inplaceable = true},
%I2 : tensor<64xf32> {linalg.inplaceable = true}) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%c10 = constant 10 : index
%f1 = constant 1.000000e+00 : f32
%f2 = constant 2.000000e+00 : f32
// 5. %B bufferizes inplace, %A would alias and have a different value.
// The calls to `foo` are determined to read conservatively, so %A cannot
// bufferize inplace.
// CHECK: fill
// CHECK-SAME: {__inplace_results_attr__ = ["false"]}
%A = linalg.fill(%f1, %I) : f32, tensor<64xf32> -> tensor<64xf32>
// 4. Bufferizes inplace: no alias to %A is yet possible.
// CHECK: fill
// CHECK-SAME: {__inplace_results_attr__ = ["true"]}
%B = linalg.fill(%f2, %I) : f32, tensor<64xf32> -> tensor<64xf32>
// 3. Does not read or write, bufferizes inplace.
// CHECK: scf.for
// CHECK: {__inplace_results_attr__ = ["true", "true"]}
%r:2 = scf.for %i = %c0 to %c10 step %c1 iter_args(%0 = %A, %1 = %B)
-> (tensor<64xf32>, tensor<64xf32>)
{
scf.yield %0, %1 : tensor<64xf32>, tensor<64xf32>
}
call @foo(%r#0) : (tensor<64xf32>) -> ()
call @foo(%r#1) : (tensor<64xf32>) -> ()
// 2. %B2 already bufferizes inplace, %A2 would alias and have a different
// value. The calls to `foo` are determined to read conservatively, so %A2
// cannot bufferize inplace.
// CHECK: fill
// CHECK-SAME: {__inplace_results_attr__ = ["false"]}
%A2 = linalg.fill(%f1, %I2) : f32, tensor<64xf32> -> tensor<64xf32>
// 1. Bufferizes inplace: no alias to %A2 is yet possible.
// CHECK: fill
// CHECK-SAME: {__inplace_results_attr__ = ["true"]}
%B2 = linalg.fill(%f2, %I2) : f32, tensor<64xf32> -> tensor<64xf32>
call @bar(%A2) : (tensor<64xf32>) -> ()
call @bar(%B2) : (tensor<64xf32>) -> ()
return
}

View File

@ -0,0 +1,15 @@
// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize -split-input-file -verify-diagnostics
// -----
// expected-error @-3 {{expected callgraph to be free of circular dependencies}}
func @foo() {
call @bar() : () -> ()
return
}
func @bar() {
call @foo() : () -> ()
return
}