From 698896cd6c8cc5e865e1715e7c9d82295f82745b Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 7 Jan 2022 06:32:35 +0900 Subject: [PATCH] [mlir][linalg][bufferize][NFC] Change allocationFn return type to FailureOr In addition, all functions that call `allocationFn` now return FailureOr. This resolves a few TODOs in the code base. Differential Revision: https://reviews.llvm.org/D116452 --- .../BufferizableOpInterface.h | 10 ++++---- .../BufferizableOpInterface.cpp | 24 ++++++++++--------- .../LinalgInterfaceImpl.cpp | 14 +++++++---- .../SCFInterfaceImpl.cpp | 11 ++++++++- .../TensorInterfaceImpl.cpp | 17 ++++++++++--- .../VectorInterfaceImpl.cpp | 2 ++ .../Transforms/ComprehensiveBufferizePass.cpp | 6 ++--- 7 files changed, 57 insertions(+), 27 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h index 921353a23ea7..c18f7f9fc5e9 100644 --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -41,7 +41,7 @@ struct PostAnalysisStep; // TODO: Could be replaced with a "bufferization strategy" object with virtual // functions in the future. struct AllocationCallbacks { - using AllocationFn = std::function( + using AllocationFn = std::function( OpBuilder &, Location, MemRefType, ArrayRef)>; using DeallocationFn = std::function; using MemCpyFn = std::function; @@ -360,15 +360,15 @@ public: Value findLastPrecedingWrite(Value value) const; /// Creates a memref allocation. - Optional createAlloc(OpBuilder &b, Location loc, MemRefType type, - ArrayRef dynShape) const; + FailureOr createAlloc(OpBuilder &b, Location loc, MemRefType type, + ArrayRef dynShape) const; /// Creates a memref allocation for the given shaped value. This function may /// perform additional optimizations such as buffer allocation hoisting. If /// `createDealloc`, a deallocation op is inserted at the point where the /// allocation goes out of scope. - Value createAlloc(OpBuilder &b, Location loc, Value shapedValue, - bool deallocMemref) const; + FailureOr createAlloc(OpBuilder &b, Location loc, Value shapedValue, + bool deallocMemref) const; /// Creates a memref deallocation. The given memref buffer must have been /// allocated using `createAlloc`. diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp index 118e25a23148..b2a58069e85a 100644 --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -41,9 +41,9 @@ using namespace linalg::comprehensive_bufferize; /// Default allocation function that is used by the comprehensive bufferization /// pass. The default currently creates a ranked memref using `memref.alloc`. -static Optional defaultAllocationFn(OpBuilder &b, Location loc, - MemRefType type, - ArrayRef dynShape) { +static FailureOr defaultAllocationFn(OpBuilder &b, Location loc, + MemRefType type, + ArrayRef dynShape) { Value allocated = b.create( loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments)); return allocated; @@ -391,8 +391,10 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getResultBuffer( // allocation should be inserted (in the absence of allocation hoisting). setInsertionPointAfter(rewriter, operandBuffer); // Allocate the result buffer. - Value resultBuffer = + FailureOr resultBuffer = createAlloc(rewriter, loc, operandBuffer, options.createDeallocs); + if (failed(resultBuffer)) + return failure(); bool skipCopy = false; // Do not copy if the last preceding write of `operand` is an op that does // not write (skipping ops that merely create aliases). E.g., InitTensorOp. @@ -413,7 +415,7 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getResultBuffer( if (!skipCopy) { // The copy happens right before the op that is bufferized. rewriter.setInsertionPoint(op); - createMemCpy(rewriter, loc, operandBuffer, resultBuffer); + createMemCpy(rewriter, loc, operandBuffer, *resultBuffer); } return resultBuffer; } @@ -537,7 +539,8 @@ static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc, /// Create an AllocOp/DeallocOp pair, where the AllocOp is after /// `shapedValue.getDefiningOp` (or at the top of the block in case of a /// bbArg) and the DeallocOp is at the end of the block. -Value mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc( +FailureOr +mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc( OpBuilder &b, Location loc, Value shapedValue, bool deallocMemref) const { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -549,10 +552,9 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc( // Note: getAllocationTypeAndShape also sets the insertion point. MemRefType allocMemRefType = getAllocationTypeAndShape(b, loc, shapedValue, dynShape); - Optional allocated = createAlloc(b, loc, allocMemRefType, dynShape); - // TODO: For now just assert the value is returned. Eventually need to - // error-propagate. - assert(allocated && "allocation failed"); + FailureOr allocated = createAlloc(b, loc, allocMemRefType, dynShape); + if (failed(allocated)) + return failure(); Value casted = allocated.getValue(); if (memRefType && memRefType != allocMemRefType) { casted = b.create(loc, memRefType, allocated.getValue()); @@ -568,7 +570,7 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc( } /// Create a memref allocation. -Optional +FailureOr mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc( OpBuilder &b, Location loc, MemRefType type, ArrayRef dynShape) const { diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp index c4f42afb9828..a3cb3c36065b 100644 --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -55,6 +55,8 @@ static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op, OpResult opResult = op.getTiedOpResult(opOperand); assert(opResult && "could not find correspond OpResult"); FailureOr resultBuffer = state.getResultBuffer(rewriter, opResult); + if (failed(resultBuffer)) + return failure(); newOutputBuffers.push_back(*resultBuffer); } @@ -210,10 +212,12 @@ struct InitTensorOpInterface if (initTensorOp->getUses().empty()) return success(); - Value alloc = state.createAlloc(rewriter, initTensorOp->getLoc(), - initTensorOp.result(), - state.getOptions().createDeallocs); - replaceOpWithBufferizedValues(rewriter, op, alloc); + FailureOr alloc = state.createAlloc( + rewriter, initTensorOp->getLoc(), initTensorOp.result(), + state.getOptions().createDeallocs); + if (failed(alloc)) + return failure(); + replaceOpWithBufferizedValues(rewriter, op, *alloc); return success(); } }; @@ -287,6 +291,8 @@ struct TiledLoopOpInterface if (value.getType().isa()) { FailureOr buffer = state.getResultBuffer( rewriter, tiledLoopOp->getResult(nextResultNum++)); + if (failed(buffer)) + return failure(); newOutputs.push_back(*buffer); newResults.push_back(*buffer); } else { diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp index 5983d421aaed..1d62c7880a31 100644 --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp @@ -295,10 +295,19 @@ struct ForOpInterface }; // Construct a new scf.for op with memref instead of tensor values. + bool resultBufferFailure = false; SmallVector initArgs = convert(forOp.getInitArgs(), [&](Value val, int64_t index) { - return *state.getResultBuffer(rewriter, forOp->getOpResult(index)); + FailureOr resultBuffer = + state.getResultBuffer(rewriter, forOp->getOpResult(index)); + if (failed(resultBuffer)) { + resultBufferFailure = true; + return Value(); + } + return *resultBuffer; }); + if (resultBufferFailure) + return failure(); auto newForOp = rewriter.create( forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), initArgs); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp index 6b8b8983972a..b6ee0fc63471 100644 --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp @@ -54,6 +54,8 @@ struct CastOpInterface // The result buffer still has the old (pre-cast) type. FailureOr resultBuffer = state.getResultBuffer(rewriter, castOp->getResult(0)); + if (failed(resultBuffer)) + return failure(); auto sourceMemRefType = resultBuffer->getType().cast(); Attribute memorySpace = sourceMemRefType.getMemorySpace(); TensorType resultTensorType = @@ -149,9 +151,14 @@ struct ExtractSliceOpInterface // If not inplaceable, alloc. bool inplace = state.isInPlace(extractSliceOp->getResult(0)); Value alloc; - if (!inplace) - alloc = state.createAlloc(rewriter, loc, extractSliceOp.result(), - state.getOptions().createDeallocs); + if (!inplace) { + FailureOr allocOrFailure = + state.createAlloc(rewriter, loc, extractSliceOp.result(), + state.getOptions().createDeallocs); + if (failed(allocOrFailure)) + return failure(); + alloc = *allocOrFailure; + } // Bufferize to subview. auto subviewMemRefType = @@ -238,6 +245,8 @@ struct InsertOpInterface auto insertOp = cast(op); FailureOr destMemref = state.getResultBuffer(rewriter, insertOp->getOpResult(0)); + if (failed(destMemref)) + return failure(); rewriter.create(insertOp.getLoc(), insertOp.scalar(), *destMemref, insertOp.indices()); replaceOpWithBufferizedValues(rewriter, op, *destMemref); @@ -404,6 +413,8 @@ struct InsertSliceOpInterface // When bufferizing out-of-place, `getResultBuffer` allocates. FailureOr dstMemref = state.getResultBuffer(rewriter, insertSliceOp->getResult(0)); + if (failed(dstMemref)) + return failure(); // Take a subview of the dst. auto dstMemrefType = dstMemref->getType().cast(); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp index 3c8d6a9c96e5..58013323cb70 100644 --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp @@ -100,6 +100,8 @@ struct TransferWriteOpInterface // this point. FailureOr resultBuffer = state.getResultBuffer(rewriter, op->getResult(0)); + if (failed(resultBuffer)) + return failure(); rewriter.create( writeOp.getLoc(), writeOp.vector(), *resultBuffer, writeOp.indices(), writeOp.permutation_mapAttr(), writeOp.in_boundsAttr()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp index 13e18001d82e..21d7c4e62a45 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -64,9 +64,9 @@ static void applyEnablingTransformations(ModuleOp moduleOp) { (void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)); } -static Optional allocationFnUsingAlloca(OpBuilder &b, Location loc, - MemRefType type, - ArrayRef dynShape) { +static FailureOr allocationFnUsingAlloca(OpBuilder &b, Location loc, + MemRefType type, + ArrayRef dynShape) { Value allocated = b.create( loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments)); return allocated;