[mlir][linalg][bufferize][NFC] Simplify buffer API of BufferizationState

Instead of `lookupBuffer` and `getResultBuffer`, there is now a single `getBuffer` function. This simplifies the `BufferizableOpInterface` API and is less confusing to users. They could previously have called the wrong function.

Furthermore, since `getBuffer` now takes an `OpOperand &` instead of a `Value`, users can no longer accidentally use one of the previous two functions incorrectly, which would have resulted in missing buffer copies.

Differential Revision: https://reviews.llvm.org/D116455
This commit is contained in:
Matthias Springer 2022-01-08 01:10:15 +09:00
parent 8e2b6aac32
commit d9184ab1a5
7 changed files with 117 additions and 137 deletions

View File

@ -377,18 +377,14 @@ public:
/// Creates a memcpy between two given buffers. /// Creates a memcpy between two given buffers.
void createMemCpy(OpBuilder &b, Location loc, Value from, Value to) const; void createMemCpy(OpBuilder &b, Location loc, Value from, Value to) const;
/// Lookup the memref buffer that is associated to the given tensor value.
/// Asserts if no buffer is associated.
Value lookupBuffer(RewriterBase &rewriter, Value tensor) const;
/// Return `true` if the given OpResult has been decided to bufferize inplace. /// Return `true` if the given OpResult has been decided to bufferize inplace.
bool isInPlace(OpOperand &opOperand) const; bool isInPlace(OpOperand &opOperand) const;
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate /// Return the buffer (memref) for a given OpOperand (tensor). Allocate
/// a new buffer and copy over data from the existing buffer if out-of-place /// a new buffer and copy over data from the existing buffer if out-of-place
/// bufferization is necessary. /// bufferization was decided.
FailureOr<Value> getResultBuffer(RewriterBase &rewriter, FailureOr<Value> getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
OpResult result) const; bool forceInPlace = false) const;
/// Return dialect-specific bufferization state. /// Return dialect-specific bufferization state.
template <typename StateT> template <typename StateT>

View File

@ -347,74 +347,73 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
}); });
} }
static Value lookupBuffer(RewriterBase &rewriter, Value tensor) {
assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
// Replace "%t = to_tensor %m" with %m.
if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>())
return toTensorOp.memref();
// Insert to_memref op.
OpBuilder::InsertionGuard g(rewriter);
setInsertionPointAfter(rewriter, tensor);
Type memrefType;
if (auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>()) {
memrefType = getDynamicMemRefType(rankedTensorType);
} else {
memrefType = getUnrankedMemRefType(
tensor.getType().cast<TensorType>().getElementType());
}
return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType,
tensor);
}
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate /// Return the result buffer (memref) for a given OpResult (tensor). Allocate
/// a new buffer and copy over data from the existing buffer if out-of-place /// a new buffer and copy over data from the existing buffer if out-of-place
/// bufferization is necessary. /// bufferization is necessary.
FailureOr<Value> FailureOr<Value>
mlir::linalg::comprehensive_bufferize::BufferizationState::getResultBuffer( mlir::linalg::comprehensive_bufferize::BufferizationState::getBuffer(
RewriterBase &rewriter, OpResult result) const { RewriterBase &rewriter, OpOperand &opOperand, bool forceInPlace) const {
OpBuilder::InsertionGuard guard(rewriter); OpBuilder::InsertionGuard guard(rewriter);
Operation *op = result.getOwner(); Operation *op = opOperand.getOwner();
SmallVector<OpOperand *> aliasingOperands = getAliasingOpOperand(result); Location loc = op->getLoc();
assert(!aliasingOperands.empty() && "could not get aliasing OpOperand"); Value operand = opOperand.get();
OpOperand *opOperand = aliasingOperands.front();
Value operand = opOperand->get();
Value operandBuffer = lookupBuffer(rewriter, operand); Value operandBuffer = lookupBuffer(rewriter, operand);
// Make sure that all OpOperands are the same buffer. If this is not the case,
// we would have to materialize a memref value.
// TODO: Should be looking for checking for "equivalent buffers" instead of
// operator== here, but equivalent buffers for scf.if yield values are not
// set up yet.
if (aliasingOperands.size() > 1 &&
!llvm::all_of(aliasingOperands, [&](OpOperand *o) {
return lookupBuffer(rewriter, o->get()) == operandBuffer;
}))
return FailureOr<Value>(op->emitError("result buffer is ambiguous"));
// If bufferizing out-of-place, allocate a new buffer. if (forceInPlace || aliasInfo.isInPlace(opOperand))
if (!aliasInfo.isInPlace(*opOperand)) { return operandBuffer;
// Ops with multiple aliasing operands can currently not bufferize
// out-of-place. // Bufferizing out-of-place: Allocate a new buffer.
assert( // Move insertion point right after `operandBuffer`. That is where the
aliasingOperands.size() == 1 && // allocation should be inserted (in the absence of allocation hoisting).
"ops with multiple aliasing OpOperands cannot bufferize out-of-place"); setInsertionPointAfter(rewriter, operandBuffer);
Location loc = op->getLoc(); // Allocate the result buffer.
// Move insertion point right after `operandBuffer`. That is where the FailureOr<Value> resultBuffer =
// allocation should be inserted (in the absence of allocation hoisting). createAlloc(rewriter, loc, operandBuffer, options.createDeallocs);
setInsertionPointAfter(rewriter, operandBuffer); if (failed(resultBuffer))
// Allocate the result buffer. return failure();
FailureOr<Value> resultBuffer = // Do not copy if the last preceding write of `operand` is an op that does
createAlloc(rewriter, loc, operandBuffer, options.createDeallocs); // not write (skipping ops that merely create aliases). E.g., InitTensorOp.
if (failed(resultBuffer)) // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA
return failure(); // use-def chain, it returns that value, regardless of whether it is a
bool skipCopy = false; // memory write or not.
// Do not copy if the last preceding write of `operand` is an op that does Value lastWrite = findLastPrecedingWrite(operand);
// not write (skipping ops that merely create aliases). E.g., InitTensorOp. if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite))
// Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA if (!bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(), *this))
// use-def chain, it returns that value, regardless of whether it is a return resultBuffer;
// memory write or not. // Do not copy if the copied data is never read.
Value lastWrite = findLastPrecedingWrite(operand); OpResult aliasingOpResult = getAliasingOpResult(opOperand);
if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite)) if (aliasingOpResult && !bufferizesToMemoryRead(opOperand) &&
if (!bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(), *this)) !isValueRead(aliasingOpResult))
skipCopy = true; return resultBuffer;
// Do not copy if the copied data is never read. (Neither by this op nor by // Do not copy if this op does not read the data, but writes it.
// any following op.) if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))
if (!bufferizesToMemoryRead(*opOperand) && !isValueRead(result))
skipCopy = true;
// Do not copy if this op does not read the data, but writes it.
if (bufferizesToMemoryWrite(*opOperand) &&
!bufferizesToMemoryRead(*opOperand))
skipCopy = true;
if (!skipCopy) {
// The copy happens right before the op that is bufferized.
rewriter.setInsertionPoint(op);
createMemCpy(rewriter, loc, operandBuffer, *resultBuffer);
}
return resultBuffer; return resultBuffer;
}
// Bufferizing in-place. No need to allocate a new buffer. // The copy happens right before the op that is bufferized.
return operandBuffer; rewriter.setInsertionPoint(op);
createMemCpy(rewriter, loc, operandBuffer, *resultBuffer);
return resultBuffer;
} }
void mlir::linalg::comprehensive_bufferize::replaceOpWithBufferizedValues( void mlir::linalg::comprehensive_bufferize::replaceOpWithBufferizedValues(
@ -593,28 +592,6 @@ bool mlir::linalg::comprehensive_bufferize::isFunctionArgument(Value value) {
return isa<FuncOp>(bbArg.getOwner()->getParentOp()); return isa<FuncOp>(bbArg.getOwner()->getParentOp());
} }
Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
RewriterBase &rewriter, Value tensor) const {
assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
// Replace "%t = to_tensor %m" with %m.
if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>())
return toTensorOp.memref();
// Insert to_memref op.
OpBuilder::InsertionGuard g(rewriter);
setInsertionPointAfter(rewriter, tensor);
Type memrefType;
if (auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>()) {
memrefType = getDynamicMemRefType(rankedTensorType);
} else {
memrefType = getUnrankedMemRefType(
tensor.getType().cast<TensorType>().getElementType());
}
return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType,
tensor);
}
bool mlir::linalg::comprehensive_bufferize::BufferizationState::isInPlace( bool mlir::linalg::comprehensive_bufferize::BufferizationState::isInPlace(
OpOperand &opOperand) const { OpOperand &opOperand) const {
return aliasInfo.isInPlace(opOperand); return aliasInfo.isInPlace(opOperand);

View File

@ -46,15 +46,19 @@ static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
newInputBuffers.push_back(opOperand->get()); newInputBuffers.push_back(opOperand->get());
continue; continue;
} }
newInputBuffers.push_back(state.lookupBuffer(rewriter, opOperand->get())); // Input operands are never written to.
newInputBuffers.push_back(
*state.getBuffer(rewriter, *opOperand, /*forceInPlace=*/true));
} }
// New output operands for the cloned op. // New output operands for the cloned op.
SmallVector<Value> newOutputBuffers; SmallVector<Value> newOutputBuffers;
for (OpOperand *opOperand : op.getOutputOperands()) { for (OpResult opResult : op->getOpResults()) {
OpResult opResult = op.getTiedOpResult(opOperand); SmallVector<OpOperand *> aliasingOpOperands =
assert(opResult && "could not find correspond OpResult"); state.getAliasingOpOperand(opResult);
FailureOr<Value> resultBuffer = state.getResultBuffer(rewriter, opResult); assert(aliasingOpOperands.size() == 1 && "expected 1 OpOperand");
FailureOr<Value> resultBuffer =
state.getBuffer(rewriter, *aliasingOpOperands.front());
if (failed(resultBuffer)) if (failed(resultBuffer))
return failure(); return failure();
newOutputBuffers.push_back(*resultBuffer); newOutputBuffers.push_back(*resultBuffer);
@ -284,24 +288,23 @@ struct TiledLoopOpInterface
// Compute new inputs, outputs and results. // Compute new inputs, outputs and results.
SmallVector<Value> newInputs, newOutputs, newResults; SmallVector<Value> newInputs, newOutputs, newResults;
for (Value value : tiledLoopOp.inputs()) { for (int i = tiledLoopOp.getNumControlOperands();
if (value.getType().isa<TensorType>()) { i < tiledLoopOp->getNumOperands(); ++i) {
newInputs.push_back(state.lookupBuffer(rewriter, value)); OpOperand &operand = tiledLoopOp->getOpOperand(i);
} else { Value rewrittenValue = operand.get();
newInputs.push_back(value); if (rewrittenValue.getType().isa<TensorType>()) {
} FailureOr<Value> bufferOrFailure = state.getBuffer(rewriter, operand);
} if (failed(bufferOrFailure))
int nextResultNum = 0;
for (Value value : tiledLoopOp.outputs()) {
if (value.getType().isa<TensorType>()) {
FailureOr<Value> buffer = state.getResultBuffer(
rewriter, tiledLoopOp->getResult(nextResultNum++));
if (failed(buffer))
return failure(); return failure();
newOutputs.push_back(*buffer); rewrittenValue = *bufferOrFailure;
newResults.push_back(*buffer); }
if (i <
tiledLoopOp.getNumControlOperands() + tiledLoopOp.getNumInputs()) {
newInputs.push_back(rewrittenValue);
} else { } else {
newOutputs.push_back(value); newOutputs.push_back(rewrittenValue);
if (operand.get().getType().isa<TensorType>())
newResults.push_back(rewrittenValue);
} }
} }

View File

@ -351,7 +351,7 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
// Cast values at the call site if necessary. // Cast values at the call site if necessary.
returnValues.push_back( returnValues.push_back(
getNonCastedValue(state.lookupBuffer(rewriter, returnVal))); getNonCastedValue(*state.getBuffer(rewriter, returnOperand)));
} }
// 2. Rewrite the terminator without the inPlace bufferizable values. // 2. Rewrite the terminator without the inPlace bufferizable values.
@ -659,7 +659,8 @@ struct CallOpInterface
// Return operands that are equivalent to some bbArg, are not // Return operands that are equivalent to some bbArg, are not
// returned. // returned.
Value buffer = Value buffer =
state.lookupBuffer(rewriter, callOp->getOperand(*bbArgIdx)); *state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx),
/*forceInPlace=*/true);
replacementValues[returnValIdx] = buffer; replacementValues[returnValIdx] = buffer;
newOperands[*bbArgIdx] = buffer; newOperands[*bbArgIdx] = buffer;
continue; continue;
@ -690,9 +691,9 @@ struct CallOpInterface
// Retrieve buffers for tensor operands. Tensor operand buffers, who's // Retrieve buffers for tensor operands. Tensor operand buffers, who's
// corresponding FuncOp bbArgs are equivalent to a returned tensor, were // corresponding FuncOp bbArgs are equivalent to a returned tensor, were
// already stored in `newOperands` during Step 1. // already stored in `newOperands` during Step 1.
Value buffer = newOperands[idx] Value buffer = newOperands[idx] ? newOperands[idx]
? newOperands[idx] : *state.getBuffer(rewriter, opOperand,
: state.lookupBuffer(rewriter, tensorOperand); /*forceInPlace=*/true);
// Caller / callee type mistmatch is handled with a CastOp. // Caller / callee type mistmatch is handled with a CastOp.
auto memRefType = bufferizedFuncType.getInput(idx); auto memRefType = bufferizedFuncType.getInput(idx);

View File

@ -280,19 +280,17 @@ struct ForOpInterface
}; };
// Construct a new scf.for op with memref instead of tensor values. // Construct a new scf.for op with memref instead of tensor values.
bool resultBufferFailure = false; SmallVector<Value> initArgs;
SmallVector<Value> initArgs = for (OpOperand &opOperand : forOp.getIterOpOperands()) {
convert(forOp.getInitArgs(), [&](Value val, int64_t index) { if (opOperand.get().getType().isa<TensorType>()) {
FailureOr<Value> resultBuffer = FailureOr<Value> resultBuffer = state.getBuffer(rewriter, opOperand);
state.getResultBuffer(rewriter, forOp->getOpResult(index)); if (failed(resultBuffer))
if (failed(resultBuffer)) { return failure();
resultBufferFailure = true; initArgs.push_back(*resultBuffer);
return Value(); } else {
} initArgs.push_back(opOperand.get());
return *resultBuffer; }
}); }
if (resultBufferFailure)
return failure();
auto newForOp = rewriter.create<scf::ForOp>( auto newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), initArgs); forOp.getStep(), initArgs);

View File

@ -53,7 +53,7 @@ struct CastOpInterface
// The result buffer still has the old (pre-cast) type. // The result buffer still has the old (pre-cast) type.
FailureOr<Value> resultBuffer = FailureOr<Value> resultBuffer =
state.getResultBuffer(rewriter, castOp->getResult(0)); state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/);
if (failed(resultBuffer)) if (failed(resultBuffer))
return failure(); return failure();
auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>(); auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
@ -106,7 +106,7 @@ struct DimOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationState &state) const { const BufferizationState &state) const {
auto dimOp = cast<tensor::DimOp>(op); auto dimOp = cast<tensor::DimOp>(op);
Value v = state.lookupBuffer(rewriter, dimOp.source()); Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/);
replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index()); replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
return success(); return success();
} }
@ -143,7 +143,9 @@ struct ExtractSliceOpInterface
const BufferizationState &state) const { const BufferizationState &state) const {
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
Location loc = extractSliceOp.getLoc(); Location loc = extractSliceOp.getLoc();
Value srcMemref = state.lookupBuffer(rewriter, extractSliceOp.source()); Value srcMemref =
*state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
/*forceInPlace=*/true);
auto srcMemrefType = srcMemref.getType().cast<MemRefType>(); auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
auto dstTensorType = auto dstTensorType =
extractSliceOp.result().getType().cast<RankedTensorType>(); extractSliceOp.result().getType().cast<RankedTensorType>();
@ -206,7 +208,8 @@ struct ExtractOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationState &state) const { const BufferizationState &state) const {
auto extractOp = cast<tensor::ExtractOp>(op); auto extractOp = cast<tensor::ExtractOp>(op);
Value srcMemref = state.lookupBuffer(rewriter, extractOp.tensor()); Value srcMemref =
*state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/);
replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref, replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
extractOp.indices()); extractOp.indices());
return success(); return success();
@ -244,7 +247,7 @@ struct InsertOpInterface
const BufferizationState &state) const { const BufferizationState &state) const {
auto insertOp = cast<tensor::InsertOp>(op); auto insertOp = cast<tensor::InsertOp>(op);
FailureOr<Value> destMemref = FailureOr<Value> destMemref =
state.getResultBuffer(rewriter, insertOp->getOpResult(0)); state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/);
if (failed(destMemref)) if (failed(destMemref))
return failure(); return failure();
rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(), rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
@ -412,7 +415,7 @@ struct InsertSliceOpInterface
// When bufferizing out-of-place, `getResultBuffer` allocates. // When bufferizing out-of-place, `getResultBuffer` allocates.
FailureOr<Value> dstMemref = FailureOr<Value> dstMemref =
state.getResultBuffer(rewriter, insertSliceOp->getResult(0)); state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/);
if (failed(dstMemref)) if (failed(dstMemref))
return failure(); return failure();
@ -430,7 +433,8 @@ struct InsertSliceOpInterface
// Copy tensor. If this tensor.insert_slice has a matching // Copy tensor. If this tensor.insert_slice has a matching
// tensor.extract_slice, the copy operation will eventually fold away. // tensor.extract_slice, the copy operation will eventually fold away.
Value srcMemref = state.lookupBuffer(rewriter, insertSliceOp.source()); Value srcMemref =
*state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
state.createMemCpy(rewriter, loc, srcMemref, subView); state.createMemCpy(rewriter, loc, srcMemref, subView);
replaceOpWithBufferizedValues(rewriter, op, *dstMemref); replaceOpWithBufferizedValues(rewriter, op, *dstMemref);

View File

@ -48,7 +48,8 @@ struct TransferReadOpInterface
"only tensor types expected"); "only tensor types expected");
// TransferReadOp always reads from the bufferized op.source(). // TransferReadOp always reads from the bufferized op.source().
Value buffer = state.lookupBuffer(rewriter, readOp.source()); Value buffer =
*state.getBuffer(rewriter, readOp->getOpOperand(0) /*source*/);
replaceOpWithNewBufferizedOp<vector::TransferReadOp>( replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
rewriter, readOp, readOp.getVectorType(), buffer, readOp.indices(), rewriter, readOp, readOp.getVectorType(), buffer, readOp.indices(),
readOp.permutation_map(), readOp.padding(), readOp.mask(), readOp.permutation_map(), readOp.padding(), readOp.mask(),
@ -99,7 +100,7 @@ struct TransferWriteOpInterface
// Leave the previous transfer_write to dead code as it still has uses at // Leave the previous transfer_write to dead code as it still has uses at
// this point. // this point.
FailureOr<Value> resultBuffer = FailureOr<Value> resultBuffer =
state.getResultBuffer(rewriter, op->getResult(0)); state.getBuffer(rewriter, op->getOpOperand(1) /*source*/);
if (failed(resultBuffer)) if (failed(resultBuffer))
return failure(); return failure();
rewriter.create<vector::TransferWriteOp>( rewriter.create<vector::TransferWriteOp>(