forked from OSchip/llvm-project
[mlir][linalg][bufferize] Compose dialect-specific bufferization state
Use composition instead of inheritance for storing dialect-specific bufferization state. This is in preparation of adding "tensor dialect"-specific bufferization state. Differential Revision: https://reviews.llvm.org/D114508
This commit is contained in:
parent
c94b80b438
commit
d62b4b08af
|
@ -230,6 +230,13 @@ struct AllocationCallbacks {
|
|||
MemCpyFn memCpyFn;
|
||||
};
|
||||
|
||||
/// Dialect-specific bufferization state. Analysis/bufferization information
|
||||
/// that is specific to ops from a certain dialect can be stored in derived
|
||||
/// variants of this struct.
|
||||
struct DialectBufferizationState {
|
||||
virtual ~DialectBufferizationState() = default;
|
||||
};
|
||||
|
||||
/// BufferizationState keeps track of bufferization state and provides access to
|
||||
/// the results of the analysis.
|
||||
struct BufferizationState {
|
||||
|
@ -271,6 +278,14 @@ struct BufferizationState {
|
|||
/// Erase all ops that were marked obsolete.
|
||||
void eraseObsoleteOps();
|
||||
|
||||
/// Return dialect-specific bufferization state.
|
||||
template <typename StateT> StateT &getDialectState(StringRef name) {
|
||||
// Create state if it does not exist yet.
|
||||
if (!dialectState.count(name))
|
||||
dialectState[name] = std::make_unique<StateT>();
|
||||
return static_cast<StateT &>(*dialectState[name]);
|
||||
}
|
||||
|
||||
/// `aliasInfo` keeps track of aliasing and equivalent values.
|
||||
BufferizationAliasInfo aliasInfo;
|
||||
|
||||
|
@ -284,6 +299,9 @@ struct BufferizationState {
|
|||
|
||||
/// Obsolete ops that should be deleted after bufferization.
|
||||
SmallVector<Operation *> obsoleteOps;
|
||||
|
||||
/// Dialect-specific bufferization state.
|
||||
DenseMap<StringRef, std::unique_ptr<DialectBufferizationState>> dialectState;
|
||||
};
|
||||
|
||||
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
|
||||
|
|
|
@ -27,11 +27,9 @@ using namespace tensor;
|
|||
using namespace comprehensive_bufferize;
|
||||
|
||||
namespace {
|
||||
/// A specialization of BufferizationState that keeps track of additional
|
||||
/// state required for bufferization of function boundaries.
|
||||
struct ModuleBufferizationState : public BufferizationState {
|
||||
using BufferizationState::BufferizationState;
|
||||
|
||||
/// Extra bufferization state that is required for bufferization of function
|
||||
/// boundaries.
|
||||
struct ModuleBufferizationState : public DialectBufferizationState {
|
||||
/// A map for looking up bufferized function types.
|
||||
DenseMap<FuncOp, FunctionType> bufferizedFunctionTypes;
|
||||
|
||||
|
@ -40,6 +38,12 @@ struct ModuleBufferizationState : public BufferizationState {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
static ModuleBufferizationState &
|
||||
getModuleBufferizationState(BufferizationState &state) {
|
||||
return state.getDialectState<ModuleBufferizationState>(
|
||||
StandardOpsDialect::getDialectNamespace());
|
||||
}
|
||||
|
||||
static bool isaTensor(Type t) { return t.isa<TensorType>(); }
|
||||
|
||||
/// If `value` is a memref::CastOp, return its source. Otherwise, return
|
||||
|
@ -127,7 +131,9 @@ static FunctionType getOrCreateBufferizedFunctionType(
|
|||
/// Store function BlockArguments that are equivalent to a returned value in
|
||||
/// the given ModuleBufferizationState.
|
||||
static void populateEquivalentFuncOpBBArgs(FuncOp funcOp,
|
||||
ModuleBufferizationState &state) {
|
||||
BufferizationState &state) {
|
||||
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
|
||||
|
||||
// Support only single return-terminated block in the function.
|
||||
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
|
||||
assert(returnOp && "expected func with single return op");
|
||||
|
@ -137,7 +143,7 @@ static void populateEquivalentFuncOpBBArgs(FuncOp funcOp,
|
|||
for (BlockArgument bbArg : funcOp.getArguments())
|
||||
if (bbArg.getType().isa<RankedTensorType>())
|
||||
if (state.aliasInfo.areEquivalentBufferizedValues(returnVal, bbArg))
|
||||
state.equivalentReturnValToBBArg[returnVal] = bbArg;
|
||||
moduleState.equivalentReturnValToBBArg[returnVal] = bbArg;
|
||||
}
|
||||
|
||||
/// Rewrite the `funcOp` arguments analysis return values and terminator into
|
||||
|
@ -155,8 +161,9 @@ static void populateEquivalentFuncOpBBArgs(FuncOp funcOp,
|
|||
/// originate from an op with an Alloc effect, they could be hoisted in the
|
||||
/// future.
|
||||
static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
|
||||
ModuleBufferizationState &state) {
|
||||
BufferizationState &state) {
|
||||
LLVM_DEBUG(DBGS() << "Begin bufferizeFuncOpBoundary:\n" << funcOp << "\n");
|
||||
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
|
||||
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
|
||||
|
||||
// If nothing to do then we are done.
|
||||
|
@ -188,7 +195,7 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
|
|||
<< "returns a tensor";
|
||||
FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
|
||||
funcOp, funcOp.getType().getInputs(), TypeRange{},
|
||||
state.bufferizedFunctionTypes);
|
||||
moduleState.bufferizedFunctionTypes);
|
||||
funcOp.setType(bufferizedFuncType);
|
||||
LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary no fun body: " << funcOp);
|
||||
return success();
|
||||
|
@ -210,7 +217,7 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
|
|||
}
|
||||
|
||||
// If return operand is equivalent to some bbArg, no need to return it.
|
||||
if (state.equivalentReturnValToBBArg.count(returnVal))
|
||||
if (moduleState.equivalentReturnValToBBArg.count(returnVal))
|
||||
continue;
|
||||
|
||||
// Cast values at the call site if necessary.
|
||||
|
@ -221,7 +228,7 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
|
|||
ValueRange retValues{returnValues};
|
||||
FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
|
||||
funcOp, funcOp.getType().getInputs(), retValues.getTypes(),
|
||||
state.bufferizedFunctionTypes);
|
||||
moduleState.bufferizedFunctionTypes);
|
||||
OpBuilder b(returnOp);
|
||||
b.create<ReturnOp>(returnOp.getLoc(), returnValues);
|
||||
returnOp->erase();
|
||||
|
@ -474,7 +481,7 @@ struct CallOpInterface
|
|||
FuncOp funcOp = getCalledFunction(callOp);
|
||||
assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
|
||||
"expected Callop to a FuncOp");
|
||||
auto &moduleState = static_cast<ModuleBufferizationState &>(state);
|
||||
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
|
||||
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
|
@ -649,7 +656,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
|
|||
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
|
||||
return failure();
|
||||
|
||||
ModuleBufferizationState state(moduleOp, *options.allocationFns);
|
||||
BufferizationState state(moduleOp, *options.allocationFns);
|
||||
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
|
||||
|
||||
// Interestingly, all function args that are not visible outside of a module
|
||||
|
|
Loading…
Reference in New Issue