[mlir][linalg][bufferize] Add mustBufferizeInPlace to op interface

This is useful for ops such as scf::IfOp, which always bufferize in-place.

This commit is in preparation of decoupling BufferizationAliasInfo from the SCF dialect.

Differential Revision: https://reviews.llvm.org/D113339
This commit is contained in:
Matthias Springer 2021-11-10 19:22:42 +09:00
parent e7861449ea
commit 8f6119128f
2 changed files with 69 additions and 14 deletions

View File

@ -89,6 +89,25 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
});
}]
>,
InterfaceMethod<
/*desc=*/[{
Return `true` if the given OpResult must bufferize in-place with its
corresponding aliasing OpOperand. Alias sets and inplace attributes
will be set up accordingly before making any other bufferization
decisions. This method will never be called on OpResults that do not
have a tensor type.
Note: This method may not return `true` if the given OpResult does not
have an aliasing OpOperand.
}],
/*retType=*/"bool",
/*methodName=*/"mustBufferizeInPlace",
/*args=*/(ins "OpResult":$opResult),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return false;
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the OpResult that aliases with a given OpOperand when

View File

@ -538,18 +538,20 @@ BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
createAliasInfoEntry(bbArg);
});
// The return value of an scf::IfOp aliases with both yield values.
rootOp->walk([&](scf::IfOp ifOp) {
if (ifOp->getNumResults() > 0) {
for (auto it : llvm::zip(ifOp.thenYield().results(),
ifOp.elseYield().results(), ifOp.results())) {
aliasInfo.unionSets(std::get<0>(it), std::get<1>(it));
aliasInfo.unionSets(std::get<0>(it), std::get<2>(it));
}
// scf::IfOp always bufferizes in-place.
for (OpResult opResult : ifOp->getResults())
setInPlaceOpResult(opResult, InPlaceSpec::True);
// Set up alias sets for OpResults that must bufferize in-place. This should
// be done before making any other bufferization decisions.
rootOp->walk([&](BufferizableOpInterface bufferizableOp) {
for (OpResult opResult : bufferizableOp->getOpResults()) {
if (opResult.getType().isa<TensorType>())
if (bufferizableOp.mustBufferizeInPlace(opResult)) {
SmallVector<OpOperand *> operands =
bufferizableOp.getAliasingOpOperand(opResult);
assert(!operands.empty() &&
"expected that OpResult has aliasing OpOperand");
for (OpOperand *operand : operands)
aliasInfo.unionSets(operand->get(), opResult);
setInPlaceOpResult(opResult, InPlaceSpec::True);
}
}
});
}
@ -951,9 +953,14 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
/// * However, adding an alias {%0, %t} would mean that the second
/// TransferWriteOp overwrites the first one. Therefore, the TransferReadOp
/// would no longer be reading the result of %1.
///
/// If `checkConsistencyOnly` is true, this function checks if there is a
/// read-after-write conflict without bufferizing `operand` inplace. This would
/// indicate a problem with the current inplace bufferization decisions.
bool wouldCreateReadAfterWriteInterference(
OpOperand &operand, OpResult result, const DominanceInfo &domInfo,
const BufferizationAliasInfo &aliasInfo) {
const BufferizationAliasInfo &aliasInfo,
bool checkConsistencyOnly = false) {
#ifndef NDEBUG
SmallVector<OpOperand *> opOperands = getAliasingOpOperand(result);
assert(llvm::find(opOperands, &operand) != opOperands.end() &&
@ -986,7 +993,7 @@ bool wouldCreateReadAfterWriteInterference(
getAliasingReads(usesRead, result);
getAliasingInplaceWrites(usesWrite, operand.get());
getAliasingInplaceWrites(usesWrite, result);
if (bufferizesToMemoryWrite(operand))
if (!checkConsistencyOnly && bufferizesToMemoryWrite(operand))
usesWrite.insert(&operand);
return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, aliasInfo);
@ -2229,6 +2236,24 @@ LogicalResult mlir::linalg::eliminateInsertSliceAnchoredInitTensorOps(
});
}
/// Assert that the current bufferization decisions are consistent.
static void checkAliasInfoConsistency(FuncOp funcOp,
const DominanceInfo &domInfo,
const BufferizationAliasInfo &aliasInfo) {
funcOp.walk([&](Operation *op) {
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
for (OpOperand &opOperand : op->getOpOperands())
if (opOperand.get().getType().isa<TensorType>())
if (OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand))
// If this assertion fails, there is probably an inconsistent
// combination of "mustBufferizeInPlace" decisions.
assert(!wouldCreateReadAfterWriteInterference(
opOperand, opResult, domInfo, aliasInfo,
/*checkConsistencyOnly=*/true) &&
"found read after write conflict before running analysis");
});
}
LogicalResult
mlir::linalg::runComprehensiveBufferize(ModuleOp moduleOp,
const BufferizationOptions &options) {
@ -2240,6 +2265,7 @@ mlir::linalg::runComprehensiveBufferize(ModuleOp moduleOp,
DominanceInfo domInfo(moduleOp);
BufferizationAliasInfo aliasInfo(moduleOp);
// Interestingly, all function args that are not visible outside of a module
// can be fully bufferized inplace by guaranteeing the CallOp is bufferized
// inplace. Therefore, we just bufferize funcOp as if none of its results were
@ -2260,6 +2286,10 @@ mlir::linalg::runComprehensiveBufferize(ModuleOp moduleOp,
if (bbArg.getType().isa<TensorType>())
setInPlaceFuncArgument(bbArg);
#ifndef NDEBUG
checkAliasInfoConsistency(funcOp, domInfo, aliasInfo);
#endif // NDEBUG
// If the analysis fails, just return.
if (failed(inPlaceAnalysisFuncOpBody(funcOp, aliasInfo, domInfo,
options.analysisFuzzerSeed)))
@ -2778,6 +2808,12 @@ struct IfOpInterface
return true;
}
bool mustBufferizeInPlace(Operation *op, OpResult opResult) const {
// IfOp results always bufferize in-place. Since they have no OpOperands,
// they are mostly ignored by the analysis once alias sets are set up.
return true;
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,