[mlir][linalg][bufferize] Remove special scf::IfOp rules

Remove some of the special rules for scf::IfOp (not all of them) and encode them in the op interface. This is in preparation of decoupling analysis, bufferization and dialects.

Differential Revision: https://reviews.llvm.org/D112901
This commit is contained in:
Matthias Springer 2021-11-10 18:39:16 +09:00
parent 007e55133e
commit be98b20b9d
1 changed files with 46 additions and 7 deletions

View File

@ -546,6 +546,10 @@ BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
aliasInfo.unionSets(std::get<0>(it), std::get<1>(it));
aliasInfo.unionSets(std::get<0>(it), std::get<2>(it));
}
// scf::IfOp always bufferizes in-place.
for (OpResult opResult : ifOp->getResults())
setInPlaceOpResult(opResult, InPlaceSpec::True);
}
});
}
@ -732,10 +736,9 @@ static Value findLastPrecedingWrite(Value value) {
auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
if (!bufferizableOp)
return true;
if (isa<scf::IfOp>(op))
return true;
return bufferizableOp.isMemoryWrite(value.cast<OpResult>());
});
assert(result.size() == 1 && "expected exactly one result");
return result.front();
}
@ -1344,10 +1347,10 @@ static Value getResultBuffer(OpBuilder &b, OpResult result,
}
// If bufferizing out-of-place, allocate a new buffer.
bool needCopy =
getInPlace(result) != InPlaceSpec::True && !isa<scf::IfOp>(op);
bool needCopy = getInPlace(result) != InPlaceSpec::True;
if (needCopy) {
// Ops such as scf::IfOp can currently not bufferize out-of-place.
// Ops with multiple aliasing operands can currently not bufferize
// out-of-place.
assert(
aliasingOperands.size() == 1 &&
"ops with multiple aliasing OpOperands cannot bufferize out-of-place");
@ -2750,15 +2753,47 @@ struct IfOpInterface
: public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
OpResult opResult) const {
// IfOps do not have tensor OpOperands. The yielded value can be any SSA
// value that is in scope. To allow for use-def chain traversal through
// IfOps in the analysis, both corresponding yield values from the then/else
// branches are considered to be aliasing with the result.
auto ifOp = cast<scf::IfOp>(op);
// Either one of the corresponding yield values from the then/else branches
// may alias with the result.
size_t resultNum = std::distance(op->getOpResults().begin(),
llvm::find(op->getOpResults(), opResult));
return {&ifOp.thenYield()->getOpOperand(resultNum),
&ifOp.elseYield()->getOpOperand(resultNum)};
}
// TODO: For better bufferization results, this could return `true` only if
// there is a memory write in one (or both) of the branches. Since this is not
// allowed at the moment, we should never encounter scf.ifs that yield
// unmodified tensors. Such scf.yield ops could just fold away.
bool isMemoryWrite(Operation *op, OpResult opResult) const {
// IfOp results are always considered memory writes in the analysis. This
// design decision simplifies the analysis considerably. E.g., consider the
// following test case:
//
// %0 = "some_writing_op" : tensor<?xf32>
// %r = scf.if %c -> (tensor<?xf32>) {
// scf.yield %0
// } else {
// %1 = "another_writing_op"(%0) : tensor<?xf32>
// }
// "some_reading_op"(%r)
//
// "another_writing_op" in the above example should be able to bufferize
// inplace in the absence of another read of %0. However, if the scf.if op
// would not be considered a "write", the analysis would detect the
// following conflict:
//
// * read = some_reading_op
// * lastWrite = %0 (Note: The last write of %r would be a set: {%0, %1}.)
// * conflictingWrite = %1
//
// For more details, check the "scf.IfOp" section of the design document.
return true;
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,
@ -2873,6 +2908,10 @@ struct YieldOpInterface
return OpResult();
}
BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
return BufferRelation::Equivalent;
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,