[mlir][linalg][bufferize][NFC] Move tensor interface impl to new build target

This makes ComprehensiveBufferize entirely independent of the tensor dialect.

Differential Revision: https://reviews.llvm.org/D114217
This commit is contained in:
Matthias Springer 2021-11-24 18:20:00 +09:00
parent 8ef460fc51
commit bb273a35a0
10 changed files with 555 additions and 462 deletions

View File

@ -322,6 +322,26 @@ struct PostAnalysisStep {
SmallVector<Operation *> &newOps) = 0;
};
/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
/// with the same shape as `shapedType` and specified `layout` and
/// `addressSpace`.
MemRefType getContiguousMemRefType(ShapedType shapedType,
MemRefLayoutAttrInterface layout = {},
Attribute memorySpace = {});
/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
/// with the same shape as `shapedType` and specified `layout` and
/// `addressSpace` or an UnrankedMemRefType otherwise.
Type getContiguousOrUnrankedMemRefType(Type type,
MemRefLayoutAttrInterface layout = {},
Attribute memorySpace = {});
/// Return a MemRefType to which the `tensorType` can be bufferized in a
/// composable fashion. The layout must be the most dynamic possible and
/// canonicalize away once bufferization is finished.
MemRefType getDynamicMemRefType(RankedTensorType tensorType,
unsigned addressSpace = 0);
} // namespace comprehensive_bufferize
} // namespace linalg
} // namespace mlir

View File

@ -1,3 +1,11 @@
//===- LinalgInterfaceImpl.h - Linalg Impl. of BufferizableOpInterface ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_LINALG_INTERFACE_IMPL_H
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_LINALG_INTERFACE_IMPL_H

View File

@ -0,0 +1,27 @@
//===- LinalgInterfaceImpl.h - Linalg Impl. of BufferizableOpInterface ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_TENSOR_INTERFACE_IMPL_H
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_TENSOR_INTERFACE_IMPL_H
namespace mlir {
class DialectRegistry;
namespace linalg {
namespace comprehensive_bufferize {
namespace tensor_ext {
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
} // namespace tensor_ext
} // namespace comprehensive_bufferize
} // namespace linalg
} // namespace mlir
#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_TENSOR_INTERFACE_IMPL_H

View File

@ -12,6 +12,7 @@
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "llvm/Support/Debug.h"
@ -528,3 +529,31 @@ void mlir::linalg::comprehensive_bufferize::BufferizationState::
op->erase();
obsoleteOps.clear();
}
MemRefType mlir::linalg::comprehensive_bufferize::getContiguousMemRefType(
ShapedType shapedType, MemRefLayoutAttrInterface layout,
Attribute memorySpace) {
return MemRefType::get(shapedType.getShape(), shapedType.getElementType(),
layout, memorySpace);
}
Type mlir::linalg::comprehensive_bufferize::getContiguousOrUnrankedMemRefType(
Type type, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
if (type.isa<RankedTensorType, MemRefType>())
return getContiguousMemRefType(type.cast<ShapedType>(), layout,
memorySpace);
assert(!layout && "expected empty layout with UnrankedMemRefType");
return UnrankedMemRefType::get(getElementTypeOrSelf(type), memorySpace);
}
MemRefType mlir::linalg::comprehensive_bufferize::getDynamicMemRefType(
RankedTensorType tensorType, unsigned addressSpace) {
// TODO: address space decisions to connect with the actual alloc.
int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
SmallVector<int64_t> dynamicStrides(tensorType.getRank(),
ShapedType::kDynamicStrideOrOffset);
AffineMap stridedLayout = makeStridedLinearLayoutMap(
dynamicStrides, dynamicOffset, tensorType.getContext());
return MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
stridedLayout, addressSpace);
}

View File

@ -2,6 +2,7 @@ set(LLVM_OPTIONAL_SOURCES
BufferizableOpInterface.cpp
ComprehensiveBufferize.cpp
LinalgInterfaceImpl.cpp
TensorInterfaceImpl.cpp
)
add_mlir_dialect_library(MLIRBufferizableOpInterface
@ -25,6 +26,16 @@ add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl
MLIRTensor
)
add_mlir_dialect_library(MLIRTensorBufferizableOpInterfaceImpl
TensorInterfaceImpl.cpp
LINK_LIBS PUBLIC
MLIRBufferizableOpInterface
MLIRIR
MLIRMemRef
MLIRTensor
)
add_mlir_dialect_library(MLIRComprehensiveBufferize
ComprehensiveBufferize.cpp
@ -37,6 +48,5 @@ add_mlir_dialect_library(MLIRComprehensiveBufferize
MLIRSCF
MLIRStandard
MLIRStandardOpsTransforms
MLIRTensor
MLIRVector
)

View File

@ -587,45 +587,6 @@ getEquivalentEnclosingFuncBBArg(Value v,
// Bufferization-specific MemRefType support.
//===----------------------------------------------------------------------===//
/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
/// with the same shape as `shapedType` and specified `layout` and
/// `addressSpace`.
static MemRefType getContiguousMemRefType(ShapedType shapedType,
MemRefLayoutAttrInterface layout = {},
Attribute memorySpace = {}) {
return MemRefType::get(shapedType.getShape(), shapedType.getElementType(),
layout, memorySpace);
}
/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
/// with the same shape as `shapedType` and specified `layout` and
/// `addressSpace` or an UnrankedMemRefType otherwise.
static Type
getContiguousOrUnrankedMemRefType(Type type,
MemRefLayoutAttrInterface layout = {},
Attribute memorySpace = {}) {
if (type.isa<RankedTensorType, MemRefType>())
return getContiguousMemRefType(type.cast<ShapedType>(), layout,
memorySpace);
assert(!layout && "expected empty layout with UnrankedMemRefType");
return UnrankedMemRefType::get(getElementTypeOrSelf(type), memorySpace);
}
/// Return a MemRefType to which the `tensorType` can be bufferized in a
/// composable fashion. The layout must be the most dynamic possible and
/// canonicalize away once bufferization is finished.
static MemRefType getDynamicMemRefType(RankedTensorType tensorType,
unsigned addressSpace = 0) {
// TODO: address space decisions to connect with the actual alloc.
int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
SmallVector<int64_t> dynamicStrides(tensorType.getRank(),
ShapedType::kDynamicStrideOrOffset);
AffineMap stridedLayout = makeStridedLinearLayoutMap(
dynamicStrides, dynamicOffset, tensorType.getContext());
return MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
stridedLayout, addressSpace);
}
/// Return the FunctionType with `argumentTypes` and `resultTypes` where each
/// tensor is replaced by the corresponding buffer type.
/// In order for all the callers to agree, this *must* bufferize to the most
@ -1965,420 +1926,6 @@ struct ReturnOpInterface
} // namespace std_ext
namespace tensor_ext {
struct CastOpInterface
: public BufferizableOpInterface::ExternalModel<CastOpInterface,
tensor::CastOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
return false;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
return false;
}
SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
OpResult opResult) const {
return {&op->getOpOperand(0)};
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
return op->getResult(0);
}
BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
return BufferRelation::Equivalent;
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
auto castOp = cast<tensor::CastOp>(op);
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(castOp);
Value resultBuffer = getResultBuffer(b, castOp->getResult(0), state);
if (!resultBuffer)
return failure();
Type sourceType = resultBuffer.getType();
auto rankedMemRefType = sourceType.dyn_cast<MemRefType>();
auto unrankedMemRefType = sourceType.dyn_cast<UnrankedMemRefType>();
assert(rankedMemRefType || unrankedMemRefType);
Attribute memorySpace = rankedMemRefType
? rankedMemRefType.getMemorySpace()
: unrankedMemRefType.getMemorySpace();
TensorType tensorType = castOp.getResult().getType().cast<TensorType>();
MemRefLayoutAttrInterface layout =
rankedMemRefType && tensorType.isa<RankedTensorType>()
? rankedMemRefType.getLayout()
: MemRefLayoutAttrInterface();
Type memRefType = getContiguousOrUnrankedMemRefType(
castOp.getResult().getType(), layout, memorySpace);
Value res =
b.create<memref::CastOp>(castOp.getLoc(), memRefType, resultBuffer);
state.aliasInfo.insertNewBufferEquivalence(res, castOp.getResult());
state.mapBuffer(castOp.getResult(), res);
return success();
}
};
struct DimOpInterface
: public BufferizableOpInterface::ExternalModel<DimOpInterface,
tensor::DimOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
return false;
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
return OpResult();
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
auto dimOp = cast<tensor::DimOp>(op);
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(dimOp);
if (dimOp.source().getType().isa<RankedTensorType>()) {
Value v = state.lookupBuffer(dimOp.source());
dimOp.result().replaceAllUsesWith(
b.create<memref::DimOp>(dimOp.getLoc(), v, dimOp.index()));
}
return success();
}
};
struct ExtractSliceOpInterface
: public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
tensor::ExtractSliceOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
return false;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
return false;
}
SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
OpResult opResult) const {
return {&op->getOpOperand(0) /*source*/};
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
return &opOperand == &op->getOpOperand(0) /*source*/
? op->getResult(0)
: OpResult();
}
BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
return BufferRelation::None;
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
LDBG("bufferize: " << *extractSliceOp << '\n');
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(extractSliceOp);
Location loc = extractSliceOp.getLoc();
Value srcMemref = state.lookupBuffer(extractSliceOp.source());
auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
auto dstTensorType =
extractSliceOp.result().getType().cast<RankedTensorType>();
// If not inplaceable, alloc.
bool inplace = state.aliasInfo.isInPlace(extractSliceOp->getResult(0));
Value alloc;
if (!inplace)
alloc = createNewAllocDeallocPairForShapedValue(
b, loc, extractSliceOp.result(), state);
// Bufferize to subview.
auto subviewMemRefType =
memref::SubViewOp::inferRankReducedResultType(
dstTensorType.getRank(), srcMemrefType,
extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
extractSliceOp.getMixedStrides())
.cast<MemRefType>();
Value subView = b.create<memref::SubViewOp>(
loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(),
extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
// Insert new alias.
state.aliasInfo.insertNewBufferAlias(subView, srcMemref);
/// If not inplaceable, copy.
if (!inplace) {
// Do not copy if the copied data is never read.
if (isValueRead(extractSliceOp.result()))
state.allocationFns.memCpyFn(b, extractSliceOp.getLoc(), subView,
alloc);
subView = alloc;
}
state.mapBuffer(extractSliceOp.result(), subView);
return success();
}
};
struct ExtractOpInterface
: public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
tensor::ExtractOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
return false;
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
return OpResult();
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
auto extractOp = cast<tensor::ExtractOp>(op);
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(extractOp);
Location loc = extractOp.getLoc();
Value srcMemref = state.lookupBuffer(extractOp.tensor());
Value l = b.create<memref::LoadOp>(loc, srcMemref, extractOp.indices());
extractOp.replaceAllUsesWith(l);
return success();
}
};
/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
/// equivalent operand / result and same offset/sizes/strides specification).
///
/// This is one particular type of relationship between ops on tensors that
/// reduce to an equivalence on buffers. This should be generalized and
/// exposed as interfaces on the proper types.
static bool
areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
ExtractSliceOp st, InsertSliceOp sti) {
if (!st || !sti)
return false;
if (!aliasInfo.areEquivalentBufferizedValues(st.source(), sti.dest()))
return false;
if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
return false;
return true;
}
/// Return true if the source of a `insertSliceOp` bufferizes to an
/// equivalent ExtractSliceOp that bufferizes inplace.
static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp(
const BufferizationAliasInfo &aliasInfo, InsertSliceOp insertSliceOp) {
LDBG("isSourceEquivalentToAMatchingInplaceExtractSliceOp: " << *insertSliceOp
<< '\n');
bool foundOp = false;
aliasInfo.applyOnEquivalenceClass(insertSliceOp.source(), [&](Value value) {
auto extractSliceOp = value.getDefiningOp<ExtractSliceOp>();
if (extractSliceOp &&
areEquivalentExtractSliceOps(aliasInfo, extractSliceOp,
insertSliceOp) &&
aliasInfo.isInPlace(extractSliceOp->getResult(0))) {
LDBG("\tfound: " << extractSliceOp.getOperation() << '\n');
foundOp = true;
}
});
if (!foundOp)
LDBG("\tnot equivalent\n");
return foundOp;
}
/// Return true if `value` is originating from an ExtractSliceOp that matches
/// the given InsertSliceOp.
static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
Value value, InsertSliceOp insertOp) {
auto condition = [&](Value val) {
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp))
return true;
return false;
};
return llvm::all_of(findValueInReverseUseDefChain(value, condition),
condition);
}
struct InsertSliceOpInterface
: public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
tensor::InsertSliceOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
return &opOperand == &op->getOpOperand(1) /*dest*/;
}
SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
OpResult opResult) const {
return {&op->getOpOperand(1) /*dest*/};
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
return &opOperand == &op->getOpOperand(1) /*dest*/
? op->getResult(0)
: OpResult();
}
BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
return BufferRelation::Equivalent;
}
bool isNotConflicting(Operation *op, OpOperand *uRead,
OpOperand *uConflictingWrite,
const BufferizationAliasInfo &aliasInfo) const {
Operation *readingOp = uRead->getOwner();
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
// Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
// uRead is an InsertSliceOp...
if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
// As an example, consider the following IR.
//
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
// %1 = linalg.fill %cst, %0 {inplace= [true] }
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
// {inplace= [true] }
// TODO: Use insertSliceOp.getDestOpOperand etc. when available.
if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
hasMatchingExtractSliceOp(aliasInfo, uConflictingWrite->get(),
insertSliceOp))
// Case 1: The main insight is that InsertSliceOp reads only part of
// the destination tensor. The overwritten area is not read. If
// uConflictingWrite writes into exactly the memory location that is
// being read by uRead, this is not a conflict.
//
// In the above example:
// uRead = OpOperand 1 (%t) of tensor.insert_slice
// uConflictingWrite = OpOperand 1 (%0) of linalg.fill
//
// The read of %t does not conflict with the write of the FillOp
// (same aliases!) because the area that the FillOp operates on is
// exactly the one that is *not* read via %t.
return true;
if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
hasMatchingExtractSliceOp(aliasInfo, uRead->get(), insertSliceOp))
// Case 2: The read of the source tensor and the write to the dest
// tensor via an InsertSliceOp is not a conflict if the read is
// reading exactly that part of an equivalent tensor that the
// InsertSliceOp is writing.
//
// In the above example:
// uRead = OpOperand 0 (%1) of tensor.insert_slice
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
return true;
}
// If uConflictingWrite is an InsertSliceOp...
if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
// As an example, consider the following IR.
//
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
// %1 = linalg.fill %cst, %0 {inplace= [true] }
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
// {inplace= [true] }
// %3 = vector.transfer_read %1, %cst
//
// In the above example:
// uRead = OpOperand 0 (%1) of vector.transfer_read
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
// lastWrite = %1
//
// This is not a conflict because the InsertSliceOp overwrites the
// memory segment of %1 with the exact same data. (Effectively, there
// is no memory write here.)
if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
aliasInfo.areEquivalentBufferizedValues(uRead->get(),
insertSliceOp.source()) &&
hasMatchingExtractSliceOp(aliasInfo, insertSliceOp.source(),
insertSliceOp))
return true;
return false;
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
// insert_slice ops arise from tiling and bufferizing them out-of-place is
// generally a deal breaker. When used with loops, this ends up cloning the
// whole tensor on every single iteration and is a symptom of a
// catastrophically bad scheduling decision.
// TODO: be very loud about it or even consider failing the pass.
auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
LDBG("bufferize: " << *insertSliceOp << '\n');
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(insertSliceOp);
Location loc = insertSliceOp.getLoc();
// When bufferizing out-of-place, `getResultBuffer` allocates.
Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), state);
if (!dstMemref)
return failure();
// A copy of the source buffer is needed if either:
// - The producer of `source` is not inplace. This is the case where a
// slice is computed out of place into the inplace full tensor.
// - The result is not inplace. This is the case where the whole tensor is
// cloned and the clone needs to be updated.
// TODO: Is this necessary?
bool needCopy = !isSourceEquivalentToAMatchingInplaceExtractSliceOp(
state.aliasInfo, insertSliceOp) ||
!state.aliasInfo.isInPlace(insertSliceOp->getResult(0));
if (needCopy) {
LDBG("insert_slice needs extra source copy: " << insertSliceOp.source()
<< " -> copy\n");
// Take a subview of the dst.
auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
auto subviewMemRefType =
memref::SubViewOp::inferRankReducedResultType(
insertSliceOp.getSourceType().getRank(), dstMemrefType,
insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
insertSliceOp.getMixedStrides())
.cast<MemRefType>();
Value subView = b.create<memref::SubViewOp>(
loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
// Insert new alias.
state.aliasInfo.insertNewBufferAlias(subView, dstMemref);
// Copy tensor.
Value srcMemref = state.lookupBuffer(insertSliceOp.source());
state.allocationFns.memCpyFn(b, insertSliceOp.getLoc(), srcMemref,
subView);
}
state.mapBuffer(insertSliceOp.result(), dstMemref);
return success();
}
};
} // namespace tensor_ext
namespace vector_ext {
struct TransferReadOpInterface
@ -2484,13 +2031,6 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
registry.addOpInterface<scf::YieldOp, scf_ext::YieldOpInterface>();
registry.addOpInterface<CallOp, std_ext::CallOpInterface>();
registry.addOpInterface<ReturnOp, std_ext::ReturnOpInterface>();
registry.addOpInterface<tensor::CastOp, tensor_ext::CastOpInterface>();
registry.addOpInterface<tensor::DimOp, tensor_ext::DimOpInterface>();
registry.addOpInterface<tensor::ExtractSliceOp,
tensor_ext::ExtractSliceOpInterface>();
registry.addOpInterface<tensor::ExtractOp, tensor_ext::ExtractOpInterface>();
registry.addOpInterface<tensor::InsertSliceOp,
tensor_ext::InsertSliceOpInterface>();
registry.addOpInterface<vector::TransferReadOp,
vector_ext::TransferReadOpInterface>();
registry.addOpInterface<vector::TransferWriteOp,

View File

@ -0,0 +1,437 @@
//===- TensorInterfaceImpl.cpp - Tensor Impl. of BufferizableOpInterface --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
namespace mlir {
namespace linalg {
namespace comprehensive_bufferize {
namespace tensor_ext {
using tensor::ExtractSliceOp;
using tensor::InsertSliceOp;
struct CastOpInterface
: public BufferizableOpInterface::ExternalModel<CastOpInterface,
tensor::CastOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
return false;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
return false;
}
SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
OpResult opResult) const {
return {&op->getOpOperand(0)};
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
return op->getResult(0);
}
BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
return BufferRelation::Equivalent;
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
auto castOp = cast<tensor::CastOp>(op);
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(castOp);
Value resultBuffer = getResultBuffer(b, castOp->getResult(0), state);
if (!resultBuffer)
return failure();
Type sourceType = resultBuffer.getType();
auto rankedMemRefType = sourceType.dyn_cast<MemRefType>();
auto unrankedMemRefType = sourceType.dyn_cast<UnrankedMemRefType>();
assert(rankedMemRefType || unrankedMemRefType);
Attribute memorySpace = rankedMemRefType
? rankedMemRefType.getMemorySpace()
: unrankedMemRefType.getMemorySpace();
TensorType tensorType = castOp.getResult().getType().cast<TensorType>();
MemRefLayoutAttrInterface layout =
rankedMemRefType && tensorType.isa<RankedTensorType>()
? rankedMemRefType.getLayout()
: MemRefLayoutAttrInterface();
Type memRefType = getContiguousOrUnrankedMemRefType(
castOp.getResult().getType(), layout, memorySpace);
Value res =
b.create<memref::CastOp>(castOp.getLoc(), memRefType, resultBuffer);
state.aliasInfo.insertNewBufferEquivalence(res, castOp.getResult());
state.mapBuffer(castOp.getResult(), res);
return success();
}
};
struct DimOpInterface
: public BufferizableOpInterface::ExternalModel<DimOpInterface,
tensor::DimOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
return false;
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
return OpResult();
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
auto dimOp = cast<tensor::DimOp>(op);
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(dimOp);
if (dimOp.source().getType().isa<RankedTensorType>()) {
Value v = state.lookupBuffer(dimOp.source());
dimOp.result().replaceAllUsesWith(
b.create<memref::DimOp>(dimOp.getLoc(), v, dimOp.index()));
}
return success();
}
};
struct ExtractSliceOpInterface
: public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
tensor::ExtractSliceOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
return false;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
return false;
}
SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
OpResult opResult) const {
return {&op->getOpOperand(0) /*source*/};
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
return &opOperand == &op->getOpOperand(0) /*source*/
? op->getResult(0)
: OpResult();
}
BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
return BufferRelation::None;
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(extractSliceOp);
Location loc = extractSliceOp.getLoc();
Value srcMemref = state.lookupBuffer(extractSliceOp.source());
auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
auto dstTensorType =
extractSliceOp.result().getType().cast<RankedTensorType>();
// If not inplaceable, alloc.
bool inplace = state.aliasInfo.isInPlace(extractSliceOp->getResult(0));
Value alloc;
if (!inplace)
alloc = state.allocationFns.createAllocDeallocFn(
b, loc, extractSliceOp.result(), state);
// Bufferize to subview.
auto subviewMemRefType =
memref::SubViewOp::inferRankReducedResultType(
dstTensorType.getRank(), srcMemrefType,
extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
extractSliceOp.getMixedStrides())
.cast<MemRefType>();
Value subView = b.create<memref::SubViewOp>(
loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(),
extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
// Insert new alias.
state.aliasInfo.insertNewBufferAlias(subView, srcMemref);
/// If not inplaceable, copy.
if (!inplace) {
// Do not copy if the copied data is never read.
if (isValueRead(extractSliceOp.result()))
state.allocationFns.memCpyFn(b, extractSliceOp.getLoc(), subView,
alloc);
subView = alloc;
}
state.mapBuffer(extractSliceOp.result(), subView);
return success();
}
};
struct ExtractOpInterface
: public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
tensor::ExtractOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
return false;
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
return OpResult();
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
auto extractOp = cast<tensor::ExtractOp>(op);
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(extractOp);
Location loc = extractOp.getLoc();
Value srcMemref = state.lookupBuffer(extractOp.tensor());
Value l = b.create<memref::LoadOp>(loc, srcMemref, extractOp.indices());
extractOp.replaceAllUsesWith(l);
return success();
}
};
/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
/// equivalent operand / result and same offset/sizes/strides specification).
///
/// This is one particular type of relationship between ops on tensors that
/// reduce to an equivalence on buffers. This should be generalized and
/// exposed as interfaces on the proper types.
static bool
areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
ExtractSliceOp st, InsertSliceOp sti) {
if (!st || !sti)
return false;
if (!aliasInfo.areEquivalentBufferizedValues(st.source(), sti.dest()))
return false;
if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
return false;
return true;
}
/// Return true if the source of a `insertSliceOp` bufferizes to an
/// equivalent ExtractSliceOp that bufferizes inplace.
static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp(
const BufferizationAliasInfo &aliasInfo, InsertSliceOp insertSliceOp) {
bool foundOp = false;
aliasInfo.applyOnEquivalenceClass(insertSliceOp.source(), [&](Value value) {
auto extractSliceOp = value.getDefiningOp<ExtractSliceOp>();
if (extractSliceOp &&
areEquivalentExtractSliceOps(aliasInfo, extractSliceOp,
insertSliceOp) &&
aliasInfo.isInPlace(extractSliceOp->getResult(0))) {
foundOp = true;
}
});
return foundOp;
}
/// Return true if `value` is originating from an ExtractSliceOp that matches
/// the given InsertSliceOp.
static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
Value value, InsertSliceOp insertOp) {
auto condition = [&](Value val) {
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp))
return true;
return false;
};
return llvm::all_of(findValueInReverseUseDefChain(value, condition),
condition);
}
struct InsertSliceOpInterface
: public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
tensor::InsertSliceOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
return &opOperand == &op->getOpOperand(1) /*dest*/;
}
SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
OpResult opResult) const {
return {&op->getOpOperand(1) /*dest*/};
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
return &opOperand == &op->getOpOperand(1) /*dest*/
? op->getResult(0)
: OpResult();
}
BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
return BufferRelation::Equivalent;
}
bool isNotConflicting(Operation *op, OpOperand *uRead,
OpOperand *uConflictingWrite,
const BufferizationAliasInfo &aliasInfo) const {
Operation *readingOp = uRead->getOwner();
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
// Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
// uRead is an InsertSliceOp...
if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
// As an example, consider the following IR.
//
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
// %1 = linalg.fill %cst, %0 {inplace= [true] }
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
// {inplace= [true] }
// TODO: Use insertSliceOp.getDestOpOperand etc. when available.
if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
hasMatchingExtractSliceOp(aliasInfo, uConflictingWrite->get(),
insertSliceOp))
// Case 1: The main insight is that InsertSliceOp reads only part of
// the destination tensor. The overwritten area is not read. If
// uConflictingWrite writes into exactly the memory location that is
// being read by uRead, this is not a conflict.
//
// In the above example:
// uRead = OpOperand 1 (%t) of tensor.insert_slice
// uConflictingWrite = OpOperand 1 (%0) of linalg.fill
//
// The read of %t does not conflict with the write of the FillOp
// (same aliases!) because the area that the FillOp operates on is
// exactly the one that is *not* read via %t.
return true;
if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
hasMatchingExtractSliceOp(aliasInfo, uRead->get(), insertSliceOp))
// Case 2: The read of the source tensor and the write to the dest
// tensor via an InsertSliceOp is not a conflict if the read is
// reading exactly that part of an equivalent tensor that the
// InsertSliceOp is writing.
//
// In the above example:
// uRead = OpOperand 0 (%1) of tensor.insert_slice
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
return true;
}
// If uConflictingWrite is an InsertSliceOp...
if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
// As an example, consider the following IR.
//
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
// %1 = linalg.fill %cst, %0 {inplace= [true] }
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
// {inplace= [true] }
// %3 = vector.transfer_read %1, %cst
//
// In the above example:
// uRead = OpOperand 0 (%1) of vector.transfer_read
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
// lastWrite = %1
//
// This is not a conflict because the InsertSliceOp overwrites the
// memory segment of %1 with the exact same data. (Effectively, there
// is no memory write here.)
if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
aliasInfo.areEquivalentBufferizedValues(uRead->get(),
insertSliceOp.source()) &&
hasMatchingExtractSliceOp(aliasInfo, insertSliceOp.source(),
insertSliceOp))
return true;
return false;
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
// insert_slice ops arise from tiling and bufferizing them out-of-place is
// generally a deal breaker. When used with loops, this ends up cloning the
// whole tensor on every single iteration and is a symptom of a
// catastrophically bad scheduling decision.
// TODO: be very loud about it or even consider failing the pass.
auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(insertSliceOp);
Location loc = insertSliceOp.getLoc();
// When bufferizing out-of-place, `getResultBuffer` allocates.
Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), state);
if (!dstMemref)
return failure();
// A copy of the source buffer is needed if either:
// - The producer of `source` is not inplace. This is the case where a
// slice is computed out of place into the inplace full tensor.
// - The result is not inplace. This is the case where the whole tensor is
// cloned and the clone needs to be updated.
// TODO: Is this necessary?
bool needCopy = !isSourceEquivalentToAMatchingInplaceExtractSliceOp(
state.aliasInfo, insertSliceOp) ||
!state.aliasInfo.isInPlace(insertSliceOp->getResult(0));
if (needCopy) {
// Take a subview of the dst.
auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
auto subviewMemRefType =
memref::SubViewOp::inferRankReducedResultType(
insertSliceOp.getSourceType().getRank(), dstMemrefType,
insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
insertSliceOp.getMixedStrides())
.cast<MemRefType>();
Value subView = b.create<memref::SubViewOp>(
loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
// Insert new alias.
state.aliasInfo.insertNewBufferAlias(subView, dstMemref);
// Copy tensor.
Value srcMemref = state.lookupBuffer(insertSliceOp.source());
state.allocationFns.memCpyFn(b, insertSliceOp.getLoc(), srcMemref,
subView);
}
state.mapBuffer(insertSliceOp.result(), dstMemref);
return success();
}
};
} // namespace tensor_ext
} // namespace comprehensive_bufferize
} // namespace linalg
} // namespace mlir
void mlir::linalg::comprehensive_bufferize::tensor_ext::
registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
registry.addOpInterface<tensor::CastOp, tensor_ext::CastOpInterface>();
registry.addOpInterface<tensor::DimOp, tensor_ext::DimOpInterface>();
registry.addOpInterface<tensor::ExtractSliceOp,
tensor_ext::ExtractSliceOpInterface>();
registry.addOpInterface<tensor::ExtractOp, tensor_ext::ExtractOpInterface>();
registry.addOpInterface<tensor::InsertSliceOp,
tensor_ext::InsertSliceOpInterface>();
}

View File

@ -49,6 +49,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRStandardOpsTransforms
MLIRStandardToLLVM
MLIRTensor
MLIRTensorBufferizableOpInterfaceImpl
MLIRTransforms
MLIRTransformUtils
MLIRVector

View File

@ -10,6 +10,7 @@
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
@ -38,6 +39,7 @@ struct LinalgComprehensiveModuleBufferize
arith::ArithmeticDialect, StandardOpsDialect, AffineDialect>();
registerBufferizableOpInterfaceExternalModels(registry);
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
tensor_ext::registerBufferizableOpInterfaceExternalModels(registry);
}
};
} // end namespace

View File

@ -6326,6 +6326,25 @@ cc_library(
],
)
cc_library(
name = "TensorBufferizableOpInterfaceImpl",
srcs = [
"lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp",
],
hdrs = [
"include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h",
],
includes = ["include"],
deps = [
":BufferizableOpInterface",
":IR",
":MemRefDialect",
":Support",
":TensorDialect",
"//llvm:Support",
],
)
td_library(
name = "LinalgDocTdFiles",
srcs = ["include/mlir/Dialect/Linalg/IR/LinalgDoc.td"],
@ -6545,6 +6564,7 @@ cc_library(
":StandardOps",
":StandardOpsTransforms",
":Support",
":TensorBufferizableOpInterfaceImpl",
":TensorDialect",
":TransformUtils",
":VectorOps",
@ -6575,7 +6595,6 @@ cc_library(
":SCFDialect",
":StandardOps",
":Support",
":TensorDialect",
":TransformUtils",
":VectorOps",
"//llvm:Support",