[mlir][linalg][bufferize] Compose dialect-specific bufferization state

Use composition instead of inheritance for storing dialect-specific bufferization state. This is in preparation of adding "tensor dialect"-specific bufferization state.

Differential Revision: https://reviews.llvm.org/D114508
This commit is contained in:
Matthias Springer 2021-11-26 11:35:10 +09:00
parent c94b80b438
commit d62b4b08af
2 changed files with 38 additions and 13 deletions

View File

@ -230,6 +230,13 @@ struct AllocationCallbacks {
MemCpyFn memCpyFn;
};
/// Dialect-specific bufferization state. Analysis/bufferization information
/// that is specific to ops from a certain dialect can be stored in derived
/// variants of this struct.
struct DialectBufferizationState {
virtual ~DialectBufferizationState() = default;
};
/// BufferizationState keeps track of bufferization state and provides access to
/// the results of the analysis.
struct BufferizationState {
@ -271,6 +278,14 @@ struct BufferizationState {
/// Erase all ops that were marked obsolete.
void eraseObsoleteOps();
/// Return dialect-specific bufferization state.
template <typename StateT> StateT &getDialectState(StringRef name) {
// Create state if it does not exist yet.
if (!dialectState.count(name))
dialectState[name] = std::make_unique<StateT>();
return static_cast<StateT &>(*dialectState[name]);
}
/// `aliasInfo` keeps track of aliasing and equivalent values.
BufferizationAliasInfo aliasInfo;
@ -284,6 +299,9 @@ struct BufferizationState {
/// Obsolete ops that should be deleted after bufferization.
SmallVector<Operation *> obsoleteOps;
/// Dialect-specific bufferization state.
DenseMap<StringRef, std::unique_ptr<DialectBufferizationState>> dialectState;
};
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate

View File

@ -27,11 +27,9 @@ using namespace tensor;
using namespace comprehensive_bufferize;
namespace {
/// A specialization of BufferizationState that keeps track of additional
/// state required for bufferization of function boundaries.
struct ModuleBufferizationState : public BufferizationState {
using BufferizationState::BufferizationState;
/// Extra bufferization state that is required for bufferization of function
/// boundaries.
struct ModuleBufferizationState : public DialectBufferizationState {
/// A map for looking up bufferized function types.
DenseMap<FuncOp, FunctionType> bufferizedFunctionTypes;
@ -40,6 +38,12 @@ struct ModuleBufferizationState : public BufferizationState {
};
} // namespace
static ModuleBufferizationState &
getModuleBufferizationState(BufferizationState &state) {
return state.getDialectState<ModuleBufferizationState>(
StandardOpsDialect::getDialectNamespace());
}
static bool isaTensor(Type t) { return t.isa<TensorType>(); }
/// If `value` is a memref::CastOp, return its source. Otherwise, return
@ -127,7 +131,9 @@ static FunctionType getOrCreateBufferizedFunctionType(
/// Store function BlockArguments that are equivalent to a returned value in
/// the given ModuleBufferizationState.
static void populateEquivalentFuncOpBBArgs(FuncOp funcOp,
ModuleBufferizationState &state) {
BufferizationState &state) {
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
// Support only single return-terminated block in the function.
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
assert(returnOp && "expected func with single return op");
@ -137,7 +143,7 @@ static void populateEquivalentFuncOpBBArgs(FuncOp funcOp,
for (BlockArgument bbArg : funcOp.getArguments())
if (bbArg.getType().isa<RankedTensorType>())
if (state.aliasInfo.areEquivalentBufferizedValues(returnVal, bbArg))
state.equivalentReturnValToBBArg[returnVal] = bbArg;
moduleState.equivalentReturnValToBBArg[returnVal] = bbArg;
}
/// Rewrite the `funcOp` arguments analysis return values and terminator into
@ -155,8 +161,9 @@ static void populateEquivalentFuncOpBBArgs(FuncOp funcOp,
/// originate from an op with an Alloc effect, they could be hoisted in the
/// future.
static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
ModuleBufferizationState &state) {
BufferizationState &state) {
LLVM_DEBUG(DBGS() << "Begin bufferizeFuncOpBoundary:\n" << funcOp << "\n");
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
// If nothing to do then we are done.
@ -188,7 +195,7 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
<< "returns a tensor";
FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
funcOp, funcOp.getType().getInputs(), TypeRange{},
state.bufferizedFunctionTypes);
moduleState.bufferizedFunctionTypes);
funcOp.setType(bufferizedFuncType);
LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary no fun body: " << funcOp);
return success();
@ -210,7 +217,7 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
}
// If return operand is equivalent to some bbArg, no need to return it.
if (state.equivalentReturnValToBBArg.count(returnVal))
if (moduleState.equivalentReturnValToBBArg.count(returnVal))
continue;
// Cast values at the call site if necessary.
@ -221,7 +228,7 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
ValueRange retValues{returnValues};
FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
funcOp, funcOp.getType().getInputs(), retValues.getTypes(),
state.bufferizedFunctionTypes);
moduleState.bufferizedFunctionTypes);
OpBuilder b(returnOp);
b.create<ReturnOp>(returnOp.getLoc(), returnValues);
returnOp->erase();
@ -474,7 +481,7 @@ struct CallOpInterface
FuncOp funcOp = getCalledFunction(callOp);
assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
"expected Callop to a FuncOp");
auto &moduleState = static_cast<ModuleBufferizationState &>(state);
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
@ -649,7 +656,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
return failure();
ModuleBufferizationState state(moduleOp, *options.allocationFns);
BufferizationState state(moduleOp, *options.allocationFns);
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
// Interestingly, all function args that are not visible outside of a module