forked from OSchip/llvm-project
[mlir] Add a MemRefCastOp canonicalization pattern.
Summary: This revision adds a conservative canonicalization pattern for MemRefCastOp that are typically inserted during ViewOp and SubViewOp canonicalization. Ideally such canonicalizations would propagate the type to consumers but this is not a local behavior. As a consequence MemRefCastOp are introduced to keep type compatibility but need to be cleaned up later, in the case where more dynamic behavior than necessary is introduced. Differential Revision: https://reviews.llvm.org/D79438
This commit is contained in:
parent
8650b36935
commit
94438c86ad
|
@ -301,6 +301,44 @@ ParseResult parseDimAndSymbolList(OpAsmParser &parser,
|
|||
|
||||
raw_ostream &operator<<(raw_ostream &os, SubViewOp::Range &range);
|
||||
|
||||
/// Determines whether MemRefCastOp casts to a more dynamic version of the
|
||||
/// source memref. This is useful to to fold a memref_cast into a consuming op
|
||||
/// and implement canonicalization patterns for ops in different dialects that
|
||||
/// may consume the results of memref_cast operations. Such foldable memref_cast
|
||||
/// operations are typically inserted as `view` and `subview` ops are
|
||||
/// canonicalized, to preserve the type compatibility of their uses.
|
||||
///
|
||||
/// Returns true when all conditions are met:
|
||||
/// 1. source and result are ranked memrefs with strided semantics and same
|
||||
/// element type and rank.
|
||||
/// 2. each of the source's size, offset or stride has more static information
|
||||
/// than the corresponding result's size, offset or stride.
|
||||
///
|
||||
/// Example 1:
|
||||
/// ```mlir
|
||||
/// %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32>
|
||||
/// %2 = consumer %1 ... : memref<?x?xf32> ...
|
||||
/// ```
|
||||
///
|
||||
/// may fold into:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %2 = consumer %0 ... : memref<8x16xf32> ...
|
||||
/// ```
|
||||
///
|
||||
/// Example 2:
|
||||
/// ```
|
||||
/// %1 = memref_cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
|
||||
/// to memref<?x?xf32>
|
||||
/// consumer %1 : memref<?x?xf32> ...
|
||||
/// ```
|
||||
///
|
||||
/// may fold into:
|
||||
///
|
||||
/// ```
|
||||
/// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
|
||||
/// ```
|
||||
bool canFoldIntoConsumerOp(MemRefCastOp castOp);
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_IR_STANDARDOPS_IR_OPS_H
|
||||
|
|
|
@ -2606,6 +2606,7 @@ def SubViewOp : Std_Op<"subview", [
|
|||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -44,82 +44,16 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
|
|||
template <typename NamedStructuredOpType>
|
||||
static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op);
|
||||
|
||||
/// Determines whether it is possible to fold it away in the parent Linalg op:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32>
|
||||
/// %2 = linalg.slice %1 ... : memref<?x?xf32> ...
|
||||
/// // or
|
||||
/// %1 = memref_cast %0 : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>>
|
||||
/// to memref<?x?xf32>
|
||||
/// linalg.generic(%1 ...) : memref<?x?xf32> ...
|
||||
/// ```
|
||||
///
|
||||
/// into
|
||||
///
|
||||
/// ```mlir
|
||||
/// %2 = linalg.slice %0 ... : memref<8x16xf32> ...
|
||||
/// // or
|
||||
/// linalg.generic(%0 ... : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>>
|
||||
/// ```
|
||||
///
|
||||
static bool canFold(MemRefCastOp castOp) {
|
||||
MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
|
||||
MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
|
||||
|
||||
// If we don't have MemRefType as source and destination, bail out.
|
||||
if (!sourceType || !resultType)
|
||||
return false;
|
||||
|
||||
// If resultType has a map, it needs to be the same as the source type to
|
||||
// canonicalize.
|
||||
if (!resultType.getAffineMaps().empty() &&
|
||||
sourceType.getAffineMaps() != resultType.getAffineMaps())
|
||||
return false;
|
||||
|
||||
// Ensure that:
|
||||
// 1. source is static
|
||||
// 2. source and target have the same rank (will be extended when needed)
|
||||
// 3. if result is partially static, ensure sizes match.
|
||||
if (!sourceType.hasStaticShape() ||
|
||||
sourceType.getRank() != resultType.getRank())
|
||||
return false;
|
||||
|
||||
for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
|
||||
auto sourceSize = std::get<0>(it);
|
||||
auto resultSize = std::get<1>(it);
|
||||
if (ShapedType::isDynamic(resultSize))
|
||||
continue;
|
||||
if (sourceSize != resultSize)
|
||||
return false;
|
||||
}
|
||||
|
||||
// If source has a map, it can only canonicalize if it is the canonical
|
||||
// strided layout map.
|
||||
if (sourceType.getAffineMaps().empty())
|
||||
return true;
|
||||
|
||||
int64_t offset;
|
||||
SmallVector<int64_t, 4> strides;
|
||||
auto res = getStridesAndOffset(sourceType, strides, offset);
|
||||
(void)res;
|
||||
assert(succeeded(res));
|
||||
auto stridedMap =
|
||||
makeStridedLinearLayoutMap(strides, offset, castOp.getContext());
|
||||
AffineMap sourceMap = sourceType.getAffineMaps().front();
|
||||
return sourceMap == stridedMap;
|
||||
}
|
||||
|
||||
/// This is a common class used for patterns of the form
|
||||
/// ```
|
||||
/// someop(memrefcast) -> someop
|
||||
/// ```
|
||||
/// It folds the source of any memref_cast into the root operation directly.
|
||||
/// It folds the source of the memref_cast into the root operation directly.
|
||||
static LogicalResult foldMemRefCast(Operation *op) {
|
||||
bool folded = false;
|
||||
for (OpOperand &operand : op->getOpOperands()) {
|
||||
auto castOp = dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
|
||||
if (castOp && canFold(castOp)) {
|
||||
if (castOp && canFoldIntoConsumerOp(castOp)) {
|
||||
operand.set(castOp.getOperand());
|
||||
folded = true;
|
||||
}
|
||||
|
|
|
@ -2519,6 +2519,111 @@ public:
|
|||
|
||||
} // end anonymous namespace
|
||||
|
||||
/// Determines whether MemRefCastOp casts to a more dynamic version of the
|
||||
/// source memref. This is useful to to fold a memref_cast into a consuming op
|
||||
/// and implement canonicalization patterns for ops in different dialects that
|
||||
/// may consume the results of memref_cast operations. Such foldable memref_cast
|
||||
/// operations are typically inserted as `view` and `subview` ops are
|
||||
/// canonicalized, to preserve the type compatibility of their uses.
|
||||
///
|
||||
/// Returns true when all conditions are met:
|
||||
/// 1. source and result are ranked memrefs with strided semantics and same
|
||||
/// element type and rank.
|
||||
/// 2. each of the source's size, offset or stride has more static information
|
||||
/// than the corresponding result's size, offset or stride.
|
||||
///
|
||||
/// Example 1:
|
||||
/// ```mlir
|
||||
/// %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32>
|
||||
/// %2 = consumer %1 ... : memref<?x?xf32> ...
|
||||
/// ```
|
||||
///
|
||||
/// may fold into:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %2 = consumer %0 ... : memref<8x16xf32> ...
|
||||
/// ```
|
||||
///
|
||||
/// Example 2:
|
||||
/// ```
|
||||
/// %1 = memref_cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
|
||||
/// to memref<?x?xf32>
|
||||
/// consumer %1 : memref<?x?xf32> ...
|
||||
/// ```
|
||||
///
|
||||
/// may fold into:
|
||||
///
|
||||
/// ```
|
||||
/// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
|
||||
/// ```
|
||||
bool mlir::canFoldIntoConsumerOp(MemRefCastOp castOp) {
|
||||
MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
|
||||
MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
|
||||
|
||||
// Requires ranked MemRefType.
|
||||
if (!sourceType || !resultType)
|
||||
return false;
|
||||
|
||||
// Requires same elemental type.
|
||||
if (sourceType.getElementType() != resultType.getElementType())
|
||||
return false;
|
||||
|
||||
// Requires same rank.
|
||||
if (sourceType.getRank() != resultType.getRank())
|
||||
return false;
|
||||
|
||||
// Only fold casts between strided memref forms.
|
||||
int64_t sourceOffset, resultOffset;
|
||||
SmallVector<int64_t, 4> sourceStrides, resultStrides;
|
||||
if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
|
||||
failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
|
||||
return false;
|
||||
|
||||
// If cast is towards more static sizes along any dimension, don't fold.
|
||||
for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
|
||||
auto ss = std::get<0>(it), st = std::get<1>(it);
|
||||
if (ss != st)
|
||||
if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st))
|
||||
return false;
|
||||
}
|
||||
|
||||
// If cast is towards more static offset along any dimension, don't fold.
|
||||
if (sourceOffset != resultOffset)
|
||||
if (MemRefType::isDynamicStrideOrOffset(sourceOffset) &&
|
||||
!MemRefType::isDynamicStrideOrOffset(resultOffset))
|
||||
return false;
|
||||
|
||||
// If cast is towards more static strides along any dimension, don't fold.
|
||||
for (auto it : llvm::zip(sourceStrides, resultStrides)) {
|
||||
auto ss = std::get<0>(it), st = std::get<1>(it);
|
||||
if (ss != st)
|
||||
if (MemRefType::isDynamicStrideOrOffset(ss) &&
|
||||
!MemRefType::isDynamicStrideOrOffset(st))
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
OpFoldResult SubViewOp::fold(ArrayRef<Attribute>) {
|
||||
auto folds = [](Operation *op) {
|
||||
bool folded = false;
|
||||
for (OpOperand &operand : op->getOpOperands()) {
|
||||
auto castOp =
|
||||
dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
|
||||
if (castOp && canFoldIntoConsumerOp(castOp)) {
|
||||
operand.set(castOp.getOperand());
|
||||
folded = true;
|
||||
}
|
||||
}
|
||||
return folded ? success() : failure();
|
||||
};
|
||||
|
||||
if (succeeded(folds(*this)))
|
||||
return getResult();
|
||||
return {};
|
||||
}
|
||||
|
||||
void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<SubViewOpShapeFolder, SubViewOpStrideFolder,
|
||||
|
|
|
@ -919,3 +919,15 @@ func @tensor_divi_unsigned_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> {
|
|||
// CHECK: return %[[ARG]]
|
||||
return %res : tensor<4x5xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @memref_cast_folding_subview
|
||||
func @memref_cast_folding_subview(%arg0: memref<4x5xf32>, %i: index) -> (memref<?x?xf32, offset:? , strides: [?, ?]>) {
|
||||
%0 = memref_cast %arg0 : memref<4x5xf32> to memref<?x?xf32>
|
||||
// CHECK-NEXT: subview %{{.*}}: memref<4x5xf32>
|
||||
%1 = subview %0[][%i,%i][]: memref<?x?xf32> to memref<?x?xf32, offset:? , strides: [?, ?]>
|
||||
// CHECK-NEXT: return %{{.*}}
|
||||
return %1: memref<?x?xf32, offset:? , strides: [?, ?]>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue