forked from OSchip/llvm-project
[mlir][tensor][bufferize][NFC] Remove duplicate code
InsertSliceOp and ParallelInsertSliceOp are very similar and can share some of the bufferization analysis code. Differential Revision: https://reviews.llvm.org/D130465
This commit is contained in:
parent
8cbf4a386b
commit
1defec8730
|
@ -552,29 +552,30 @@ struct InsertOpInterface
|
|||
|
||||
/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
|
||||
/// equivalent operand / result and same offset/sizes/strides specification).
|
||||
///
|
||||
/// This is one particular type of relationship between ops on tensors that
|
||||
/// reduce to an equivalence on buffers. This should be generalized and
|
||||
/// exposed as interfaces on the proper types.
|
||||
template <typename OpTy>
|
||||
static bool areEquivalentExtractSliceOps(const AnalysisState &state,
|
||||
ExtractSliceOp st, InsertSliceOp sti) {
|
||||
if (!st || !sti)
|
||||
ExtractSliceOp extractSliceOp,
|
||||
OpTy insertSliceOp) {
|
||||
if (!extractSliceOp || !insertSliceOp)
|
||||
return false;
|
||||
if (sti != sti &&
|
||||
!state.areEquivalentBufferizedValues(st.getSource(), sti.getDest()))
|
||||
if (extractSliceOp != insertSliceOp &&
|
||||
!state.areEquivalentBufferizedValues(extractSliceOp.getSource(),
|
||||
insertSliceOp.getDest()))
|
||||
return false;
|
||||
if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
|
||||
if (!sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
|
||||
isEqualConstantIntOrValue))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Return true if `value` is originating from an ExtractSliceOp that matches
|
||||
/// the given InsertSliceOp.
|
||||
template <typename OpTy>
|
||||
static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
|
||||
InsertSliceOp insertOp) {
|
||||
OpTy insertSliceOp) {
|
||||
auto condition = [&](Value val) {
|
||||
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
|
||||
if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
|
||||
if (auto extractSliceOp = val.getDefiningOp<ExtractSliceOp>())
|
||||
if (areEquivalentExtractSliceOps(state, extractSliceOp, insertSliceOp))
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
|
@ -583,6 +584,83 @@ static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
|
|||
condition);
|
||||
}
|
||||
|
||||
template <typename OpTy>
|
||||
static bool isNotConflictingInsertSliceLikeOp(Operation *op, OpOperand *uRead,
|
||||
OpOperand *uConflictingWrite,
|
||||
const AnalysisState &state) {
|
||||
Operation *readingOp = uRead->getOwner();
|
||||
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
|
||||
|
||||
// Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
|
||||
// uRead is an InsertSliceOp...
|
||||
if (auto insertSliceOp = dyn_cast<OpTy>(readingOp)) {
|
||||
// As an example, consider the following IR.
|
||||
//
|
||||
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
|
||||
// %1 = linalg.fill %cst, %0 {inplace= [true] }
|
||||
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
|
||||
// {inplace= [true] }
|
||||
|
||||
// TODO: Use insertSliceOp.getDestOpOperand etc. when available.
|
||||
if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||
hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
|
||||
insertSliceOp))
|
||||
// Case 1: The main insight is that InsertSliceOp reads only part of
|
||||
// the destination tensor. The overwritten area is not read. If
|
||||
// uConflictingWrite writes into exactly the memory location that is
|
||||
// being read by uRead, this is not a conflict.
|
||||
//
|
||||
// In the above example:
|
||||
// uRead = OpOperand 1 (%t) of tensor.insert_slice
|
||||
// uConflictingWrite = OpOperand 1 (%0) of linalg.fill
|
||||
//
|
||||
// The read of %t does not conflict with the write of the FillOp
|
||||
// (same aliases!) because the area that the FillOp operates on is
|
||||
// exactly the one that is *not* read via %t.
|
||||
return true;
|
||||
|
||||
if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
|
||||
uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||
hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
|
||||
// Case 2: The read of the source tensor and the write to the dest
|
||||
// tensor via an InsertSliceOp is not a conflict if the read is
|
||||
// reading exactly that part of an equivalent tensor that the
|
||||
// InsertSliceOp is writing.
|
||||
//
|
||||
// In the above example:
|
||||
// uRead = OpOperand 0 (%1) of tensor.insert_slice
|
||||
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
|
||||
return true;
|
||||
}
|
||||
|
||||
// If uConflictingWrite is an InsertSliceOp...
|
||||
if (auto insertSliceOp = dyn_cast<OpTy>(conflictingWritingOp))
|
||||
// As an example, consider the following IR.
|
||||
//
|
||||
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
|
||||
// %1 = linalg.fill %cst, %0 {inplace= [true] }
|
||||
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
|
||||
// {inplace= [true] }
|
||||
// %3 = vector.transfer_read %1, %cst
|
||||
//
|
||||
// In the above example:
|
||||
// uRead = OpOperand 0 (%1) of vector.transfer_read
|
||||
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
|
||||
// lastWrite = %1
|
||||
//
|
||||
// This is not a conflict because the InsertSliceOp overwrites the
|
||||
// memory segment of %1 with the exact same data. (Effectively, there
|
||||
// is no memory write here.)
|
||||
if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||
state.areEquivalentBufferizedValues(uRead->get(),
|
||||
insertSliceOp.getSource()) &&
|
||||
hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
|
||||
insertSliceOp))
|
||||
return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
|
||||
/// certain circumstances, this op can also be a no-op.
|
||||
struct InsertSliceOpInterface
|
||||
|
@ -613,77 +691,8 @@ struct InsertSliceOpInterface
|
|||
bool isNotConflicting(Operation *op, OpOperand *uRead,
|
||||
OpOperand *uConflictingWrite,
|
||||
const AnalysisState &state) const {
|
||||
Operation *readingOp = uRead->getOwner();
|
||||
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
|
||||
|
||||
// Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
|
||||
// uRead is an InsertSliceOp...
|
||||
if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
|
||||
// As an example, consider the following IR.
|
||||
//
|
||||
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
|
||||
// %1 = linalg.fill %cst, %0 {inplace= [true] }
|
||||
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
|
||||
// {inplace= [true] }
|
||||
|
||||
// TODO: Use insertSliceOp.getDestOpOperand etc. when available.
|
||||
if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||
hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
|
||||
insertSliceOp))
|
||||
// Case 1: The main insight is that InsertSliceOp reads only part of
|
||||
// the destination tensor. The overwritten area is not read. If
|
||||
// uConflictingWrite writes into exactly the memory location that is
|
||||
// being read by uRead, this is not a conflict.
|
||||
//
|
||||
// In the above example:
|
||||
// uRead = OpOperand 1 (%t) of tensor.insert_slice
|
||||
// uConflictingWrite = OpOperand 1 (%0) of linalg.fill
|
||||
//
|
||||
// The read of %t does not conflict with the write of the FillOp
|
||||
// (same aliases!) because the area that the FillOp operates on is
|
||||
// exactly the one that is *not* read via %t.
|
||||
return true;
|
||||
|
||||
if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
|
||||
uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||
hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
|
||||
// Case 2: The read of the source tensor and the write to the dest
|
||||
// tensor via an InsertSliceOp is not a conflict if the read is
|
||||
// reading exactly that part of an equivalent tensor that the
|
||||
// InsertSliceOp is writing.
|
||||
//
|
||||
// In the above example:
|
||||
// uRead = OpOperand 0 (%1) of tensor.insert_slice
|
||||
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
|
||||
return true;
|
||||
}
|
||||
|
||||
// If uConflictingWrite is an InsertSliceOp...
|
||||
if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
|
||||
// As an example, consider the following IR.
|
||||
//
|
||||
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
|
||||
// %1 = linalg.fill %cst, %0 {inplace= [true] }
|
||||
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
|
||||
// {inplace= [true] }
|
||||
// %3 = vector.transfer_read %1, %cst
|
||||
//
|
||||
// In the above example:
|
||||
// uRead = OpOperand 0 (%1) of vector.transfer_read
|
||||
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
|
||||
// lastWrite = %1
|
||||
//
|
||||
// This is not a conflict because the InsertSliceOp overwrites the
|
||||
// memory segment of %1 with the exact same data. (Effectively, there
|
||||
// is no memory write here.)
|
||||
if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||
state.areEquivalentBufferizedValues(uRead->get(),
|
||||
insertSliceOp.getSource()) &&
|
||||
hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
|
||||
insertSliceOp))
|
||||
return true;
|
||||
|
||||
return false;
|
||||
return isNotConflictingInsertSliceLikeOp<tensor::InsertSliceOp>(
|
||||
op, uRead, uConflictingWrite, state);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
|
@ -805,36 +814,6 @@ struct ReshapeOpInterface
|
|||
}
|
||||
};
|
||||
|
||||
/// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e.
|
||||
/// equivalent operand / result and same offset/sizes/strides specification).
|
||||
static bool areEquivalentExtractSliceOps(const AnalysisState &state,
|
||||
ExtractSliceOp st,
|
||||
ParallelInsertSliceOp sti) {
|
||||
if (!st || !sti)
|
||||
return false;
|
||||
if (st != sti &&
|
||||
!state.areEquivalentBufferizedValues(st.getSource(), sti.getDest()))
|
||||
return false;
|
||||
if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Return true if `value` is originating from an ExtractSliceOp that matches
|
||||
/// the given InsertSliceOp.
|
||||
static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
|
||||
ParallelInsertSliceOp insertOp) {
|
||||
auto condition = [&](Value val) {
|
||||
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
|
||||
if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
|
||||
return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
|
||||
condition);
|
||||
}
|
||||
|
||||
/// Analysis of ParallelInsertSliceOp.
|
||||
struct ParallelInsertSliceOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<
|
||||
|
@ -978,83 +957,11 @@ struct ParallelInsertSliceOpInterface
|
|||
return success();
|
||||
}
|
||||
|
||||
// TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share
|
||||
// the code.
|
||||
bool isNotConflicting(Operation *op, OpOperand *uRead,
|
||||
OpOperand *uConflictingWrite,
|
||||
const AnalysisState &state) const {
|
||||
Operation *readingOp = uRead->getOwner();
|
||||
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
|
||||
|
||||
// Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
|
||||
// uRead is an InsertSliceOp...
|
||||
if (auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(readingOp)) {
|
||||
// As an example, consider the following IR.
|
||||
//
|
||||
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
|
||||
// %1 = linalg.fill %cst, %0 {inplace= [true] }
|
||||
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
|
||||
// {inplace= [true] }
|
||||
|
||||
// TODO: Use insertSliceOp.getDestOpOperand etc. when available.
|
||||
if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||
hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
|
||||
insertSliceOp))
|
||||
// Case 1: The main insight is that InsertSliceOp reads only part of
|
||||
// the destination tensor. The overwritten area is not read. If
|
||||
// uConflictingWrite writes into exactly the memory location that is
|
||||
// being read by uRead, this is not a conflict.
|
||||
//
|
||||
// In the above example:
|
||||
// uRead = OpOperand 1 (%t) of tensor.insert_slice
|
||||
// uConflictingWrite = OpOperand 1 (%0) of linalg.fill
|
||||
//
|
||||
// The read of %t does not conflict with the write of the FillOp
|
||||
// (same aliases!) because the area that the FillOp operates on is
|
||||
// exactly the one that is *not* read via %t.
|
||||
return true;
|
||||
|
||||
if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
|
||||
uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||
hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
|
||||
// Case 2: The read of the source tensor and the write to the dest
|
||||
// tensor via an InsertSliceOp is not a conflict if the read is
|
||||
// reading exactly that part of an equivalent tensor that the
|
||||
// InsertSliceOp is writing.
|
||||
//
|
||||
// In the above example:
|
||||
// uRead = OpOperand 0 (%1) of tensor.insert_slice
|
||||
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
|
||||
return true;
|
||||
}
|
||||
|
||||
// If uConflictingWrite is an InsertSliceOp...
|
||||
if (auto insertSliceOp =
|
||||
dyn_cast<ParallelInsertSliceOp>(conflictingWritingOp))
|
||||
// As an example, consider the following IR.
|
||||
//
|
||||
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
|
||||
// %1 = linalg.fill %cst, %0 {inplace= [true] }
|
||||
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
|
||||
// {inplace= [true] }
|
||||
// %3 = vector.transfer_read %1, %cst
|
||||
//
|
||||
// In the above example:
|
||||
// uRead = OpOperand 0 (%1) of vector.transfer_read
|
||||
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
|
||||
// lastWrite = %1
|
||||
//
|
||||
// This is not a conflict because the InsertSliceOp overwrites the
|
||||
// memory segment of %1 with the exact same data. (Effectively, there
|
||||
// is no memory write here.)
|
||||
if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||
state.areEquivalentBufferizedValues(uRead->get(),
|
||||
insertSliceOp.getSource()) &&
|
||||
hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
|
||||
insertSliceOp))
|
||||
return true;
|
||||
|
||||
return false;
|
||||
return isNotConflictingInsertSliceLikeOp<tensor::ParallelInsertSliceOp>(
|
||||
op, uRead, uConflictingWrite, state);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
Loading…
Reference in New Issue