[mlir][linalg][bufferize] Decouple BufferizationAliasInfo

Move dialect-specific and analysis-specific function out of BufferizationAliasInfo. BufferizationAliasInfo's only job now is to keep track of aliases.

This is in preparation of futher decoupling ComprehensiveBufferize from various dialects.

Differential Revision: https://reviews.llvm.org/D112992
This commit is contained in:
Matthias Springer 2021-11-05 11:40:12 +09:00
parent c8f4005b0c
commit 37317f5bd2
2 changed files with 115 additions and 122 deletions

View File

@ -48,14 +48,6 @@ public:
/// `alias`. Additionally, merge their equivalence classes.
void insertNewBufferEquivalence(Value newValue, Value alias);
/// Return true if, under current bufferization decisions, the buffer of
/// `value` is not writable.
bool aliasesNonWritableBuffer(Value value) const;
/// Return true if the buffer to which `operand` would bufferize is equivalent
/// to some buffer write.
bool aliasesInPlaceWrite(Value v) const;
/// Set the inPlace bufferization spec to true.
/// Merge result's and operand's aliasing sets and iterate to a fixed point.
void bufferizeInPlace(OpResult result, OpOperand &operand);
@ -63,23 +55,6 @@ public:
/// Set the inPlace bufferization spec to false.
void bufferizeOutOfPlace(OpResult result);
/// Return true if `value` has an ExtractSliceOp matching the given
/// InsertSliceOp in its reverse SSA use-def chain.
bool hasMatchingExtractSliceOp(Value value,
tensor::InsertSliceOp insertOp) const;
/// Return true if bufferizing `opOperand` inplace with `opResult` would
/// create a write to a non-writable buffer.
bool wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand,
OpResult opResult) const;
/// Assume that result bufferizes in-place with one of the operation's
/// operands. Return true if it is possible to find an inplace write W that
/// creates a conflict.
bool
wouldCreateReadAfterWriteInterference(OpOperand &operand, OpResult result,
const DominanceInfo &domInfo) const;
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
bool areEquivalentBufferizedValues(Value v1, Value v2) const {
// Return `false` if we have no information about `v1` or `v2`.
@ -91,14 +66,13 @@ public:
equivalentInfo.getLeaderValue(v2);
}
/// Return true if the source of an `insertSliceOp` bufferizes to an
/// equivalent ExtractSliceOp.
bool isSourceEquivalentToAMatchingInplaceExtractSliceOp(
tensor::InsertSliceOp insertSliceOp) const;
/// Apply `fun` to all the members of the equivalence class of `v`.
void applyOnEquivalenceClass(Value v, function_ref<void(Value)> fun) const;
/// Apply `fun` to all aliases of `v`.
void applyOnAliases(Value v, function_ref<void(Value)> fun) const;
// TODO: Move these out of BufferizationAliasInfo.
/// Return true if the value is known to bufferize to writable memory.
bool bufferizesToWritableMemory(Value v) const;
@ -128,22 +102,6 @@ private:
/// Check that aliasInfo for `v` exists and return a reference to it.
EquivalenceClassRangeType getAliases(Value v) const;
/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
/// equivalent operand / result and same offset/sizes/strides specification).
///
/// This is one particular type of relationship between ops on tensors that
/// reduce to an equivalence on buffers. This should be generalized and
/// exposed as interfaces on the proper types.
bool areEquivalentExtractSliceOps(tensor::ExtractSliceOp st,
tensor::InsertSliceOp sti) const;
/// Given sets of uses and writes, return true if there is a RaW conflict
/// under the assumption that all given reads/writes alias the same buffer and
/// that all given writes bufferize inplace.
bool hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
const DenseSet<OpOperand *> &usesWrite,
const DominanceInfo &domInfo) const;
/// Set of tensors that are known to bufferize to writable memory.
llvm::DenseSet<Value> bufferizeToWritableMemory;

View File

@ -508,6 +508,24 @@ static BufferRelation bufferRelation(OpOperand &opOperand) {
// Bufferization-specific alias analysis.
//===----------------------------------------------------------------------===//
/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
/// equivalent operand / result and same offset/sizes/strides specification).
///
/// This is one particular type of relationship between ops on tensors that
/// reduce to an equivalence on buffers. This should be generalized and
/// exposed as interfaces on the proper types.
static bool
areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
ExtractSliceOp st, InsertSliceOp sti) {
if (!st || !sti)
return false;
if (!aliasInfo.areEquivalentBufferizedValues(st.source(), sti.dest()))
return false;
if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
return false;
return true;
}
/// Return true if opOperand has been decided to bufferize in-place.
static bool isInplaceMemoryWrite(OpOperand &opOperand) {
// Ops that do not bufferize to a memory write, cannot be write in-place.
@ -567,24 +585,27 @@ void BufferizationAliasInfo::insertNewBufferEquivalence(Value newValue,
/// Return true if, under current bufferization decisions, the buffer of `value`
/// is not writable.
bool BufferizationAliasInfo::aliasesNonWritableBuffer(Value value) const {
static bool aliasesNonWritableBuffer(Value value,
const BufferizationAliasInfo &aliasInfo) {
LDBG("----Start aliasesNonWritableBuffer\n");
for (Value v : getAliases(value)) {
bool foundNonWritableBuffer = false;
aliasInfo.applyOnAliases(value, [&](Value v) {
LDBG("-----------examine: " << printValueInfo(v) << '\n');
if (bufferizesToWritableMemory(v)) {
if (aliasInfo.bufferizesToWritableMemory(v)) {
LDBG("-----------Value is known to be writable -> skip: "
<< printValueInfo(v) << '\n');
continue;
return;
}
if (auto bbArg = v.dyn_cast<BlockArgument>()) {
if (getInPlace(bbArg) == InPlaceSpec::True) {
LDBG("-----------bbArg is writable -> skip: " << printValueInfo(bbArg)
<< '\n');
continue;
return;
}
LDBG("-----------notWritable bbArg\n");
return true;
foundNonWritableBuffer = true;
return;
}
auto bufferizableOp = dyn_cast<BufferizableOpInterface>(v.getDefiningOp());
@ -592,11 +613,15 @@ bool BufferizationAliasInfo::aliasesNonWritableBuffer(Value value) const {
// Unknown ops are treated conservatively: Assume that it is illegal to
// write to their OpResults in-place.
LDBG("-----------notWritable op\n");
return true;
foundNonWritableBuffer = true;
return;
}
}
LDBG("---->value is writable\n");
return false;
});
if (!foundNonWritableBuffer)
LDBG("---->value is writable\n");
return foundNonWritableBuffer;
}
bool BufferizationAliasInfo::bufferizesToWritableMemory(Value v) const {
@ -610,20 +635,26 @@ void BufferizationAliasInfo::setBufferizesToWritableMemory(Value v) {
/// Return true if the buffer to which `operand` would bufferize is equivalent
/// to some buffer write.
bool BufferizationAliasInfo::aliasesInPlaceWrite(Value value) const {
static bool aliasesInPlaceWrite(Value value,
const BufferizationAliasInfo &aliasInfo) {
LDBG("----Start aliasesInPlaceWrite\n");
LDBG("-------for : " << printValueInfo(value) << '\n');
for (Value v : getAliases(value)) {
bool foundInplaceWrite = false;
aliasInfo.applyOnAliases(value, [&](Value v) {
for (auto &use : v.getUses()) {
if (isInplaceMemoryWrite(use)) {
LDBG("-----------wants to bufferize to inPlace write: "
<< printOperationInfo(use.getOwner()) << '\n');
return true;
foundInplaceWrite = true;
return;
}
}
}
LDBG("----------->does not alias an inplace write\n");
return false;
});
if (!foundInplaceWrite)
LDBG("----------->does not alias an inplace write\n");
return foundInplaceWrite;
}
/// Set the inPlace bufferization spec to true.
@ -731,11 +762,11 @@ static Value findLastPrecedingWrite(Value value) {
/// Return true if `value` is originating from an ExtractSliceOp that matches
/// the given InsertSliceOp.
bool BufferizationAliasInfo::hasMatchingExtractSliceOp(
Value value, InsertSliceOp insertOp) const {
static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
Value value, InsertSliceOp insertOp) {
auto condition = [&](Value val) {
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
if (areEquivalentExtractSliceOps(extractOp, insertOp))
if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp))
return true;
return false;
};
@ -766,10 +797,11 @@ static bool happensBefore(Operation *a, Operation *b,
/// A conflict is: According to SSA use-def chains, a read R is supposed to read
/// the result of a write W1. But because of bufferization decisions, R actually
/// reads another write W2.
bool BufferizationAliasInfo::hasReadAfterWriteInterference(
const DenseSet<OpOperand *> &usesRead,
const DenseSet<OpOperand *> &usesWrite,
const DominanceInfo &domInfo) const {
static bool
hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
const DenseSet<OpOperand *> &usesWrite,
const DominanceInfo &domInfo,
const BufferizationAliasInfo &aliasInfo) {
for (OpOperand *uRead : usesRead) {
Operation *readingOp = uRead->getOwner();
@ -850,7 +882,8 @@ bool BufferizationAliasInfo::hasReadAfterWriteInterference(
// TODO: Use insertSliceOp.getDestOpOperand etc. when available.
if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
hasMatchingExtractSliceOp(uConflictingWrite->get(), insertSliceOp))
hasMatchingExtractSliceOp(aliasInfo, uConflictingWrite->get(),
insertSliceOp))
// Case 1: The main insight is that InsertSliceOp reads only part of
// the destination tensor. The overwritten area is not read. If
// uConflictingWrite writes into exactly the memory location that is
@ -867,7 +900,7 @@ bool BufferizationAliasInfo::hasReadAfterWriteInterference(
if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
hasMatchingExtractSliceOp(uRead->get(), insertSliceOp))
hasMatchingExtractSliceOp(aliasInfo, uRead->get(), insertSliceOp))
// Case 2: The read of the source tensor and the write to the dest
// tensor via an InsertSliceOp is not a conflict if the read is
// reading exactly that part of an equivalent tensor that the
@ -910,8 +943,9 @@ bool BufferizationAliasInfo::hasReadAfterWriteInterference(
/// * However, adding an alias {%0, %t} would mean that the second
/// TransferWriteOp overwrites the first one. Therefore, the TransferReadOp
/// would no longer be reading the result of %1.
bool BufferizationAliasInfo::wouldCreateReadAfterWriteInterference(
OpOperand &operand, OpResult result, const DominanceInfo &domInfo) const {
bool wouldCreateReadAfterWriteInterference(
OpOperand &operand, OpResult result, const DominanceInfo &domInfo,
const BufferizationAliasInfo &aliasInfo) {
#ifndef NDEBUG
SmallVector<OpOperand *> opOperands = getAliasingOpOperand(result);
assert(llvm::find(opOperands, &operand) != opOperands.end() &&
@ -920,20 +954,22 @@ bool BufferizationAliasInfo::wouldCreateReadAfterWriteInterference(
// Helper function to iterate on aliases of `root` and capture the reads.
auto getAliasingReads = [&](DenseSet<OpOperand *> &res, Value root) {
for (Value alias : getAliases(root))
aliasInfo.applyOnAliases(root, [&](Value alias) {
for (auto &use : alias.getUses())
// Read to a value that aliases root.
if (bufferizesToMemoryRead(use))
res.insert(&use);
});
};
// Helper function to iterate on aliases of `root` and capture the writes.
auto getAliasingInplaceWrites = [&](DenseSet<OpOperand *> &res, Value root) {
for (Value alias : getAliases(root))
aliasInfo.applyOnAliases(root, [&](Value alias) {
for (auto &use : alias.getUses())
// Inplace write to a value that aliases root.
if (isInplaceMemoryWrite(use))
res.insert(&use);
});
};
// Collect reads and writes of all aliases of OpOperand and OpResult.
@ -945,13 +981,14 @@ bool BufferizationAliasInfo::wouldCreateReadAfterWriteInterference(
if (bufferizesToMemoryWrite(operand))
usesWrite.insert(&operand);
return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo);
return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, aliasInfo);
}
/// Return true if bufferizing `opOperand` inplace with `opResult` would create
/// a write to a non-writable buffer.
bool BufferizationAliasInfo::wouldCreateWriteToNonWritableBuffer(
OpOperand &opOperand, OpResult opResult) const {
static bool
wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand, OpResult opResult,
const BufferizationAliasInfo &aliasInfo) {
#ifndef NDEBUG
SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
assert(llvm::find(opOperands, &opOperand) != opOperands.end() &&
@ -961,15 +998,15 @@ bool BufferizationAliasInfo::wouldCreateWriteToNonWritableBuffer(
// Certain buffers are not writeable:
// 1. A function bbArg that is not inplaceable or
// 2. A constant op.
assert(!aliasesNonWritableBuffer(opResult) &&
assert(!aliasesNonWritableBuffer(opResult, aliasInfo) &&
"expected that opResult does not alias non-writable buffer");
bool nonWritable = aliasesNonWritableBuffer(opOperand.get());
bool nonWritable = aliasesNonWritableBuffer(opOperand.get(), aliasInfo);
if (!nonWritable)
return false;
// This is a problem only if the buffer is written to via some alias.
bool hasWrite = aliasesInPlaceWrite(opResult) ||
aliasesInPlaceWrite(opOperand.get()) ||
bool hasWrite = aliasesInPlaceWrite(opResult, aliasInfo) ||
aliasesInPlaceWrite(opOperand.get(), aliasInfo) ||
bufferizesToMemoryWrite(opOperand);
if (!hasWrite)
return false;
@ -978,28 +1015,6 @@ bool BufferizationAliasInfo::wouldCreateWriteToNonWritableBuffer(
return true;
}
/// Return true if the source of a `insertSliceOp` bufferizes to an
/// equivalent ExtractSliceOp that bufferizes inplace.
bool BufferizationAliasInfo::isSourceEquivalentToAMatchingInplaceExtractSliceOp(
InsertSliceOp insertSliceOp) const {
LDBG("isSourceEquivalentToAMatchingInplaceExtractSliceOp: " << *insertSliceOp
<< '\n');
auto leaderIt = equivalentInfo.findLeader(insertSliceOp.source());
for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
++mit) {
auto extractSliceOp =
dyn_cast_or_null<ExtractSliceOp>(mit->getDefiningOp());
if (extractSliceOp &&
areEquivalentExtractSliceOps(extractSliceOp, insertSliceOp) &&
getInPlace(extractSliceOp.result()) == InPlaceSpec::True) {
LDBG("\tfound: " << *mit->getDefiningOp() << '\n');
return true;
}
}
LDBG("\tnot equivalent\n");
return false;
}
/// Apply `fun` to all the members of the equivalence class of `v`.
void BufferizationAliasInfo::applyOnEquivalenceClass(
Value v, function_ref<void(Value)> fun) const {
@ -1010,6 +1025,15 @@ void BufferizationAliasInfo::applyOnEquivalenceClass(
}
}
/// Apply `fun` to all aliases of `v`.
void BufferizationAliasInfo::applyOnAliases(
Value v, function_ref<void(Value)> fun) const {
auto leaderIt = aliasInfo.findLeader(v);
for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) {
fun(*mit);
}
}
void BufferizationAliasInfo::printAliases(raw_ostream &os) const {
os << "\n/===================== AliasInfo =====================\n";
for (auto it = aliasInfo.begin(), eit = aliasInfo.end(); it != eit; ++it) {
@ -1066,20 +1090,6 @@ void BufferizationAliasInfo::dumpEquivalences() const {
printEquivalences(llvm::errs());
}
/// This is one particular type of relationship between ops on tensors that
/// reduce to an equivalence on buffers. This should be generalized and exposed
/// as interfaces on the proper types.
bool BufferizationAliasInfo::areEquivalentExtractSliceOps(
ExtractSliceOp st, InsertSliceOp sti) const {
if (!st || !sti)
return false;
if (!equivalentInfo.isEquivalent(st.source(), sti.dest()))
return false;
if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
return false;
return true;
}
//===----------------------------------------------------------------------===//
// Forward declarations.
//===----------------------------------------------------------------------===//
@ -1475,8 +1485,9 @@ bufferizableInPlaceAnalysisImpl(OpOperand &operand, OpResult result,
<< printValueInfo(result) << '\n');
bool foundInterference =
aliasInfo.wouldCreateWriteToNonWritableBuffer(operand, result) ||
aliasInfo.wouldCreateReadAfterWriteInterference(operand, result, domInfo);
wouldCreateWriteToNonWritableBuffer(operand, result, aliasInfo) ||
wouldCreateReadAfterWriteInterference(operand, result, domInfo,
aliasInfo);
if (foundInterference)
aliasInfo.bufferizeOutOfPlace(result);
@ -3276,6 +3287,30 @@ struct ExtractOpInterface
}
};
/// Return true if the source of a `insertSliceOp` bufferizes to an
/// equivalent ExtractSliceOp that bufferizes inplace.
static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp(
const BufferizationAliasInfo &aliasInfo, InsertSliceOp insertSliceOp) {
LDBG("isSourceEquivalentToAMatchingInplaceExtractSliceOp: " << *insertSliceOp
<< '\n');
bool foundOp = false;
aliasInfo.applyOnEquivalenceClass(insertSliceOp.source(), [&](Value value) {
auto extractSliceOp = value.getDefiningOp<ExtractSliceOp>();
if (extractSliceOp &&
areEquivalentExtractSliceOps(aliasInfo, extractSliceOp,
insertSliceOp) &&
getInPlace(extractSliceOp.result()) == InPlaceSpec::True) {
LDBG("\tfound: " << extractSliceOp.getOperation() << '\n');
foundOp = true;
}
});
if (!foundOp)
LDBG("\tnot equivalent\n");
return foundOp;
}
struct InsertSliceOpInterface
: public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
tensor::InsertSliceOp> {
@ -3345,8 +3380,8 @@ struct InsertSliceOpInterface
// cloned and the clone needs to be updated.
auto inPlace = getInPlace(insertSliceOp->getResult(0));
// TODO: Is this necessary?
if (!aliasInfo.isSourceEquivalentToAMatchingInplaceExtractSliceOp(
insertSliceOp) ||
if (!isSourceEquivalentToAMatchingInplaceExtractSliceOp(aliasInfo,
insertSliceOp) ||
inPlace != InPlaceSpec::True) {
LDBG("insert_slice needs extra source copy: " << insertSliceOp.source()
<< " -> copy\n");