[mlir][linalg][bufferize][NFC] Change allocationFn return type to FailureOr<Value>

In addition, all functions that call `allocationFn` now return FailureOr<Value>. This resolves a few TODOs in the code base.

Differential Revision: https://reviews.llvm.org/D116452
This commit is contained in:
Matthias Springer 2022-01-07 06:32:35 +09:00
parent 4a661602ef
commit 698896cd6c
7 changed files with 57 additions and 27 deletions

View File

@ -41,7 +41,7 @@ struct PostAnalysisStep;
// TODO: Could be replaced with a "bufferization strategy" object with virtual // TODO: Could be replaced with a "bufferization strategy" object with virtual
// functions in the future. // functions in the future.
struct AllocationCallbacks { struct AllocationCallbacks {
using AllocationFn = std::function<Optional<Value>( using AllocationFn = std::function<FailureOr<Value>(
OpBuilder &, Location, MemRefType, ArrayRef<Value>)>; OpBuilder &, Location, MemRefType, ArrayRef<Value>)>;
using DeallocationFn = std::function<void(OpBuilder &, Location, Value)>; using DeallocationFn = std::function<void(OpBuilder &, Location, Value)>;
using MemCpyFn = std::function<void(OpBuilder &, Location, Value, Value)>; using MemCpyFn = std::function<void(OpBuilder &, Location, Value, Value)>;
@ -360,15 +360,15 @@ public:
Value findLastPrecedingWrite(Value value) const; Value findLastPrecedingWrite(Value value) const;
/// Creates a memref allocation. /// Creates a memref allocation.
Optional<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type, FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
ArrayRef<Value> dynShape) const; ArrayRef<Value> dynShape) const;
/// Creates a memref allocation for the given shaped value. This function may /// Creates a memref allocation for the given shaped value. This function may
/// perform additional optimizations such as buffer allocation hoisting. If /// perform additional optimizations such as buffer allocation hoisting. If
/// `createDealloc`, a deallocation op is inserted at the point where the /// `createDealloc`, a deallocation op is inserted at the point where the
/// allocation goes out of scope. /// allocation goes out of scope.
Value createAlloc(OpBuilder &b, Location loc, Value shapedValue, FailureOr<Value> createAlloc(OpBuilder &b, Location loc, Value shapedValue,
bool deallocMemref) const; bool deallocMemref) const;
/// Creates a memref deallocation. The given memref buffer must have been /// Creates a memref deallocation. The given memref buffer must have been
/// allocated using `createAlloc`. /// allocated using `createAlloc`.

View File

@ -41,9 +41,9 @@ using namespace linalg::comprehensive_bufferize;
/// Default allocation function that is used by the comprehensive bufferization /// Default allocation function that is used by the comprehensive bufferization
/// pass. The default currently creates a ranked memref using `memref.alloc`. /// pass. The default currently creates a ranked memref using `memref.alloc`.
static Optional<Value> defaultAllocationFn(OpBuilder &b, Location loc, static FailureOr<Value> defaultAllocationFn(OpBuilder &b, Location loc,
MemRefType type, MemRefType type,
ArrayRef<Value> dynShape) { ArrayRef<Value> dynShape) {
Value allocated = b.create<memref::AllocOp>( Value allocated = b.create<memref::AllocOp>(
loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments)); loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
return allocated; return allocated;
@ -391,8 +391,10 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getResultBuffer(
// allocation should be inserted (in the absence of allocation hoisting). // allocation should be inserted (in the absence of allocation hoisting).
setInsertionPointAfter(rewriter, operandBuffer); setInsertionPointAfter(rewriter, operandBuffer);
// Allocate the result buffer. // Allocate the result buffer.
Value resultBuffer = FailureOr<Value> resultBuffer =
createAlloc(rewriter, loc, operandBuffer, options.createDeallocs); createAlloc(rewriter, loc, operandBuffer, options.createDeallocs);
if (failed(resultBuffer))
return failure();
bool skipCopy = false; bool skipCopy = false;
// Do not copy if the last preceding write of `operand` is an op that does // Do not copy if the last preceding write of `operand` is an op that does
// not write (skipping ops that merely create aliases). E.g., InitTensorOp. // not write (skipping ops that merely create aliases). E.g., InitTensorOp.
@ -413,7 +415,7 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getResultBuffer(
if (!skipCopy) { if (!skipCopy) {
// The copy happens right before the op that is bufferized. // The copy happens right before the op that is bufferized.
rewriter.setInsertionPoint(op); rewriter.setInsertionPoint(op);
createMemCpy(rewriter, loc, operandBuffer, resultBuffer); createMemCpy(rewriter, loc, operandBuffer, *resultBuffer);
} }
return resultBuffer; return resultBuffer;
} }
@ -537,7 +539,8 @@ static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
/// Create an AllocOp/DeallocOp pair, where the AllocOp is after /// Create an AllocOp/DeallocOp pair, where the AllocOp is after
/// `shapedValue.getDefiningOp` (or at the top of the block in case of a /// `shapedValue.getDefiningOp` (or at the top of the block in case of a
/// bbArg) and the DeallocOp is at the end of the block. /// bbArg) and the DeallocOp is at the end of the block.
Value mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc( FailureOr<Value>
mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
OpBuilder &b, Location loc, Value shapedValue, bool deallocMemref) const { OpBuilder &b, Location loc, Value shapedValue, bool deallocMemref) const {
// Take a guard before anything else. // Take a guard before anything else.
OpBuilder::InsertionGuard g(b); OpBuilder::InsertionGuard g(b);
@ -549,10 +552,9 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
// Note: getAllocationTypeAndShape also sets the insertion point. // Note: getAllocationTypeAndShape also sets the insertion point.
MemRefType allocMemRefType = MemRefType allocMemRefType =
getAllocationTypeAndShape(b, loc, shapedValue, dynShape); getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
Optional<Value> allocated = createAlloc(b, loc, allocMemRefType, dynShape); FailureOr<Value> allocated = createAlloc(b, loc, allocMemRefType, dynShape);
// TODO: For now just assert the value is returned. Eventually need to if (failed(allocated))
// error-propagate. return failure();
assert(allocated && "allocation failed");
Value casted = allocated.getValue(); Value casted = allocated.getValue();
if (memRefType && memRefType != allocMemRefType) { if (memRefType && memRefType != allocMemRefType) {
casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue()); casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
@ -568,7 +570,7 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
} }
/// Create a memref allocation. /// Create a memref allocation.
Optional<Value> FailureOr<Value>
mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc( mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
OpBuilder &b, Location loc, MemRefType type, OpBuilder &b, Location loc, MemRefType type,
ArrayRef<Value> dynShape) const { ArrayRef<Value> dynShape) const {

View File

@ -55,6 +55,8 @@ static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
OpResult opResult = op.getTiedOpResult(opOperand); OpResult opResult = op.getTiedOpResult(opOperand);
assert(opResult && "could not find correspond OpResult"); assert(opResult && "could not find correspond OpResult");
FailureOr<Value> resultBuffer = state.getResultBuffer(rewriter, opResult); FailureOr<Value> resultBuffer = state.getResultBuffer(rewriter, opResult);
if (failed(resultBuffer))
return failure();
newOutputBuffers.push_back(*resultBuffer); newOutputBuffers.push_back(*resultBuffer);
} }
@ -210,10 +212,12 @@ struct InitTensorOpInterface
if (initTensorOp->getUses().empty()) if (initTensorOp->getUses().empty())
return success(); return success();
Value alloc = state.createAlloc(rewriter, initTensorOp->getLoc(), FailureOr<Value> alloc = state.createAlloc(
initTensorOp.result(), rewriter, initTensorOp->getLoc(), initTensorOp.result(),
state.getOptions().createDeallocs); state.getOptions().createDeallocs);
replaceOpWithBufferizedValues(rewriter, op, alloc); if (failed(alloc))
return failure();
replaceOpWithBufferizedValues(rewriter, op, *alloc);
return success(); return success();
} }
}; };
@ -287,6 +291,8 @@ struct TiledLoopOpInterface
if (value.getType().isa<TensorType>()) { if (value.getType().isa<TensorType>()) {
FailureOr<Value> buffer = state.getResultBuffer( FailureOr<Value> buffer = state.getResultBuffer(
rewriter, tiledLoopOp->getResult(nextResultNum++)); rewriter, tiledLoopOp->getResult(nextResultNum++));
if (failed(buffer))
return failure();
newOutputs.push_back(*buffer); newOutputs.push_back(*buffer);
newResults.push_back(*buffer); newResults.push_back(*buffer);
} else { } else {

View File

@ -295,10 +295,19 @@ struct ForOpInterface
}; };
// Construct a new scf.for op with memref instead of tensor values. // Construct a new scf.for op with memref instead of tensor values.
bool resultBufferFailure = false;
SmallVector<Value> initArgs = SmallVector<Value> initArgs =
convert(forOp.getInitArgs(), [&](Value val, int64_t index) { convert(forOp.getInitArgs(), [&](Value val, int64_t index) {
return *state.getResultBuffer(rewriter, forOp->getOpResult(index)); FailureOr<Value> resultBuffer =
state.getResultBuffer(rewriter, forOp->getOpResult(index));
if (failed(resultBuffer)) {
resultBufferFailure = true;
return Value();
}
return *resultBuffer;
}); });
if (resultBufferFailure)
return failure();
auto newForOp = rewriter.create<scf::ForOp>( auto newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), initArgs); forOp.getStep(), initArgs);

View File

@ -54,6 +54,8 @@ struct CastOpInterface
// The result buffer still has the old (pre-cast) type. // The result buffer still has the old (pre-cast) type.
FailureOr<Value> resultBuffer = FailureOr<Value> resultBuffer =
state.getResultBuffer(rewriter, castOp->getResult(0)); state.getResultBuffer(rewriter, castOp->getResult(0));
if (failed(resultBuffer))
return failure();
auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>(); auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
Attribute memorySpace = sourceMemRefType.getMemorySpace(); Attribute memorySpace = sourceMemRefType.getMemorySpace();
TensorType resultTensorType = TensorType resultTensorType =
@ -149,9 +151,14 @@ struct ExtractSliceOpInterface
// If not inplaceable, alloc. // If not inplaceable, alloc.
bool inplace = state.isInPlace(extractSliceOp->getResult(0)); bool inplace = state.isInPlace(extractSliceOp->getResult(0));
Value alloc; Value alloc;
if (!inplace) if (!inplace) {
alloc = state.createAlloc(rewriter, loc, extractSliceOp.result(), FailureOr<Value> allocOrFailure =
state.getOptions().createDeallocs); state.createAlloc(rewriter, loc, extractSliceOp.result(),
state.getOptions().createDeallocs);
if (failed(allocOrFailure))
return failure();
alloc = *allocOrFailure;
}
// Bufferize to subview. // Bufferize to subview.
auto subviewMemRefType = auto subviewMemRefType =
@ -238,6 +245,8 @@ struct InsertOpInterface
auto insertOp = cast<tensor::InsertOp>(op); auto insertOp = cast<tensor::InsertOp>(op);
FailureOr<Value> destMemref = FailureOr<Value> destMemref =
state.getResultBuffer(rewriter, insertOp->getOpResult(0)); state.getResultBuffer(rewriter, insertOp->getOpResult(0));
if (failed(destMemref))
return failure();
rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(), rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
*destMemref, insertOp.indices()); *destMemref, insertOp.indices());
replaceOpWithBufferizedValues(rewriter, op, *destMemref); replaceOpWithBufferizedValues(rewriter, op, *destMemref);
@ -404,6 +413,8 @@ struct InsertSliceOpInterface
// When bufferizing out-of-place, `getResultBuffer` allocates. // When bufferizing out-of-place, `getResultBuffer` allocates.
FailureOr<Value> dstMemref = FailureOr<Value> dstMemref =
state.getResultBuffer(rewriter, insertSliceOp->getResult(0)); state.getResultBuffer(rewriter, insertSliceOp->getResult(0));
if (failed(dstMemref))
return failure();
// Take a subview of the dst. // Take a subview of the dst.
auto dstMemrefType = dstMemref->getType().cast<MemRefType>(); auto dstMemrefType = dstMemref->getType().cast<MemRefType>();

View File

@ -100,6 +100,8 @@ struct TransferWriteOpInterface
// this point. // this point.
FailureOr<Value> resultBuffer = FailureOr<Value> resultBuffer =
state.getResultBuffer(rewriter, op->getResult(0)); state.getResultBuffer(rewriter, op->getResult(0));
if (failed(resultBuffer))
return failure();
rewriter.create<vector::TransferWriteOp>( rewriter.create<vector::TransferWriteOp>(
writeOp.getLoc(), writeOp.vector(), *resultBuffer, writeOp.indices(), writeOp.getLoc(), writeOp.vector(), *resultBuffer, writeOp.indices(),
writeOp.permutation_mapAttr(), writeOp.in_boundsAttr()); writeOp.permutation_mapAttr(), writeOp.in_boundsAttr());

View File

@ -64,9 +64,9 @@ static void applyEnablingTransformations(ModuleOp moduleOp) {
(void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)); (void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns));
} }
static Optional<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc, static FailureOr<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
MemRefType type, MemRefType type,
ArrayRef<Value> dynShape) { ArrayRef<Value> dynShape) {
Value allocated = b.create<memref::AllocaOp>( Value allocated = b.create<memref::AllocaOp>(
loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments)); loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
return allocated; return allocated;