[mlir][bufferize][NFC] Deallocate all buffers at the end of bufferization

This makes bufferization more modular. This is in preparation of future refactorings.

Differential Revision: https://reviews.llvm.org/D121362
This commit is contained in:
Matthias Springer 2022-03-15 17:50:09 +09:00
parent 875782bd9e
commit 05e0495f1d
7 changed files with 131 additions and 120 deletions

View File

@ -415,13 +415,22 @@ struct BufferizationState {
BufferizationState(const AnalysisState &analysisState)
: analysisState(analysisState) {}
/// Creates a memref allocation with the given type and dynamic extents.
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
ValueRange dynShape);
/// Creates a memref allocation for the given shaped value. This function may
/// perform additional optimizations such as buffer allocation hoisting.
// TODO: Allocation hoisting should be a cleanup pass.
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, Value shapedValue);
/// Return the buffer (memref) for a given OpOperand (tensor). Allocate
/// a new buffer and copy over data from the existing buffer if out-of-place
/// bufferization was decided.
FailureOr<Value>
getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
bool forceInPlace = false,
Optional<Operation *> customCopyInsertionPoint = None) const;
Optional<Operation *> customCopyInsertionPoint = None);
/// Return a reference to the BufferizationOptions.
const BufferizationOptions &getOptions() const {
@ -477,27 +486,6 @@ BaseMemRefType getMemRefType(TensorType tensorType,
MemRefLayoutAttrInterface layout = {},
Attribute memorySpace = {});
/// Creates a memref allocation with the given type and dynamic extents.
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
ValueRange dynShape,
const BufferizationOptions &options);
/// Creates a memref allocation with the given type and dynamic extents. If
/// `createDealloc`, a deallocation op is inserted at the point where the
/// allocation goes out of scope.
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
ValueRange dynShape, bool deallocMemref,
const BufferizationOptions &options);
/// Creates a memref allocation for the given shaped value. This function may
/// perform additional optimizations such as buffer allocation hoisting. If
/// `createDealloc`, a deallocation op is inserted at the point where the
/// allocation goes out of scope.
// TODO: Allocation hoisting should be a cleanup pass.
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, Value shapedValue,
bool deallocMemref,
const BufferizationOptions &options);
/// Creates a memref deallocation. The given memref buffer must have been
/// allocated using `createAlloc`.
LogicalResult createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
@ -507,6 +495,10 @@ LogicalResult createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
LogicalResult createMemCpy(OpBuilder &b, Location loc, Value from, Value to,
const BufferizationOptions &options);
/// Finalize all buffer allocations, i.e., create alloc ops as specified in the
/// bufferization options and deallocate all buffers.
LogicalResult finalizeBuffers(Operation *op,
const BufferizationOptions &options);
} // namespace bufferization
} // namespace mlir

View File

@ -70,13 +70,6 @@ void populateEliminateBufferizeMaterializationsPatterns(
// TODO: Extract `options` from `state` and pass as separate argument.
LogicalResult bufferizeOp(Operation *op, const AnalysisState &analysisState);
/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
/// Reuse an existing `BufferizationState`.
///
/// Note: This function overload is useful for extending the bufferization.
LogicalResult bufferizeOp(Operation *op,
BufferizationState &bufferizationState);
/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
/// Buffers are duplicated and copied before any tensor use that bufferizes to
/// a memory write.
@ -87,6 +80,16 @@ LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options);
BufferizationOptions getPartialBufferizationOptions();
//===----------------------------------------------------------------------===//
// Helper functions for extending Bufferization
//===----------------------------------------------------------------------===//
/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
/// Reuse an existing `BufferizationState`.
///
/// Note: This function overload is useful for extending the bufferization.
LogicalResult bufferizeOp(Operation *op,
BufferizationState &bufferizationState);
} // namespace bufferization
} // namespace mlir

View File

@ -42,6 +42,8 @@ constexpr const ::llvm::StringLiteral
constexpr const ::llvm::StringLiteral
bufferization::BufferizableOpInterface::kInplaceableAttrName;
static const char *kBufferAllocationAttr = "bufferization.allocation";
//===----------------------------------------------------------------------===//
// BufferizationOptions
//===----------------------------------------------------------------------===//
@ -243,9 +245,10 @@ Value mlir::bufferization::lookupBuffer(RewriterBase &rewriter, Value tensor,
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
/// a new buffer and copy over data from the existing buffer if out-of-place
/// bufferization is necessary.
FailureOr<Value> BufferizationState::getBuffer(
RewriterBase &rewriter, OpOperand &opOperand, bool forceInPlace,
Optional<Operation *> customCopyInsertionPoint) const {
FailureOr<Value>
BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
bool forceInPlace,
Optional<Operation *> customCopyInsertionPoint) {
const BufferizationOptions &options = analysisState.getOptions();
OpBuilder::InsertionGuard guard(rewriter);
Operation *op = opOperand.getOwner();
@ -261,8 +264,7 @@ FailureOr<Value> BufferizationState::getBuffer(
// allocation should be inserted (in the absence of allocation hoisting).
setInsertionPointAfter(rewriter, operandBuffer);
// Allocate the result buffer.
FailureOr<Value> resultBuffer = createAlloc(rewriter, loc, operandBuffer,
options.createDeallocs, options);
FailureOr<Value> resultBuffer = createAlloc(rewriter, loc, operandBuffer);
if (failed(resultBuffer))
return failure();
// Do not copy if the last preceding writes of `operand` are ops that do
@ -358,6 +360,33 @@ bool AlwaysCopyAnalysisState::areEquivalentBufferizedValues(Value v1,
// Bufferization-specific scoped alloc/dealloc insertion support.
//===----------------------------------------------------------------------===//
/// Create a memref allocation with the given type and dynamic extents.
static FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
ValueRange dynShape,
const BufferizationOptions &options) {
if (options.allocationFn)
return (*options.allocationFn)(b, loc, type, dynShape,
options.bufferAlignment);
// Default bufferallocation via AllocOp.
Value allocated = b.create<memref::AllocOp>(
loc, type, dynShape, b.getI64IntegerAttr(options.bufferAlignment));
return allocated;
}
/// Creates a memref deallocation. The given memref buffer must have been
/// allocated using `createAlloc`.
LogicalResult
bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
const BufferizationOptions &options) {
if (options.deallocationFn)
return (*options.deallocationFn)(b, loc, allocatedBuffer);
// Default buffer deallocation via DeallocOp.
b.create<memref::DeallocOp>(loc, allocatedBuffer);
return success();
}
/// Move the insertion point of the given builder to the beginning of a
/// surrounding block as much as possible, while not crossing any allocation
/// hoisting barriers.
@ -436,92 +465,39 @@ static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
return allocMemRefType;
}
/// Create an AllocOp/DeallocOp pair, where the AllocOp is after
/// `shapedValue.getDefiningOp` (or at the top of the block in case of a
/// bbArg) and the DeallocOp is at the end of the block.
FailureOr<Value>
bufferization::createAlloc(OpBuilder &b, Location loc, Value shapedValue,
bool deallocMemref,
const BufferizationOptions &options) {
static Value createBufferAllocation(OpBuilder &b, Location loc, MemRefType type,
ValueRange dynShape) {
auto allocaOp = b.create<memref::AllocaOp>(loc, type, dynShape);
allocaOp->setAttr(kBufferAllocationAttr, b.getUnitAttr());
return allocaOp.getResult();
}
/// Create an allocation after `shapedValue.getDefiningOp` (or at the top of the
/// block in case of a bbArg).
FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc,
Value shapedValue) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
// 1. Create memory allocation.
assert(shapedValue.getType().isa<ShapedType>());
MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>();
SmallVector<Value> dynShape;
// Note: getAllocationTypeAndShape also sets the insertion point.
MemRefType allocMemRefType =
getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
FailureOr<Value> allocated =
createAlloc(b, loc, allocMemRefType, dynShape, options);
if (failed(allocated))
return failure();
Value casted = allocated.getValue();
Value alloc = createBufferAllocation(b, loc, allocMemRefType, dynShape);
if (memRefType && memRefType != allocMemRefType) {
assert(memref::CastOp::areCastCompatible(allocated.getValue().getType(),
memRefType) &&
assert(memref::CastOp::areCastCompatible(alloc.getType(), memRefType) &&
"createAlloc: cast incompatible");
casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
alloc = b.create<memref::CastOp>(loc, memRefType, alloc);
}
if (deallocMemref) {
// 2. Create memory deallocation.
b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator());
if (failed(createDealloc(b, loc, allocated.getValue(), options)))
return failure();
}
return casted;
}
/// Create a memref allocation with the given type and dynamic extents.
FailureOr<Value>
bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type,
ValueRange dynShape,
const BufferizationOptions &options) {
if (options.allocationFn)
return (*options.allocationFn)(b, loc, type, dynShape,
options.bufferAlignment);
// Default bufferallocation via AllocOp.
Value allocated = b.create<memref::AllocOp>(
loc, type, dynShape, b.getI64IntegerAttr(options.bufferAlignment));
return allocated;
}
/// Create a memref allocation with the given type and dynamic extents. May also
/// deallocate the memref again.
FailureOr<Value>
bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type,
ValueRange dynShape, bool deallocMemref,
const BufferizationOptions &options) {
OpBuilder::InsertionGuard g(b);
FailureOr<Value> alloc = createAlloc(b, loc, type, dynShape, options);
if (failed(alloc))
return failure();
if (deallocMemref) {
// Dealloc at the end of the block.
b.setInsertionPoint(alloc.getValue().getParentBlock()->getTerminator());
if (failed(createDealloc(b, loc, *alloc, options)))
return failure();
}
return alloc;
}
/// Create a memref deallocation.
LogicalResult
bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
const BufferizationOptions &options) {
if (options.deallocationFn)
return (*options.deallocationFn)(b, loc, allocatedBuffer);
// Default buffer deallocation via DeallocOp.
b.create<memref::DeallocOp>(loc, allocatedBuffer);
return success();
/// Create a memref allocation with the given type and dynamic extents.
FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc,
MemRefType type,
ValueRange dynShape) {
return createBufferAllocation(b, loc, type, dynShape);
}
/// Create a memory copy between two memref buffers.
@ -535,6 +511,41 @@ LogicalResult bufferization::createMemCpy(OpBuilder &b, Location loc,
return success();
}
LogicalResult
bufferization::finalizeBuffers(Operation *op,
const BufferizationOptions &options) {
IRRewriter rewriter(op->getContext());
// Bufferization creates memref.alloca ops. After bufferization, these must be
// rewritten to alloc/dealloc ops as specified in the bufferization options.
WalkResult status = op->walk([&](memref::AllocaOp allocaOp) {
// Ignore memref.alloca ops that were not created by the bufferization.
if (!allocaOp->hasAttr(kBufferAllocationAttr))
return WalkResult::skip();
Block *block = allocaOp->getBlock();
rewriter.setInsertionPoint(allocaOp);
FailureOr<Value> alloc =
createAlloc(rewriter, allocaOp->getLoc(), allocaOp.getType(),
allocaOp.dynamicSizes(), options);
if (failed(alloc))
return WalkResult::interrupt();
rewriter.replaceOp(allocaOp, *alloc);
// Stop here if deallocations are deactivated.
if (!options.createDeallocs)
return WalkResult::advance();
rewriter.setInsertionPoint(block->getTerminator());
if (failed(createDealloc(rewriter, alloc->getLoc(), *alloc, options)))
return WalkResult::interrupt();
return WalkResult::advance();
});
return success(!status.wasInterrupted());
}
//===----------------------------------------------------------------------===//
// Bufferization-specific BlockAndValueMapping support with debugging.
//===----------------------------------------------------------------------===//

View File

@ -302,7 +302,11 @@ checkBufferizationResult(Operation *op, const BufferizationOptions &options) {
LogicalResult bufferization::bufferizeOp(Operation *op,
const AnalysisState &analysisState) {
BufferizationState bufferizationState(analysisState);
return bufferizeOp(op, bufferizationState);
if (failed(bufferizeOp(op, bufferizationState)))
return failure();
if (failed(finalizeBuffers(op, analysisState.getOptions())))
return failure();
return success();
}
LogicalResult
@ -332,7 +336,10 @@ bufferization::bufferizeOp(Operation *op,
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
return failure();
return checkBufferizationResult(op, bufferizationState.getOptions());
if (failed(checkBufferizationResult(op, bufferizationState.getOptions())))
return failure();
return success();
}
namespace {

View File

@ -1054,6 +1054,10 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize(
}
}
// Finalize all buffers.
if (failed(finalizeBuffers(moduleOp, options)))
return failure();
// Perform a post-processing pass of layout modification at function boundary
// according to the kBufferLayoutAttrName.
layoutPostProcessing(moduleOp);

View File

@ -235,9 +235,8 @@ struct InitTensorOpInterface
if (initTensorOp->getUses().empty())
return success();
FailureOr<Value> alloc =
createAlloc(rewriter, initTensorOp->getLoc(), initTensorOp.result(),
state.getOptions().createDeallocs, state.getOptions());
FailureOr<Value> alloc = state.createAlloc(rewriter, initTensorOp->getLoc(),
initTensorOp.result());
if (failed(alloc))
return failure();
replaceOpWithBufferizedValues(rewriter, op, *alloc);

View File

@ -228,8 +228,7 @@ struct ExtractSliceOpInterface
Value alloc;
if (!inplace) {
FailureOr<Value> allocOrFailure =
createAlloc(rewriter, loc, extractSliceOp.result(),
state.getOptions().createDeallocs, state.getOptions());
state.createAlloc(rewriter, loc, extractSliceOp.result());
if (failed(allocOrFailure))
return failure();
alloc = *allocOrFailure;
@ -338,9 +337,7 @@ struct FromElementsOpInterface
auto shape = tensorType.getShape();
MemRefType resultType = getContiguousMemRefType(tensorType);
FailureOr<Value> maybeBuffer =
createAlloc(rewriter, loc, resultType, {},
/*deallocMemref=*/state.getOptions().createDeallocs,
state.getOptions());
state.createAlloc(rewriter, loc, resultType, {});
if (failed(maybeBuffer))
return failure();
Value buffer = *maybeBuffer;
@ -389,10 +386,8 @@ struct GenerateOpInterface
Location loc = op->getLoc();
MemRefType memrefType =
getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>());
FailureOr<Value> maybeResult =
createAlloc(rewriter, loc, memrefType, generateOp.dynamicExtents(),
/*deallocMemref=*/state.getOptions().createDeallocs,
state.getOptions());
FailureOr<Value> maybeResult = state.createAlloc(
rewriter, loc, memrefType, generateOp.dynamicExtents());
if (failed(maybeResult))
return failure();
Value result = *maybeResult;