forked from OSchip/llvm-project
[mlir][linalg][bufferize][NFC] Decouple ComprehensiveBufferize from tensor dialect
Add a new BufferizableOpInterface method `isNotConflicting` that can be used to implement custom analysis rules. Differential Revision: https://reviews.llvm.org/D113961
This commit is contained in:
parent
0c7890c844
commit
26e90423f4
|
@ -215,6 +215,29 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
|||
/*defaultImplementation=*/[{
|
||||
return false;
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return `true` if the `uRead` and `uWrite` do not constitute a RaW
|
||||
conflict. If they are conflicting or if it is unknown whether they are
|
||||
conflicting, return `false`. This method will never be called with
|
||||
OpOperands that do not have a tensor type. At least one of the two
|
||||
given OpOperands belongs to this operation.
|
||||
|
||||
This method can be implemented to specify custom RaW analysis rules.
|
||||
If this method returns `true` the given OpOperands are not considered
|
||||
to be conflicting and do not force out-of-place bufferization. (There
|
||||
may still be other conflicts that do.)
|
||||
}],
|
||||
/*retType=*/"bool",
|
||||
/*methodName=*/"isNotConflicting",
|
||||
/*args=*/(ins "OpOperand *":$uRead,
|
||||
"OpOperand *":$uWrite,
|
||||
"const BufferizationAliasInfo &":$aliasInfo),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return false;
|
||||
}]
|
||||
>
|
||||
];
|
||||
|
||||
|
|
|
@ -281,24 +281,6 @@ static std::string printValueInfo(Value value, bool prefix) {
|
|||
// Bufferization-specific alias analysis.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
|
||||
/// equivalent operand / result and same offset/sizes/strides specification).
|
||||
///
|
||||
/// This is one particular type of relationship between ops on tensors that
|
||||
/// reduce to an equivalence on buffers. This should be generalized and
|
||||
/// exposed as interfaces on the proper types.
|
||||
static bool
|
||||
areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
|
||||
ExtractSliceOp st, InsertSliceOp sti) {
|
||||
if (!st || !sti)
|
||||
return false;
|
||||
if (!aliasInfo.areEquivalentBufferizedValues(st.source(), sti.dest()))
|
||||
return false;
|
||||
if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Return true if opOperand has been decided to bufferize in-place.
|
||||
static bool isInplaceMemoryWrite(OpOperand &opOperand,
|
||||
const BufferizationAliasInfo &aliasInfo) {
|
||||
|
@ -368,21 +350,6 @@ static bool aliasesInPlaceWrite(Value value,
|
|||
return foundInplaceWrite;
|
||||
}
|
||||
|
||||
/// Return true if `value` is originating from an ExtractSliceOp that matches
|
||||
/// the given InsertSliceOp.
|
||||
static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
|
||||
Value value, InsertSliceOp insertOp) {
|
||||
auto condition = [&](Value val) {
|
||||
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
|
||||
if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp))
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
|
||||
return llvm::all_of(findValueInReverseUseDefChain(value, condition),
|
||||
condition);
|
||||
}
|
||||
|
||||
/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
|
||||
/// properly dominates `b` and `b` is not inside `a`.
|
||||
static bool happensBefore(Operation *a, Operation *b,
|
||||
|
@ -450,6 +417,21 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
|
|||
if (uConflictingWrite == uRead)
|
||||
continue;
|
||||
|
||||
// No conflict if the op interface says so.
|
||||
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(readingOp))
|
||||
if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite,
|
||||
aliasInfo))
|
||||
continue;
|
||||
|
||||
if (conflictingWritingOp != readingOp)
|
||||
if (auto bufferizableOp =
|
||||
dyn_cast<BufferizableOpInterface>(conflictingWritingOp))
|
||||
if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite,
|
||||
aliasInfo))
|
||||
continue;
|
||||
|
||||
// Special rules for branches.
|
||||
// TODO: Use an interface.
|
||||
if (scf::insideMutuallyExclusiveBranches(readingOp, conflictingWritingOp))
|
||||
continue;
|
||||
|
||||
|
@ -478,73 +460,6 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
|
|||
if (getAliasingOpResult(*uConflictingWrite) == lastWrite)
|
||||
continue;
|
||||
|
||||
// Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
|
||||
// uRead is an InsertSliceOp...
|
||||
if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
|
||||
// As an example, consider the following IR.
|
||||
//
|
||||
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace= [true] }
|
||||
// %1 = linalg.fill %cst, %0 {inplace= [true] }
|
||||
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
|
||||
// {inplace= [true] }
|
||||
|
||||
// TODO: Use insertSliceOp.getDestOpOperand etc. when available.
|
||||
if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||
hasMatchingExtractSliceOp(aliasInfo, uConflictingWrite->get(),
|
||||
insertSliceOp))
|
||||
// Case 1: The main insight is that InsertSliceOp reads only part of
|
||||
// the destination tensor. The overwritten area is not read. If
|
||||
// uConflictingWrite writes into exactly the memory location that is
|
||||
// being read by uRead, this is not a conflict.
|
||||
//
|
||||
// In the above example:
|
||||
// uRead = OpOperand 1 (%t) of tensor.insert_slice
|
||||
// uConflictingWrite = OpOperand 1 (%0) of linalg.fill
|
||||
//
|
||||
// The read of %t does not conflict with the write of the FillOp
|
||||
// (same aliases!) because the area that the FillOp operates on is
|
||||
// exactly the one that is *not* read via %t.
|
||||
continue;
|
||||
|
||||
if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
|
||||
uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||
hasMatchingExtractSliceOp(aliasInfo, uRead->get(), insertSliceOp))
|
||||
// Case 2: The read of the source tensor and the write to the dest
|
||||
// tensor via an InsertSliceOp is not a conflict if the read is
|
||||
// reading exactly that part of an equivalent tensor that the
|
||||
// InsertSliceOp is writing.
|
||||
//
|
||||
// In the above example:
|
||||
// uRead = OpOperand 0 (%1) of tensor.insert_slice
|
||||
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
|
||||
continue;
|
||||
}
|
||||
|
||||
// If uConflictingWrite is an InsertSliceOp...
|
||||
if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
|
||||
// As an example, consider the following IR.
|
||||
//
|
||||
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace= [true] }
|
||||
// %1 = linalg.fill %cst, %0 {inplace= [true] }
|
||||
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
|
||||
// {inplace= [true] }
|
||||
// %3 = vector.transfer_read %1, %cst
|
||||
//
|
||||
// In the above example:
|
||||
// uRead = OpOperand 0 (%1) of vector.transfer_read
|
||||
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
|
||||
// lastWrite = %1
|
||||
//
|
||||
// This is not a conflict because the InsertSliceOp overwrites the
|
||||
// memory segment of %1 with the exact same data. (Effectively, there
|
||||
// is no memory write here.)
|
||||
if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||
aliasInfo.areEquivalentBufferizedValues(uRead->get(),
|
||||
insertSliceOp.source()) &&
|
||||
hasMatchingExtractSliceOp(aliasInfo, insertSliceOp.source(),
|
||||
insertSliceOp))
|
||||
continue;
|
||||
|
||||
// All requirements are met. Conflict found!
|
||||
LDBG("CONFLICT CONFIRMED!\n\n");
|
||||
return true;
|
||||
|
@ -2321,6 +2236,24 @@ struct ExtractOpInterface
|
|||
}
|
||||
};
|
||||
|
||||
/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
|
||||
/// equivalent operand / result and same offset/sizes/strides specification).
|
||||
///
|
||||
/// This is one particular type of relationship between ops on tensors that
|
||||
/// reduce to an equivalence on buffers. This should be generalized and
|
||||
/// exposed as interfaces on the proper types.
|
||||
static bool
|
||||
areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
|
||||
ExtractSliceOp st, InsertSliceOp sti) {
|
||||
if (!st || !sti)
|
||||
return false;
|
||||
if (!aliasInfo.areEquivalentBufferizedValues(st.source(), sti.dest()))
|
||||
return false;
|
||||
if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Return true if the source of a `insertSliceOp` bufferizes to an
|
||||
/// equivalent ExtractSliceOp that bufferizes inplace.
|
||||
static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp(
|
||||
|
@ -2345,6 +2278,21 @@ static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp(
|
|||
return foundOp;
|
||||
}
|
||||
|
||||
/// Return true if `value` is originating from an ExtractSliceOp that matches
|
||||
/// the given InsertSliceOp.
|
||||
static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
|
||||
Value value, InsertSliceOp insertOp) {
|
||||
auto condition = [&](Value val) {
|
||||
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
|
||||
if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp))
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
|
||||
return llvm::all_of(findValueInReverseUseDefChain(value, condition),
|
||||
condition);
|
||||
}
|
||||
|
||||
struct InsertSliceOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
|
||||
tensor::InsertSliceOp> {
|
||||
|
@ -2371,6 +2319,82 @@ struct InsertSliceOpInterface
|
|||
return BufferRelation::Equivalent;
|
||||
}
|
||||
|
||||
bool isNotConflicting(Operation *op, OpOperand *uRead,
|
||||
OpOperand *uConflictingWrite,
|
||||
const BufferizationAliasInfo &aliasInfo) const {
|
||||
Operation *readingOp = uRead->getOwner();
|
||||
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
|
||||
|
||||
// Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
|
||||
// uRead is an InsertSliceOp...
|
||||
if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
|
||||
// As an example, consider the following IR.
|
||||
//
|
||||
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
|
||||
// %1 = linalg.fill %cst, %0 {inplace= [true] }
|
||||
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
|
||||
// {inplace= [true] }
|
||||
|
||||
// TODO: Use insertSliceOp.getDestOpOperand etc. when available.
|
||||
if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||
hasMatchingExtractSliceOp(aliasInfo, uConflictingWrite->get(),
|
||||
insertSliceOp))
|
||||
// Case 1: The main insight is that InsertSliceOp reads only part of
|
||||
// the destination tensor. The overwritten area is not read. If
|
||||
// uConflictingWrite writes into exactly the memory location that is
|
||||
// being read by uRead, this is not a conflict.
|
||||
//
|
||||
// In the above example:
|
||||
// uRead = OpOperand 1 (%t) of tensor.insert_slice
|
||||
// uConflictingWrite = OpOperand 1 (%0) of linalg.fill
|
||||
//
|
||||
// The read of %t does not conflict with the write of the FillOp
|
||||
// (same aliases!) because the area that the FillOp operates on is
|
||||
// exactly the one that is *not* read via %t.
|
||||
return true;
|
||||
|
||||
if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
|
||||
uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||
hasMatchingExtractSliceOp(aliasInfo, uRead->get(), insertSliceOp))
|
||||
// Case 2: The read of the source tensor and the write to the dest
|
||||
// tensor via an InsertSliceOp is not a conflict if the read is
|
||||
// reading exactly that part of an equivalent tensor that the
|
||||
// InsertSliceOp is writing.
|
||||
//
|
||||
// In the above example:
|
||||
// uRead = OpOperand 0 (%1) of tensor.insert_slice
|
||||
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
|
||||
return true;
|
||||
}
|
||||
|
||||
// If uConflictingWrite is an InsertSliceOp...
|
||||
if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
|
||||
// As an example, consider the following IR.
|
||||
//
|
||||
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
|
||||
// %1 = linalg.fill %cst, %0 {inplace= [true] }
|
||||
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
|
||||
// {inplace= [true] }
|
||||
// %3 = vector.transfer_read %1, %cst
|
||||
//
|
||||
// In the above example:
|
||||
// uRead = OpOperand 0 (%1) of vector.transfer_read
|
||||
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
|
||||
// lastWrite = %1
|
||||
//
|
||||
// This is not a conflict because the InsertSliceOp overwrites the
|
||||
// memory segment of %1 with the exact same data. (Effectively, there
|
||||
// is no memory write here.)
|
||||
if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||
aliasInfo.areEquivalentBufferizedValues(uRead->get(),
|
||||
insertSliceOp.source()) &&
|
||||
hasMatchingExtractSliceOp(aliasInfo, insertSliceOp.source(),
|
||||
insertSliceOp))
|
||||
return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BufferizationState &state) const {
|
||||
// insert_slice ops arise from tiling and bufferizing them out-of-place is
|
||||
|
|
Loading…
Reference in New Issue