[mlir][Linalg] Allow comprehensive bufferization to use callbacks for alloc/dealloc.

Using callbacks for allocation/deallocation allows users to override
the default.
Also add an option to comprehensive bufferization pass to use `alloca`
instead of `alloc`s. Note that this option is just for testing. The
option to use `alloca` does not work well with the option to allow for
returning memrefs.

Differential Revision: https://reviews.llvm.org/D112166
This commit is contained in:
MaheshRavishankar 2021-10-24 22:30:10 -07:00
parent 416fd03708
commit c86f218fe4
6 changed files with 288 additions and 108 deletions

View File

@ -40,7 +40,10 @@ def LinalgComprehensiveModuleBufferize :
"Only runs inplaceability analysis (for testing purposes only)">,
Option<"allowReturnMemref", "allow-return-memref", "bool",
/*default=*/"false",
"Allows the return of memrefs (for testing purposes only)">
"Allows the return of memrefs (for testing purposes only)">,
Option<"useAlloca", "use-alloca", "bool",
/*default=*/"false",
"Use stack allocations for memrefs (for testing purposes only)">
];
let constructor = "mlir::createLinalgComprehensiveModuleBufferizePass()";
}

View File

@ -175,14 +175,36 @@ LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
BufferizationAliasInfo &aliasInfo,
const DominanceInfo &domInfo);
/// Default allocation function that is used by the comprehensive bufferization
/// pass. The default currently creates a ranked memref using `memref.alloc`.
Optional<Value> defaultAllocationFn(OpBuilder &b, Location loc,
Value shapedValue);
/// Default deallocation function that is used by the comprehensive
/// bufferization pass. It expects to recieve back the value called from the
/// `defaultAllocationFn`.
void defaultDeallocationFn(OpBuilder &b, Location loc, Value allocatedBuffer);
/// Callback functions that are used by the comprehensive bufferization pass to
/// allocate/deallocate memory. These default to use the
/// `defaultAllocationFn`/`defaultDeallocationFn`, but can be overridden by the
/// caller. The `deallocationFn` is gauranteed to recieve the `Value` returned
/// by the `allocationFn`.
struct AllocationCallbacks {
std::function<Optional<Value>(OpBuilder &b, Location loc, Value shapedValue)>
allocationFn = defaultAllocationFn;
std::function<void(OpBuilder &b, Location loc, Value v)> deallocationFn =
defaultDeallocationFn;
};
/// Bufferize one particular op.
/// `bufferizedFunctionTypes` (resp. `globalCreator`) are expected to be
/// non-null if `op` is a CallOpInterface (resp. GlobalCreator).
LogicalResult
bufferizeOp(Operation *op, BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,
DenseMap<FuncOp, FunctionType> *bufferizedFunctionTypes = nullptr,
GlobalCreator *globalCreator = nullptr);
AllocationCallbacks allocationFns,
DenseMap<FuncOp, FunctionType> *bufferizedFunctionTypes = nullptr);
} // namespace linalg
} // namespace mlir

View File

@ -33,6 +33,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRAnalysis
MLIRArithmetic
MLIRComplex
MLIRInferTypeOpInterface
MLIRIR
MLIRMemRef
MLIRLinalgAnalysis

View File

@ -118,12 +118,12 @@
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/BufferUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SetVector.h"
@ -983,7 +983,6 @@ bool BufferizationAliasInfo::hasReadAfterWriteInterference(
const DenseSet<OpOperand *> &usesRead,
const DenseSet<OpOperand *> &usesWrite,
const DominanceInfo &domInfo) const {
for (OpOperand *uRead : usesRead) {
Operation *readingOp = uRead->getOwner();
@ -1415,66 +1414,27 @@ Operation *getFirstParentOfType(Value v) {
/// Create an Allocop/DeAllocOp pair, where the AllocOp is after
/// `shapedValue.getDefiningOp` (or at the top of the block in case of a
/// bbArg) and the DeallocOp is at the end of the block.
static Value
createNewAllocDeallocPairForShapedValue(OpBuilder &b, Location loc,
Value shapedValue,
BufferizationAliasInfo &aliasInfo) {
static Value createNewAllocDeallocPairForShapedValue(
OpBuilder &b, Location loc, Value shapedValue,
BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
// TODO: non-zero address space.
// TODO: layout information if relevant.
// Cannot allocate an unranked memref so just always go for the contiguous
// form.
MemRefType allocMemRefType =
getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
assert(shapedValue.getType().isa<ShapedType>());
MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>();
memRefType = memRefType ? memRefType : allocMemRefType;
if (auto bbArg = shapedValue.dyn_cast<BlockArgument>()) {
b.setInsertionPointToStart(bbArg.getOwner());
loc = bbArg.getOwner()->getParentOp()->getLoc();
} else {
b.setInsertionPoint(shapedValue.getDefiningOp());
loc = shapedValue.getDefiningOp()->getLoc();
Optional<Value> allocated = allocationFns.allocationFn(b, loc, shapedValue);
// TODO: For now just assert the value is returned. Eventually need to
// error-propagate.
assert(allocated && "allocation failed");
Value casted = allocated.getValue();
MemRefType allocMemRefType = allocated->getType().cast<MemRefType>();
if (memRefType && memRefType != allocMemRefType) {
casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue());
}
// Compute the dynamic part of the shape.
SmallVector<Value> dynShape;
for (auto dim : enumerate(memRefType.getShape()))
if (dim.value() == ShapedType::kDynamicSize)
dynShape.push_back(createOrFoldDimOp(b, loc, shapedValue, dim.index()));
// If the buffer is statically shaped, try to hoist it to the first enclosing
// parallel region.
// TODO: this concept of parallel region and threadlocal needs interfaces.
// TODO: also hoist in the dynamic case. For now this relies on subsequent
// calls to LICM and buffer hoisting which will most likely not succeed.
// TODO: when packing, allocate a static bounding box which will enable more
// hoisting.
Value allocated;
{ // Guarded insertion point to potentially hoist the AllocOp.
OpBuilder::InsertionGuard g(b);
if (dynShape.empty()) {
Operation *parent =
getFirstParentOfType<FuncOp, TiledLoopOp, scf::ParallelOp,
AffineParallelOp>(shapedValue);
if (parent)
b.setInsertionPointToStart(&(parent->getRegion(0).front()));
}
allocated = b.create<memref::AllocOp>(
loc, allocMemRefType, dynShape, b.getI64IntegerAttr(kBufferAlignments));
aliasInfo.createAliasInfoEntry(allocated);
}
Value casted = allocated;
if (memRefType != allocMemRefType) {
casted = b.create<memref::CastOp>(loc, memRefType, allocated);
aliasInfo.insertNewBufferEquivalence(casted, allocated);
}
b.setInsertionPoint(allocated.getParentBlock()->getTerminator());
b.create<memref::DeallocOp>(loc, allocated);
allocationFns.deallocationFn(b, loc, allocated.getValue());
return casted;
}
@ -1488,6 +1448,7 @@ createNewAllocDeallocPairForShapedValue(OpBuilder &b, Location loc,
static Value getResultBuffer(OpBuilder &b, OpResult result,
const BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks allocationFns,
bool skipCopy = false) {
OpBuilder::InsertionGuard guard(b);
Operation *op = result.getOwner();
@ -1515,8 +1476,8 @@ static Value getResultBuffer(OpBuilder &b, OpResult result,
"ops with multiple aliasing OpOperands cannot bufferize out-of-place");
Location loc = op->getLoc();
// Allocate the result buffer.
Value resultBuffer =
createNewAllocDeallocPairForShapedValue(b, loc, operand, aliasInfo);
Value resultBuffer = createNewAllocDeallocPairForShapedValue(
b, loc, operand, aliasInfo, allocationFns);
// Do not copy the result of an InitTensorOp.
if (isInitTensorOp(operand))
skipCopy = true;
@ -1538,11 +1499,10 @@ static Value getResultBuffer(OpBuilder &b, OpResult result,
/// Helper function for LinalgOp bufferization.
/// When allocating a new buffer, analyze whether `op` wants to read form that
/// buffer. Only in that case, a copy of the result buffer may be needed.
static LogicalResult
allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
SmallVectorImpl<Value> &resultBuffers,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo) {
static LogicalResult allocateBuffersForResults(
OpBuilder &b, Location loc, LinalgOp op,
SmallVectorImpl<Value> &resultBuffers, BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(op);
@ -1553,7 +1513,8 @@ allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
OpResult opResult = getInplaceableOpResult(*opOperand);
assert(opResult && "could not find correspond OpResult");
bool skipCopy = !op.payloadUsesValueFromOperand(opOperand);
Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo, skipCopy);
Value resultBuffer =
getResultBuffer(b, opResult, bvm, aliasInfo, allocationFns, skipCopy);
if (!resultBuffer)
return failure();
resultBuffers.push_back(resultBuffer);
@ -1568,7 +1529,8 @@ allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
/// Generic conversion for any LinalgOp on tensors.
static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo) {
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFns) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
@ -1591,7 +1553,7 @@ static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
SmallVector<Value> newOutputBuffers;
// Try to allocate new buffers depending on op's inplace semantics.
if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm,
aliasInfo)))
aliasInfo, allocationFns)))
return failure();
// Clone the newly bufferized op.
@ -1616,7 +1578,7 @@ static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
/// to allow FuncOp that are inplaceable to write inPlace.
static LogicalResult
bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,
BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns,
DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
FuncOp funcOp = getCalledFunction(callOp);
assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
@ -1755,12 +1717,14 @@ bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
/// tensor::CastOp bufferizes to memref::CastOp.
static LogicalResult bufferize(OpBuilder &b, tensor::CastOp castOp,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo) {
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFn) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(castOp);
Value resultBuffer = getResultBuffer(b, castOp->getResult(0), bvm, aliasInfo);
Value resultBuffer =
getResultBuffer(b, castOp->getResult(0), bvm, aliasInfo, allocationFn);
if (!resultBuffer)
return failure();
Type sourceType = resultBuffer.getType();
@ -1786,10 +1750,15 @@ static LogicalResult bufferize(OpBuilder &b, tensor::CastOp castOp,
static LogicalResult bufferize(OpBuilder &b, arith::ConstantOp constantOp,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,
GlobalCreator &globalCreator) {
BufferizationAliasInfo &aliasInfo) {
assert(constantOp.getType().dyn_cast<RankedTensorType>() &&
"not a constant ranked tensor");
auto moduleOp = constantOp->getParentOfType<ModuleOp>();
if (!moduleOp) {
return constantOp.emitError(
"cannot bufferize constants not within builtin.module op");
}
GlobalCreator globalCreator(moduleOp);
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
@ -1824,7 +1793,8 @@ static LogicalResult bufferize(OpBuilder &b, tensor::DimOp dimOp,
static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo) {
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFn) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
@ -1837,7 +1807,8 @@ static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp,
"unsupported unranked tensor");
// TODO: More general: Matching bbArg does not bufferize to a read.
Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo);
Value resultBuffer =
getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn);
if (!resultBuffer)
return failure();
@ -1880,7 +1851,8 @@ static LogicalResult bufferize(OpBuilder &b, scf::IfOp ifOp,
/// FuncOp always creates TensorToMemRef ops.
static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo) {
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFn) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPointToStart(&funcOp.body().front());
@ -1906,7 +1878,8 @@ static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
/// TODO: consider hoisting across function boundaries prior to bufferization.
static LogicalResult bufferize(OpBuilder &b, InitTensorOp initTensorOp,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo) {
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFn) {
// The InitTensorOp may have been eliminated.
if (initTensorOp->getUses().empty())
return success();
@ -1916,7 +1889,8 @@ static LogicalResult bufferize(OpBuilder &b, InitTensorOp initTensorOp,
b.setInsertionPoint(initTensorOp);
Value alloc = createNewAllocDeallocPairForShapedValue(
b, initTensorOp->getLoc(), initTensorOp.result(), aliasInfo);
b, initTensorOp->getLoc(), initTensorOp.result(), aliasInfo,
allocationFn);
map(bvm, initTensorOp.result(), alloc);
return success();
}
@ -1949,7 +1923,8 @@ static LogicalResult bufferize(OpBuilder &b, ReturnOp returnOp,
/// Bufferization for TiledLoopOp..
static LogicalResult bufferize(OpBuilder &b, TiledLoopOp tiledLoopOp,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo) {
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFn) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
@ -1989,7 +1964,8 @@ static LogicalResult bufferize(OpBuilder &b, TiledLoopOp tiledLoopOp,
const OpResult &opResult = tiledLoopOp->getResult(resultIndex);
OpOperand &yieldOperand = yieldOp->getOpOperand(resultIndex);
Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo);
Value resultBuffer =
getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn);
if (!resultBuffer)
return failure();
@ -2073,7 +2049,8 @@ static LogicalResult bufferize(OpBuilder &b, TiledLoopOp tiledLoopOp,
/// isolation.
static LogicalResult bufferize(OpBuilder &b, ExtractSliceOp extractSliceOp,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo) {
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFn) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
@ -2093,7 +2070,7 @@ static LogicalResult bufferize(OpBuilder &b, ExtractSliceOp extractSliceOp,
auto inPlace = getInPlace(extractSliceOp->getResult(0));
if (inPlace != InPlaceSpec::True)
alloc = createNewAllocDeallocPairForShapedValue(
b, loc, extractSliceOp.result(), aliasInfo);
b, loc, extractSliceOp.result(), aliasInfo, allocationFn);
// Set insertion point now that potential alloc/dealloc are introduced.
b.setInsertionPoint(extractSliceOp);
@ -2125,7 +2102,8 @@ static LogicalResult bufferize(OpBuilder &b, ExtractSliceOp extractSliceOp,
static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo) {
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFn) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(insertSliceOp);
@ -2140,8 +2118,8 @@ static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp,
// TODO: be very loud about it or even consider failing the pass.
// Alloc a copy for `insertSliceOp.dest()`, it will become the result
// buffer.
Value dstMemref =
getResultBuffer(b, insertSliceOp->getResult(0), bvm, aliasInfo);
Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), bvm,
aliasInfo, allocationFn);
if (!dstMemref)
return failure();
auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
@ -2184,7 +2162,8 @@ static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp,
static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo) {
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFn) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(op);
@ -2205,7 +2184,8 @@ static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op,
// Leave the previous transfer_write to dead code as it still has uses at
// this point.
auto writeOp = cast<vector::TransferWriteOp>(op.getOperation());
Value resultBuffer = getResultBuffer(b, op->getResult(0), bvm, aliasInfo);
Value resultBuffer =
getResultBuffer(b, op->getResult(0), bvm, aliasInfo, allocationFn);
if (!resultBuffer)
return failure();
b.create<vector::TransferWriteOp>(
@ -2436,18 +2416,107 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
// Bufferization entry-point for functions.
//===----------------------------------------------------------------------===//
/// Compute the type of the `memref` to use for allocating the buffer for
/// `shapedValue`. Also returns (by reference in `dynShape`), the value for the
/// dynamic dimensions in the returned `memref` type. The function also sets the
/// insertion point of the builder `b` to the position where the allocation is
/// to be inserted.
static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
Value shapedValue,
SmallVectorImpl<Value> &dynShape) {
MemRefType allocMemRefType =
getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
if (auto bbArg = shapedValue.dyn_cast<BlockArgument>()) {
b.setInsertionPointToStart(bbArg.getOwner());
loc = bbArg.getOwner()->getParentOp()->getLoc();
} else {
b.setInsertionPoint(shapedValue.getDefiningOp());
loc = shapedValue.getDefiningOp()->getLoc();
}
// Compute the dynamic part of the shape.
bool foundDynamicShapes = false;
if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
shapedValue.getDefiningOp())) {
ReifiedRankedShapedTypeDims resultDims;
if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
foundDynamicShapes = true;
OpResult resultValue = shapedValue.dyn_cast<OpResult>();
auto &shape = resultDims[resultValue.getResultNumber()];
for (auto dim : enumerate(allocMemRefType.getShape()))
if (dim.value() == ShapedType::kDynamicSize)
dynShape.push_back(shape[dim.index()]);
}
}
if (!foundDynamicShapes) {
for (auto dim : enumerate(allocMemRefType.getShape()))
if (dim.value() == ShapedType::kDynamicSize)
dynShape.push_back(createOrFoldDimOp(b, loc, shapedValue, dim.index()));
}
// If the buffer is statically shaped, try to hoist it to the first enclosing
// parallel region.
// TODO: this concept of parallel region and threadlocal needs interfaces.
// TODO: also hoist in the dynamic case. For now this relies on subsequent
// calls to LICM and buffer hoisting which will most likely not succeed.
// TODO: when packing, allocate a static bounding box which will enable more
// hoisting.
if (dynShape.empty()) {
Operation *parent =
getFirstParentOfType<FuncOp, TiledLoopOp, scf::ParallelOp,
AffineParallelOp>(shapedValue);
if (parent)
b.setInsertionPointToStart(&(parent->getRegion(0).front()));
}
return allocMemRefType;
}
Optional<Value> mlir::linalg::defaultAllocationFn(OpBuilder &b, Location loc,
Value shapedValue) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
SmallVector<Value> dynShape;
MemRefType allocMemRefType =
getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
Value allocated = b.create<memref::AllocOp>(
loc, allocMemRefType, dynShape, b.getI64IntegerAttr(kBufferAlignments));
return allocated;
}
static Optional<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
Value shapedValue) {
OpBuilder::InsertionGuard g(b);
SmallVector<Value> dynShape;
MemRefType allocMemRefType =
getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
Value allocated = b.create<memref::AllocaOp>(
loc, allocMemRefType, dynShape, b.getI64IntegerAttr(kBufferAlignments));
return allocated;
}
void mlir::linalg::defaultDeallocationFn(OpBuilder &b, Location loc,
Value allocatedBuffer) {
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(allocatedBuffer.getParentBlock()->getTerminator());
b.create<memref::DeallocOp>(loc, allocatedBuffer);
}
LogicalResult mlir::linalg::bufferizeOp(
Operation *op, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo,
DenseMap<FuncOp, FunctionType> *bufferizedFunctionTypes,
GlobalCreator *globalCreator) {
AllocationCallbacks allocationFns,
DenseMap<FuncOp, FunctionType> *bufferizedFunctionTypes) {
OpBuilder b(op->getContext());
return TypeSwitch<Operation *, LogicalResult>(op)
// Skip BufferCast and TensorLoad ops.
.Case<memref::BufferCastOp, memref::TensorLoadOp>(
[&](auto) { return success(); })
.Case<tensor::CastOp, tensor::DimOp, ExtractSliceOp, scf::ForOp,
InitTensorOp, InsertSliceOp, tensor::ExtractOp, LinalgOp, ReturnOp,
TiledLoopOp, VectorTransferOpInterface, linalg::YieldOp,
.Case<ExtractSliceOp, InitTensorOp, InsertSliceOp, LinalgOp, scf::ForOp,
tensor::CastOp, TiledLoopOp, VectorTransferOpInterface>(
[&](auto op) {
LDBG("Begin bufferize:\n" << op << '\n');
return bufferize(b, op, bvm, aliasInfo, allocationFns);
})
.Case<tensor::DimOp, tensor::ExtractOp, ReturnOp, linalg::YieldOp,
scf::YieldOp>([&](auto op) {
LDBG("Begin bufferize:\n" << op << '\n');
return bufferize(b, op, bvm, aliasInfo);
@ -2464,15 +2533,14 @@ LogicalResult mlir::linalg::bufferizeOp(
if (!bufferizedFunctionTypes)
llvm_unreachable(
"null bufferizedFunctionTypes when bufferizing CallOpInterface");
return bufferize(b, op, bvm, aliasInfo, *bufferizedFunctionTypes);
return bufferize(b, op, bvm, aliasInfo, allocationFns,
*bufferizedFunctionTypes);
})
.Case([&](arith::ConstantOp op) {
if (!isaTensor(op.getResult().getType()))
return success();
LDBG("Begin bufferize:\n" << op << '\n');
if (!globalCreator)
llvm_unreachable("null globalCreator when bufferizing ConstantOp");
return bufferize(b, op, bvm, aliasInfo, *globalCreator);
return bufferize(b, op, bvm, aliasInfo);
})
.Default([&](Operation *op) -> LogicalResult {
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
@ -2485,15 +2553,13 @@ LogicalResult mlir::linalg::bufferizeOp(
static LogicalResult bufferizeFuncOpInternals(
FuncOp funcOp, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo,
DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes,
GlobalCreator &globalCreator) {
AllocationCallbacks &allocationFns,
DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
LLVM_DEBUG(llvm::dbgs() << "\n\n");
LDBG("Begin BufferizeFuncOpInternals:\n" << funcOp << '\n');
OpBuilder b(funcOp->getContext());
// Start by bufferizing `funcOp` arguments.
if (failed(bufferize(b, funcOp, bvm, aliasInfo)))
/// Start by bufferizing `funcOp` arguments.
if (failed(bufferize(b, funcOp, bvm, aliasInfo, allocationFns)))
return failure();
// Cannot erase ops during the traversal. Do that afterwards.
@ -2516,13 +2582,13 @@ static LogicalResult bufferizeFuncOpInternals(
}
for (Operation *op : llvm::reverse(preorderBufferize))
if (failed(bufferizeOp(op, bvm, aliasInfo, &bufferizedFunctionTypes,
&globalCreator)))
if (failed(bufferizeOp(op, bvm, aliasInfo, allocationFns,
&bufferizedFunctionTypes)))
return failure();
if (!bufferizedOps.contains(op) &&
failed(bufferizeOp(op, bvm, aliasInfo, &bufferizedFunctionTypes,
&globalCreator)))
failed(bufferizeOp(op, bvm, aliasInfo, allocationFns,
&bufferizedFunctionTypes)))
return failure();
// Register post-walk erasure, if necessary.
@ -2793,12 +2859,19 @@ namespace {
struct LinalgComprehensiveModuleBufferize
: public LinalgComprehensiveModuleBufferizeBase<
LinalgComprehensiveModuleBufferize> {
LinalgComprehensiveModuleBufferize() {}
LinalgComprehensiveModuleBufferize(
const LinalgComprehensiveModuleBufferize &p) {}
void runOnOperation() override;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect, memref::MemRefDialect>();
}
private:
std::unique_ptr<AllocationCallbacks> allocationFns;
};
} // end namespace
@ -2983,6 +3056,22 @@ static LogicalResult runInitTensorElimination(FuncOp funcOp,
}
void LinalgComprehensiveModuleBufferize::runOnOperation() {
if (!allocationFns) {
// The allocation functions to use needs to be set here. The flag for the
// pass and flag for the use of alloca map to LLVM command line
// options. These being static global objects have no set order in which
// they are defined. So ideally this should be in the constructor, but the
// constructor might be called before the flag is initialized using the
// command line option. So this is set up at the start of the pass.
if (useAlloca) {
AllocationCallbacks allocaAllocationFns = {
allocationFnUsingAlloca, [](OpBuilder &b, Location loc, Value v) {}};
allocationFns =
std::make_unique<AllocationCallbacks>(std::move(allocaAllocationFns));
} else {
allocationFns = std::make_unique<AllocationCallbacks>();
}
}
ModuleOp moduleOp = getOperation();
applyEnablingTransformations(moduleOp);
@ -2992,7 +3081,6 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
return signalPassFailure();
GlobalCreator globalCreator(moduleOp);
DominanceInfo domInfo(moduleOp);
BufferizationAliasInfo aliasInfo(moduleOp);
// Interestingly, all function args that are not visible outside of a module
@ -3032,8 +3120,8 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
if (!testAnalysisOnly) {
BlockAndValueMapping tensorToBufferMap;
if (failed(bufferizeFuncOpInternals(funcOp, tensorToBufferMap, aliasInfo,
bufferizedFunctionTypes,
globalCreator))) {
*allocationFns,
bufferizedFunctionTypes))) {
signalPassFailure();
return;
}

View File

@ -0,0 +1,65 @@
// RUN: mlir-opt %s -pass-pipeline="linalg-comprehensive-module-bufferize{allow-return-memref use-alloca}" -split-input-file | FileCheck %s
// CHECK-DAG: #[[$DYN_0D_MAP:.*]] = affine_map<()[s0] -> (s0)>
// CHECK-DAG: #[[$DYN_1D_MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
// CHECK: func @init_and_dot(
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref<64xf32, #[[$DYN_1D_MAP]]>
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: memref<64xf32, #[[$DYN_1D_MAP]]>
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: memref<f32, #[[$DYN_0D_MAP]]>
func @init_and_dot(%a: tensor<64xf32>, %b: tensor<64xf32>, %c: tensor<f32>) -> tensor<f32> {
// CHECK-NEXT: %[[C0:.*]] = arith.constant 0{{.*}} : f32
%v0 = arith.constant 0.0 : f32
// CHECK-NEXT: linalg.fill(%[[C0]], %[[C]]) : f32, memref<f32, #[[$DYN_0D_MAP]]>
%d = linalg.fill(%v0, %c) : f32, tensor<f32> -> tensor<f32>
// CHECK-NEXT: linalg.dot ins(%[[A]], %[[B]] : memref<64xf32, #[[$DYN_1D_MAP]]>, memref<64xf32, #[[$DYN_1D_MAP]]>) outs(%[[C]] : memref<f32, #[[$DYN_0D_MAP]]>)
%e = linalg.dot ins(%a, %b : tensor<64xf32>,tensor<64xf32>)
outs(%d: tensor<f32>) -> tensor<f32>
// CHECK-NEXT: return
return %e : tensor<f32>
}
// CHECK: func @main()
func @main() {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0{{.*}} : f32
// CHECK-DAG: %[[C1:.*]] = arith.constant 1{{.*}} : f32
// CHECK-DAG: %[[C2:.*]] = arith.constant 2{{.*}} : f32
%v0 = arith.constant 0.0 : f32
%v1 = arith.constant 1.0 : f32
%v2 = arith.constant 2.0 : f32
// CHECK-NEXT: %[[C:.*]] = memref.alloca() {alignment = 128 : i64} : memref<f32>
// CHECK-NEXT: %[[B:.*]] = memref.alloca() {alignment = 128 : i64} : memref<64xf32>
// CHECK-NEXT: %[[A:.*]] = memref.alloca() {alignment = 128 : i64} : memref<64xf32>
%A = linalg.init_tensor [64] : tensor<64xf32>
%B = linalg.init_tensor [64] : tensor<64xf32>
%C = linalg.init_tensor [] : tensor<f32>
// CHECK-NEXT: linalg.fill(%[[C1]], %[[A]]) : f32, memref<64xf32>
// CHECK-NEXT: linalg.fill(%[[C2]], %[[B]]) : f32, memref<64xf32>
// CHECK-NEXT: linalg.fill(%[[C0]], %[[C]]) : f32, memref<f32>
%AA = linalg.fill(%v1, %A) : f32, tensor<64xf32> -> tensor<64xf32>
%BB = linalg.fill(%v2, %B) : f32, tensor<64xf32> -> tensor<64xf32>
%CC = linalg.fill(%v0, %C) : f32, tensor<f32> -> tensor<f32>
// CHECK-NEXT: %[[cA:.*]] = memref.cast %[[A]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]>
// CHECK-NEXT: %[[cB:.*]] = memref.cast %[[B]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]>
// CHECK-NEXT: %[[cC:.*]] = memref.cast %[[C]] : memref<f32> to memref<f32, #[[$DYN_0D_MAP]]>
// CHECK-NEXT: call @init_and_dot(%[[cA]], %[[cB]], %[[cC]])
%res = call @init_and_dot(%AA, %BB, %CC) :
(tensor<64xf32>, tensor<64xf32>, tensor<f32>) -> tensor<f32>
// CHECK-NEXT: %[[dC:.*]] = memref.cast %[[C]] : memref<f32> to memref<*xf32>
%res2 = tensor.cast %res: tensor<f32> to tensor<*xf32>
// CHECK-NEXT: call @print_memref_f32(%[[dC]]) : (memref<*xf32>) -> ()
call @print_memref_f32(%res2) : (tensor<*xf32>) -> ()
return
}
// CHECK: func private @print_memref_f32(memref<*xf32>)
func private @print_memref_f32(tensor<*xf32>)

View File

@ -6314,6 +6314,7 @@ cc_library(
":ComplexDialect",
":DialectUtils",
":IR",
":InferTypeOpInterface",
":LinalgOps",
":LinalgPassIncGen",
":LinalgStructuredOpsIncGen",