forked from OSchip/llvm-project
[mlir][Linalg] Remove alloc/dealloc pair as a callback.
The alloc dealloc pair generation callback is really central to the bufferization algorithm, it modifies the state in a way that affects correctness. This is not really a configurable option. Moving it to BufferizationState removes what was probably the reason it was added as a callback. Differential Revision: https://reviews.llvm.org/D114417
This commit is contained in:
parent
06d4a76309
commit
0a58982b08
|
@ -212,16 +212,13 @@ struct BufferizationState;
|
|||
// functions in the future.
|
||||
struct AllocationCallbacks {
|
||||
using AllocationFn = std::function<Optional<Value>(
|
||||
OpBuilder &, Location, MemRefType, const SmallVector<Value> &)>;
|
||||
OpBuilder &, Location, MemRefType, ArrayRef<Value>)>;
|
||||
using DeallocationFn = std::function<void(OpBuilder &, Location, Value)>;
|
||||
using MemCpyFn = std::function<void(OpBuilder &, Location, Value, Value)>;
|
||||
using CreateAllocDeallocFn =
|
||||
std::function<Value(OpBuilder &, Location, Value, BufferizationState &)>;
|
||||
|
||||
AllocationCallbacks(AllocationFn allocFn, DeallocationFn deallocFn,
|
||||
MemCpyFn copyFn, CreateAllocDeallocFn allocDeallocFn)
|
||||
: allocationFn(allocFn), deallocationFn(deallocFn), memCpyFn(copyFn),
|
||||
createAllocDeallocFn(allocDeallocFn) {}
|
||||
MemCpyFn copyFn)
|
||||
: allocationFn(allocFn), deallocationFn(deallocFn), memCpyFn(copyFn) {}
|
||||
|
||||
/// A function that allocates memory.
|
||||
AllocationFn allocationFn;
|
||||
|
@ -231,11 +228,6 @@ struct AllocationCallbacks {
|
|||
|
||||
/// A function that copies memory between two allocations.
|
||||
MemCpyFn memCpyFn;
|
||||
|
||||
/// A function that creates an alloc-dealloc pair. This function may perform
|
||||
/// additional optimizations such as buffer allocation hoisting. This function
|
||||
/// calls `allocationFn` and `deallocationFn` to create (de)allocations.
|
||||
CreateAllocDeallocFn createAllocDeallocFn;
|
||||
};
|
||||
|
||||
/// BufferizationState keeps track of bufferization state and provides access to
|
||||
|
@ -247,6 +239,12 @@ struct BufferizationState {
|
|||
// BufferizationState should be passed as a reference.
|
||||
BufferizationState(const BufferizationState &) = delete;
|
||||
|
||||
/// A function that creates an alloc-dealloc pair. This function may perform
|
||||
/// additional optimizations such as buffer allocation hoisting. This function
|
||||
/// calls `allocationFn` and `deallocationFn` to create (de)allocations.
|
||||
Value createAllocDeallocFn(OpBuilder &builder, Location loc,
|
||||
Value shapedValue);
|
||||
|
||||
/// Map tensor values to memref buffers.
|
||||
void mapBuffer(ValueRange tensors, ValueRange buffers);
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/AsmState.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
|
@ -360,8 +361,7 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
|
|||
b.setInsertionPointAfter(operandBuffer.getDefiningOp());
|
||||
}
|
||||
// Allocate the result buffer.
|
||||
Value resultBuffer =
|
||||
state.allocationFns.createAllocDeallocFn(b, loc, operandBuffer, state);
|
||||
Value resultBuffer = state.createAllocDeallocFn(b, loc, operandBuffer);
|
||||
bool skipCopy = false;
|
||||
// Do not copy if the last preceding write of `operand` is an op that does
|
||||
// not write (skipping ops that merely create aliases). E.g., InitTensorOp.
|
||||
|
@ -442,6 +442,118 @@ mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
|
|||
return op->emitError() << "unsupported op with tensors";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Bufferization-specific scoped alloc/dealloc insertion support.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Move the insertion point of the given builder to the beginning of a
|
||||
/// surrounding block as much as possible, while not crossing any allocation
|
||||
/// hoisting barriers.
|
||||
static void moveInsertionPointToAllocationHoistingBarrier(OpBuilder &b) {
|
||||
Operation *op = b.getInsertionBlock()->getParentOp();
|
||||
while (op) {
|
||||
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
|
||||
if (bufferizableOp.isAllocationHoistingBarrier())
|
||||
break;
|
||||
op = op->getParentOp();
|
||||
}
|
||||
|
||||
// FuncOp is an allocation hoisting barrier, so the above loop should never
|
||||
// run out of parents.
|
||||
assert(
|
||||
(op && cast<BufferizableOpInterface>(op).isAllocationHoistingBarrier()) &&
|
||||
"expected traversal to end at allocation hoisting barrier");
|
||||
|
||||
// TODO: Handle cases where allocation hoisting barrier has more than one
|
||||
// region or block.
|
||||
assert(op->getNumRegions() == 1 &&
|
||||
"allocation hoisting barriers with >1 regions not supported");
|
||||
assert(op->getRegion(0).getBlocks().size() == 1 &&
|
||||
"allocation hoisting barriers with >1 blocks not supported");
|
||||
b.setInsertionPointToStart(&(op->getRegion(0).front()));
|
||||
}
|
||||
|
||||
/// Compute the type of the `memref` to use for allocating the buffer for
|
||||
/// `shapedValue`. Also returns (by reference in `dynShape`), the value for the
|
||||
/// dynamic dimensions in the returned `memref` type. The function may also set
|
||||
/// the insertion point to an earlier location, where the allocation should
|
||||
/// happen ("allocation hoisting").
|
||||
static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
|
||||
Value shapedValue,
|
||||
SmallVectorImpl<Value> &dynShape) {
|
||||
MemRefType allocMemRefType =
|
||||
getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
|
||||
|
||||
// Compute the dynamic part of the shape.
|
||||
bool reifiedShapes = false;
|
||||
if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
|
||||
shapedValue.getDefiningOp())) {
|
||||
ReifiedRankedShapedTypeDims resultDims;
|
||||
if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
|
||||
reifiedShapes = true;
|
||||
OpResult resultValue = shapedValue.dyn_cast<OpResult>();
|
||||
auto &shape = resultDims[resultValue.getResultNumber()];
|
||||
for (auto dim : enumerate(allocMemRefType.getShape()))
|
||||
if (ShapedType::isDynamic(dim.value()))
|
||||
dynShape.push_back(shape[dim.index()]);
|
||||
}
|
||||
}
|
||||
|
||||
if (!reifiedShapes) {
|
||||
for (auto dim : enumerate(allocMemRefType.getShape()))
|
||||
if (ShapedType::isDynamic(dim.value())) {
|
||||
assert((shapedValue.getType().isa<UnrankedMemRefType>() ||
|
||||
shapedValue.getType().isa<MemRefType>()) &&
|
||||
"expected MemRef type");
|
||||
dynShape.push_back(
|
||||
b.create<memref::DimOp>(loc, shapedValue, dim.index()));
|
||||
}
|
||||
}
|
||||
|
||||
// If the buffer is statically shaped, try to hoist it to the first enclosing
|
||||
// parallel region.
|
||||
// TODO: also hoist in the dynamic case. For now this relies on subsequent
|
||||
// calls to LICM and buffer hoisting which will most likely not succeed.
|
||||
// TODO: when packing, allocate a static bounding box which will enable more
|
||||
// hoisting.
|
||||
if (dynShape.empty())
|
||||
moveInsertionPointToAllocationHoistingBarrier(b);
|
||||
|
||||
return allocMemRefType;
|
||||
}
|
||||
|
||||
/// Create an Allocop/DeAllocOp pair, where the AllocOp is after
|
||||
/// `shapedValue.getDefiningOp` (or at the top of the block in case of a
|
||||
/// bbArg) and the DeallocOp is at the end of the block.
|
||||
Value mlir::linalg::comprehensive_bufferize::BufferizationState::
|
||||
createAllocDeallocFn(OpBuilder &b, Location loc, Value shapedValue) {
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
|
||||
// 1. Create memory allocation.
|
||||
assert(shapedValue.getType().isa<ShapedType>());
|
||||
MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>();
|
||||
SmallVector<Value> dynShape;
|
||||
// Note: getAllocationTypeAndShape also sets the insertion point.
|
||||
MemRefType allocMemRefType =
|
||||
getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
|
||||
Optional<Value> allocated =
|
||||
allocationFns.allocationFn(b, loc, allocMemRefType, dynShape);
|
||||
// TODO: For now just assert the value is returned. Eventually need to
|
||||
// error-propagate.
|
||||
assert(allocated && "allocation failed");
|
||||
Value casted = allocated.getValue();
|
||||
if (memRefType && memRefType != allocMemRefType) {
|
||||
casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
|
||||
aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue());
|
||||
}
|
||||
|
||||
// 2. Create memory deallocation.
|
||||
b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator());
|
||||
allocationFns.deallocationFn(b, loc, allocated.getValue());
|
||||
return casted;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Bufferization-specific BlockAndValueMapping support with debugging.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -627,118 +627,6 @@ static FunctionType getOrCreateBufferizedFunctionType(
|
|||
return it2.first->second;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Bufferization-specific scoped alloc/dealloc insertion support.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Move the insertion point of the given builder to the beginning of a
|
||||
/// surrounding block as much as possible, while not crossing any allocation
|
||||
/// hoisting barriers.
|
||||
static void moveInsertionPointToAllocationHoistingBarrier(OpBuilder &b) {
|
||||
Operation *op = b.getInsertionBlock()->getParentOp();
|
||||
while (op) {
|
||||
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
|
||||
if (bufferizableOp.isAllocationHoistingBarrier())
|
||||
break;
|
||||
op = op->getParentOp();
|
||||
}
|
||||
|
||||
// FuncOp is an allocation hoisting barrier, so the above loop should never
|
||||
// run out of parents.
|
||||
assert(
|
||||
(op && cast<BufferizableOpInterface>(op).isAllocationHoistingBarrier()) &&
|
||||
"expected traversal to end at allocation hoisting barrier");
|
||||
|
||||
// TODO: Handle cases where allocation hoisting barrier has more than one
|
||||
// region or block.
|
||||
assert(op->getNumRegions() == 1 &&
|
||||
"allocation hoisting barriers with >1 regions not supported");
|
||||
assert(op->getRegion(0).getBlocks().size() == 1 &&
|
||||
"allocation hoisting barriers with >1 blocks not supported");
|
||||
b.setInsertionPointToStart(&(op->getRegion(0).front()));
|
||||
}
|
||||
|
||||
/// Compute the type of the `memref` to use for allocating the buffer for
|
||||
/// `shapedValue`. Also returns (by reference in `dynShape`), the value for the
|
||||
/// dynamic dimensions in the returned `memref` type. The function may also set
|
||||
/// the insertion point to an earlier location, where the allocation should
|
||||
/// happen ("allocation hoisting").
|
||||
static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
|
||||
Value shapedValue,
|
||||
SmallVectorImpl<Value> &dynShape) {
|
||||
MemRefType allocMemRefType =
|
||||
getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
|
||||
|
||||
// Compute the dynamic part of the shape.
|
||||
bool reifiedShapes = false;
|
||||
if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
|
||||
shapedValue.getDefiningOp())) {
|
||||
ReifiedRankedShapedTypeDims resultDims;
|
||||
if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
|
||||
reifiedShapes = true;
|
||||
OpResult resultValue = shapedValue.dyn_cast<OpResult>();
|
||||
auto &shape = resultDims[resultValue.getResultNumber()];
|
||||
for (auto dim : enumerate(allocMemRefType.getShape()))
|
||||
if (ShapedType::isDynamic(dim.value()))
|
||||
dynShape.push_back(shape[dim.index()]);
|
||||
}
|
||||
}
|
||||
|
||||
if (!reifiedShapes) {
|
||||
for (auto dim : enumerate(allocMemRefType.getShape()))
|
||||
if (ShapedType::isDynamic(dim.value())) {
|
||||
assert((shapedValue.getType().isa<UnrankedMemRefType>() ||
|
||||
shapedValue.getType().isa<MemRefType>()) &&
|
||||
"expected MemRef type");
|
||||
dynShape.push_back(
|
||||
b.create<memref::DimOp>(loc, shapedValue, dim.index()));
|
||||
}
|
||||
}
|
||||
|
||||
// If the buffer is statically shaped, try to hoist it to the first enclosing
|
||||
// parallel region.
|
||||
// TODO: also hoist in the dynamic case. For now this relies on subsequent
|
||||
// calls to LICM and buffer hoisting which will most likely not succeed.
|
||||
// TODO: when packing, allocate a static bounding box which will enable more
|
||||
// hoisting.
|
||||
if (dynShape.empty())
|
||||
moveInsertionPointToAllocationHoistingBarrier(b);
|
||||
|
||||
return allocMemRefType;
|
||||
}
|
||||
|
||||
/// Create an Allocop/DeAllocOp pair, where the AllocOp is after
|
||||
/// `shapedValue.getDefiningOp` (or at the top of the block in case of a
|
||||
/// bbArg) and the DeallocOp is at the end of the block.
|
||||
static Value createNewAllocDeallocPairForShapedValue(
|
||||
OpBuilder &b, Location loc, Value shapedValue, BufferizationState &state) {
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
|
||||
// 1. Create memory allocation.
|
||||
assert(shapedValue.getType().isa<ShapedType>());
|
||||
MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>();
|
||||
SmallVector<Value> dynShape;
|
||||
// Note: getAllocationTypeAndShape also sets the insertion point.
|
||||
MemRefType allocMemRefType =
|
||||
getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
|
||||
Optional<Value> allocated =
|
||||
state.allocationFns.allocationFn(b, loc, allocMemRefType, dynShape);
|
||||
// TODO: For now just assert the value is returned. Eventually need to
|
||||
// error-propagate.
|
||||
assert(allocated && "allocation failed");
|
||||
Value casted = allocated.getValue();
|
||||
if (memRefType && memRefType != allocMemRefType) {
|
||||
casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
|
||||
state.aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue());
|
||||
}
|
||||
|
||||
// 2. Create memory deallocation.
|
||||
b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator());
|
||||
state.allocationFns.deallocationFn(b, loc, allocated.getValue());
|
||||
return casted;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Bufferization as simple BlockAndValueMapping rewrites.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1358,7 +1246,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
|
|||
/// pass. The default currently creates a ranked memref using `memref.alloc`.
|
||||
static Optional<Value> defaultAllocationFn(OpBuilder &b, Location loc,
|
||||
MemRefType type,
|
||||
const SmallVector<Value> &dynShape) {
|
||||
ArrayRef<Value> dynShape) {
|
||||
Value allocated = b.create<memref::AllocOp>(
|
||||
loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
|
||||
return allocated;
|
||||
|
@ -1381,8 +1269,7 @@ static void defaultMemCpyFn(OpBuilder &b, Location loc, Value from, Value to) {
|
|||
std::unique_ptr<AllocationCallbacks>
|
||||
mlir::linalg::comprehensive_bufferize::defaultAllocationCallbacks() {
|
||||
return std::make_unique<AllocationCallbacks>(
|
||||
defaultAllocationFn, defaultDeallocationFn, defaultMemCpyFn,
|
||||
createNewAllocDeallocPairForShapedValue);
|
||||
defaultAllocationFn, defaultDeallocationFn, defaultMemCpyFn);
|
||||
}
|
||||
|
||||
// Default constructor for BufferizationOptions that sets all allocation
|
||||
|
|
|
@ -167,8 +167,8 @@ struct InitTensorOpInterface
|
|||
OpBuilder::InsertionGuard g(b);
|
||||
b.setInsertionPoint(initTensorOp);
|
||||
|
||||
Value alloc = state.allocationFns.createAllocDeallocFn(
|
||||
b, initTensorOp->getLoc(), initTensorOp.result(), state);
|
||||
Value alloc = state.createAllocDeallocFn(b, initTensorOp->getLoc(),
|
||||
initTensorOp.result());
|
||||
state.mapBuffer(initTensorOp.result(), alloc);
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -154,8 +154,7 @@ struct ExtractSliceOpInterface
|
|||
bool inplace = state.aliasInfo.isInPlace(extractSliceOp->getResult(0));
|
||||
Value alloc;
|
||||
if (!inplace)
|
||||
alloc = state.allocationFns.createAllocDeallocFn(
|
||||
b, loc, extractSliceOp.result(), state);
|
||||
alloc = state.createAllocDeallocFn(b, loc, extractSliceOp.result());
|
||||
|
||||
// Bufferize to subview.
|
||||
auto subviewMemRefType =
|
||||
|
|
|
@ -52,9 +52,9 @@ static void applyEnablingTransformations(ModuleOp moduleOp) {
|
|||
(void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns));
|
||||
}
|
||||
|
||||
static Optional<Value>
|
||||
allocationFnUsingAlloca(OpBuilder &b, Location loc, MemRefType type,
|
||||
const SmallVector<Value> &dynShape) {
|
||||
static Optional<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
|
||||
MemRefType type,
|
||||
ArrayRef<Value> dynShape) {
|
||||
Value allocated = b.create<memref::AllocaOp>(
|
||||
loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
|
||||
return allocated;
|
||||
|
|
Loading…
Reference in New Issue