[mlir][linalg][bufferize][NFC] Change allocationFn return type to FailureOr<Value>

In addition, all functions that call `allocationFn` now return FailureOr<Value>. This resolves a few TODOs in the code base.

Differential Revision: https://reviews.llvm.org/D116452
This commit is contained in:
Matthias Springer 2022-01-07 06:32:35 +09:00
parent 4a661602ef
commit 698896cd6c
7 changed files with 57 additions and 27 deletions

View File

@ -41,7 +41,7 @@ struct PostAnalysisStep;
// TODO: Could be replaced with a "bufferization strategy" object with virtual
// functions in the future.
struct AllocationCallbacks {
using AllocationFn = std::function<Optional<Value>(
using AllocationFn = std::function<FailureOr<Value>(
OpBuilder &, Location, MemRefType, ArrayRef<Value>)>;
using DeallocationFn = std::function<void(OpBuilder &, Location, Value)>;
using MemCpyFn = std::function<void(OpBuilder &, Location, Value, Value)>;
@ -360,15 +360,15 @@ public:
Value findLastPrecedingWrite(Value value) const;
/// Creates a memref allocation.
Optional<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
ArrayRef<Value> dynShape) const;
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
ArrayRef<Value> dynShape) const;
/// Creates a memref allocation for the given shaped value. This function may
/// perform additional optimizations such as buffer allocation hoisting. If
/// `createDealloc`, a deallocation op is inserted at the point where the
/// allocation goes out of scope.
Value createAlloc(OpBuilder &b, Location loc, Value shapedValue,
bool deallocMemref) const;
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, Value shapedValue,
bool deallocMemref) const;
/// Creates a memref deallocation. The given memref buffer must have been
/// allocated using `createAlloc`.

View File

@ -41,9 +41,9 @@ using namespace linalg::comprehensive_bufferize;
/// Default allocation function that is used by the comprehensive bufferization
/// pass. The default currently creates a ranked memref using `memref.alloc`.
static Optional<Value> defaultAllocationFn(OpBuilder &b, Location loc,
MemRefType type,
ArrayRef<Value> dynShape) {
static FailureOr<Value> defaultAllocationFn(OpBuilder &b, Location loc,
MemRefType type,
ArrayRef<Value> dynShape) {
Value allocated = b.create<memref::AllocOp>(
loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
return allocated;
@ -391,8 +391,10 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getResultBuffer(
// allocation should be inserted (in the absence of allocation hoisting).
setInsertionPointAfter(rewriter, operandBuffer);
// Allocate the result buffer.
Value resultBuffer =
FailureOr<Value> resultBuffer =
createAlloc(rewriter, loc, operandBuffer, options.createDeallocs);
if (failed(resultBuffer))
return failure();
bool skipCopy = false;
// Do not copy if the last preceding write of `operand` is an op that does
// not write (skipping ops that merely create aliases). E.g., InitTensorOp.
@ -413,7 +415,7 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getResultBuffer(
if (!skipCopy) {
// The copy happens right before the op that is bufferized.
rewriter.setInsertionPoint(op);
createMemCpy(rewriter, loc, operandBuffer, resultBuffer);
createMemCpy(rewriter, loc, operandBuffer, *resultBuffer);
}
return resultBuffer;
}
@ -537,7 +539,8 @@ static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
/// Create an AllocOp/DeallocOp pair, where the AllocOp is after
/// `shapedValue.getDefiningOp` (or at the top of the block in case of a
/// bbArg) and the DeallocOp is at the end of the block.
Value mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
FailureOr<Value>
mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
OpBuilder &b, Location loc, Value shapedValue, bool deallocMemref) const {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
@ -549,10 +552,9 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
// Note: getAllocationTypeAndShape also sets the insertion point.
MemRefType allocMemRefType =
getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
Optional<Value> allocated = createAlloc(b, loc, allocMemRefType, dynShape);
// TODO: For now just assert the value is returned. Eventually need to
// error-propagate.
assert(allocated && "allocation failed");
FailureOr<Value> allocated = createAlloc(b, loc, allocMemRefType, dynShape);
if (failed(allocated))
return failure();
Value casted = allocated.getValue();
if (memRefType && memRefType != allocMemRefType) {
casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
@ -568,7 +570,7 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
}
/// Create a memref allocation.
Optional<Value>
FailureOr<Value>
mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
OpBuilder &b, Location loc, MemRefType type,
ArrayRef<Value> dynShape) const {

View File

@ -55,6 +55,8 @@ static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
OpResult opResult = op.getTiedOpResult(opOperand);
assert(opResult && "could not find correspond OpResult");
FailureOr<Value> resultBuffer = state.getResultBuffer(rewriter, opResult);
if (failed(resultBuffer))
return failure();
newOutputBuffers.push_back(*resultBuffer);
}
@ -210,10 +212,12 @@ struct InitTensorOpInterface
if (initTensorOp->getUses().empty())
return success();
Value alloc = state.createAlloc(rewriter, initTensorOp->getLoc(),
initTensorOp.result(),
state.getOptions().createDeallocs);
replaceOpWithBufferizedValues(rewriter, op, alloc);
FailureOr<Value> alloc = state.createAlloc(
rewriter, initTensorOp->getLoc(), initTensorOp.result(),
state.getOptions().createDeallocs);
if (failed(alloc))
return failure();
replaceOpWithBufferizedValues(rewriter, op, *alloc);
return success();
}
};
@ -287,6 +291,8 @@ struct TiledLoopOpInterface
if (value.getType().isa<TensorType>()) {
FailureOr<Value> buffer = state.getResultBuffer(
rewriter, tiledLoopOp->getResult(nextResultNum++));
if (failed(buffer))
return failure();
newOutputs.push_back(*buffer);
newResults.push_back(*buffer);
} else {

View File

@ -295,10 +295,19 @@ struct ForOpInterface
};
// Construct a new scf.for op with memref instead of tensor values.
bool resultBufferFailure = false;
SmallVector<Value> initArgs =
convert(forOp.getInitArgs(), [&](Value val, int64_t index) {
return *state.getResultBuffer(rewriter, forOp->getOpResult(index));
FailureOr<Value> resultBuffer =
state.getResultBuffer(rewriter, forOp->getOpResult(index));
if (failed(resultBuffer)) {
resultBufferFailure = true;
return Value();
}
return *resultBuffer;
});
if (resultBufferFailure)
return failure();
auto newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), initArgs);

View File

@ -54,6 +54,8 @@ struct CastOpInterface
// The result buffer still has the old (pre-cast) type.
FailureOr<Value> resultBuffer =
state.getResultBuffer(rewriter, castOp->getResult(0));
if (failed(resultBuffer))
return failure();
auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
Attribute memorySpace = sourceMemRefType.getMemorySpace();
TensorType resultTensorType =
@ -149,9 +151,14 @@ struct ExtractSliceOpInterface
// If not inplaceable, alloc.
bool inplace = state.isInPlace(extractSliceOp->getResult(0));
Value alloc;
if (!inplace)
alloc = state.createAlloc(rewriter, loc, extractSliceOp.result(),
state.getOptions().createDeallocs);
if (!inplace) {
FailureOr<Value> allocOrFailure =
state.createAlloc(rewriter, loc, extractSliceOp.result(),
state.getOptions().createDeallocs);
if (failed(allocOrFailure))
return failure();
alloc = *allocOrFailure;
}
// Bufferize to subview.
auto subviewMemRefType =
@ -238,6 +245,8 @@ struct InsertOpInterface
auto insertOp = cast<tensor::InsertOp>(op);
FailureOr<Value> destMemref =
state.getResultBuffer(rewriter, insertOp->getOpResult(0));
if (failed(destMemref))
return failure();
rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
*destMemref, insertOp.indices());
replaceOpWithBufferizedValues(rewriter, op, *destMemref);
@ -404,6 +413,8 @@ struct InsertSliceOpInterface
// When bufferizing out-of-place, `getResultBuffer` allocates.
FailureOr<Value> dstMemref =
state.getResultBuffer(rewriter, insertSliceOp->getResult(0));
if (failed(dstMemref))
return failure();
// Take a subview of the dst.
auto dstMemrefType = dstMemref->getType().cast<MemRefType>();

View File

@ -100,6 +100,8 @@ struct TransferWriteOpInterface
// this point.
FailureOr<Value> resultBuffer =
state.getResultBuffer(rewriter, op->getResult(0));
if (failed(resultBuffer))
return failure();
rewriter.create<vector::TransferWriteOp>(
writeOp.getLoc(), writeOp.vector(), *resultBuffer, writeOp.indices(),
writeOp.permutation_mapAttr(), writeOp.in_boundsAttr());

View File

@ -64,9 +64,9 @@ static void applyEnablingTransformations(ModuleOp moduleOp) {
(void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns));
}
static Optional<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
MemRefType type,
ArrayRef<Value> dynShape) {
static FailureOr<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
MemRefType type,
ArrayRef<Value> dynShape) {
Value allocated = b.create<memref::AllocaOp>(
loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
return allocated;