forked from OSchip/llvm-project
[mlir][linalg][bufferize] Support custom insertion point for buffer copies
By default, copies are inserted right before the tensor OpOperand use. With this change, `bufferize` implementation can change the insertion point. This is needed for some ops where it would be illegal to insert a copy right before the use. Differential Revision: https://reviews.llvm.org/D117291
This commit is contained in:
parent
e58e401b79
commit
1eeffcdb7a
|
@ -386,8 +386,10 @@ public:
|
|||
/// Return the buffer (memref) for a given OpOperand (tensor). Allocate
|
||||
/// a new buffer and copy over data from the existing buffer if out-of-place
|
||||
/// bufferization was decided.
|
||||
FailureOr<Value> getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
|
||||
bool forceInPlace = false) const;
|
||||
FailureOr<Value>
|
||||
getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
|
||||
bool forceInPlace = false,
|
||||
Optional<Operation *> customCopyInsertionPoint = None) const;
|
||||
|
||||
/// Return dialect-specific bufferization state.
|
||||
template <typename StateT>
|
||||
|
|
|
@ -377,7 +377,8 @@ static Value lookupBuffer(RewriterBase &rewriter, Value tensor) {
|
|||
/// bufferization is necessary.
|
||||
FailureOr<Value>
|
||||
mlir::linalg::comprehensive_bufferize::BufferizationState::getBuffer(
|
||||
RewriterBase &rewriter, OpOperand &opOperand, bool forceInPlace) const {
|
||||
RewriterBase &rewriter, OpOperand &opOperand, bool forceInPlace,
|
||||
Optional<Operation *> customCopyInsertionPoint) const {
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
Operation *op = opOperand.getOwner();
|
||||
Location loc = op->getLoc();
|
||||
|
@ -418,9 +419,14 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getBuffer(
|
|||
if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))
|
||||
return resultBuffer;
|
||||
|
||||
// The copy happens right before the op that is bufferized.
|
||||
rewriter.setInsertionPoint(op);
|
||||
if (customCopyInsertionPoint) {
|
||||
rewriter.setInsertionPoint(*customCopyInsertionPoint);
|
||||
} else {
|
||||
// The copy happens right before the op that is bufferized.
|
||||
rewriter.setInsertionPoint(op);
|
||||
}
|
||||
createMemCpy(rewriter, loc, operandBuffer, *resultBuffer);
|
||||
|
||||
return resultBuffer;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue