forked from OSchip/llvm-project
[mlir][linalg][bufferize][NFC] Simplify buffer API of BufferizationState
Instead of `lookupBuffer` and `getResultBuffer`, there is now a single `getBuffer` function. This simplifies the `BufferizableOpInterface` API and is less confusing to users. They could previously have called the wrong function. Furthermore, since `getBuffer` now takes an `OpOperand &` instead of a `Value`, users can no longer accidentally use one of the previous two functions incorrectly, which would have resulted in missing buffer copies. Differential Revision: https://reviews.llvm.org/D116455
This commit is contained in:
parent
8e2b6aac32
commit
d9184ab1a5
|
@ -377,18 +377,14 @@ public:
|
|||
/// Creates a memcpy between two given buffers.
|
||||
void createMemCpy(OpBuilder &b, Location loc, Value from, Value to) const;
|
||||
|
||||
/// Lookup the memref buffer that is associated to the given tensor value.
|
||||
/// Asserts if no buffer is associated.
|
||||
Value lookupBuffer(RewriterBase &rewriter, Value tensor) const;
|
||||
|
||||
/// Return `true` if the given OpResult has been decided to bufferize inplace.
|
||||
bool isInPlace(OpOperand &opOperand) const;
|
||||
|
||||
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
|
||||
/// Return the buffer (memref) for a given OpOperand (tensor). Allocate
|
||||
/// a new buffer and copy over data from the existing buffer if out-of-place
|
||||
/// bufferization is necessary.
|
||||
FailureOr<Value> getResultBuffer(RewriterBase &rewriter,
|
||||
OpResult result) const;
|
||||
/// bufferization was decided.
|
||||
FailureOr<Value> getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
|
||||
bool forceInPlace = false) const;
|
||||
|
||||
/// Return dialect-specific bufferization state.
|
||||
template <typename StateT>
|
||||
|
|
|
@ -347,74 +347,73 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
|
|||
});
|
||||
}
|
||||
|
||||
static Value lookupBuffer(RewriterBase &rewriter, Value tensor) {
|
||||
assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
|
||||
|
||||
// Replace "%t = to_tensor %m" with %m.
|
||||
if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>())
|
||||
return toTensorOp.memref();
|
||||
|
||||
// Insert to_memref op.
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
setInsertionPointAfter(rewriter, tensor);
|
||||
Type memrefType;
|
||||
if (auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>()) {
|
||||
memrefType = getDynamicMemRefType(rankedTensorType);
|
||||
} else {
|
||||
memrefType = getUnrankedMemRefType(
|
||||
tensor.getType().cast<TensorType>().getElementType());
|
||||
}
|
||||
return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType,
|
||||
tensor);
|
||||
}
|
||||
|
||||
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
|
||||
/// a new buffer and copy over data from the existing buffer if out-of-place
|
||||
/// bufferization is necessary.
|
||||
FailureOr<Value>
|
||||
mlir::linalg::comprehensive_bufferize::BufferizationState::getResultBuffer(
|
||||
RewriterBase &rewriter, OpResult result) const {
|
||||
mlir::linalg::comprehensive_bufferize::BufferizationState::getBuffer(
|
||||
RewriterBase &rewriter, OpOperand &opOperand, bool forceInPlace) const {
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
Operation *op = result.getOwner();
|
||||
SmallVector<OpOperand *> aliasingOperands = getAliasingOpOperand(result);
|
||||
assert(!aliasingOperands.empty() && "could not get aliasing OpOperand");
|
||||
OpOperand *opOperand = aliasingOperands.front();
|
||||
Value operand = opOperand->get();
|
||||
Operation *op = opOperand.getOwner();
|
||||
Location loc = op->getLoc();
|
||||
Value operand = opOperand.get();
|
||||
Value operandBuffer = lookupBuffer(rewriter, operand);
|
||||
// Make sure that all OpOperands are the same buffer. If this is not the case,
|
||||
// we would have to materialize a memref value.
|
||||
// TODO: Should be looking for checking for "equivalent buffers" instead of
|
||||
// operator== here, but equivalent buffers for scf.if yield values are not
|
||||
// set up yet.
|
||||
if (aliasingOperands.size() > 1 &&
|
||||
!llvm::all_of(aliasingOperands, [&](OpOperand *o) {
|
||||
return lookupBuffer(rewriter, o->get()) == operandBuffer;
|
||||
}))
|
||||
return FailureOr<Value>(op->emitError("result buffer is ambiguous"));
|
||||
|
||||
// If bufferizing out-of-place, allocate a new buffer.
|
||||
if (!aliasInfo.isInPlace(*opOperand)) {
|
||||
// Ops with multiple aliasing operands can currently not bufferize
|
||||
// out-of-place.
|
||||
assert(
|
||||
aliasingOperands.size() == 1 &&
|
||||
"ops with multiple aliasing OpOperands cannot bufferize out-of-place");
|
||||
Location loc = op->getLoc();
|
||||
// Move insertion point right after `operandBuffer`. That is where the
|
||||
// allocation should be inserted (in the absence of allocation hoisting).
|
||||
setInsertionPointAfter(rewriter, operandBuffer);
|
||||
// Allocate the result buffer.
|
||||
FailureOr<Value> resultBuffer =
|
||||
createAlloc(rewriter, loc, operandBuffer, options.createDeallocs);
|
||||
if (failed(resultBuffer))
|
||||
return failure();
|
||||
bool skipCopy = false;
|
||||
// Do not copy if the last preceding write of `operand` is an op that does
|
||||
// not write (skipping ops that merely create aliases). E.g., InitTensorOp.
|
||||
// Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA
|
||||
// use-def chain, it returns that value, regardless of whether it is a
|
||||
// memory write or not.
|
||||
Value lastWrite = findLastPrecedingWrite(operand);
|
||||
if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite))
|
||||
if (!bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(), *this))
|
||||
skipCopy = true;
|
||||
// Do not copy if the copied data is never read. (Neither by this op nor by
|
||||
// any following op.)
|
||||
if (!bufferizesToMemoryRead(*opOperand) && !isValueRead(result))
|
||||
skipCopy = true;
|
||||
// Do not copy if this op does not read the data, but writes it.
|
||||
if (bufferizesToMemoryWrite(*opOperand) &&
|
||||
!bufferizesToMemoryRead(*opOperand))
|
||||
skipCopy = true;
|
||||
if (!skipCopy) {
|
||||
// The copy happens right before the op that is bufferized.
|
||||
rewriter.setInsertionPoint(op);
|
||||
createMemCpy(rewriter, loc, operandBuffer, *resultBuffer);
|
||||
}
|
||||
if (forceInPlace || aliasInfo.isInPlace(opOperand))
|
||||
return operandBuffer;
|
||||
|
||||
// Bufferizing out-of-place: Allocate a new buffer.
|
||||
// Move insertion point right after `operandBuffer`. That is where the
|
||||
// allocation should be inserted (in the absence of allocation hoisting).
|
||||
setInsertionPointAfter(rewriter, operandBuffer);
|
||||
// Allocate the result buffer.
|
||||
FailureOr<Value> resultBuffer =
|
||||
createAlloc(rewriter, loc, operandBuffer, options.createDeallocs);
|
||||
if (failed(resultBuffer))
|
||||
return failure();
|
||||
// Do not copy if the last preceding write of `operand` is an op that does
|
||||
// not write (skipping ops that merely create aliases). E.g., InitTensorOp.
|
||||
// Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA
|
||||
// use-def chain, it returns that value, regardless of whether it is a
|
||||
// memory write or not.
|
||||
Value lastWrite = findLastPrecedingWrite(operand);
|
||||
if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite))
|
||||
if (!bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(), *this))
|
||||
return resultBuffer;
|
||||
// Do not copy if the copied data is never read.
|
||||
OpResult aliasingOpResult = getAliasingOpResult(opOperand);
|
||||
if (aliasingOpResult && !bufferizesToMemoryRead(opOperand) &&
|
||||
!isValueRead(aliasingOpResult))
|
||||
return resultBuffer;
|
||||
// Do not copy if this op does not read the data, but writes it.
|
||||
if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))
|
||||
return resultBuffer;
|
||||
}
|
||||
|
||||
// Bufferizing in-place. No need to allocate a new buffer.
|
||||
return operandBuffer;
|
||||
// The copy happens right before the op that is bufferized.
|
||||
rewriter.setInsertionPoint(op);
|
||||
createMemCpy(rewriter, loc, operandBuffer, *resultBuffer);
|
||||
return resultBuffer;
|
||||
}
|
||||
|
||||
void mlir::linalg::comprehensive_bufferize::replaceOpWithBufferizedValues(
|
||||
|
@ -593,28 +592,6 @@ bool mlir::linalg::comprehensive_bufferize::isFunctionArgument(Value value) {
|
|||
return isa<FuncOp>(bbArg.getOwner()->getParentOp());
|
||||
}
|
||||
|
||||
Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
|
||||
RewriterBase &rewriter, Value tensor) const {
|
||||
assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
|
||||
|
||||
// Replace "%t = to_tensor %m" with %m.
|
||||
if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>())
|
||||
return toTensorOp.memref();
|
||||
|
||||
// Insert to_memref op.
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
setInsertionPointAfter(rewriter, tensor);
|
||||
Type memrefType;
|
||||
if (auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>()) {
|
||||
memrefType = getDynamicMemRefType(rankedTensorType);
|
||||
} else {
|
||||
memrefType = getUnrankedMemRefType(
|
||||
tensor.getType().cast<TensorType>().getElementType());
|
||||
}
|
||||
return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType,
|
||||
tensor);
|
||||
}
|
||||
|
||||
bool mlir::linalg::comprehensive_bufferize::BufferizationState::isInPlace(
|
||||
OpOperand &opOperand) const {
|
||||
return aliasInfo.isInPlace(opOperand);
|
||||
|
|
|
@ -46,15 +46,19 @@ static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
|
|||
newInputBuffers.push_back(opOperand->get());
|
||||
continue;
|
||||
}
|
||||
newInputBuffers.push_back(state.lookupBuffer(rewriter, opOperand->get()));
|
||||
// Input operands are never written to.
|
||||
newInputBuffers.push_back(
|
||||
*state.getBuffer(rewriter, *opOperand, /*forceInPlace=*/true));
|
||||
}
|
||||
|
||||
// New output operands for the cloned op.
|
||||
SmallVector<Value> newOutputBuffers;
|
||||
for (OpOperand *opOperand : op.getOutputOperands()) {
|
||||
OpResult opResult = op.getTiedOpResult(opOperand);
|
||||
assert(opResult && "could not find correspond OpResult");
|
||||
FailureOr<Value> resultBuffer = state.getResultBuffer(rewriter, opResult);
|
||||
for (OpResult opResult : op->getOpResults()) {
|
||||
SmallVector<OpOperand *> aliasingOpOperands =
|
||||
state.getAliasingOpOperand(opResult);
|
||||
assert(aliasingOpOperands.size() == 1 && "expected 1 OpOperand");
|
||||
FailureOr<Value> resultBuffer =
|
||||
state.getBuffer(rewriter, *aliasingOpOperands.front());
|
||||
if (failed(resultBuffer))
|
||||
return failure();
|
||||
newOutputBuffers.push_back(*resultBuffer);
|
||||
|
@ -284,24 +288,23 @@ struct TiledLoopOpInterface
|
|||
|
||||
// Compute new inputs, outputs and results.
|
||||
SmallVector<Value> newInputs, newOutputs, newResults;
|
||||
for (Value value : tiledLoopOp.inputs()) {
|
||||
if (value.getType().isa<TensorType>()) {
|
||||
newInputs.push_back(state.lookupBuffer(rewriter, value));
|
||||
} else {
|
||||
newInputs.push_back(value);
|
||||
}
|
||||
}
|
||||
int nextResultNum = 0;
|
||||
for (Value value : tiledLoopOp.outputs()) {
|
||||
if (value.getType().isa<TensorType>()) {
|
||||
FailureOr<Value> buffer = state.getResultBuffer(
|
||||
rewriter, tiledLoopOp->getResult(nextResultNum++));
|
||||
if (failed(buffer))
|
||||
for (int i = tiledLoopOp.getNumControlOperands();
|
||||
i < tiledLoopOp->getNumOperands(); ++i) {
|
||||
OpOperand &operand = tiledLoopOp->getOpOperand(i);
|
||||
Value rewrittenValue = operand.get();
|
||||
if (rewrittenValue.getType().isa<TensorType>()) {
|
||||
FailureOr<Value> bufferOrFailure = state.getBuffer(rewriter, operand);
|
||||
if (failed(bufferOrFailure))
|
||||
return failure();
|
||||
newOutputs.push_back(*buffer);
|
||||
newResults.push_back(*buffer);
|
||||
rewrittenValue = *bufferOrFailure;
|
||||
}
|
||||
if (i <
|
||||
tiledLoopOp.getNumControlOperands() + tiledLoopOp.getNumInputs()) {
|
||||
newInputs.push_back(rewrittenValue);
|
||||
} else {
|
||||
newOutputs.push_back(value);
|
||||
newOutputs.push_back(rewrittenValue);
|
||||
if (operand.get().getType().isa<TensorType>())
|
||||
newResults.push_back(rewrittenValue);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -351,7 +351,7 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
|
|||
|
||||
// Cast values at the call site if necessary.
|
||||
returnValues.push_back(
|
||||
getNonCastedValue(state.lookupBuffer(rewriter, returnVal)));
|
||||
getNonCastedValue(*state.getBuffer(rewriter, returnOperand)));
|
||||
}
|
||||
|
||||
// 2. Rewrite the terminator without the inPlace bufferizable values.
|
||||
|
@ -659,7 +659,8 @@ struct CallOpInterface
|
|||
// Return operands that are equivalent to some bbArg, are not
|
||||
// returned.
|
||||
Value buffer =
|
||||
state.lookupBuffer(rewriter, callOp->getOperand(*bbArgIdx));
|
||||
*state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx),
|
||||
/*forceInPlace=*/true);
|
||||
replacementValues[returnValIdx] = buffer;
|
||||
newOperands[*bbArgIdx] = buffer;
|
||||
continue;
|
||||
|
@ -690,9 +691,9 @@ struct CallOpInterface
|
|||
// Retrieve buffers for tensor operands. Tensor operand buffers, who's
|
||||
// corresponding FuncOp bbArgs are equivalent to a returned tensor, were
|
||||
// already stored in `newOperands` during Step 1.
|
||||
Value buffer = newOperands[idx]
|
||||
? newOperands[idx]
|
||||
: state.lookupBuffer(rewriter, tensorOperand);
|
||||
Value buffer = newOperands[idx] ? newOperands[idx]
|
||||
: *state.getBuffer(rewriter, opOperand,
|
||||
/*forceInPlace=*/true);
|
||||
|
||||
// Caller / callee type mistmatch is handled with a CastOp.
|
||||
auto memRefType = bufferizedFuncType.getInput(idx);
|
||||
|
|
|
@ -280,19 +280,17 @@ struct ForOpInterface
|
|||
};
|
||||
|
||||
// Construct a new scf.for op with memref instead of tensor values.
|
||||
bool resultBufferFailure = false;
|
||||
SmallVector<Value> initArgs =
|
||||
convert(forOp.getInitArgs(), [&](Value val, int64_t index) {
|
||||
FailureOr<Value> resultBuffer =
|
||||
state.getResultBuffer(rewriter, forOp->getOpResult(index));
|
||||
if (failed(resultBuffer)) {
|
||||
resultBufferFailure = true;
|
||||
return Value();
|
||||
}
|
||||
return *resultBuffer;
|
||||
});
|
||||
if (resultBufferFailure)
|
||||
return failure();
|
||||
SmallVector<Value> initArgs;
|
||||
for (OpOperand &opOperand : forOp.getIterOpOperands()) {
|
||||
if (opOperand.get().getType().isa<TensorType>()) {
|
||||
FailureOr<Value> resultBuffer = state.getBuffer(rewriter, opOperand);
|
||||
if (failed(resultBuffer))
|
||||
return failure();
|
||||
initArgs.push_back(*resultBuffer);
|
||||
} else {
|
||||
initArgs.push_back(opOperand.get());
|
||||
}
|
||||
}
|
||||
auto newForOp = rewriter.create<scf::ForOp>(
|
||||
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
|
||||
forOp.getStep(), initArgs);
|
||||
|
|
|
@ -53,7 +53,7 @@ struct CastOpInterface
|
|||
|
||||
// The result buffer still has the old (pre-cast) type.
|
||||
FailureOr<Value> resultBuffer =
|
||||
state.getResultBuffer(rewriter, castOp->getResult(0));
|
||||
state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/);
|
||||
if (failed(resultBuffer))
|
||||
return failure();
|
||||
auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
|
||||
|
@ -106,7 +106,7 @@ struct DimOpInterface
|
|||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
const BufferizationState &state) const {
|
||||
auto dimOp = cast<tensor::DimOp>(op);
|
||||
Value v = state.lookupBuffer(rewriter, dimOp.source());
|
||||
Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/);
|
||||
replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
|
||||
return success();
|
||||
}
|
||||
|
@ -143,7 +143,9 @@ struct ExtractSliceOpInterface
|
|||
const BufferizationState &state) const {
|
||||
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
|
||||
Location loc = extractSliceOp.getLoc();
|
||||
Value srcMemref = state.lookupBuffer(rewriter, extractSliceOp.source());
|
||||
Value srcMemref =
|
||||
*state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
|
||||
/*forceInPlace=*/true);
|
||||
auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
|
||||
auto dstTensorType =
|
||||
extractSliceOp.result().getType().cast<RankedTensorType>();
|
||||
|
@ -206,7 +208,8 @@ struct ExtractOpInterface
|
|||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
const BufferizationState &state) const {
|
||||
auto extractOp = cast<tensor::ExtractOp>(op);
|
||||
Value srcMemref = state.lookupBuffer(rewriter, extractOp.tensor());
|
||||
Value srcMemref =
|
||||
*state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/);
|
||||
replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
|
||||
extractOp.indices());
|
||||
return success();
|
||||
|
@ -244,7 +247,7 @@ struct InsertOpInterface
|
|||
const BufferizationState &state) const {
|
||||
auto insertOp = cast<tensor::InsertOp>(op);
|
||||
FailureOr<Value> destMemref =
|
||||
state.getResultBuffer(rewriter, insertOp->getOpResult(0));
|
||||
state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/);
|
||||
if (failed(destMemref))
|
||||
return failure();
|
||||
rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
|
||||
|
@ -412,7 +415,7 @@ struct InsertSliceOpInterface
|
|||
|
||||
// When bufferizing out-of-place, `getResultBuffer` allocates.
|
||||
FailureOr<Value> dstMemref =
|
||||
state.getResultBuffer(rewriter, insertSliceOp->getResult(0));
|
||||
state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/);
|
||||
if (failed(dstMemref))
|
||||
return failure();
|
||||
|
||||
|
@ -430,7 +433,8 @@ struct InsertSliceOpInterface
|
|||
|
||||
// Copy tensor. If this tensor.insert_slice has a matching
|
||||
// tensor.extract_slice, the copy operation will eventually fold away.
|
||||
Value srcMemref = state.lookupBuffer(rewriter, insertSliceOp.source());
|
||||
Value srcMemref =
|
||||
*state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
|
||||
state.createMemCpy(rewriter, loc, srcMemref, subView);
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
|
||||
|
|
|
@ -48,7 +48,8 @@ struct TransferReadOpInterface
|
|||
"only tensor types expected");
|
||||
|
||||
// TransferReadOp always reads from the bufferized op.source().
|
||||
Value buffer = state.lookupBuffer(rewriter, readOp.source());
|
||||
Value buffer =
|
||||
*state.getBuffer(rewriter, readOp->getOpOperand(0) /*source*/);
|
||||
replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
|
||||
rewriter, readOp, readOp.getVectorType(), buffer, readOp.indices(),
|
||||
readOp.permutation_map(), readOp.padding(), readOp.mask(),
|
||||
|
@ -99,7 +100,7 @@ struct TransferWriteOpInterface
|
|||
// Leave the previous transfer_write to dead code as it still has uses at
|
||||
// this point.
|
||||
FailureOr<Value> resultBuffer =
|
||||
state.getResultBuffer(rewriter, op->getResult(0));
|
||||
state.getBuffer(rewriter, op->getOpOperand(1) /*source*/);
|
||||
if (failed(resultBuffer))
|
||||
return failure();
|
||||
rewriter.create<vector::TransferWriteOp>(
|
||||
|
|
Loading…
Reference in New Issue