forked from OSchip/llvm-project
[mlir][linalg][bufferize][NFC] Move tensor interface impl to new build target
This makes ComprehensiveBufferize entirely independent of the tensor dialect. Differential Revision: https://reviews.llvm.org/D114217
This commit is contained in:
parent
8ef460fc51
commit
bb273a35a0
|
@ -322,6 +322,26 @@ struct PostAnalysisStep {
|
|||
SmallVector<Operation *> &newOps) = 0;
|
||||
};
|
||||
|
||||
/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
|
||||
/// with the same shape as `shapedType` and specified `layout` and
|
||||
/// `addressSpace`.
|
||||
MemRefType getContiguousMemRefType(ShapedType shapedType,
|
||||
MemRefLayoutAttrInterface layout = {},
|
||||
Attribute memorySpace = {});
|
||||
|
||||
/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
|
||||
/// with the same shape as `shapedType` and specified `layout` and
|
||||
/// `addressSpace` or an UnrankedMemRefType otherwise.
|
||||
Type getContiguousOrUnrankedMemRefType(Type type,
|
||||
MemRefLayoutAttrInterface layout = {},
|
||||
Attribute memorySpace = {});
|
||||
|
||||
/// Return a MemRefType to which the `tensorType` can be bufferized in a
|
||||
/// composable fashion. The layout must be the most dynamic possible and
|
||||
/// canonicalize away once bufferization is finished.
|
||||
MemRefType getDynamicMemRefType(RankedTensorType tensorType,
|
||||
unsigned addressSpace = 0);
|
||||
|
||||
} // namespace comprehensive_bufferize
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
|
|
@ -1,3 +1,11 @@
|
|||
//===- LinalgInterfaceImpl.h - Linalg Impl. of BufferizableOpInterface ----===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_LINALG_INTERFACE_IMPL_H
|
||||
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_LINALG_INTERFACE_IMPL_H
|
||||
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
//===- LinalgInterfaceImpl.h - Linalg Impl. of BufferizableOpInterface ----===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_TENSOR_INTERFACE_IMPL_H
|
||||
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_TENSOR_INTERFACE_IMPL_H
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class DialectRegistry;
|
||||
|
||||
namespace linalg {
|
||||
namespace comprehensive_bufferize {
|
||||
namespace tensor_ext {
|
||||
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
|
||||
|
||||
} // namespace tensor_ext
|
||||
} // namespace comprehensive_bufferize
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_TENSOR_INTERFACE_IMPL_H
|
|
@ -12,6 +12,7 @@
|
|||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
|
@ -528,3 +529,31 @@ void mlir::linalg::comprehensive_bufferize::BufferizationState::
|
|||
op->erase();
|
||||
obsoleteOps.clear();
|
||||
}
|
||||
|
||||
MemRefType mlir::linalg::comprehensive_bufferize::getContiguousMemRefType(
|
||||
ShapedType shapedType, MemRefLayoutAttrInterface layout,
|
||||
Attribute memorySpace) {
|
||||
return MemRefType::get(shapedType.getShape(), shapedType.getElementType(),
|
||||
layout, memorySpace);
|
||||
}
|
||||
|
||||
Type mlir::linalg::comprehensive_bufferize::getContiguousOrUnrankedMemRefType(
|
||||
Type type, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
|
||||
if (type.isa<RankedTensorType, MemRefType>())
|
||||
return getContiguousMemRefType(type.cast<ShapedType>(), layout,
|
||||
memorySpace);
|
||||
assert(!layout && "expected empty layout with UnrankedMemRefType");
|
||||
return UnrankedMemRefType::get(getElementTypeOrSelf(type), memorySpace);
|
||||
}
|
||||
|
||||
MemRefType mlir::linalg::comprehensive_bufferize::getDynamicMemRefType(
|
||||
RankedTensorType tensorType, unsigned addressSpace) {
|
||||
// TODO: address space decisions to connect with the actual alloc.
|
||||
int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
|
||||
SmallVector<int64_t> dynamicStrides(tensorType.getRank(),
|
||||
ShapedType::kDynamicStrideOrOffset);
|
||||
AffineMap stridedLayout = makeStridedLinearLayoutMap(
|
||||
dynamicStrides, dynamicOffset, tensorType.getContext());
|
||||
return MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
|
||||
stridedLayout, addressSpace);
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ set(LLVM_OPTIONAL_SOURCES
|
|||
BufferizableOpInterface.cpp
|
||||
ComprehensiveBufferize.cpp
|
||||
LinalgInterfaceImpl.cpp
|
||||
TensorInterfaceImpl.cpp
|
||||
)
|
||||
|
||||
add_mlir_dialect_library(MLIRBufferizableOpInterface
|
||||
|
@ -25,6 +26,16 @@ add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl
|
|||
MLIRTensor
|
||||
)
|
||||
|
||||
add_mlir_dialect_library(MLIRTensorBufferizableOpInterfaceImpl
|
||||
TensorInterfaceImpl.cpp
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRBufferizableOpInterface
|
||||
MLIRIR
|
||||
MLIRMemRef
|
||||
MLIRTensor
|
||||
)
|
||||
|
||||
add_mlir_dialect_library(MLIRComprehensiveBufferize
|
||||
ComprehensiveBufferize.cpp
|
||||
|
||||
|
@ -37,6 +48,5 @@ add_mlir_dialect_library(MLIRComprehensiveBufferize
|
|||
MLIRSCF
|
||||
MLIRStandard
|
||||
MLIRStandardOpsTransforms
|
||||
MLIRTensor
|
||||
MLIRVector
|
||||
)
|
||||
|
|
|
@ -587,45 +587,6 @@ getEquivalentEnclosingFuncBBArg(Value v,
|
|||
// Bufferization-specific MemRefType support.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
|
||||
/// with the same shape as `shapedType` and specified `layout` and
|
||||
/// `addressSpace`.
|
||||
static MemRefType getContiguousMemRefType(ShapedType shapedType,
|
||||
MemRefLayoutAttrInterface layout = {},
|
||||
Attribute memorySpace = {}) {
|
||||
return MemRefType::get(shapedType.getShape(), shapedType.getElementType(),
|
||||
layout, memorySpace);
|
||||
}
|
||||
|
||||
/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
|
||||
/// with the same shape as `shapedType` and specified `layout` and
|
||||
/// `addressSpace` or an UnrankedMemRefType otherwise.
|
||||
static Type
|
||||
getContiguousOrUnrankedMemRefType(Type type,
|
||||
MemRefLayoutAttrInterface layout = {},
|
||||
Attribute memorySpace = {}) {
|
||||
if (type.isa<RankedTensorType, MemRefType>())
|
||||
return getContiguousMemRefType(type.cast<ShapedType>(), layout,
|
||||
memorySpace);
|
||||
assert(!layout && "expected empty layout with UnrankedMemRefType");
|
||||
return UnrankedMemRefType::get(getElementTypeOrSelf(type), memorySpace);
|
||||
}
|
||||
|
||||
/// Return a MemRefType to which the `tensorType` can be bufferized in a
|
||||
/// composable fashion. The layout must be the most dynamic possible and
|
||||
/// canonicalize away once bufferization is finished.
|
||||
static MemRefType getDynamicMemRefType(RankedTensorType tensorType,
|
||||
unsigned addressSpace = 0) {
|
||||
// TODO: address space decisions to connect with the actual alloc.
|
||||
int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
|
||||
SmallVector<int64_t> dynamicStrides(tensorType.getRank(),
|
||||
ShapedType::kDynamicStrideOrOffset);
|
||||
AffineMap stridedLayout = makeStridedLinearLayoutMap(
|
||||
dynamicStrides, dynamicOffset, tensorType.getContext());
|
||||
return MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
|
||||
stridedLayout, addressSpace);
|
||||
}
|
||||
|
||||
/// Return the FunctionType with `argumentTypes` and `resultTypes` where each
|
||||
/// tensor is replaced by the corresponding buffer type.
|
||||
/// In order for all the callers to agree, this *must* bufferize to the most
|
||||
|
@ -1965,420 +1926,6 @@ struct ReturnOpInterface
|
|||
|
||||
} // namespace std_ext
|
||||
|
||||
namespace tensor_ext {
|
||||
|
||||
struct CastOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<CastOpInterface,
|
||||
tensor::CastOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
|
||||
OpResult opResult) const {
|
||||
return {&op->getOpOperand(0)};
|
||||
}
|
||||
|
||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
|
||||
return op->getResult(0);
|
||||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
|
||||
return BufferRelation::Equivalent;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BufferizationState &state) const {
|
||||
auto castOp = cast<tensor::CastOp>(op);
|
||||
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
b.setInsertionPoint(castOp);
|
||||
|
||||
Value resultBuffer = getResultBuffer(b, castOp->getResult(0), state);
|
||||
if (!resultBuffer)
|
||||
return failure();
|
||||
Type sourceType = resultBuffer.getType();
|
||||
auto rankedMemRefType = sourceType.dyn_cast<MemRefType>();
|
||||
auto unrankedMemRefType = sourceType.dyn_cast<UnrankedMemRefType>();
|
||||
assert(rankedMemRefType || unrankedMemRefType);
|
||||
Attribute memorySpace = rankedMemRefType
|
||||
? rankedMemRefType.getMemorySpace()
|
||||
: unrankedMemRefType.getMemorySpace();
|
||||
TensorType tensorType = castOp.getResult().getType().cast<TensorType>();
|
||||
MemRefLayoutAttrInterface layout =
|
||||
rankedMemRefType && tensorType.isa<RankedTensorType>()
|
||||
? rankedMemRefType.getLayout()
|
||||
: MemRefLayoutAttrInterface();
|
||||
Type memRefType = getContiguousOrUnrankedMemRefType(
|
||||
castOp.getResult().getType(), layout, memorySpace);
|
||||
Value res =
|
||||
b.create<memref::CastOp>(castOp.getLoc(), memRefType, resultBuffer);
|
||||
state.aliasInfo.insertNewBufferEquivalence(res, castOp.getResult());
|
||||
state.mapBuffer(castOp.getResult(), res);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct DimOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<DimOpInterface,
|
||||
tensor::DimOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
|
||||
return OpResult();
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BufferizationState &state) const {
|
||||
auto dimOp = cast<tensor::DimOp>(op);
|
||||
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
b.setInsertionPoint(dimOp);
|
||||
|
||||
if (dimOp.source().getType().isa<RankedTensorType>()) {
|
||||
Value v = state.lookupBuffer(dimOp.source());
|
||||
dimOp.result().replaceAllUsesWith(
|
||||
b.create<memref::DimOp>(dimOp.getLoc(), v, dimOp.index()));
|
||||
}
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ExtractSliceOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
|
||||
tensor::ExtractSliceOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
|
||||
OpResult opResult) const {
|
||||
return {&op->getOpOperand(0) /*source*/};
|
||||
}
|
||||
|
||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
|
||||
return &opOperand == &op->getOpOperand(0) /*source*/
|
||||
? op->getResult(0)
|
||||
: OpResult();
|
||||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
|
||||
return BufferRelation::None;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BufferizationState &state) const {
|
||||
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
|
||||
LDBG("bufferize: " << *extractSliceOp << '\n');
|
||||
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
b.setInsertionPoint(extractSliceOp);
|
||||
|
||||
Location loc = extractSliceOp.getLoc();
|
||||
Value srcMemref = state.lookupBuffer(extractSliceOp.source());
|
||||
auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
|
||||
auto dstTensorType =
|
||||
extractSliceOp.result().getType().cast<RankedTensorType>();
|
||||
|
||||
// If not inplaceable, alloc.
|
||||
bool inplace = state.aliasInfo.isInPlace(extractSliceOp->getResult(0));
|
||||
Value alloc;
|
||||
if (!inplace)
|
||||
alloc = createNewAllocDeallocPairForShapedValue(
|
||||
b, loc, extractSliceOp.result(), state);
|
||||
|
||||
// Bufferize to subview.
|
||||
auto subviewMemRefType =
|
||||
memref::SubViewOp::inferRankReducedResultType(
|
||||
dstTensorType.getRank(), srcMemrefType,
|
||||
extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
|
||||
extractSliceOp.getMixedStrides())
|
||||
.cast<MemRefType>();
|
||||
Value subView = b.create<memref::SubViewOp>(
|
||||
loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(),
|
||||
extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
|
||||
// Insert new alias.
|
||||
state.aliasInfo.insertNewBufferAlias(subView, srcMemref);
|
||||
|
||||
/// If not inplaceable, copy.
|
||||
if (!inplace) {
|
||||
// Do not copy if the copied data is never read.
|
||||
if (isValueRead(extractSliceOp.result()))
|
||||
state.allocationFns.memCpyFn(b, extractSliceOp.getLoc(), subView,
|
||||
alloc);
|
||||
subView = alloc;
|
||||
}
|
||||
|
||||
state.mapBuffer(extractSliceOp.result(), subView);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ExtractOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
|
||||
tensor::ExtractOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
|
||||
return OpResult();
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BufferizationState &state) const {
|
||||
auto extractOp = cast<tensor::ExtractOp>(op);
|
||||
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
b.setInsertionPoint(extractOp);
|
||||
|
||||
Location loc = extractOp.getLoc();
|
||||
Value srcMemref = state.lookupBuffer(extractOp.tensor());
|
||||
Value l = b.create<memref::LoadOp>(loc, srcMemref, extractOp.indices());
|
||||
extractOp.replaceAllUsesWith(l);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
|
||||
/// equivalent operand / result and same offset/sizes/strides specification).
|
||||
///
|
||||
/// This is one particular type of relationship between ops on tensors that
|
||||
/// reduce to an equivalence on buffers. This should be generalized and
|
||||
/// exposed as interfaces on the proper types.
|
||||
static bool
|
||||
areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
|
||||
ExtractSliceOp st, InsertSliceOp sti) {
|
||||
if (!st || !sti)
|
||||
return false;
|
||||
if (!aliasInfo.areEquivalentBufferizedValues(st.source(), sti.dest()))
|
||||
return false;
|
||||
if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Return true if the source of a `insertSliceOp` bufferizes to an
|
||||
/// equivalent ExtractSliceOp that bufferizes inplace.
|
||||
static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp(
|
||||
const BufferizationAliasInfo &aliasInfo, InsertSliceOp insertSliceOp) {
|
||||
LDBG("isSourceEquivalentToAMatchingInplaceExtractSliceOp: " << *insertSliceOp
|
||||
<< '\n');
|
||||
bool foundOp = false;
|
||||
aliasInfo.applyOnEquivalenceClass(insertSliceOp.source(), [&](Value value) {
|
||||
auto extractSliceOp = value.getDefiningOp<ExtractSliceOp>();
|
||||
if (extractSliceOp &&
|
||||
areEquivalentExtractSliceOps(aliasInfo, extractSliceOp,
|
||||
insertSliceOp) &&
|
||||
aliasInfo.isInPlace(extractSliceOp->getResult(0))) {
|
||||
LDBG("\tfound: " << extractSliceOp.getOperation() << '\n');
|
||||
foundOp = true;
|
||||
}
|
||||
});
|
||||
|
||||
if (!foundOp)
|
||||
LDBG("\tnot equivalent\n");
|
||||
|
||||
return foundOp;
|
||||
}
|
||||
|
||||
/// Return true if `value` is originating from an ExtractSliceOp that matches
|
||||
/// the given InsertSliceOp.
|
||||
static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
|
||||
Value value, InsertSliceOp insertOp) {
|
||||
auto condition = [&](Value val) {
|
||||
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
|
||||
if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp))
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
|
||||
return llvm::all_of(findValueInReverseUseDefChain(value, condition),
|
||||
condition);
|
||||
}
|
||||
|
||||
struct InsertSliceOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
|
||||
tensor::InsertSliceOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
|
||||
return &opOperand == &op->getOpOperand(1) /*dest*/;
|
||||
}
|
||||
|
||||
SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
|
||||
OpResult opResult) const {
|
||||
return {&op->getOpOperand(1) /*dest*/};
|
||||
}
|
||||
|
||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
|
||||
return &opOperand == &op->getOpOperand(1) /*dest*/
|
||||
? op->getResult(0)
|
||||
: OpResult();
|
||||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
|
||||
return BufferRelation::Equivalent;
|
||||
}
|
||||
|
||||
bool isNotConflicting(Operation *op, OpOperand *uRead,
|
||||
OpOperand *uConflictingWrite,
|
||||
const BufferizationAliasInfo &aliasInfo) const {
|
||||
Operation *readingOp = uRead->getOwner();
|
||||
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
|
||||
|
||||
// Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
|
||||
// uRead is an InsertSliceOp...
|
||||
if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
|
||||
// As an example, consider the following IR.
|
||||
//
|
||||
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
|
||||
// %1 = linalg.fill %cst, %0 {inplace= [true] }
|
||||
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
|
||||
// {inplace= [true] }
|
||||
|
||||
// TODO: Use insertSliceOp.getDestOpOperand etc. when available.
|
||||
if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||
hasMatchingExtractSliceOp(aliasInfo, uConflictingWrite->get(),
|
||||
insertSliceOp))
|
||||
// Case 1: The main insight is that InsertSliceOp reads only part of
|
||||
// the destination tensor. The overwritten area is not read. If
|
||||
// uConflictingWrite writes into exactly the memory location that is
|
||||
// being read by uRead, this is not a conflict.
|
||||
//
|
||||
// In the above example:
|
||||
// uRead = OpOperand 1 (%t) of tensor.insert_slice
|
||||
// uConflictingWrite = OpOperand 1 (%0) of linalg.fill
|
||||
//
|
||||
// The read of %t does not conflict with the write of the FillOp
|
||||
// (same aliases!) because the area that the FillOp operates on is
|
||||
// exactly the one that is *not* read via %t.
|
||||
return true;
|
||||
|
||||
if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
|
||||
uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||
hasMatchingExtractSliceOp(aliasInfo, uRead->get(), insertSliceOp))
|
||||
// Case 2: The read of the source tensor and the write to the dest
|
||||
// tensor via an InsertSliceOp is not a conflict if the read is
|
||||
// reading exactly that part of an equivalent tensor that the
|
||||
// InsertSliceOp is writing.
|
||||
//
|
||||
// In the above example:
|
||||
// uRead = OpOperand 0 (%1) of tensor.insert_slice
|
||||
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
|
||||
return true;
|
||||
}
|
||||
|
||||
// If uConflictingWrite is an InsertSliceOp...
|
||||
if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
|
||||
// As an example, consider the following IR.
|
||||
//
|
||||
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
|
||||
// %1 = linalg.fill %cst, %0 {inplace= [true] }
|
||||
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
|
||||
// {inplace= [true] }
|
||||
// %3 = vector.transfer_read %1, %cst
|
||||
//
|
||||
// In the above example:
|
||||
// uRead = OpOperand 0 (%1) of vector.transfer_read
|
||||
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
|
||||
// lastWrite = %1
|
||||
//
|
||||
// This is not a conflict because the InsertSliceOp overwrites the
|
||||
// memory segment of %1 with the exact same data. (Effectively, there
|
||||
// is no memory write here.)
|
||||
if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||
aliasInfo.areEquivalentBufferizedValues(uRead->get(),
|
||||
insertSliceOp.source()) &&
|
||||
hasMatchingExtractSliceOp(aliasInfo, insertSliceOp.source(),
|
||||
insertSliceOp))
|
||||
return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BufferizationState &state) const {
|
||||
// insert_slice ops arise from tiling and bufferizing them out-of-place is
|
||||
// generally a deal breaker. When used with loops, this ends up cloning the
|
||||
// whole tensor on every single iteration and is a symptom of a
|
||||
// catastrophically bad scheduling decision.
|
||||
// TODO: be very loud about it or even consider failing the pass.
|
||||
auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
|
||||
LDBG("bufferize: " << *insertSliceOp << '\n');
|
||||
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
b.setInsertionPoint(insertSliceOp);
|
||||
Location loc = insertSliceOp.getLoc();
|
||||
|
||||
// When bufferizing out-of-place, `getResultBuffer` allocates.
|
||||
Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), state);
|
||||
if (!dstMemref)
|
||||
return failure();
|
||||
|
||||
// A copy of the source buffer is needed if either:
|
||||
// - The producer of `source` is not inplace. This is the case where a
|
||||
// slice is computed out of place into the inplace full tensor.
|
||||
// - The result is not inplace. This is the case where the whole tensor is
|
||||
// cloned and the clone needs to be updated.
|
||||
// TODO: Is this necessary?
|
||||
bool needCopy = !isSourceEquivalentToAMatchingInplaceExtractSliceOp(
|
||||
state.aliasInfo, insertSliceOp) ||
|
||||
!state.aliasInfo.isInPlace(insertSliceOp->getResult(0));
|
||||
if (needCopy) {
|
||||
LDBG("insert_slice needs extra source copy: " << insertSliceOp.source()
|
||||
<< " -> copy\n");
|
||||
// Take a subview of the dst.
|
||||
auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
|
||||
auto subviewMemRefType =
|
||||
memref::SubViewOp::inferRankReducedResultType(
|
||||
insertSliceOp.getSourceType().getRank(), dstMemrefType,
|
||||
insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
|
||||
insertSliceOp.getMixedStrides())
|
||||
.cast<MemRefType>();
|
||||
Value subView = b.create<memref::SubViewOp>(
|
||||
loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
|
||||
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
|
||||
// Insert new alias.
|
||||
state.aliasInfo.insertNewBufferAlias(subView, dstMemref);
|
||||
// Copy tensor.
|
||||
Value srcMemref = state.lookupBuffer(insertSliceOp.source());
|
||||
state.allocationFns.memCpyFn(b, insertSliceOp.getLoc(), srcMemref,
|
||||
subView);
|
||||
}
|
||||
|
||||
state.mapBuffer(insertSliceOp.result(), dstMemref);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensor_ext
|
||||
|
||||
namespace vector_ext {
|
||||
|
||||
struct TransferReadOpInterface
|
||||
|
@ -2484,13 +2031,6 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
|
|||
registry.addOpInterface<scf::YieldOp, scf_ext::YieldOpInterface>();
|
||||
registry.addOpInterface<CallOp, std_ext::CallOpInterface>();
|
||||
registry.addOpInterface<ReturnOp, std_ext::ReturnOpInterface>();
|
||||
registry.addOpInterface<tensor::CastOp, tensor_ext::CastOpInterface>();
|
||||
registry.addOpInterface<tensor::DimOp, tensor_ext::DimOpInterface>();
|
||||
registry.addOpInterface<tensor::ExtractSliceOp,
|
||||
tensor_ext::ExtractSliceOpInterface>();
|
||||
registry.addOpInterface<tensor::ExtractOp, tensor_ext::ExtractOpInterface>();
|
||||
registry.addOpInterface<tensor::InsertSliceOp,
|
||||
tensor_ext::InsertSliceOpInterface>();
|
||||
registry.addOpInterface<vector::TransferReadOp,
|
||||
vector_ext::TransferReadOpInterface>();
|
||||
registry.addOpInterface<vector::TransferWriteOp,
|
||||
|
|
|
@ -0,0 +1,437 @@
|
|||
//===- TensorInterfaceImpl.cpp - Tensor Impl. of BufferizableOpInterface --===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace linalg {
|
||||
namespace comprehensive_bufferize {
|
||||
namespace tensor_ext {
|
||||
|
||||
using tensor::ExtractSliceOp;
|
||||
using tensor::InsertSliceOp;
|
||||
|
||||
struct CastOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<CastOpInterface,
|
||||
tensor::CastOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
|
||||
OpResult opResult) const {
|
||||
return {&op->getOpOperand(0)};
|
||||
}
|
||||
|
||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
|
||||
return op->getResult(0);
|
||||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
|
||||
return BufferRelation::Equivalent;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BufferizationState &state) const {
|
||||
auto castOp = cast<tensor::CastOp>(op);
|
||||
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
b.setInsertionPoint(castOp);
|
||||
|
||||
Value resultBuffer = getResultBuffer(b, castOp->getResult(0), state);
|
||||
if (!resultBuffer)
|
||||
return failure();
|
||||
Type sourceType = resultBuffer.getType();
|
||||
auto rankedMemRefType = sourceType.dyn_cast<MemRefType>();
|
||||
auto unrankedMemRefType = sourceType.dyn_cast<UnrankedMemRefType>();
|
||||
assert(rankedMemRefType || unrankedMemRefType);
|
||||
Attribute memorySpace = rankedMemRefType
|
||||
? rankedMemRefType.getMemorySpace()
|
||||
: unrankedMemRefType.getMemorySpace();
|
||||
TensorType tensorType = castOp.getResult().getType().cast<TensorType>();
|
||||
MemRefLayoutAttrInterface layout =
|
||||
rankedMemRefType && tensorType.isa<RankedTensorType>()
|
||||
? rankedMemRefType.getLayout()
|
||||
: MemRefLayoutAttrInterface();
|
||||
Type memRefType = getContiguousOrUnrankedMemRefType(
|
||||
castOp.getResult().getType(), layout, memorySpace);
|
||||
Value res =
|
||||
b.create<memref::CastOp>(castOp.getLoc(), memRefType, resultBuffer);
|
||||
state.aliasInfo.insertNewBufferEquivalence(res, castOp.getResult());
|
||||
state.mapBuffer(castOp.getResult(), res);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct DimOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<DimOpInterface,
|
||||
tensor::DimOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
|
||||
return OpResult();
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BufferizationState &state) const {
|
||||
auto dimOp = cast<tensor::DimOp>(op);
|
||||
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
b.setInsertionPoint(dimOp);
|
||||
|
||||
if (dimOp.source().getType().isa<RankedTensorType>()) {
|
||||
Value v = state.lookupBuffer(dimOp.source());
|
||||
dimOp.result().replaceAllUsesWith(
|
||||
b.create<memref::DimOp>(dimOp.getLoc(), v, dimOp.index()));
|
||||
}
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ExtractSliceOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
|
||||
tensor::ExtractSliceOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
|
||||
OpResult opResult) const {
|
||||
return {&op->getOpOperand(0) /*source*/};
|
||||
}
|
||||
|
||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
|
||||
return &opOperand == &op->getOpOperand(0) /*source*/
|
||||
? op->getResult(0)
|
||||
: OpResult();
|
||||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
|
||||
return BufferRelation::None;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BufferizationState &state) const {
|
||||
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
|
||||
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
b.setInsertionPoint(extractSliceOp);
|
||||
|
||||
Location loc = extractSliceOp.getLoc();
|
||||
Value srcMemref = state.lookupBuffer(extractSliceOp.source());
|
||||
auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
|
||||
auto dstTensorType =
|
||||
extractSliceOp.result().getType().cast<RankedTensorType>();
|
||||
|
||||
// If not inplaceable, alloc.
|
||||
bool inplace = state.aliasInfo.isInPlace(extractSliceOp->getResult(0));
|
||||
Value alloc;
|
||||
if (!inplace)
|
||||
alloc = state.allocationFns.createAllocDeallocFn(
|
||||
b, loc, extractSliceOp.result(), state);
|
||||
|
||||
// Bufferize to subview.
|
||||
auto subviewMemRefType =
|
||||
memref::SubViewOp::inferRankReducedResultType(
|
||||
dstTensorType.getRank(), srcMemrefType,
|
||||
extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
|
||||
extractSliceOp.getMixedStrides())
|
||||
.cast<MemRefType>();
|
||||
Value subView = b.create<memref::SubViewOp>(
|
||||
loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(),
|
||||
extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
|
||||
// Insert new alias.
|
||||
state.aliasInfo.insertNewBufferAlias(subView, srcMemref);
|
||||
|
||||
/// If not inplaceable, copy.
|
||||
if (!inplace) {
|
||||
// Do not copy if the copied data is never read.
|
||||
if (isValueRead(extractSliceOp.result()))
|
||||
state.allocationFns.memCpyFn(b, extractSliceOp.getLoc(), subView,
|
||||
alloc);
|
||||
subView = alloc;
|
||||
}
|
||||
|
||||
state.mapBuffer(extractSliceOp.result(), subView);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ExtractOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
|
||||
tensor::ExtractOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
|
||||
return OpResult();
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BufferizationState &state) const {
|
||||
auto extractOp = cast<tensor::ExtractOp>(op);
|
||||
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
b.setInsertionPoint(extractOp);
|
||||
|
||||
Location loc = extractOp.getLoc();
|
||||
Value srcMemref = state.lookupBuffer(extractOp.tensor());
|
||||
Value l = b.create<memref::LoadOp>(loc, srcMemref, extractOp.indices());
|
||||
extractOp.replaceAllUsesWith(l);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
|
||||
/// equivalent operand / result and same offset/sizes/strides specification).
|
||||
///
|
||||
/// This is one particular type of relationship between ops on tensors that
|
||||
/// reduce to an equivalence on buffers. This should be generalized and
|
||||
/// exposed as interfaces on the proper types.
|
||||
static bool
|
||||
areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
|
||||
ExtractSliceOp st, InsertSliceOp sti) {
|
||||
if (!st || !sti)
|
||||
return false;
|
||||
if (!aliasInfo.areEquivalentBufferizedValues(st.source(), sti.dest()))
|
||||
return false;
|
||||
if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Return true if the source of a `insertSliceOp` bufferizes to an
|
||||
/// equivalent ExtractSliceOp that bufferizes inplace.
|
||||
static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp(
|
||||
const BufferizationAliasInfo &aliasInfo, InsertSliceOp insertSliceOp) {
|
||||
bool foundOp = false;
|
||||
aliasInfo.applyOnEquivalenceClass(insertSliceOp.source(), [&](Value value) {
|
||||
auto extractSliceOp = value.getDefiningOp<ExtractSliceOp>();
|
||||
if (extractSliceOp &&
|
||||
areEquivalentExtractSliceOps(aliasInfo, extractSliceOp,
|
||||
insertSliceOp) &&
|
||||
aliasInfo.isInPlace(extractSliceOp->getResult(0))) {
|
||||
foundOp = true;
|
||||
}
|
||||
});
|
||||
return foundOp;
|
||||
}
|
||||
|
||||
/// Return true if `value` is originating from an ExtractSliceOp that matches
|
||||
/// the given InsertSliceOp.
|
||||
static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
|
||||
Value value, InsertSliceOp insertOp) {
|
||||
auto condition = [&](Value val) {
|
||||
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
|
||||
if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp))
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
|
||||
return llvm::all_of(findValueInReverseUseDefChain(value, condition),
|
||||
condition);
|
||||
}
|
||||
|
||||
struct InsertSliceOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
|
||||
tensor::InsertSliceOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
|
||||
return &opOperand == &op->getOpOperand(1) /*dest*/;
|
||||
}
|
||||
|
||||
SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
|
||||
OpResult opResult) const {
|
||||
return {&op->getOpOperand(1) /*dest*/};
|
||||
}
|
||||
|
||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
|
||||
return &opOperand == &op->getOpOperand(1) /*dest*/
|
||||
? op->getResult(0)
|
||||
: OpResult();
|
||||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
|
||||
return BufferRelation::Equivalent;
|
||||
}
|
||||
|
||||
bool isNotConflicting(Operation *op, OpOperand *uRead,
|
||||
OpOperand *uConflictingWrite,
|
||||
const BufferizationAliasInfo &aliasInfo) const {
|
||||
Operation *readingOp = uRead->getOwner();
|
||||
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
|
||||
|
||||
// Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
|
||||
// uRead is an InsertSliceOp...
|
||||
if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
|
||||
// As an example, consider the following IR.
|
||||
//
|
||||
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
|
||||
// %1 = linalg.fill %cst, %0 {inplace= [true] }
|
||||
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
|
||||
// {inplace= [true] }
|
||||
|
||||
// TODO: Use insertSliceOp.getDestOpOperand etc. when available.
|
||||
if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||
hasMatchingExtractSliceOp(aliasInfo, uConflictingWrite->get(),
|
||||
insertSliceOp))
|
||||
// Case 1: The main insight is that InsertSliceOp reads only part of
|
||||
// the destination tensor. The overwritten area is not read. If
|
||||
// uConflictingWrite writes into exactly the memory location that is
|
||||
// being read by uRead, this is not a conflict.
|
||||
//
|
||||
// In the above example:
|
||||
// uRead = OpOperand 1 (%t) of tensor.insert_slice
|
||||
// uConflictingWrite = OpOperand 1 (%0) of linalg.fill
|
||||
//
|
||||
// The read of %t does not conflict with the write of the FillOp
|
||||
// (same aliases!) because the area that the FillOp operates on is
|
||||
// exactly the one that is *not* read via %t.
|
||||
return true;
|
||||
|
||||
if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
|
||||
uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||
hasMatchingExtractSliceOp(aliasInfo, uRead->get(), insertSliceOp))
|
||||
// Case 2: The read of the source tensor and the write to the dest
|
||||
// tensor via an InsertSliceOp is not a conflict if the read is
|
||||
// reading exactly that part of an equivalent tensor that the
|
||||
// InsertSliceOp is writing.
|
||||
//
|
||||
// In the above example:
|
||||
// uRead = OpOperand 0 (%1) of tensor.insert_slice
|
||||
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
|
||||
return true;
|
||||
}
|
||||
|
||||
// If uConflictingWrite is an InsertSliceOp...
|
||||
if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
|
||||
// As an example, consider the following IR.
|
||||
//
|
||||
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
|
||||
// %1 = linalg.fill %cst, %0 {inplace= [true] }
|
||||
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
|
||||
// {inplace= [true] }
|
||||
// %3 = vector.transfer_read %1, %cst
|
||||
//
|
||||
// In the above example:
|
||||
// uRead = OpOperand 0 (%1) of vector.transfer_read
|
||||
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
|
||||
// lastWrite = %1
|
||||
//
|
||||
// This is not a conflict because the InsertSliceOp overwrites the
|
||||
// memory segment of %1 with the exact same data. (Effectively, there
|
||||
// is no memory write here.)
|
||||
if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||
aliasInfo.areEquivalentBufferizedValues(uRead->get(),
|
||||
insertSliceOp.source()) &&
|
||||
hasMatchingExtractSliceOp(aliasInfo, insertSliceOp.source(),
|
||||
insertSliceOp))
|
||||
return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BufferizationState &state) const {
|
||||
// insert_slice ops arise from tiling and bufferizing them out-of-place is
|
||||
// generally a deal breaker. When used with loops, this ends up cloning the
|
||||
// whole tensor on every single iteration and is a symptom of a
|
||||
// catastrophically bad scheduling decision.
|
||||
// TODO: be very loud about it or even consider failing the pass.
|
||||
auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
|
||||
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
b.setInsertionPoint(insertSliceOp);
|
||||
Location loc = insertSliceOp.getLoc();
|
||||
|
||||
// When bufferizing out-of-place, `getResultBuffer` allocates.
|
||||
Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), state);
|
||||
if (!dstMemref)
|
||||
return failure();
|
||||
|
||||
// A copy of the source buffer is needed if either:
|
||||
// - The producer of `source` is not inplace. This is the case where a
|
||||
// slice is computed out of place into the inplace full tensor.
|
||||
// - The result is not inplace. This is the case where the whole tensor is
|
||||
// cloned and the clone needs to be updated.
|
||||
// TODO: Is this necessary?
|
||||
bool needCopy = !isSourceEquivalentToAMatchingInplaceExtractSliceOp(
|
||||
state.aliasInfo, insertSliceOp) ||
|
||||
!state.aliasInfo.isInPlace(insertSliceOp->getResult(0));
|
||||
if (needCopy) {
|
||||
// Take a subview of the dst.
|
||||
auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
|
||||
auto subviewMemRefType =
|
||||
memref::SubViewOp::inferRankReducedResultType(
|
||||
insertSliceOp.getSourceType().getRank(), dstMemrefType,
|
||||
insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
|
||||
insertSliceOp.getMixedStrides())
|
||||
.cast<MemRefType>();
|
||||
Value subView = b.create<memref::SubViewOp>(
|
||||
loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
|
||||
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
|
||||
// Insert new alias.
|
||||
state.aliasInfo.insertNewBufferAlias(subView, dstMemref);
|
||||
// Copy tensor.
|
||||
Value srcMemref = state.lookupBuffer(insertSliceOp.source());
|
||||
state.allocationFns.memCpyFn(b, insertSliceOp.getLoc(), srcMemref,
|
||||
subView);
|
||||
}
|
||||
|
||||
state.mapBuffer(insertSliceOp.result(), dstMemref);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensor_ext
|
||||
} // namespace comprehensive_bufferize
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
||||
void mlir::linalg::comprehensive_bufferize::tensor_ext::
|
||||
registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
|
||||
registry.addOpInterface<tensor::CastOp, tensor_ext::CastOpInterface>();
|
||||
registry.addOpInterface<tensor::DimOp, tensor_ext::DimOpInterface>();
|
||||
registry.addOpInterface<tensor::ExtractSliceOp,
|
||||
tensor_ext::ExtractSliceOpInterface>();
|
||||
registry.addOpInterface<tensor::ExtractOp, tensor_ext::ExtractOpInterface>();
|
||||
registry.addOpInterface<tensor::InsertSliceOp,
|
||||
tensor_ext::InsertSliceOpInterface>();
|
||||
}
|
|
@ -49,6 +49,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
|
|||
MLIRStandardOpsTransforms
|
||||
MLIRStandardToLLVM
|
||||
MLIRTensor
|
||||
MLIRTensorBufferizableOpInterfaceImpl
|
||||
MLIRTransforms
|
||||
MLIRTransformUtils
|
||||
MLIRVector
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
|
@ -38,6 +39,7 @@ struct LinalgComprehensiveModuleBufferize
|
|||
arith::ArithmeticDialect, StandardOpsDialect, AffineDialect>();
|
||||
registerBufferizableOpInterfaceExternalModels(registry);
|
||||
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
tensor_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
}
|
||||
};
|
||||
} // end namespace
|
||||
|
|
|
@ -6326,6 +6326,25 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "TensorBufferizableOpInterfaceImpl",
|
||||
srcs = [
|
||||
"lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp",
|
||||
],
|
||||
hdrs = [
|
||||
"include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h",
|
||||
],
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
":BufferizableOpInterface",
|
||||
":IR",
|
||||
":MemRefDialect",
|
||||
":Support",
|
||||
":TensorDialect",
|
||||
"//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
td_library(
|
||||
name = "LinalgDocTdFiles",
|
||||
srcs = ["include/mlir/Dialect/Linalg/IR/LinalgDoc.td"],
|
||||
|
@ -6545,6 +6564,7 @@ cc_library(
|
|||
":StandardOps",
|
||||
":StandardOpsTransforms",
|
||||
":Support",
|
||||
":TensorBufferizableOpInterfaceImpl",
|
||||
":TensorDialect",
|
||||
":TransformUtils",
|
||||
":VectorOps",
|
||||
|
@ -6575,7 +6595,6 @@ cc_library(
|
|||
":SCFDialect",
|
||||
":StandardOps",
|
||||
":Support",
|
||||
":TensorDialect",
|
||||
":TransformUtils",
|
||||
":VectorOps",
|
||||
"//llvm:Support",
|
||||
|
|
Loading…
Reference in New Issue