forked from OSchip/llvm-project
[mlir][bufferize] Infer memref types when possible
Instead of recomputing memref types from tensor types, try to infer them when possible. This results in more precise layout maps. Differential Revision: https://reviews.llvm.org/D125614
This commit is contained in:
parent
b3077f563d
commit
12e41d9264
|
@ -503,6 +503,9 @@ struct BufferizationState {
|
|||
Optional<Operation *> customCopyInsertionPoint = None);
|
||||
|
||||
/// Return the buffer type for a given OpOperand (tensor) after bufferization.
|
||||
///
|
||||
/// Note: Op implementations should preferrably call `getBuffer()->getType()`.
|
||||
/// This function should only be used if `getBuffer` cannot be used.
|
||||
BaseMemRefType getBufferType(OpOperand &opOperand) const;
|
||||
|
||||
/// Return a reference to the BufferizationOptions.
|
||||
|
@ -546,9 +549,18 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
|
|||
return newOp;
|
||||
}
|
||||
|
||||
/// Return a MemRefType to which the `tensorType` can be bufferized in a
|
||||
/// composable fashion. The layout must be the most dynamic possible and
|
||||
/// canonicalize away once bufferization is finished.
|
||||
/// Return a MemRefType to which the `tensorType` can be bufferized.
|
||||
///
|
||||
/// If possible, op bufferization implementations should not use this function
|
||||
/// and instead infer precise memref types for tensor results by themselves.
|
||||
///
|
||||
/// Unless a layout map was specified, `options` flags determine what kind of
|
||||
/// layout map will be used. For best composability (without copies), the fully
|
||||
/// dynamic layout map is used by default.
|
||||
///
|
||||
/// Note: Canonicalization patterns could clean up layout maps and infer more
|
||||
/// precise layout maps after bufferization. However, many possible
|
||||
/// canonicalizations are currently not implemented.
|
||||
BaseMemRefType getMemRefType(TensorType tensorType,
|
||||
const BufferizationOptions &options,
|
||||
MemRefLayoutAttrInterface layout = {},
|
||||
|
|
|
@ -82,17 +82,22 @@ struct IndexCastOpInterface
|
|||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
auto castOp = cast<arith::IndexCastOp>(op);
|
||||
auto resultTensorType = castOp.getType().cast<TensorType>();
|
||||
|
||||
Value source = *state.getBuffer(rewriter, op->getOpOperand(0) /*in*/);
|
||||
auto sourceType = source.getType().cast<BaseMemRefType>();
|
||||
|
||||
// Result type should have same layout and address space as the source type.
|
||||
MemRefLayoutAttrInterface layout = {};
|
||||
if (auto rankedMemRefType = sourceType.dyn_cast<MemRefType>())
|
||||
layout = rankedMemRefType.getLayout();
|
||||
Type resultType =
|
||||
getMemRefType(castOp.getType().cast<TensorType>(), state.getOptions(),
|
||||
layout, sourceType.getMemorySpace());
|
||||
BaseMemRefType resultType;
|
||||
if (auto rankedMemRefType = sourceType.dyn_cast<MemRefType>()) {
|
||||
resultType = MemRefType::get(
|
||||
rankedMemRefType.getShape(), resultTensorType.getElementType(),
|
||||
rankedMemRefType.getLayout(), rankedMemRefType.getMemorySpace());
|
||||
} else {
|
||||
auto unrankedMemrefType = sourceType.cast<UnrankedMemRefType>();
|
||||
resultType = UnrankedMemRefType::get(resultTensorType.getElementType(),
|
||||
unrankedMemrefType.getMemorySpace());
|
||||
}
|
||||
|
||||
replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, resultType,
|
||||
source);
|
||||
|
@ -146,15 +151,14 @@ struct SelectOpInterface
|
|||
// both of them to the most dynamic MemRef type.
|
||||
if (trueBuffer.getType() != falseBuffer.getType()) {
|
||||
auto trueType = trueBuffer.getType().cast<MemRefType>();
|
||||
auto tensorType = selectOp.getTrueValue().getType().cast<TensorType>();
|
||||
int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
|
||||
SmallVector<int64_t> dynamicStrides(tensorType.getRank(),
|
||||
SmallVector<int64_t> dynamicStrides(trueType.getRank(),
|
||||
ShapedType::kDynamicStrideOrOffset);
|
||||
AffineMap stridedLayout = makeStridedLinearLayoutMap(
|
||||
dynamicStrides, dynamicOffset, op->getContext());
|
||||
BaseMemRefType castedType = bufferization::getMemRefType(
|
||||
tensorType, state.getOptions(), AffineMapAttr::get(stridedLayout),
|
||||
trueType.getMemorySpace());
|
||||
auto castedType =
|
||||
MemRefType::get(trueType.getShape(), trueType.getElementType(),
|
||||
stridedLayout, trueType.getMemorySpaceAsInt());
|
||||
trueBuffer = rewriter.create<memref::CastOp>(loc, castedType, trueBuffer);
|
||||
falseBuffer =
|
||||
rewriter.create<memref::CastOp>(loc, castedType, falseBuffer);
|
||||
|
|
|
@ -79,6 +79,7 @@ struct ExecuteRegionOpInterface
|
|||
SmallVector<Type> newResultTypes;
|
||||
for (Type type : executeRegionOp->getResultTypes()) {
|
||||
if (auto tensorType = type.dyn_cast<TensorType>()) {
|
||||
// TODO: Infer the result type instead of computing it.
|
||||
newResultTypes.push_back(getMemRefType(tensorType, state.getOptions()));
|
||||
} else {
|
||||
newResultTypes.push_back(type);
|
||||
|
@ -188,6 +189,7 @@ struct IfOpInterface
|
|||
SmallVector<Type> newTypes;
|
||||
for (Type returnType : ifOp->getResultTypes()) {
|
||||
if (auto tensorType = returnType.dyn_cast<TensorType>()) {
|
||||
// TODO: Infer the result type instead of computing it.
|
||||
newTypes.push_back(getMemRefType(tensorType, state.getOptions()));
|
||||
} else {
|
||||
newTypes.push_back(returnType);
|
||||
|
|
|
@ -66,6 +66,7 @@ struct AssumingOpInterface
|
|||
SmallVector<Type> newResultTypes;
|
||||
for (Type type : assumingOp->getResultTypes()) {
|
||||
if (auto tensorType = type.dyn_cast<TensorType>()) {
|
||||
// TODO: Infer the result type instead of computing it.
|
||||
newResultTypes.push_back(getMemRefType(tensorType, state.getOptions()));
|
||||
} else {
|
||||
newResultTypes.push_back(type);
|
||||
|
|
Loading…
Reference in New Issue