[mlir][linalg][bufferize] Support custom insertion point for buffer copies

By default, copies are inserted right before the tensor OpOperand use. With this change, `bufferize` implementation can change the insertion point. This is needed for some ops where it would be illegal to insert a copy right before the use.

Differential Revision: https://reviews.llvm.org/D117291
This commit is contained in:
Matthias Springer 2022-01-14 22:46:52 +09:00
parent e58e401b79
commit 1eeffcdb7a
2 changed files with 13 additions and 5 deletions

View File

@ -386,8 +386,10 @@ public:
/// Return the buffer (memref) for a given OpOperand (tensor). Allocate
/// a new buffer and copy over data from the existing buffer if out-of-place
/// bufferization was decided.
FailureOr<Value> getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
bool forceInPlace = false) const;
FailureOr<Value>
getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
bool forceInPlace = false,
Optional<Operation *> customCopyInsertionPoint = None) const;
/// Return dialect-specific bufferization state.
template <typename StateT>

View File

@ -377,7 +377,8 @@ static Value lookupBuffer(RewriterBase &rewriter, Value tensor) {
/// bufferization is necessary.
FailureOr<Value>
mlir::linalg::comprehensive_bufferize::BufferizationState::getBuffer(
RewriterBase &rewriter, OpOperand &opOperand, bool forceInPlace) const {
RewriterBase &rewriter, OpOperand &opOperand, bool forceInPlace,
Optional<Operation *> customCopyInsertionPoint) const {
OpBuilder::InsertionGuard guard(rewriter);
Operation *op = opOperand.getOwner();
Location loc = op->getLoc();
@ -418,9 +419,14 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getBuffer(
if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))
return resultBuffer;
if (customCopyInsertionPoint) {
rewriter.setInsertionPoint(*customCopyInsertionPoint);
} else {
// The copy happens right before the op that is bufferized.
rewriter.setInsertionPoint(op);
}
createMemCpy(rewriter, loc, operandBuffer, *resultBuffer);
return resultBuffer;
}