forked from OSchip/llvm-project
[mlir][linalg][bufferize][NFC] Change allocationFn return type to FailureOr<Value>
In addition, all functions that call `allocationFn` now return FailureOr<Value>. This resolves a few TODOs in the code base. Differential Revision: https://reviews.llvm.org/D116452
This commit is contained in:
parent
4a661602ef
commit
698896cd6c
|
@ -41,7 +41,7 @@ struct PostAnalysisStep;
|
|||
// TODO: Could be replaced with a "bufferization strategy" object with virtual
|
||||
// functions in the future.
|
||||
struct AllocationCallbacks {
|
||||
using AllocationFn = std::function<Optional<Value>(
|
||||
using AllocationFn = std::function<FailureOr<Value>(
|
||||
OpBuilder &, Location, MemRefType, ArrayRef<Value>)>;
|
||||
using DeallocationFn = std::function<void(OpBuilder &, Location, Value)>;
|
||||
using MemCpyFn = std::function<void(OpBuilder &, Location, Value, Value)>;
|
||||
|
@ -360,15 +360,15 @@ public:
|
|||
Value findLastPrecedingWrite(Value value) const;
|
||||
|
||||
/// Creates a memref allocation.
|
||||
Optional<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
|
||||
ArrayRef<Value> dynShape) const;
|
||||
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
|
||||
ArrayRef<Value> dynShape) const;
|
||||
|
||||
/// Creates a memref allocation for the given shaped value. This function may
|
||||
/// perform additional optimizations such as buffer allocation hoisting. If
|
||||
/// `createDealloc`, a deallocation op is inserted at the point where the
|
||||
/// allocation goes out of scope.
|
||||
Value createAlloc(OpBuilder &b, Location loc, Value shapedValue,
|
||||
bool deallocMemref) const;
|
||||
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, Value shapedValue,
|
||||
bool deallocMemref) const;
|
||||
|
||||
/// Creates a memref deallocation. The given memref buffer must have been
|
||||
/// allocated using `createAlloc`.
|
||||
|
|
|
@ -41,9 +41,9 @@ using namespace linalg::comprehensive_bufferize;
|
|||
|
||||
/// Default allocation function that is used by the comprehensive bufferization
|
||||
/// pass. The default currently creates a ranked memref using `memref.alloc`.
|
||||
static Optional<Value> defaultAllocationFn(OpBuilder &b, Location loc,
|
||||
MemRefType type,
|
||||
ArrayRef<Value> dynShape) {
|
||||
static FailureOr<Value> defaultAllocationFn(OpBuilder &b, Location loc,
|
||||
MemRefType type,
|
||||
ArrayRef<Value> dynShape) {
|
||||
Value allocated = b.create<memref::AllocOp>(
|
||||
loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
|
||||
return allocated;
|
||||
|
@ -391,8 +391,10 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getResultBuffer(
|
|||
// allocation should be inserted (in the absence of allocation hoisting).
|
||||
setInsertionPointAfter(rewriter, operandBuffer);
|
||||
// Allocate the result buffer.
|
||||
Value resultBuffer =
|
||||
FailureOr<Value> resultBuffer =
|
||||
createAlloc(rewriter, loc, operandBuffer, options.createDeallocs);
|
||||
if (failed(resultBuffer))
|
||||
return failure();
|
||||
bool skipCopy = false;
|
||||
// Do not copy if the last preceding write of `operand` is an op that does
|
||||
// not write (skipping ops that merely create aliases). E.g., InitTensorOp.
|
||||
|
@ -413,7 +415,7 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getResultBuffer(
|
|||
if (!skipCopy) {
|
||||
// The copy happens right before the op that is bufferized.
|
||||
rewriter.setInsertionPoint(op);
|
||||
createMemCpy(rewriter, loc, operandBuffer, resultBuffer);
|
||||
createMemCpy(rewriter, loc, operandBuffer, *resultBuffer);
|
||||
}
|
||||
return resultBuffer;
|
||||
}
|
||||
|
@ -537,7 +539,8 @@ static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
|
|||
/// Create an AllocOp/DeallocOp pair, where the AllocOp is after
|
||||
/// `shapedValue.getDefiningOp` (or at the top of the block in case of a
|
||||
/// bbArg) and the DeallocOp is at the end of the block.
|
||||
Value mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
|
||||
FailureOr<Value>
|
||||
mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
|
||||
OpBuilder &b, Location loc, Value shapedValue, bool deallocMemref) const {
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
|
@ -549,10 +552,9 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
|
|||
// Note: getAllocationTypeAndShape also sets the insertion point.
|
||||
MemRefType allocMemRefType =
|
||||
getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
|
||||
Optional<Value> allocated = createAlloc(b, loc, allocMemRefType, dynShape);
|
||||
// TODO: For now just assert the value is returned. Eventually need to
|
||||
// error-propagate.
|
||||
assert(allocated && "allocation failed");
|
||||
FailureOr<Value> allocated = createAlloc(b, loc, allocMemRefType, dynShape);
|
||||
if (failed(allocated))
|
||||
return failure();
|
||||
Value casted = allocated.getValue();
|
||||
if (memRefType && memRefType != allocMemRefType) {
|
||||
casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
|
||||
|
@ -568,7 +570,7 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
|
|||
}
|
||||
|
||||
/// Create a memref allocation.
|
||||
Optional<Value>
|
||||
FailureOr<Value>
|
||||
mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
|
||||
OpBuilder &b, Location loc, MemRefType type,
|
||||
ArrayRef<Value> dynShape) const {
|
||||
|
|
|
@ -55,6 +55,8 @@ static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
|
|||
OpResult opResult = op.getTiedOpResult(opOperand);
|
||||
assert(opResult && "could not find correspond OpResult");
|
||||
FailureOr<Value> resultBuffer = state.getResultBuffer(rewriter, opResult);
|
||||
if (failed(resultBuffer))
|
||||
return failure();
|
||||
newOutputBuffers.push_back(*resultBuffer);
|
||||
}
|
||||
|
||||
|
@ -210,10 +212,12 @@ struct InitTensorOpInterface
|
|||
if (initTensorOp->getUses().empty())
|
||||
return success();
|
||||
|
||||
Value alloc = state.createAlloc(rewriter, initTensorOp->getLoc(),
|
||||
initTensorOp.result(),
|
||||
state.getOptions().createDeallocs);
|
||||
replaceOpWithBufferizedValues(rewriter, op, alloc);
|
||||
FailureOr<Value> alloc = state.createAlloc(
|
||||
rewriter, initTensorOp->getLoc(), initTensorOp.result(),
|
||||
state.getOptions().createDeallocs);
|
||||
if (failed(alloc))
|
||||
return failure();
|
||||
replaceOpWithBufferizedValues(rewriter, op, *alloc);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -287,6 +291,8 @@ struct TiledLoopOpInterface
|
|||
if (value.getType().isa<TensorType>()) {
|
||||
FailureOr<Value> buffer = state.getResultBuffer(
|
||||
rewriter, tiledLoopOp->getResult(nextResultNum++));
|
||||
if (failed(buffer))
|
||||
return failure();
|
||||
newOutputs.push_back(*buffer);
|
||||
newResults.push_back(*buffer);
|
||||
} else {
|
||||
|
|
|
@ -295,10 +295,19 @@ struct ForOpInterface
|
|||
};
|
||||
|
||||
// Construct a new scf.for op with memref instead of tensor values.
|
||||
bool resultBufferFailure = false;
|
||||
SmallVector<Value> initArgs =
|
||||
convert(forOp.getInitArgs(), [&](Value val, int64_t index) {
|
||||
return *state.getResultBuffer(rewriter, forOp->getOpResult(index));
|
||||
FailureOr<Value> resultBuffer =
|
||||
state.getResultBuffer(rewriter, forOp->getOpResult(index));
|
||||
if (failed(resultBuffer)) {
|
||||
resultBufferFailure = true;
|
||||
return Value();
|
||||
}
|
||||
return *resultBuffer;
|
||||
});
|
||||
if (resultBufferFailure)
|
||||
return failure();
|
||||
auto newForOp = rewriter.create<scf::ForOp>(
|
||||
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
|
||||
forOp.getStep(), initArgs);
|
||||
|
|
|
@ -54,6 +54,8 @@ struct CastOpInterface
|
|||
// The result buffer still has the old (pre-cast) type.
|
||||
FailureOr<Value> resultBuffer =
|
||||
state.getResultBuffer(rewriter, castOp->getResult(0));
|
||||
if (failed(resultBuffer))
|
||||
return failure();
|
||||
auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
|
||||
Attribute memorySpace = sourceMemRefType.getMemorySpace();
|
||||
TensorType resultTensorType =
|
||||
|
@ -149,9 +151,14 @@ struct ExtractSliceOpInterface
|
|||
// If not inplaceable, alloc.
|
||||
bool inplace = state.isInPlace(extractSliceOp->getResult(0));
|
||||
Value alloc;
|
||||
if (!inplace)
|
||||
alloc = state.createAlloc(rewriter, loc, extractSliceOp.result(),
|
||||
state.getOptions().createDeallocs);
|
||||
if (!inplace) {
|
||||
FailureOr<Value> allocOrFailure =
|
||||
state.createAlloc(rewriter, loc, extractSliceOp.result(),
|
||||
state.getOptions().createDeallocs);
|
||||
if (failed(allocOrFailure))
|
||||
return failure();
|
||||
alloc = *allocOrFailure;
|
||||
}
|
||||
|
||||
// Bufferize to subview.
|
||||
auto subviewMemRefType =
|
||||
|
@ -238,6 +245,8 @@ struct InsertOpInterface
|
|||
auto insertOp = cast<tensor::InsertOp>(op);
|
||||
FailureOr<Value> destMemref =
|
||||
state.getResultBuffer(rewriter, insertOp->getOpResult(0));
|
||||
if (failed(destMemref))
|
||||
return failure();
|
||||
rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
|
||||
*destMemref, insertOp.indices());
|
||||
replaceOpWithBufferizedValues(rewriter, op, *destMemref);
|
||||
|
@ -404,6 +413,8 @@ struct InsertSliceOpInterface
|
|||
// When bufferizing out-of-place, `getResultBuffer` allocates.
|
||||
FailureOr<Value> dstMemref =
|
||||
state.getResultBuffer(rewriter, insertSliceOp->getResult(0));
|
||||
if (failed(dstMemref))
|
||||
return failure();
|
||||
|
||||
// Take a subview of the dst.
|
||||
auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
|
||||
|
|
|
@ -100,6 +100,8 @@ struct TransferWriteOpInterface
|
|||
// this point.
|
||||
FailureOr<Value> resultBuffer =
|
||||
state.getResultBuffer(rewriter, op->getResult(0));
|
||||
if (failed(resultBuffer))
|
||||
return failure();
|
||||
rewriter.create<vector::TransferWriteOp>(
|
||||
writeOp.getLoc(), writeOp.vector(), *resultBuffer, writeOp.indices(),
|
||||
writeOp.permutation_mapAttr(), writeOp.in_boundsAttr());
|
||||
|
|
|
@ -64,9 +64,9 @@ static void applyEnablingTransformations(ModuleOp moduleOp) {
|
|||
(void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns));
|
||||
}
|
||||
|
||||
static Optional<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
|
||||
MemRefType type,
|
||||
ArrayRef<Value> dynShape) {
|
||||
static FailureOr<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
|
||||
MemRefType type,
|
||||
ArrayRef<Value> dynShape) {
|
||||
Value allocated = b.create<memref::AllocaOp>(
|
||||
loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
|
||||
return allocated;
|
||||
|
|
Loading…
Reference in New Issue