forked from OSchip/llvm-project
[mlir][linalg][bufferize][NFC] Pass BufferizationState as const reference
This is mostly for documentation purposes: Passing the object as a const reference signifies that analysis decisions cannot be changed after the analysis. Differential Revision: https://reviews.llvm.org/D116742
This commit is contained in:
parent
f2277e60f4
commit
2975407bd4
|
@ -304,29 +304,29 @@ public:
|
|||
|
||||
/// Determine which OpOperand* will alias with `result` if the op is
|
||||
/// bufferized in place. Return an empty vector if the op is not bufferizable.
|
||||
SmallVector<OpOperand *> getAliasingOpOperand(OpResult result);
|
||||
SmallVector<OpOperand *> getAliasingOpOperand(OpResult result) const;
|
||||
|
||||
/// Determine which OpResult will alias with `opOperand` if the op is
|
||||
/// bufferized in place. Return an empty OpResult if the op is not
|
||||
/// bufferizable.
|
||||
OpResult getAliasingOpResult(OpOperand &opOperand);
|
||||
OpResult getAliasingOpResult(OpOperand &opOperand) const;
|
||||
|
||||
/// Return true if `opOperand` bufferizes to a memory read. Return `true` if
|
||||
/// the op is not bufferizable.
|
||||
bool bufferizesToMemoryRead(OpOperand &opOperand);
|
||||
bool bufferizesToMemoryRead(OpOperand &opOperand) const;
|
||||
|
||||
/// Return true if `opOperand` bufferizes to a memory write. Return true` if
|
||||
/// the op is not bufferizable.
|
||||
bool bufferizesToMemoryWrite(OpOperand &opOperand);
|
||||
bool bufferizesToMemoryWrite(OpOperand &opOperand) const;
|
||||
|
||||
/// Return true if `opOperand` does neither read nor write but bufferizes to
|
||||
/// an alias. Return false if the op is not bufferizable.
|
||||
bool bufferizesToAliasOnly(OpOperand &opOperand);
|
||||
bool bufferizesToAliasOnly(OpOperand &opOperand) const;
|
||||
|
||||
/// Return true if the given value is read by an op that bufferizes to a
|
||||
/// memory read. Also takes into account ops that create an alias but do not
|
||||
/// read by themselves (e.g., ExtractSliceOp).
|
||||
bool isValueRead(Value value);
|
||||
bool isValueRead(Value value) const;
|
||||
|
||||
/// Starting from `value`, follow the use-def chain in reverse, always
|
||||
/// selecting the aliasing OpOperands. Find and return Values for which
|
||||
|
@ -351,9 +351,8 @@ public:
|
|||
/// In the above example, Values with a star satisfy the condition. When
|
||||
/// starting the traversal from Value 1, the resulting SetVector is:
|
||||
/// { 2, 7, 8, 5 }
|
||||
llvm::SetVector<Value>
|
||||
findValueInReverseUseDefChain(Value value,
|
||||
llvm::function_ref<bool(Value)> condition);
|
||||
llvm::SetVector<Value> findValueInReverseUseDefChain(
|
||||
Value value, llvm::function_ref<bool(Value)> condition) const;
|
||||
|
||||
/// Find the Value of the last preceding write of a given Value.
|
||||
///
|
||||
|
@ -363,33 +362,34 @@ public:
|
|||
///
|
||||
/// Note: When reaching an end of the reverse SSA use-def chain, that value
|
||||
/// is returned regardless of whether it is a memory write or not.
|
||||
Value findLastPrecedingWrite(Value value);
|
||||
Value findLastPrecedingWrite(Value value) const;
|
||||
|
||||
/// Creates a memref allocation.
|
||||
Optional<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
|
||||
ArrayRef<Value> dynShape);
|
||||
ArrayRef<Value> dynShape) const;
|
||||
|
||||
/// Creates an alloc-dealloc pair. This function may perform additional
|
||||
/// optimizations such as buffer allocation hoisting.
|
||||
Value createAllocDeallocPair(OpBuilder &builder, Location loc,
|
||||
Value shapedValue);
|
||||
Value shapedValue) const;
|
||||
|
||||
/// Creates a memref deallocation. The given memref buffer must have been
|
||||
/// allocated using `createAlloc`.
|
||||
void createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer);
|
||||
void createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer) const;
|
||||
|
||||
/// Creates a memcpy between two given buffers.
|
||||
void createMemCpy(OpBuilder &b, Location loc, Value from, Value to);
|
||||
void createMemCpy(OpBuilder &b, Location loc, Value from, Value to) const;
|
||||
|
||||
/// Replace an op with replacement values. The op is deleted. Tensor OpResults
|
||||
/// must be replaced with memref values.
|
||||
void replaceOp(RewriterBase &rewriter, Operation *op, ValueRange values);
|
||||
void replaceOp(RewriterBase &rewriter, Operation *op,
|
||||
ValueRange values) const;
|
||||
|
||||
/// Replace an op with a new op. Tensor OpResults must be replaced with memref
|
||||
/// values.
|
||||
template <typename OpTy, typename... Args>
|
||||
OpTy replaceOpWithNewOp(RewriterBase &rewriter, Operation *op,
|
||||
Args &&...args) {
|
||||
Args &&...args) const {
|
||||
Operation *newOp =
|
||||
rewriter.create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
|
||||
replaceOp(rewriter, op, newOp->getResults());
|
||||
|
@ -398,7 +398,7 @@ public:
|
|||
|
||||
/// Lookup the memref buffer that is associated to the given tensor value.
|
||||
/// Asserts if no buffer is associated.
|
||||
Value lookupBuffer(RewriterBase &rewriter, Value tensor);
|
||||
Value lookupBuffer(RewriterBase &rewriter, Value tensor) const;
|
||||
|
||||
/// Return `true` if the given OpResult has been decided to bufferize inplace.
|
||||
bool isInPlace(OpResult opResult) const;
|
||||
|
@ -406,10 +406,19 @@ public:
|
|||
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
|
||||
/// a new buffer and copy over data from the existing buffer if out-of-place
|
||||
/// bufferization is necessary.
|
||||
Value getResultBuffer(RewriterBase &rewriter, OpResult result);
|
||||
Value getResultBuffer(RewriterBase &rewriter, OpResult result) const;
|
||||
|
||||
/// Return dialect-specific bufferization state.
|
||||
template <typename StateT> StateT &getDialectState(StringRef name) {
|
||||
template <typename StateT>
|
||||
Optional<const StateT *> getDialectState(StringRef name) const {
|
||||
auto it = dialectState.find(name);
|
||||
if (it == dialectState.end())
|
||||
return None;
|
||||
return static_cast<const StateT *>(it->getSecond().get());
|
||||
}
|
||||
|
||||
/// Return dialect-specific bufferization state or create one if none exists.
|
||||
template <typename StateT> StateT &getOrCreateDialectState(StringRef name) {
|
||||
// Create state if it does not exist yet.
|
||||
if (!dialectState.count(name))
|
||||
dialectState[name] = std::make_unique<StateT>();
|
||||
|
@ -419,15 +428,10 @@ public:
|
|||
/// Return a reference to the BufferizationOptions.
|
||||
const BufferizationOptions &getOptions() const { return options; }
|
||||
|
||||
/// Return a reference to the BufferizationAliasInfo.
|
||||
BufferizationAliasInfo &getAliasInfo() { return aliasInfo; }
|
||||
|
||||
private:
|
||||
friend LogicalResult
|
||||
runComprehensiveBufferize(Operation *op, const BufferizationOptions &options,
|
||||
BufferizationState &state);
|
||||
|
||||
friend LogicalResult
|
||||
runComprehensiveBufferize(ModuleOp moduleOp,
|
||||
std::unique_ptr<BufferizationOptions> options);
|
||||
|
||||
/// `aliasInfo` keeps track of aliasing and equivalent values. Only internal
|
||||
/// functions and `runComprehensiveBufferize` may access this object.
|
||||
BufferizationAliasInfo aliasInfo;
|
||||
|
@ -441,17 +445,17 @@ private:
|
|||
|
||||
/// Bufferize all ops in the given region.
|
||||
LogicalResult bufferize(RewriterBase &rewriter, Region *region,
|
||||
BufferizationState &state);
|
||||
const BufferizationState &state);
|
||||
|
||||
/// Bufferize all ops in the given block.
|
||||
LogicalResult bufferize(RewriterBase &rewriter, Block *block,
|
||||
BufferizationState &state);
|
||||
const BufferizationState &state);
|
||||
|
||||
/// Bufferize the given op. If the op has no tensor OpOperands/OpResults, this
|
||||
/// function returns immediately. Otherwise, it calls the `bufferize` interface
|
||||
/// method of `BufferizableOpInterface`.
|
||||
LogicalResult bufferize(RewriterBase &rewriter, Operation *op,
|
||||
BufferizationState &state);
|
||||
const BufferizationState &state);
|
||||
|
||||
/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
|
||||
/// with the same shape as `shapedType` and specified `layout` and
|
||||
|
@ -492,38 +496,39 @@ struct AllocationHoistingBarrierOnly
|
|||
: public BufferizableOpInterface::ExternalModel<
|
||||
AllocationHoistingBarrierOnly<OpTy>, OpTy> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
SmallVector<OpOperand *>
|
||||
getAliasingOpOperand(Operation *op, OpResult opResult,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return OpResult();
|
||||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const BufferizationAliasInfo &aliasInfo,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return BufferRelation::None;
|
||||
}
|
||||
|
||||
bool isWritable(Operation *op, Value value, BufferizationState &state) const {
|
||||
bool isWritable(Operation *op, Value value,
|
||||
const BufferizationState &state) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
|
||||
if (any_of(op->getOperandTypes(), isaTensor) ||
|
||||
any_of(op->getResultTypes(), isaTensor))
|
||||
|
|
|
@ -33,7 +33,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
|||
/*retType=*/"bool",
|
||||
/*methodName=*/"bufferizesToMemoryRead",
|
||||
/*args=*/(ins "OpOperand &":$opOperand,
|
||||
"BufferizationState &":$state),
|
||||
"const BufferizationState &":$state),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
// Does not have to be implemented for ops without tensor OpOperands.
|
||||
|
@ -62,7 +62,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
|||
/*retType=*/"bool",
|
||||
/*methodName=*/"bufferizesToMemoryWrite",
|
||||
/*args=*/(ins "OpOperand &":$opOperand,
|
||||
"BufferizationState &":$state),
|
||||
"const BufferizationState &":$state),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
// Does not have to be implemented for ops without tensor OpOperands.
|
||||
|
@ -85,7 +85,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
|||
/*retType=*/"bool",
|
||||
/*methodName=*/"isMemoryWrite",
|
||||
/*args=*/(ins "OpResult":$opResult,
|
||||
"BufferizationState &":$state),
|
||||
"const BufferizationState &":$state),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
auto bufferizableOp =
|
||||
|
@ -116,7 +116,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
|||
/*retType=*/"bool",
|
||||
/*methodName=*/"mustBufferizeInPlace",
|
||||
/*args=*/(ins "OpResult":$opResult,
|
||||
"BufferizationState &":$state),
|
||||
"const BufferizationState &":$state),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return false;
|
||||
|
@ -131,7 +131,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
|||
/*retType=*/"OpResult",
|
||||
/*methodName=*/"getAliasingOpResult",
|
||||
/*args=*/(ins "OpOperand &":$opOperand,
|
||||
"BufferizationState &":$state),
|
||||
"const BufferizationState &":$state),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
// Does not have to be implemented for ops without tensor OpOperands.
|
||||
|
@ -155,7 +155,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
|||
/*retType=*/"SmallVector<OpOperand *>",
|
||||
/*methodName=*/"getAliasingOpOperand",
|
||||
/*args=*/(ins "OpResult":$opResult,
|
||||
"BufferizationState &":$state),
|
||||
"const BufferizationState &":$state),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
assert(opResult.getType().isa<TensorType>() &&
|
||||
|
@ -188,7 +188,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
|||
/*methodName=*/"bufferRelation",
|
||||
/*args=*/(ins "OpResult":$opResult,
|
||||
"const BufferizationAliasInfo &":$aliasInfo,
|
||||
"BufferizationState &":$state),
|
||||
"const BufferizationState &":$state),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
// Does not have to be implemented for ops without tensor OpResults
|
||||
|
@ -210,7 +210,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
|||
/*retType=*/"LogicalResult",
|
||||
/*methodName=*/"bufferize",
|
||||
/*args=*/(ins "RewriterBase &":$rewriter,
|
||||
"BufferizationState &":$state),
|
||||
"const BufferizationState &":$state),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
llvm_unreachable("bufferize not implemented");
|
||||
|
@ -236,7 +236,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
|||
/*retType=*/"bool",
|
||||
/*methodName=*/"isWritable",
|
||||
/*args=*/(ins "Value":$value,
|
||||
"BufferizationState &":$state),
|
||||
"const BufferizationState &":$state),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return value.isa<OpResult>();
|
||||
|
@ -275,7 +275,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
|||
/*methodName=*/"isNotConflicting",
|
||||
/*args=*/(ins "OpOperand *":$uRead,
|
||||
"OpOperand *":$uWrite,
|
||||
"BufferizationState &":$state,
|
||||
"const BufferizationState &":$state,
|
||||
"const BufferizationAliasInfo &":$aliasInfo),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
|
@ -292,7 +292,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
|||
///
|
||||
/// Examples of such ops are `tensor.extract_slice` and `tensor.cast`.
|
||||
bool bufferizesToAliasOnly(OpOperand &opOperand,
|
||||
BufferizationState &state) {
|
||||
const BufferizationState &state) {
|
||||
auto bufferizableOp =
|
||||
cast<BufferizableOpInterface>(getOperation());
|
||||
return !bufferizableOp.bufferizesToMemoryRead(opOperand, state)
|
||||
|
|
|
@ -24,7 +24,7 @@ struct ConstantOpInterface
|
|||
: public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
|
||||
arith::ConstantOp> {
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
auto constantOp = cast<arith::ConstantOp>(op);
|
||||
assert(constantOp.getType().dyn_cast<RankedTensorType>() &&
|
||||
"not a constant ranked tensor");
|
||||
|
@ -40,7 +40,8 @@ struct ConstantOpInterface
|
|||
return success();
|
||||
}
|
||||
|
||||
bool isWritable(Operation *op, Value value, BufferizationState &state) const {
|
||||
bool isWritable(Operation *op, Value value,
|
||||
const BufferizationState &state) const {
|
||||
// Memory locations returned by memref::GetGlobalOp may not be written to.
|
||||
assert(value.isa<OpResult>());
|
||||
return false;
|
||||
|
|
|
@ -199,7 +199,7 @@ BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
|
|||
/// in place. Return an empty vector if the op is not bufferizable.
|
||||
SmallVector<OpOperand *>
|
||||
mlir::linalg::comprehensive_bufferize::BufferizationState::getAliasingOpOperand(
|
||||
OpResult result) {
|
||||
OpResult result) const {
|
||||
if (Operation *op = result.getDefiningOp())
|
||||
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
|
||||
return bufferizableOp.getAliasingOpOperand(result, *this);
|
||||
|
@ -210,7 +210,7 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getAliasingOpOperand(
|
|||
/// in place. Return an empty OpResult if the op is not bufferizable.
|
||||
OpResult
|
||||
mlir::linalg::comprehensive_bufferize::BufferizationState::getAliasingOpResult(
|
||||
OpOperand &opOperand) {
|
||||
OpOperand &opOperand) const {
|
||||
if (auto bufferizableOp =
|
||||
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
|
||||
return bufferizableOp.getAliasingOpResult(opOperand, *this);
|
||||
|
@ -220,7 +220,7 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getAliasingOpResult(
|
|||
/// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
|
||||
/// op is not bufferizable.
|
||||
bool mlir::linalg::comprehensive_bufferize::BufferizationState::
|
||||
bufferizesToMemoryRead(OpOperand &opOperand) {
|
||||
bufferizesToMemoryRead(OpOperand &opOperand) const {
|
||||
if (auto bufferizableOp =
|
||||
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
|
||||
return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
|
||||
|
@ -233,7 +233,7 @@ bool mlir::linalg::comprehensive_bufferize::BufferizationState::
|
|||
/// Return true if `opOperand` bufferizes to a memory write. Return
|
||||
/// `true` if the op is not bufferizable.
|
||||
bool mlir::linalg::comprehensive_bufferize::BufferizationState::
|
||||
bufferizesToMemoryWrite(OpOperand &opOperand) {
|
||||
bufferizesToMemoryWrite(OpOperand &opOperand) const {
|
||||
if (auto bufferizableOp =
|
||||
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
|
||||
return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
|
||||
|
@ -246,7 +246,7 @@ bool mlir::linalg::comprehensive_bufferize::BufferizationState::
|
|||
/// Return true if `opOperand` does neither read nor write but bufferizes to an
|
||||
/// alias. Return false if the op is not bufferizable.
|
||||
bool mlir::linalg::comprehensive_bufferize::BufferizationState::
|
||||
bufferizesToAliasOnly(OpOperand &opOperand) {
|
||||
bufferizesToAliasOnly(OpOperand &opOperand) const {
|
||||
if (auto bufferizableOp =
|
||||
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
|
||||
return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
|
||||
|
@ -260,7 +260,7 @@ bool mlir::linalg::comprehensive_bufferize::BufferizationState::
|
|||
/// read. Also takes into account ops that create an alias but do not read by
|
||||
/// themselves (e.g., ExtractSliceOp).
|
||||
bool mlir::linalg::comprehensive_bufferize::BufferizationState::isValueRead(
|
||||
Value value) {
|
||||
Value value) const {
|
||||
SmallVector<OpOperand *> workingSet;
|
||||
for (OpOperand &use : value.getUses())
|
||||
workingSet.push_back(&use);
|
||||
|
@ -282,10 +282,9 @@ bool mlir::linalg::comprehensive_bufferize::BufferizationState::isValueRead(
|
|||
// the aliasing OpOperands. Find and return Values for which `condition`
|
||||
// evaluates to true. OpOperands of such matching Values are not traversed any
|
||||
// further.
|
||||
llvm::SetVector<Value>
|
||||
mlir::linalg::comprehensive_bufferize::BufferizationState::
|
||||
findValueInReverseUseDefChain(Value value,
|
||||
llvm::function_ref<bool(Value)> condition) {
|
||||
llvm::SetVector<Value> mlir::linalg::comprehensive_bufferize::
|
||||
BufferizationState::findValueInReverseUseDefChain(
|
||||
Value value, llvm::function_ref<bool(Value)> condition) const {
|
||||
llvm::SetVector<Value> result, workingSet;
|
||||
workingSet.insert(value);
|
||||
|
||||
|
@ -312,7 +311,7 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::
|
|||
|
||||
// Find the Value of the last preceding write of a given Value.
|
||||
Value mlir::linalg::comprehensive_bufferize::BufferizationState::
|
||||
findLastPrecedingWrite(Value value) {
|
||||
findLastPrecedingWrite(Value value) const {
|
||||
SetVector<Value> result =
|
||||
findValueInReverseUseDefChain(value, [&](Value value) {
|
||||
Operation *op = value.getDefiningOp();
|
||||
|
@ -360,7 +359,7 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
|
|||
/// a new buffer and copy over data from the existing buffer if out-of-place
|
||||
/// bufferization is necessary.
|
||||
Value mlir::linalg::comprehensive_bufferize::BufferizationState::
|
||||
getResultBuffer(RewriterBase &rewriter, OpResult result) {
|
||||
getResultBuffer(RewriterBase &rewriter, OpResult result) const {
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
Operation *op = result.getOwner();
|
||||
SmallVector<OpOperand *> aliasingOperands = getAliasingOpOperand(result);
|
||||
|
@ -424,7 +423,7 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
|
|||
}
|
||||
|
||||
void mlir::linalg::comprehensive_bufferize::BufferizationState::replaceOp(
|
||||
RewriterBase &rewriter, Operation *op, ValueRange values) {
|
||||
RewriterBase &rewriter, Operation *op, ValueRange values) const {
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
|
||||
// Replace all OpResults with the given values.
|
||||
|
@ -454,7 +453,7 @@ void mlir::linalg::comprehensive_bufferize::BufferizationState::replaceOp(
|
|||
}
|
||||
|
||||
LogicalResult mlir::linalg::comprehensive_bufferize::bufferize(
|
||||
RewriterBase &rewriter, Region *region, BufferizationState &state) {
|
||||
RewriterBase &rewriter, Region *region, const BufferizationState &state) {
|
||||
for (Block &block : *region)
|
||||
if (failed(bufferize(rewriter, &block, state)))
|
||||
return failure();
|
||||
|
@ -462,7 +461,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::bufferize(
|
|||
}
|
||||
|
||||
LogicalResult mlir::linalg::comprehensive_bufferize::bufferize(
|
||||
RewriterBase &rewriter, Block *block, BufferizationState &state) {
|
||||
RewriterBase &rewriter, Block *block, const BufferizationState &state) {
|
||||
// Ops may get deleted during the traversal, so do not iterate over `block`
|
||||
// directly.
|
||||
SmallVector<Operation *> ops;
|
||||
|
@ -476,7 +475,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::bufferize(
|
|||
}
|
||||
|
||||
LogicalResult mlir::linalg::comprehensive_bufferize::bufferize(
|
||||
RewriterBase &rewriter, Operation *op, BufferizationState &state) {
|
||||
RewriterBase &rewriter, Operation *op, const BufferizationState &state) {
|
||||
// Check if op has tensor results or operands.
|
||||
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
|
||||
bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
|
||||
|
@ -592,7 +591,8 @@ static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
|
|||
/// `shapedValue.getDefiningOp` (or at the top of the block in case of a
|
||||
/// bbArg) and the DeallocOp is at the end of the block.
|
||||
Value mlir::linalg::comprehensive_bufferize::BufferizationState::
|
||||
createAllocDeallocPair(OpBuilder &b, Location loc, Value shapedValue) {
|
||||
createAllocDeallocPair(OpBuilder &b, Location loc,
|
||||
Value shapedValue) const {
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
|
||||
|
@ -621,19 +621,20 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
|
|||
/// Create a memref allocation.
|
||||
Optional<Value>
|
||||
mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
|
||||
OpBuilder &b, Location loc, MemRefType type, ArrayRef<Value> dynShape) {
|
||||
OpBuilder &b, Location loc, MemRefType type,
|
||||
ArrayRef<Value> dynShape) const {
|
||||
return options.allocationFns->allocationFn(b, loc, type, dynShape);
|
||||
}
|
||||
|
||||
/// Create a memref deallocation.
|
||||
void mlir::linalg::comprehensive_bufferize::BufferizationState::createDealloc(
|
||||
OpBuilder &b, Location loc, Value allocatedBuffer) {
|
||||
OpBuilder &b, Location loc, Value allocatedBuffer) const {
|
||||
return options.allocationFns->deallocationFn(b, loc, allocatedBuffer);
|
||||
}
|
||||
|
||||
/// Create a memory copy between two memref buffers.
|
||||
void mlir::linalg::comprehensive_bufferize::BufferizationState::createMemCpy(
|
||||
OpBuilder &b, Location loc, Value from, Value to) {
|
||||
OpBuilder &b, Location loc, Value from, Value to) const {
|
||||
return options.allocationFns->memCpyFn(b, loc, from, to);
|
||||
}
|
||||
|
||||
|
@ -649,7 +650,7 @@ bool mlir::linalg::comprehensive_bufferize::isFunctionArgument(Value value) {
|
|||
}
|
||||
|
||||
Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
|
||||
RewriterBase &rewriter, Value tensor) {
|
||||
RewriterBase &rewriter, Value tensor) const {
|
||||
assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
|
||||
|
||||
// Replace "%t = to_tensor %m" with %m.
|
||||
|
|
|
@ -40,18 +40,18 @@ struct ToMemrefOpInterface
|
|||
: public BufferizableOpInterface::ExternalModel<ToMemrefOpInterface,
|
||||
bufferization::ToMemrefOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
// It is unknown whether the resulting MemRef will be read or not.
|
||||
return true;
|
||||
}
|
||||
|
||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return OpResult();
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
auto toMemrefOp = cast<bufferization::ToMemrefOp>(op);
|
||||
|
||||
// Fold to_memref(to_tensor(x)) to x.
|
||||
|
@ -86,11 +86,12 @@ struct ToTensorOpInterface
|
|||
: public BufferizableOpInterface::ExternalModel<ToTensorOpInterface,
|
||||
bufferization::ToTensorOp> {
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return success();
|
||||
}
|
||||
|
||||
bool isWritable(Operation *op, Value value, BufferizationState &state) const {
|
||||
bool isWritable(Operation *op, Value value,
|
||||
const BufferizationState &state) const {
|
||||
// It is unknown whether the MemRef operand is writable or not.
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -661,7 +661,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
|
|||
|
||||
IRRewriter rewriter(op->getContext());
|
||||
DominanceInfo domInfo(op);
|
||||
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
|
||||
BufferizationAliasInfo &aliasInfo = state.getAliasInfo();
|
||||
|
||||
if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo)))
|
||||
return failure();
|
||||
|
|
|
@ -24,7 +24,7 @@ namespace {
|
|||
|
||||
/// Generic conversion for any LinalgOp on tensors.
|
||||
static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
|
||||
BufferizationState &state) {
|
||||
const BufferizationState &state) {
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
rewriter.setInsertionPoint(op);
|
||||
|
@ -142,13 +142,13 @@ struct LinalgOpInterface
|
|||
: public BufferizableOpInterface::ExternalModel<LinalgOpInterface<OpTy>,
|
||||
OpTy> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
auto genericOp = cast<linalg::LinalgOp>(op);
|
||||
return genericOp.payloadUsesValueFromOperand(&opOperand);
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
auto bufferizableOp = cast<BufferizableOpInterface>(op);
|
||||
return static_cast<bool>(
|
||||
bufferizableOp.getAliasingOpResult(opOperand, state));
|
||||
|
@ -156,7 +156,7 @@ struct LinalgOpInterface
|
|||
|
||||
SmallVector<OpOperand *>
|
||||
getAliasingOpOperand(Operation *op, OpResult opResult,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
auto genericOp = cast<linalg::LinalgOp>(op);
|
||||
DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp);
|
||||
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands())
|
||||
|
@ -166,7 +166,7 @@ struct LinalgOpInterface
|
|||
}
|
||||
|
||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
auto genericOp = cast<linalg::LinalgOp>(op);
|
||||
DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp);
|
||||
return pairs[&opOperand];
|
||||
|
@ -174,12 +174,12 @@ struct LinalgOpInterface
|
|||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const BufferizationAliasInfo &aliasInfo,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return BufferRelation::Equivalent;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return bufferizeLinalgOp(rewriter, cast<LinalgOp>(op), state);
|
||||
}
|
||||
};
|
||||
|
@ -188,13 +188,13 @@ struct InitTensorOpInterface
|
|||
: public BufferizableOpInterface::ExternalModel<InitTensorOpInterface,
|
||||
linalg::InitTensorOp> {
|
||||
bool isMemoryWrite(Operation *op, OpResult opResult,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
// InitTensorOps allocate but do not write.
|
||||
return false;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
auto initTensorOp = cast<linalg::InitTensorOp>(op);
|
||||
|
||||
// The InitTensorOp may have been eliminated.
|
||||
|
@ -212,7 +212,7 @@ struct TiledLoopOpInterface
|
|||
: public BufferizableOpInterface::ExternalModel<TiledLoopOpInterface,
|
||||
linalg::TiledLoopOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
// TiledLoop alone doesn't bufferize to a memory read, one of the uses of
|
||||
// its matching bbArg may.
|
||||
auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
|
||||
|
@ -220,7 +220,7 @@ struct TiledLoopOpInterface
|
|||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
// TiledLoop alone doesn't bufferize to a memory write, one of the uses of
|
||||
// its matching bbArg may.
|
||||
auto bufferizableOp = cast<BufferizableOpInterface>(op);
|
||||
|
@ -229,18 +229,19 @@ struct TiledLoopOpInterface
|
|||
}
|
||||
|
||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
|
||||
return tiledLoopOp.getTiedOpResult(opOperand);
|
||||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const BufferizationAliasInfo &aliasInfo,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return BufferRelation::Equivalent;
|
||||
}
|
||||
|
||||
bool isWritable(Operation *op, Value value, BufferizationState &state) const {
|
||||
bool isWritable(Operation *op, Value value,
|
||||
const BufferizationState &state) const {
|
||||
// Interestingly, linalg::TiledLoopOp's bbArg can **always** be viewed
|
||||
// inplace from the perspective of ops nested under:
|
||||
// 1. Either the matching iter operand is not bufferized inplace and an
|
||||
|
@ -253,7 +254,7 @@ struct TiledLoopOpInterface
|
|||
bool isAllocationHoistingBarrier(Operation *op) const { return true; }
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
|
||||
|
||||
// Compute new inputs, outputs and results.
|
||||
|
@ -355,22 +356,22 @@ struct YieldOpInterface
|
|||
: public BufferizableOpInterface::ExternalModel<YieldOpInterface,
|
||||
linalg::YieldOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return OpResult();
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
auto yieldOp = cast<linalg::YieldOp>(op);
|
||||
|
||||
if (!yieldOp->getParentOfType<TiledLoopOp>())
|
||||
|
|
|
@ -34,9 +34,20 @@ struct ModuleBufferizationState : public DialectBufferizationState {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
/// Get ModuleBufferizationState.
|
||||
static const ModuleBufferizationState &
|
||||
getModuleBufferizationState(const BufferizationState &state) {
|
||||
Optional<const ModuleBufferizationState *> maybeState =
|
||||
state.getDialectState<ModuleBufferizationState>(
|
||||
StandardOpsDialect::getDialectNamespace());
|
||||
assert(maybeState.hasValue() && "ModuleBufferizationState does not exist");
|
||||
return **maybeState;
|
||||
}
|
||||
|
||||
/// Get or create ModuleBufferizationState.
|
||||
static ModuleBufferizationState &
|
||||
getModuleBufferizationState(BufferizationState &state) {
|
||||
return state.getDialectState<ModuleBufferizationState>(
|
||||
return state.getOrCreateDialectState<ModuleBufferizationState>(
|
||||
StandardOpsDialect::getDialectNamespace());
|
||||
}
|
||||
|
||||
|
@ -471,19 +482,25 @@ namespace std_ext {
|
|||
/// Return the index of the bbArg in the given FuncOp that is equivalent to the
|
||||
/// specified return value (if any).
|
||||
static Optional<int64_t>
|
||||
getEquivalentFuncArgIdx(FuncOp funcOp, ModuleBufferizationState &state,
|
||||
getEquivalentFuncArgIdx(FuncOp funcOp, const ModuleBufferizationState &state,
|
||||
int64_t returnValIdx) {
|
||||
if (!state.equivalentFuncArgs[funcOp].count(returnValIdx))
|
||||
if (!state.equivalentFuncArgs.count(funcOp))
|
||||
// No equivalence info stores for funcOp.
|
||||
return None;
|
||||
|
||||
const DenseMap<int64_t, int64_t> &equivFuncArgs =
|
||||
state.equivalentFuncArgs.lookup(funcOp);
|
||||
if (!equivFuncArgs.count(returnValIdx))
|
||||
// Return value has no equivalent bbArg.
|
||||
return None;
|
||||
|
||||
return state.equivalentFuncArgs[funcOp][returnValIdx];
|
||||
return equivFuncArgs.lookup(returnValIdx);
|
||||
}
|
||||
|
||||
struct CallOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<CallOpInterface, CallOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
// CallOpInterface alone doesn't bufferize to a memory read, one of the uses
|
||||
// of the matching bbArg may. It is the responsibility of the caller to
|
||||
// inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be
|
||||
|
@ -492,7 +509,7 @@ struct CallOpInterface
|
|||
}
|
||||
|
||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
// CallOpInterface is special, it needs to wait for the callee to be
|
||||
// bufferized and needs to inspect the BufferAliasInfo object. It can't
|
||||
// make a proper determination by itself and needs to be conservative.
|
||||
|
@ -503,14 +520,15 @@ struct CallOpInterface
|
|||
/// marked inplaceable. For now, it is the responsibility of the `callOp`
|
||||
/// bufferization to allow FuncOp that are inplaceable to write inPlace.
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
CallOp callOp = cast<CallOp>(op);
|
||||
unsigned numResults = callOp.getNumResults();
|
||||
unsigned numOperands = callOp->getNumOperands();
|
||||
FuncOp funcOp = getCalledFunction(callOp);
|
||||
assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
|
||||
"expected CallOp to a FuncOp");
|
||||
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
|
||||
const ModuleBufferizationState &moduleState =
|
||||
getModuleBufferizationState(state);
|
||||
|
||||
// Result types of the bufferized CallOp.
|
||||
SmallVector<Type> resultTypes;
|
||||
|
@ -626,22 +644,22 @@ struct ReturnOpInterface
|
|||
: public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
|
||||
ReturnOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return OpResult();
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
auto returnOp = cast<ReturnOp>(op);
|
||||
assert(isa<FuncOp>(returnOp->getParentOp()) &&
|
||||
"only support FuncOp parent for ReturnOp");
|
||||
|
@ -662,7 +680,7 @@ struct ReturnOpInterface
|
|||
struct FuncOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> {
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
auto funcOp = cast<FuncOp>(op);
|
||||
|
||||
// Bufferize function body.
|
||||
|
@ -670,11 +688,13 @@ struct FuncOpInterface
|
|||
}
|
||||
|
||||
/// Return `true` if the given function argument is writable.
|
||||
bool isWritable(Operation *op, Value value, BufferizationState &state) const {
|
||||
bool isWritable(Operation *op, Value value,
|
||||
const BufferizationState &state) const {
|
||||
auto funcOp = cast<FuncOp>(op);
|
||||
BlockArgument bbArg = value.dyn_cast<BlockArgument>();
|
||||
assert(bbArg && "expected BlockArgument");
|
||||
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
|
||||
const ModuleBufferizationState &moduleState =
|
||||
getModuleBufferizationState(state);
|
||||
|
||||
// In a first approximation:
|
||||
// =========================
|
||||
|
@ -720,8 +740,9 @@ static void setInPlaceFuncArgument(BlockArgument bbArg, bool inPlace) {
|
|||
}
|
||||
|
||||
/// Annotate the IR with the result of the analysis. For testing/debugging only.
|
||||
static void annotateOpsWithBufferizationMarkers(FuncOp funcOp,
|
||||
BufferizationState &state) {
|
||||
static void
|
||||
annotateOpsWithBufferizationMarkers(FuncOp funcOp,
|
||||
const BufferizationState &state) {
|
||||
auto bufferizableOp = cast<BufferizableOpInterface>(funcOp.getOperation());
|
||||
for (BlockArgument bbArg : funcOp.getArguments())
|
||||
if (bbArg.getType().isa<TensorType>())
|
||||
|
@ -733,7 +754,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
|
|||
IRRewriter rewriter(moduleOp.getContext());
|
||||
BufferizationState state(moduleOp, *options);
|
||||
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
|
||||
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
|
||||
BufferizationAliasInfo &aliasInfo = state.getAliasInfo();
|
||||
|
||||
if (failed(getFuncOpsOrderedByCalls(moduleOp, moduleState.orderedFuncOps,
|
||||
moduleState.callerMap)))
|
||||
|
|
|
@ -24,7 +24,7 @@ struct ExecuteRegionOpInterface
|
|||
scf::ExecuteRegionOp> {
|
||||
SmallVector<OpOperand *>
|
||||
getAliasingOpOperand(Operation *op, OpResult opResult,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
// ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
|
||||
// any SSA value that is in scope. To allow for use-def chain traversal
|
||||
// through ExecuteRegionOps in the analysis, the corresponding yield value
|
||||
|
@ -41,7 +41,7 @@ struct ExecuteRegionOpInterface
|
|||
}
|
||||
|
||||
bool mustBufferizeInPlace(Operation *op, OpResult opResult,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
// ExecuteRegionOp results always bufferize in-place. Since they have no
|
||||
// OpOperands, they are mostly ignored by the analysis once alias sets are
|
||||
// set up.
|
||||
|
@ -51,7 +51,7 @@ struct ExecuteRegionOpInterface
|
|||
// TODO: For better bufferization results, this could return `true` only if
|
||||
// there is a memory write in the region.
|
||||
bool isMemoryWrite(Operation *op, OpResult opResult,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
// Similar to scf.if, results of this op are always considered memory writes
|
||||
// in the analysis. This is a useful pattern for all ops that have tensor
|
||||
// OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
|
||||
|
@ -61,7 +61,7 @@ struct ExecuteRegionOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
// TODO: Add bufferization support when needed. scf.execute_region should be
|
||||
// bufferized similar to scf.if.
|
||||
auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
|
||||
|
@ -76,7 +76,7 @@ struct ExecuteRegionOpInterface
|
|||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const BufferizationAliasInfo &aliasInfo,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return BufferRelation::Equivalent;
|
||||
}
|
||||
};
|
||||
|
@ -85,7 +85,7 @@ struct IfOpInterface
|
|||
: public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
|
||||
SmallVector<OpOperand *>
|
||||
getAliasingOpOperand(Operation *op, OpResult opResult,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
// IfOps do not have tensor OpOperands. The yielded value can be any SSA
|
||||
// value that is in scope. To allow for use-def chain traversal through
|
||||
// IfOps in the analysis, both corresponding yield values from the then/else
|
||||
|
@ -102,7 +102,7 @@ struct IfOpInterface
|
|||
// allowed at the moment, we should never encounter scf.ifs that yield
|
||||
// unmodified tensors. Such scf.yield ops could just fold away.
|
||||
bool isMemoryWrite(Operation *op, OpResult opResult,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
// IfOp results are always considered memory writes in the analysis. This
|
||||
// design decision simplifies the analysis considerably. E.g., consider the
|
||||
// following test case:
|
||||
|
@ -129,14 +129,14 @@ struct IfOpInterface
|
|||
}
|
||||
|
||||
bool mustBufferizeInPlace(Operation *op, OpResult opResult,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
// IfOp results always bufferize in-place. Since they have no OpOperands,
|
||||
// they are mostly ignored by the analysis once alias sets are set up.
|
||||
return true;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
auto ifOp = cast<scf::IfOp>(op);
|
||||
|
||||
// Compute new types of the bufferized scf.if op.
|
||||
|
@ -209,7 +209,7 @@ struct IfOpInterface
|
|||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const BufferizationAliasInfo &aliasInfo,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
// IfOp results are equivalent to their corresponding yield values if both
|
||||
// yield values are equivalent to each other.
|
||||
auto bufferizableOp = cast<BufferizableOpInterface>(op);
|
||||
|
@ -226,7 +226,7 @@ struct ForOpInterface
|
|||
: public BufferizableOpInterface::ExternalModel<ForOpInterface,
|
||||
scf::ForOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
// scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
|
||||
// its matching bbArg may.
|
||||
auto forOp = cast<scf::ForOp>(op);
|
||||
|
@ -234,7 +234,7 @@ struct ForOpInterface
|
|||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
// Tensor iter_args of scf::ForOps are always considered as a write. This is
|
||||
// to simplify the analysis.
|
||||
// TODO: Consider doing sth. like isValueWritten.
|
||||
|
@ -242,7 +242,7 @@ struct ForOpInterface
|
|||
}
|
||||
|
||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
auto forOp = cast<scf::ForOp>(op);
|
||||
if (!opOperand.get().getType().isa<RankedTensorType>())
|
||||
return OpResult();
|
||||
|
@ -251,7 +251,7 @@ struct ForOpInterface
|
|||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const BufferizationAliasInfo &aliasInfo,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
// ForOp results are equivalent to their corresponding init_args if the
|
||||
// corresponding iter_args and yield values are equivalent.
|
||||
auto forOp = cast<scf::ForOp>(op);
|
||||
|
@ -263,7 +263,8 @@ struct ForOpInterface
|
|||
return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None;
|
||||
}
|
||||
|
||||
bool isWritable(Operation *op, Value value, BufferizationState &state) const {
|
||||
bool isWritable(Operation *op, Value value,
|
||||
const BufferizationState &state) const {
|
||||
// Interestingly, scf::ForOp's bbArg can **always** be viewed
|
||||
// inplace from the perspective of ops nested under:
|
||||
// 1. Either the matching iter operand is not bufferized inplace and an
|
||||
|
@ -274,7 +275,7 @@ struct ForOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
auto forOp = cast<scf::ForOp>(op);
|
||||
Block *oldLoopBody = &forOp.getLoopBody().front();
|
||||
|
||||
|
@ -416,22 +417,22 @@ struct YieldOpInterface
|
|||
: public BufferizableOpInterface::ExternalModel<YieldOpInterface,
|
||||
scf::YieldOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return OpResult();
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
auto yieldOp = cast<scf::YieldOp>(op);
|
||||
if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp>(
|
||||
yieldOp->getParentOp()))
|
||||
|
|
|
@ -27,28 +27,28 @@ struct CastOpInterface
|
|||
: public BufferizableOpInterface::ExternalModel<CastOpInterface,
|
||||
tensor::CastOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return op->getResult(0);
|
||||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const BufferizationAliasInfo &aliasInfo,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return BufferRelation::Equivalent;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
auto castOp = cast<tensor::CastOp>(op);
|
||||
|
||||
Value resultBuffer = state.getResultBuffer(rewriter, castOp->getResult(0));
|
||||
|
@ -78,22 +78,22 @@ struct DimOpInterface
|
|||
: public BufferizableOpInterface::ExternalModel<DimOpInterface,
|
||||
tensor::DimOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return OpResult();
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
auto dimOp = cast<tensor::DimOp>(op);
|
||||
if (!dimOp.source().getType().isa<RankedTensorType>())
|
||||
return dimOp.emitError("unranked tensor not supported");
|
||||
|
@ -107,17 +107,17 @@ struct ExtractSliceOpInterface
|
|||
: public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
|
||||
tensor::ExtractSliceOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return &opOperand == &op->getOpOperand(0) /*source*/
|
||||
? op->getResult(0)
|
||||
: OpResult();
|
||||
|
@ -125,12 +125,12 @@ struct ExtractSliceOpInterface
|
|||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const BufferizationAliasInfo &aliasInfo,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return BufferRelation::None;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
|
||||
Location loc = extractSliceOp.getLoc();
|
||||
Value srcMemref = state.lookupBuffer(rewriter, extractSliceOp.source());
|
||||
|
@ -173,22 +173,22 @@ struct ExtractOpInterface
|
|||
: public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
|
||||
tensor::ExtractOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return OpResult();
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
auto extractOp = cast<tensor::ExtractOp>(op);
|
||||
Value srcMemref = state.lookupBuffer(rewriter, extractOp.tensor());
|
||||
state.replaceOpWithNewOp<memref::LoadOp>(rewriter, op, srcMemref,
|
||||
|
@ -201,17 +201,17 @@ struct InsertOpInterface
|
|||
: public BufferizableOpInterface::ExternalModel<InsertOpInterface,
|
||||
tensor::InsertOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
|
||||
"expected dest OpOperand");
|
||||
return op->getOpResult(0);
|
||||
|
@ -219,12 +219,12 @@ struct InsertOpInterface
|
|||
|
||||
SmallVector<OpOperand *>
|
||||
getAliasingOpOperand(Operation *op, OpResult opResult,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return {&op->getOpOperand(1) /*dest*/};
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
auto insertOp = cast<tensor::InsertOp>(op);
|
||||
Location loc = insertOp.getLoc();
|
||||
Value destMemref =
|
||||
|
@ -237,7 +237,7 @@ struct InsertOpInterface
|
|||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const BufferizationAliasInfo &aliasInfo,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return BufferRelation::Equivalent;
|
||||
}
|
||||
};
|
||||
|
@ -263,8 +263,8 @@ areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
|
|||
/// Return true if `value` is originating from an ExtractSliceOp that matches
|
||||
/// the given InsertSliceOp.
|
||||
static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
|
||||
BufferizationState &state, Value value,
|
||||
InsertSliceOp insertOp) {
|
||||
const BufferizationState &state,
|
||||
Value value, InsertSliceOp insertOp) {
|
||||
auto condition = [&](Value val) {
|
||||
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
|
||||
if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp))
|
||||
|
@ -280,17 +280,17 @@ struct InsertSliceOpInterface
|
|||
: public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
|
||||
tensor::InsertSliceOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return &opOperand == &op->getOpOperand(1) /*dest*/;
|
||||
}
|
||||
|
||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return &opOperand == &op->getOpOperand(1) /*dest*/
|
||||
? op->getResult(0)
|
||||
: OpResult();
|
||||
|
@ -298,12 +298,13 @@ struct InsertSliceOpInterface
|
|||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const BufferizationAliasInfo &aliasInfo,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return BufferRelation::Equivalent;
|
||||
}
|
||||
|
||||
bool isNotConflicting(Operation *op, OpOperand *uRead,
|
||||
OpOperand *uConflictingWrite, BufferizationState &state,
|
||||
OpOperand *uConflictingWrite,
|
||||
const BufferizationState &state,
|
||||
const BufferizationAliasInfo &aliasInfo) const {
|
||||
Operation *readingOp = uRead->getOwner();
|
||||
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
|
||||
|
@ -380,7 +381,7 @@ struct InsertSliceOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
// insert_slice ops arise from tiling and bufferizing them out-of-place is
|
||||
// generally a deal breaker. When used with loops, this ends up cloning the
|
||||
// whole tensor on every single iteration and is a symptom of a
|
||||
|
|
|
@ -21,26 +21,26 @@ struct TransferReadOpInterface
|
|||
: public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
|
||||
vector::TransferReadOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
assert(opOperand.get().getType().isa<RankedTensorType>() &&
|
||||
"only tensor types expected");
|
||||
return true;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
assert(opOperand.get().getType().isa<RankedTensorType>() &&
|
||||
"only tensor types expected");
|
||||
return false;
|
||||
}
|
||||
|
||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return OpResult();
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
auto readOp = cast<vector::TransferReadOp>(op);
|
||||
assert(readOp.getShapedType().isa<TensorType>() &&
|
||||
"only tensor types expected");
|
||||
|
@ -60,21 +60,21 @@ struct TransferWriteOpInterface
|
|||
: public BufferizableOpInterface::ExternalModel<TransferWriteOpInterface,
|
||||
vector::TransferWriteOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
assert(opOperand.get().getType().isa<TensorType>() &&
|
||||
"only tensor types expected");
|
||||
return true;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
assert(opOperand.get().getType().isa<TensorType>() &&
|
||||
"only tensor types expected");
|
||||
return true;
|
||||
}
|
||||
|
||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
assert(opOperand.get().getType().isa<TensorType>() &&
|
||||
"only tensor types expected");
|
||||
return op->getOpResult(0);
|
||||
|
@ -82,12 +82,12 @@ struct TransferWriteOpInterface
|
|||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const BufferizationAliasInfo &aliasInfo,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
return BufferRelation::Equivalent;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationState &state) const {
|
||||
auto writeOp = cast<vector::TransferWriteOp>(op);
|
||||
assert(writeOp.getShapedType().isa<TensorType>() &&
|
||||
"only tensor types expected");
|
||||
|
|
Loading…
Reference in New Issue