[mlir][linalg][bufferize][NFC] Remove special casing of CallOps

Differential Revision: https://reviews.llvm.org/D113966
This commit is contained in:
Matthias Springer 2021-11-23 11:12:38 +09:00
parent b1083830d6
commit 8d0994ed21
4 changed files with 164 additions and 187 deletions

View File

@ -11,6 +11,7 @@
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/LLVM.h"
@ -240,9 +241,8 @@ struct AllocationCallbacks {
/// BufferizationState keeps track of bufferization state and provides access to
/// the results of the analysis.
struct BufferizationState {
BufferizationState(BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFns)
: aliasInfo(aliasInfo), allocationFns(allocationFns) {}
BufferizationState(ModuleOp moduleOp, AllocationCallbacks &allocationFns)
: aliasInfo(moduleOp), allocationFns(allocationFns) {}
// BufferizationState should be passed as a reference.
BufferizationState(const BufferizationState &) = delete;
@ -270,8 +270,11 @@ struct BufferizationState {
/// Mark `op` as obsolete, so that it is deleted after bufferization.
void markOpObsolete(Operation *op);
/// Erase all ops that were marked obsolete.
void eraseObsoleteOps();
/// `aliasInfo` keeps track of aliasing and equivalent values.
BufferizationAliasInfo &aliasInfo;
BufferizationAliasInfo aliasInfo;
/// `allocationFns` contains helper functions for creating alloc ops, dealloc
/// ops and memcpy ops.
@ -283,6 +286,10 @@ struct BufferizationState {
/// Obsolete ops that should be deleted after bufferization.
SmallVector<Operation *> obsoleteOps;
/// A map for looking up bufferized function types.
// TODO: Entangle function calls and FuncOps from the remaining bufferization.
DenseMap<FuncOp, FunctionType> bufferizedFunctionTypes;
};
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate

View File

@ -25,11 +25,7 @@ static constexpr int64_t kBufferAlignments = 128;
std::unique_ptr<AllocationCallbacks> defaultAllocationCallbacks();
/// Bufferize one particular op.
/// `bufferizedFunctionTypes` (resp. `globalCreator`) are expected to be
/// non-null if `op` is a CallOpInterface (resp. GlobalCreator).
LogicalResult
bufferizeOp(Operation *op, BufferizationState &state,
DenseMap<FuncOp, FunctionType> *bufferizedFunctionTypes = nullptr);
LogicalResult bufferizeOp(Operation *op, BufferizationState &state);
/// Register external models implemented for the `BufferizableOpInterface`.
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);

View File

@ -470,3 +470,10 @@ void mlir::linalg::comprehensive_bufferize::BufferizationState::markOpObsolete(
Operation *op) {
obsoleteOps.push_back(op);
}
void mlir::linalg::comprehensive_bufferize::BufferizationState::
eraseObsoleteOps() {
for (Operation *op : obsoleteOps)
op->erase();
obsoleteOps.clear();
}

View File

@ -783,144 +783,6 @@ static Value createNewAllocDeallocPairForShapedValue(
// Bufferization as simple BlockAndValueMapping rewrites.
//===----------------------------------------------------------------------===//
/// In a first approximation, all the function arguments of a FuncOp are marked
/// inplaceable. For now, it is the responsibility of the `callOp` bufferization
/// to allow FuncOp that are inplaceable to write inPlace.
static LogicalResult
bufferize(OpBuilder &b, CallOpInterface callOp, BufferizationState &state,
DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
FuncOp funcOp = getCalledFunction(callOp);
assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
"expected Callop to a FuncOp");
// If nothing to do then we are done.
if (!llvm::any_of(funcOp.getType().getInputs(), isaTensor) &&
!llvm::any_of(funcOp.getType().getResults(), isaTensor))
return success();
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(callOp);
// 1. Filter return types:
// - if the callee is bodiless / external, we cannot inspect it and we
// cannot assume anything. We can just assert that it does not return a
// tensor as this would have to bufferize to "return a memref", whose
// semantics is ill-defined.
// - if the callee has a body, we perform inter-procedural equivalence
// analysis. When successful, a result folds onto an operand. When
// unsuccessful, additional work is needed to either:
// * hoist a result into an inplaceable operand or
// * devise a better representation to truly return a buffer.
SmallVector<Type> resultTypes;
SmallVector<Value> hoistedArguments;
if (funcOp.body().empty()) {
if (llvm::any_of(funcOp.getType().getResults(), isaTensor))
return callOp->emitError()
<< "cannot bufferize bodiless function that returns a tensor";
} else {
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
assert(returnOp && "expected func with single return op");
// For each FuncOp result, keep track of which inplace argument it reuses.
for (OpOperand &returnOperand : returnOp->getOpOperands()) {
Type returnType = returnOperand.get().getType();
if (!isaTensor(returnType)) {
resultTypes.push_back(returnType);
continue;
}
// If return operand is equivalent to some bbArg, no need to return it.
Value returnVal = returnOperand.get();
if (BlockArgument bbArg =
getEquivalentEnclosingFuncBBArg(returnVal, state.aliasInfo)) {
Value oldRes = callOp->getResult(returnOperand.getOperandNumber());
int64_t idx = bbArg.getArgNumber();
Value buffer = state.lookupBuffer(callOp->getOperand(idx));
// Add CallOp operand/result equivalence: this is interprocedural info.
state.aliasInfo.insertNewBufferEquivalence(oldRes, buffer);
state.mapBuffer(oldRes, buffer);
// Add a TensorLoadOp to kill all uses of the CallOp return.
// Replace all uses of the CallOp results so we can erase the CallOp.
// This TensorLoadOp must fold/DCE away or bufferization should be
// considered failed.
Value tensorLoad =
b.create<memref::TensorLoadOp>(callOp.getLoc(), buffer);
oldRes.replaceAllUsesWith(tensorLoad);
// Add new op equivalence info.
state.aliasInfo.insertNewBufferEquivalence(tensorLoad, buffer);
state.mapBuffer(tensorLoad, buffer);
continue;
}
// TODO: Need to hoist above function boundary.
if (Operation *allocOp = getEquivalentAlloc(returnVal, state.aliasInfo)) {
hoistedArguments.push_back(allocOp->getResult(0));
continue;
}
// Other cases legitimately need to return a tensor, this is currently not
// supported. For instance, if hoisting across function boundary has
// failed, it may be due to e.g. data-dependent sizes. In such a case, we
// would we need a better type than memref.
resultTypes.push_back(returnType);
int64_t returnIdx = returnOperand.getOperandNumber();
return returnOp->emitError()
<< "buffer result #" << returnIdx << " not produced by an alloc\n";
}
}
// 2. Compute bufferized FunctionType.
SmallVector<Type> argumentTypes{callOp->getOperandTypes()};
ValueRange hoistedArgs{hoistedArguments};
llvm::append_range(argumentTypes, hoistedArgs.getTypes());
// Get the bufferized FunctionType for funcOp or construct it if not yet
// available.
FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
funcOp, argumentTypes, resultTypes, bufferizedFunctionTypes);
// 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
SmallVector<Value> newOperands;
newOperands.reserve(callOp->getNumOperands());
for (OpOperand &opOperand : callOp->getOpOperands()) {
Value tensorOperand = opOperand.get();
// Non-tensor operands are just copied.
if (!tensorOperand.getType().isa<TensorType>()) {
newOperands.push_back(tensorOperand);
continue;
}
// Tensor operands are guaranteed to have been buferized.
int64_t idx = opOperand.getOperandNumber();
Value buffer = state.lookupBuffer(tensorOperand);
// Caller / callee type mistmatch is handled with a CastOp.
auto memRefType = bufferizedFuncType.getInput(idx);
// Since we don't yet have a clear layout story, buffer_cast may
// conservatively turn tensors into more dynamic memref than necessary.
// If the memref type of the callee fails, introduce an extra memref.cast
// that will either canonicalize away or fail compilation until we can do
// something better.
if (buffer.getType() != memRefType) {
Value castBuffer =
b.create<memref::CastOp>(callOp.getLoc(), memRefType, buffer);
// Add new op equivalence info.
state.aliasInfo.insertNewBufferEquivalence(castBuffer, buffer);
state.mapBuffer(tensorOperand, castBuffer);
buffer = castBuffer;
}
newOperands.push_back(buffer);
}
// 4. Create the new CallOp.
Operation *newCallOp = b.create<CallOp>(callOp.getLoc(), funcOp.sym_name(),
resultTypes, newOperands);
newCallOp->setAttrs(callOp->getAttrs());
// Delete the op at the end of bufferization.
return success();
}
/// FuncOp always creates TensorToMemRef ops.
static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
BufferizationState &state) {
@ -1065,20 +927,11 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
// Bufferization entry-point for functions.
//===----------------------------------------------------------------------===//
LogicalResult mlir::linalg::comprehensive_bufferize::bufferizeOp(
Operation *op, BufferizationState &state,
DenseMap<FuncOp, FunctionType> *bufferizedFunctionTypes) {
LogicalResult
mlir::linalg::comprehensive_bufferize::bufferizeOp(Operation *op,
BufferizationState &state) {
OpBuilder b(op->getContext());
// CallOps are handled separately.
if (auto callOp = dyn_cast<CallOpInterface>(op)) {
LDBG("Begin bufferize:\n" << callOp << '\n');
if (!bufferizedFunctionTypes)
llvm_unreachable(
"null bufferizedFunctionTypes when bufferizing CallOpInterface");
return bufferize(b, callOp, state, *bufferizedFunctionTypes);
}
// Skip BufferCast and TensorLoad ops.
if (isa<memref::BufferCastOp, memref::TensorLoadOp>(op))
return success();
@ -1098,9 +951,8 @@ LogicalResult mlir::linalg::comprehensive_bufferize::bufferizeOp(
return op->emitError() << "unsupported op with tensors";
}
static LogicalResult bufferizeFuncOpInternals(
FuncOp funcOp, BufferizationState &state,
DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
static LogicalResult bufferizeFuncOpInternals(FuncOp funcOp,
BufferizationState &state) {
LLVM_DEBUG(llvm::dbgs() << "\n\n");
LDBG("Begin BufferizeFuncOpInternals:\n" << funcOp << '\n');
OpBuilder b(funcOp->getContext());
@ -1109,19 +961,9 @@ static LogicalResult bufferizeFuncOpInternals(
if (failed(bufferize(b, funcOp, state)))
return failure();
// Cannot erase ops during the traversal. Do that afterwards.
SmallVector<Operation *> toErase;
auto walkFunc = [&](Operation *op) -> WalkResult {
if (failed(bufferizeOp(op, state, &bufferizedFunctionTypes)))
if (failed(bufferizeOp(op, state)))
return failure();
// Register post-walk erasure, if necessary.
if (isa<CallOpInterface>(op))
if (llvm::any_of(op->getOperandTypes(), isaTensor) ||
llvm::any_of(op->getResultTypes(), isaTensor))
toErase.push_back(op);
return success();
};
@ -1133,9 +975,6 @@ static LogicalResult bufferizeFuncOpInternals(
LDBG("End BufferizeFuncOpInternals:\n" << funcOp << '\n');
for (Operation *op : toErase)
op->erase();
return success();
}
@ -1516,12 +1355,12 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
ModuleOp moduleOp, const BufferizationOptions &options) {
SmallVector<FuncOp> orderedFuncOps;
DenseMap<FuncOp, DenseSet<Operation *>> callerMap;
DenseMap<FuncOp, FunctionType> bufferizedFunctionTypes;
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
return failure();
DominanceInfo domInfo(moduleOp);
BufferizationAliasInfo aliasInfo(moduleOp);
BufferizationState state(moduleOp, *options.allocationFns);
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
// Interestingly, all function args that are not visible outside of a module
// can be fully bufferized inplace by guaranteeing the CallOp is bufferized
@ -1564,16 +1403,12 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
// Bufferization phase.
if (!options.testAnalysisOnly) {
BufferizationState state(aliasInfo, *options.allocationFns);
// Bufferize all ops in funcOp.
if (failed(
bufferizeFuncOpInternals(funcOp, state, bufferizedFunctionTypes)))
if (failed(bufferizeFuncOpInternals(funcOp, state)))
return failure();
// Erase all obsolete ops.
for (Operation *op : state.obsoleteOps)
op->erase();
state.eraseObsoleteOps();
}
}
// Annotate operations if we only want to report the analysis.
@ -1586,7 +1421,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
// Note: It would be good to apply cleanups here but we cannot as aliasInfo
// would be invalidated.
if (failed(bufferizeFuncOpBoundary(funcOp, aliasInfo,
bufferizedFunctionTypes)))
state.bufferizedFunctionTypes)))
return failure();
if (!options.allowReturnMemref &&
@ -1986,10 +1821,142 @@ struct CallOpInterface
return BufferRelation::Equivalent;
}
/// In a first approximation, all the function arguments of a FuncOp are
/// marked inplaceable. For now, it is the responsibility of the `callOp`
/// bufferization to allow FuncOp that are inplaceable to write inPlace.
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
llvm_unreachable("CallOps are handled separately");
return failure();
CallOp callOp = cast<CallOp>(op);
FuncOp funcOp = getCalledFunction(callOp);
assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
"expected Callop to a FuncOp");
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(callOp);
// 1. Filter return types:
// - if the callee is bodiless / external, we cannot inspect it and we
// cannot assume anything. We can just assert that it does not return a
// tensor as this would have to bufferize to "return a memref", whose
// semantics is ill-defined.
// - if the callee has a body, we perform inter-procedural equivalence
// analysis. When successful, a result folds onto an operand. When
// unsuccessful, additional work is needed to either:
// * hoist a result into an inplaceable operand or
// * devise a better representation to truly return a buffer.
SmallVector<Type> resultTypes;
SmallVector<Value> hoistedArguments;
if (funcOp.body().empty()) {
if (llvm::any_of(funcOp.getType().getResults(), isaTensor))
return callOp->emitError()
<< "cannot bufferize bodiless function that returns a tensor";
} else {
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
assert(returnOp && "expected func with single return op");
// For each FuncOp result, keep track of which inplace argument it reuses.
for (OpOperand &returnOperand : returnOp->getOpOperands()) {
Type returnType = returnOperand.get().getType();
if (!isaTensor(returnType)) {
resultTypes.push_back(returnType);
continue;
}
// If return operand is equivalent to some bbArg, no need to return it.
Value returnVal = returnOperand.get();
if (BlockArgument bbArg =
getEquivalentEnclosingFuncBBArg(returnVal, state.aliasInfo)) {
Value oldRes = callOp->getResult(returnOperand.getOperandNumber());
int64_t idx = bbArg.getArgNumber();
Value buffer = state.lookupBuffer(callOp->getOperand(idx));
// Add CallOp operand/result equivalence: this is interprocedural
// info.
state.aliasInfo.insertNewBufferEquivalence(oldRes, buffer);
state.mapBuffer(oldRes, buffer);
// Add a TensorLoadOp to kill all uses of the CallOp return.
// Replace all uses of the CallOp results so we can erase the CallOp.
// This TensorLoadOp must fold/DCE away or bufferization should be
// considered failed.
Value tensorLoad =
b.create<memref::TensorLoadOp>(callOp.getLoc(), buffer);
oldRes.replaceAllUsesWith(tensorLoad);
// Add new op equivalence info.
state.aliasInfo.insertNewBufferEquivalence(tensorLoad, buffer);
state.mapBuffer(tensorLoad, buffer);
continue;
}
// TODO: Need to hoist above function boundary.
if (Operation *allocOp =
getEquivalentAlloc(returnVal, state.aliasInfo)) {
hoistedArguments.push_back(allocOp->getResult(0));
continue;
}
// Other cases legitimately need to return a tensor, this is currently
// not supported. For instance, if hoisting across function boundary has
// failed, it may be due to e.g. data-dependent sizes. In such a case,
// we would we need a better type than memref.
resultTypes.push_back(returnType);
int64_t returnIdx = returnOperand.getOperandNumber();
return returnOp->emitError() << "buffer result #" << returnIdx
<< " not produced by an alloc\n";
}
}
// 2. Compute bufferized FunctionType.
SmallVector<Type> argumentTypes{callOp->getOperandTypes()};
ValueRange hoistedArgs{hoistedArguments};
llvm::append_range(argumentTypes, hoistedArgs.getTypes());
// Get the bufferized FunctionType for funcOp or construct it if not yet
// available.
FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
funcOp, argumentTypes, resultTypes, state.bufferizedFunctionTypes);
// 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
SmallVector<Value> newOperands;
newOperands.reserve(callOp->getNumOperands());
for (OpOperand &opOperand : callOp->getOpOperands()) {
Value tensorOperand = opOperand.get();
// Non-tensor operands are just copied.
if (!tensorOperand.getType().isa<TensorType>()) {
newOperands.push_back(tensorOperand);
continue;
}
// Tensor operands are guaranteed to have been buferized.
int64_t idx = opOperand.getOperandNumber();
Value buffer = state.lookupBuffer(tensorOperand);
// Caller / callee type mistmatch is handled with a CastOp.
auto memRefType = bufferizedFuncType.getInput(idx);
// Since we don't yet have a clear layout story, buffer_cast may
// conservatively turn tensors into more dynamic memref than necessary.
// If the memref type of the callee fails, introduce an extra memref.cast
// that will either canonicalize away or fail compilation until we can do
// something better.
if (buffer.getType() != memRefType) {
Value castBuffer =
b.create<memref::CastOp>(callOp.getLoc(), memRefType, buffer);
// Add new op equivalence info.
state.aliasInfo.insertNewBufferEquivalence(castBuffer, buffer);
state.mapBuffer(tensorOperand, castBuffer);
buffer = castBuffer;
}
newOperands.push_back(buffer);
}
// 4. Create the new CallOp.
Operation *newCallOp = b.create<CallOp>(callOp.getLoc(), funcOp.sym_name(),
resultTypes, newOperands);
newCallOp->setAttrs(callOp->getAttrs());
// 5. Delete the op at the end of bufferization.
state.markOpObsolete(callOp);
return success();
}
};