[MLIR][Shape] Lower `shape.any`

Lower `shape.any` to its first operand.

Differential Revision: https://reviews.llvm.org/D83123
This commit is contained in:
Frederik Gossen 2020-07-13 08:28:13 +00:00
parent 07c4c7e795
commit 9df6afbb5c
3 changed files with 40 additions and 1 deletions

View File

@ -23,6 +23,22 @@ namespace {
#include "ShapeToStandardPatterns.inc" #include "ShapeToStandardPatterns.inc"
/// Conversion patterns. /// Conversion patterns.
class AnyOpConversion : public OpConversionPattern<AnyOp> {
public:
using OpConversionPattern<AnyOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
AnyOp::Adaptor transformed(operands);
// Replace `any` with its first operand.
// Any operand would be a valid substitution.
rewriter.replaceOp(op, {transformed.inputs().front()});
return success();
}
};
template <typename SrcOpTy, typename DstOpTy> template <typename SrcOpTy, typename DstOpTy>
class BinaryOpConversion : public OpConversionPattern<SrcOpTy> { class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
public: public:
@ -181,6 +197,7 @@ void mlir::populateShapeToStandardConversionPatterns(
populateWithGenerated(ctx, &patterns); populateWithGenerated(ctx, &patterns);
// clang-format off // clang-format off
patterns.insert< patterns.insert<
AnyOpConversion,
BinaryOpConversion<AddOp, AddIOp>, BinaryOpConversion<AddOp, AddIOp>,
BinaryOpConversion<MulOp, MulIOp>, BinaryOpConversion<MulOp, MulIOp>,
ConstSizeOpConverter, ConstSizeOpConverter,

View File

@ -18,4 +18,3 @@ def IndexToSizeOpConversion : Pat<
def SizeToIndexOpConversion : Pat< def SizeToIndexOpConversion : Pat<
(Shape_SizeToIndexOp $arg), (Shape_SizeToIndexOp $arg),
(replaceWithValue $arg)>; (replaceWithValue $arg)>;

View File

@ -158,3 +158,26 @@ func @get_extent_from_extent_tensor(%extents : tensor<?xindex>,
return %result : !shape.size return %result : !shape.size
} }
// -----
// Lower `any` to its first operand.
// CHECK-LABEL: @any_of_three
// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>) -> tensor<?xindex>
func @any_of_three(%a : !shape.shape, %b : !shape.shape, %c : !shape.shape)
-> !shape.shape {
// CHECK: return %[[A]] : tensor<?xindex>
%result = shape.any %a, %b, %c
return %result : !shape.shape
}
// -----
// Lower `any` to its first operand.
// CHECK-LABEL: @any_of_one
// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>) -> tensor<?xindex>
func @any_of_one(%a : !shape.shape) -> !shape.shape {
// CHECK: return %[[A]] : tensor<?xindex>
%result = shape.any %a
return %result : !shape.shape
}