forked from OSchip/llvm-project
[mlir][linalg][bufferize][NFC] Pass BufferizationState as const reference
This is mostly for documentation purposes: Passing the object as a const reference signifies that analysis decisions cannot be changed after the analysis. Differential Revision: https://reviews.llvm.org/D116742
This commit is contained in:
parent
f2277e60f4
commit
2975407bd4
|
@ -304,29 +304,29 @@ public:
|
||||||
|
|
||||||
/// Determine which OpOperand* will alias with `result` if the op is
|
/// Determine which OpOperand* will alias with `result` if the op is
|
||||||
/// bufferized in place. Return an empty vector if the op is not bufferizable.
|
/// bufferized in place. Return an empty vector if the op is not bufferizable.
|
||||||
SmallVector<OpOperand *> getAliasingOpOperand(OpResult result);
|
SmallVector<OpOperand *> getAliasingOpOperand(OpResult result) const;
|
||||||
|
|
||||||
/// Determine which OpResult will alias with `opOperand` if the op is
|
/// Determine which OpResult will alias with `opOperand` if the op is
|
||||||
/// bufferized in place. Return an empty OpResult if the op is not
|
/// bufferized in place. Return an empty OpResult if the op is not
|
||||||
/// bufferizable.
|
/// bufferizable.
|
||||||
OpResult getAliasingOpResult(OpOperand &opOperand);
|
OpResult getAliasingOpResult(OpOperand &opOperand) const;
|
||||||
|
|
||||||
/// Return true if `opOperand` bufferizes to a memory read. Return `true` if
|
/// Return true if `opOperand` bufferizes to a memory read. Return `true` if
|
||||||
/// the op is not bufferizable.
|
/// the op is not bufferizable.
|
||||||
bool bufferizesToMemoryRead(OpOperand &opOperand);
|
bool bufferizesToMemoryRead(OpOperand &opOperand) const;
|
||||||
|
|
||||||
/// Return true if `opOperand` bufferizes to a memory write. Return true` if
|
/// Return true if `opOperand` bufferizes to a memory write. Return true` if
|
||||||
/// the op is not bufferizable.
|
/// the op is not bufferizable.
|
||||||
bool bufferizesToMemoryWrite(OpOperand &opOperand);
|
bool bufferizesToMemoryWrite(OpOperand &opOperand) const;
|
||||||
|
|
||||||
/// Return true if `opOperand` does neither read nor write but bufferizes to
|
/// Return true if `opOperand` does neither read nor write but bufferizes to
|
||||||
/// an alias. Return false if the op is not bufferizable.
|
/// an alias. Return false if the op is not bufferizable.
|
||||||
bool bufferizesToAliasOnly(OpOperand &opOperand);
|
bool bufferizesToAliasOnly(OpOperand &opOperand) const;
|
||||||
|
|
||||||
/// Return true if the given value is read by an op that bufferizes to a
|
/// Return true if the given value is read by an op that bufferizes to a
|
||||||
/// memory read. Also takes into account ops that create an alias but do not
|
/// memory read. Also takes into account ops that create an alias but do not
|
||||||
/// read by themselves (e.g., ExtractSliceOp).
|
/// read by themselves (e.g., ExtractSliceOp).
|
||||||
bool isValueRead(Value value);
|
bool isValueRead(Value value) const;
|
||||||
|
|
||||||
/// Starting from `value`, follow the use-def chain in reverse, always
|
/// Starting from `value`, follow the use-def chain in reverse, always
|
||||||
/// selecting the aliasing OpOperands. Find and return Values for which
|
/// selecting the aliasing OpOperands. Find and return Values for which
|
||||||
|
@ -351,9 +351,8 @@ public:
|
||||||
/// In the above example, Values with a star satisfy the condition. When
|
/// In the above example, Values with a star satisfy the condition. When
|
||||||
/// starting the traversal from Value 1, the resulting SetVector is:
|
/// starting the traversal from Value 1, the resulting SetVector is:
|
||||||
/// { 2, 7, 8, 5 }
|
/// { 2, 7, 8, 5 }
|
||||||
llvm::SetVector<Value>
|
llvm::SetVector<Value> findValueInReverseUseDefChain(
|
||||||
findValueInReverseUseDefChain(Value value,
|
Value value, llvm::function_ref<bool(Value)> condition) const;
|
||||||
llvm::function_ref<bool(Value)> condition);
|
|
||||||
|
|
||||||
/// Find the Value of the last preceding write of a given Value.
|
/// Find the Value of the last preceding write of a given Value.
|
||||||
///
|
///
|
||||||
|
@ -363,33 +362,34 @@ public:
|
||||||
///
|
///
|
||||||
/// Note: When reaching an end of the reverse SSA use-def chain, that value
|
/// Note: When reaching an end of the reverse SSA use-def chain, that value
|
||||||
/// is returned regardless of whether it is a memory write or not.
|
/// is returned regardless of whether it is a memory write or not.
|
||||||
Value findLastPrecedingWrite(Value value);
|
Value findLastPrecedingWrite(Value value) const;
|
||||||
|
|
||||||
/// Creates a memref allocation.
|
/// Creates a memref allocation.
|
||||||
Optional<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
|
Optional<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
|
||||||
ArrayRef<Value> dynShape);
|
ArrayRef<Value> dynShape) const;
|
||||||
|
|
||||||
/// Creates an alloc-dealloc pair. This function may perform additional
|
/// Creates an alloc-dealloc pair. This function may perform additional
|
||||||
/// optimizations such as buffer allocation hoisting.
|
/// optimizations such as buffer allocation hoisting.
|
||||||
Value createAllocDeallocPair(OpBuilder &builder, Location loc,
|
Value createAllocDeallocPair(OpBuilder &builder, Location loc,
|
||||||
Value shapedValue);
|
Value shapedValue) const;
|
||||||
|
|
||||||
/// Creates a memref deallocation. The given memref buffer must have been
|
/// Creates a memref deallocation. The given memref buffer must have been
|
||||||
/// allocated using `createAlloc`.
|
/// allocated using `createAlloc`.
|
||||||
void createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer);
|
void createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer) const;
|
||||||
|
|
||||||
/// Creates a memcpy between two given buffers.
|
/// Creates a memcpy between two given buffers.
|
||||||
void createMemCpy(OpBuilder &b, Location loc, Value from, Value to);
|
void createMemCpy(OpBuilder &b, Location loc, Value from, Value to) const;
|
||||||
|
|
||||||
/// Replace an op with replacement values. The op is deleted. Tensor OpResults
|
/// Replace an op with replacement values. The op is deleted. Tensor OpResults
|
||||||
/// must be replaced with memref values.
|
/// must be replaced with memref values.
|
||||||
void replaceOp(RewriterBase &rewriter, Operation *op, ValueRange values);
|
void replaceOp(RewriterBase &rewriter, Operation *op,
|
||||||
|
ValueRange values) const;
|
||||||
|
|
||||||
/// Replace an op with a new op. Tensor OpResults must be replaced with memref
|
/// Replace an op with a new op. Tensor OpResults must be replaced with memref
|
||||||
/// values.
|
/// values.
|
||||||
template <typename OpTy, typename... Args>
|
template <typename OpTy, typename... Args>
|
||||||
OpTy replaceOpWithNewOp(RewriterBase &rewriter, Operation *op,
|
OpTy replaceOpWithNewOp(RewriterBase &rewriter, Operation *op,
|
||||||
Args &&...args) {
|
Args &&...args) const {
|
||||||
Operation *newOp =
|
Operation *newOp =
|
||||||
rewriter.create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
|
rewriter.create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
|
||||||
replaceOp(rewriter, op, newOp->getResults());
|
replaceOp(rewriter, op, newOp->getResults());
|
||||||
|
@ -398,7 +398,7 @@ public:
|
||||||
|
|
||||||
/// Lookup the memref buffer that is associated to the given tensor value.
|
/// Lookup the memref buffer that is associated to the given tensor value.
|
||||||
/// Asserts if no buffer is associated.
|
/// Asserts if no buffer is associated.
|
||||||
Value lookupBuffer(RewriterBase &rewriter, Value tensor);
|
Value lookupBuffer(RewriterBase &rewriter, Value tensor) const;
|
||||||
|
|
||||||
/// Return `true` if the given OpResult has been decided to bufferize inplace.
|
/// Return `true` if the given OpResult has been decided to bufferize inplace.
|
||||||
bool isInPlace(OpResult opResult) const;
|
bool isInPlace(OpResult opResult) const;
|
||||||
|
@ -406,10 +406,19 @@ public:
|
||||||
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
|
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
|
||||||
/// a new buffer and copy over data from the existing buffer if out-of-place
|
/// a new buffer and copy over data from the existing buffer if out-of-place
|
||||||
/// bufferization is necessary.
|
/// bufferization is necessary.
|
||||||
Value getResultBuffer(RewriterBase &rewriter, OpResult result);
|
Value getResultBuffer(RewriterBase &rewriter, OpResult result) const;
|
||||||
|
|
||||||
/// Return dialect-specific bufferization state.
|
/// Return dialect-specific bufferization state.
|
||||||
template <typename StateT> StateT &getDialectState(StringRef name) {
|
template <typename StateT>
|
||||||
|
Optional<const StateT *> getDialectState(StringRef name) const {
|
||||||
|
auto it = dialectState.find(name);
|
||||||
|
if (it == dialectState.end())
|
||||||
|
return None;
|
||||||
|
return static_cast<const StateT *>(it->getSecond().get());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return dialect-specific bufferization state or create one if none exists.
|
||||||
|
template <typename StateT> StateT &getOrCreateDialectState(StringRef name) {
|
||||||
// Create state if it does not exist yet.
|
// Create state if it does not exist yet.
|
||||||
if (!dialectState.count(name))
|
if (!dialectState.count(name))
|
||||||
dialectState[name] = std::make_unique<StateT>();
|
dialectState[name] = std::make_unique<StateT>();
|
||||||
|
@ -419,15 +428,10 @@ public:
|
||||||
/// Return a reference to the BufferizationOptions.
|
/// Return a reference to the BufferizationOptions.
|
||||||
const BufferizationOptions &getOptions() const { return options; }
|
const BufferizationOptions &getOptions() const { return options; }
|
||||||
|
|
||||||
|
/// Return a reference to the BufferizationAliasInfo.
|
||||||
|
BufferizationAliasInfo &getAliasInfo() { return aliasInfo; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend LogicalResult
|
|
||||||
runComprehensiveBufferize(Operation *op, const BufferizationOptions &options,
|
|
||||||
BufferizationState &state);
|
|
||||||
|
|
||||||
friend LogicalResult
|
|
||||||
runComprehensiveBufferize(ModuleOp moduleOp,
|
|
||||||
std::unique_ptr<BufferizationOptions> options);
|
|
||||||
|
|
||||||
/// `aliasInfo` keeps track of aliasing and equivalent values. Only internal
|
/// `aliasInfo` keeps track of aliasing and equivalent values. Only internal
|
||||||
/// functions and `runComprehensiveBufferize` may access this object.
|
/// functions and `runComprehensiveBufferize` may access this object.
|
||||||
BufferizationAliasInfo aliasInfo;
|
BufferizationAliasInfo aliasInfo;
|
||||||
|
@ -441,17 +445,17 @@ private:
|
||||||
|
|
||||||
/// Bufferize all ops in the given region.
|
/// Bufferize all ops in the given region.
|
||||||
LogicalResult bufferize(RewriterBase &rewriter, Region *region,
|
LogicalResult bufferize(RewriterBase &rewriter, Region *region,
|
||||||
BufferizationState &state);
|
const BufferizationState &state);
|
||||||
|
|
||||||
/// Bufferize all ops in the given block.
|
/// Bufferize all ops in the given block.
|
||||||
LogicalResult bufferize(RewriterBase &rewriter, Block *block,
|
LogicalResult bufferize(RewriterBase &rewriter, Block *block,
|
||||||
BufferizationState &state);
|
const BufferizationState &state);
|
||||||
|
|
||||||
/// Bufferize the given op. If the op has no tensor OpOperands/OpResults, this
|
/// Bufferize the given op. If the op has no tensor OpOperands/OpResults, this
|
||||||
/// function returns immediately. Otherwise, it calls the `bufferize` interface
|
/// function returns immediately. Otherwise, it calls the `bufferize` interface
|
||||||
/// method of `BufferizableOpInterface`.
|
/// method of `BufferizableOpInterface`.
|
||||||
LogicalResult bufferize(RewriterBase &rewriter, Operation *op,
|
LogicalResult bufferize(RewriterBase &rewriter, Operation *op,
|
||||||
BufferizationState &state);
|
const BufferizationState &state);
|
||||||
|
|
||||||
/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
|
/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
|
||||||
/// with the same shape as `shapedType` and specified `layout` and
|
/// with the same shape as `shapedType` and specified `layout` and
|
||||||
|
@ -492,38 +496,39 @@ struct AllocationHoistingBarrierOnly
|
||||||
: public BufferizableOpInterface::ExternalModel<
|
: public BufferizableOpInterface::ExternalModel<
|
||||||
AllocationHoistingBarrierOnly<OpTy>, OpTy> {
|
AllocationHoistingBarrierOnly<OpTy>, OpTy> {
|
||||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<OpOperand *>
|
SmallVector<OpOperand *>
|
||||||
getAliasingOpOperand(Operation *op, OpResult opResult,
|
getAliasingOpOperand(Operation *op, OpResult opResult,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return OpResult();
|
return OpResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||||
const BufferizationAliasInfo &aliasInfo,
|
const BufferizationAliasInfo &aliasInfo,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return BufferRelation::None;
|
return BufferRelation::None;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isWritable(Operation *op, Value value, BufferizationState &state) const {
|
bool isWritable(Operation *op, Value value,
|
||||||
|
const BufferizationState &state) const {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
|
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
|
||||||
if (any_of(op->getOperandTypes(), isaTensor) ||
|
if (any_of(op->getOperandTypes(), isaTensor) ||
|
||||||
any_of(op->getResultTypes(), isaTensor))
|
any_of(op->getResultTypes(), isaTensor))
|
||||||
|
|
|
@ -33,7 +33,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
||||||
/*retType=*/"bool",
|
/*retType=*/"bool",
|
||||||
/*methodName=*/"bufferizesToMemoryRead",
|
/*methodName=*/"bufferizesToMemoryRead",
|
||||||
/*args=*/(ins "OpOperand &":$opOperand,
|
/*args=*/(ins "OpOperand &":$opOperand,
|
||||||
"BufferizationState &":$state),
|
"const BufferizationState &":$state),
|
||||||
/*methodBody=*/"",
|
/*methodBody=*/"",
|
||||||
/*defaultImplementation=*/[{
|
/*defaultImplementation=*/[{
|
||||||
// Does not have to be implemented for ops without tensor OpOperands.
|
// Does not have to be implemented for ops without tensor OpOperands.
|
||||||
|
@ -62,7 +62,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
||||||
/*retType=*/"bool",
|
/*retType=*/"bool",
|
||||||
/*methodName=*/"bufferizesToMemoryWrite",
|
/*methodName=*/"bufferizesToMemoryWrite",
|
||||||
/*args=*/(ins "OpOperand &":$opOperand,
|
/*args=*/(ins "OpOperand &":$opOperand,
|
||||||
"BufferizationState &":$state),
|
"const BufferizationState &":$state),
|
||||||
/*methodBody=*/"",
|
/*methodBody=*/"",
|
||||||
/*defaultImplementation=*/[{
|
/*defaultImplementation=*/[{
|
||||||
// Does not have to be implemented for ops without tensor OpOperands.
|
// Does not have to be implemented for ops without tensor OpOperands.
|
||||||
|
@ -85,7 +85,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
||||||
/*retType=*/"bool",
|
/*retType=*/"bool",
|
||||||
/*methodName=*/"isMemoryWrite",
|
/*methodName=*/"isMemoryWrite",
|
||||||
/*args=*/(ins "OpResult":$opResult,
|
/*args=*/(ins "OpResult":$opResult,
|
||||||
"BufferizationState &":$state),
|
"const BufferizationState &":$state),
|
||||||
/*methodBody=*/"",
|
/*methodBody=*/"",
|
||||||
/*defaultImplementation=*/[{
|
/*defaultImplementation=*/[{
|
||||||
auto bufferizableOp =
|
auto bufferizableOp =
|
||||||
|
@ -116,7 +116,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
||||||
/*retType=*/"bool",
|
/*retType=*/"bool",
|
||||||
/*methodName=*/"mustBufferizeInPlace",
|
/*methodName=*/"mustBufferizeInPlace",
|
||||||
/*args=*/(ins "OpResult":$opResult,
|
/*args=*/(ins "OpResult":$opResult,
|
||||||
"BufferizationState &":$state),
|
"const BufferizationState &":$state),
|
||||||
/*methodBody=*/"",
|
/*methodBody=*/"",
|
||||||
/*defaultImplementation=*/[{
|
/*defaultImplementation=*/[{
|
||||||
return false;
|
return false;
|
||||||
|
@ -131,7 +131,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
||||||
/*retType=*/"OpResult",
|
/*retType=*/"OpResult",
|
||||||
/*methodName=*/"getAliasingOpResult",
|
/*methodName=*/"getAliasingOpResult",
|
||||||
/*args=*/(ins "OpOperand &":$opOperand,
|
/*args=*/(ins "OpOperand &":$opOperand,
|
||||||
"BufferizationState &":$state),
|
"const BufferizationState &":$state),
|
||||||
/*methodBody=*/"",
|
/*methodBody=*/"",
|
||||||
/*defaultImplementation=*/[{
|
/*defaultImplementation=*/[{
|
||||||
// Does not have to be implemented for ops without tensor OpOperands.
|
// Does not have to be implemented for ops without tensor OpOperands.
|
||||||
|
@ -155,7 +155,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
||||||
/*retType=*/"SmallVector<OpOperand *>",
|
/*retType=*/"SmallVector<OpOperand *>",
|
||||||
/*methodName=*/"getAliasingOpOperand",
|
/*methodName=*/"getAliasingOpOperand",
|
||||||
/*args=*/(ins "OpResult":$opResult,
|
/*args=*/(ins "OpResult":$opResult,
|
||||||
"BufferizationState &":$state),
|
"const BufferizationState &":$state),
|
||||||
/*methodBody=*/"",
|
/*methodBody=*/"",
|
||||||
/*defaultImplementation=*/[{
|
/*defaultImplementation=*/[{
|
||||||
assert(opResult.getType().isa<TensorType>() &&
|
assert(opResult.getType().isa<TensorType>() &&
|
||||||
|
@ -188,7 +188,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
||||||
/*methodName=*/"bufferRelation",
|
/*methodName=*/"bufferRelation",
|
||||||
/*args=*/(ins "OpResult":$opResult,
|
/*args=*/(ins "OpResult":$opResult,
|
||||||
"const BufferizationAliasInfo &":$aliasInfo,
|
"const BufferizationAliasInfo &":$aliasInfo,
|
||||||
"BufferizationState &":$state),
|
"const BufferizationState &":$state),
|
||||||
/*methodBody=*/"",
|
/*methodBody=*/"",
|
||||||
/*defaultImplementation=*/[{
|
/*defaultImplementation=*/[{
|
||||||
// Does not have to be implemented for ops without tensor OpResults
|
// Does not have to be implemented for ops without tensor OpResults
|
||||||
|
@ -210,7 +210,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
||||||
/*retType=*/"LogicalResult",
|
/*retType=*/"LogicalResult",
|
||||||
/*methodName=*/"bufferize",
|
/*methodName=*/"bufferize",
|
||||||
/*args=*/(ins "RewriterBase &":$rewriter,
|
/*args=*/(ins "RewriterBase &":$rewriter,
|
||||||
"BufferizationState &":$state),
|
"const BufferizationState &":$state),
|
||||||
/*methodBody=*/"",
|
/*methodBody=*/"",
|
||||||
/*defaultImplementation=*/[{
|
/*defaultImplementation=*/[{
|
||||||
llvm_unreachable("bufferize not implemented");
|
llvm_unreachable("bufferize not implemented");
|
||||||
|
@ -236,7 +236,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
||||||
/*retType=*/"bool",
|
/*retType=*/"bool",
|
||||||
/*methodName=*/"isWritable",
|
/*methodName=*/"isWritable",
|
||||||
/*args=*/(ins "Value":$value,
|
/*args=*/(ins "Value":$value,
|
||||||
"BufferizationState &":$state),
|
"const BufferizationState &":$state),
|
||||||
/*methodBody=*/"",
|
/*methodBody=*/"",
|
||||||
/*defaultImplementation=*/[{
|
/*defaultImplementation=*/[{
|
||||||
return value.isa<OpResult>();
|
return value.isa<OpResult>();
|
||||||
|
@ -275,7 +275,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
||||||
/*methodName=*/"isNotConflicting",
|
/*methodName=*/"isNotConflicting",
|
||||||
/*args=*/(ins "OpOperand *":$uRead,
|
/*args=*/(ins "OpOperand *":$uRead,
|
||||||
"OpOperand *":$uWrite,
|
"OpOperand *":$uWrite,
|
||||||
"BufferizationState &":$state,
|
"const BufferizationState &":$state,
|
||||||
"const BufferizationAliasInfo &":$aliasInfo),
|
"const BufferizationAliasInfo &":$aliasInfo),
|
||||||
/*methodBody=*/"",
|
/*methodBody=*/"",
|
||||||
/*defaultImplementation=*/[{
|
/*defaultImplementation=*/[{
|
||||||
|
@ -292,7 +292,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
||||||
///
|
///
|
||||||
/// Examples of such ops are `tensor.extract_slice` and `tensor.cast`.
|
/// Examples of such ops are `tensor.extract_slice` and `tensor.cast`.
|
||||||
bool bufferizesToAliasOnly(OpOperand &opOperand,
|
bool bufferizesToAliasOnly(OpOperand &opOperand,
|
||||||
BufferizationState &state) {
|
const BufferizationState &state) {
|
||||||
auto bufferizableOp =
|
auto bufferizableOp =
|
||||||
cast<BufferizableOpInterface>(getOperation());
|
cast<BufferizableOpInterface>(getOperation());
|
||||||
return !bufferizableOp.bufferizesToMemoryRead(opOperand, state)
|
return !bufferizableOp.bufferizesToMemoryRead(opOperand, state)
|
||||||
|
|
|
@ -24,7 +24,7 @@ struct ConstantOpInterface
|
||||||
: public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
|
: public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
|
||||||
arith::ConstantOp> {
|
arith::ConstantOp> {
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
auto constantOp = cast<arith::ConstantOp>(op);
|
auto constantOp = cast<arith::ConstantOp>(op);
|
||||||
assert(constantOp.getType().dyn_cast<RankedTensorType>() &&
|
assert(constantOp.getType().dyn_cast<RankedTensorType>() &&
|
||||||
"not a constant ranked tensor");
|
"not a constant ranked tensor");
|
||||||
|
@ -40,7 +40,8 @@ struct ConstantOpInterface
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isWritable(Operation *op, Value value, BufferizationState &state) const {
|
bool isWritable(Operation *op, Value value,
|
||||||
|
const BufferizationState &state) const {
|
||||||
// Memory locations returned by memref::GetGlobalOp may not be written to.
|
// Memory locations returned by memref::GetGlobalOp may not be written to.
|
||||||
assert(value.isa<OpResult>());
|
assert(value.isa<OpResult>());
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -199,7 +199,7 @@ BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
|
||||||
/// in place. Return an empty vector if the op is not bufferizable.
|
/// in place. Return an empty vector if the op is not bufferizable.
|
||||||
SmallVector<OpOperand *>
|
SmallVector<OpOperand *>
|
||||||
mlir::linalg::comprehensive_bufferize::BufferizationState::getAliasingOpOperand(
|
mlir::linalg::comprehensive_bufferize::BufferizationState::getAliasingOpOperand(
|
||||||
OpResult result) {
|
OpResult result) const {
|
||||||
if (Operation *op = result.getDefiningOp())
|
if (Operation *op = result.getDefiningOp())
|
||||||
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
|
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
|
||||||
return bufferizableOp.getAliasingOpOperand(result, *this);
|
return bufferizableOp.getAliasingOpOperand(result, *this);
|
||||||
|
@ -210,7 +210,7 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getAliasingOpOperand(
|
||||||
/// in place. Return an empty OpResult if the op is not bufferizable.
|
/// in place. Return an empty OpResult if the op is not bufferizable.
|
||||||
OpResult
|
OpResult
|
||||||
mlir::linalg::comprehensive_bufferize::BufferizationState::getAliasingOpResult(
|
mlir::linalg::comprehensive_bufferize::BufferizationState::getAliasingOpResult(
|
||||||
OpOperand &opOperand) {
|
OpOperand &opOperand) const {
|
||||||
if (auto bufferizableOp =
|
if (auto bufferizableOp =
|
||||||
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
|
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
|
||||||
return bufferizableOp.getAliasingOpResult(opOperand, *this);
|
return bufferizableOp.getAliasingOpResult(opOperand, *this);
|
||||||
|
@ -220,7 +220,7 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getAliasingOpResult(
|
||||||
/// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
|
/// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
|
||||||
/// op is not bufferizable.
|
/// op is not bufferizable.
|
||||||
bool mlir::linalg::comprehensive_bufferize::BufferizationState::
|
bool mlir::linalg::comprehensive_bufferize::BufferizationState::
|
||||||
bufferizesToMemoryRead(OpOperand &opOperand) {
|
bufferizesToMemoryRead(OpOperand &opOperand) const {
|
||||||
if (auto bufferizableOp =
|
if (auto bufferizableOp =
|
||||||
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
|
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
|
||||||
return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
|
return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
|
||||||
|
@ -233,7 +233,7 @@ bool mlir::linalg::comprehensive_bufferize::BufferizationState::
|
||||||
/// Return true if `opOperand` bufferizes to a memory write. Return
|
/// Return true if `opOperand` bufferizes to a memory write. Return
|
||||||
/// `true` if the op is not bufferizable.
|
/// `true` if the op is not bufferizable.
|
||||||
bool mlir::linalg::comprehensive_bufferize::BufferizationState::
|
bool mlir::linalg::comprehensive_bufferize::BufferizationState::
|
||||||
bufferizesToMemoryWrite(OpOperand &opOperand) {
|
bufferizesToMemoryWrite(OpOperand &opOperand) const {
|
||||||
if (auto bufferizableOp =
|
if (auto bufferizableOp =
|
||||||
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
|
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
|
||||||
return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
|
return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
|
||||||
|
@ -246,7 +246,7 @@ bool mlir::linalg::comprehensive_bufferize::BufferizationState::
|
||||||
/// Return true if `opOperand` does neither read nor write but bufferizes to an
|
/// Return true if `opOperand` does neither read nor write but bufferizes to an
|
||||||
/// alias. Return false if the op is not bufferizable.
|
/// alias. Return false if the op is not bufferizable.
|
||||||
bool mlir::linalg::comprehensive_bufferize::BufferizationState::
|
bool mlir::linalg::comprehensive_bufferize::BufferizationState::
|
||||||
bufferizesToAliasOnly(OpOperand &opOperand) {
|
bufferizesToAliasOnly(OpOperand &opOperand) const {
|
||||||
if (auto bufferizableOp =
|
if (auto bufferizableOp =
|
||||||
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
|
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
|
||||||
return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
|
return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
|
||||||
|
@ -260,7 +260,7 @@ bool mlir::linalg::comprehensive_bufferize::BufferizationState::
|
||||||
/// read. Also takes into account ops that create an alias but do not read by
|
/// read. Also takes into account ops that create an alias but do not read by
|
||||||
/// themselves (e.g., ExtractSliceOp).
|
/// themselves (e.g., ExtractSliceOp).
|
||||||
bool mlir::linalg::comprehensive_bufferize::BufferizationState::isValueRead(
|
bool mlir::linalg::comprehensive_bufferize::BufferizationState::isValueRead(
|
||||||
Value value) {
|
Value value) const {
|
||||||
SmallVector<OpOperand *> workingSet;
|
SmallVector<OpOperand *> workingSet;
|
||||||
for (OpOperand &use : value.getUses())
|
for (OpOperand &use : value.getUses())
|
||||||
workingSet.push_back(&use);
|
workingSet.push_back(&use);
|
||||||
|
@ -282,10 +282,9 @@ bool mlir::linalg::comprehensive_bufferize::BufferizationState::isValueRead(
|
||||||
// the aliasing OpOperands. Find and return Values for which `condition`
|
// the aliasing OpOperands. Find and return Values for which `condition`
|
||||||
// evaluates to true. OpOperands of such matching Values are not traversed any
|
// evaluates to true. OpOperands of such matching Values are not traversed any
|
||||||
// further.
|
// further.
|
||||||
llvm::SetVector<Value>
|
llvm::SetVector<Value> mlir::linalg::comprehensive_bufferize::
|
||||||
mlir::linalg::comprehensive_bufferize::BufferizationState::
|
BufferizationState::findValueInReverseUseDefChain(
|
||||||
findValueInReverseUseDefChain(Value value,
|
Value value, llvm::function_ref<bool(Value)> condition) const {
|
||||||
llvm::function_ref<bool(Value)> condition) {
|
|
||||||
llvm::SetVector<Value> result, workingSet;
|
llvm::SetVector<Value> result, workingSet;
|
||||||
workingSet.insert(value);
|
workingSet.insert(value);
|
||||||
|
|
||||||
|
@ -312,7 +311,7 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::
|
||||||
|
|
||||||
// Find the Value of the last preceding write of a given Value.
|
// Find the Value of the last preceding write of a given Value.
|
||||||
Value mlir::linalg::comprehensive_bufferize::BufferizationState::
|
Value mlir::linalg::comprehensive_bufferize::BufferizationState::
|
||||||
findLastPrecedingWrite(Value value) {
|
findLastPrecedingWrite(Value value) const {
|
||||||
SetVector<Value> result =
|
SetVector<Value> result =
|
||||||
findValueInReverseUseDefChain(value, [&](Value value) {
|
findValueInReverseUseDefChain(value, [&](Value value) {
|
||||||
Operation *op = value.getDefiningOp();
|
Operation *op = value.getDefiningOp();
|
||||||
|
@ -360,7 +359,7 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
|
||||||
/// a new buffer and copy over data from the existing buffer if out-of-place
|
/// a new buffer and copy over data from the existing buffer if out-of-place
|
||||||
/// bufferization is necessary.
|
/// bufferization is necessary.
|
||||||
Value mlir::linalg::comprehensive_bufferize::BufferizationState::
|
Value mlir::linalg::comprehensive_bufferize::BufferizationState::
|
||||||
getResultBuffer(RewriterBase &rewriter, OpResult result) {
|
getResultBuffer(RewriterBase &rewriter, OpResult result) const {
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
Operation *op = result.getOwner();
|
Operation *op = result.getOwner();
|
||||||
SmallVector<OpOperand *> aliasingOperands = getAliasingOpOperand(result);
|
SmallVector<OpOperand *> aliasingOperands = getAliasingOpOperand(result);
|
||||||
|
@ -424,7 +423,7 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::linalg::comprehensive_bufferize::BufferizationState::replaceOp(
|
void mlir::linalg::comprehensive_bufferize::BufferizationState::replaceOp(
|
||||||
RewriterBase &rewriter, Operation *op, ValueRange values) {
|
RewriterBase &rewriter, Operation *op, ValueRange values) const {
|
||||||
OpBuilder::InsertionGuard g(rewriter);
|
OpBuilder::InsertionGuard g(rewriter);
|
||||||
|
|
||||||
// Replace all OpResults with the given values.
|
// Replace all OpResults with the given values.
|
||||||
|
@ -454,7 +453,7 @@ void mlir::linalg::comprehensive_bufferize::BufferizationState::replaceOp(
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult mlir::linalg::comprehensive_bufferize::bufferize(
|
LogicalResult mlir::linalg::comprehensive_bufferize::bufferize(
|
||||||
RewriterBase &rewriter, Region *region, BufferizationState &state) {
|
RewriterBase &rewriter, Region *region, const BufferizationState &state) {
|
||||||
for (Block &block : *region)
|
for (Block &block : *region)
|
||||||
if (failed(bufferize(rewriter, &block, state)))
|
if (failed(bufferize(rewriter, &block, state)))
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -462,7 +461,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::bufferize(
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult mlir::linalg::comprehensive_bufferize::bufferize(
|
LogicalResult mlir::linalg::comprehensive_bufferize::bufferize(
|
||||||
RewriterBase &rewriter, Block *block, BufferizationState &state) {
|
RewriterBase &rewriter, Block *block, const BufferizationState &state) {
|
||||||
// Ops may get deleted during the traversal, so do not iterate over `block`
|
// Ops may get deleted during the traversal, so do not iterate over `block`
|
||||||
// directly.
|
// directly.
|
||||||
SmallVector<Operation *> ops;
|
SmallVector<Operation *> ops;
|
||||||
|
@ -476,7 +475,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::bufferize(
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult mlir::linalg::comprehensive_bufferize::bufferize(
|
LogicalResult mlir::linalg::comprehensive_bufferize::bufferize(
|
||||||
RewriterBase &rewriter, Operation *op, BufferizationState &state) {
|
RewriterBase &rewriter, Operation *op, const BufferizationState &state) {
|
||||||
// Check if op has tensor results or operands.
|
// Check if op has tensor results or operands.
|
||||||
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
|
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
|
||||||
bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
|
bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
|
||||||
|
@ -592,7 +591,8 @@ static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
|
||||||
/// `shapedValue.getDefiningOp` (or at the top of the block in case of a
|
/// `shapedValue.getDefiningOp` (or at the top of the block in case of a
|
||||||
/// bbArg) and the DeallocOp is at the end of the block.
|
/// bbArg) and the DeallocOp is at the end of the block.
|
||||||
Value mlir::linalg::comprehensive_bufferize::BufferizationState::
|
Value mlir::linalg::comprehensive_bufferize::BufferizationState::
|
||||||
createAllocDeallocPair(OpBuilder &b, Location loc, Value shapedValue) {
|
createAllocDeallocPair(OpBuilder &b, Location loc,
|
||||||
|
Value shapedValue) const {
|
||||||
// Take a guard before anything else.
|
// Take a guard before anything else.
|
||||||
OpBuilder::InsertionGuard g(b);
|
OpBuilder::InsertionGuard g(b);
|
||||||
|
|
||||||
|
@ -621,19 +621,20 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
|
||||||
/// Create a memref allocation.
|
/// Create a memref allocation.
|
||||||
Optional<Value>
|
Optional<Value>
|
||||||
mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
|
mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
|
||||||
OpBuilder &b, Location loc, MemRefType type, ArrayRef<Value> dynShape) {
|
OpBuilder &b, Location loc, MemRefType type,
|
||||||
|
ArrayRef<Value> dynShape) const {
|
||||||
return options.allocationFns->allocationFn(b, loc, type, dynShape);
|
return options.allocationFns->allocationFn(b, loc, type, dynShape);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a memref deallocation.
|
/// Create a memref deallocation.
|
||||||
void mlir::linalg::comprehensive_bufferize::BufferizationState::createDealloc(
|
void mlir::linalg::comprehensive_bufferize::BufferizationState::createDealloc(
|
||||||
OpBuilder &b, Location loc, Value allocatedBuffer) {
|
OpBuilder &b, Location loc, Value allocatedBuffer) const {
|
||||||
return options.allocationFns->deallocationFn(b, loc, allocatedBuffer);
|
return options.allocationFns->deallocationFn(b, loc, allocatedBuffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a memory copy between two memref buffers.
|
/// Create a memory copy between two memref buffers.
|
||||||
void mlir::linalg::comprehensive_bufferize::BufferizationState::createMemCpy(
|
void mlir::linalg::comprehensive_bufferize::BufferizationState::createMemCpy(
|
||||||
OpBuilder &b, Location loc, Value from, Value to) {
|
OpBuilder &b, Location loc, Value from, Value to) const {
|
||||||
return options.allocationFns->memCpyFn(b, loc, from, to);
|
return options.allocationFns->memCpyFn(b, loc, from, to);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -649,7 +650,7 @@ bool mlir::linalg::comprehensive_bufferize::isFunctionArgument(Value value) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
|
Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
|
||||||
RewriterBase &rewriter, Value tensor) {
|
RewriterBase &rewriter, Value tensor) const {
|
||||||
assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
|
assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
|
||||||
|
|
||||||
// Replace "%t = to_tensor %m" with %m.
|
// Replace "%t = to_tensor %m" with %m.
|
||||||
|
|
|
@ -40,18 +40,18 @@ struct ToMemrefOpInterface
|
||||||
: public BufferizableOpInterface::ExternalModel<ToMemrefOpInterface,
|
: public BufferizableOpInterface::ExternalModel<ToMemrefOpInterface,
|
||||||
bufferization::ToMemrefOp> {
|
bufferization::ToMemrefOp> {
|
||||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
// It is unknown whether the resulting MemRef will be read or not.
|
// It is unknown whether the resulting MemRef will be read or not.
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return OpResult();
|
return OpResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
auto toMemrefOp = cast<bufferization::ToMemrefOp>(op);
|
auto toMemrefOp = cast<bufferization::ToMemrefOp>(op);
|
||||||
|
|
||||||
// Fold to_memref(to_tensor(x)) to x.
|
// Fold to_memref(to_tensor(x)) to x.
|
||||||
|
@ -86,11 +86,12 @@ struct ToTensorOpInterface
|
||||||
: public BufferizableOpInterface::ExternalModel<ToTensorOpInterface,
|
: public BufferizableOpInterface::ExternalModel<ToTensorOpInterface,
|
||||||
bufferization::ToTensorOp> {
|
bufferization::ToTensorOp> {
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isWritable(Operation *op, Value value, BufferizationState &state) const {
|
bool isWritable(Operation *op, Value value,
|
||||||
|
const BufferizationState &state) const {
|
||||||
// It is unknown whether the MemRef operand is writable or not.
|
// It is unknown whether the MemRef operand is writable or not.
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -661,7 +661,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
|
||||||
|
|
||||||
IRRewriter rewriter(op->getContext());
|
IRRewriter rewriter(op->getContext());
|
||||||
DominanceInfo domInfo(op);
|
DominanceInfo domInfo(op);
|
||||||
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
|
BufferizationAliasInfo &aliasInfo = state.getAliasInfo();
|
||||||
|
|
||||||
if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo)))
|
if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
|
@ -24,7 +24,7 @@ namespace {
|
||||||
|
|
||||||
/// Generic conversion for any LinalgOp on tensors.
|
/// Generic conversion for any LinalgOp on tensors.
|
||||||
static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
|
static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
|
||||||
BufferizationState &state) {
|
const BufferizationState &state) {
|
||||||
// Take a guard before anything else.
|
// Take a guard before anything else.
|
||||||
OpBuilder::InsertionGuard g(rewriter);
|
OpBuilder::InsertionGuard g(rewriter);
|
||||||
rewriter.setInsertionPoint(op);
|
rewriter.setInsertionPoint(op);
|
||||||
|
@ -142,13 +142,13 @@ struct LinalgOpInterface
|
||||||
: public BufferizableOpInterface::ExternalModel<LinalgOpInterface<OpTy>,
|
: public BufferizableOpInterface::ExternalModel<LinalgOpInterface<OpTy>,
|
||||||
OpTy> {
|
OpTy> {
|
||||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
auto genericOp = cast<linalg::LinalgOp>(op);
|
auto genericOp = cast<linalg::LinalgOp>(op);
|
||||||
return genericOp.payloadUsesValueFromOperand(&opOperand);
|
return genericOp.payloadUsesValueFromOperand(&opOperand);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
auto bufferizableOp = cast<BufferizableOpInterface>(op);
|
auto bufferizableOp = cast<BufferizableOpInterface>(op);
|
||||||
return static_cast<bool>(
|
return static_cast<bool>(
|
||||||
bufferizableOp.getAliasingOpResult(opOperand, state));
|
bufferizableOp.getAliasingOpResult(opOperand, state));
|
||||||
|
@ -156,7 +156,7 @@ struct LinalgOpInterface
|
||||||
|
|
||||||
SmallVector<OpOperand *>
|
SmallVector<OpOperand *>
|
||||||
getAliasingOpOperand(Operation *op, OpResult opResult,
|
getAliasingOpOperand(Operation *op, OpResult opResult,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
auto genericOp = cast<linalg::LinalgOp>(op);
|
auto genericOp = cast<linalg::LinalgOp>(op);
|
||||||
DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp);
|
DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp);
|
||||||
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands())
|
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands())
|
||||||
|
@ -166,7 +166,7 @@ struct LinalgOpInterface
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
auto genericOp = cast<linalg::LinalgOp>(op);
|
auto genericOp = cast<linalg::LinalgOp>(op);
|
||||||
DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp);
|
DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp);
|
||||||
return pairs[&opOperand];
|
return pairs[&opOperand];
|
||||||
|
@ -174,12 +174,12 @@ struct LinalgOpInterface
|
||||||
|
|
||||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||||
const BufferizationAliasInfo &aliasInfo,
|
const BufferizationAliasInfo &aliasInfo,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return BufferRelation::Equivalent;
|
return BufferRelation::Equivalent;
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return bufferizeLinalgOp(rewriter, cast<LinalgOp>(op), state);
|
return bufferizeLinalgOp(rewriter, cast<LinalgOp>(op), state);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -188,13 +188,13 @@ struct InitTensorOpInterface
|
||||||
: public BufferizableOpInterface::ExternalModel<InitTensorOpInterface,
|
: public BufferizableOpInterface::ExternalModel<InitTensorOpInterface,
|
||||||
linalg::InitTensorOp> {
|
linalg::InitTensorOp> {
|
||||||
bool isMemoryWrite(Operation *op, OpResult opResult,
|
bool isMemoryWrite(Operation *op, OpResult opResult,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
// InitTensorOps allocate but do not write.
|
// InitTensorOps allocate but do not write.
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
auto initTensorOp = cast<linalg::InitTensorOp>(op);
|
auto initTensorOp = cast<linalg::InitTensorOp>(op);
|
||||||
|
|
||||||
// The InitTensorOp may have been eliminated.
|
// The InitTensorOp may have been eliminated.
|
||||||
|
@ -212,7 +212,7 @@ struct TiledLoopOpInterface
|
||||||
: public BufferizableOpInterface::ExternalModel<TiledLoopOpInterface,
|
: public BufferizableOpInterface::ExternalModel<TiledLoopOpInterface,
|
||||||
linalg::TiledLoopOp> {
|
linalg::TiledLoopOp> {
|
||||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
// TiledLoop alone doesn't bufferize to a memory read, one of the uses of
|
// TiledLoop alone doesn't bufferize to a memory read, one of the uses of
|
||||||
// its matching bbArg may.
|
// its matching bbArg may.
|
||||||
auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
|
auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
|
||||||
|
@ -220,7 +220,7 @@ struct TiledLoopOpInterface
|
||||||
}
|
}
|
||||||
|
|
||||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
// TiledLoop alone doesn't bufferize to a memory write, one of the uses of
|
// TiledLoop alone doesn't bufferize to a memory write, one of the uses of
|
||||||
// its matching bbArg may.
|
// its matching bbArg may.
|
||||||
auto bufferizableOp = cast<BufferizableOpInterface>(op);
|
auto bufferizableOp = cast<BufferizableOpInterface>(op);
|
||||||
|
@ -229,18 +229,19 @@ struct TiledLoopOpInterface
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
|
auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
|
||||||
return tiledLoopOp.getTiedOpResult(opOperand);
|
return tiledLoopOp.getTiedOpResult(opOperand);
|
||||||
}
|
}
|
||||||
|
|
||||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||||
const BufferizationAliasInfo &aliasInfo,
|
const BufferizationAliasInfo &aliasInfo,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return BufferRelation::Equivalent;
|
return BufferRelation::Equivalent;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isWritable(Operation *op, Value value, BufferizationState &state) const {
|
bool isWritable(Operation *op, Value value,
|
||||||
|
const BufferizationState &state) const {
|
||||||
// Interestingly, linalg::TiledLoopOp's bbArg can **always** be viewed
|
// Interestingly, linalg::TiledLoopOp's bbArg can **always** be viewed
|
||||||
// inplace from the perspective of ops nested under:
|
// inplace from the perspective of ops nested under:
|
||||||
// 1. Either the matching iter operand is not bufferized inplace and an
|
// 1. Either the matching iter operand is not bufferized inplace and an
|
||||||
|
@ -253,7 +254,7 @@ struct TiledLoopOpInterface
|
||||||
bool isAllocationHoistingBarrier(Operation *op) const { return true; }
|
bool isAllocationHoistingBarrier(Operation *op) const { return true; }
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
|
auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
|
||||||
|
|
||||||
// Compute new inputs, outputs and results.
|
// Compute new inputs, outputs and results.
|
||||||
|
@ -355,22 +356,22 @@ struct YieldOpInterface
|
||||||
: public BufferizableOpInterface::ExternalModel<YieldOpInterface,
|
: public BufferizableOpInterface::ExternalModel<YieldOpInterface,
|
||||||
linalg::YieldOp> {
|
linalg::YieldOp> {
|
||||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return OpResult();
|
return OpResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
auto yieldOp = cast<linalg::YieldOp>(op);
|
auto yieldOp = cast<linalg::YieldOp>(op);
|
||||||
|
|
||||||
if (!yieldOp->getParentOfType<TiledLoopOp>())
|
if (!yieldOp->getParentOfType<TiledLoopOp>())
|
||||||
|
|
|
@ -34,9 +34,20 @@ struct ModuleBufferizationState : public DialectBufferizationState {
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
/// Get ModuleBufferizationState.
|
||||||
|
static const ModuleBufferizationState &
|
||||||
|
getModuleBufferizationState(const BufferizationState &state) {
|
||||||
|
Optional<const ModuleBufferizationState *> maybeState =
|
||||||
|
state.getDialectState<ModuleBufferizationState>(
|
||||||
|
StandardOpsDialect::getDialectNamespace());
|
||||||
|
assert(maybeState.hasValue() && "ModuleBufferizationState does not exist");
|
||||||
|
return **maybeState;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get or create ModuleBufferizationState.
|
||||||
static ModuleBufferizationState &
|
static ModuleBufferizationState &
|
||||||
getModuleBufferizationState(BufferizationState &state) {
|
getModuleBufferizationState(BufferizationState &state) {
|
||||||
return state.getDialectState<ModuleBufferizationState>(
|
return state.getOrCreateDialectState<ModuleBufferizationState>(
|
||||||
StandardOpsDialect::getDialectNamespace());
|
StandardOpsDialect::getDialectNamespace());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -471,19 +482,25 @@ namespace std_ext {
|
||||||
/// Return the index of the bbArg in the given FuncOp that is equivalent to the
|
/// Return the index of the bbArg in the given FuncOp that is equivalent to the
|
||||||
/// specified return value (if any).
|
/// specified return value (if any).
|
||||||
static Optional<int64_t>
|
static Optional<int64_t>
|
||||||
getEquivalentFuncArgIdx(FuncOp funcOp, ModuleBufferizationState &state,
|
getEquivalentFuncArgIdx(FuncOp funcOp, const ModuleBufferizationState &state,
|
||||||
int64_t returnValIdx) {
|
int64_t returnValIdx) {
|
||||||
if (!state.equivalentFuncArgs[funcOp].count(returnValIdx))
|
if (!state.equivalentFuncArgs.count(funcOp))
|
||||||
|
// No equivalence info stores for funcOp.
|
||||||
|
return None;
|
||||||
|
|
||||||
|
const DenseMap<int64_t, int64_t> &equivFuncArgs =
|
||||||
|
state.equivalentFuncArgs.lookup(funcOp);
|
||||||
|
if (!equivFuncArgs.count(returnValIdx))
|
||||||
// Return value has no equivalent bbArg.
|
// Return value has no equivalent bbArg.
|
||||||
return None;
|
return None;
|
||||||
|
|
||||||
return state.equivalentFuncArgs[funcOp][returnValIdx];
|
return equivFuncArgs.lookup(returnValIdx);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct CallOpInterface
|
struct CallOpInterface
|
||||||
: public BufferizableOpInterface::ExternalModel<CallOpInterface, CallOp> {
|
: public BufferizableOpInterface::ExternalModel<CallOpInterface, CallOp> {
|
||||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
// CallOpInterface alone doesn't bufferize to a memory read, one of the uses
|
// CallOpInterface alone doesn't bufferize to a memory read, one of the uses
|
||||||
// of the matching bbArg may. It is the responsibility of the caller to
|
// of the matching bbArg may. It is the responsibility of the caller to
|
||||||
// inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be
|
// inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be
|
||||||
|
@ -492,7 +509,7 @@ struct CallOpInterface
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
// CallOpInterface is special, it needs to wait for the callee to be
|
// CallOpInterface is special, it needs to wait for the callee to be
|
||||||
// bufferized and needs to inspect the BufferAliasInfo object. It can't
|
// bufferized and needs to inspect the BufferAliasInfo object. It can't
|
||||||
// make a proper determination by itself and needs to be conservative.
|
// make a proper determination by itself and needs to be conservative.
|
||||||
|
@ -503,14 +520,15 @@ struct CallOpInterface
|
||||||
/// marked inplaceable. For now, it is the responsibility of the `callOp`
|
/// marked inplaceable. For now, it is the responsibility of the `callOp`
|
||||||
/// bufferization to allow FuncOp that are inplaceable to write inPlace.
|
/// bufferization to allow FuncOp that are inplaceable to write inPlace.
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
CallOp callOp = cast<CallOp>(op);
|
CallOp callOp = cast<CallOp>(op);
|
||||||
unsigned numResults = callOp.getNumResults();
|
unsigned numResults = callOp.getNumResults();
|
||||||
unsigned numOperands = callOp->getNumOperands();
|
unsigned numOperands = callOp->getNumOperands();
|
||||||
FuncOp funcOp = getCalledFunction(callOp);
|
FuncOp funcOp = getCalledFunction(callOp);
|
||||||
assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
|
assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
|
||||||
"expected CallOp to a FuncOp");
|
"expected CallOp to a FuncOp");
|
||||||
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
|
const ModuleBufferizationState &moduleState =
|
||||||
|
getModuleBufferizationState(state);
|
||||||
|
|
||||||
// Result types of the bufferized CallOp.
|
// Result types of the bufferized CallOp.
|
||||||
SmallVector<Type> resultTypes;
|
SmallVector<Type> resultTypes;
|
||||||
|
@ -626,22 +644,22 @@ struct ReturnOpInterface
|
||||||
: public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
|
: public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
|
||||||
ReturnOp> {
|
ReturnOp> {
|
||||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return OpResult();
|
return OpResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
auto returnOp = cast<ReturnOp>(op);
|
auto returnOp = cast<ReturnOp>(op);
|
||||||
assert(isa<FuncOp>(returnOp->getParentOp()) &&
|
assert(isa<FuncOp>(returnOp->getParentOp()) &&
|
||||||
"only support FuncOp parent for ReturnOp");
|
"only support FuncOp parent for ReturnOp");
|
||||||
|
@ -662,7 +680,7 @@ struct ReturnOpInterface
|
||||||
struct FuncOpInterface
|
struct FuncOpInterface
|
||||||
: public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> {
|
: public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> {
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
auto funcOp = cast<FuncOp>(op);
|
auto funcOp = cast<FuncOp>(op);
|
||||||
|
|
||||||
// Bufferize function body.
|
// Bufferize function body.
|
||||||
|
@ -670,11 +688,13 @@ struct FuncOpInterface
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return `true` if the given function argument is writable.
|
/// Return `true` if the given function argument is writable.
|
||||||
bool isWritable(Operation *op, Value value, BufferizationState &state) const {
|
bool isWritable(Operation *op, Value value,
|
||||||
|
const BufferizationState &state) const {
|
||||||
auto funcOp = cast<FuncOp>(op);
|
auto funcOp = cast<FuncOp>(op);
|
||||||
BlockArgument bbArg = value.dyn_cast<BlockArgument>();
|
BlockArgument bbArg = value.dyn_cast<BlockArgument>();
|
||||||
assert(bbArg && "expected BlockArgument");
|
assert(bbArg && "expected BlockArgument");
|
||||||
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
|
const ModuleBufferizationState &moduleState =
|
||||||
|
getModuleBufferizationState(state);
|
||||||
|
|
||||||
// In a first approximation:
|
// In a first approximation:
|
||||||
// =========================
|
// =========================
|
||||||
|
@ -720,8 +740,9 @@ static void setInPlaceFuncArgument(BlockArgument bbArg, bool inPlace) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Annotate the IR with the result of the analysis. For testing/debugging only.
|
/// Annotate the IR with the result of the analysis. For testing/debugging only.
|
||||||
static void annotateOpsWithBufferizationMarkers(FuncOp funcOp,
|
static void
|
||||||
BufferizationState &state) {
|
annotateOpsWithBufferizationMarkers(FuncOp funcOp,
|
||||||
|
const BufferizationState &state) {
|
||||||
auto bufferizableOp = cast<BufferizableOpInterface>(funcOp.getOperation());
|
auto bufferizableOp = cast<BufferizableOpInterface>(funcOp.getOperation());
|
||||||
for (BlockArgument bbArg : funcOp.getArguments())
|
for (BlockArgument bbArg : funcOp.getArguments())
|
||||||
if (bbArg.getType().isa<TensorType>())
|
if (bbArg.getType().isa<TensorType>())
|
||||||
|
@ -733,7 +754,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
|
||||||
IRRewriter rewriter(moduleOp.getContext());
|
IRRewriter rewriter(moduleOp.getContext());
|
||||||
BufferizationState state(moduleOp, *options);
|
BufferizationState state(moduleOp, *options);
|
||||||
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
|
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
|
||||||
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
|
BufferizationAliasInfo &aliasInfo = state.getAliasInfo();
|
||||||
|
|
||||||
if (failed(getFuncOpsOrderedByCalls(moduleOp, moduleState.orderedFuncOps,
|
if (failed(getFuncOpsOrderedByCalls(moduleOp, moduleState.orderedFuncOps,
|
||||||
moduleState.callerMap)))
|
moduleState.callerMap)))
|
||||||
|
|
|
@ -24,7 +24,7 @@ struct ExecuteRegionOpInterface
|
||||||
scf::ExecuteRegionOp> {
|
scf::ExecuteRegionOp> {
|
||||||
SmallVector<OpOperand *>
|
SmallVector<OpOperand *>
|
||||||
getAliasingOpOperand(Operation *op, OpResult opResult,
|
getAliasingOpOperand(Operation *op, OpResult opResult,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
// ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
|
// ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
|
||||||
// any SSA value that is in scope. To allow for use-def chain traversal
|
// any SSA value that is in scope. To allow for use-def chain traversal
|
||||||
// through ExecuteRegionOps in the analysis, the corresponding yield value
|
// through ExecuteRegionOps in the analysis, the corresponding yield value
|
||||||
|
@ -41,7 +41,7 @@ struct ExecuteRegionOpInterface
|
||||||
}
|
}
|
||||||
|
|
||||||
bool mustBufferizeInPlace(Operation *op, OpResult opResult,
|
bool mustBufferizeInPlace(Operation *op, OpResult opResult,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
// ExecuteRegionOp results always bufferize in-place. Since they have no
|
// ExecuteRegionOp results always bufferize in-place. Since they have no
|
||||||
// OpOperands, they are mostly ignored by the analysis once alias sets are
|
// OpOperands, they are mostly ignored by the analysis once alias sets are
|
||||||
// set up.
|
// set up.
|
||||||
|
@ -51,7 +51,7 @@ struct ExecuteRegionOpInterface
|
||||||
// TODO: For better bufferization results, this could return `true` only if
|
// TODO: For better bufferization results, this could return `true` only if
|
||||||
// there is a memory write in the region.
|
// there is a memory write in the region.
|
||||||
bool isMemoryWrite(Operation *op, OpResult opResult,
|
bool isMemoryWrite(Operation *op, OpResult opResult,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
// Similar to scf.if, results of this op are always considered memory writes
|
// Similar to scf.if, results of this op are always considered memory writes
|
||||||
// in the analysis. This is a useful pattern for all ops that have tensor
|
// in the analysis. This is a useful pattern for all ops that have tensor
|
||||||
// OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
|
// OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
|
||||||
|
@ -61,7 +61,7 @@ struct ExecuteRegionOpInterface
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
// TODO: Add bufferization support when needed. scf.execute_region should be
|
// TODO: Add bufferization support when needed. scf.execute_region should be
|
||||||
// bufferized similar to scf.if.
|
// bufferized similar to scf.if.
|
||||||
auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
|
auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
|
||||||
|
@ -76,7 +76,7 @@ struct ExecuteRegionOpInterface
|
||||||
|
|
||||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||||
const BufferizationAliasInfo &aliasInfo,
|
const BufferizationAliasInfo &aliasInfo,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return BufferRelation::Equivalent;
|
return BufferRelation::Equivalent;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -85,7 +85,7 @@ struct IfOpInterface
|
||||||
: public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
|
: public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
|
||||||
SmallVector<OpOperand *>
|
SmallVector<OpOperand *>
|
||||||
getAliasingOpOperand(Operation *op, OpResult opResult,
|
getAliasingOpOperand(Operation *op, OpResult opResult,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
// IfOps do not have tensor OpOperands. The yielded value can be any SSA
|
// IfOps do not have tensor OpOperands. The yielded value can be any SSA
|
||||||
// value that is in scope. To allow for use-def chain traversal through
|
// value that is in scope. To allow for use-def chain traversal through
|
||||||
// IfOps in the analysis, both corresponding yield values from the then/else
|
// IfOps in the analysis, both corresponding yield values from the then/else
|
||||||
|
@ -102,7 +102,7 @@ struct IfOpInterface
|
||||||
// allowed at the moment, we should never encounter scf.ifs that yield
|
// allowed at the moment, we should never encounter scf.ifs that yield
|
||||||
// unmodified tensors. Such scf.yield ops could just fold away.
|
// unmodified tensors. Such scf.yield ops could just fold away.
|
||||||
bool isMemoryWrite(Operation *op, OpResult opResult,
|
bool isMemoryWrite(Operation *op, OpResult opResult,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
// IfOp results are always considered memory writes in the analysis. This
|
// IfOp results are always considered memory writes in the analysis. This
|
||||||
// design decision simplifies the analysis considerably. E.g., consider the
|
// design decision simplifies the analysis considerably. E.g., consider the
|
||||||
// following test case:
|
// following test case:
|
||||||
|
@ -129,14 +129,14 @@ struct IfOpInterface
|
||||||
}
|
}
|
||||||
|
|
||||||
bool mustBufferizeInPlace(Operation *op, OpResult opResult,
|
bool mustBufferizeInPlace(Operation *op, OpResult opResult,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
// IfOp results always bufferize in-place. Since they have no OpOperands,
|
// IfOp results always bufferize in-place. Since they have no OpOperands,
|
||||||
// they are mostly ignored by the analysis once alias sets are set up.
|
// they are mostly ignored by the analysis once alias sets are set up.
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
auto ifOp = cast<scf::IfOp>(op);
|
auto ifOp = cast<scf::IfOp>(op);
|
||||||
|
|
||||||
// Compute new types of the bufferized scf.if op.
|
// Compute new types of the bufferized scf.if op.
|
||||||
|
@ -209,7 +209,7 @@ struct IfOpInterface
|
||||||
|
|
||||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||||
const BufferizationAliasInfo &aliasInfo,
|
const BufferizationAliasInfo &aliasInfo,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
// IfOp results are equivalent to their corresponding yield values if both
|
// IfOp results are equivalent to their corresponding yield values if both
|
||||||
// yield values are equivalent to each other.
|
// yield values are equivalent to each other.
|
||||||
auto bufferizableOp = cast<BufferizableOpInterface>(op);
|
auto bufferizableOp = cast<BufferizableOpInterface>(op);
|
||||||
|
@ -226,7 +226,7 @@ struct ForOpInterface
|
||||||
: public BufferizableOpInterface::ExternalModel<ForOpInterface,
|
: public BufferizableOpInterface::ExternalModel<ForOpInterface,
|
||||||
scf::ForOp> {
|
scf::ForOp> {
|
||||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
// scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
|
// scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
|
||||||
// its matching bbArg may.
|
// its matching bbArg may.
|
||||||
auto forOp = cast<scf::ForOp>(op);
|
auto forOp = cast<scf::ForOp>(op);
|
||||||
|
@ -234,7 +234,7 @@ struct ForOpInterface
|
||||||
}
|
}
|
||||||
|
|
||||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
// Tensor iter_args of scf::ForOps are always considered as a write. This is
|
// Tensor iter_args of scf::ForOps are always considered as a write. This is
|
||||||
// to simplify the analysis.
|
// to simplify the analysis.
|
||||||
// TODO: Consider doing sth. like isValueWritten.
|
// TODO: Consider doing sth. like isValueWritten.
|
||||||
|
@ -242,7 +242,7 @@ struct ForOpInterface
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
auto forOp = cast<scf::ForOp>(op);
|
auto forOp = cast<scf::ForOp>(op);
|
||||||
if (!opOperand.get().getType().isa<RankedTensorType>())
|
if (!opOperand.get().getType().isa<RankedTensorType>())
|
||||||
return OpResult();
|
return OpResult();
|
||||||
|
@ -251,7 +251,7 @@ struct ForOpInterface
|
||||||
|
|
||||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||||
const BufferizationAliasInfo &aliasInfo,
|
const BufferizationAliasInfo &aliasInfo,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
// ForOp results are equivalent to their corresponding init_args if the
|
// ForOp results are equivalent to their corresponding init_args if the
|
||||||
// corresponding iter_args and yield values are equivalent.
|
// corresponding iter_args and yield values are equivalent.
|
||||||
auto forOp = cast<scf::ForOp>(op);
|
auto forOp = cast<scf::ForOp>(op);
|
||||||
|
@ -263,7 +263,8 @@ struct ForOpInterface
|
||||||
return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None;
|
return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isWritable(Operation *op, Value value, BufferizationState &state) const {
|
bool isWritable(Operation *op, Value value,
|
||||||
|
const BufferizationState &state) const {
|
||||||
// Interestingly, scf::ForOp's bbArg can **always** be viewed
|
// Interestingly, scf::ForOp's bbArg can **always** be viewed
|
||||||
// inplace from the perspective of ops nested under:
|
// inplace from the perspective of ops nested under:
|
||||||
// 1. Either the matching iter operand is not bufferized inplace and an
|
// 1. Either the matching iter operand is not bufferized inplace and an
|
||||||
|
@ -274,7 +275,7 @@ struct ForOpInterface
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
auto forOp = cast<scf::ForOp>(op);
|
auto forOp = cast<scf::ForOp>(op);
|
||||||
Block *oldLoopBody = &forOp.getLoopBody().front();
|
Block *oldLoopBody = &forOp.getLoopBody().front();
|
||||||
|
|
||||||
|
@ -416,22 +417,22 @@ struct YieldOpInterface
|
||||||
: public BufferizableOpInterface::ExternalModel<YieldOpInterface,
|
: public BufferizableOpInterface::ExternalModel<YieldOpInterface,
|
||||||
scf::YieldOp> {
|
scf::YieldOp> {
|
||||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return OpResult();
|
return OpResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
auto yieldOp = cast<scf::YieldOp>(op);
|
auto yieldOp = cast<scf::YieldOp>(op);
|
||||||
if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp>(
|
if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp>(
|
||||||
yieldOp->getParentOp()))
|
yieldOp->getParentOp()))
|
||||||
|
|
|
@ -27,28 +27,28 @@ struct CastOpInterface
|
||||||
: public BufferizableOpInterface::ExternalModel<CastOpInterface,
|
: public BufferizableOpInterface::ExternalModel<CastOpInterface,
|
||||||
tensor::CastOp> {
|
tensor::CastOp> {
|
||||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return op->getResult(0);
|
return op->getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||||
const BufferizationAliasInfo &aliasInfo,
|
const BufferizationAliasInfo &aliasInfo,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return BufferRelation::Equivalent;
|
return BufferRelation::Equivalent;
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
auto castOp = cast<tensor::CastOp>(op);
|
auto castOp = cast<tensor::CastOp>(op);
|
||||||
|
|
||||||
Value resultBuffer = state.getResultBuffer(rewriter, castOp->getResult(0));
|
Value resultBuffer = state.getResultBuffer(rewriter, castOp->getResult(0));
|
||||||
|
@ -78,22 +78,22 @@ struct DimOpInterface
|
||||||
: public BufferizableOpInterface::ExternalModel<DimOpInterface,
|
: public BufferizableOpInterface::ExternalModel<DimOpInterface,
|
||||||
tensor::DimOp> {
|
tensor::DimOp> {
|
||||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return OpResult();
|
return OpResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
auto dimOp = cast<tensor::DimOp>(op);
|
auto dimOp = cast<tensor::DimOp>(op);
|
||||||
if (!dimOp.source().getType().isa<RankedTensorType>())
|
if (!dimOp.source().getType().isa<RankedTensorType>())
|
||||||
return dimOp.emitError("unranked tensor not supported");
|
return dimOp.emitError("unranked tensor not supported");
|
||||||
|
@ -107,17 +107,17 @@ struct ExtractSliceOpInterface
|
||||||
: public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
|
: public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
|
||||||
tensor::ExtractSliceOp> {
|
tensor::ExtractSliceOp> {
|
||||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return &opOperand == &op->getOpOperand(0) /*source*/
|
return &opOperand == &op->getOpOperand(0) /*source*/
|
||||||
? op->getResult(0)
|
? op->getResult(0)
|
||||||
: OpResult();
|
: OpResult();
|
||||||
|
@ -125,12 +125,12 @@ struct ExtractSliceOpInterface
|
||||||
|
|
||||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||||
const BufferizationAliasInfo &aliasInfo,
|
const BufferizationAliasInfo &aliasInfo,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return BufferRelation::None;
|
return BufferRelation::None;
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
|
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
|
||||||
Location loc = extractSliceOp.getLoc();
|
Location loc = extractSliceOp.getLoc();
|
||||||
Value srcMemref = state.lookupBuffer(rewriter, extractSliceOp.source());
|
Value srcMemref = state.lookupBuffer(rewriter, extractSliceOp.source());
|
||||||
|
@ -173,22 +173,22 @@ struct ExtractOpInterface
|
||||||
: public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
|
: public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
|
||||||
tensor::ExtractOp> {
|
tensor::ExtractOp> {
|
||||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return OpResult();
|
return OpResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
auto extractOp = cast<tensor::ExtractOp>(op);
|
auto extractOp = cast<tensor::ExtractOp>(op);
|
||||||
Value srcMemref = state.lookupBuffer(rewriter, extractOp.tensor());
|
Value srcMemref = state.lookupBuffer(rewriter, extractOp.tensor());
|
||||||
state.replaceOpWithNewOp<memref::LoadOp>(rewriter, op, srcMemref,
|
state.replaceOpWithNewOp<memref::LoadOp>(rewriter, op, srcMemref,
|
||||||
|
@ -201,17 +201,17 @@ struct InsertOpInterface
|
||||||
: public BufferizableOpInterface::ExternalModel<InsertOpInterface,
|
: public BufferizableOpInterface::ExternalModel<InsertOpInterface,
|
||||||
tensor::InsertOp> {
|
tensor::InsertOp> {
|
||||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
|
assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
|
||||||
"expected dest OpOperand");
|
"expected dest OpOperand");
|
||||||
return op->getOpResult(0);
|
return op->getOpResult(0);
|
||||||
|
@ -219,12 +219,12 @@ struct InsertOpInterface
|
||||||
|
|
||||||
SmallVector<OpOperand *>
|
SmallVector<OpOperand *>
|
||||||
getAliasingOpOperand(Operation *op, OpResult opResult,
|
getAliasingOpOperand(Operation *op, OpResult opResult,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return {&op->getOpOperand(1) /*dest*/};
|
return {&op->getOpOperand(1) /*dest*/};
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
auto insertOp = cast<tensor::InsertOp>(op);
|
auto insertOp = cast<tensor::InsertOp>(op);
|
||||||
Location loc = insertOp.getLoc();
|
Location loc = insertOp.getLoc();
|
||||||
Value destMemref =
|
Value destMemref =
|
||||||
|
@ -237,7 +237,7 @@ struct InsertOpInterface
|
||||||
|
|
||||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||||
const BufferizationAliasInfo &aliasInfo,
|
const BufferizationAliasInfo &aliasInfo,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return BufferRelation::Equivalent;
|
return BufferRelation::Equivalent;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -263,8 +263,8 @@ areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
|
||||||
/// Return true if `value` is originating from an ExtractSliceOp that matches
|
/// Return true if `value` is originating from an ExtractSliceOp that matches
|
||||||
/// the given InsertSliceOp.
|
/// the given InsertSliceOp.
|
||||||
static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
|
static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
|
||||||
BufferizationState &state, Value value,
|
const BufferizationState &state,
|
||||||
InsertSliceOp insertOp) {
|
Value value, InsertSliceOp insertOp) {
|
||||||
auto condition = [&](Value val) {
|
auto condition = [&](Value val) {
|
||||||
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
|
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
|
||||||
if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp))
|
if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp))
|
||||||
|
@ -280,17 +280,17 @@ struct InsertSliceOpInterface
|
||||||
: public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
|
: public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
|
||||||
tensor::InsertSliceOp> {
|
tensor::InsertSliceOp> {
|
||||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return &opOperand == &op->getOpOperand(1) /*dest*/;
|
return &opOperand == &op->getOpOperand(1) /*dest*/;
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return &opOperand == &op->getOpOperand(1) /*dest*/
|
return &opOperand == &op->getOpOperand(1) /*dest*/
|
||||||
? op->getResult(0)
|
? op->getResult(0)
|
||||||
: OpResult();
|
: OpResult();
|
||||||
|
@ -298,12 +298,13 @@ struct InsertSliceOpInterface
|
||||||
|
|
||||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||||
const BufferizationAliasInfo &aliasInfo,
|
const BufferizationAliasInfo &aliasInfo,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return BufferRelation::Equivalent;
|
return BufferRelation::Equivalent;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isNotConflicting(Operation *op, OpOperand *uRead,
|
bool isNotConflicting(Operation *op, OpOperand *uRead,
|
||||||
OpOperand *uConflictingWrite, BufferizationState &state,
|
OpOperand *uConflictingWrite,
|
||||||
|
const BufferizationState &state,
|
||||||
const BufferizationAliasInfo &aliasInfo) const {
|
const BufferizationAliasInfo &aliasInfo) const {
|
||||||
Operation *readingOp = uRead->getOwner();
|
Operation *readingOp = uRead->getOwner();
|
||||||
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
|
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
|
||||||
|
@ -380,7 +381,7 @@ struct InsertSliceOpInterface
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
// insert_slice ops arise from tiling and bufferizing them out-of-place is
|
// insert_slice ops arise from tiling and bufferizing them out-of-place is
|
||||||
// generally a deal breaker. When used with loops, this ends up cloning the
|
// generally a deal breaker. When used with loops, this ends up cloning the
|
||||||
// whole tensor on every single iteration and is a symptom of a
|
// whole tensor on every single iteration and is a symptom of a
|
||||||
|
|
|
@ -21,26 +21,26 @@ struct TransferReadOpInterface
|
||||||
: public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
|
: public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
|
||||||
vector::TransferReadOp> {
|
vector::TransferReadOp> {
|
||||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
assert(opOperand.get().getType().isa<RankedTensorType>() &&
|
assert(opOperand.get().getType().isa<RankedTensorType>() &&
|
||||||
"only tensor types expected");
|
"only tensor types expected");
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
assert(opOperand.get().getType().isa<RankedTensorType>() &&
|
assert(opOperand.get().getType().isa<RankedTensorType>() &&
|
||||||
"only tensor types expected");
|
"only tensor types expected");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return OpResult();
|
return OpResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
auto readOp = cast<vector::TransferReadOp>(op);
|
auto readOp = cast<vector::TransferReadOp>(op);
|
||||||
assert(readOp.getShapedType().isa<TensorType>() &&
|
assert(readOp.getShapedType().isa<TensorType>() &&
|
||||||
"only tensor types expected");
|
"only tensor types expected");
|
||||||
|
@ -60,21 +60,21 @@ struct TransferWriteOpInterface
|
||||||
: public BufferizableOpInterface::ExternalModel<TransferWriteOpInterface,
|
: public BufferizableOpInterface::ExternalModel<TransferWriteOpInterface,
|
||||||
vector::TransferWriteOp> {
|
vector::TransferWriteOp> {
|
||||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
assert(opOperand.get().getType().isa<TensorType>() &&
|
assert(opOperand.get().getType().isa<TensorType>() &&
|
||||||
"only tensor types expected");
|
"only tensor types expected");
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
assert(opOperand.get().getType().isa<TensorType>() &&
|
assert(opOperand.get().getType().isa<TensorType>() &&
|
||||||
"only tensor types expected");
|
"only tensor types expected");
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
assert(opOperand.get().getType().isa<TensorType>() &&
|
assert(opOperand.get().getType().isa<TensorType>() &&
|
||||||
"only tensor types expected");
|
"only tensor types expected");
|
||||||
return op->getOpResult(0);
|
return op->getOpResult(0);
|
||||||
|
@ -82,12 +82,12 @@ struct TransferWriteOpInterface
|
||||||
|
|
||||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||||
const BufferizationAliasInfo &aliasInfo,
|
const BufferizationAliasInfo &aliasInfo,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return BufferRelation::Equivalent;
|
return BufferRelation::Equivalent;
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
auto writeOp = cast<vector::TransferWriteOp>(op);
|
auto writeOp = cast<vector::TransferWriteOp>(op);
|
||||||
assert(writeOp.getShapedType().isa<TensorType>() &&
|
assert(writeOp.getShapedType().isa<TensorType>() &&
|
||||||
"only tensor types expected");
|
"only tensor types expected");
|
||||||
|
|
Loading…
Reference in New Issue