forked from OSchip/llvm-project
[mlir][Linalg] Allow comprehensive bufferization to use callbacks for alloc/dealloc.
Using callbacks for allocation/deallocation allows users to override the default. Also add an option to comprehensive bufferization pass to use `alloca` instead of `alloc`s. Note that this option is just for testing. The option to use `alloca` does not work well with the option to allow for returning memrefs. Differential Revision: https://reviews.llvm.org/D112166
This commit is contained in:
parent
416fd03708
commit
c86f218fe4
|
@ -40,7 +40,10 @@ def LinalgComprehensiveModuleBufferize :
|
|||
"Only runs inplaceability analysis (for testing purposes only)">,
|
||||
Option<"allowReturnMemref", "allow-return-memref", "bool",
|
||||
/*default=*/"false",
|
||||
"Allows the return of memrefs (for testing purposes only)">
|
||||
"Allows the return of memrefs (for testing purposes only)">,
|
||||
Option<"useAlloca", "use-alloca", "bool",
|
||||
/*default=*/"false",
|
||||
"Use stack allocations for memrefs (for testing purposes only)">
|
||||
];
|
||||
let constructor = "mlir::createLinalgComprehensiveModuleBufferizePass()";
|
||||
}
|
||||
|
|
|
@ -175,14 +175,36 @@ LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
|
|||
BufferizationAliasInfo &aliasInfo,
|
||||
const DominanceInfo &domInfo);
|
||||
|
||||
/// Default allocation function that is used by the comprehensive bufferization
|
||||
/// pass. The default currently creates a ranked memref using `memref.alloc`.
|
||||
Optional<Value> defaultAllocationFn(OpBuilder &b, Location loc,
|
||||
Value shapedValue);
|
||||
|
||||
/// Default deallocation function that is used by the comprehensive
|
||||
/// bufferization pass. It expects to recieve back the value called from the
|
||||
/// `defaultAllocationFn`.
|
||||
void defaultDeallocationFn(OpBuilder &b, Location loc, Value allocatedBuffer);
|
||||
|
||||
/// Callback functions that are used by the comprehensive bufferization pass to
|
||||
/// allocate/deallocate memory. These default to use the
|
||||
/// `defaultAllocationFn`/`defaultDeallocationFn`, but can be overridden by the
|
||||
/// caller. The `deallocationFn` is gauranteed to recieve the `Value` returned
|
||||
/// by the `allocationFn`.
|
||||
struct AllocationCallbacks {
|
||||
std::function<Optional<Value>(OpBuilder &b, Location loc, Value shapedValue)>
|
||||
allocationFn = defaultAllocationFn;
|
||||
std::function<void(OpBuilder &b, Location loc, Value v)> deallocationFn =
|
||||
defaultDeallocationFn;
|
||||
};
|
||||
|
||||
/// Bufferize one particular op.
|
||||
/// `bufferizedFunctionTypes` (resp. `globalCreator`) are expected to be
|
||||
/// non-null if `op` is a CallOpInterface (resp. GlobalCreator).
|
||||
LogicalResult
|
||||
bufferizeOp(Operation *op, BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
DenseMap<FuncOp, FunctionType> *bufferizedFunctionTypes = nullptr,
|
||||
GlobalCreator *globalCreator = nullptr);
|
||||
AllocationCallbacks allocationFns,
|
||||
DenseMap<FuncOp, FunctionType> *bufferizedFunctionTypes = nullptr);
|
||||
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
|
|
@ -33,6 +33,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
|
|||
MLIRAnalysis
|
||||
MLIRArithmetic
|
||||
MLIRComplex
|
||||
MLIRInferTypeOpInterface
|
||||
MLIRIR
|
||||
MLIRMemRef
|
||||
MLIRLinalgAnalysis
|
||||
|
|
|
@ -118,12 +118,12 @@
|
|||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/IR/AsmState.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/BufferUtils.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/ScopeExit.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
|
@ -983,7 +983,6 @@ bool BufferizationAliasInfo::hasReadAfterWriteInterference(
|
|||
const DenseSet<OpOperand *> &usesRead,
|
||||
const DenseSet<OpOperand *> &usesWrite,
|
||||
const DominanceInfo &domInfo) const {
|
||||
|
||||
for (OpOperand *uRead : usesRead) {
|
||||
Operation *readingOp = uRead->getOwner();
|
||||
|
||||
|
@ -1415,66 +1414,27 @@ Operation *getFirstParentOfType(Value v) {
|
|||
/// Create an Allocop/DeAllocOp pair, where the AllocOp is after
|
||||
/// `shapedValue.getDefiningOp` (or at the top of the block in case of a
|
||||
/// bbArg) and the DeallocOp is at the end of the block.
|
||||
static Value
|
||||
createNewAllocDeallocPairForShapedValue(OpBuilder &b, Location loc,
|
||||
Value shapedValue,
|
||||
BufferizationAliasInfo &aliasInfo) {
|
||||
static Value createNewAllocDeallocPairForShapedValue(
|
||||
OpBuilder &b, Location loc, Value shapedValue,
|
||||
BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns) {
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
|
||||
// TODO: non-zero address space.
|
||||
// TODO: layout information if relevant.
|
||||
// Cannot allocate an unranked memref so just always go for the contiguous
|
||||
// form.
|
||||
MemRefType allocMemRefType =
|
||||
getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
|
||||
assert(shapedValue.getType().isa<ShapedType>());
|
||||
MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>();
|
||||
memRefType = memRefType ? memRefType : allocMemRefType;
|
||||
|
||||
if (auto bbArg = shapedValue.dyn_cast<BlockArgument>()) {
|
||||
b.setInsertionPointToStart(bbArg.getOwner());
|
||||
loc = bbArg.getOwner()->getParentOp()->getLoc();
|
||||
} else {
|
||||
b.setInsertionPoint(shapedValue.getDefiningOp());
|
||||
loc = shapedValue.getDefiningOp()->getLoc();
|
||||
Optional<Value> allocated = allocationFns.allocationFn(b, loc, shapedValue);
|
||||
// TODO: For now just assert the value is returned. Eventually need to
|
||||
// error-propagate.
|
||||
assert(allocated && "allocation failed");
|
||||
Value casted = allocated.getValue();
|
||||
MemRefType allocMemRefType = allocated->getType().cast<MemRefType>();
|
||||
if (memRefType && memRefType != allocMemRefType) {
|
||||
casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
|
||||
aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue());
|
||||
}
|
||||
|
||||
// Compute the dynamic part of the shape.
|
||||
SmallVector<Value> dynShape;
|
||||
for (auto dim : enumerate(memRefType.getShape()))
|
||||
if (dim.value() == ShapedType::kDynamicSize)
|
||||
dynShape.push_back(createOrFoldDimOp(b, loc, shapedValue, dim.index()));
|
||||
|
||||
// If the buffer is statically shaped, try to hoist it to the first enclosing
|
||||
// parallel region.
|
||||
// TODO: this concept of parallel region and threadlocal needs interfaces.
|
||||
// TODO: also hoist in the dynamic case. For now this relies on subsequent
|
||||
// calls to LICM and buffer hoisting which will most likely not succeed.
|
||||
// TODO: when packing, allocate a static bounding box which will enable more
|
||||
// hoisting.
|
||||
Value allocated;
|
||||
{ // Guarded insertion point to potentially hoist the AllocOp.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
if (dynShape.empty()) {
|
||||
Operation *parent =
|
||||
getFirstParentOfType<FuncOp, TiledLoopOp, scf::ParallelOp,
|
||||
AffineParallelOp>(shapedValue);
|
||||
if (parent)
|
||||
b.setInsertionPointToStart(&(parent->getRegion(0).front()));
|
||||
}
|
||||
allocated = b.create<memref::AllocOp>(
|
||||
loc, allocMemRefType, dynShape, b.getI64IntegerAttr(kBufferAlignments));
|
||||
aliasInfo.createAliasInfoEntry(allocated);
|
||||
}
|
||||
Value casted = allocated;
|
||||
if (memRefType != allocMemRefType) {
|
||||
casted = b.create<memref::CastOp>(loc, memRefType, allocated);
|
||||
aliasInfo.insertNewBufferEquivalence(casted, allocated);
|
||||
}
|
||||
b.setInsertionPoint(allocated.getParentBlock()->getTerminator());
|
||||
b.create<memref::DeallocOp>(loc, allocated);
|
||||
|
||||
allocationFns.deallocationFn(b, loc, allocated.getValue());
|
||||
return casted;
|
||||
}
|
||||
|
||||
|
@ -1488,6 +1448,7 @@ createNewAllocDeallocPairForShapedValue(OpBuilder &b, Location loc,
|
|||
static Value getResultBuffer(OpBuilder &b, OpResult result,
|
||||
const BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks allocationFns,
|
||||
bool skipCopy = false) {
|
||||
OpBuilder::InsertionGuard guard(b);
|
||||
Operation *op = result.getOwner();
|
||||
|
@ -1515,8 +1476,8 @@ static Value getResultBuffer(OpBuilder &b, OpResult result,
|
|||
"ops with multiple aliasing OpOperands cannot bufferize out-of-place");
|
||||
Location loc = op->getLoc();
|
||||
// Allocate the result buffer.
|
||||
Value resultBuffer =
|
||||
createNewAllocDeallocPairForShapedValue(b, loc, operand, aliasInfo);
|
||||
Value resultBuffer = createNewAllocDeallocPairForShapedValue(
|
||||
b, loc, operand, aliasInfo, allocationFns);
|
||||
// Do not copy the result of an InitTensorOp.
|
||||
if (isInitTensorOp(operand))
|
||||
skipCopy = true;
|
||||
|
@ -1538,11 +1499,10 @@ static Value getResultBuffer(OpBuilder &b, OpResult result,
|
|||
/// Helper function for LinalgOp bufferization.
|
||||
/// When allocating a new buffer, analyze whether `op` wants to read form that
|
||||
/// buffer. Only in that case, a copy of the result buffer may be needed.
|
||||
static LogicalResult
|
||||
allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
|
||||
SmallVectorImpl<Value> &resultBuffers,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo) {
|
||||
static LogicalResult allocateBuffersForResults(
|
||||
OpBuilder &b, Location loc, LinalgOp op,
|
||||
SmallVectorImpl<Value> &resultBuffers, BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns) {
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
b.setInsertionPoint(op);
|
||||
|
@ -1553,7 +1513,8 @@ allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
|
|||
OpResult opResult = getInplaceableOpResult(*opOperand);
|
||||
assert(opResult && "could not find correspond OpResult");
|
||||
bool skipCopy = !op.payloadUsesValueFromOperand(opOperand);
|
||||
Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo, skipCopy);
|
||||
Value resultBuffer =
|
||||
getResultBuffer(b, opResult, bvm, aliasInfo, allocationFns, skipCopy);
|
||||
if (!resultBuffer)
|
||||
return failure();
|
||||
resultBuffers.push_back(resultBuffer);
|
||||
|
@ -1568,7 +1529,8 @@ allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
|
|||
/// Generic conversion for any LinalgOp on tensors.
|
||||
static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo) {
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks &allocationFns) {
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
|
||||
|
@ -1591,7 +1553,7 @@ static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
|
|||
SmallVector<Value> newOutputBuffers;
|
||||
// Try to allocate new buffers depending on op's inplace semantics.
|
||||
if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm,
|
||||
aliasInfo)))
|
||||
aliasInfo, allocationFns)))
|
||||
return failure();
|
||||
|
||||
// Clone the newly bufferized op.
|
||||
|
@ -1616,7 +1578,7 @@ static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
|
|||
/// to allow FuncOp that are inplaceable to write inPlace.
|
||||
static LogicalResult
|
||||
bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns,
|
||||
DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
|
||||
FuncOp funcOp = getCalledFunction(callOp);
|
||||
assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
|
||||
|
@ -1755,12 +1717,14 @@ bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
|
|||
/// tensor::CastOp bufferizes to memref::CastOp.
|
||||
static LogicalResult bufferize(OpBuilder &b, tensor::CastOp castOp,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo) {
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks &allocationFn) {
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
b.setInsertionPoint(castOp);
|
||||
|
||||
Value resultBuffer = getResultBuffer(b, castOp->getResult(0), bvm, aliasInfo);
|
||||
Value resultBuffer =
|
||||
getResultBuffer(b, castOp->getResult(0), bvm, aliasInfo, allocationFn);
|
||||
if (!resultBuffer)
|
||||
return failure();
|
||||
Type sourceType = resultBuffer.getType();
|
||||
|
@ -1786,10 +1750,15 @@ static LogicalResult bufferize(OpBuilder &b, tensor::CastOp castOp,
|
|||
|
||||
static LogicalResult bufferize(OpBuilder &b, arith::ConstantOp constantOp,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
GlobalCreator &globalCreator) {
|
||||
BufferizationAliasInfo &aliasInfo) {
|
||||
assert(constantOp.getType().dyn_cast<RankedTensorType>() &&
|
||||
"not a constant ranked tensor");
|
||||
auto moduleOp = constantOp->getParentOfType<ModuleOp>();
|
||||
if (!moduleOp) {
|
||||
return constantOp.emitError(
|
||||
"cannot bufferize constants not within builtin.module op");
|
||||
}
|
||||
GlobalCreator globalCreator(moduleOp);
|
||||
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
|
@ -1824,7 +1793,8 @@ static LogicalResult bufferize(OpBuilder &b, tensor::DimOp dimOp,
|
|||
|
||||
static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo) {
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks &allocationFn) {
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
|
||||
|
@ -1837,7 +1807,8 @@ static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp,
|
|||
"unsupported unranked tensor");
|
||||
|
||||
// TODO: More general: Matching bbArg does not bufferize to a read.
|
||||
Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo);
|
||||
Value resultBuffer =
|
||||
getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn);
|
||||
if (!resultBuffer)
|
||||
return failure();
|
||||
|
||||
|
@ -1880,7 +1851,8 @@ static LogicalResult bufferize(OpBuilder &b, scf::IfOp ifOp,
|
|||
/// FuncOp always creates TensorToMemRef ops.
|
||||
static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo) {
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks &allocationFn) {
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
b.setInsertionPointToStart(&funcOp.body().front());
|
||||
|
@ -1906,7 +1878,8 @@ static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
|
|||
/// TODO: consider hoisting across function boundaries prior to bufferization.
|
||||
static LogicalResult bufferize(OpBuilder &b, InitTensorOp initTensorOp,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo) {
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks &allocationFn) {
|
||||
// The InitTensorOp may have been eliminated.
|
||||
if (initTensorOp->getUses().empty())
|
||||
return success();
|
||||
|
@ -1916,7 +1889,8 @@ static LogicalResult bufferize(OpBuilder &b, InitTensorOp initTensorOp,
|
|||
b.setInsertionPoint(initTensorOp);
|
||||
|
||||
Value alloc = createNewAllocDeallocPairForShapedValue(
|
||||
b, initTensorOp->getLoc(), initTensorOp.result(), aliasInfo);
|
||||
b, initTensorOp->getLoc(), initTensorOp.result(), aliasInfo,
|
||||
allocationFn);
|
||||
map(bvm, initTensorOp.result(), alloc);
|
||||
return success();
|
||||
}
|
||||
|
@ -1949,7 +1923,8 @@ static LogicalResult bufferize(OpBuilder &b, ReturnOp returnOp,
|
|||
/// Bufferization for TiledLoopOp..
|
||||
static LogicalResult bufferize(OpBuilder &b, TiledLoopOp tiledLoopOp,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo) {
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks &allocationFn) {
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
|
||||
|
@ -1989,7 +1964,8 @@ static LogicalResult bufferize(OpBuilder &b, TiledLoopOp tiledLoopOp,
|
|||
|
||||
const OpResult &opResult = tiledLoopOp->getResult(resultIndex);
|
||||
OpOperand &yieldOperand = yieldOp->getOpOperand(resultIndex);
|
||||
Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo);
|
||||
Value resultBuffer =
|
||||
getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn);
|
||||
if (!resultBuffer)
|
||||
return failure();
|
||||
|
||||
|
@ -2073,7 +2049,8 @@ static LogicalResult bufferize(OpBuilder &b, TiledLoopOp tiledLoopOp,
|
|||
/// isolation.
|
||||
static LogicalResult bufferize(OpBuilder &b, ExtractSliceOp extractSliceOp,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo) {
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks &allocationFn) {
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
|
||||
|
@ -2093,7 +2070,7 @@ static LogicalResult bufferize(OpBuilder &b, ExtractSliceOp extractSliceOp,
|
|||
auto inPlace = getInPlace(extractSliceOp->getResult(0));
|
||||
if (inPlace != InPlaceSpec::True)
|
||||
alloc = createNewAllocDeallocPairForShapedValue(
|
||||
b, loc, extractSliceOp.result(), aliasInfo);
|
||||
b, loc, extractSliceOp.result(), aliasInfo, allocationFn);
|
||||
|
||||
// Set insertion point now that potential alloc/dealloc are introduced.
|
||||
b.setInsertionPoint(extractSliceOp);
|
||||
|
@ -2125,7 +2102,8 @@ static LogicalResult bufferize(OpBuilder &b, ExtractSliceOp extractSliceOp,
|
|||
|
||||
static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo) {
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks &allocationFn) {
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
b.setInsertionPoint(insertSliceOp);
|
||||
|
@ -2140,8 +2118,8 @@ static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp,
|
|||
// TODO: be very loud about it or even consider failing the pass.
|
||||
// Alloc a copy for `insertSliceOp.dest()`, it will become the result
|
||||
// buffer.
|
||||
Value dstMemref =
|
||||
getResultBuffer(b, insertSliceOp->getResult(0), bvm, aliasInfo);
|
||||
Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), bvm,
|
||||
aliasInfo, allocationFn);
|
||||
if (!dstMemref)
|
||||
return failure();
|
||||
auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
|
||||
|
@ -2184,7 +2162,8 @@ static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp,
|
|||
|
||||
static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo) {
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks &allocationFn) {
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
b.setInsertionPoint(op);
|
||||
|
@ -2205,7 +2184,8 @@ static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op,
|
|||
// Leave the previous transfer_write to dead code as it still has uses at
|
||||
// this point.
|
||||
auto writeOp = cast<vector::TransferWriteOp>(op.getOperation());
|
||||
Value resultBuffer = getResultBuffer(b, op->getResult(0), bvm, aliasInfo);
|
||||
Value resultBuffer =
|
||||
getResultBuffer(b, op->getResult(0), bvm, aliasInfo, allocationFn);
|
||||
if (!resultBuffer)
|
||||
return failure();
|
||||
b.create<vector::TransferWriteOp>(
|
||||
|
@ -2436,18 +2416,107 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
|
|||
// Bufferization entry-point for functions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Compute the type of the `memref` to use for allocating the buffer for
|
||||
/// `shapedValue`. Also returns (by reference in `dynShape`), the value for the
|
||||
/// dynamic dimensions in the returned `memref` type. The function also sets the
|
||||
/// insertion point of the builder `b` to the position where the allocation is
|
||||
/// to be inserted.
|
||||
static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
|
||||
Value shapedValue,
|
||||
SmallVectorImpl<Value> &dynShape) {
|
||||
MemRefType allocMemRefType =
|
||||
getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
|
||||
if (auto bbArg = shapedValue.dyn_cast<BlockArgument>()) {
|
||||
b.setInsertionPointToStart(bbArg.getOwner());
|
||||
loc = bbArg.getOwner()->getParentOp()->getLoc();
|
||||
} else {
|
||||
b.setInsertionPoint(shapedValue.getDefiningOp());
|
||||
loc = shapedValue.getDefiningOp()->getLoc();
|
||||
}
|
||||
|
||||
// Compute the dynamic part of the shape.
|
||||
bool foundDynamicShapes = false;
|
||||
if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
|
||||
shapedValue.getDefiningOp())) {
|
||||
ReifiedRankedShapedTypeDims resultDims;
|
||||
if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
|
||||
foundDynamicShapes = true;
|
||||
OpResult resultValue = shapedValue.dyn_cast<OpResult>();
|
||||
auto &shape = resultDims[resultValue.getResultNumber()];
|
||||
for (auto dim : enumerate(allocMemRefType.getShape()))
|
||||
if (dim.value() == ShapedType::kDynamicSize)
|
||||
dynShape.push_back(shape[dim.index()]);
|
||||
}
|
||||
}
|
||||
if (!foundDynamicShapes) {
|
||||
for (auto dim : enumerate(allocMemRefType.getShape()))
|
||||
if (dim.value() == ShapedType::kDynamicSize)
|
||||
dynShape.push_back(createOrFoldDimOp(b, loc, shapedValue, dim.index()));
|
||||
}
|
||||
|
||||
// If the buffer is statically shaped, try to hoist it to the first enclosing
|
||||
// parallel region.
|
||||
// TODO: this concept of parallel region and threadlocal needs interfaces.
|
||||
// TODO: also hoist in the dynamic case. For now this relies on subsequent
|
||||
// calls to LICM and buffer hoisting which will most likely not succeed.
|
||||
// TODO: when packing, allocate a static bounding box which will enable more
|
||||
// hoisting.
|
||||
if (dynShape.empty()) {
|
||||
Operation *parent =
|
||||
getFirstParentOfType<FuncOp, TiledLoopOp, scf::ParallelOp,
|
||||
AffineParallelOp>(shapedValue);
|
||||
if (parent)
|
||||
b.setInsertionPointToStart(&(parent->getRegion(0).front()));
|
||||
}
|
||||
return allocMemRefType;
|
||||
}
|
||||
|
||||
Optional<Value> mlir::linalg::defaultAllocationFn(OpBuilder &b, Location loc,
|
||||
Value shapedValue) {
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
SmallVector<Value> dynShape;
|
||||
MemRefType allocMemRefType =
|
||||
getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
|
||||
Value allocated = b.create<memref::AllocOp>(
|
||||
loc, allocMemRefType, dynShape, b.getI64IntegerAttr(kBufferAlignments));
|
||||
return allocated;
|
||||
}
|
||||
|
||||
static Optional<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
|
||||
Value shapedValue) {
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
SmallVector<Value> dynShape;
|
||||
MemRefType allocMemRefType =
|
||||
getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
|
||||
Value allocated = b.create<memref::AllocaOp>(
|
||||
loc, allocMemRefType, dynShape, b.getI64IntegerAttr(kBufferAlignments));
|
||||
return allocated;
|
||||
}
|
||||
|
||||
void mlir::linalg::defaultDeallocationFn(OpBuilder &b, Location loc,
|
||||
Value allocatedBuffer) {
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
b.setInsertionPoint(allocatedBuffer.getParentBlock()->getTerminator());
|
||||
b.create<memref::DeallocOp>(loc, allocatedBuffer);
|
||||
}
|
||||
|
||||
LogicalResult mlir::linalg::bufferizeOp(
|
||||
Operation *op, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo,
|
||||
DenseMap<FuncOp, FunctionType> *bufferizedFunctionTypes,
|
||||
GlobalCreator *globalCreator) {
|
||||
AllocationCallbacks allocationFns,
|
||||
DenseMap<FuncOp, FunctionType> *bufferizedFunctionTypes) {
|
||||
OpBuilder b(op->getContext());
|
||||
return TypeSwitch<Operation *, LogicalResult>(op)
|
||||
// Skip BufferCast and TensorLoad ops.
|
||||
.Case<memref::BufferCastOp, memref::TensorLoadOp>(
|
||||
[&](auto) { return success(); })
|
||||
.Case<tensor::CastOp, tensor::DimOp, ExtractSliceOp, scf::ForOp,
|
||||
InitTensorOp, InsertSliceOp, tensor::ExtractOp, LinalgOp, ReturnOp,
|
||||
TiledLoopOp, VectorTransferOpInterface, linalg::YieldOp,
|
||||
.Case<ExtractSliceOp, InitTensorOp, InsertSliceOp, LinalgOp, scf::ForOp,
|
||||
tensor::CastOp, TiledLoopOp, VectorTransferOpInterface>(
|
||||
[&](auto op) {
|
||||
LDBG("Begin bufferize:\n" << op << '\n');
|
||||
return bufferize(b, op, bvm, aliasInfo, allocationFns);
|
||||
})
|
||||
.Case<tensor::DimOp, tensor::ExtractOp, ReturnOp, linalg::YieldOp,
|
||||
scf::YieldOp>([&](auto op) {
|
||||
LDBG("Begin bufferize:\n" << op << '\n');
|
||||
return bufferize(b, op, bvm, aliasInfo);
|
||||
|
@ -2464,15 +2533,14 @@ LogicalResult mlir::linalg::bufferizeOp(
|
|||
if (!bufferizedFunctionTypes)
|
||||
llvm_unreachable(
|
||||
"null bufferizedFunctionTypes when bufferizing CallOpInterface");
|
||||
return bufferize(b, op, bvm, aliasInfo, *bufferizedFunctionTypes);
|
||||
return bufferize(b, op, bvm, aliasInfo, allocationFns,
|
||||
*bufferizedFunctionTypes);
|
||||
})
|
||||
.Case([&](arith::ConstantOp op) {
|
||||
if (!isaTensor(op.getResult().getType()))
|
||||
return success();
|
||||
LDBG("Begin bufferize:\n" << op << '\n');
|
||||
if (!globalCreator)
|
||||
llvm_unreachable("null globalCreator when bufferizing ConstantOp");
|
||||
return bufferize(b, op, bvm, aliasInfo, *globalCreator);
|
||||
return bufferize(b, op, bvm, aliasInfo);
|
||||
})
|
||||
.Default([&](Operation *op) -> LogicalResult {
|
||||
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
|
||||
|
@ -2485,15 +2553,13 @@ LogicalResult mlir::linalg::bufferizeOp(
|
|||
|
||||
static LogicalResult bufferizeFuncOpInternals(
|
||||
FuncOp funcOp, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo,
|
||||
DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes,
|
||||
GlobalCreator &globalCreator) {
|
||||
|
||||
AllocationCallbacks &allocationFns,
|
||||
DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "\n\n");
|
||||
LDBG("Begin BufferizeFuncOpInternals:\n" << funcOp << '\n');
|
||||
OpBuilder b(funcOp->getContext());
|
||||
|
||||
// Start by bufferizing `funcOp` arguments.
|
||||
if (failed(bufferize(b, funcOp, bvm, aliasInfo)))
|
||||
/// Start by bufferizing `funcOp` arguments.
|
||||
if (failed(bufferize(b, funcOp, bvm, aliasInfo, allocationFns)))
|
||||
return failure();
|
||||
|
||||
// Cannot erase ops during the traversal. Do that afterwards.
|
||||
|
@ -2516,13 +2582,13 @@ static LogicalResult bufferizeFuncOpInternals(
|
|||
}
|
||||
|
||||
for (Operation *op : llvm::reverse(preorderBufferize))
|
||||
if (failed(bufferizeOp(op, bvm, aliasInfo, &bufferizedFunctionTypes,
|
||||
&globalCreator)))
|
||||
if (failed(bufferizeOp(op, bvm, aliasInfo, allocationFns,
|
||||
&bufferizedFunctionTypes)))
|
||||
return failure();
|
||||
|
||||
if (!bufferizedOps.contains(op) &&
|
||||
failed(bufferizeOp(op, bvm, aliasInfo, &bufferizedFunctionTypes,
|
||||
&globalCreator)))
|
||||
failed(bufferizeOp(op, bvm, aliasInfo, allocationFns,
|
||||
&bufferizedFunctionTypes)))
|
||||
return failure();
|
||||
|
||||
// Register post-walk erasure, if necessary.
|
||||
|
@ -2793,12 +2859,19 @@ namespace {
|
|||
struct LinalgComprehensiveModuleBufferize
|
||||
: public LinalgComprehensiveModuleBufferizeBase<
|
||||
LinalgComprehensiveModuleBufferize> {
|
||||
LinalgComprehensiveModuleBufferize() {}
|
||||
|
||||
LinalgComprehensiveModuleBufferize(
|
||||
const LinalgComprehensiveModuleBufferize &p) {}
|
||||
|
||||
void runOnOperation() override;
|
||||
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<linalg::LinalgDialect, memref::MemRefDialect>();
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<AllocationCallbacks> allocationFns;
|
||||
};
|
||||
} // end namespace
|
||||
|
||||
|
@ -2983,6 +3056,22 @@ static LogicalResult runInitTensorElimination(FuncOp funcOp,
|
|||
}
|
||||
|
||||
void LinalgComprehensiveModuleBufferize::runOnOperation() {
|
||||
if (!allocationFns) {
|
||||
// The allocation functions to use needs to be set here. The flag for the
|
||||
// pass and flag for the use of alloca map to LLVM command line
|
||||
// options. These being static global objects have no set order in which
|
||||
// they are defined. So ideally this should be in the constructor, but the
|
||||
// constructor might be called before the flag is initialized using the
|
||||
// command line option. So this is set up at the start of the pass.
|
||||
if (useAlloca) {
|
||||
AllocationCallbacks allocaAllocationFns = {
|
||||
allocationFnUsingAlloca, [](OpBuilder &b, Location loc, Value v) {}};
|
||||
allocationFns =
|
||||
std::make_unique<AllocationCallbacks>(std::move(allocaAllocationFns));
|
||||
} else {
|
||||
allocationFns = std::make_unique<AllocationCallbacks>();
|
||||
}
|
||||
}
|
||||
ModuleOp moduleOp = getOperation();
|
||||
applyEnablingTransformations(moduleOp);
|
||||
|
||||
|
@ -2992,7 +3081,6 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
|
|||
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
|
||||
return signalPassFailure();
|
||||
|
||||
GlobalCreator globalCreator(moduleOp);
|
||||
DominanceInfo domInfo(moduleOp);
|
||||
BufferizationAliasInfo aliasInfo(moduleOp);
|
||||
// Interestingly, all function args that are not visible outside of a module
|
||||
|
@ -3032,8 +3120,8 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
|
|||
if (!testAnalysisOnly) {
|
||||
BlockAndValueMapping tensorToBufferMap;
|
||||
if (failed(bufferizeFuncOpInternals(funcOp, tensorToBufferMap, aliasInfo,
|
||||
bufferizedFunctionTypes,
|
||||
globalCreator))) {
|
||||
*allocationFns,
|
||||
bufferizedFunctionTypes))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
// RUN: mlir-opt %s -pass-pipeline="linalg-comprehensive-module-bufferize{allow-return-memref use-alloca}" -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-DAG: #[[$DYN_0D_MAP:.*]] = affine_map<()[s0] -> (s0)>
|
||||
// CHECK-DAG: #[[$DYN_1D_MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
|
||||
|
||||
// CHECK: func @init_and_dot(
|
||||
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref<64xf32, #[[$DYN_1D_MAP]]>
|
||||
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: memref<64xf32, #[[$DYN_1D_MAP]]>
|
||||
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: memref<f32, #[[$DYN_0D_MAP]]>
|
||||
func @init_and_dot(%a: tensor<64xf32>, %b: tensor<64xf32>, %c: tensor<f32>) -> tensor<f32> {
|
||||
// CHECK-NEXT: %[[C0:.*]] = arith.constant 0{{.*}} : f32
|
||||
%v0 = arith.constant 0.0 : f32
|
||||
|
||||
// CHECK-NEXT: linalg.fill(%[[C0]], %[[C]]) : f32, memref<f32, #[[$DYN_0D_MAP]]>
|
||||
%d = linalg.fill(%v0, %c) : f32, tensor<f32> -> tensor<f32>
|
||||
|
||||
// CHECK-NEXT: linalg.dot ins(%[[A]], %[[B]] : memref<64xf32, #[[$DYN_1D_MAP]]>, memref<64xf32, #[[$DYN_1D_MAP]]>) outs(%[[C]] : memref<f32, #[[$DYN_0D_MAP]]>)
|
||||
%e = linalg.dot ins(%a, %b : tensor<64xf32>,tensor<64xf32>)
|
||||
outs(%d: tensor<f32>) -> tensor<f32>
|
||||
|
||||
// CHECK-NEXT: return
|
||||
return %e : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK: func @main()
|
||||
func @main() {
|
||||
// CHECK-DAG: %[[C0:.*]] = arith.constant 0{{.*}} : f32
|
||||
// CHECK-DAG: %[[C1:.*]] = arith.constant 1{{.*}} : f32
|
||||
// CHECK-DAG: %[[C2:.*]] = arith.constant 2{{.*}} : f32
|
||||
%v0 = arith.constant 0.0 : f32
|
||||
%v1 = arith.constant 1.0 : f32
|
||||
%v2 = arith.constant 2.0 : f32
|
||||
|
||||
// CHECK-NEXT: %[[C:.*]] = memref.alloca() {alignment = 128 : i64} : memref<f32>
|
||||
// CHECK-NEXT: %[[B:.*]] = memref.alloca() {alignment = 128 : i64} : memref<64xf32>
|
||||
// CHECK-NEXT: %[[A:.*]] = memref.alloca() {alignment = 128 : i64} : memref<64xf32>
|
||||
%A = linalg.init_tensor [64] : tensor<64xf32>
|
||||
%B = linalg.init_tensor [64] : tensor<64xf32>
|
||||
%C = linalg.init_tensor [] : tensor<f32>
|
||||
|
||||
// CHECK-NEXT: linalg.fill(%[[C1]], %[[A]]) : f32, memref<64xf32>
|
||||
// CHECK-NEXT: linalg.fill(%[[C2]], %[[B]]) : f32, memref<64xf32>
|
||||
// CHECK-NEXT: linalg.fill(%[[C0]], %[[C]]) : f32, memref<f32>
|
||||
%AA = linalg.fill(%v1, %A) : f32, tensor<64xf32> -> tensor<64xf32>
|
||||
%BB = linalg.fill(%v2, %B) : f32, tensor<64xf32> -> tensor<64xf32>
|
||||
%CC = linalg.fill(%v0, %C) : f32, tensor<f32> -> tensor<f32>
|
||||
|
||||
// CHECK-NEXT: %[[cA:.*]] = memref.cast %[[A]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]>
|
||||
// CHECK-NEXT: %[[cB:.*]] = memref.cast %[[B]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]>
|
||||
// CHECK-NEXT: %[[cC:.*]] = memref.cast %[[C]] : memref<f32> to memref<f32, #[[$DYN_0D_MAP]]>
|
||||
// CHECK-NEXT: call @init_and_dot(%[[cA]], %[[cB]], %[[cC]])
|
||||
%res = call @init_and_dot(%AA, %BB, %CC) :
|
||||
(tensor<64xf32>, tensor<64xf32>, tensor<f32>) -> tensor<f32>
|
||||
|
||||
// CHECK-NEXT: %[[dC:.*]] = memref.cast %[[C]] : memref<f32> to memref<*xf32>
|
||||
%res2 = tensor.cast %res: tensor<f32> to tensor<*xf32>
|
||||
|
||||
// CHECK-NEXT: call @print_memref_f32(%[[dC]]) : (memref<*xf32>) -> ()
|
||||
call @print_memref_f32(%res2) : (tensor<*xf32>) -> ()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: func private @print_memref_f32(memref<*xf32>)
|
||||
func private @print_memref_f32(tensor<*xf32>)
|
|
@ -6314,6 +6314,7 @@ cc_library(
|
|||
":ComplexDialect",
|
||||
":DialectUtils",
|
||||
":IR",
|
||||
":InferTypeOpInterface",
|
||||
":LinalgOps",
|
||||
":LinalgPassIncGen",
|
||||
":LinalgStructuredOpsIncGen",
|
||||
|
|
Loading…
Reference in New Issue