[mlir][Linalg] Remove alloc/dealloc pair as a callback.

The alloc dealloc pair generation callback is really central to the
bufferization algorithm, it modifies the state in a way that affects
correctness. This is not really a configurable option. Moving it to
BufferizationState removes what was probably the reason it was added
as a callback.

Differential Revision: https://reviews.llvm.org/D114417
This commit is contained in:
MaheshRavishankar 2021-11-24 09:26:11 -08:00
parent 06d4a76309
commit 0a58982b08
6 changed files with 131 additions and 135 deletions

View File

@ -212,16 +212,13 @@ struct BufferizationState;
// functions in the future.
struct AllocationCallbacks {
using AllocationFn = std::function<Optional<Value>(
OpBuilder &, Location, MemRefType, const SmallVector<Value> &)>;
OpBuilder &, Location, MemRefType, ArrayRef<Value>)>;
using DeallocationFn = std::function<void(OpBuilder &, Location, Value)>;
using MemCpyFn = std::function<void(OpBuilder &, Location, Value, Value)>;
using CreateAllocDeallocFn =
std::function<Value(OpBuilder &, Location, Value, BufferizationState &)>;
AllocationCallbacks(AllocationFn allocFn, DeallocationFn deallocFn,
MemCpyFn copyFn, CreateAllocDeallocFn allocDeallocFn)
: allocationFn(allocFn), deallocationFn(deallocFn), memCpyFn(copyFn),
createAllocDeallocFn(allocDeallocFn) {}
MemCpyFn copyFn)
: allocationFn(allocFn), deallocationFn(deallocFn), memCpyFn(copyFn) {}
/// A function that allocates memory.
AllocationFn allocationFn;
@ -231,11 +228,6 @@ struct AllocationCallbacks {
/// A function that copies memory between two allocations.
MemCpyFn memCpyFn;
/// A function that creates an alloc-dealloc pair. This function may perform
/// additional optimizations such as buffer allocation hoisting. This function
/// calls `allocationFn` and `deallocationFn` to create (de)allocations.
CreateAllocDeallocFn createAllocDeallocFn;
};
/// BufferizationState keeps track of bufferization state and provides access to
@ -247,6 +239,12 @@ struct BufferizationState {
// BufferizationState should be passed as a reference.
BufferizationState(const BufferizationState &) = delete;
/// A function that creates an alloc-dealloc pair. This function may perform
/// additional optimizations such as buffer allocation hoisting. This function
/// calls `allocationFn` and `deallocationFn` to create (de)allocations.
Value createAllocDeallocFn(OpBuilder &builder, Location loc,
Value shapedValue);
/// Map tensor values to memref buffers.
void mapBuffer(ValueRange tensors, ValueRange buffers);

View File

@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BlockAndValueMapping.h"
@ -360,8 +361,7 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
b.setInsertionPointAfter(operandBuffer.getDefiningOp());
}
// Allocate the result buffer.
Value resultBuffer =
state.allocationFns.createAllocDeallocFn(b, loc, operandBuffer, state);
Value resultBuffer = state.createAllocDeallocFn(b, loc, operandBuffer);
bool skipCopy = false;
// Do not copy if the last preceding write of `operand` is an op that does
// not write (skipping ops that merely create aliases). E.g., InitTensorOp.
@ -442,6 +442,118 @@ mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
return op->emitError() << "unsupported op with tensors";
}
//===----------------------------------------------------------------------===//
// Bufferization-specific scoped alloc/dealloc insertion support.
//===----------------------------------------------------------------------===//
/// Move the insertion point of the given builder to the beginning of a
/// surrounding block as much as possible, while not crossing any allocation
/// hoisting barriers.
static void moveInsertionPointToAllocationHoistingBarrier(OpBuilder &b) {
Operation *op = b.getInsertionBlock()->getParentOp();
while (op) {
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
if (bufferizableOp.isAllocationHoistingBarrier())
break;
op = op->getParentOp();
}
// FuncOp is an allocation hoisting barrier, so the above loop should never
// run out of parents.
assert(
(op && cast<BufferizableOpInterface>(op).isAllocationHoistingBarrier()) &&
"expected traversal to end at allocation hoisting barrier");
// TODO: Handle cases where allocation hoisting barrier has more than one
// region or block.
assert(op->getNumRegions() == 1 &&
"allocation hoisting barriers with >1 regions not supported");
assert(op->getRegion(0).getBlocks().size() == 1 &&
"allocation hoisting barriers with >1 blocks not supported");
b.setInsertionPointToStart(&(op->getRegion(0).front()));
}
/// Compute the type of the `memref` to use for allocating the buffer for
/// `shapedValue`. Also returns (by reference in `dynShape`), the value for the
/// dynamic dimensions in the returned `memref` type. The function may also set
/// the insertion point to an earlier location, where the allocation should
/// happen ("allocation hoisting").
static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
Value shapedValue,
SmallVectorImpl<Value> &dynShape) {
MemRefType allocMemRefType =
getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
// Compute the dynamic part of the shape.
bool reifiedShapes = false;
if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
shapedValue.getDefiningOp())) {
ReifiedRankedShapedTypeDims resultDims;
if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
reifiedShapes = true;
OpResult resultValue = shapedValue.dyn_cast<OpResult>();
auto &shape = resultDims[resultValue.getResultNumber()];
for (auto dim : enumerate(allocMemRefType.getShape()))
if (ShapedType::isDynamic(dim.value()))
dynShape.push_back(shape[dim.index()]);
}
}
if (!reifiedShapes) {
for (auto dim : enumerate(allocMemRefType.getShape()))
if (ShapedType::isDynamic(dim.value())) {
assert((shapedValue.getType().isa<UnrankedMemRefType>() ||
shapedValue.getType().isa<MemRefType>()) &&
"expected MemRef type");
dynShape.push_back(
b.create<memref::DimOp>(loc, shapedValue, dim.index()));
}
}
// If the buffer is statically shaped, try to hoist it to the first enclosing
// parallel region.
// TODO: also hoist in the dynamic case. For now this relies on subsequent
// calls to LICM and buffer hoisting which will most likely not succeed.
// TODO: when packing, allocate a static bounding box which will enable more
// hoisting.
if (dynShape.empty())
moveInsertionPointToAllocationHoistingBarrier(b);
return allocMemRefType;
}
/// Create an Allocop/DeAllocOp pair, where the AllocOp is after
/// `shapedValue.getDefiningOp` (or at the top of the block in case of a
/// bbArg) and the DeallocOp is at the end of the block.
Value mlir::linalg::comprehensive_bufferize::BufferizationState::
createAllocDeallocFn(OpBuilder &b, Location loc, Value shapedValue) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
// 1. Create memory allocation.
assert(shapedValue.getType().isa<ShapedType>());
MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>();
SmallVector<Value> dynShape;
// Note: getAllocationTypeAndShape also sets the insertion point.
MemRefType allocMemRefType =
getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
Optional<Value> allocated =
allocationFns.allocationFn(b, loc, allocMemRefType, dynShape);
// TODO: For now just assert the value is returned. Eventually need to
// error-propagate.
assert(allocated && "allocation failed");
Value casted = allocated.getValue();
if (memRefType && memRefType != allocMemRefType) {
casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue());
}
// 2. Create memory deallocation.
b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator());
allocationFns.deallocationFn(b, loc, allocated.getValue());
return casted;
}
//===----------------------------------------------------------------------===//
// Bufferization-specific BlockAndValueMapping support with debugging.
//===----------------------------------------------------------------------===//

View File

@ -627,118 +627,6 @@ static FunctionType getOrCreateBufferizedFunctionType(
return it2.first->second;
}
//===----------------------------------------------------------------------===//
// Bufferization-specific scoped alloc/dealloc insertion support.
//===----------------------------------------------------------------------===//
/// Move the insertion point of the given builder to the beginning of a
/// surrounding block as much as possible, while not crossing any allocation
/// hoisting barriers.
static void moveInsertionPointToAllocationHoistingBarrier(OpBuilder &b) {
Operation *op = b.getInsertionBlock()->getParentOp();
while (op) {
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
if (bufferizableOp.isAllocationHoistingBarrier())
break;
op = op->getParentOp();
}
// FuncOp is an allocation hoisting barrier, so the above loop should never
// run out of parents.
assert(
(op && cast<BufferizableOpInterface>(op).isAllocationHoistingBarrier()) &&
"expected traversal to end at allocation hoisting barrier");
// TODO: Handle cases where allocation hoisting barrier has more than one
// region or block.
assert(op->getNumRegions() == 1 &&
"allocation hoisting barriers with >1 regions not supported");
assert(op->getRegion(0).getBlocks().size() == 1 &&
"allocation hoisting barriers with >1 blocks not supported");
b.setInsertionPointToStart(&(op->getRegion(0).front()));
}
/// Compute the type of the `memref` to use for allocating the buffer for
/// `shapedValue`. Also returns (by reference in `dynShape`), the value for the
/// dynamic dimensions in the returned `memref` type. The function may also set
/// the insertion point to an earlier location, where the allocation should
/// happen ("allocation hoisting").
static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
Value shapedValue,
SmallVectorImpl<Value> &dynShape) {
MemRefType allocMemRefType =
getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
// Compute the dynamic part of the shape.
bool reifiedShapes = false;
if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
shapedValue.getDefiningOp())) {
ReifiedRankedShapedTypeDims resultDims;
if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
reifiedShapes = true;
OpResult resultValue = shapedValue.dyn_cast<OpResult>();
auto &shape = resultDims[resultValue.getResultNumber()];
for (auto dim : enumerate(allocMemRefType.getShape()))
if (ShapedType::isDynamic(dim.value()))
dynShape.push_back(shape[dim.index()]);
}
}
if (!reifiedShapes) {
for (auto dim : enumerate(allocMemRefType.getShape()))
if (ShapedType::isDynamic(dim.value())) {
assert((shapedValue.getType().isa<UnrankedMemRefType>() ||
shapedValue.getType().isa<MemRefType>()) &&
"expected MemRef type");
dynShape.push_back(
b.create<memref::DimOp>(loc, shapedValue, dim.index()));
}
}
// If the buffer is statically shaped, try to hoist it to the first enclosing
// parallel region.
// TODO: also hoist in the dynamic case. For now this relies on subsequent
// calls to LICM and buffer hoisting which will most likely not succeed.
// TODO: when packing, allocate a static bounding box which will enable more
// hoisting.
if (dynShape.empty())
moveInsertionPointToAllocationHoistingBarrier(b);
return allocMemRefType;
}
/// Create an Allocop/DeAllocOp pair, where the AllocOp is after
/// `shapedValue.getDefiningOp` (or at the top of the block in case of a
/// bbArg) and the DeallocOp is at the end of the block.
static Value createNewAllocDeallocPairForShapedValue(
OpBuilder &b, Location loc, Value shapedValue, BufferizationState &state) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
// 1. Create memory allocation.
assert(shapedValue.getType().isa<ShapedType>());
MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>();
SmallVector<Value> dynShape;
// Note: getAllocationTypeAndShape also sets the insertion point.
MemRefType allocMemRefType =
getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
Optional<Value> allocated =
state.allocationFns.allocationFn(b, loc, allocMemRefType, dynShape);
// TODO: For now just assert the value is returned. Eventually need to
// error-propagate.
assert(allocated && "allocation failed");
Value casted = allocated.getValue();
if (memRefType && memRefType != allocMemRefType) {
casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
state.aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue());
}
// 2. Create memory deallocation.
b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator());
state.allocationFns.deallocationFn(b, loc, allocated.getValue());
return casted;
}
//===----------------------------------------------------------------------===//
// Bufferization as simple BlockAndValueMapping rewrites.
//===----------------------------------------------------------------------===//
@ -1358,7 +1246,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
/// pass. The default currently creates a ranked memref using `memref.alloc`.
static Optional<Value> defaultAllocationFn(OpBuilder &b, Location loc,
MemRefType type,
const SmallVector<Value> &dynShape) {
ArrayRef<Value> dynShape) {
Value allocated = b.create<memref::AllocOp>(
loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
return allocated;
@ -1381,8 +1269,7 @@ static void defaultMemCpyFn(OpBuilder &b, Location loc, Value from, Value to) {
std::unique_ptr<AllocationCallbacks>
mlir::linalg::comprehensive_bufferize::defaultAllocationCallbacks() {
return std::make_unique<AllocationCallbacks>(
defaultAllocationFn, defaultDeallocationFn, defaultMemCpyFn,
createNewAllocDeallocPairForShapedValue);
defaultAllocationFn, defaultDeallocationFn, defaultMemCpyFn);
}
// Default constructor for BufferizationOptions that sets all allocation

View File

@ -167,8 +167,8 @@ struct InitTensorOpInterface
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(initTensorOp);
Value alloc = state.allocationFns.createAllocDeallocFn(
b, initTensorOp->getLoc(), initTensorOp.result(), state);
Value alloc = state.createAllocDeallocFn(b, initTensorOp->getLoc(),
initTensorOp.result());
state.mapBuffer(initTensorOp.result(), alloc);
return success();
}

View File

@ -154,8 +154,7 @@ struct ExtractSliceOpInterface
bool inplace = state.aliasInfo.isInPlace(extractSliceOp->getResult(0));
Value alloc;
if (!inplace)
alloc = state.allocationFns.createAllocDeallocFn(
b, loc, extractSliceOp.result(), state);
alloc = state.createAllocDeallocFn(b, loc, extractSliceOp.result());
// Bufferize to subview.
auto subviewMemRefType =

View File

@ -52,9 +52,9 @@ static void applyEnablingTransformations(ModuleOp moduleOp) {
(void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns));
}
static Optional<Value>
allocationFnUsingAlloca(OpBuilder &b, Location loc, MemRefType type,
const SmallVector<Value> &dynShape) {
static Optional<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
MemRefType type,
ArrayRef<Value> dynShape) {
Value allocated = b.create<memref::AllocaOp>(
loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
return allocated;