forked from OSchip/llvm-project
[mlir][linalg][bufferize] Group helpers in BufferizationState
This simplifies the signature of `bufferize`. Differential Revision: https://reviews.llvm.org/D113388
This commit is contained in:
parent
7ac1fd0da9
commit
aeb1c8d0ca
|
@ -197,6 +197,8 @@ findValueInReverseUseDefChain(Value value,
|
|||
/// is returned regardless of whether it is a memory write or not.
|
||||
Value findLastPrecedingWrite(Value value);
|
||||
|
||||
struct BufferizationState;
|
||||
|
||||
/// Callback functions that are used to allocate/deallocate/copy memory buffers.
|
||||
/// Comprehensive Bufferize provides default implementations of these functions.
|
||||
// TODO: Could be replaced with a "bufferization strategy" object with virtual
|
||||
|
@ -207,8 +209,7 @@ struct AllocationCallbacks {
|
|||
using DeallocationFn = std::function<void(OpBuilder &, Location, Value)>;
|
||||
using MemCpyFn = std::function<void(OpBuilder &, Location, Value, Value)>;
|
||||
using CreateAllocDeallocFn =
|
||||
std::function<Value(OpBuilder &, Location, Value,
|
||||
BufferizationAliasInfo &, AllocationCallbacks &)>;
|
||||
std::function<Value(OpBuilder &, Location, Value, BufferizationState &)>;
|
||||
|
||||
AllocationCallbacks(AllocationFn allocFn, DeallocationFn deallocFn,
|
||||
MemCpyFn copyFn, CreateAllocDeallocFn allocDeallocFn)
|
||||
|
@ -230,13 +231,40 @@ struct AllocationCallbacks {
|
|||
CreateAllocDeallocFn createAllocDeallocFn;
|
||||
};
|
||||
|
||||
/// BufferizationState keeps track of bufferization state and provides access to
|
||||
/// the results of the analysis.
|
||||
struct BufferizationState {
|
||||
BufferizationState(BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks &allocationFns,
|
||||
BlockAndValueMapping &tensorToBufferMap)
|
||||
: aliasInfo(aliasInfo), allocationFns(allocationFns),
|
||||
tensorToBufferMap(tensorToBufferMap) {}
|
||||
|
||||
/// Map tensor values to memref buffers.
|
||||
void mapBuffer(ValueRange tensors, ValueRange buffers);
|
||||
|
||||
/// Map a tensor value to a memref buffer.
|
||||
void mapBuffer(Value tensor, Value buffer);
|
||||
|
||||
/// Lookup the memref buffer that is associated to the given tensor value.
|
||||
/// Asserts if no buffer is associated.
|
||||
Value lookupBuffer(Value tensor) const;
|
||||
|
||||
/// `aliasInfo` keeps track of aliasing and equivalent values.
|
||||
BufferizationAliasInfo &aliasInfo;
|
||||
|
||||
/// `allocationFns` contains helper functions for creating alloc ops, dealloc
|
||||
/// ops and memcpy ops.
|
||||
AllocationCallbacks &allocationFns;
|
||||
|
||||
/// The mapping of tensors to buffers.
|
||||
BlockAndValueMapping &tensorToBufferMap;
|
||||
};
|
||||
|
||||
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
|
||||
/// a new buffer and copy over data from the existing buffer if out-of-place
|
||||
/// bufferization is necessary.
|
||||
Value getResultBuffer(OpBuilder &b, OpResult result,
|
||||
const BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks allocationFns);
|
||||
Value getResultBuffer(OpBuilder &b, OpResult result, BufferizationState &state);
|
||||
|
||||
} // namespace comprehensive_bufferize
|
||||
} // namespace linalg
|
||||
|
@ -280,9 +308,7 @@ struct AllocationHoistingBarrierOnly
|
|||
bool isWritable(Operation *op, Value value) const { return false; }
|
||||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks &allocationFn) const {
|
||||
BufferizationState &state) const {
|
||||
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
|
||||
if (any_of(op->getOperandTypes(), isaTensor) ||
|
||||
any_of(op->getResultTypes(), isaTensor))
|
||||
|
|
|
@ -160,8 +160,6 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
|||
llvm_unreachable("bufferRelation not implemented");
|
||||
}]
|
||||
>,
|
||||
// TODO: Simplify method signature: Pass an OpBuilder and a
|
||||
// BufferizationState object.
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Bufferize this op, i.e., rewrite it into a memref-based equivalent.
|
||||
|
@ -171,9 +169,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
|||
/*retType=*/"LogicalResult",
|
||||
/*methodName=*/"bufferize",
|
||||
/*args=*/(ins "OpBuilder &":$b,
|
||||
"BlockAndValueMapping &":$bvm,
|
||||
"BufferizationAliasInfo &":$aliasInfo,
|
||||
"AllocationCallbacks &":$allocationFn),
|
||||
"BufferizationState &":$state),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
llvm_unreachable("bufferize not implemented");
|
||||
|
|
|
@ -27,6 +27,8 @@ namespace comprehensive_bufferize {
|
|||
// TODO: from some HW description.
|
||||
static constexpr int64_t kBufferAlignments = 128;
|
||||
|
||||
struct BufferizationState;
|
||||
|
||||
/// Analyze the `ops` to determine which OpResults are inplaceable.
|
||||
LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
|
@ -55,9 +57,7 @@ std::unique_ptr<AllocationCallbacks> defaultAllocationCallbacks();
|
|||
/// `bufferizedFunctionTypes` (resp. `globalCreator`) are expected to be
|
||||
/// non-null if `op` is a CallOpInterface (resp. GlobalCreator).
|
||||
LogicalResult
|
||||
bufferizeOp(Operation *op, BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks allocationFns,
|
||||
bufferizeOp(Operation *op, BufferizationState &state,
|
||||
DenseMap<FuncOp, FunctionType> *bufferizedFunctionTypes = nullptr);
|
||||
|
||||
/// Register external models implemented for the `BufferizableOpInterface`.
|
||||
|
|
|
@ -7,8 +7,11 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
|
||||
#include "mlir/IR/AsmState.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
namespace mlir {
|
||||
|
@ -319,30 +322,28 @@ Value mlir::linalg::comprehensive_bufferize::findLastPrecedingWrite(
|
|||
/// a new buffer and copy over data from the existing buffer if out-of-place
|
||||
/// bufferization is necessary.
|
||||
Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
|
||||
OpBuilder &b, OpResult result, const BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo, AllocationCallbacks allocationFns) {
|
||||
OpBuilder &b, OpResult result, BufferizationState &state) {
|
||||
OpBuilder::InsertionGuard guard(b);
|
||||
Operation *op = result.getOwner();
|
||||
SmallVector<OpOperand *> aliasingOperands = getAliasingOpOperand(result);
|
||||
assert(!aliasingOperands.empty() && "could not get aliasing OpOperand");
|
||||
OpOperand *opOperand = aliasingOperands.front();
|
||||
Value operand = opOperand->get();
|
||||
Value operandBuffer = bvm.lookupOrNull(operand);
|
||||
assert(operandBuffer && "operand buffer not found");
|
||||
Value operandBuffer = state.lookupBuffer(operand);
|
||||
// Make sure that all OpOperands are the same buffer. If this is not the case,
|
||||
// we would have to materialize a memref value.
|
||||
// TODO: Should be looking for checking for "equivalent buffers" instead of
|
||||
// operator== here, but equivalent buffers for scf.if yield values are not
|
||||
// set up yet.
|
||||
if (!llvm::all_of(aliasingOperands, [&](OpOperand *o) {
|
||||
return bvm.lookup(o->get()) == operandBuffer;
|
||||
return state.lookupBuffer(o->get()) == operandBuffer;
|
||||
})) {
|
||||
op->emitError("result buffer is ambiguous");
|
||||
return Value();
|
||||
}
|
||||
|
||||
// If bufferizing out-of-place, allocate a new buffer.
|
||||
if (!aliasInfo.isInPlace(result)) {
|
||||
if (!state.aliasInfo.isInPlace(result)) {
|
||||
// Ops with multiple aliasing operands can currently not bufferize
|
||||
// out-of-place.
|
||||
assert(
|
||||
|
@ -350,8 +351,8 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
|
|||
"ops with multiple aliasing OpOperands cannot bufferize out-of-place");
|
||||
Location loc = op->getLoc();
|
||||
// Allocate the result buffer.
|
||||
Value resultBuffer = allocationFns.createAllocDeallocFn(
|
||||
b, loc, operand, aliasInfo, allocationFns);
|
||||
Value resultBuffer =
|
||||
state.allocationFns.createAllocDeallocFn(b, loc, operand, state);
|
||||
bool skipCopy = false;
|
||||
// Do not copy if the last preceding write of `operand` is an op that does
|
||||
// not write (skipping ops that merely create aliases). E.g., InitTensorOp.
|
||||
|
@ -373,7 +374,7 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
|
|||
if (!skipCopy) {
|
||||
// Set insertion point now that potential alloc/dealloc are introduced.
|
||||
b.setInsertionPoint(op);
|
||||
allocationFns.memCpyFn(b, loc, operandBuffer, resultBuffer);
|
||||
state.allocationFns.memCpyFn(b, loc, operandBuffer, resultBuffer);
|
||||
}
|
||||
return resultBuffer;
|
||||
}
|
||||
|
@ -381,3 +382,39 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
|
|||
// Bufferizing in-place. No need to allocate a new buffer.
|
||||
return operandBuffer;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Bufferization-specific BlockAndValueMapping support with debugging.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Wrapper for better debugging.
|
||||
void mlir::linalg::comprehensive_bufferize::BufferizationState::mapBuffer(
|
||||
ValueRange tensors, ValueRange buffers) {
|
||||
assert(!tensors.empty() && "unexpected empty tensors");
|
||||
return tensorToBufferMap.map(tensors, buffers);
|
||||
}
|
||||
|
||||
/// Wrapper for better debugging.
|
||||
void mlir::linalg::comprehensive_bufferize::BufferizationState::mapBuffer(
|
||||
Value tensor, Value buffer) {
|
||||
assert(tensor && "unexpected empty tensor");
|
||||
assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
|
||||
return tensorToBufferMap.map(tensor, buffer);
|
||||
}
|
||||
|
||||
/// Wrapper for better debugging.
|
||||
Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
|
||||
Value tensor) const {
|
||||
// TODO: if key comes from bbArg, forward.
|
||||
assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
|
||||
Value v = tensorToBufferMap.lookupOrNull(tensor);
|
||||
|
||||
if (!v) {
|
||||
// Dump tensor for easier debugging.
|
||||
tensor.dump();
|
||||
llvm_unreachable("tensor is not mapped");
|
||||
return Value();
|
||||
}
|
||||
|
||||
return v;
|
||||
}
|
||||
|
|
|
@ -172,47 +172,6 @@ static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
|
|||
return returnOp;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Bufferization-specific BlockAndValueMapping support with debugging.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Wrapper for better debugging.
|
||||
static void map(BlockAndValueMapping &bvm, ValueRange keys, ValueRange values) {
|
||||
assert(!keys.empty() && "Unexpected empty keys");
|
||||
LDBG("\n\tMap: " << printValueInfo(keys.front())
|
||||
<< "\n\tto: " << printValueInfo(values.front()) << '\n');
|
||||
return bvm.map(keys, values);
|
||||
}
|
||||
|
||||
/// Wrapper for better debugging.
|
||||
static void map(BlockAndValueMapping &bvm, Value key, Value value) {
|
||||
LDBG("\n\tMap: " << printValueInfo(key) << "\n\tto: " << printValueInfo(value)
|
||||
<< '\n');
|
||||
return bvm.map(key, value);
|
||||
}
|
||||
|
||||
/// Wrapper for better debugging.
|
||||
static Value lookup(const BlockAndValueMapping &bvm, Value key) {
|
||||
// TODO: if key comes from bbArg, forward.
|
||||
assert(key.getType().isa<TensorType>());
|
||||
Value v = bvm.lookupOrNull(key);
|
||||
if (v)
|
||||
return v;
|
||||
|
||||
Operation *parentOp;
|
||||
if (auto bbArg = key.dyn_cast<BlockArgument>()) {
|
||||
if (isa<FuncOp>(key.getParentBlock()->getParentOp()))
|
||||
parentOp = key.getParentBlock()->getParentOp();
|
||||
else
|
||||
parentOp = key.getParentBlock()->getParentOp()->getParentOfType<FuncOp>();
|
||||
} else {
|
||||
parentOp = key.getDefiningOp()->getParentOfType<FuncOp>();
|
||||
}
|
||||
LDBG("In func:\n" << *parentOp << "\nNO VALUE FOR KEY: " << key << '\n');
|
||||
(void)parentOp;
|
||||
return Value();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Bufferization-specific attribute manipulation.
|
||||
// These are for testing and debugging only. Bufferization information is
|
||||
|
@ -878,8 +837,7 @@ static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
|
|||
/// `shapedValue.getDefiningOp` (or at the top of the block in case of a
|
||||
/// bbArg) and the DeallocOp is at the end of the block.
|
||||
static Value createNewAllocDeallocPairForShapedValue(
|
||||
OpBuilder &b, Location loc, Value shapedValue,
|
||||
BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns) {
|
||||
OpBuilder &b, Location loc, Value shapedValue, BufferizationState &state) {
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
|
||||
|
@ -891,19 +849,19 @@ static Value createNewAllocDeallocPairForShapedValue(
|
|||
MemRefType allocMemRefType =
|
||||
getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
|
||||
Optional<Value> allocated =
|
||||
allocationFns.allocationFn(b, loc, allocMemRefType, dynShape);
|
||||
state.allocationFns.allocationFn(b, loc, allocMemRefType, dynShape);
|
||||
// TODO: For now just assert the value is returned. Eventually need to
|
||||
// error-propagate.
|
||||
assert(allocated && "allocation failed");
|
||||
Value casted = allocated.getValue();
|
||||
if (memRefType && memRefType != allocMemRefType) {
|
||||
casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
|
||||
aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue());
|
||||
state.aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue());
|
||||
}
|
||||
|
||||
// 2. Create memory deallocation.
|
||||
b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator());
|
||||
allocationFns.deallocationFn(b, loc, allocated.getValue());
|
||||
state.allocationFns.deallocationFn(b, loc, allocated.getValue());
|
||||
return casted;
|
||||
}
|
||||
|
||||
|
@ -915,8 +873,7 @@ static Value createNewAllocDeallocPairForShapedValue(
|
|||
/// inplaceable. For now, it is the responsibility of the `callOp` bufferization
|
||||
/// to allow FuncOp that are inplaceable to write inPlace.
|
||||
static LogicalResult
|
||||
bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns,
|
||||
bufferize(OpBuilder &b, CallOpInterface callOp, BufferizationState &state,
|
||||
DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
|
||||
FuncOp funcOp = getCalledFunction(callOp);
|
||||
assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
|
||||
|
@ -962,14 +919,13 @@ bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
|
|||
// If return operand is equivalent to some bbArg, no need to return it.
|
||||
Value returnVal = returnOperand.get();
|
||||
if (BlockArgument bbArg =
|
||||
getEquivalentEnclosingFuncBBArg(returnVal, aliasInfo)) {
|
||||
getEquivalentEnclosingFuncBBArg(returnVal, state.aliasInfo)) {
|
||||
Value oldRes = callOp->getResult(returnOperand.getOperandNumber());
|
||||
int64_t idx = bbArg.getArgNumber();
|
||||
Value buffer = lookup(bvm, callOp->getOperand(idx));
|
||||
assert(buffer && "expected bufferized value");
|
||||
Value buffer = state.lookupBuffer(callOp->getOperand(idx));
|
||||
// Add CallOp operand/result equivalence: this is interprocedural info.
|
||||
aliasInfo.insertNewBufferEquivalence(oldRes, buffer);
|
||||
map(bvm, oldRes, buffer);
|
||||
state.aliasInfo.insertNewBufferEquivalence(oldRes, buffer);
|
||||
state.mapBuffer(oldRes, buffer);
|
||||
// Add a TensorLoadOp to kill all uses of the CallOp return.
|
||||
// Replace all uses of the CallOp results so we can erase the CallOp.
|
||||
// This TensorLoadOp must fold/DCE away or bufferization should be
|
||||
|
@ -978,13 +934,13 @@ bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
|
|||
b.create<memref::TensorLoadOp>(callOp.getLoc(), buffer);
|
||||
oldRes.replaceAllUsesWith(tensorLoad);
|
||||
// Add new op equivalence info.
|
||||
aliasInfo.insertNewBufferEquivalence(tensorLoad, buffer);
|
||||
map(bvm, tensorLoad, buffer);
|
||||
state.aliasInfo.insertNewBufferEquivalence(tensorLoad, buffer);
|
||||
state.mapBuffer(tensorLoad, buffer);
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO: Need to hoist above function boundary.
|
||||
if (Operation *allocOp = getEquivalentAlloc(returnVal, aliasInfo)) {
|
||||
if (Operation *allocOp = getEquivalentAlloc(returnVal, state.aliasInfo)) {
|
||||
hoistedArguments.push_back(allocOp->getResult(0));
|
||||
continue;
|
||||
}
|
||||
|
@ -1023,8 +979,7 @@ bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
|
|||
|
||||
// Tensor operands are guaranteed to have been buferized.
|
||||
int64_t idx = opOperand.getOperandNumber();
|
||||
Value buffer = lookup(bvm, tensorOperand);
|
||||
assert(buffer && "expected bufferized value");
|
||||
Value buffer = state.lookupBuffer(tensorOperand);
|
||||
|
||||
// Caller / callee type mistmatch is handled with a CastOp.
|
||||
auto memRefType = bufferizedFuncType.getInput(idx);
|
||||
|
@ -1037,8 +992,8 @@ bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
|
|||
Value castBuffer =
|
||||
b.create<memref::CastOp>(callOp.getLoc(), memRefType, buffer);
|
||||
// Add new op equivalence info.
|
||||
aliasInfo.insertNewBufferEquivalence(castBuffer, buffer);
|
||||
map(bvm, tensorOperand, castBuffer);
|
||||
state.aliasInfo.insertNewBufferEquivalence(castBuffer, buffer);
|
||||
state.mapBuffer(tensorOperand, castBuffer);
|
||||
buffer = castBuffer;
|
||||
}
|
||||
newOperands.push_back(buffer);
|
||||
|
@ -1054,9 +1009,7 @@ bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
|
|||
|
||||
/// FuncOp always creates TensorToMemRef ops.
|
||||
static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks &allocationFn) {
|
||||
BufferizationState &state) {
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
b.setInsertionPointToStart(&funcOp.body().front());
|
||||
|
@ -1072,8 +1025,8 @@ static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
|
|||
: getContiguousOrUnrankedMemRefType(tensorType);
|
||||
Value bufferCast =
|
||||
b.create<memref::BufferCastOp>(funcOp.getLoc(), memRefType, bbArg);
|
||||
aliasInfo.insertNewBufferEquivalence(bufferCast, bbArg);
|
||||
map(bvm, bbArg, bufferCast);
|
||||
state.aliasInfo.insertNewBufferEquivalence(bufferCast, bbArg);
|
||||
state.mapBuffer(bbArg, bufferCast);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
@ -1230,8 +1183,7 @@ void mlir::linalg::comprehensive_bufferize::defaultMemCpyFn(OpBuilder &b,
|
|||
}
|
||||
|
||||
LogicalResult mlir::linalg::comprehensive_bufferize::bufferizeOp(
|
||||
Operation *op, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks allocationFns,
|
||||
Operation *op, BufferizationState &state,
|
||||
DenseMap<FuncOp, FunctionType> *bufferizedFunctionTypes) {
|
||||
OpBuilder b(op->getContext());
|
||||
|
||||
|
@ -1241,8 +1193,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::bufferizeOp(
|
|||
if (!bufferizedFunctionTypes)
|
||||
llvm_unreachable(
|
||||
"null bufferizedFunctionTypes when bufferizing CallOpInterface");
|
||||
return bufferize(b, callOp, bvm, aliasInfo, allocationFns,
|
||||
*bufferizedFunctionTypes);
|
||||
return bufferize(b, callOp, state, *bufferizedFunctionTypes);
|
||||
}
|
||||
|
||||
// Skip BufferCast and TensorLoad ops.
|
||||
|
@ -1251,7 +1202,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::bufferizeOp(
|
|||
|
||||
// Bufferize using `BufferizableOpInterface`.
|
||||
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
|
||||
return bufferizableOp.bufferize(b, bvm, aliasInfo, allocationFns);
|
||||
return bufferizableOp.bufferize(b, state);
|
||||
|
||||
// Other op with tensors. No bufferization method specified.
|
||||
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
|
||||
|
@ -1262,23 +1213,21 @@ LogicalResult mlir::linalg::comprehensive_bufferize::bufferizeOp(
|
|||
}
|
||||
|
||||
static LogicalResult bufferizeFuncOpInternals(
|
||||
FuncOp funcOp, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks &allocationFns,
|
||||
FuncOp funcOp, BufferizationState &state,
|
||||
DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "\n\n");
|
||||
LDBG("Begin BufferizeFuncOpInternals:\n" << funcOp << '\n');
|
||||
OpBuilder b(funcOp->getContext());
|
||||
|
||||
// Start by bufferizing `funcOp` arguments.
|
||||
if (failed(bufferize(b, funcOp, bvm, aliasInfo, allocationFns)))
|
||||
if (failed(bufferize(b, funcOp, state)))
|
||||
return failure();
|
||||
|
||||
// Cannot erase ops during the traversal. Do that afterwards.
|
||||
SmallVector<Operation *> toErase;
|
||||
|
||||
auto walkFunc = [&](Operation *op) -> WalkResult {
|
||||
if (failed(bufferizeOp(op, bvm, aliasInfo, allocationFns,
|
||||
&bufferizedFunctionTypes)))
|
||||
if (failed(bufferizeOp(op, state, &bufferizedFunctionTypes)))
|
||||
return failure();
|
||||
|
||||
// Register post-walk erasure, if necessary.
|
||||
|
@ -1852,9 +1801,10 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
|
|||
// Bufferization phase.
|
||||
if (!options.testAnalysisOnly) {
|
||||
BlockAndValueMapping tensorToBufferMap;
|
||||
if (failed(bufferizeFuncOpInternals(funcOp, tensorToBufferMap, aliasInfo,
|
||||
*options.allocationFns,
|
||||
bufferizedFunctionTypes)))
|
||||
BufferizationState state(aliasInfo, *options.allocationFns,
|
||||
tensorToBufferMap);
|
||||
if (failed(
|
||||
bufferizeFuncOpInternals(funcOp, state, bufferizedFunctionTypes)))
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
@ -1926,9 +1876,7 @@ struct ConstantOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks &allocationFn) const {
|
||||
BufferizationState &state) const {
|
||||
auto constantOp = cast<arith::ConstantOp>(op);
|
||||
if (!isaTensor(constantOp.getResult().getType()))
|
||||
return success();
|
||||
|
@ -1948,8 +1896,8 @@ struct ConstantOpInterface
|
|||
auto globalMemref = globalCreator.getGlobalFor(constantOp);
|
||||
Value memref = b.create<memref::GetGlobalOp>(
|
||||
constantOp.getLoc(), globalMemref.type(), globalMemref.getName());
|
||||
aliasInfo.insertNewBufferEquivalence(memref, constantOp.getResult());
|
||||
map(bvm, constantOp, memref);
|
||||
state.aliasInfo.insertNewBufferEquivalence(memref, constantOp.getResult());
|
||||
state.mapBuffer(constantOp, memref);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -1969,10 +1917,10 @@ namespace linalg_ext {
|
|||
/// Helper function for LinalgOp bufferization.
|
||||
/// When allocating a new buffer, analyze whether `op` wants to read form that
|
||||
/// buffer. Only in that case, a copy of the result buffer may be needed.
|
||||
static LogicalResult allocateBuffersForResults(
|
||||
OpBuilder &b, Location loc, LinalgOp op,
|
||||
SmallVectorImpl<Value> &resultBuffers, BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns) {
|
||||
static LogicalResult
|
||||
allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
|
||||
SmallVectorImpl<Value> &resultBuffers,
|
||||
BufferizationState &state) {
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
b.setInsertionPoint(op);
|
||||
|
@ -1983,24 +1931,21 @@ static LogicalResult allocateBuffersForResults(
|
|||
OpResult opResult = cast<BufferizableOpInterface>(op.getOperation())
|
||||
.getAliasingOpResult(*opOperand);
|
||||
assert(opResult && "could not find correspond OpResult");
|
||||
Value resultBuffer =
|
||||
getResultBuffer(b, opResult, bvm, aliasInfo, allocationFns);
|
||||
Value resultBuffer = getResultBuffer(b, opResult, state);
|
||||
if (!resultBuffer)
|
||||
return failure();
|
||||
resultBuffers.push_back(resultBuffer);
|
||||
}
|
||||
|
||||
if (op->getNumResults())
|
||||
map(bvm, op->getResults(), resultBuffers);
|
||||
state.mapBuffer(op->getResults(), resultBuffers);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Generic conversion for any LinalgOp on tensors.
|
||||
static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks &allocationFns) {
|
||||
BufferizationState &state) {
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
|
||||
|
@ -2017,13 +1962,11 @@ static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op,
|
|||
newInputBuffers.push_back(opOperand->get());
|
||||
continue;
|
||||
}
|
||||
newInputBuffers.push_back(lookup(bvm, opOperand->get()));
|
||||
assert(newInputBuffers.back() && "missing buffer");
|
||||
newInputBuffers.push_back(state.lookupBuffer(opOperand->get()));
|
||||
}
|
||||
SmallVector<Value> newOutputBuffers;
|
||||
// Try to allocate new buffers depending on op's inplace semantics.
|
||||
if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm,
|
||||
aliasInfo, allocationFns)))
|
||||
if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, state)))
|
||||
return failure();
|
||||
|
||||
// Clone the newly bufferized op.
|
||||
|
@ -2036,7 +1979,7 @@ static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op,
|
|||
|
||||
// Replace the results of the old op with the new output buffers.
|
||||
if (op->getNumResults())
|
||||
map(bvm, op->getResults(), newOutputBuffers);
|
||||
state.mapBuffer(op->getResults(), newOutputBuffers);
|
||||
|
||||
// The original op will be DCE'd away later.
|
||||
|
||||
|
@ -2087,11 +2030,8 @@ struct LinalgOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks &allocationFn) const {
|
||||
return bufferizeLinalgOp(b, cast<LinalgOp>(op), bvm, aliasInfo,
|
||||
allocationFn);
|
||||
BufferizationState &state) const {
|
||||
return bufferizeLinalgOp(b, cast<LinalgOp>(op), state);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -2109,9 +2049,7 @@ struct InitTensorOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks &allocationFn) const {
|
||||
BufferizationState &state) const {
|
||||
auto initTensorOp = cast<linalg::InitTensorOp>(op);
|
||||
|
||||
// The InitTensorOp may have been eliminated.
|
||||
|
@ -2123,9 +2061,8 @@ struct InitTensorOpInterface
|
|||
b.setInsertionPoint(initTensorOp);
|
||||
|
||||
Value alloc = createNewAllocDeallocPairForShapedValue(
|
||||
b, initTensorOp->getLoc(), initTensorOp.result(), aliasInfo,
|
||||
allocationFn);
|
||||
map(bvm, initTensorOp.result(), alloc);
|
||||
b, initTensorOp->getLoc(), initTensorOp.result(), state);
|
||||
state.mapBuffer(initTensorOp.result(), alloc);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -2178,9 +2115,7 @@ struct TiledLoopOpInterface
|
|||
bool isAllocationHoistingBarrier(Operation *op) const { return true; }
|
||||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks &allocationFn) const {
|
||||
BufferizationState &state) const {
|
||||
auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
|
||||
|
||||
// Take a guard before anything else.
|
||||
|
@ -2222,15 +2157,14 @@ struct TiledLoopOpInterface
|
|||
|
||||
const OpResult &opResult = tiledLoopOp->getResult(resultIndex);
|
||||
OpOperand &yieldOperand = yieldOp->getOpOperand(resultIndex);
|
||||
Value resultBuffer =
|
||||
getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn);
|
||||
Value resultBuffer = getResultBuffer(b, opResult, state);
|
||||
if (!resultBuffer)
|
||||
return failure();
|
||||
|
||||
// Insert mapping and aliasing info.
|
||||
aliasInfo.createAliasInfoEntry(resultBuffer);
|
||||
aliasInfo.insertNewBufferEquivalence(opResult, resultBuffer);
|
||||
map(bvm, opResult, resultBuffer);
|
||||
state.aliasInfo.createAliasInfoEntry(resultBuffer);
|
||||
state.aliasInfo.insertNewBufferEquivalence(opResult, resultBuffer);
|
||||
state.mapBuffer(opResult, resultBuffer);
|
||||
|
||||
// Insert new operand and bbArg.
|
||||
tiledLoopOp->insertOperands(nextOutputOperandIndex, resultBuffer);
|
||||
|
@ -2238,9 +2172,10 @@ struct TiledLoopOpInterface
|
|||
body->insertArgument(nextOutputBBArgIndex, resultBuffer.getType());
|
||||
BlockArgument oldTensorBBArg = body->getArgument(oldOutputBBArgIndex);
|
||||
// Insert mapping and aliasing info.
|
||||
aliasInfo.createAliasInfoEntry(newBufferBBArg);
|
||||
aliasInfo.insertNewBufferEquivalence(oldTensorBBArg, newBufferBBArg);
|
||||
map(bvm, oldTensorBBArg, newBufferBBArg);
|
||||
state.aliasInfo.createAliasInfoEntry(newBufferBBArg);
|
||||
state.aliasInfo.insertNewBufferEquivalence(oldTensorBBArg,
|
||||
newBufferBBArg);
|
||||
state.mapBuffer(oldTensorBBArg, newBufferBBArg);
|
||||
|
||||
// Set operand of `linalg.yield` to the bbArg so it just canonicalizes
|
||||
// away later.
|
||||
|
@ -2268,8 +2203,7 @@ struct TiledLoopOpInterface
|
|||
continue;
|
||||
}
|
||||
|
||||
Value inputBuffer = lookup(bvm, oldInputTensor);
|
||||
assert(inputBuffer && " missing buffer for operand");
|
||||
Value inputBuffer = state.lookupBuffer(oldInputTensor);
|
||||
|
||||
// Insert new operand and bbArg.
|
||||
tiledLoopOp->insertOperands(nextInputOperandIndex, inputBuffer);
|
||||
|
@ -2278,9 +2212,10 @@ struct TiledLoopOpInterface
|
|||
BlockArgument oldTensorBBArg = body->getArgument(oldInputBBArgIndex);
|
||||
|
||||
// Insert mapping and aliasing info.
|
||||
aliasInfo.createAliasInfoEntry(newBufferBBArg);
|
||||
aliasInfo.insertNewBufferEquivalence(oldTensorBBArg, newBufferBBArg);
|
||||
map(bvm, oldTensorBBArg, newBufferBBArg);
|
||||
state.aliasInfo.createAliasInfoEntry(newBufferBBArg);
|
||||
state.aliasInfo.insertNewBufferEquivalence(oldTensorBBArg,
|
||||
newBufferBBArg);
|
||||
state.mapBuffer(oldTensorBBArg, newBufferBBArg);
|
||||
|
||||
// Increment indices.
|
||||
++numNewInputBuffers;
|
||||
|
@ -2318,9 +2253,7 @@ struct YieldOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks &allocationFn) const {
|
||||
BufferizationState &state) const {
|
||||
auto yieldOp = cast<linalg::YieldOp>(op);
|
||||
|
||||
// Take a guard before anything else.
|
||||
|
@ -2394,9 +2327,7 @@ struct IfOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks &allocationFn) const {
|
||||
BufferizationState &state) const {
|
||||
// scf::IfOp is bufferized after scf::YieldOp in the else branch.
|
||||
return success();
|
||||
}
|
||||
|
@ -2405,9 +2336,7 @@ struct IfOpInterface
|
|||
/// Bufferize the scf::IfOp. This function is called after the YieldOp was
|
||||
/// bufferized.
|
||||
static LogicalResult bufferizeIfOp(scf::IfOp ifOp, OpBuilder &b,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks &allocationFn) {
|
||||
BufferizationState &state) {
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
b.setInsertionPoint(ifOp);
|
||||
|
@ -2420,13 +2349,12 @@ static LogicalResult bufferizeIfOp(scf::IfOp ifOp, OpBuilder &b,
|
|||
assert(opResult.getType().isa<RankedTensorType>() &&
|
||||
"unsupported unranked tensor");
|
||||
|
||||
Value resultBuffer =
|
||||
getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn);
|
||||
Value resultBuffer = getResultBuffer(b, opResult, state);
|
||||
if (!resultBuffer)
|
||||
return failure();
|
||||
|
||||
aliasInfo.createAliasInfoEntry(resultBuffer);
|
||||
map(bvm, opResult, resultBuffer);
|
||||
state.aliasInfo.createAliasInfoEntry(resultBuffer);
|
||||
state.mapBuffer(opResult, resultBuffer);
|
||||
}
|
||||
|
||||
return success();
|
||||
|
@ -2477,9 +2405,7 @@ struct ForOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks &allocationFn) const {
|
||||
BufferizationState &state) const {
|
||||
// Note: This method is just setting up the mappings for the block arguments
|
||||
// and the result buffer. The op is bufferized after the scf::YieldOp.
|
||||
|
||||
|
@ -2497,17 +2423,16 @@ struct ForOpInterface
|
|||
"unsupported unranked tensor");
|
||||
|
||||
// TODO: More general: Matching bbArg does not bufferize to a read.
|
||||
Value resultBuffer =
|
||||
getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn);
|
||||
Value resultBuffer = getResultBuffer(b, opResult, state);
|
||||
if (!resultBuffer)
|
||||
return failure();
|
||||
|
||||
OpOperand &opOperand = forOp.getOpOperandForResult(opResult);
|
||||
BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand);
|
||||
aliasInfo.createAliasInfoEntry(resultBuffer);
|
||||
aliasInfo.insertNewBufferEquivalence(bbArg, resultBuffer);
|
||||
map(bvm, bbArg, resultBuffer);
|
||||
map(bvm, opResult, resultBuffer);
|
||||
state.aliasInfo.createAliasInfoEntry(resultBuffer);
|
||||
state.aliasInfo.insertNewBufferEquivalence(bbArg, resultBuffer);
|
||||
state.mapBuffer(bbArg, resultBuffer);
|
||||
state.mapBuffer(opResult, resultBuffer);
|
||||
}
|
||||
|
||||
return success();
|
||||
|
@ -2517,9 +2442,7 @@ struct ForOpInterface
|
|||
/// Bufferize the scf::ForOp. This function is called after the YieldOp was
|
||||
/// bufferized.
|
||||
static LogicalResult bufferizeForOp(scf::ForOp forOp, OpBuilder &b,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks &allocationFn) {
|
||||
BufferizationState &state) {
|
||||
auto yieldOp = cast<scf::YieldOp>(&forOp.region().front().back());
|
||||
for (OpOperand &operand : yieldOp->getOpOperands()) {
|
||||
auto tensorType = operand.get().getType().dyn_cast<TensorType>();
|
||||
|
@ -2529,9 +2452,10 @@ static LogicalResult bufferizeForOp(scf::ForOp forOp, OpBuilder &b,
|
|||
OpOperand &forOperand = forOp.getOpOperandForResult(
|
||||
forOp->getResult(operand.getOperandNumber()));
|
||||
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
|
||||
Value yieldedBuffer = lookup(bvm, operand.get());
|
||||
Value bbArgBuffer = lookup(bvm, bbArg);
|
||||
if (!aliasInfo.areEquivalentBufferizedValues(yieldedBuffer, bbArgBuffer)) {
|
||||
Value yieldedBuffer = state.lookupBuffer(operand.get());
|
||||
Value bbArgBuffer = state.lookupBuffer(bbArg);
|
||||
if (!state.aliasInfo.areEquivalentBufferizedValues(yieldedBuffer,
|
||||
bbArgBuffer)) {
|
||||
// TODO: this could get resolved with copies but it can also turn into
|
||||
// swaps so we need to be careful about order of copies.
|
||||
return yieldOp->emitError()
|
||||
|
@ -2567,9 +2491,7 @@ struct YieldOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks &allocationFn) const {
|
||||
BufferizationState &state) const {
|
||||
auto yieldOp = cast<scf::YieldOp>(op);
|
||||
|
||||
if (auto execOp = dyn_cast<scf::ExecuteRegionOp>(yieldOp->getParentOp())) {
|
||||
|
@ -2584,12 +2506,12 @@ struct YieldOpInterface
|
|||
if (auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp())) {
|
||||
if (ifOp.elseYield() != yieldOp)
|
||||
return success();
|
||||
return bufferizeIfOp(ifOp, b, bvm, aliasInfo, allocationFn);
|
||||
return bufferizeIfOp(ifOp, b, state);
|
||||
}
|
||||
|
||||
// Bufferize scf::ForOp after bufferizing the scf::YieldOp.
|
||||
if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp()))
|
||||
return bufferizeForOp(forOp, b, bvm, aliasInfo, allocationFn);
|
||||
return bufferizeForOp(forOp, b, state);
|
||||
|
||||
return yieldOp->emitError("expected scf::ForOp parent for scf::YieldOp");
|
||||
}
|
||||
|
@ -2635,9 +2557,7 @@ struct CallOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks &allocationFn) const {
|
||||
BufferizationState &state) const {
|
||||
llvm_unreachable("CallOps are handled separately");
|
||||
return failure();
|
||||
}
|
||||
|
@ -2659,9 +2579,7 @@ struct ReturnOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks &allocationFn) const {
|
||||
BufferizationState &state) const {
|
||||
auto returnOp = cast<ReturnOp>(op);
|
||||
|
||||
// Take a guard before anything else.
|
||||
|
@ -2675,12 +2593,11 @@ struct ReturnOpInterface
|
|||
auto tensorType = operand.get().getType().dyn_cast<TensorType>();
|
||||
if (!tensorType)
|
||||
continue;
|
||||
Value v = lookup(bvm, operand.get());
|
||||
assert(v && "missing buffer for result");
|
||||
Value v = state.lookupBuffer(operand.get());
|
||||
Value returnTensor = b.create<memref::TensorLoadOp>(returnOp.getLoc(), v);
|
||||
operand.set(returnTensor);
|
||||
aliasInfo.insertNewBufferEquivalence(returnTensor, v);
|
||||
map(bvm, returnTensor, v);
|
||||
state.aliasInfo.insertNewBufferEquivalence(returnTensor, v);
|
||||
state.mapBuffer(returnTensor, v);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
@ -2715,17 +2632,14 @@ struct CastOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks &allocationFn) const {
|
||||
BufferizationState &state) const {
|
||||
auto castOp = cast<tensor::CastOp>(op);
|
||||
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
b.setInsertionPoint(castOp);
|
||||
|
||||
Value resultBuffer =
|
||||
getResultBuffer(b, castOp->getResult(0), bvm, aliasInfo, allocationFn);
|
||||
Value resultBuffer = getResultBuffer(b, castOp->getResult(0), state);
|
||||
if (!resultBuffer)
|
||||
return failure();
|
||||
Type sourceType = resultBuffer.getType();
|
||||
|
@ -2744,8 +2658,8 @@ struct CastOpInterface
|
|||
castOp.getResult().getType(), layout, memorySpace);
|
||||
Value res =
|
||||
b.create<memref::CastOp>(castOp.getLoc(), memRefType, resultBuffer);
|
||||
aliasInfo.insertNewBufferEquivalence(res, castOp.getResult());
|
||||
map(bvm, castOp.getResult(), res);
|
||||
state.aliasInfo.insertNewBufferEquivalence(res, castOp.getResult());
|
||||
state.mapBuffer(castOp.getResult(), res);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -2766,9 +2680,7 @@ struct DimOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks &allocationFn) const {
|
||||
BufferizationState &state) const {
|
||||
auto dimOp = cast<tensor::DimOp>(op);
|
||||
|
||||
// Take a guard before anything else.
|
||||
|
@ -2776,8 +2688,7 @@ struct DimOpInterface
|
|||
b.setInsertionPoint(dimOp);
|
||||
|
||||
if (dimOp.source().getType().isa<RankedTensorType>()) {
|
||||
Value v = lookup(bvm, dimOp.source());
|
||||
assert(v && "missing buffer");
|
||||
Value v = state.lookupBuffer(dimOp.source());
|
||||
dimOp.result().replaceAllUsesWith(
|
||||
b.create<memref::DimOp>(dimOp.getLoc(), v, dimOp.index()));
|
||||
}
|
||||
|
@ -2812,9 +2723,7 @@ struct ExtractSliceOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks &allocationFn) const {
|
||||
BufferizationState &state) const {
|
||||
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
|
||||
|
||||
// Take a guard before anything else.
|
||||
|
@ -2824,18 +2733,16 @@ struct ExtractSliceOpInterface
|
|||
|
||||
Location loc = extractSliceOp.getLoc();
|
||||
// Bail if source was not bufferized.
|
||||
Value srcMemref = lookup(bvm, extractSliceOp.source());
|
||||
if (!srcMemref)
|
||||
return failure();
|
||||
Value srcMemref = state.lookupBuffer(extractSliceOp.source());
|
||||
auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
|
||||
auto dstTensorType =
|
||||
extractSliceOp.result().getType().cast<RankedTensorType>();
|
||||
|
||||
// If not inplaceable, alloc.
|
||||
Value alloc;
|
||||
if (!aliasInfo.isInPlace(extractSliceOp->getResult(0)))
|
||||
if (!state.aliasInfo.isInPlace(extractSliceOp->getResult(0)))
|
||||
alloc = createNewAllocDeallocPairForShapedValue(
|
||||
b, loc, extractSliceOp.result(), aliasInfo, allocationFn);
|
||||
b, loc, extractSliceOp.result(), state);
|
||||
|
||||
// Set insertion point now that potential alloc/dealloc are introduced.
|
||||
b.setInsertionPoint(extractSliceOp);
|
||||
|
@ -2851,17 +2758,18 @@ struct ExtractSliceOpInterface
|
|||
loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(),
|
||||
extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
|
||||
// Insert new alias.
|
||||
aliasInfo.insertNewBufferAlias(subView, srcMemref);
|
||||
state.aliasInfo.insertNewBufferAlias(subView, srcMemref);
|
||||
|
||||
/// If not inplaceable, copy.
|
||||
if (alloc) {
|
||||
// Do not copy if the copied data is never read.
|
||||
if (isValueRead(extractSliceOp.result()))
|
||||
allocationFn.memCpyFn(b, extractSliceOp.getLoc(), subView, alloc);
|
||||
state.allocationFns.memCpyFn(b, extractSliceOp.getLoc(), subView,
|
||||
alloc);
|
||||
subView = alloc;
|
||||
}
|
||||
|
||||
map(bvm, extractSliceOp.result(), subView);
|
||||
state.mapBuffer(extractSliceOp.result(), subView);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -2882,9 +2790,7 @@ struct ExtractOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks &allocationFn) const {
|
||||
BufferizationState &state) const {
|
||||
auto extractOp = cast<tensor::ExtractOp>(op);
|
||||
|
||||
// Take a guard before anything else.
|
||||
|
@ -2892,7 +2798,7 @@ struct ExtractOpInterface
|
|||
b.setInsertionPoint(extractOp);
|
||||
|
||||
Location loc = extractOp.getLoc();
|
||||
Value srcMemref = lookup(bvm, extractOp.tensor());
|
||||
Value srcMemref = state.lookupBuffer(extractOp.tensor());
|
||||
Value l = b.create<memref::LoadOp>(loc, srcMemref, extractOp.indices());
|
||||
extractOp.replaceAllUsesWith(l);
|
||||
return success();
|
||||
|
@ -2950,9 +2856,7 @@ struct InsertSliceOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks &allocationFn) const {
|
||||
BufferizationState &state) const {
|
||||
auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
|
||||
|
||||
// Take a guard before anything else.
|
||||
|
@ -2969,15 +2873,12 @@ struct InsertSliceOpInterface
|
|||
// TODO: be very loud about it or even consider failing the pass.
|
||||
// Alloc a copy for `insertSliceOp.dest()`, it will become the result
|
||||
// buffer.
|
||||
Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), bvm,
|
||||
aliasInfo, allocationFn);
|
||||
Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), state);
|
||||
if (!dstMemref)
|
||||
return failure();
|
||||
auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
|
||||
|
||||
Value srcMemref = lookup(bvm, insertSliceOp.source());
|
||||
if (!srcMemref)
|
||||
return failure();
|
||||
Value srcMemref = state.lookupBuffer(insertSliceOp.source());
|
||||
auto subviewMemRefType =
|
||||
memref::SubViewOp::inferRankReducedResultType(
|
||||
insertSliceOp.getSourceType().getRank(), dstMemrefType,
|
||||
|
@ -2991,9 +2892,9 @@ struct InsertSliceOpInterface
|
|||
// - The result is not inplace. This is the case where the whole tensor is
|
||||
// cloned and the clone needs to be updated.
|
||||
// TODO: Is this necessary?
|
||||
if (!isSourceEquivalentToAMatchingInplaceExtractSliceOp(aliasInfo,
|
||||
if (!isSourceEquivalentToAMatchingInplaceExtractSliceOp(state.aliasInfo,
|
||||
insertSliceOp) ||
|
||||
!aliasInfo.isInPlace(insertSliceOp->getResult(0))) {
|
||||
!state.aliasInfo.isInPlace(insertSliceOp->getResult(0))) {
|
||||
LDBG("insert_slice needs extra source copy: " << insertSliceOp.source()
|
||||
<< " -> copy\n");
|
||||
// Take a subview of the dst.
|
||||
|
@ -3001,11 +2902,12 @@ struct InsertSliceOpInterface
|
|||
loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
|
||||
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
|
||||
// Insert new alias.
|
||||
aliasInfo.insertNewBufferAlias(subView, dstMemref);
|
||||
allocationFn.memCpyFn(b, insertSliceOp.getLoc(), srcMemref, subView);
|
||||
state.aliasInfo.insertNewBufferAlias(subView, dstMemref);
|
||||
state.allocationFns.memCpyFn(b, insertSliceOp.getLoc(), srcMemref,
|
||||
subView);
|
||||
}
|
||||
|
||||
map(bvm, insertSliceOp.result(), dstMemref);
|
||||
state.mapBuffer(insertSliceOp.result(), dstMemref);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -3035,9 +2937,7 @@ struct TransferReadOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks &allocationFn) const {
|
||||
BufferizationState &state) const {
|
||||
auto transferReadOp = cast<vector::TransferReadOp>(op);
|
||||
|
||||
// Take a guard before anything else.
|
||||
|
@ -3048,8 +2948,7 @@ struct TransferReadOpInterface
|
|||
return failure();
|
||||
|
||||
// TransferReadOp always reads from the bufferized op.source().
|
||||
Value v = lookup(bvm, transferReadOp.source());
|
||||
assert(v && "missing buffer");
|
||||
Value v = state.lookupBuffer(transferReadOp.source());
|
||||
transferReadOp.sourceMutable().assign(v);
|
||||
return success();
|
||||
}
|
||||
|
@ -3086,9 +2985,7 @@ struct TransferWriteOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
AllocationCallbacks &allocationFn) const {
|
||||
BufferizationState &state) const {
|
||||
auto writeOp = cast<vector::TransferWriteOp>(op);
|
||||
|
||||
// Take a guard before anything else.
|
||||
|
@ -3101,15 +2998,14 @@ struct TransferWriteOpInterface
|
|||
// Create a new transfer_write on buffer that doesn't have a return value.
|
||||
// Leave the previous transfer_write to dead code as it still has uses at
|
||||
// this point.
|
||||
Value resultBuffer =
|
||||
getResultBuffer(b, op->getResult(0), bvm, aliasInfo, allocationFn);
|
||||
Value resultBuffer = getResultBuffer(b, op->getResult(0), state);
|
||||
if (!resultBuffer)
|
||||
return failure();
|
||||
b.create<vector::TransferWriteOp>(
|
||||
writeOp.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(),
|
||||
writeOp.permutation_map(),
|
||||
writeOp.in_bounds() ? *writeOp.in_bounds() : ArrayAttr());
|
||||
map(bvm, op->getResult(0), resultBuffer);
|
||||
state.mapBuffer(op->getResult(0), resultBuffer);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue