[mlir][bufferize] Infer memref types when possible

Instead of recomputing memref types from tensor types, try to infer them when possible. This results in more precise layout maps.

Differential Revision: https://reviews.llvm.org/D125614
This commit is contained in:
Matthias Springer 2022-05-16 01:53:51 +02:00
parent b3077f563d
commit 12e41d9264
4 changed files with 33 additions and 14 deletions

View File

@ -503,6 +503,9 @@ struct BufferizationState {
Optional<Operation *> customCopyInsertionPoint = None);
/// Return the buffer type for a given OpOperand (tensor) after bufferization.
///
/// Note: Op implementations should preferrably call `getBuffer()->getType()`.
/// This function should only be used if `getBuffer` cannot be used.
BaseMemRefType getBufferType(OpOperand &opOperand) const;
/// Return a reference to the BufferizationOptions.
@ -546,9 +549,18 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
return newOp;
}
/// Return a MemRefType to which the `tensorType` can be bufferized in a
/// composable fashion. The layout must be the most dynamic possible and
/// canonicalize away once bufferization is finished.
/// Return a MemRefType to which the `tensorType` can be bufferized.
///
/// If possible, op bufferization implementations should not use this function
/// and instead infer precise memref types for tensor results by themselves.
///
/// Unless a layout map was specified, `options` flags determine what kind of
/// layout map will be used. For best composability (without copies), the fully
/// dynamic layout map is used by default.
///
/// Note: Canonicalization patterns could clean up layout maps and infer more
/// precise layout maps after bufferization. However, many possible
/// canonicalizations are currently not implemented.
BaseMemRefType getMemRefType(TensorType tensorType,
const BufferizationOptions &options,
MemRefLayoutAttrInterface layout = {},

View File

@ -82,17 +82,22 @@ struct IndexCastOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
auto castOp = cast<arith::IndexCastOp>(op);
auto resultTensorType = castOp.getType().cast<TensorType>();
Value source = *state.getBuffer(rewriter, op->getOpOperand(0) /*in*/);
auto sourceType = source.getType().cast<BaseMemRefType>();
// Result type should have same layout and address space as the source type.
MemRefLayoutAttrInterface layout = {};
if (auto rankedMemRefType = sourceType.dyn_cast<MemRefType>())
layout = rankedMemRefType.getLayout();
Type resultType =
getMemRefType(castOp.getType().cast<TensorType>(), state.getOptions(),
layout, sourceType.getMemorySpace());
BaseMemRefType resultType;
if (auto rankedMemRefType = sourceType.dyn_cast<MemRefType>()) {
resultType = MemRefType::get(
rankedMemRefType.getShape(), resultTensorType.getElementType(),
rankedMemRefType.getLayout(), rankedMemRefType.getMemorySpace());
} else {
auto unrankedMemrefType = sourceType.cast<UnrankedMemRefType>();
resultType = UnrankedMemRefType::get(resultTensorType.getElementType(),
unrankedMemrefType.getMemorySpace());
}
replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, resultType,
source);
@ -146,15 +151,14 @@ struct SelectOpInterface
// both of them to the most dynamic MemRef type.
if (trueBuffer.getType() != falseBuffer.getType()) {
auto trueType = trueBuffer.getType().cast<MemRefType>();
auto tensorType = selectOp.getTrueValue().getType().cast<TensorType>();
int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
SmallVector<int64_t> dynamicStrides(tensorType.getRank(),
SmallVector<int64_t> dynamicStrides(trueType.getRank(),
ShapedType::kDynamicStrideOrOffset);
AffineMap stridedLayout = makeStridedLinearLayoutMap(
dynamicStrides, dynamicOffset, op->getContext());
BaseMemRefType castedType = bufferization::getMemRefType(
tensorType, state.getOptions(), AffineMapAttr::get(stridedLayout),
trueType.getMemorySpace());
auto castedType =
MemRefType::get(trueType.getShape(), trueType.getElementType(),
stridedLayout, trueType.getMemorySpaceAsInt());
trueBuffer = rewriter.create<memref::CastOp>(loc, castedType, trueBuffer);
falseBuffer =
rewriter.create<memref::CastOp>(loc, castedType, falseBuffer);

View File

@ -79,6 +79,7 @@ struct ExecuteRegionOpInterface
SmallVector<Type> newResultTypes;
for (Type type : executeRegionOp->getResultTypes()) {
if (auto tensorType = type.dyn_cast<TensorType>()) {
// TODO: Infer the result type instead of computing it.
newResultTypes.push_back(getMemRefType(tensorType, state.getOptions()));
} else {
newResultTypes.push_back(type);
@ -188,6 +189,7 @@ struct IfOpInterface
SmallVector<Type> newTypes;
for (Type returnType : ifOp->getResultTypes()) {
if (auto tensorType = returnType.dyn_cast<TensorType>()) {
// TODO: Infer the result type instead of computing it.
newTypes.push_back(getMemRefType(tensorType, state.getOptions()));
} else {
newTypes.push_back(returnType);

View File

@ -66,6 +66,7 @@ struct AssumingOpInterface
SmallVector<Type> newResultTypes;
for (Type type : assumingOp->getResultTypes()) {
if (auto tensorType = type.dyn_cast<TensorType>()) {
// TODO: Infer the result type instead of computing it.
newResultTypes.push_back(getMemRefType(tensorType, state.getOptions()));
} else {
newResultTypes.push_back(type);