forked from OSchip/llvm-project
[mlir][linalg][bufferize] Compose dialect-specific bufferization state
Use composition instead of inheritance for storing dialect-specific bufferization state. This is in preparation of adding "tensor dialect"-specific bufferization state. Differential Revision: https://reviews.llvm.org/D114508
This commit is contained in:
parent
c94b80b438
commit
d62b4b08af
|
@ -230,6 +230,13 @@ struct AllocationCallbacks {
|
||||||
MemCpyFn memCpyFn;
|
MemCpyFn memCpyFn;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Dialect-specific bufferization state. Analysis/bufferization information
|
||||||
|
/// that is specific to ops from a certain dialect can be stored in derived
|
||||||
|
/// variants of this struct.
|
||||||
|
struct DialectBufferizationState {
|
||||||
|
virtual ~DialectBufferizationState() = default;
|
||||||
|
};
|
||||||
|
|
||||||
/// BufferizationState keeps track of bufferization state and provides access to
|
/// BufferizationState keeps track of bufferization state and provides access to
|
||||||
/// the results of the analysis.
|
/// the results of the analysis.
|
||||||
struct BufferizationState {
|
struct BufferizationState {
|
||||||
|
@ -271,6 +278,14 @@ struct BufferizationState {
|
||||||
/// Erase all ops that were marked obsolete.
|
/// Erase all ops that were marked obsolete.
|
||||||
void eraseObsoleteOps();
|
void eraseObsoleteOps();
|
||||||
|
|
||||||
|
/// Return dialect-specific bufferization state.
|
||||||
|
template <typename StateT> StateT &getDialectState(StringRef name) {
|
||||||
|
// Create state if it does not exist yet.
|
||||||
|
if (!dialectState.count(name))
|
||||||
|
dialectState[name] = std::make_unique<StateT>();
|
||||||
|
return static_cast<StateT &>(*dialectState[name]);
|
||||||
|
}
|
||||||
|
|
||||||
/// `aliasInfo` keeps track of aliasing and equivalent values.
|
/// `aliasInfo` keeps track of aliasing and equivalent values.
|
||||||
BufferizationAliasInfo aliasInfo;
|
BufferizationAliasInfo aliasInfo;
|
||||||
|
|
||||||
|
@ -284,6 +299,9 @@ struct BufferizationState {
|
||||||
|
|
||||||
/// Obsolete ops that should be deleted after bufferization.
|
/// Obsolete ops that should be deleted after bufferization.
|
||||||
SmallVector<Operation *> obsoleteOps;
|
SmallVector<Operation *> obsoleteOps;
|
||||||
|
|
||||||
|
/// Dialect-specific bufferization state.
|
||||||
|
DenseMap<StringRef, std::unique_ptr<DialectBufferizationState>> dialectState;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
|
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
|
||||||
|
|
|
@ -27,11 +27,9 @@ using namespace tensor;
|
||||||
using namespace comprehensive_bufferize;
|
using namespace comprehensive_bufferize;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
/// A specialization of BufferizationState that keeps track of additional
|
/// Extra bufferization state that is required for bufferization of function
|
||||||
/// state required for bufferization of function boundaries.
|
/// boundaries.
|
||||||
struct ModuleBufferizationState : public BufferizationState {
|
struct ModuleBufferizationState : public DialectBufferizationState {
|
||||||
using BufferizationState::BufferizationState;
|
|
||||||
|
|
||||||
/// A map for looking up bufferized function types.
|
/// A map for looking up bufferized function types.
|
||||||
DenseMap<FuncOp, FunctionType> bufferizedFunctionTypes;
|
DenseMap<FuncOp, FunctionType> bufferizedFunctionTypes;
|
||||||
|
|
||||||
|
@ -40,6 +38,12 @@ struct ModuleBufferizationState : public BufferizationState {
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
static ModuleBufferizationState &
|
||||||
|
getModuleBufferizationState(BufferizationState &state) {
|
||||||
|
return state.getDialectState<ModuleBufferizationState>(
|
||||||
|
StandardOpsDialect::getDialectNamespace());
|
||||||
|
}
|
||||||
|
|
||||||
static bool isaTensor(Type t) { return t.isa<TensorType>(); }
|
static bool isaTensor(Type t) { return t.isa<TensorType>(); }
|
||||||
|
|
||||||
/// If `value` is a memref::CastOp, return its source. Otherwise, return
|
/// If `value` is a memref::CastOp, return its source. Otherwise, return
|
||||||
|
@ -127,7 +131,9 @@ static FunctionType getOrCreateBufferizedFunctionType(
|
||||||
/// Store function BlockArguments that are equivalent to a returned value in
|
/// Store function BlockArguments that are equivalent to a returned value in
|
||||||
/// the given ModuleBufferizationState.
|
/// the given ModuleBufferizationState.
|
||||||
static void populateEquivalentFuncOpBBArgs(FuncOp funcOp,
|
static void populateEquivalentFuncOpBBArgs(FuncOp funcOp,
|
||||||
ModuleBufferizationState &state) {
|
BufferizationState &state) {
|
||||||
|
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
|
||||||
|
|
||||||
// Support only single return-terminated block in the function.
|
// Support only single return-terminated block in the function.
|
||||||
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
|
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
|
||||||
assert(returnOp && "expected func with single return op");
|
assert(returnOp && "expected func with single return op");
|
||||||
|
@ -137,7 +143,7 @@ static void populateEquivalentFuncOpBBArgs(FuncOp funcOp,
|
||||||
for (BlockArgument bbArg : funcOp.getArguments())
|
for (BlockArgument bbArg : funcOp.getArguments())
|
||||||
if (bbArg.getType().isa<RankedTensorType>())
|
if (bbArg.getType().isa<RankedTensorType>())
|
||||||
if (state.aliasInfo.areEquivalentBufferizedValues(returnVal, bbArg))
|
if (state.aliasInfo.areEquivalentBufferizedValues(returnVal, bbArg))
|
||||||
state.equivalentReturnValToBBArg[returnVal] = bbArg;
|
moduleState.equivalentReturnValToBBArg[returnVal] = bbArg;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Rewrite the `funcOp` arguments analysis return values and terminator into
|
/// Rewrite the `funcOp` arguments analysis return values and terminator into
|
||||||
|
@ -155,8 +161,9 @@ static void populateEquivalentFuncOpBBArgs(FuncOp funcOp,
|
||||||
/// originate from an op with an Alloc effect, they could be hoisted in the
|
/// originate from an op with an Alloc effect, they could be hoisted in the
|
||||||
/// future.
|
/// future.
|
||||||
static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
|
static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
|
||||||
ModuleBufferizationState &state) {
|
BufferizationState &state) {
|
||||||
LLVM_DEBUG(DBGS() << "Begin bufferizeFuncOpBoundary:\n" << funcOp << "\n");
|
LLVM_DEBUG(DBGS() << "Begin bufferizeFuncOpBoundary:\n" << funcOp << "\n");
|
||||||
|
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
|
||||||
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
|
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
|
||||||
|
|
||||||
// If nothing to do then we are done.
|
// If nothing to do then we are done.
|
||||||
|
@ -188,7 +195,7 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
|
||||||
<< "returns a tensor";
|
<< "returns a tensor";
|
||||||
FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
|
FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
|
||||||
funcOp, funcOp.getType().getInputs(), TypeRange{},
|
funcOp, funcOp.getType().getInputs(), TypeRange{},
|
||||||
state.bufferizedFunctionTypes);
|
moduleState.bufferizedFunctionTypes);
|
||||||
funcOp.setType(bufferizedFuncType);
|
funcOp.setType(bufferizedFuncType);
|
||||||
LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary no fun body: " << funcOp);
|
LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary no fun body: " << funcOp);
|
||||||
return success();
|
return success();
|
||||||
|
@ -210,7 +217,7 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
|
||||||
}
|
}
|
||||||
|
|
||||||
// If return operand is equivalent to some bbArg, no need to return it.
|
// If return operand is equivalent to some bbArg, no need to return it.
|
||||||
if (state.equivalentReturnValToBBArg.count(returnVal))
|
if (moduleState.equivalentReturnValToBBArg.count(returnVal))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
// Cast values at the call site if necessary.
|
// Cast values at the call site if necessary.
|
||||||
|
@ -221,7 +228,7 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
|
||||||
ValueRange retValues{returnValues};
|
ValueRange retValues{returnValues};
|
||||||
FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
|
FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
|
||||||
funcOp, funcOp.getType().getInputs(), retValues.getTypes(),
|
funcOp, funcOp.getType().getInputs(), retValues.getTypes(),
|
||||||
state.bufferizedFunctionTypes);
|
moduleState.bufferizedFunctionTypes);
|
||||||
OpBuilder b(returnOp);
|
OpBuilder b(returnOp);
|
||||||
b.create<ReturnOp>(returnOp.getLoc(), returnValues);
|
b.create<ReturnOp>(returnOp.getLoc(), returnValues);
|
||||||
returnOp->erase();
|
returnOp->erase();
|
||||||
|
@ -474,7 +481,7 @@ struct CallOpInterface
|
||||||
FuncOp funcOp = getCalledFunction(callOp);
|
FuncOp funcOp = getCalledFunction(callOp);
|
||||||
assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
|
assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
|
||||||
"expected Callop to a FuncOp");
|
"expected Callop to a FuncOp");
|
||||||
auto &moduleState = static_cast<ModuleBufferizationState &>(state);
|
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
|
||||||
|
|
||||||
// Take a guard before anything else.
|
// Take a guard before anything else.
|
||||||
OpBuilder::InsertionGuard g(b);
|
OpBuilder::InsertionGuard g(b);
|
||||||
|
@ -649,7 +656,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
|
||||||
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
|
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
ModuleBufferizationState state(moduleOp, *options.allocationFns);
|
BufferizationState state(moduleOp, *options.allocationFns);
|
||||||
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
|
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
|
||||||
|
|
||||||
// Interestingly, all function args that are not visible outside of a module
|
// Interestingly, all function args that are not visible outside of a module
|
||||||
|
|
Loading…
Reference in New Issue