forked from OSchip/llvm-project
[mlir][bufferize][NFC] Deallocate all buffers at the end of bufferization
This makes bufferization more modular. This is in preparation of future refactorings. Differential Revision: https://reviews.llvm.org/D121362
This commit is contained in:
parent
875782bd9e
commit
05e0495f1d
|
@ -415,13 +415,22 @@ struct BufferizationState {
|
|||
BufferizationState(const AnalysisState &analysisState)
|
||||
: analysisState(analysisState) {}
|
||||
|
||||
/// Creates a memref allocation with the given type and dynamic extents.
|
||||
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
|
||||
ValueRange dynShape);
|
||||
|
||||
/// Creates a memref allocation for the given shaped value. This function may
|
||||
/// perform additional optimizations such as buffer allocation hoisting.
|
||||
// TODO: Allocation hoisting should be a cleanup pass.
|
||||
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, Value shapedValue);
|
||||
|
||||
/// Return the buffer (memref) for a given OpOperand (tensor). Allocate
|
||||
/// a new buffer and copy over data from the existing buffer if out-of-place
|
||||
/// bufferization was decided.
|
||||
FailureOr<Value>
|
||||
getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
|
||||
bool forceInPlace = false,
|
||||
Optional<Operation *> customCopyInsertionPoint = None) const;
|
||||
Optional<Operation *> customCopyInsertionPoint = None);
|
||||
|
||||
/// Return a reference to the BufferizationOptions.
|
||||
const BufferizationOptions &getOptions() const {
|
||||
|
@ -477,27 +486,6 @@ BaseMemRefType getMemRefType(TensorType tensorType,
|
|||
MemRefLayoutAttrInterface layout = {},
|
||||
Attribute memorySpace = {});
|
||||
|
||||
/// Creates a memref allocation with the given type and dynamic extents.
|
||||
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
|
||||
ValueRange dynShape,
|
||||
const BufferizationOptions &options);
|
||||
|
||||
/// Creates a memref allocation with the given type and dynamic extents. If
|
||||
/// `createDealloc`, a deallocation op is inserted at the point where the
|
||||
/// allocation goes out of scope.
|
||||
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
|
||||
ValueRange dynShape, bool deallocMemref,
|
||||
const BufferizationOptions &options);
|
||||
|
||||
/// Creates a memref allocation for the given shaped value. This function may
|
||||
/// perform additional optimizations such as buffer allocation hoisting. If
|
||||
/// `createDealloc`, a deallocation op is inserted at the point where the
|
||||
/// allocation goes out of scope.
|
||||
// TODO: Allocation hoisting should be a cleanup pass.
|
||||
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, Value shapedValue,
|
||||
bool deallocMemref,
|
||||
const BufferizationOptions &options);
|
||||
|
||||
/// Creates a memref deallocation. The given memref buffer must have been
|
||||
/// allocated using `createAlloc`.
|
||||
LogicalResult createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
|
||||
|
@ -507,6 +495,10 @@ LogicalResult createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
|
|||
LogicalResult createMemCpy(OpBuilder &b, Location loc, Value from, Value to,
|
||||
const BufferizationOptions &options);
|
||||
|
||||
/// Finalize all buffer allocations, i.e., create alloc ops as specified in the
|
||||
/// bufferization options and deallocate all buffers.
|
||||
LogicalResult finalizeBuffers(Operation *op,
|
||||
const BufferizationOptions &options);
|
||||
} // namespace bufferization
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -70,13 +70,6 @@ void populateEliminateBufferizeMaterializationsPatterns(
|
|||
// TODO: Extract `options` from `state` and pass as separate argument.
|
||||
LogicalResult bufferizeOp(Operation *op, const AnalysisState &analysisState);
|
||||
|
||||
/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
|
||||
/// Reuse an existing `BufferizationState`.
|
||||
///
|
||||
/// Note: This function overload is useful for extending the bufferization.
|
||||
LogicalResult bufferizeOp(Operation *op,
|
||||
BufferizationState &bufferizationState);
|
||||
|
||||
/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
|
||||
/// Buffers are duplicated and copied before any tensor use that bufferizes to
|
||||
/// a memory write.
|
||||
|
@ -87,6 +80,16 @@ LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options);
|
|||
|
||||
BufferizationOptions getPartialBufferizationOptions();
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Helper functions for extending Bufferization
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
|
||||
/// Reuse an existing `BufferizationState`.
|
||||
///
|
||||
/// Note: This function overload is useful for extending the bufferization.
|
||||
LogicalResult bufferizeOp(Operation *op,
|
||||
BufferizationState &bufferizationState);
|
||||
} // namespace bufferization
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -42,6 +42,8 @@ constexpr const ::llvm::StringLiteral
|
|||
constexpr const ::llvm::StringLiteral
|
||||
bufferization::BufferizableOpInterface::kInplaceableAttrName;
|
||||
|
||||
static const char *kBufferAllocationAttr = "bufferization.allocation";
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BufferizationOptions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -243,9 +245,10 @@ Value mlir::bufferization::lookupBuffer(RewriterBase &rewriter, Value tensor,
|
|||
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
|
||||
/// a new buffer and copy over data from the existing buffer if out-of-place
|
||||
/// bufferization is necessary.
|
||||
FailureOr<Value> BufferizationState::getBuffer(
|
||||
RewriterBase &rewriter, OpOperand &opOperand, bool forceInPlace,
|
||||
Optional<Operation *> customCopyInsertionPoint) const {
|
||||
FailureOr<Value>
|
||||
BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
|
||||
bool forceInPlace,
|
||||
Optional<Operation *> customCopyInsertionPoint) {
|
||||
const BufferizationOptions &options = analysisState.getOptions();
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
Operation *op = opOperand.getOwner();
|
||||
|
@ -261,8 +264,7 @@ FailureOr<Value> BufferizationState::getBuffer(
|
|||
// allocation should be inserted (in the absence of allocation hoisting).
|
||||
setInsertionPointAfter(rewriter, operandBuffer);
|
||||
// Allocate the result buffer.
|
||||
FailureOr<Value> resultBuffer = createAlloc(rewriter, loc, operandBuffer,
|
||||
options.createDeallocs, options);
|
||||
FailureOr<Value> resultBuffer = createAlloc(rewriter, loc, operandBuffer);
|
||||
if (failed(resultBuffer))
|
||||
return failure();
|
||||
// Do not copy if the last preceding writes of `operand` are ops that do
|
||||
|
@ -358,6 +360,33 @@ bool AlwaysCopyAnalysisState::areEquivalentBufferizedValues(Value v1,
|
|||
// Bufferization-specific scoped alloc/dealloc insertion support.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Create a memref allocation with the given type and dynamic extents.
|
||||
static FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
|
||||
ValueRange dynShape,
|
||||
const BufferizationOptions &options) {
|
||||
if (options.allocationFn)
|
||||
return (*options.allocationFn)(b, loc, type, dynShape,
|
||||
options.bufferAlignment);
|
||||
|
||||
// Default bufferallocation via AllocOp.
|
||||
Value allocated = b.create<memref::AllocOp>(
|
||||
loc, type, dynShape, b.getI64IntegerAttr(options.bufferAlignment));
|
||||
return allocated;
|
||||
}
|
||||
|
||||
/// Creates a memref deallocation. The given memref buffer must have been
|
||||
/// allocated using `createAlloc`.
|
||||
LogicalResult
|
||||
bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
|
||||
const BufferizationOptions &options) {
|
||||
if (options.deallocationFn)
|
||||
return (*options.deallocationFn)(b, loc, allocatedBuffer);
|
||||
|
||||
// Default buffer deallocation via DeallocOp.
|
||||
b.create<memref::DeallocOp>(loc, allocatedBuffer);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Move the insertion point of the given builder to the beginning of a
|
||||
/// surrounding block as much as possible, while not crossing any allocation
|
||||
/// hoisting barriers.
|
||||
|
@ -436,92 +465,39 @@ static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
|
|||
return allocMemRefType;
|
||||
}
|
||||
|
||||
/// Create an AllocOp/DeallocOp pair, where the AllocOp is after
|
||||
/// `shapedValue.getDefiningOp` (or at the top of the block in case of a
|
||||
/// bbArg) and the DeallocOp is at the end of the block.
|
||||
FailureOr<Value>
|
||||
bufferization::createAlloc(OpBuilder &b, Location loc, Value shapedValue,
|
||||
bool deallocMemref,
|
||||
const BufferizationOptions &options) {
|
||||
static Value createBufferAllocation(OpBuilder &b, Location loc, MemRefType type,
|
||||
ValueRange dynShape) {
|
||||
auto allocaOp = b.create<memref::AllocaOp>(loc, type, dynShape);
|
||||
allocaOp->setAttr(kBufferAllocationAttr, b.getUnitAttr());
|
||||
return allocaOp.getResult();
|
||||
}
|
||||
|
||||
/// Create an allocation after `shapedValue.getDefiningOp` (or at the top of the
|
||||
/// block in case of a bbArg).
|
||||
FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc,
|
||||
Value shapedValue) {
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
|
||||
// 1. Create memory allocation.
|
||||
assert(shapedValue.getType().isa<ShapedType>());
|
||||
MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>();
|
||||
SmallVector<Value> dynShape;
|
||||
// Note: getAllocationTypeAndShape also sets the insertion point.
|
||||
MemRefType allocMemRefType =
|
||||
getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
|
||||
FailureOr<Value> allocated =
|
||||
createAlloc(b, loc, allocMemRefType, dynShape, options);
|
||||
if (failed(allocated))
|
||||
return failure();
|
||||
Value casted = allocated.getValue();
|
||||
Value alloc = createBufferAllocation(b, loc, allocMemRefType, dynShape);
|
||||
if (memRefType && memRefType != allocMemRefType) {
|
||||
assert(memref::CastOp::areCastCompatible(allocated.getValue().getType(),
|
||||
memRefType) &&
|
||||
assert(memref::CastOp::areCastCompatible(alloc.getType(), memRefType) &&
|
||||
"createAlloc: cast incompatible");
|
||||
casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
|
||||
alloc = b.create<memref::CastOp>(loc, memRefType, alloc);
|
||||
}
|
||||
|
||||
if (deallocMemref) {
|
||||
// 2. Create memory deallocation.
|
||||
b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator());
|
||||
if (failed(createDealloc(b, loc, allocated.getValue(), options)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
return casted;
|
||||
}
|
||||
|
||||
/// Create a memref allocation with the given type and dynamic extents.
|
||||
FailureOr<Value>
|
||||
bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type,
|
||||
ValueRange dynShape,
|
||||
const BufferizationOptions &options) {
|
||||
if (options.allocationFn)
|
||||
return (*options.allocationFn)(b, loc, type, dynShape,
|
||||
options.bufferAlignment);
|
||||
|
||||
// Default bufferallocation via AllocOp.
|
||||
Value allocated = b.create<memref::AllocOp>(
|
||||
loc, type, dynShape, b.getI64IntegerAttr(options.bufferAlignment));
|
||||
return allocated;
|
||||
}
|
||||
|
||||
/// Create a memref allocation with the given type and dynamic extents. May also
|
||||
/// deallocate the memref again.
|
||||
FailureOr<Value>
|
||||
bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type,
|
||||
ValueRange dynShape, bool deallocMemref,
|
||||
const BufferizationOptions &options) {
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
|
||||
FailureOr<Value> alloc = createAlloc(b, loc, type, dynShape, options);
|
||||
if (failed(alloc))
|
||||
return failure();
|
||||
|
||||
if (deallocMemref) {
|
||||
// Dealloc at the end of the block.
|
||||
b.setInsertionPoint(alloc.getValue().getParentBlock()->getTerminator());
|
||||
if (failed(createDealloc(b, loc, *alloc, options)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
return alloc;
|
||||
}
|
||||
|
||||
/// Create a memref deallocation.
|
||||
LogicalResult
|
||||
bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
|
||||
const BufferizationOptions &options) {
|
||||
if (options.deallocationFn)
|
||||
return (*options.deallocationFn)(b, loc, allocatedBuffer);
|
||||
|
||||
// Default buffer deallocation via DeallocOp.
|
||||
b.create<memref::DeallocOp>(loc, allocatedBuffer);
|
||||
return success();
|
||||
/// Create a memref allocation with the given type and dynamic extents.
|
||||
FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc,
|
||||
MemRefType type,
|
||||
ValueRange dynShape) {
|
||||
return createBufferAllocation(b, loc, type, dynShape);
|
||||
}
|
||||
|
||||
/// Create a memory copy between two memref buffers.
|
||||
|
@ -535,6 +511,41 @@ LogicalResult bufferization::createMemCpy(OpBuilder &b, Location loc,
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
bufferization::finalizeBuffers(Operation *op,
|
||||
const BufferizationOptions &options) {
|
||||
IRRewriter rewriter(op->getContext());
|
||||
|
||||
// Bufferization creates memref.alloca ops. After bufferization, these must be
|
||||
// rewritten to alloc/dealloc ops as specified in the bufferization options.
|
||||
WalkResult status = op->walk([&](memref::AllocaOp allocaOp) {
|
||||
// Ignore memref.alloca ops that were not created by the bufferization.
|
||||
if (!allocaOp->hasAttr(kBufferAllocationAttr))
|
||||
return WalkResult::skip();
|
||||
|
||||
Block *block = allocaOp->getBlock();
|
||||
rewriter.setInsertionPoint(allocaOp);
|
||||
FailureOr<Value> alloc =
|
||||
createAlloc(rewriter, allocaOp->getLoc(), allocaOp.getType(),
|
||||
allocaOp.dynamicSizes(), options);
|
||||
if (failed(alloc))
|
||||
return WalkResult::interrupt();
|
||||
rewriter.replaceOp(allocaOp, *alloc);
|
||||
|
||||
// Stop here if deallocations are deactivated.
|
||||
if (!options.createDeallocs)
|
||||
return WalkResult::advance();
|
||||
|
||||
rewriter.setInsertionPoint(block->getTerminator());
|
||||
if (failed(createDealloc(rewriter, alloc->getLoc(), *alloc, options)))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
return success(!status.wasInterrupted());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Bufferization-specific BlockAndValueMapping support with debugging.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -302,7 +302,11 @@ checkBufferizationResult(Operation *op, const BufferizationOptions &options) {
|
|||
LogicalResult bufferization::bufferizeOp(Operation *op,
|
||||
const AnalysisState &analysisState) {
|
||||
BufferizationState bufferizationState(analysisState);
|
||||
return bufferizeOp(op, bufferizationState);
|
||||
if (failed(bufferizeOp(op, bufferizationState)))
|
||||
return failure();
|
||||
if (failed(finalizeBuffers(op, analysisState.getOptions())))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
|
@ -332,7 +336,10 @@ bufferization::bufferizeOp(Operation *op,
|
|||
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
|
||||
return failure();
|
||||
|
||||
return checkBufferizationResult(op, bufferizationState.getOptions());
|
||||
if (failed(checkBufferizationResult(op, bufferizationState.getOptions())))
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
|
|
@ -1054,6 +1054,10 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize(
|
|||
}
|
||||
}
|
||||
|
||||
// Finalize all buffers.
|
||||
if (failed(finalizeBuffers(moduleOp, options)))
|
||||
return failure();
|
||||
|
||||
// Perform a post-processing pass of layout modification at function boundary
|
||||
// according to the kBufferLayoutAttrName.
|
||||
layoutPostProcessing(moduleOp);
|
||||
|
|
|
@ -235,9 +235,8 @@ struct InitTensorOpInterface
|
|||
if (initTensorOp->getUses().empty())
|
||||
return success();
|
||||
|
||||
FailureOr<Value> alloc =
|
||||
createAlloc(rewriter, initTensorOp->getLoc(), initTensorOp.result(),
|
||||
state.getOptions().createDeallocs, state.getOptions());
|
||||
FailureOr<Value> alloc = state.createAlloc(rewriter, initTensorOp->getLoc(),
|
||||
initTensorOp.result());
|
||||
if (failed(alloc))
|
||||
return failure();
|
||||
replaceOpWithBufferizedValues(rewriter, op, *alloc);
|
||||
|
|
|
@ -228,8 +228,7 @@ struct ExtractSliceOpInterface
|
|||
Value alloc;
|
||||
if (!inplace) {
|
||||
FailureOr<Value> allocOrFailure =
|
||||
createAlloc(rewriter, loc, extractSliceOp.result(),
|
||||
state.getOptions().createDeallocs, state.getOptions());
|
||||
state.createAlloc(rewriter, loc, extractSliceOp.result());
|
||||
if (failed(allocOrFailure))
|
||||
return failure();
|
||||
alloc = *allocOrFailure;
|
||||
|
@ -338,9 +337,7 @@ struct FromElementsOpInterface
|
|||
auto shape = tensorType.getShape();
|
||||
MemRefType resultType = getContiguousMemRefType(tensorType);
|
||||
FailureOr<Value> maybeBuffer =
|
||||
createAlloc(rewriter, loc, resultType, {},
|
||||
/*deallocMemref=*/state.getOptions().createDeallocs,
|
||||
state.getOptions());
|
||||
state.createAlloc(rewriter, loc, resultType, {});
|
||||
if (failed(maybeBuffer))
|
||||
return failure();
|
||||
Value buffer = *maybeBuffer;
|
||||
|
@ -389,10 +386,8 @@ struct GenerateOpInterface
|
|||
Location loc = op->getLoc();
|
||||
MemRefType memrefType =
|
||||
getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>());
|
||||
FailureOr<Value> maybeResult =
|
||||
createAlloc(rewriter, loc, memrefType, generateOp.dynamicExtents(),
|
||||
/*deallocMemref=*/state.getOptions().createDeallocs,
|
||||
state.getOptions());
|
||||
FailureOr<Value> maybeResult = state.createAlloc(
|
||||
rewriter, loc, memrefType, generateOp.dynamicExtents());
|
||||
if (failed(maybeResult))
|
||||
return failure();
|
||||
Value result = *maybeResult;
|
||||
|
|
Loading…
Reference in New Issue