[mlir] Add a MemRefCastOp canonicalization pattern.

Summary:
This revision adds a conservative canonicalization pattern for MemRefCastOp that are typically inserted during ViewOp and SubViewOp canonicalization.
Ideally such canonicalizations would propagate the type to consumers but this is not a local behavior. As a consequence MemRefCastOp are introduced to keep type compatibility but need to be cleaned up later, in the case where more dynamic behavior than necessary is introduced.

Differential Revision: https://reviews.llvm.org/D79438
This commit is contained in:
Nicolas Vasilache 2020-05-06 09:05:15 -04:00
parent 8650b36935
commit 94438c86ad
5 changed files with 158 additions and 68 deletions

View File

@ -301,6 +301,44 @@ ParseResult parseDimAndSymbolList(OpAsmParser &parser,
raw_ostream &operator<<(raw_ostream &os, SubViewOp::Range &range); raw_ostream &operator<<(raw_ostream &os, SubViewOp::Range &range);
/// Determines whether MemRefCastOp casts to a more dynamic version of the
/// source memref. This is useful to to fold a memref_cast into a consuming op
/// and implement canonicalization patterns for ops in different dialects that
/// may consume the results of memref_cast operations. Such foldable memref_cast
/// operations are typically inserted as `view` and `subview` ops are
/// canonicalized, to preserve the type compatibility of their uses.
///
/// Returns true when all conditions are met:
/// 1. source and result are ranked memrefs with strided semantics and same
/// element type and rank.
/// 2. each of the source's size, offset or stride has more static information
/// than the corresponding result's size, offset or stride.
///
/// Example 1:
/// ```mlir
/// %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32>
/// %2 = consumer %1 ... : memref<?x?xf32> ...
/// ```
///
/// may fold into:
///
/// ```mlir
/// %2 = consumer %0 ... : memref<8x16xf32> ...
/// ```
///
/// Example 2:
/// ```
/// %1 = memref_cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
/// to memref<?x?xf32>
/// consumer %1 : memref<?x?xf32> ...
/// ```
///
/// may fold into:
///
/// ```
/// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
/// ```
bool canFoldIntoConsumerOp(MemRefCastOp castOp);
} // end namespace mlir } // end namespace mlir
#endif // MLIR_DIALECT_IR_STANDARDOPS_IR_OPS_H #endif // MLIR_DIALECT_IR_STANDARDOPS_IR_OPS_H

View File

@ -2606,6 +2606,7 @@ def SubViewOp : Std_Op<"subview", [
}]; }];
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
let hasFolder = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -44,82 +44,16 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
template <typename NamedStructuredOpType> template <typename NamedStructuredOpType>
static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op); static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op);
/// Determines whether it is possible to fold it away in the parent Linalg op:
///
/// ```mlir
/// %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32>
/// %2 = linalg.slice %1 ... : memref<?x?xf32> ...
/// // or
/// %1 = memref_cast %0 : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>>
/// to memref<?x?xf32>
/// linalg.generic(%1 ...) : memref<?x?xf32> ...
/// ```
///
/// into
///
/// ```mlir
/// %2 = linalg.slice %0 ... : memref<8x16xf32> ...
/// // or
/// linalg.generic(%0 ... : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>>
/// ```
///
static bool canFold(MemRefCastOp castOp) {
MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
// If we don't have MemRefType as source and destination, bail out.
if (!sourceType || !resultType)
return false;
// If resultType has a map, it needs to be the same as the source type to
// canonicalize.
if (!resultType.getAffineMaps().empty() &&
sourceType.getAffineMaps() != resultType.getAffineMaps())
return false;
// Ensure that:
// 1. source is static
// 2. source and target have the same rank (will be extended when needed)
// 3. if result is partially static, ensure sizes match.
if (!sourceType.hasStaticShape() ||
sourceType.getRank() != resultType.getRank())
return false;
for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
auto sourceSize = std::get<0>(it);
auto resultSize = std::get<1>(it);
if (ShapedType::isDynamic(resultSize))
continue;
if (sourceSize != resultSize)
return false;
}
// If source has a map, it can only canonicalize if it is the canonical
// strided layout map.
if (sourceType.getAffineMaps().empty())
return true;
int64_t offset;
SmallVector<int64_t, 4> strides;
auto res = getStridesAndOffset(sourceType, strides, offset);
(void)res;
assert(succeeded(res));
auto stridedMap =
makeStridedLinearLayoutMap(strides, offset, castOp.getContext());
AffineMap sourceMap = sourceType.getAffineMaps().front();
return sourceMap == stridedMap;
}
/// This is a common class used for patterns of the form /// This is a common class used for patterns of the form
/// ``` /// ```
/// someop(memrefcast) -> someop /// someop(memrefcast) -> someop
/// ``` /// ```
/// It folds the source of any memref_cast into the root operation directly. /// It folds the source of the memref_cast into the root operation directly.
static LogicalResult foldMemRefCast(Operation *op) { static LogicalResult foldMemRefCast(Operation *op) {
bool folded = false; bool folded = false;
for (OpOperand &operand : op->getOpOperands()) { for (OpOperand &operand : op->getOpOperands()) {
auto castOp = dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp()); auto castOp = dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
if (castOp && canFold(castOp)) { if (castOp && canFoldIntoConsumerOp(castOp)) {
operand.set(castOp.getOperand()); operand.set(castOp.getOperand());
folded = true; folded = true;
} }

View File

@ -2519,6 +2519,111 @@ public:
} // end anonymous namespace } // end anonymous namespace
/// Determines whether MemRefCastOp casts to a more dynamic version of the
/// source memref. This is useful to to fold a memref_cast into a consuming op
/// and implement canonicalization patterns for ops in different dialects that
/// may consume the results of memref_cast operations. Such foldable memref_cast
/// operations are typically inserted as `view` and `subview` ops are
/// canonicalized, to preserve the type compatibility of their uses.
///
/// Returns true when all conditions are met:
/// 1. source and result are ranked memrefs with strided semantics and same
/// element type and rank.
/// 2. each of the source's size, offset or stride has more static information
/// than the corresponding result's size, offset or stride.
///
/// Example 1:
/// ```mlir
/// %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32>
/// %2 = consumer %1 ... : memref<?x?xf32> ...
/// ```
///
/// may fold into:
///
/// ```mlir
/// %2 = consumer %0 ... : memref<8x16xf32> ...
/// ```
///
/// Example 2:
/// ```
/// %1 = memref_cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
/// to memref<?x?xf32>
/// consumer %1 : memref<?x?xf32> ...
/// ```
///
/// may fold into:
///
/// ```
/// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
/// ```
bool mlir::canFoldIntoConsumerOp(MemRefCastOp castOp) {
MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
// Requires ranked MemRefType.
if (!sourceType || !resultType)
return false;
// Requires same elemental type.
if (sourceType.getElementType() != resultType.getElementType())
return false;
// Requires same rank.
if (sourceType.getRank() != resultType.getRank())
return false;
// Only fold casts between strided memref forms.
int64_t sourceOffset, resultOffset;
SmallVector<int64_t, 4> sourceStrides, resultStrides;
if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
return false;
// If cast is towards more static sizes along any dimension, don't fold.
for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
auto ss = std::get<0>(it), st = std::get<1>(it);
if (ss != st)
if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st))
return false;
}
// If cast is towards more static offset along any dimension, don't fold.
if (sourceOffset != resultOffset)
if (MemRefType::isDynamicStrideOrOffset(sourceOffset) &&
!MemRefType::isDynamicStrideOrOffset(resultOffset))
return false;
// If cast is towards more static strides along any dimension, don't fold.
for (auto it : llvm::zip(sourceStrides, resultStrides)) {
auto ss = std::get<0>(it), st = std::get<1>(it);
if (ss != st)
if (MemRefType::isDynamicStrideOrOffset(ss) &&
!MemRefType::isDynamicStrideOrOffset(st))
return false;
}
return true;
}
OpFoldResult SubViewOp::fold(ArrayRef<Attribute>) {
auto folds = [](Operation *op) {
bool folded = false;
for (OpOperand &operand : op->getOpOperands()) {
auto castOp =
dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
if (castOp && canFoldIntoConsumerOp(castOp)) {
operand.set(castOp.getOperand());
folded = true;
}
}
return folded ? success() : failure();
};
if (succeeded(folds(*this)))
return getResult();
return {};
}
void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) { MLIRContext *context) {
results.insert<SubViewOpShapeFolder, SubViewOpStrideFolder, results.insert<SubViewOpShapeFolder, SubViewOpStrideFolder,

View File

@ -919,3 +919,15 @@ func @tensor_divi_unsigned_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> {
// CHECK: return %[[ARG]] // CHECK: return %[[ARG]]
return %res : tensor<4x5xi32> return %res : tensor<4x5xi32>
} }
// -----
// CHECK-LABEL: func @memref_cast_folding_subview
func @memref_cast_folding_subview(%arg0: memref<4x5xf32>, %i: index) -> (memref<?x?xf32, offset:? , strides: [?, ?]>) {
%0 = memref_cast %arg0 : memref<4x5xf32> to memref<?x?xf32>
// CHECK-NEXT: subview %{{.*}}: memref<4x5xf32>
%1 = subview %0[][%i,%i][]: memref<?x?xf32> to memref<?x?xf32, offset:? , strides: [?, ?]>
// CHECK-NEXT: return %{{.*}}
return %1: memref<?x?xf32, offset:? , strides: [?, ?]>
}