[mlir][linalg][bufferize][NFC] Simplify buffer API of BufferizationState

Instead of `lookupBuffer` and `getResultBuffer`, there is now a single `getBuffer` function. This simplifies the `BufferizableOpInterface` API and is less confusing to users. They could previously have called the wrong function.

Furthermore, since `getBuffer` now takes an `OpOperand &` instead of a `Value`, users can no longer accidentally use one of the previous two functions incorrectly, which would have resulted in missing buffer copies.

Differential Revision: https://reviews.llvm.org/D116455
This commit is contained in:
Matthias Springer 2022-01-08 01:10:15 +09:00
parent 8e2b6aac32
commit d9184ab1a5
7 changed files with 117 additions and 137 deletions

View File

@ -377,18 +377,14 @@ public:
/// Creates a memcpy between two given buffers.
void createMemCpy(OpBuilder &b, Location loc, Value from, Value to) const;
/// Lookup the memref buffer that is associated to the given tensor value.
/// Asserts if no buffer is associated.
Value lookupBuffer(RewriterBase &rewriter, Value tensor) const;
/// Return `true` if the given OpResult has been decided to bufferize inplace.
bool isInPlace(OpOperand &opOperand) const;
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
/// Return the buffer (memref) for a given OpOperand (tensor). Allocate
/// a new buffer and copy over data from the existing buffer if out-of-place
/// bufferization is necessary.
FailureOr<Value> getResultBuffer(RewriterBase &rewriter,
OpResult result) const;
/// bufferization was decided.
FailureOr<Value> getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
bool forceInPlace = false) const;
/// Return dialect-specific bufferization state.
template <typename StateT>

View File

@ -347,74 +347,73 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
});
}
static Value lookupBuffer(RewriterBase &rewriter, Value tensor) {
assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
// Replace "%t = to_tensor %m" with %m.
if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>())
return toTensorOp.memref();
// Insert to_memref op.
OpBuilder::InsertionGuard g(rewriter);
setInsertionPointAfter(rewriter, tensor);
Type memrefType;
if (auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>()) {
memrefType = getDynamicMemRefType(rankedTensorType);
} else {
memrefType = getUnrankedMemRefType(
tensor.getType().cast<TensorType>().getElementType());
}
return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType,
tensor);
}
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
/// a new buffer and copy over data from the existing buffer if out-of-place
/// bufferization is necessary.
FailureOr<Value>
mlir::linalg::comprehensive_bufferize::BufferizationState::getResultBuffer(
RewriterBase &rewriter, OpResult result) const {
mlir::linalg::comprehensive_bufferize::BufferizationState::getBuffer(
RewriterBase &rewriter, OpOperand &opOperand, bool forceInPlace) const {
OpBuilder::InsertionGuard guard(rewriter);
Operation *op = result.getOwner();
SmallVector<OpOperand *> aliasingOperands = getAliasingOpOperand(result);
assert(!aliasingOperands.empty() && "could not get aliasing OpOperand");
OpOperand *opOperand = aliasingOperands.front();
Value operand = opOperand->get();
Operation *op = opOperand.getOwner();
Location loc = op->getLoc();
Value operand = opOperand.get();
Value operandBuffer = lookupBuffer(rewriter, operand);
// Make sure that all OpOperands are the same buffer. If this is not the case,
// we would have to materialize a memref value.
// TODO: Should be looking for checking for "equivalent buffers" instead of
// operator== here, but equivalent buffers for scf.if yield values are not
// set up yet.
if (aliasingOperands.size() > 1 &&
!llvm::all_of(aliasingOperands, [&](OpOperand *o) {
return lookupBuffer(rewriter, o->get()) == operandBuffer;
}))
return FailureOr<Value>(op->emitError("result buffer is ambiguous"));
// If bufferizing out-of-place, allocate a new buffer.
if (!aliasInfo.isInPlace(*opOperand)) {
// Ops with multiple aliasing operands can currently not bufferize
// out-of-place.
assert(
aliasingOperands.size() == 1 &&
"ops with multiple aliasing OpOperands cannot bufferize out-of-place");
Location loc = op->getLoc();
// Move insertion point right after `operandBuffer`. That is where the
// allocation should be inserted (in the absence of allocation hoisting).
setInsertionPointAfter(rewriter, operandBuffer);
// Allocate the result buffer.
FailureOr<Value> resultBuffer =
createAlloc(rewriter, loc, operandBuffer, options.createDeallocs);
if (failed(resultBuffer))
return failure();
bool skipCopy = false;
// Do not copy if the last preceding write of `operand` is an op that does
// not write (skipping ops that merely create aliases). E.g., InitTensorOp.
// Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA
// use-def chain, it returns that value, regardless of whether it is a
// memory write or not.
Value lastWrite = findLastPrecedingWrite(operand);
if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite))
if (!bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(), *this))
skipCopy = true;
// Do not copy if the copied data is never read. (Neither by this op nor by
// any following op.)
if (!bufferizesToMemoryRead(*opOperand) && !isValueRead(result))
skipCopy = true;
// Do not copy if this op does not read the data, but writes it.
if (bufferizesToMemoryWrite(*opOperand) &&
!bufferizesToMemoryRead(*opOperand))
skipCopy = true;
if (!skipCopy) {
// The copy happens right before the op that is bufferized.
rewriter.setInsertionPoint(op);
createMemCpy(rewriter, loc, operandBuffer, *resultBuffer);
}
if (forceInPlace || aliasInfo.isInPlace(opOperand))
return operandBuffer;
// Bufferizing out-of-place: Allocate a new buffer.
// Move insertion point right after `operandBuffer`. That is where the
// allocation should be inserted (in the absence of allocation hoisting).
setInsertionPointAfter(rewriter, operandBuffer);
// Allocate the result buffer.
FailureOr<Value> resultBuffer =
createAlloc(rewriter, loc, operandBuffer, options.createDeallocs);
if (failed(resultBuffer))
return failure();
// Do not copy if the last preceding write of `operand` is an op that does
// not write (skipping ops that merely create aliases). E.g., InitTensorOp.
// Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA
// use-def chain, it returns that value, regardless of whether it is a
// memory write or not.
Value lastWrite = findLastPrecedingWrite(operand);
if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite))
if (!bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(), *this))
return resultBuffer;
// Do not copy if the copied data is never read.
OpResult aliasingOpResult = getAliasingOpResult(opOperand);
if (aliasingOpResult && !bufferizesToMemoryRead(opOperand) &&
!isValueRead(aliasingOpResult))
return resultBuffer;
// Do not copy if this op does not read the data, but writes it.
if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))
return resultBuffer;
}
// Bufferizing in-place. No need to allocate a new buffer.
return operandBuffer;
// The copy happens right before the op that is bufferized.
rewriter.setInsertionPoint(op);
createMemCpy(rewriter, loc, operandBuffer, *resultBuffer);
return resultBuffer;
}
void mlir::linalg::comprehensive_bufferize::replaceOpWithBufferizedValues(
@ -593,28 +592,6 @@ bool mlir::linalg::comprehensive_bufferize::isFunctionArgument(Value value) {
return isa<FuncOp>(bbArg.getOwner()->getParentOp());
}
Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
RewriterBase &rewriter, Value tensor) const {
assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
// Replace "%t = to_tensor %m" with %m.
if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>())
return toTensorOp.memref();
// Insert to_memref op.
OpBuilder::InsertionGuard g(rewriter);
setInsertionPointAfter(rewriter, tensor);
Type memrefType;
if (auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>()) {
memrefType = getDynamicMemRefType(rankedTensorType);
} else {
memrefType = getUnrankedMemRefType(
tensor.getType().cast<TensorType>().getElementType());
}
return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType,
tensor);
}
bool mlir::linalg::comprehensive_bufferize::BufferizationState::isInPlace(
OpOperand &opOperand) const {
return aliasInfo.isInPlace(opOperand);

View File

@ -46,15 +46,19 @@ static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
newInputBuffers.push_back(opOperand->get());
continue;
}
newInputBuffers.push_back(state.lookupBuffer(rewriter, opOperand->get()));
// Input operands are never written to.
newInputBuffers.push_back(
*state.getBuffer(rewriter, *opOperand, /*forceInPlace=*/true));
}
// New output operands for the cloned op.
SmallVector<Value> newOutputBuffers;
for (OpOperand *opOperand : op.getOutputOperands()) {
OpResult opResult = op.getTiedOpResult(opOperand);
assert(opResult && "could not find correspond OpResult");
FailureOr<Value> resultBuffer = state.getResultBuffer(rewriter, opResult);
for (OpResult opResult : op->getOpResults()) {
SmallVector<OpOperand *> aliasingOpOperands =
state.getAliasingOpOperand(opResult);
assert(aliasingOpOperands.size() == 1 && "expected 1 OpOperand");
FailureOr<Value> resultBuffer =
state.getBuffer(rewriter, *aliasingOpOperands.front());
if (failed(resultBuffer))
return failure();
newOutputBuffers.push_back(*resultBuffer);
@ -284,24 +288,23 @@ struct TiledLoopOpInterface
// Compute new inputs, outputs and results.
SmallVector<Value> newInputs, newOutputs, newResults;
for (Value value : tiledLoopOp.inputs()) {
if (value.getType().isa<TensorType>()) {
newInputs.push_back(state.lookupBuffer(rewriter, value));
} else {
newInputs.push_back(value);
}
}
int nextResultNum = 0;
for (Value value : tiledLoopOp.outputs()) {
if (value.getType().isa<TensorType>()) {
FailureOr<Value> buffer = state.getResultBuffer(
rewriter, tiledLoopOp->getResult(nextResultNum++));
if (failed(buffer))
for (int i = tiledLoopOp.getNumControlOperands();
i < tiledLoopOp->getNumOperands(); ++i) {
OpOperand &operand = tiledLoopOp->getOpOperand(i);
Value rewrittenValue = operand.get();
if (rewrittenValue.getType().isa<TensorType>()) {
FailureOr<Value> bufferOrFailure = state.getBuffer(rewriter, operand);
if (failed(bufferOrFailure))
return failure();
newOutputs.push_back(*buffer);
newResults.push_back(*buffer);
rewrittenValue = *bufferOrFailure;
}
if (i <
tiledLoopOp.getNumControlOperands() + tiledLoopOp.getNumInputs()) {
newInputs.push_back(rewrittenValue);
} else {
newOutputs.push_back(value);
newOutputs.push_back(rewrittenValue);
if (operand.get().getType().isa<TensorType>())
newResults.push_back(rewrittenValue);
}
}

View File

@ -351,7 +351,7 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
// Cast values at the call site if necessary.
returnValues.push_back(
getNonCastedValue(state.lookupBuffer(rewriter, returnVal)));
getNonCastedValue(*state.getBuffer(rewriter, returnOperand)));
}
// 2. Rewrite the terminator without the inPlace bufferizable values.
@ -659,7 +659,8 @@ struct CallOpInterface
// Return operands that are equivalent to some bbArg, are not
// returned.
Value buffer =
state.lookupBuffer(rewriter, callOp->getOperand(*bbArgIdx));
*state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx),
/*forceInPlace=*/true);
replacementValues[returnValIdx] = buffer;
newOperands[*bbArgIdx] = buffer;
continue;
@ -690,9 +691,9 @@ struct CallOpInterface
// Retrieve buffers for tensor operands. Tensor operand buffers, who's
// corresponding FuncOp bbArgs are equivalent to a returned tensor, were
// already stored in `newOperands` during Step 1.
Value buffer = newOperands[idx]
? newOperands[idx]
: state.lookupBuffer(rewriter, tensorOperand);
Value buffer = newOperands[idx] ? newOperands[idx]
: *state.getBuffer(rewriter, opOperand,
/*forceInPlace=*/true);
// Caller / callee type mistmatch is handled with a CastOp.
auto memRefType = bufferizedFuncType.getInput(idx);

View File

@ -280,19 +280,17 @@ struct ForOpInterface
};
// Construct a new scf.for op with memref instead of tensor values.
bool resultBufferFailure = false;
SmallVector<Value> initArgs =
convert(forOp.getInitArgs(), [&](Value val, int64_t index) {
FailureOr<Value> resultBuffer =
state.getResultBuffer(rewriter, forOp->getOpResult(index));
if (failed(resultBuffer)) {
resultBufferFailure = true;
return Value();
}
return *resultBuffer;
});
if (resultBufferFailure)
return failure();
SmallVector<Value> initArgs;
for (OpOperand &opOperand : forOp.getIterOpOperands()) {
if (opOperand.get().getType().isa<TensorType>()) {
FailureOr<Value> resultBuffer = state.getBuffer(rewriter, opOperand);
if (failed(resultBuffer))
return failure();
initArgs.push_back(*resultBuffer);
} else {
initArgs.push_back(opOperand.get());
}
}
auto newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), initArgs);

View File

@ -53,7 +53,7 @@ struct CastOpInterface
// The result buffer still has the old (pre-cast) type.
FailureOr<Value> resultBuffer =
state.getResultBuffer(rewriter, castOp->getResult(0));
state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/);
if (failed(resultBuffer))
return failure();
auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
@ -106,7 +106,7 @@ struct DimOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationState &state) const {
auto dimOp = cast<tensor::DimOp>(op);
Value v = state.lookupBuffer(rewriter, dimOp.source());
Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/);
replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
return success();
}
@ -143,7 +143,9 @@ struct ExtractSliceOpInterface
const BufferizationState &state) const {
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
Location loc = extractSliceOp.getLoc();
Value srcMemref = state.lookupBuffer(rewriter, extractSliceOp.source());
Value srcMemref =
*state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
/*forceInPlace=*/true);
auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
auto dstTensorType =
extractSliceOp.result().getType().cast<RankedTensorType>();
@ -206,7 +208,8 @@ struct ExtractOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationState &state) const {
auto extractOp = cast<tensor::ExtractOp>(op);
Value srcMemref = state.lookupBuffer(rewriter, extractOp.tensor());
Value srcMemref =
*state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/);
replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
extractOp.indices());
return success();
@ -244,7 +247,7 @@ struct InsertOpInterface
const BufferizationState &state) const {
auto insertOp = cast<tensor::InsertOp>(op);
FailureOr<Value> destMemref =
state.getResultBuffer(rewriter, insertOp->getOpResult(0));
state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/);
if (failed(destMemref))
return failure();
rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
@ -412,7 +415,7 @@ struct InsertSliceOpInterface
// When bufferizing out-of-place, `getResultBuffer` allocates.
FailureOr<Value> dstMemref =
state.getResultBuffer(rewriter, insertSliceOp->getResult(0));
state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/);
if (failed(dstMemref))
return failure();
@ -430,7 +433,8 @@ struct InsertSliceOpInterface
// Copy tensor. If this tensor.insert_slice has a matching
// tensor.extract_slice, the copy operation will eventually fold away.
Value srcMemref = state.lookupBuffer(rewriter, insertSliceOp.source());
Value srcMemref =
*state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
state.createMemCpy(rewriter, loc, srcMemref, subView);
replaceOpWithBufferizedValues(rewriter, op, *dstMemref);

View File

@ -48,7 +48,8 @@ struct TransferReadOpInterface
"only tensor types expected");
// TransferReadOp always reads from the bufferized op.source().
Value buffer = state.lookupBuffer(rewriter, readOp.source());
Value buffer =
*state.getBuffer(rewriter, readOp->getOpOperand(0) /*source*/);
replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
rewriter, readOp, readOp.getVectorType(), buffer, readOp.indices(),
readOp.permutation_map(), readOp.padding(), readOp.mask(),
@ -99,7 +100,7 @@ struct TransferWriteOpInterface
// Leave the previous transfer_write to dead code as it still has uses at
// this point.
FailureOr<Value> resultBuffer =
state.getResultBuffer(rewriter, op->getResult(0));
state.getBuffer(rewriter, op->getOpOperand(1) /*source*/);
if (failed(resultBuffer))
return failure();
rewriter.create<vector::TransferWriteOp>(