forked from OSchip/llvm-project
[mlir] Add a MemRefCastOp canonicalization pattern.
Summary: This revision adds a conservative canonicalization pattern for MemRefCastOp that are typically inserted during ViewOp and SubViewOp canonicalization. Ideally such canonicalizations would propagate the type to consumers but this is not a local behavior. As a consequence MemRefCastOp are introduced to keep type compatibility but need to be cleaned up later, in the case where more dynamic behavior than necessary is introduced. Differential Revision: https://reviews.llvm.org/D79438
This commit is contained in:
parent
8650b36935
commit
94438c86ad
|
@ -301,6 +301,44 @@ ParseResult parseDimAndSymbolList(OpAsmParser &parser,
|
||||||
|
|
||||||
raw_ostream &operator<<(raw_ostream &os, SubViewOp::Range &range);
|
raw_ostream &operator<<(raw_ostream &os, SubViewOp::Range &range);
|
||||||
|
|
||||||
|
/// Determines whether MemRefCastOp casts to a more dynamic version of the
|
||||||
|
/// source memref. This is useful to to fold a memref_cast into a consuming op
|
||||||
|
/// and implement canonicalization patterns for ops in different dialects that
|
||||||
|
/// may consume the results of memref_cast operations. Such foldable memref_cast
|
||||||
|
/// operations are typically inserted as `view` and `subview` ops are
|
||||||
|
/// canonicalized, to preserve the type compatibility of their uses.
|
||||||
|
///
|
||||||
|
/// Returns true when all conditions are met:
|
||||||
|
/// 1. source and result are ranked memrefs with strided semantics and same
|
||||||
|
/// element type and rank.
|
||||||
|
/// 2. each of the source's size, offset or stride has more static information
|
||||||
|
/// than the corresponding result's size, offset or stride.
|
||||||
|
///
|
||||||
|
/// Example 1:
|
||||||
|
/// ```mlir
|
||||||
|
/// %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32>
|
||||||
|
/// %2 = consumer %1 ... : memref<?x?xf32> ...
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// may fold into:
|
||||||
|
///
|
||||||
|
/// ```mlir
|
||||||
|
/// %2 = consumer %0 ... : memref<8x16xf32> ...
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// Example 2:
|
||||||
|
/// ```
|
||||||
|
/// %1 = memref_cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
|
||||||
|
/// to memref<?x?xf32>
|
||||||
|
/// consumer %1 : memref<?x?xf32> ...
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// may fold into:
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
|
||||||
|
/// ```
|
||||||
|
bool canFoldIntoConsumerOp(MemRefCastOp castOp);
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
||||||
#endif // MLIR_DIALECT_IR_STANDARDOPS_IR_OPS_H
|
#endif // MLIR_DIALECT_IR_STANDARDOPS_IR_OPS_H
|
||||||
|
|
|
@ -2606,6 +2606,7 @@ def SubViewOp : Std_Op<"subview", [
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -44,82 +44,16 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
|
||||||
template <typename NamedStructuredOpType>
|
template <typename NamedStructuredOpType>
|
||||||
static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op);
|
static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op);
|
||||||
|
|
||||||
/// Determines whether it is possible to fold it away in the parent Linalg op:
|
|
||||||
///
|
|
||||||
/// ```mlir
|
|
||||||
/// %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32>
|
|
||||||
/// %2 = linalg.slice %1 ... : memref<?x?xf32> ...
|
|
||||||
/// // or
|
|
||||||
/// %1 = memref_cast %0 : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>>
|
|
||||||
/// to memref<?x?xf32>
|
|
||||||
/// linalg.generic(%1 ...) : memref<?x?xf32> ...
|
|
||||||
/// ```
|
|
||||||
///
|
|
||||||
/// into
|
|
||||||
///
|
|
||||||
/// ```mlir
|
|
||||||
/// %2 = linalg.slice %0 ... : memref<8x16xf32> ...
|
|
||||||
/// // or
|
|
||||||
/// linalg.generic(%0 ... : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>>
|
|
||||||
/// ```
|
|
||||||
///
|
|
||||||
static bool canFold(MemRefCastOp castOp) {
|
|
||||||
MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
|
|
||||||
MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
|
|
||||||
|
|
||||||
// If we don't have MemRefType as source and destination, bail out.
|
|
||||||
if (!sourceType || !resultType)
|
|
||||||
return false;
|
|
||||||
|
|
||||||
// If resultType has a map, it needs to be the same as the source type to
|
|
||||||
// canonicalize.
|
|
||||||
if (!resultType.getAffineMaps().empty() &&
|
|
||||||
sourceType.getAffineMaps() != resultType.getAffineMaps())
|
|
||||||
return false;
|
|
||||||
|
|
||||||
// Ensure that:
|
|
||||||
// 1. source is static
|
|
||||||
// 2. source and target have the same rank (will be extended when needed)
|
|
||||||
// 3. if result is partially static, ensure sizes match.
|
|
||||||
if (!sourceType.hasStaticShape() ||
|
|
||||||
sourceType.getRank() != resultType.getRank())
|
|
||||||
return false;
|
|
||||||
|
|
||||||
for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
|
|
||||||
auto sourceSize = std::get<0>(it);
|
|
||||||
auto resultSize = std::get<1>(it);
|
|
||||||
if (ShapedType::isDynamic(resultSize))
|
|
||||||
continue;
|
|
||||||
if (sourceSize != resultSize)
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// If source has a map, it can only canonicalize if it is the canonical
|
|
||||||
// strided layout map.
|
|
||||||
if (sourceType.getAffineMaps().empty())
|
|
||||||
return true;
|
|
||||||
|
|
||||||
int64_t offset;
|
|
||||||
SmallVector<int64_t, 4> strides;
|
|
||||||
auto res = getStridesAndOffset(sourceType, strides, offset);
|
|
||||||
(void)res;
|
|
||||||
assert(succeeded(res));
|
|
||||||
auto stridedMap =
|
|
||||||
makeStridedLinearLayoutMap(strides, offset, castOp.getContext());
|
|
||||||
AffineMap sourceMap = sourceType.getAffineMaps().front();
|
|
||||||
return sourceMap == stridedMap;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// This is a common class used for patterns of the form
|
/// This is a common class used for patterns of the form
|
||||||
/// ```
|
/// ```
|
||||||
/// someop(memrefcast) -> someop
|
/// someop(memrefcast) -> someop
|
||||||
/// ```
|
/// ```
|
||||||
/// It folds the source of any memref_cast into the root operation directly.
|
/// It folds the source of the memref_cast into the root operation directly.
|
||||||
static LogicalResult foldMemRefCast(Operation *op) {
|
static LogicalResult foldMemRefCast(Operation *op) {
|
||||||
bool folded = false;
|
bool folded = false;
|
||||||
for (OpOperand &operand : op->getOpOperands()) {
|
for (OpOperand &operand : op->getOpOperands()) {
|
||||||
auto castOp = dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
|
auto castOp = dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
|
||||||
if (castOp && canFold(castOp)) {
|
if (castOp && canFoldIntoConsumerOp(castOp)) {
|
||||||
operand.set(castOp.getOperand());
|
operand.set(castOp.getOperand());
|
||||||
folded = true;
|
folded = true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -2519,6 +2519,111 @@ public:
|
||||||
|
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
/// Determines whether MemRefCastOp casts to a more dynamic version of the
|
||||||
|
/// source memref. This is useful to to fold a memref_cast into a consuming op
|
||||||
|
/// and implement canonicalization patterns for ops in different dialects that
|
||||||
|
/// may consume the results of memref_cast operations. Such foldable memref_cast
|
||||||
|
/// operations are typically inserted as `view` and `subview` ops are
|
||||||
|
/// canonicalized, to preserve the type compatibility of their uses.
|
||||||
|
///
|
||||||
|
/// Returns true when all conditions are met:
|
||||||
|
/// 1. source and result are ranked memrefs with strided semantics and same
|
||||||
|
/// element type and rank.
|
||||||
|
/// 2. each of the source's size, offset or stride has more static information
|
||||||
|
/// than the corresponding result's size, offset or stride.
|
||||||
|
///
|
||||||
|
/// Example 1:
|
||||||
|
/// ```mlir
|
||||||
|
/// %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32>
|
||||||
|
/// %2 = consumer %1 ... : memref<?x?xf32> ...
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// may fold into:
|
||||||
|
///
|
||||||
|
/// ```mlir
|
||||||
|
/// %2 = consumer %0 ... : memref<8x16xf32> ...
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// Example 2:
|
||||||
|
/// ```
|
||||||
|
/// %1 = memref_cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
|
||||||
|
/// to memref<?x?xf32>
|
||||||
|
/// consumer %1 : memref<?x?xf32> ...
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// may fold into:
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
|
||||||
|
/// ```
|
||||||
|
bool mlir::canFoldIntoConsumerOp(MemRefCastOp castOp) {
|
||||||
|
MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
|
||||||
|
MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
|
||||||
|
|
||||||
|
// Requires ranked MemRefType.
|
||||||
|
if (!sourceType || !resultType)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
// Requires same elemental type.
|
||||||
|
if (sourceType.getElementType() != resultType.getElementType())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
// Requires same rank.
|
||||||
|
if (sourceType.getRank() != resultType.getRank())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
// Only fold casts between strided memref forms.
|
||||||
|
int64_t sourceOffset, resultOffset;
|
||||||
|
SmallVector<int64_t, 4> sourceStrides, resultStrides;
|
||||||
|
if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
|
||||||
|
failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
// If cast is towards more static sizes along any dimension, don't fold.
|
||||||
|
for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
|
||||||
|
auto ss = std::get<0>(it), st = std::get<1>(it);
|
||||||
|
if (ss != st)
|
||||||
|
if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st))
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If cast is towards more static offset along any dimension, don't fold.
|
||||||
|
if (sourceOffset != resultOffset)
|
||||||
|
if (MemRefType::isDynamicStrideOrOffset(sourceOffset) &&
|
||||||
|
!MemRefType::isDynamicStrideOrOffset(resultOffset))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
// If cast is towards more static strides along any dimension, don't fold.
|
||||||
|
for (auto it : llvm::zip(sourceStrides, resultStrides)) {
|
||||||
|
auto ss = std::get<0>(it), st = std::get<1>(it);
|
||||||
|
if (ss != st)
|
||||||
|
if (MemRefType::isDynamicStrideOrOffset(ss) &&
|
||||||
|
!MemRefType::isDynamicStrideOrOffset(st))
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
OpFoldResult SubViewOp::fold(ArrayRef<Attribute>) {
|
||||||
|
auto folds = [](Operation *op) {
|
||||||
|
bool folded = false;
|
||||||
|
for (OpOperand &operand : op->getOpOperands()) {
|
||||||
|
auto castOp =
|
||||||
|
dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
|
||||||
|
if (castOp && canFoldIntoConsumerOp(castOp)) {
|
||||||
|
operand.set(castOp.getOperand());
|
||||||
|
folded = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return folded ? success() : failure();
|
||||||
|
};
|
||||||
|
|
||||||
|
if (succeeded(folds(*this)))
|
||||||
|
return getResult();
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
results.insert<SubViewOpShapeFolder, SubViewOpStrideFolder,
|
results.insert<SubViewOpShapeFolder, SubViewOpStrideFolder,
|
||||||
|
|
|
@ -919,3 +919,15 @@ func @tensor_divi_unsigned_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> {
|
||||||
// CHECK: return %[[ARG]]
|
// CHECK: return %[[ARG]]
|
||||||
return %res : tensor<4x5xi32>
|
return %res : tensor<4x5xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @memref_cast_folding_subview
|
||||||
|
func @memref_cast_folding_subview(%arg0: memref<4x5xf32>, %i: index) -> (memref<?x?xf32, offset:? , strides: [?, ?]>) {
|
||||||
|
%0 = memref_cast %arg0 : memref<4x5xf32> to memref<?x?xf32>
|
||||||
|
// CHECK-NEXT: subview %{{.*}}: memref<4x5xf32>
|
||||||
|
%1 = subview %0[][%i,%i][]: memref<?x?xf32> to memref<?x?xf32, offset:? , strides: [?, ?]>
|
||||||
|
// CHECK-NEXT: return %{{.*}}
|
||||||
|
return %1: memref<?x?xf32, offset:? , strides: [?, ?]>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue