diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index 0b8c26dc9156..41e6f8a2a562 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -416,6 +416,8 @@ def Shape_MaxOp : Shape_Op<"max", [Commutative, NoSideEffect]> { let assemblyFormat = [{ $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) }]; + + let hasFolder = 1; } def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> { @@ -433,6 +435,8 @@ def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> { let assemblyFormat = [{ $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) }]; + + let hasFolder = 1; } def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index bb7ed5cf05ce..388a3a5763b1 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -937,6 +937,28 @@ void NumElementsOp::build(OpBuilder &builder, OperationState &result, return build(builder, result, type, shape); } +//===----------------------------------------------------------------------===// +// MaxOp +//===----------------------------------------------------------------------===// + +OpFoldResult MaxOp::fold(llvm::ArrayRef operands) { + // If operands are equal, just propagate one. + if (lhs() == rhs()) + return lhs(); + return nullptr; +} + +//===----------------------------------------------------------------------===// +// MinOp +//===----------------------------------------------------------------------===// + +OpFoldResult MinOp::fold(llvm::ArrayRef operands) { + // If operands are equal, just propagate one. + if (lhs() == rhs()) + return lhs(); + return nullptr; +} + //===----------------------------------------------------------------------===// // MulOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index b0c12ea0b149..86ac4c9af963 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1188,3 +1188,23 @@ func @casted_extent_tensor(%arg : tensor<*xf32>) -> tensor<3xindex> { %1 = tensor.cast %0 : tensor to tensor<3xindex> return %1 : tensor<3xindex> } + +// ---- + +// CHECK-LABEL: max_same_arg +// CHECK-SAME: (%[[SHAPE:.*]]: !shape.shape) +func @max_same_arg(%a: !shape.shape) -> !shape.shape { + %1 = shape.max %a, %a : !shape.shape, !shape.shape -> !shape.shape + // CHECK: return %[[SHAPE]] + return %1 : !shape.shape +} + +// ---- + +// CHECK-LABEL: min_same_arg +// CHECK-SAME: (%[[SHAPE:.*]]: !shape.shape) +func @min_same_arg(%a: !shape.shape) -> !shape.shape { + %1 = shape.min %a, %a : !shape.shape, !shape.shape -> !shape.shape + // CHECK: return %[[SHAPE]] + return %1 : !shape.shape +}