forked from OSchip/llvm-project
[MLIR][Shape] Lower `shape.any`
Lower `shape.any` to its first operand. Differential Revision: https://reviews.llvm.org/D83123
This commit is contained in:
parent
07c4c7e795
commit
9df6afbb5c
|
@ -23,6 +23,22 @@ namespace {
|
||||||
#include "ShapeToStandardPatterns.inc"
|
#include "ShapeToStandardPatterns.inc"
|
||||||
|
|
||||||
/// Conversion patterns.
|
/// Conversion patterns.
|
||||||
|
class AnyOpConversion : public OpConversionPattern<AnyOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<AnyOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
AnyOp::Adaptor transformed(operands);
|
||||||
|
|
||||||
|
// Replace `any` with its first operand.
|
||||||
|
// Any operand would be a valid substitution.
|
||||||
|
rewriter.replaceOp(op, {transformed.inputs().front()});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename SrcOpTy, typename DstOpTy>
|
template <typename SrcOpTy, typename DstOpTy>
|
||||||
class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
|
class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
|
||||||
public:
|
public:
|
||||||
|
@ -181,6 +197,7 @@ void mlir::populateShapeToStandardConversionPatterns(
|
||||||
populateWithGenerated(ctx, &patterns);
|
populateWithGenerated(ctx, &patterns);
|
||||||
// clang-format off
|
// clang-format off
|
||||||
patterns.insert<
|
patterns.insert<
|
||||||
|
AnyOpConversion,
|
||||||
BinaryOpConversion<AddOp, AddIOp>,
|
BinaryOpConversion<AddOp, AddIOp>,
|
||||||
BinaryOpConversion<MulOp, MulIOp>,
|
BinaryOpConversion<MulOp, MulIOp>,
|
||||||
ConstSizeOpConverter,
|
ConstSizeOpConverter,
|
||||||
|
|
|
@ -18,4 +18,3 @@ def IndexToSizeOpConversion : Pat<
|
||||||
def SizeToIndexOpConversion : Pat<
|
def SizeToIndexOpConversion : Pat<
|
||||||
(Shape_SizeToIndexOp $arg),
|
(Shape_SizeToIndexOp $arg),
|
||||||
(replaceWithValue $arg)>;
|
(replaceWithValue $arg)>;
|
||||||
|
|
||||||
|
|
|
@ -158,3 +158,26 @@ func @get_extent_from_extent_tensor(%extents : tensor<?xindex>,
|
||||||
return %result : !shape.size
|
return %result : !shape.size
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Lower `any` to its first operand.
|
||||||
|
// CHECK-LABEL: @any_of_three
|
||||||
|
// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>) -> tensor<?xindex>
|
||||||
|
func @any_of_three(%a : !shape.shape, %b : !shape.shape, %c : !shape.shape)
|
||||||
|
-> !shape.shape {
|
||||||
|
// CHECK: return %[[A]] : tensor<?xindex>
|
||||||
|
%result = shape.any %a, %b, %c
|
||||||
|
return %result : !shape.shape
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Lower `any` to its first operand.
|
||||||
|
// CHECK-LABEL: @any_of_one
|
||||||
|
// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>) -> tensor<?xindex>
|
||||||
|
func @any_of_one(%a : !shape.shape) -> !shape.shape {
|
||||||
|
// CHECK: return %[[A]] : tensor<?xindex>
|
||||||
|
%result = shape.any %a
|
||||||
|
return %result : !shape.shape
|
||||||
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue