[mlir][linalg][bufferize] Group helpers in BufferizationState

This simplifies the signature of `bufferize`.

Differential Revision: https://reviews.llvm.org/D113388
This commit is contained in:
Matthias Springer 2021-11-11 18:23:13 +09:00
parent 7ac1fd0da9
commit aeb1c8d0ca
5 changed files with 200 additions and 245 deletions

View File

@ -197,6 +197,8 @@ findValueInReverseUseDefChain(Value value,
/// is returned regardless of whether it is a memory write or not.
Value findLastPrecedingWrite(Value value);
struct BufferizationState;
/// Callback functions that are used to allocate/deallocate/copy memory buffers.
/// Comprehensive Bufferize provides default implementations of these functions.
// TODO: Could be replaced with a "bufferization strategy" object with virtual
@ -207,8 +209,7 @@ struct AllocationCallbacks {
using DeallocationFn = std::function<void(OpBuilder &, Location, Value)>;
using MemCpyFn = std::function<void(OpBuilder &, Location, Value, Value)>;
using CreateAllocDeallocFn =
std::function<Value(OpBuilder &, Location, Value,
BufferizationAliasInfo &, AllocationCallbacks &)>;
std::function<Value(OpBuilder &, Location, Value, BufferizationState &)>;
AllocationCallbacks(AllocationFn allocFn, DeallocationFn deallocFn,
MemCpyFn copyFn, CreateAllocDeallocFn allocDeallocFn)
@ -230,13 +231,40 @@ struct AllocationCallbacks {
CreateAllocDeallocFn createAllocDeallocFn;
};
/// BufferizationState keeps track of bufferization state and provides access to
/// the results of the analysis.
struct BufferizationState {
BufferizationState(BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFns,
BlockAndValueMapping &tensorToBufferMap)
: aliasInfo(aliasInfo), allocationFns(allocationFns),
tensorToBufferMap(tensorToBufferMap) {}
/// Map tensor values to memref buffers.
void mapBuffer(ValueRange tensors, ValueRange buffers);
/// Map a tensor value to a memref buffer.
void mapBuffer(Value tensor, Value buffer);
/// Lookup the memref buffer that is associated to the given tensor value.
/// Asserts if no buffer is associated.
Value lookupBuffer(Value tensor) const;
/// `aliasInfo` keeps track of aliasing and equivalent values.
BufferizationAliasInfo &aliasInfo;
/// `allocationFns` contains helper functions for creating alloc ops, dealloc
/// ops and memcpy ops.
AllocationCallbacks &allocationFns;
/// The mapping of tensors to buffers.
BlockAndValueMapping &tensorToBufferMap;
};
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
/// a new buffer and copy over data from the existing buffer if out-of-place
/// bufferization is necessary.
Value getResultBuffer(OpBuilder &b, OpResult result,
const BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks allocationFns);
Value getResultBuffer(OpBuilder &b, OpResult result, BufferizationState &state);
} // namespace comprehensive_bufferize
} // namespace linalg
@ -280,9 +308,7 @@ struct AllocationHoistingBarrierOnly
bool isWritable(Operation *op, Value value) const { return false; }
LogicalResult bufferize(Operation *op, OpBuilder &b,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFn) const {
BufferizationState &state) const {
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
if (any_of(op->getOperandTypes(), isaTensor) ||
any_of(op->getResultTypes(), isaTensor))

View File

@ -160,8 +160,6 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
llvm_unreachable("bufferRelation not implemented");
}]
>,
// TODO: Simplify method signature: Pass an OpBuilder and a
// BufferizationState object.
InterfaceMethod<
/*desc=*/[{
Bufferize this op, i.e., rewrite it into a memref-based equivalent.
@ -171,9 +169,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"LogicalResult",
/*methodName=*/"bufferize",
/*args=*/(ins "OpBuilder &":$b,
"BlockAndValueMapping &":$bvm,
"BufferizationAliasInfo &":$aliasInfo,
"AllocationCallbacks &":$allocationFn),
"BufferizationState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
llvm_unreachable("bufferize not implemented");

View File

@ -27,6 +27,8 @@ namespace comprehensive_bufferize {
// TODO: from some HW description.
static constexpr int64_t kBufferAlignments = 128;
struct BufferizationState;
/// Analyze the `ops` to determine which OpResults are inplaceable.
LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
BufferizationAliasInfo &aliasInfo,
@ -55,9 +57,7 @@ std::unique_ptr<AllocationCallbacks> defaultAllocationCallbacks();
/// `bufferizedFunctionTypes` (resp. `globalCreator`) are expected to be
/// non-null if `op` is a CallOpInterface (resp. GlobalCreator).
LogicalResult
bufferizeOp(Operation *op, BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks allocationFns,
bufferizeOp(Operation *op, BufferizationState &state,
DenseMap<FuncOp, FunctionType> *bufferizedFunctionTypes = nullptr);
/// Register external models implemented for the `BufferizableOpInterface`.

View File

@ -7,8 +7,11 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "llvm/Support/Debug.h"
namespace mlir {
@ -319,30 +322,28 @@ Value mlir::linalg::comprehensive_bufferize::findLastPrecedingWrite(
/// a new buffer and copy over data from the existing buffer if out-of-place
/// bufferization is necessary.
Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
OpBuilder &b, OpResult result, const BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo, AllocationCallbacks allocationFns) {
OpBuilder &b, OpResult result, BufferizationState &state) {
OpBuilder::InsertionGuard guard(b);
Operation *op = result.getOwner();
SmallVector<OpOperand *> aliasingOperands = getAliasingOpOperand(result);
assert(!aliasingOperands.empty() && "could not get aliasing OpOperand");
OpOperand *opOperand = aliasingOperands.front();
Value operand = opOperand->get();
Value operandBuffer = bvm.lookupOrNull(operand);
assert(operandBuffer && "operand buffer not found");
Value operandBuffer = state.lookupBuffer(operand);
// Make sure that all OpOperands are the same buffer. If this is not the case,
// we would have to materialize a memref value.
// TODO: Should be looking for checking for "equivalent buffers" instead of
// operator== here, but equivalent buffers for scf.if yield values are not
// set up yet.
if (!llvm::all_of(aliasingOperands, [&](OpOperand *o) {
return bvm.lookup(o->get()) == operandBuffer;
return state.lookupBuffer(o->get()) == operandBuffer;
})) {
op->emitError("result buffer is ambiguous");
return Value();
}
// If bufferizing out-of-place, allocate a new buffer.
if (!aliasInfo.isInPlace(result)) {
if (!state.aliasInfo.isInPlace(result)) {
// Ops with multiple aliasing operands can currently not bufferize
// out-of-place.
assert(
@ -350,8 +351,8 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
"ops with multiple aliasing OpOperands cannot bufferize out-of-place");
Location loc = op->getLoc();
// Allocate the result buffer.
Value resultBuffer = allocationFns.createAllocDeallocFn(
b, loc, operand, aliasInfo, allocationFns);
Value resultBuffer =
state.allocationFns.createAllocDeallocFn(b, loc, operand, state);
bool skipCopy = false;
// Do not copy if the last preceding write of `operand` is an op that does
// not write (skipping ops that merely create aliases). E.g., InitTensorOp.
@ -373,7 +374,7 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
if (!skipCopy) {
// Set insertion point now that potential alloc/dealloc are introduced.
b.setInsertionPoint(op);
allocationFns.memCpyFn(b, loc, operandBuffer, resultBuffer);
state.allocationFns.memCpyFn(b, loc, operandBuffer, resultBuffer);
}
return resultBuffer;
}
@ -381,3 +382,39 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
// Bufferizing in-place. No need to allocate a new buffer.
return operandBuffer;
}
//===----------------------------------------------------------------------===//
// Bufferization-specific BlockAndValueMapping support with debugging.
//===----------------------------------------------------------------------===//
/// Wrapper for better debugging.
void mlir::linalg::comprehensive_bufferize::BufferizationState::mapBuffer(
ValueRange tensors, ValueRange buffers) {
assert(!tensors.empty() && "unexpected empty tensors");
return tensorToBufferMap.map(tensors, buffers);
}
/// Wrapper for better debugging.
void mlir::linalg::comprehensive_bufferize::BufferizationState::mapBuffer(
Value tensor, Value buffer) {
assert(tensor && "unexpected empty tensor");
assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
return tensorToBufferMap.map(tensor, buffer);
}
/// Wrapper for better debugging.
Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
Value tensor) const {
// TODO: if key comes from bbArg, forward.
assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
Value v = tensorToBufferMap.lookupOrNull(tensor);
if (!v) {
// Dump tensor for easier debugging.
tensor.dump();
llvm_unreachable("tensor is not mapped");
return Value();
}
return v;
}

View File

@ -172,47 +172,6 @@ static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
return returnOp;
}
//===----------------------------------------------------------------------===//
// Bufferization-specific BlockAndValueMapping support with debugging.
//===----------------------------------------------------------------------===//
/// Wrapper for better debugging.
static void map(BlockAndValueMapping &bvm, ValueRange keys, ValueRange values) {
assert(!keys.empty() && "Unexpected empty keys");
LDBG("\n\tMap: " << printValueInfo(keys.front())
<< "\n\tto: " << printValueInfo(values.front()) << '\n');
return bvm.map(keys, values);
}
/// Wrapper for better debugging.
static void map(BlockAndValueMapping &bvm, Value key, Value value) {
LDBG("\n\tMap: " << printValueInfo(key) << "\n\tto: " << printValueInfo(value)
<< '\n');
return bvm.map(key, value);
}
/// Wrapper for better debugging.
static Value lookup(const BlockAndValueMapping &bvm, Value key) {
// TODO: if key comes from bbArg, forward.
assert(key.getType().isa<TensorType>());
Value v = bvm.lookupOrNull(key);
if (v)
return v;
Operation *parentOp;
if (auto bbArg = key.dyn_cast<BlockArgument>()) {
if (isa<FuncOp>(key.getParentBlock()->getParentOp()))
parentOp = key.getParentBlock()->getParentOp();
else
parentOp = key.getParentBlock()->getParentOp()->getParentOfType<FuncOp>();
} else {
parentOp = key.getDefiningOp()->getParentOfType<FuncOp>();
}
LDBG("In func:\n" << *parentOp << "\nNO VALUE FOR KEY: " << key << '\n');
(void)parentOp;
return Value();
}
//===----------------------------------------------------------------------===//
// Bufferization-specific attribute manipulation.
// These are for testing and debugging only. Bufferization information is
@ -878,8 +837,7 @@ static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
/// `shapedValue.getDefiningOp` (or at the top of the block in case of a
/// bbArg) and the DeallocOp is at the end of the block.
static Value createNewAllocDeallocPairForShapedValue(
OpBuilder &b, Location loc, Value shapedValue,
BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns) {
OpBuilder &b, Location loc, Value shapedValue, BufferizationState &state) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
@ -891,19 +849,19 @@ static Value createNewAllocDeallocPairForShapedValue(
MemRefType allocMemRefType =
getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
Optional<Value> allocated =
allocationFns.allocationFn(b, loc, allocMemRefType, dynShape);
state.allocationFns.allocationFn(b, loc, allocMemRefType, dynShape);
// TODO: For now just assert the value is returned. Eventually need to
// error-propagate.
assert(allocated && "allocation failed");
Value casted = allocated.getValue();
if (memRefType && memRefType != allocMemRefType) {
casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue());
state.aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue());
}
// 2. Create memory deallocation.
b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator());
allocationFns.deallocationFn(b, loc, allocated.getValue());
state.allocationFns.deallocationFn(b, loc, allocated.getValue());
return casted;
}
@ -915,8 +873,7 @@ static Value createNewAllocDeallocPairForShapedValue(
/// inplaceable. For now, it is the responsibility of the `callOp` bufferization
/// to allow FuncOp that are inplaceable to write inPlace.
static LogicalResult
bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns,
bufferize(OpBuilder &b, CallOpInterface callOp, BufferizationState &state,
DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
FuncOp funcOp = getCalledFunction(callOp);
assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
@ -962,14 +919,13 @@ bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
// If return operand is equivalent to some bbArg, no need to return it.
Value returnVal = returnOperand.get();
if (BlockArgument bbArg =
getEquivalentEnclosingFuncBBArg(returnVal, aliasInfo)) {
getEquivalentEnclosingFuncBBArg(returnVal, state.aliasInfo)) {
Value oldRes = callOp->getResult(returnOperand.getOperandNumber());
int64_t idx = bbArg.getArgNumber();
Value buffer = lookup(bvm, callOp->getOperand(idx));
assert(buffer && "expected bufferized value");
Value buffer = state.lookupBuffer(callOp->getOperand(idx));
// Add CallOp operand/result equivalence: this is interprocedural info.
aliasInfo.insertNewBufferEquivalence(oldRes, buffer);
map(bvm, oldRes, buffer);
state.aliasInfo.insertNewBufferEquivalence(oldRes, buffer);
state.mapBuffer(oldRes, buffer);
// Add a TensorLoadOp to kill all uses of the CallOp return.
// Replace all uses of the CallOp results so we can erase the CallOp.
// This TensorLoadOp must fold/DCE away or bufferization should be
@ -978,13 +934,13 @@ bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
b.create<memref::TensorLoadOp>(callOp.getLoc(), buffer);
oldRes.replaceAllUsesWith(tensorLoad);
// Add new op equivalence info.
aliasInfo.insertNewBufferEquivalence(tensorLoad, buffer);
map(bvm, tensorLoad, buffer);
state.aliasInfo.insertNewBufferEquivalence(tensorLoad, buffer);
state.mapBuffer(tensorLoad, buffer);
continue;
}
// TODO: Need to hoist above function boundary.
if (Operation *allocOp = getEquivalentAlloc(returnVal, aliasInfo)) {
if (Operation *allocOp = getEquivalentAlloc(returnVal, state.aliasInfo)) {
hoistedArguments.push_back(allocOp->getResult(0));
continue;
}
@ -1023,8 +979,7 @@ bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
// Tensor operands are guaranteed to have been buferized.
int64_t idx = opOperand.getOperandNumber();
Value buffer = lookup(bvm, tensorOperand);
assert(buffer && "expected bufferized value");
Value buffer = state.lookupBuffer(tensorOperand);
// Caller / callee type mistmatch is handled with a CastOp.
auto memRefType = bufferizedFuncType.getInput(idx);
@ -1037,8 +992,8 @@ bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
Value castBuffer =
b.create<memref::CastOp>(callOp.getLoc(), memRefType, buffer);
// Add new op equivalence info.
aliasInfo.insertNewBufferEquivalence(castBuffer, buffer);
map(bvm, tensorOperand, castBuffer);
state.aliasInfo.insertNewBufferEquivalence(castBuffer, buffer);
state.mapBuffer(tensorOperand, castBuffer);
buffer = castBuffer;
}
newOperands.push_back(buffer);
@ -1054,9 +1009,7 @@ bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
/// FuncOp always creates TensorToMemRef ops.
static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFn) {
BufferizationState &state) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPointToStart(&funcOp.body().front());
@ -1072,8 +1025,8 @@ static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
: getContiguousOrUnrankedMemRefType(tensorType);
Value bufferCast =
b.create<memref::BufferCastOp>(funcOp.getLoc(), memRefType, bbArg);
aliasInfo.insertNewBufferEquivalence(bufferCast, bbArg);
map(bvm, bbArg, bufferCast);
state.aliasInfo.insertNewBufferEquivalence(bufferCast, bbArg);
state.mapBuffer(bbArg, bufferCast);
}
return success();
}
@ -1230,8 +1183,7 @@ void mlir::linalg::comprehensive_bufferize::defaultMemCpyFn(OpBuilder &b,
}
LogicalResult mlir::linalg::comprehensive_bufferize::bufferizeOp(
Operation *op, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo,
AllocationCallbacks allocationFns,
Operation *op, BufferizationState &state,
DenseMap<FuncOp, FunctionType> *bufferizedFunctionTypes) {
OpBuilder b(op->getContext());
@ -1241,8 +1193,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::bufferizeOp(
if (!bufferizedFunctionTypes)
llvm_unreachable(
"null bufferizedFunctionTypes when bufferizing CallOpInterface");
return bufferize(b, callOp, bvm, aliasInfo, allocationFns,
*bufferizedFunctionTypes);
return bufferize(b, callOp, state, *bufferizedFunctionTypes);
}
// Skip BufferCast and TensorLoad ops.
@ -1251,7 +1202,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::bufferizeOp(
// Bufferize using `BufferizableOpInterface`.
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
return bufferizableOp.bufferize(b, bvm, aliasInfo, allocationFns);
return bufferizableOp.bufferize(b, state);
// Other op with tensors. No bufferization method specified.
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
@ -1262,23 +1213,21 @@ LogicalResult mlir::linalg::comprehensive_bufferize::bufferizeOp(
}
static LogicalResult bufferizeFuncOpInternals(
FuncOp funcOp, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFns,
FuncOp funcOp, BufferizationState &state,
DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
LLVM_DEBUG(llvm::dbgs() << "\n\n");
LDBG("Begin BufferizeFuncOpInternals:\n" << funcOp << '\n');
OpBuilder b(funcOp->getContext());
// Start by bufferizing `funcOp` arguments.
if (failed(bufferize(b, funcOp, bvm, aliasInfo, allocationFns)))
if (failed(bufferize(b, funcOp, state)))
return failure();
// Cannot erase ops during the traversal. Do that afterwards.
SmallVector<Operation *> toErase;
auto walkFunc = [&](Operation *op) -> WalkResult {
if (failed(bufferizeOp(op, bvm, aliasInfo, allocationFns,
&bufferizedFunctionTypes)))
if (failed(bufferizeOp(op, state, &bufferizedFunctionTypes)))
return failure();
// Register post-walk erasure, if necessary.
@ -1852,9 +1801,10 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
// Bufferization phase.
if (!options.testAnalysisOnly) {
BlockAndValueMapping tensorToBufferMap;
if (failed(bufferizeFuncOpInternals(funcOp, tensorToBufferMap, aliasInfo,
*options.allocationFns,
bufferizedFunctionTypes)))
BufferizationState state(aliasInfo, *options.allocationFns,
tensorToBufferMap);
if (failed(
bufferizeFuncOpInternals(funcOp, state, bufferizedFunctionTypes)))
return failure();
}
}
@ -1926,9 +1876,7 @@ struct ConstantOpInterface
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFn) const {
BufferizationState &state) const {
auto constantOp = cast<arith::ConstantOp>(op);
if (!isaTensor(constantOp.getResult().getType()))
return success();
@ -1948,8 +1896,8 @@ struct ConstantOpInterface
auto globalMemref = globalCreator.getGlobalFor(constantOp);
Value memref = b.create<memref::GetGlobalOp>(
constantOp.getLoc(), globalMemref.type(), globalMemref.getName());
aliasInfo.insertNewBufferEquivalence(memref, constantOp.getResult());
map(bvm, constantOp, memref);
state.aliasInfo.insertNewBufferEquivalence(memref, constantOp.getResult());
state.mapBuffer(constantOp, memref);
return success();
}
@ -1969,10 +1917,10 @@ namespace linalg_ext {
/// Helper function for LinalgOp bufferization.
/// When allocating a new buffer, analyze whether `op` wants to read form that
/// buffer. Only in that case, a copy of the result buffer may be needed.
static LogicalResult allocateBuffersForResults(
OpBuilder &b, Location loc, LinalgOp op,
SmallVectorImpl<Value> &resultBuffers, BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns) {
static LogicalResult
allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
SmallVectorImpl<Value> &resultBuffers,
BufferizationState &state) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(op);
@ -1983,24 +1931,21 @@ static LogicalResult allocateBuffersForResults(
OpResult opResult = cast<BufferizableOpInterface>(op.getOperation())
.getAliasingOpResult(*opOperand);
assert(opResult && "could not find correspond OpResult");
Value resultBuffer =
getResultBuffer(b, opResult, bvm, aliasInfo, allocationFns);
Value resultBuffer = getResultBuffer(b, opResult, state);
if (!resultBuffer)
return failure();
resultBuffers.push_back(resultBuffer);
}
if (op->getNumResults())
map(bvm, op->getResults(), resultBuffers);
state.mapBuffer(op->getResults(), resultBuffers);
return success();
}
/// Generic conversion for any LinalgOp on tensors.
static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFns) {
BufferizationState &state) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
@ -2017,13 +1962,11 @@ static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op,
newInputBuffers.push_back(opOperand->get());
continue;
}
newInputBuffers.push_back(lookup(bvm, opOperand->get()));
assert(newInputBuffers.back() && "missing buffer");
newInputBuffers.push_back(state.lookupBuffer(opOperand->get()));
}
SmallVector<Value> newOutputBuffers;
// Try to allocate new buffers depending on op's inplace semantics.
if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm,
aliasInfo, allocationFns)))
if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, state)))
return failure();
// Clone the newly bufferized op.
@ -2036,7 +1979,7 @@ static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op,
// Replace the results of the old op with the new output buffers.
if (op->getNumResults())
map(bvm, op->getResults(), newOutputBuffers);
state.mapBuffer(op->getResults(), newOutputBuffers);
// The original op will be DCE'd away later.
@ -2087,11 +2030,8 @@ struct LinalgOpInterface
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFn) const {
return bufferizeLinalgOp(b, cast<LinalgOp>(op), bvm, aliasInfo,
allocationFn);
BufferizationState &state) const {
return bufferizeLinalgOp(b, cast<LinalgOp>(op), state);
}
};
@ -2109,9 +2049,7 @@ struct InitTensorOpInterface
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFn) const {
BufferizationState &state) const {
auto initTensorOp = cast<linalg::InitTensorOp>(op);
// The InitTensorOp may have been eliminated.
@ -2123,9 +2061,8 @@ struct InitTensorOpInterface
b.setInsertionPoint(initTensorOp);
Value alloc = createNewAllocDeallocPairForShapedValue(
b, initTensorOp->getLoc(), initTensorOp.result(), aliasInfo,
allocationFn);
map(bvm, initTensorOp.result(), alloc);
b, initTensorOp->getLoc(), initTensorOp.result(), state);
state.mapBuffer(initTensorOp.result(), alloc);
return success();
}
};
@ -2178,9 +2115,7 @@ struct TiledLoopOpInterface
bool isAllocationHoistingBarrier(Operation *op) const { return true; }
LogicalResult bufferize(Operation *op, OpBuilder &b,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFn) const {
BufferizationState &state) const {
auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
// Take a guard before anything else.
@ -2222,15 +2157,14 @@ struct TiledLoopOpInterface
const OpResult &opResult = tiledLoopOp->getResult(resultIndex);
OpOperand &yieldOperand = yieldOp->getOpOperand(resultIndex);
Value resultBuffer =
getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn);
Value resultBuffer = getResultBuffer(b, opResult, state);
if (!resultBuffer)
return failure();
// Insert mapping and aliasing info.
aliasInfo.createAliasInfoEntry(resultBuffer);
aliasInfo.insertNewBufferEquivalence(opResult, resultBuffer);
map(bvm, opResult, resultBuffer);
state.aliasInfo.createAliasInfoEntry(resultBuffer);
state.aliasInfo.insertNewBufferEquivalence(opResult, resultBuffer);
state.mapBuffer(opResult, resultBuffer);
// Insert new operand and bbArg.
tiledLoopOp->insertOperands(nextOutputOperandIndex, resultBuffer);
@ -2238,9 +2172,10 @@ struct TiledLoopOpInterface
body->insertArgument(nextOutputBBArgIndex, resultBuffer.getType());
BlockArgument oldTensorBBArg = body->getArgument(oldOutputBBArgIndex);
// Insert mapping and aliasing info.
aliasInfo.createAliasInfoEntry(newBufferBBArg);
aliasInfo.insertNewBufferEquivalence(oldTensorBBArg, newBufferBBArg);
map(bvm, oldTensorBBArg, newBufferBBArg);
state.aliasInfo.createAliasInfoEntry(newBufferBBArg);
state.aliasInfo.insertNewBufferEquivalence(oldTensorBBArg,
newBufferBBArg);
state.mapBuffer(oldTensorBBArg, newBufferBBArg);
// Set operand of `linalg.yield` to the bbArg so it just canonicalizes
// away later.
@ -2268,8 +2203,7 @@ struct TiledLoopOpInterface
continue;
}
Value inputBuffer = lookup(bvm, oldInputTensor);
assert(inputBuffer && " missing buffer for operand");
Value inputBuffer = state.lookupBuffer(oldInputTensor);
// Insert new operand and bbArg.
tiledLoopOp->insertOperands(nextInputOperandIndex, inputBuffer);
@ -2278,9 +2212,10 @@ struct TiledLoopOpInterface
BlockArgument oldTensorBBArg = body->getArgument(oldInputBBArgIndex);
// Insert mapping and aliasing info.
aliasInfo.createAliasInfoEntry(newBufferBBArg);
aliasInfo.insertNewBufferEquivalence(oldTensorBBArg, newBufferBBArg);
map(bvm, oldTensorBBArg, newBufferBBArg);
state.aliasInfo.createAliasInfoEntry(newBufferBBArg);
state.aliasInfo.insertNewBufferEquivalence(oldTensorBBArg,
newBufferBBArg);
state.mapBuffer(oldTensorBBArg, newBufferBBArg);
// Increment indices.
++numNewInputBuffers;
@ -2318,9 +2253,7 @@ struct YieldOpInterface
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFn) const {
BufferizationState &state) const {
auto yieldOp = cast<linalg::YieldOp>(op);
// Take a guard before anything else.
@ -2394,9 +2327,7 @@ struct IfOpInterface
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFn) const {
BufferizationState &state) const {
// scf::IfOp is bufferized after scf::YieldOp in the else branch.
return success();
}
@ -2405,9 +2336,7 @@ struct IfOpInterface
/// Bufferize the scf::IfOp. This function is called after the YieldOp was
/// bufferized.
static LogicalResult bufferizeIfOp(scf::IfOp ifOp, OpBuilder &b,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFn) {
BufferizationState &state) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(ifOp);
@ -2420,13 +2349,12 @@ static LogicalResult bufferizeIfOp(scf::IfOp ifOp, OpBuilder &b,
assert(opResult.getType().isa<RankedTensorType>() &&
"unsupported unranked tensor");
Value resultBuffer =
getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn);
Value resultBuffer = getResultBuffer(b, opResult, state);
if (!resultBuffer)
return failure();
aliasInfo.createAliasInfoEntry(resultBuffer);
map(bvm, opResult, resultBuffer);
state.aliasInfo.createAliasInfoEntry(resultBuffer);
state.mapBuffer(opResult, resultBuffer);
}
return success();
@ -2477,9 +2405,7 @@ struct ForOpInterface
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFn) const {
BufferizationState &state) const {
// Note: This method is just setting up the mappings for the block arguments
// and the result buffer. The op is bufferized after the scf::YieldOp.
@ -2497,17 +2423,16 @@ struct ForOpInterface
"unsupported unranked tensor");
// TODO: More general: Matching bbArg does not bufferize to a read.
Value resultBuffer =
getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn);
Value resultBuffer = getResultBuffer(b, opResult, state);
if (!resultBuffer)
return failure();
OpOperand &opOperand = forOp.getOpOperandForResult(opResult);
BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand);
aliasInfo.createAliasInfoEntry(resultBuffer);
aliasInfo.insertNewBufferEquivalence(bbArg, resultBuffer);
map(bvm, bbArg, resultBuffer);
map(bvm, opResult, resultBuffer);
state.aliasInfo.createAliasInfoEntry(resultBuffer);
state.aliasInfo.insertNewBufferEquivalence(bbArg, resultBuffer);
state.mapBuffer(bbArg, resultBuffer);
state.mapBuffer(opResult, resultBuffer);
}
return success();
@ -2517,9 +2442,7 @@ struct ForOpInterface
/// Bufferize the scf::ForOp. This function is called after the YieldOp was
/// bufferized.
static LogicalResult bufferizeForOp(scf::ForOp forOp, OpBuilder &b,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFn) {
BufferizationState &state) {
auto yieldOp = cast<scf::YieldOp>(&forOp.region().front().back());
for (OpOperand &operand : yieldOp->getOpOperands()) {
auto tensorType = operand.get().getType().dyn_cast<TensorType>();
@ -2529,9 +2452,10 @@ static LogicalResult bufferizeForOp(scf::ForOp forOp, OpBuilder &b,
OpOperand &forOperand = forOp.getOpOperandForResult(
forOp->getResult(operand.getOperandNumber()));
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
Value yieldedBuffer = lookup(bvm, operand.get());
Value bbArgBuffer = lookup(bvm, bbArg);
if (!aliasInfo.areEquivalentBufferizedValues(yieldedBuffer, bbArgBuffer)) {
Value yieldedBuffer = state.lookupBuffer(operand.get());
Value bbArgBuffer = state.lookupBuffer(bbArg);
if (!state.aliasInfo.areEquivalentBufferizedValues(yieldedBuffer,
bbArgBuffer)) {
// TODO: this could get resolved with copies but it can also turn into
// swaps so we need to be careful about order of copies.
return yieldOp->emitError()
@ -2567,9 +2491,7 @@ struct YieldOpInterface
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFn) const {
BufferizationState &state) const {
auto yieldOp = cast<scf::YieldOp>(op);
if (auto execOp = dyn_cast<scf::ExecuteRegionOp>(yieldOp->getParentOp())) {
@ -2584,12 +2506,12 @@ struct YieldOpInterface
if (auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp())) {
if (ifOp.elseYield() != yieldOp)
return success();
return bufferizeIfOp(ifOp, b, bvm, aliasInfo, allocationFn);
return bufferizeIfOp(ifOp, b, state);
}
// Bufferize scf::ForOp after bufferizing the scf::YieldOp.
if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp()))
return bufferizeForOp(forOp, b, bvm, aliasInfo, allocationFn);
return bufferizeForOp(forOp, b, state);
return yieldOp->emitError("expected scf::ForOp parent for scf::YieldOp");
}
@ -2635,9 +2557,7 @@ struct CallOpInterface
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFn) const {
BufferizationState &state) const {
llvm_unreachable("CallOps are handled separately");
return failure();
}
@ -2659,9 +2579,7 @@ struct ReturnOpInterface
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFn) const {
BufferizationState &state) const {
auto returnOp = cast<ReturnOp>(op);
// Take a guard before anything else.
@ -2675,12 +2593,11 @@ struct ReturnOpInterface
auto tensorType = operand.get().getType().dyn_cast<TensorType>();
if (!tensorType)
continue;
Value v = lookup(bvm, operand.get());
assert(v && "missing buffer for result");
Value v = state.lookupBuffer(operand.get());
Value returnTensor = b.create<memref::TensorLoadOp>(returnOp.getLoc(), v);
operand.set(returnTensor);
aliasInfo.insertNewBufferEquivalence(returnTensor, v);
map(bvm, returnTensor, v);
state.aliasInfo.insertNewBufferEquivalence(returnTensor, v);
state.mapBuffer(returnTensor, v);
}
return success();
}
@ -2715,17 +2632,14 @@ struct CastOpInterface
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFn) const {
BufferizationState &state) const {
auto castOp = cast<tensor::CastOp>(op);
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(castOp);
Value resultBuffer =
getResultBuffer(b, castOp->getResult(0), bvm, aliasInfo, allocationFn);
Value resultBuffer = getResultBuffer(b, castOp->getResult(0), state);
if (!resultBuffer)
return failure();
Type sourceType = resultBuffer.getType();
@ -2744,8 +2658,8 @@ struct CastOpInterface
castOp.getResult().getType(), layout, memorySpace);
Value res =
b.create<memref::CastOp>(castOp.getLoc(), memRefType, resultBuffer);
aliasInfo.insertNewBufferEquivalence(res, castOp.getResult());
map(bvm, castOp.getResult(), res);
state.aliasInfo.insertNewBufferEquivalence(res, castOp.getResult());
state.mapBuffer(castOp.getResult(), res);
return success();
}
};
@ -2766,9 +2680,7 @@ struct DimOpInterface
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFn) const {
BufferizationState &state) const {
auto dimOp = cast<tensor::DimOp>(op);
// Take a guard before anything else.
@ -2776,8 +2688,7 @@ struct DimOpInterface
b.setInsertionPoint(dimOp);
if (dimOp.source().getType().isa<RankedTensorType>()) {
Value v = lookup(bvm, dimOp.source());
assert(v && "missing buffer");
Value v = state.lookupBuffer(dimOp.source());
dimOp.result().replaceAllUsesWith(
b.create<memref::DimOp>(dimOp.getLoc(), v, dimOp.index()));
}
@ -2812,9 +2723,7 @@ struct ExtractSliceOpInterface
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFn) const {
BufferizationState &state) const {
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
// Take a guard before anything else.
@ -2824,18 +2733,16 @@ struct ExtractSliceOpInterface
Location loc = extractSliceOp.getLoc();
// Bail if source was not bufferized.
Value srcMemref = lookup(bvm, extractSliceOp.source());
if (!srcMemref)
return failure();
Value srcMemref = state.lookupBuffer(extractSliceOp.source());
auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
auto dstTensorType =
extractSliceOp.result().getType().cast<RankedTensorType>();
// If not inplaceable, alloc.
Value alloc;
if (!aliasInfo.isInPlace(extractSliceOp->getResult(0)))
if (!state.aliasInfo.isInPlace(extractSliceOp->getResult(0)))
alloc = createNewAllocDeallocPairForShapedValue(
b, loc, extractSliceOp.result(), aliasInfo, allocationFn);
b, loc, extractSliceOp.result(), state);
// Set insertion point now that potential alloc/dealloc are introduced.
b.setInsertionPoint(extractSliceOp);
@ -2851,17 +2758,18 @@ struct ExtractSliceOpInterface
loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(),
extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
// Insert new alias.
aliasInfo.insertNewBufferAlias(subView, srcMemref);
state.aliasInfo.insertNewBufferAlias(subView, srcMemref);
/// If not inplaceable, copy.
if (alloc) {
// Do not copy if the copied data is never read.
if (isValueRead(extractSliceOp.result()))
allocationFn.memCpyFn(b, extractSliceOp.getLoc(), subView, alloc);
state.allocationFns.memCpyFn(b, extractSliceOp.getLoc(), subView,
alloc);
subView = alloc;
}
map(bvm, extractSliceOp.result(), subView);
state.mapBuffer(extractSliceOp.result(), subView);
return success();
}
};
@ -2882,9 +2790,7 @@ struct ExtractOpInterface
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFn) const {
BufferizationState &state) const {
auto extractOp = cast<tensor::ExtractOp>(op);
// Take a guard before anything else.
@ -2892,7 +2798,7 @@ struct ExtractOpInterface
b.setInsertionPoint(extractOp);
Location loc = extractOp.getLoc();
Value srcMemref = lookup(bvm, extractOp.tensor());
Value srcMemref = state.lookupBuffer(extractOp.tensor());
Value l = b.create<memref::LoadOp>(loc, srcMemref, extractOp.indices());
extractOp.replaceAllUsesWith(l);
return success();
@ -2950,9 +2856,7 @@ struct InsertSliceOpInterface
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFn) const {
BufferizationState &state) const {
auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
// Take a guard before anything else.
@ -2969,15 +2873,12 @@ struct InsertSliceOpInterface
// TODO: be very loud about it or even consider failing the pass.
// Alloc a copy for `insertSliceOp.dest()`, it will become the result
// buffer.
Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), bvm,
aliasInfo, allocationFn);
Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), state);
if (!dstMemref)
return failure();
auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
Value srcMemref = lookup(bvm, insertSliceOp.source());
if (!srcMemref)
return failure();
Value srcMemref = state.lookupBuffer(insertSliceOp.source());
auto subviewMemRefType =
memref::SubViewOp::inferRankReducedResultType(
insertSliceOp.getSourceType().getRank(), dstMemrefType,
@ -2991,9 +2892,9 @@ struct InsertSliceOpInterface
// - The result is not inplace. This is the case where the whole tensor is
// cloned and the clone needs to be updated.
// TODO: Is this necessary?
if (!isSourceEquivalentToAMatchingInplaceExtractSliceOp(aliasInfo,
if (!isSourceEquivalentToAMatchingInplaceExtractSliceOp(state.aliasInfo,
insertSliceOp) ||
!aliasInfo.isInPlace(insertSliceOp->getResult(0))) {
!state.aliasInfo.isInPlace(insertSliceOp->getResult(0))) {
LDBG("insert_slice needs extra source copy: " << insertSliceOp.source()
<< " -> copy\n");
// Take a subview of the dst.
@ -3001,11 +2902,12 @@ struct InsertSliceOpInterface
loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
// Insert new alias.
aliasInfo.insertNewBufferAlias(subView, dstMemref);
allocationFn.memCpyFn(b, insertSliceOp.getLoc(), srcMemref, subView);
state.aliasInfo.insertNewBufferAlias(subView, dstMemref);
state.allocationFns.memCpyFn(b, insertSliceOp.getLoc(), srcMemref,
subView);
}
map(bvm, insertSliceOp.result(), dstMemref);
state.mapBuffer(insertSliceOp.result(), dstMemref);
return success();
}
@ -3035,9 +2937,7 @@ struct TransferReadOpInterface
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFn) const {
BufferizationState &state) const {
auto transferReadOp = cast<vector::TransferReadOp>(op);
// Take a guard before anything else.
@ -3048,8 +2948,7 @@ struct TransferReadOpInterface
return failure();
// TransferReadOp always reads from the bufferized op.source().
Value v = lookup(bvm, transferReadOp.source());
assert(v && "missing buffer");
Value v = state.lookupBuffer(transferReadOp.source());
transferReadOp.sourceMutable().assign(v);
return success();
}
@ -3086,9 +2985,7 @@ struct TransferWriteOpInterface
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,
AllocationCallbacks &allocationFn) const {
BufferizationState &state) const {
auto writeOp = cast<vector::TransferWriteOp>(op);
// Take a guard before anything else.
@ -3101,15 +2998,14 @@ struct TransferWriteOpInterface
// Create a new transfer_write on buffer that doesn't have a return value.
// Leave the previous transfer_write to dead code as it still has uses at
// this point.
Value resultBuffer =
getResultBuffer(b, op->getResult(0), bvm, aliasInfo, allocationFn);
Value resultBuffer = getResultBuffer(b, op->getResult(0), state);
if (!resultBuffer)
return failure();
b.create<vector::TransferWriteOp>(
writeOp.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(),
writeOp.permutation_map(),
writeOp.in_bounds() ? *writeOp.in_bounds() : ArrayAttr());
map(bvm, op->getResult(0), resultBuffer);
state.mapBuffer(op->getResult(0), resultBuffer);
return success();
}