[mlir][linalg][bufferize][NFC] Decouple ComprehensiveBufferize from tensor dialect

Add a new BufferizableOpInterface method `isNotConflicting` that can be used to implement custom analysis rules.

Differential Revision: https://reviews.llvm.org/D113961
This commit is contained in:
Matthias Springer 2021-11-18 16:10:10 +09:00
parent 0c7890c844
commit 26e90423f4
2 changed files with 147 additions and 100 deletions

View File

@ -215,6 +215,29 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*defaultImplementation=*/[{
return false;
}]
>,
InterfaceMethod<
/*desc=*/[{
Return `true` if the `uRead` and `uWrite` do not constitute a RaW
conflict. If they are conflicting or if it is unknown whether they are
conflicting, return `false`. This method will never be called with
OpOperands that do not have a tensor type. At least one of the two
given OpOperands belongs to this operation.
This method can be implemented to specify custom RaW analysis rules.
If this method returns `true` the given OpOperands are not considered
to be conflicting and do not force out-of-place bufferization. (There
may still be other conflicts that do.)
}],
/*retType=*/"bool",
/*methodName=*/"isNotConflicting",
/*args=*/(ins "OpOperand *":$uRead,
"OpOperand *":$uWrite,
"const BufferizationAliasInfo &":$aliasInfo),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return false;
}]
>
];

View File

@ -281,24 +281,6 @@ static std::string printValueInfo(Value value, bool prefix) {
// Bufferization-specific alias analysis.
//===----------------------------------------------------------------------===//
/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
/// equivalent operand / result and same offset/sizes/strides specification).
///
/// This is one particular type of relationship between ops on tensors that
/// reduce to an equivalence on buffers. This should be generalized and
/// exposed as interfaces on the proper types.
static bool
areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
ExtractSliceOp st, InsertSliceOp sti) {
if (!st || !sti)
return false;
if (!aliasInfo.areEquivalentBufferizedValues(st.source(), sti.dest()))
return false;
if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
return false;
return true;
}
/// Return true if opOperand has been decided to bufferize in-place.
static bool isInplaceMemoryWrite(OpOperand &opOperand,
const BufferizationAliasInfo &aliasInfo) {
@ -368,21 +350,6 @@ static bool aliasesInPlaceWrite(Value value,
return foundInplaceWrite;
}
/// Return true if `value` is originating from an ExtractSliceOp that matches
/// the given InsertSliceOp.
static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
Value value, InsertSliceOp insertOp) {
auto condition = [&](Value val) {
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp))
return true;
return false;
};
return llvm::all_of(findValueInReverseUseDefChain(value, condition),
condition);
}
/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
/// properly dominates `b` and `b` is not inside `a`.
static bool happensBefore(Operation *a, Operation *b,
@ -450,6 +417,21 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
if (uConflictingWrite == uRead)
continue;
// No conflict if the op interface says so.
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(readingOp))
if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite,
aliasInfo))
continue;
if (conflictingWritingOp != readingOp)
if (auto bufferizableOp =
dyn_cast<BufferizableOpInterface>(conflictingWritingOp))
if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite,
aliasInfo))
continue;
// Special rules for branches.
// TODO: Use an interface.
if (scf::insideMutuallyExclusiveBranches(readingOp, conflictingWritingOp))
continue;
@ -478,73 +460,6 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
if (getAliasingOpResult(*uConflictingWrite) == lastWrite)
continue;
// Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
// uRead is an InsertSliceOp...
if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
// As an example, consider the following IR.
//
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace= [true] }
// %1 = linalg.fill %cst, %0 {inplace= [true] }
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
// {inplace= [true] }
// TODO: Use insertSliceOp.getDestOpOperand etc. when available.
if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
hasMatchingExtractSliceOp(aliasInfo, uConflictingWrite->get(),
insertSliceOp))
// Case 1: The main insight is that InsertSliceOp reads only part of
// the destination tensor. The overwritten area is not read. If
// uConflictingWrite writes into exactly the memory location that is
// being read by uRead, this is not a conflict.
//
// In the above example:
// uRead = OpOperand 1 (%t) of tensor.insert_slice
// uConflictingWrite = OpOperand 1 (%0) of linalg.fill
//
// The read of %t does not conflict with the write of the FillOp
// (same aliases!) because the area that the FillOp operates on is
// exactly the one that is *not* read via %t.
continue;
if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
hasMatchingExtractSliceOp(aliasInfo, uRead->get(), insertSliceOp))
// Case 2: The read of the source tensor and the write to the dest
// tensor via an InsertSliceOp is not a conflict if the read is
// reading exactly that part of an equivalent tensor that the
// InsertSliceOp is writing.
//
// In the above example:
// uRead = OpOperand 0 (%1) of tensor.insert_slice
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
continue;
}
// If uConflictingWrite is an InsertSliceOp...
if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
// As an example, consider the following IR.
//
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace= [true] }
// %1 = linalg.fill %cst, %0 {inplace= [true] }
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
// {inplace= [true] }
// %3 = vector.transfer_read %1, %cst
//
// In the above example:
// uRead = OpOperand 0 (%1) of vector.transfer_read
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
// lastWrite = %1
//
// This is not a conflict because the InsertSliceOp overwrites the
// memory segment of %1 with the exact same data. (Effectively, there
// is no memory write here.)
if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
aliasInfo.areEquivalentBufferizedValues(uRead->get(),
insertSliceOp.source()) &&
hasMatchingExtractSliceOp(aliasInfo, insertSliceOp.source(),
insertSliceOp))
continue;
// All requirements are met. Conflict found!
LDBG("CONFLICT CONFIRMED!\n\n");
return true;
@ -2321,6 +2236,24 @@ struct ExtractOpInterface
}
};
/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
/// equivalent operand / result and same offset/sizes/strides specification).
///
/// This is one particular type of relationship between ops on tensors that
/// reduce to an equivalence on buffers. This should be generalized and
/// exposed as interfaces on the proper types.
static bool
areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
ExtractSliceOp st, InsertSliceOp sti) {
if (!st || !sti)
return false;
if (!aliasInfo.areEquivalentBufferizedValues(st.source(), sti.dest()))
return false;
if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
return false;
return true;
}
/// Return true if the source of a `insertSliceOp` bufferizes to an
/// equivalent ExtractSliceOp that bufferizes inplace.
static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp(
@ -2345,6 +2278,21 @@ static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp(
return foundOp;
}
/// Return true if `value` is originating from an ExtractSliceOp that matches
/// the given InsertSliceOp.
static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
Value value, InsertSliceOp insertOp) {
auto condition = [&](Value val) {
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp))
return true;
return false;
};
return llvm::all_of(findValueInReverseUseDefChain(value, condition),
condition);
}
struct InsertSliceOpInterface
: public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
tensor::InsertSliceOp> {
@ -2371,6 +2319,82 @@ struct InsertSliceOpInterface
return BufferRelation::Equivalent;
}
bool isNotConflicting(Operation *op, OpOperand *uRead,
OpOperand *uConflictingWrite,
const BufferizationAliasInfo &aliasInfo) const {
Operation *readingOp = uRead->getOwner();
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
// Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
// uRead is an InsertSliceOp...
if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
// As an example, consider the following IR.
//
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
// %1 = linalg.fill %cst, %0 {inplace= [true] }
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
// {inplace= [true] }
// TODO: Use insertSliceOp.getDestOpOperand etc. when available.
if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
hasMatchingExtractSliceOp(aliasInfo, uConflictingWrite->get(),
insertSliceOp))
// Case 1: The main insight is that InsertSliceOp reads only part of
// the destination tensor. The overwritten area is not read. If
// uConflictingWrite writes into exactly the memory location that is
// being read by uRead, this is not a conflict.
//
// In the above example:
// uRead = OpOperand 1 (%t) of tensor.insert_slice
// uConflictingWrite = OpOperand 1 (%0) of linalg.fill
//
// The read of %t does not conflict with the write of the FillOp
// (same aliases!) because the area that the FillOp operates on is
// exactly the one that is *not* read via %t.
return true;
if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
hasMatchingExtractSliceOp(aliasInfo, uRead->get(), insertSliceOp))
// Case 2: The read of the source tensor and the write to the dest
// tensor via an InsertSliceOp is not a conflict if the read is
// reading exactly that part of an equivalent tensor that the
// InsertSliceOp is writing.
//
// In the above example:
// uRead = OpOperand 0 (%1) of tensor.insert_slice
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
return true;
}
// If uConflictingWrite is an InsertSliceOp...
if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
// As an example, consider the following IR.
//
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
// %1 = linalg.fill %cst, %0 {inplace= [true] }
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
// {inplace= [true] }
// %3 = vector.transfer_read %1, %cst
//
// In the above example:
// uRead = OpOperand 0 (%1) of vector.transfer_read
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
// lastWrite = %1
//
// This is not a conflict because the InsertSliceOp overwrites the
// memory segment of %1 with the exact same data. (Effectively, there
// is no memory write here.)
if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
aliasInfo.areEquivalentBufferizedValues(uRead->get(),
insertSliceOp.source()) &&
hasMatchingExtractSliceOp(aliasInfo, insertSliceOp.source(),
insertSliceOp))
return true;
return false;
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
// insert_slice ops arise from tiling and bufferizing them out-of-place is