forked from OSchip/llvm-project
[mlir][linalg][bufferize] Replace remaining bvm usage with new API
* Call `replaceOp` instead of `mapBuffer`. * Remove bvm and all helper functions around bvm. * Simplify FuncOp bufferization and rely on existing functionality to generate ToMemrefOps for function BlockArguments. Differential Revision: https://reviews.llvm.org/D115515
This commit is contained in:
parent
354e5cf776
commit
417014170b
|
@ -268,6 +268,9 @@ private:
|
|||
llvm::EquivalenceClasses<Value, ValueComparator> equivalentInfo;
|
||||
};
|
||||
|
||||
/// Return `true` if the given value is a BlockArgument of a FuncOp.
|
||||
bool isFunctionArgument(Value value);
|
||||
|
||||
/// Determine which OpOperand* will alias with `result` if the op is bufferized
|
||||
/// in place. Return an empty vector if the op is not bufferizable.
|
||||
SmallVector<OpOperand *> getAliasingOpOperand(OpResult result);
|
||||
|
@ -342,18 +345,18 @@ struct DialectBufferizationState {
|
|||
DialectBufferizationState(const DialectBufferizationState &) = delete;
|
||||
};
|
||||
|
||||
/// BufferizationState keeps track of memory buffers and provides a variety of
|
||||
/// helper functions for dealing with them. In particular,
|
||||
/// BufferizationState provides a variety of helper functions for dealing with
|
||||
/// tensor values and memref buffers. In particular,
|
||||
/// `BufferizableOpInterface::bufferize` implementation should utilize the
|
||||
/// following helper functions.
|
||||
///
|
||||
/// * `createAlloc` / `createDealloc` / `createAllocDeallocPair` creates ops
|
||||
/// that allocate and/or deallocate memref buffers.
|
||||
/// * `mapBuffer` maps a tensor value to a memref buffer during bufferization.
|
||||
/// * `lookupBuffer` returns the mapped memref buffer of a given tensor value.
|
||||
/// * `lookupBuffer` returns the memref buffer of a given tensor value.
|
||||
/// * `getResultBuffer` returns the memref buffer for a given tensor OpResult.
|
||||
/// Based on inplace bufferization decisions of the analysis, it may either
|
||||
/// directly return a mapped buffer or allocate a new brand new buffer.
|
||||
/// * `replaceOp` replaces an op with new values.
|
||||
class BufferizationState {
|
||||
public:
|
||||
BufferizationState(Operation *op, const BufferizationOptions &options)
|
||||
|
@ -378,16 +381,19 @@ public:
|
|||
/// Creates a memcpy between two given buffers.
|
||||
void createMemCpy(OpBuilder &b, Location loc, Value from, Value to);
|
||||
|
||||
/// Replace an op with replacement values. The op is deleted.
|
||||
/// Replace an op with replacement values. The op is deleted. Tensor OpResults
|
||||
/// must be replaced with memref values.
|
||||
void replaceOp(Operation *op, ValueRange values);
|
||||
|
||||
/// Map tensor values to memref buffers.
|
||||
// TODO: Deprecated. Remove all uses of this op. Use `replaceOp` instead.
|
||||
void mapBuffer(ValueRange tensors, ValueRange buffers);
|
||||
|
||||
/// Map a tensor value to a memref buffer.
|
||||
// TODO: Deprecated. Remove all uses of this op. Use `replaceOp` instead.
|
||||
void mapBuffer(Value tensor, Value buffer);
|
||||
/// Replace an op with a new op. Tensor OpResults must be replaced with memref
|
||||
/// values.
|
||||
template <typename OpTy, typename... Args>
|
||||
OpTy replaceOpWithNewOp(OpBuilder &b, Operation *op, Args &&...args) {
|
||||
Operation *newOp =
|
||||
b.create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
|
||||
replaceOp(op, newOp->getResults());
|
||||
return cast<OpTy>(newOp);
|
||||
}
|
||||
|
||||
/// Lookup the memref buffer that is associated to the given tensor value.
|
||||
/// Asserts if no buffer is associated.
|
||||
|
@ -396,23 +402,11 @@ public:
|
|||
/// Return `true` if the given OpResult has been decided to bufferize inplace.
|
||||
bool isInPlace(OpResult opResult) const;
|
||||
|
||||
/// Return `true` if the given value is mapped.
|
||||
// TODO: Deprecated. Remove all uses of this op.
|
||||
bool isMapped(Value value) const;
|
||||
|
||||
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
|
||||
/// a new buffer and copy over data from the existing buffer if out-of-place
|
||||
/// bufferization is necessary.
|
||||
Value getResultBuffer(OpResult result);
|
||||
|
||||
/// Mark `op` as obsolete, so that it is deleted after bufferization.
|
||||
// TODO: Deprecated. Remove all uses of this op.
|
||||
void markOpObsolete(Operation *op);
|
||||
|
||||
/// Erase all ops that were marked obsolete.
|
||||
// TODO: Deprecated. Remove all uses of this op.
|
||||
void eraseObsoleteOps();
|
||||
|
||||
/// Return dialect-specific bufferization state.
|
||||
template <typename StateT> StateT &getDialectState(StringRef name) {
|
||||
// Create state if it does not exist yet.
|
||||
|
@ -441,12 +435,6 @@ private:
|
|||
/// functions and `runComprehensiveBufferize` may access this object.
|
||||
BufferizationAliasInfo aliasInfo;
|
||||
|
||||
/// The mapping of tensors to buffers.
|
||||
BlockAndValueMapping mapping;
|
||||
|
||||
/// Obsolete ops that should be deleted after bufferization.
|
||||
SmallVector<Operation *> obsoleteOps;
|
||||
|
||||
/// Dialect-specific bufferization state.
|
||||
DenseMap<StringRef, std::unique_ptr<DialectBufferizationState>> dialectState;
|
||||
|
||||
|
|
|
@ -35,10 +35,8 @@ struct ConstantOpInterface
|
|||
|
||||
GlobalCreator globalCreator(moduleOp);
|
||||
auto globalMemref = globalCreator.getGlobalFor(constantOp);
|
||||
Value memref = b.create<memref::GetGlobalOp>(
|
||||
constantOp.getLoc(), globalMemref.type(), globalMemref.getName());
|
||||
state.mapBuffer(constantOp, memref);
|
||||
|
||||
state.replaceOpWithNewOp<memref::GetGlobalOp>(b, op, globalMemref.type(),
|
||||
globalMemref.getName());
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -498,19 +498,6 @@ mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
|
|||
if (!state.getOptions().allowUnknownOps)
|
||||
return op->emitError() << "unsupported op with tensors";
|
||||
|
||||
// Replace all OpOperands with "to-tensor casted" bufferized values.
|
||||
for (OpOperand &operand : op->getOpOperands()) {
|
||||
if (operand.get().getType().isa<TensorType>() &&
|
||||
state.isMapped(operand.get())) {
|
||||
assert(state.getOptions().allowUnknownOps &&
|
||||
"unsupported op error should have been emitted earlier");
|
||||
b.setInsertionPoint(op);
|
||||
Value toTensorOp = b.create<bufferization::ToTensorOp>(
|
||||
op->getLoc(), state.lookupBuffer(operand.get()));
|
||||
operand.set(toTensorOp);
|
||||
}
|
||||
}
|
||||
|
||||
// Bufferize all regions.
|
||||
for (Region ®ion : op->getRegions())
|
||||
if (failed(bufferize(®ion, state)))
|
||||
|
@ -654,38 +641,13 @@ void mlir::linalg::comprehensive_bufferize::BufferizationState::createMemCpy(
|
|||
// Bufferization-specific BlockAndValueMapping support with debugging.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Wrapper for better debugging.
|
||||
void mlir::linalg::comprehensive_bufferize::BufferizationState::mapBuffer(
|
||||
ValueRange tensors, ValueRange buffers) {
|
||||
assert(!tensors.empty() && "unexpected empty tensors");
|
||||
#ifndef NDEBUG
|
||||
for (Value tensor : tensors) {
|
||||
assert(tensor && "unexpected empty tensor");
|
||||
assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
|
||||
}
|
||||
for (Value buffer : buffers) {
|
||||
assert(buffer && "unexpected empty buffer");
|
||||
assert((buffer.getType().isa<MemRefType>() ||
|
||||
buffer.getType().isa<UnrankedMemRefType>()) &&
|
||||
"expected that tensor is mapped to memref");
|
||||
}
|
||||
#endif // NDEBUG
|
||||
return mapping.map(tensors, buffers);
|
||||
bool mlir::linalg::comprehensive_bufferize::isFunctionArgument(Value value) {
|
||||
auto bbArg = value.dyn_cast<BlockArgument>();
|
||||
if (!bbArg)
|
||||
return false;
|
||||
return isa<FuncOp>(bbArg.getOwner()->getParentOp());
|
||||
}
|
||||
|
||||
/// Wrapper for better debugging.
|
||||
void mlir::linalg::comprehensive_bufferize::BufferizationState::mapBuffer(
|
||||
Value tensor, Value buffer) {
|
||||
assert(tensor && "unexpected empty tensor");
|
||||
assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
|
||||
assert(buffer && "unexpected empty buffer");
|
||||
assert((buffer.getType().isa<MemRefType>() ||
|
||||
buffer.getType().isa<UnrankedMemRefType>()) &&
|
||||
"expected that tensor is mapped to memref");
|
||||
return mapping.map(tensor, buffer);
|
||||
}
|
||||
|
||||
/// Wrapper for better debugging.
|
||||
Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
|
||||
Value tensor) {
|
||||
assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
|
||||
|
@ -694,37 +656,29 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
|
|||
if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>())
|
||||
return toTensorOp.memref();
|
||||
|
||||
Value buffer = mapping.lookupOrNull(tensor);
|
||||
if (!buffer) {
|
||||
if (options.allowUnknownOps) {
|
||||
// `tensor` was not bufferized yet. This should never happen with
|
||||
// bufferizable ops.
|
||||
assert(!options.dynCastBufferizableOp(tensor) && "tensor is not mapped");
|
||||
// Insert to_memref op.
|
||||
OpBuilder b(tensor.getContext());
|
||||
setInsertionPointAfter(b, tensor);
|
||||
return b.create<bufferization::ToMemrefOp>(
|
||||
tensor.getLoc(),
|
||||
getDynamicMemRefType(tensor.getType().cast<RankedTensorType>()),
|
||||
tensor);
|
||||
if (!isFunctionArgument(tensor)) {
|
||||
if (static_cast<bool>(options.dynCastBufferizableOp(tensor))) {
|
||||
// Dump tensor for easier debugging.
|
||||
tensor.dump();
|
||||
llvm_unreachable("op is known, but has not been bufferized yet");
|
||||
return Value();
|
||||
}
|
||||
if (!options.allowUnknownOps) {
|
||||
// Dump tensor for easier debugging.
|
||||
tensor.dump();
|
||||
// Note: An assertion should already have failed earlier.
|
||||
llvm_unreachable("unknown ops are not allowed");
|
||||
return Value();
|
||||
}
|
||||
|
||||
// Dump tensor for easier debugging.
|
||||
tensor.dump();
|
||||
llvm_unreachable("tensor is not mapped");
|
||||
return Value();
|
||||
}
|
||||
|
||||
assert((buffer.getType().isa<MemRefType>() ||
|
||||
buffer.getType().isa<UnrankedMemRefType>()) &&
|
||||
"expected that tensor is mapped to memref");
|
||||
return buffer;
|
||||
}
|
||||
|
||||
bool mlir::linalg::comprehensive_bufferize::BufferizationState::isMapped(
|
||||
Value value) const {
|
||||
assert(value.getType().isa<TensorType>() && "unexpected non-tensor type");
|
||||
return mapping.contains(value);
|
||||
// Insert to_memref op.
|
||||
OpBuilder &b = getBuilder();
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
setInsertionPointAfter(b, tensor);
|
||||
return b.create<bufferization::ToMemrefOp>(
|
||||
tensor.getLoc(),
|
||||
getDynamicMemRefType(tensor.getType().cast<RankedTensorType>()), tensor);
|
||||
}
|
||||
|
||||
bool mlir::linalg::comprehensive_bufferize::BufferizationState::isInPlace(
|
||||
|
@ -732,18 +686,6 @@ bool mlir::linalg::comprehensive_bufferize::BufferizationState::isInPlace(
|
|||
return aliasInfo.isInPlace(opResult);
|
||||
}
|
||||
|
||||
void mlir::linalg::comprehensive_bufferize::BufferizationState::markOpObsolete(
|
||||
Operation *op) {
|
||||
obsoleteOps.push_back(op);
|
||||
}
|
||||
|
||||
void mlir::linalg::comprehensive_bufferize::BufferizationState::
|
||||
eraseObsoleteOps() {
|
||||
for (Operation *op : obsoleteOps)
|
||||
op->erase();
|
||||
obsoleteOps.clear();
|
||||
}
|
||||
|
||||
MemRefType mlir::linalg::comprehensive_bufferize::getContiguousMemRefType(
|
||||
ShapedType shapedType, MemRefLayoutAttrInterface layout,
|
||||
Attribute memorySpace) {
|
||||
|
|
|
@ -63,21 +63,9 @@ struct ToMemrefOpInterface
|
|||
// If a ToMemrefOp's tensor operand has not been bufferized yet, the op
|
||||
// remains unchanged. All IR up to this ToMemrefOp has already been
|
||||
// bufferized, unless there were unknown ops that could be bufferized.
|
||||
if (!state.isMapped(toMemrefOp.tensor())) {
|
||||
assert(state.getOptions().allowUnknownOps &&
|
||||
"expected that tensor is mapped");
|
||||
return success();
|
||||
}
|
||||
|
||||
// If a ToMemrefOp's tensor operand has been bufferized, the op can be
|
||||
// removed.
|
||||
Value memref = state.lookupBuffer(toMemrefOp.tensor());
|
||||
// Do not replace a ToMemrefOp with itself. E.g., when bufferizing a
|
||||
// function body, ToMemrefOps were inserted before starting bufferization of
|
||||
// the function body. Such ToMemrefOps are replaced in a separate step after
|
||||
// the function body has been bufferized.
|
||||
if (toMemrefOp.getResult() != memref)
|
||||
toMemrefOp.replaceAllUsesWith(memref);
|
||||
assert((isFunctionArgument(toMemrefOp.tensor()) ||
|
||||
state.getOptions().allowUnknownOps) &&
|
||||
"expected that tensor is mapped");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -98,8 +86,6 @@ struct ToTensorOpInterface
|
|||
bufferization::ToTensorOp> {
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BufferizationState &state) const {
|
||||
auto tensorLoadOp = cast<bufferization::ToTensorOp>(op);
|
||||
state.mapBuffer(tensorLoadOp.result(), tensorLoadOp.memref());
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -699,8 +699,5 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
|
|||
if (failed(bufferize(op, state)))
|
||||
return failure();
|
||||
|
||||
// Erase all obsolete ops.
|
||||
state.eraseObsoleteOps();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -56,7 +56,6 @@ static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op,
|
|||
if (!resultBuffer)
|
||||
return failure();
|
||||
newOutputBuffers.push_back(resultBuffer);
|
||||
state.mapBuffer(opResult, resultBuffer);
|
||||
}
|
||||
|
||||
// Clone the newly bufferized op.
|
||||
|
@ -68,7 +67,9 @@ static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op,
|
|||
auto bufferizedOp = cast<LinalgOp>(
|
||||
op.clone(b, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands));
|
||||
|
||||
// The original op will be DCE'd away later.
|
||||
// Replace the results of the old op with the new output buffers.
|
||||
state.replaceOp(op, newOutputBuffers);
|
||||
|
||||
return comprehensive_bufferize::bufferize(bufferizedOp.getBlock(), state);
|
||||
}
|
||||
|
||||
|
@ -194,7 +195,7 @@ struct InitTensorOpInterface
|
|||
|
||||
Value alloc = state.createAllocDeallocPair(b, initTensorOp->getLoc(),
|
||||
initTensorOp.result());
|
||||
state.mapBuffer(initTensorOp.result(), alloc);
|
||||
state.replaceOp(op, alloc);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -551,9 +551,6 @@ struct CallOpInterface
|
|||
.equivalentFuncArgs[funcOp][returnOperand.getOperandNumber()];
|
||||
Value oldRes = callOp->getResult(returnOperand.getOperandNumber());
|
||||
Value buffer = state.lookupBuffer(callOp->getOperand(idx));
|
||||
// Add CallOp operand/result equivalence: this is interprocedural
|
||||
// info.
|
||||
state.mapBuffer(oldRes, buffer);
|
||||
// Add a ToTensorOp to kill all uses of the CallOp return.
|
||||
// Replace all uses of the CallOp results so we can erase the CallOp.
|
||||
// This ToTensorOp must fold/DCE away or bufferization should be
|
||||
|
@ -561,8 +558,6 @@ struct CallOpInterface
|
|||
Value toTensorOp =
|
||||
b.create<bufferization::ToTensorOp>(callOp.getLoc(), buffer);
|
||||
oldRes.replaceAllUsesWith(toTensorOp);
|
||||
// Add new op equivalence info.
|
||||
state.mapBuffer(toTensorOp, buffer);
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -603,8 +598,6 @@ struct CallOpInterface
|
|||
if (buffer.getType() != memRefType) {
|
||||
Value castBuffer =
|
||||
b.create<memref::CastOp>(callOp.getLoc(), memRefType, buffer);
|
||||
// Add new op equivalence info.
|
||||
state.mapBuffer(tensorOperand, castBuffer);
|
||||
buffer = castBuffer;
|
||||
}
|
||||
newOperands.push_back(buffer);
|
||||
|
@ -616,7 +609,7 @@ struct CallOpInterface
|
|||
newCallOp->setAttrs(callOp->getAttrs());
|
||||
|
||||
// 5. Delete the op at the end of bufferization.
|
||||
state.markOpObsolete(callOp);
|
||||
callOp->erase();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -651,7 +644,6 @@ struct ReturnOpInterface
|
|||
Value returnTensor = b.create<bufferization::ToTensorOp>(
|
||||
returnOp.getLoc(), v);
|
||||
operand.set(returnTensor);
|
||||
state.mapBuffer(returnTensor, v);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
@ -662,23 +654,6 @@ struct FuncOpInterface
|
|||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BufferizationState &state) const {
|
||||
auto funcOp = cast<FuncOp>(op);
|
||||
b.setInsertionPointToStart(&funcOp.body().front());
|
||||
|
||||
// Create BufferCastOps for function args.
|
||||
for (auto bbArg : funcOp.getArguments()) {
|
||||
auto tensorType = bbArg.getType().dyn_cast<TensorType>();
|
||||
if (!tensorType)
|
||||
continue;
|
||||
auto rankedTensorType = tensorType.dyn_cast<RankedTensorType>();
|
||||
// Cast the tensor to the most dynamic buffer possible. Further
|
||||
// canonicalizations will clean up.
|
||||
Type memRefType = rankedTensorType
|
||||
? getDynamicMemRefType(rankedTensorType)
|
||||
: getContiguousOrUnrankedMemRefType(tensorType);
|
||||
Value bufferCast = b.create<bufferization::ToMemrefOp>(funcOp.getLoc(),
|
||||
memRefType, bbArg);
|
||||
state.mapBuffer(bbArg, bufferCast);
|
||||
}
|
||||
|
||||
// Bufferize function body.
|
||||
return comprehensive_bufferize::bufferize(&funcOp.body(), state);
|
||||
|
|
|
@ -78,9 +78,7 @@ struct CastOpInterface
|
|||
: MemRefLayoutAttrInterface();
|
||||
Type memRefType = getContiguousOrUnrankedMemRefType(
|
||||
castOp.getResult().getType(), layout, memorySpace);
|
||||
Value res =
|
||||
b.create<memref::CastOp>(castOp.getLoc(), memRefType, resultBuffer);
|
||||
state.mapBuffer(castOp.getResult(), res);
|
||||
state.replaceOpWithNewOp<memref::CastOp>(b, op, memRefType, resultBuffer);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -103,11 +101,10 @@ struct DimOpInterface
|
|||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BufferizationState &state) const {
|
||||
auto dimOp = cast<tensor::DimOp>(op);
|
||||
if (dimOp.source().getType().isa<RankedTensorType>()) {
|
||||
Value v = state.lookupBuffer(dimOp.source());
|
||||
dimOp.result().replaceAllUsesWith(
|
||||
b.create<memref::DimOp>(dimOp.getLoc(), v, dimOp.index()));
|
||||
}
|
||||
if (!dimOp.source().getType().isa<RankedTensorType>())
|
||||
return dimOp.emitError("unranked tensor not supported");
|
||||
Value v = state.lookupBuffer(dimOp.source());
|
||||
state.replaceOpWithNewOp<memref::DimOp>(b, op, v, dimOp.index());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -168,7 +165,7 @@ struct ExtractSliceOpInterface
|
|||
subView = alloc;
|
||||
}
|
||||
|
||||
state.mapBuffer(extractSliceOp.result(), subView);
|
||||
state.replaceOp(op, subView);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -191,10 +188,9 @@ struct ExtractOpInterface
|
|||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BufferizationState &state) const {
|
||||
auto extractOp = cast<tensor::ExtractOp>(op);
|
||||
Location loc = extractOp.getLoc();
|
||||
Value srcMemref = state.lookupBuffer(extractOp.tensor());
|
||||
Value l = b.create<memref::LoadOp>(loc, srcMemref, extractOp.indices());
|
||||
extractOp.replaceAllUsesWith(l);
|
||||
state.replaceOpWithNewOp<memref::LoadOp>(b, op, srcMemref,
|
||||
extractOp.indices());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -228,7 +224,7 @@ struct InsertOpInterface
|
|||
Value destMemref = state.getResultBuffer(insertOp->getOpResult(0));
|
||||
b.create<memref::StoreOp>(loc, insertOp.scalar(), destMemref,
|
||||
insertOp.indices());
|
||||
state.mapBuffer(insertOp, destMemref);
|
||||
state.replaceOp(op, destMemref);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -423,7 +419,7 @@ struct InsertSliceOpInterface
|
|||
state.createMemCpy(b, insertSliceOp.getLoc(), srcMemref, subView);
|
||||
}
|
||||
|
||||
state.mapBuffer(insertSliceOp.result(), dstMemref);
|
||||
state.replaceOp(op, dstMemref);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -38,13 +38,17 @@ struct TransferReadOpInterface
|
|||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BufferizationState &state) const {
|
||||
auto transferReadOp = cast<vector::TransferReadOp>(op);
|
||||
assert(transferReadOp.getShapedType().isa<TensorType>() &&
|
||||
auto readOp = cast<vector::TransferReadOp>(op);
|
||||
assert(readOp.getShapedType().isa<TensorType>() &&
|
||||
"only tensor types expected");
|
||||
|
||||
// TransferReadOp always reads from the bufferized op.source().
|
||||
Value v = state.lookupBuffer(transferReadOp.source());
|
||||
transferReadOp.sourceMutable().assign(v);
|
||||
Value buffer = state.lookupBuffer(readOp.source());
|
||||
Value read = b.create<vector::TransferReadOp>(
|
||||
readOp.getLoc(), readOp.getVectorType(), buffer, readOp.indices(),
|
||||
readOp.permutation_map(), readOp.padding(), readOp.mask(),
|
||||
readOp.in_boundsAttr());
|
||||
state.replaceOp(op, read);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -90,7 +94,7 @@ struct TransferWriteOpInterface
|
|||
b.create<vector::TransferWriteOp>(
|
||||
writeOp.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(),
|
||||
writeOp.permutation_mapAttr(), writeOp.in_boundsAttr());
|
||||
state.mapBuffer(op->getResult(0), resultBuffer);
|
||||
state.replaceOp(op, resultBuffer);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -30,11 +30,11 @@ func @return_tensor(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
|
|||
// CHECK: %[[dim:.*]] = tensor.dim %[[A]]
|
||||
// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
|
||||
// CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
|
||||
// CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[casted]]
|
||||
// CHECK: memref.copy %[[A_memref]], %[[casted]]
|
||||
// CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
|
||||
%0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>
|
||||
|
||||
// CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[casted]]
|
||||
// CHECK: return %[[res_tensor]]
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
|
|
@ -52,10 +52,10 @@ func @use_of_unknown_op_3(%t1: tensor<?xf32> {linalg.inplaceable = true})
|
|||
-> (vector<5xf32>, vector<5xf32>) {
|
||||
%idx = arith.constant 0 : index
|
||||
%cst = arith.constant 0.0 : f32
|
||||
// CHECK: %[[m1_tensor:.*]] = bufferization.to_tensor %[[m1]]
|
||||
// CHECK: %[[v1:.*]] = vector.transfer_read %[[m1]]
|
||||
%1 = vector.transfer_read %t1[%idx], %cst : tensor<?xf32>, vector<5xf32>
|
||||
|
||||
// CHECK: %[[m1_tensor:.*]] = bufferization.to_tensor %[[m1]]
|
||||
// CHECK: %[[dummy:.*]] = "test.dummy_op"(%[[m1_tensor]])
|
||||
%0 = "test.dummy_op"(%t1) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK: %[[dummy_memref:.*]] = bufferization.to_memref %[[dummy]]
|
||||
|
@ -114,11 +114,11 @@ func @use_of_bufferizable_op_in_unbufferizable_op(
|
|||
func @unused_unknown_op(%t1 : tensor<?xf32>) -> vector<5xf32> {
|
||||
%idx = arith.constant 0 : index
|
||||
%cst = arith.constant 0.0 : f32
|
||||
// ToTensorOp is inserted to pass in the result of the above bufferized op.
|
||||
// CHECK: %[[m1_tensor:.*]] = bufferization.to_tensor %[[m1]]
|
||||
// CHECK: vector.transfer_read %[[m1]]
|
||||
%1 = vector.transfer_read %t1[%idx], %cst : tensor<?xf32>, vector<5xf32>
|
||||
|
||||
// ToTensorOp is inserted to pass in the result of the above bufferized op.
|
||||
// CHECK: %[[m1_tensor:.*]] = bufferization.to_tensor %[[m1]]
|
||||
// CHECK: "test.dummy_op"(%[[m1_tensor]])
|
||||
"test.dummy_op"(%t1) : (tensor<?xf32>) -> ()
|
||||
|
||||
|
@ -158,10 +158,10 @@ func @simple_tensor_test(%t1 : tensor<?xf32>, %f : f32) -> tensor<?xf32> {
|
|||
%c0 = arith.constant 0 : index
|
||||
// CHECK-TENSOR: %[[alloc:.*]] = memref.alloc
|
||||
// CHECK-TENSOR: %[[casted:.*]] = memref.cast %[[alloc]]
|
||||
// CHECK-TENSOR: %[[casted_tensor:.*]] = bufferization.to_tensor %[[casted]]
|
||||
// CHECK-TENSOR: memref.copy %[[t1_memref]], %[[casted]]
|
||||
// CHECK-TENSOR: memref.store %{{.*}}, %[[alloc]]
|
||||
%0 = tensor.insert %f into %t1[%c0] : tensor<?xf32>
|
||||
// CHECK-TENSOR: %[[casted_tensor:.*]] = bufferization.to_tensor %[[casted]]
|
||||
// CHECK-TENSOR: return %[[casted_tensor]]
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
|
|
@ -168,25 +168,25 @@ func @insert_slice_fun(%A0 : tensor<?xf32>,
|
|||
-> (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>)
|
||||
{
|
||||
// Hoisted allocs.
|
||||
// CHECK: %[[REALLOC_A0_2:.*]] = memref.alloc
|
||||
// CHECK: %[[REALLOC_A0:.*]] = memref.alloc
|
||||
// CHECK: %[[REALLOC_A1:.*]] = memref.alloc
|
||||
// CHECK: %[[REALLOC1:.*]] = memref.alloc
|
||||
// CHECK: %[[REALLOC2:.*]] = memref.alloc
|
||||
// CHECK: %[[REALLOC3:.*]] = memref.alloc
|
||||
|
||||
// Alloc and copy the whole result tensor. Copy the tensor.extract_slice.
|
||||
// CHECK: linalg.copy(%[[A0]], %[[REALLOC_A0]]
|
||||
// CHECK: %[[SV_A0:.*]] = memref.subview %[[REALLOC_A0]]
|
||||
// CHECK: linalg.copy(%[[A0]], %[[REALLOC3]]
|
||||
// CHECK: %[[SV_A0:.*]] = memref.subview %[[REALLOC3]]
|
||||
// CHECK: linalg.copy(%[[t0]], %[[SV_A0]])
|
||||
%r0 = tensor.insert_slice %t0 into %A0[0][4][1] : tensor<4xf32> into tensor<?xf32>
|
||||
|
||||
// Alloc and copy the whole result tensor. Copy the tensor.extract_slice.
|
||||
// CHECK: linalg.copy(%[[A0]]
|
||||
// CHECK: %[[SV_A0_2:.*]] = memref.subview %[[REALLOC_A0_2]]
|
||||
// CHECK: %[[SV_A0_2:.*]] = memref.subview %[[REALLOC2]]
|
||||
// CHECK: linalg.copy(%[[t1]], %[[SV_A0_2]])
|
||||
%r1 = tensor.insert_slice %t1 into %A0[0][4][1] : tensor<4xf32> into tensor<?xf32>
|
||||
|
||||
// Still alloc the large tensor because %A1 is read after. Copy the tensor.extract_slice.
|
||||
// CHECK: linalg.copy(%[[A1]]
|
||||
// CHECK: %[[SV_A1:.*]] = memref.subview %[[REALLOC_A1]]
|
||||
// CHECK: %[[SV_A1:.*]] = memref.subview %[[REALLOC1]]
|
||||
// CHECK: linalg.copy(%[[t0]], %[[SV_A1]])
|
||||
%r2 = tensor.insert_slice %t0 into %A1[0][4][1] : tensor<4xf32> into tensor<?xf32>
|
||||
|
||||
|
@ -196,7 +196,7 @@ func @insert_slice_fun(%A0 : tensor<?xf32>,
|
|||
// CHECK: linalg.copy(%[[t1]], %[[SV_A1_2]])
|
||||
%r3 = tensor.insert_slice %t1 into %A1[0][4][1] : tensor<4xf32> into tensor<?xf32>
|
||||
|
||||
// CHECK: return %[[REALLOC_A0]], %[[REALLOC_A0_2]], %[[REALLOC_A1]] :
|
||||
// CHECK: return %[[REALLOC3]], %[[REALLOC2]], %[[REALLOC1]] :
|
||||
// CHECK-SAME: memref<?xf32>, memref<?xf32>, memref<?xf32>
|
||||
return %r0, %r1, %r2, %r3: tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue