[mlir] Allow users of `PromoteBuffersToStackPass` to customize `small buffer` func.

Differential Revision: https://reviews.llvm.org/D96579
This commit is contained in:
Alexander Belyaev 2021-02-12 10:04:41 +01:00
parent 0c118831a3
commit 16213e1f50
2 changed files with 33 additions and 13 deletions

View File

@ -48,6 +48,11 @@ createPromoteBuffersToStackPass(unsigned maxAllocSizeInBytes = 1024,
unsigned bitwidthOfIndexType = 64,
unsigned maxRankOfAllocatedMemRef = 1);
/// Creates a pass that promotes heap-based allocations to stack-based ones.
/// Only buffers smaller with `isSmallAlloc(alloc) == true` are promoted.
std::unique_ptr<Pass>
createPromoteBuffersToStackPass(std::function<bool(Value)> isSmallAlloc);
/// Creates a pass that finalizes a partial bufferization by removing remaining
/// tensor_load and tensor_to_memref operations.
std::unique_ptr<FunctionPass> createFinalizingBufferizePass();

View File

@ -29,9 +29,9 @@ static bool isKnownControlFlowInterface(Operation *op) {
/// Check if the size of the allocation is less than the given size. The
/// transformation is only applied to small buffers since large buffers could
/// exceed the stack space.
static bool isSmallAlloc(Value alloc, unsigned maximumSizeInBytes,
unsigned bitwidthOfIndexType,
unsigned maxRankOfAllocatedMemRef) {
static bool defaultIsSmallAlloc(Value alloc, unsigned maximumSizeInBytes,
unsigned bitwidthOfIndexType,
unsigned maxRankOfAllocatedMemRef) {
auto type = alloc.getType().dyn_cast<ShapedType>();
if (!type || !alloc.getDefiningOp<AllocOp>())
return false;
@ -299,8 +299,7 @@ public:
: BufferPlacementTransformationBase(op) {}
/// Promote buffers to stack-based allocations.
void promote(unsigned maximumSize, unsigned bitwidthOfIndexType,
unsigned maxRankOfAllocatedMemRef) {
void promote(function_ref<bool(Value)> isSmallAlloc) {
for (BufferPlacementAllocs::AllocEntry &entry : allocs) {
Value alloc = std::get<0>(entry);
Operation *dealloc = std::get<1>(entry);
@ -308,9 +307,8 @@ public:
// The transformation is done if the allocation is limited to a given
// size. Furthermore, a deallocation must not be defined for this
// allocation entry and a parent allocation scope must exist.
if (!isSmallAlloc(alloc, maximumSize, bitwidthOfIndexType,
maxRankOfAllocatedMemRef) ||
dealloc || !hasAllocationScope(alloc, aliases))
if (!isSmallAlloc(alloc) || dealloc ||
!hasAllocationScope(alloc, aliases))
continue;
Operation *startOperation = BufferPlacementAllocs::getStartOperation(
@ -359,9 +357,9 @@ struct BufferLoopHoistingPass : BufferLoopHoistingBase<BufferLoopHoistingPass> {
/// The promote buffer to stack pass that tries to convert alloc nodes into
/// alloca nodes.
struct PromoteBuffersToStackPass
: PromoteBuffersToStackBase<PromoteBuffersToStackPass> {
class PromoteBuffersToStackPass
: public PromoteBuffersToStackBase<PromoteBuffersToStackPass> {
public:
PromoteBuffersToStackPass(unsigned maxAllocSizeInBytes,
unsigned bitwidthOfIndexType,
unsigned maxRankOfAllocatedMemRef) {
@ -370,12 +368,24 @@ struct PromoteBuffersToStackPass
this->maxRankOfAllocatedMemRef = maxRankOfAllocatedMemRef;
}
explicit PromoteBuffersToStackPass(std::function<bool(Value)> isSmallAlloc)
: isSmallAlloc(std::move(isSmallAlloc)) {}
void runOnFunction() override {
// Move all allocation nodes and convert candidates into allocas.
BufferPlacementPromotion optimizer(getFunction());
optimizer.promote(this->maxAllocSizeInBytes, this->bitwidthOfIndexType,
this->maxRankOfAllocatedMemRef);
if (isSmallAlloc == nullptr) {
isSmallAlloc = [=](Value alloc) {
return defaultIsSmallAlloc(alloc, maxAllocSizeInBytes,
bitwidthOfIndexType,
maxRankOfAllocatedMemRef);
};
}
optimizer.promote(isSmallAlloc);
}
private:
std::function<bool(Value)> isSmallAlloc;
};
} // end anonymous namespace
@ -395,3 +405,8 @@ mlir::createPromoteBuffersToStackPass(unsigned maxAllocSizeInBytes,
return std::make_unique<PromoteBuffersToStackPass>(
maxAllocSizeInBytes, bitwidthOfIndexType, maxRankOfAllocatedMemRef);
}
std::unique_ptr<Pass>
mlir::createPromoteBuffersToStackPass(std::function<bool(Value)> isSmallAlloc) {
return std::make_unique<PromoteBuffersToStackPass>(std::move(isSmallAlloc));
}