forked from OSchip/llvm-project
[mlir][linalg][bufferize] Support custom insertion point for buffer copies
By default, copies are inserted right before the tensor OpOperand use. With this change, `bufferize` implementation can change the insertion point. This is needed for some ops where it would be illegal to insert a copy right before the use. Differential Revision: https://reviews.llvm.org/D117291
This commit is contained in:
parent
e58e401b79
commit
1eeffcdb7a
|
@ -386,8 +386,10 @@ public:
|
||||||
/// Return the buffer (memref) for a given OpOperand (tensor). Allocate
|
/// Return the buffer (memref) for a given OpOperand (tensor). Allocate
|
||||||
/// a new buffer and copy over data from the existing buffer if out-of-place
|
/// a new buffer and copy over data from the existing buffer if out-of-place
|
||||||
/// bufferization was decided.
|
/// bufferization was decided.
|
||||||
FailureOr<Value> getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
|
FailureOr<Value>
|
||||||
bool forceInPlace = false) const;
|
getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
|
||||||
|
bool forceInPlace = false,
|
||||||
|
Optional<Operation *> customCopyInsertionPoint = None) const;
|
||||||
|
|
||||||
/// Return dialect-specific bufferization state.
|
/// Return dialect-specific bufferization state.
|
||||||
template <typename StateT>
|
template <typename StateT>
|
||||||
|
|
|
@ -377,7 +377,8 @@ static Value lookupBuffer(RewriterBase &rewriter, Value tensor) {
|
||||||
/// bufferization is necessary.
|
/// bufferization is necessary.
|
||||||
FailureOr<Value>
|
FailureOr<Value>
|
||||||
mlir::linalg::comprehensive_bufferize::BufferizationState::getBuffer(
|
mlir::linalg::comprehensive_bufferize::BufferizationState::getBuffer(
|
||||||
RewriterBase &rewriter, OpOperand &opOperand, bool forceInPlace) const {
|
RewriterBase &rewriter, OpOperand &opOperand, bool forceInPlace,
|
||||||
|
Optional<Operation *> customCopyInsertionPoint) const {
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
Operation *op = opOperand.getOwner();
|
Operation *op = opOperand.getOwner();
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
|
@ -418,9 +419,14 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getBuffer(
|
||||||
if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))
|
if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))
|
||||||
return resultBuffer;
|
return resultBuffer;
|
||||||
|
|
||||||
|
if (customCopyInsertionPoint) {
|
||||||
|
rewriter.setInsertionPoint(*customCopyInsertionPoint);
|
||||||
|
} else {
|
||||||
// The copy happens right before the op that is bufferized.
|
// The copy happens right before the op that is bufferized.
|
||||||
rewriter.setInsertionPoint(op);
|
rewriter.setInsertionPoint(op);
|
||||||
|
}
|
||||||
createMemCpy(rewriter, loc, operandBuffer, *resultBuffer);
|
createMemCpy(rewriter, loc, operandBuffer, *resultBuffer);
|
||||||
|
|
||||||
return resultBuffer;
|
return resultBuffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue