diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index 66ec6e7ac551..b4b287e55074 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -565,7 +565,7 @@ def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", [Commutative]> { let hasFolder = 1; } -def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative]> { +def Shape_CstrEqOp : Shape_Op<"cstr_eq", []> { let summary = "Determines if all input shapes are equal"; let description = [{ Given 1 or more input shapes, determine if all shapes are the exact same. @@ -582,9 +582,6 @@ def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative]> { let results = (outs Shape_WitnessType:$result); let assemblyFormat = "$inputs attr-dict"; - - let hasCanonicalizer = 1; - let hasFolder = 1; } def Shape_ConstWitnessOp : Shape_Op<"const_witness", [ConstantLike, NoSideEffect]> { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 04b1a51e986e..25fe3d9d3b22 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -325,27 +325,6 @@ OpFoldResult CstrBroadcastableOp::fold(ArrayRef operands) { return nullptr; } -//===----------------------------------------------------------------------===// -// CstrEqOp -//===----------------------------------------------------------------------===// - -void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, - MLIRContext *context) { - // If inputs are equal, return passing witness - patterns.insert(context); -} - -OpFoldResult CstrEqOp::fold(ArrayRef operands) { - if (llvm::all_of(operands, - [&](Attribute a) { return a && a == operands[0]; })) - return BoolAttr::get(true, getContext()); - - // Because a failing witness result here represents an eventual assertion - // failure, we do not try to replace it with a constant witness. Similarly, we - // cannot if there are any non-const inputs. - return nullptr; -} - //===----------------------------------------------------------------------===// // ConstSizeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td index 78c9119f1292..9a73a8847779 100644 --- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td +++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td @@ -2,17 +2,7 @@ include "mlir/Dialect/Shape/IR/ShapeOps.td" def EqualBinaryOperands : Constraint>; -def AllInputShapesEq : Constraint>; - // Canonicalization patterns. def CstrBroadcastableEqOps : Pat<(Shape_CstrBroadcastableOp:$op $lhs, $rhs), (Shape_ConstWitnessOp ConstBoolAttrTrue), [(EqualBinaryOperands $lhs, $rhs)]>; - -def CstrEqEqOps : Pat<(Shape_CstrEqOp:$op $shapes), - (Shape_ConstWitnessOp ConstBoolAttrTrue), - [(AllInputShapesEq $shapes)]>; diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index 7c90753e255e..5ebca6784c0e 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -213,62 +213,6 @@ func @not_const(%arg0: !shape.shape) -> !shape.size { return %0 : !shape.size } - -// ----- -// cstr_eq with non-constant but known equal shapes can be removed. -// CHECK-LABEL: func @f -func @f(%arg0 : !shape.shape) { - // CHECK-NEXT: shape.const_witness true - // CHECK-NEXT: consume.witness - // CHECK-NEXT: return - %0 = shape.cstr_eq %arg0, %arg0, %arg0 - "consume.witness"(%0) : (!shape.witness) -> () - return -} - -// ----- -// cstr_eq with equal const_shapes can be folded -// CHECK-LABEL: func @f -func @f() { - // CHECK-NEXT: shape.const_witness true - // CHECK-NEXT: consume.witness - // CHECK-NEXT: return - %cs0 = shape.const_shape [0, 1] - %cs1 = shape.const_shape [0, 1] - %cs2 = shape.const_shape [0, 1] - %0 = shape.cstr_eq %cs0, %cs1, %cs2 - "consume.witness"(%0) : (!shape.witness) -> () - return -} - -// ----- -// cstr_eq with unequal const_shapes cannot be folded -// CHECK-LABEL: func @f -func @f() { - // CHECK-NEXT: shape.const_shape - // CHECK-NEXT: shape.const_shape - // CHECK-NEXT: shape.cstr_eq - // CHECK-NEXT: consume.witness - // CHECK-NEXT: return - %cs0 = shape.const_shape [0, 1] - %cs1 = shape.const_shape [3, 1] - %0 = shape.cstr_eq %cs0, %cs1 - "consume.witness"(%0) : (!shape.witness) -> () - return -} - -// ----- -// cstr_eq without const_shapes cannot be folded -// CHECK-LABEL: func @f -func @f(%arg0: !shape.shape, %arg1: !shape.shape) { - // CHECK-NEXT: shape.cstr_eq - // CHECK-NEXT: consume.witness - // CHECK-NEXT: return - %0 = shape.cstr_eq %arg0, %arg1 - "consume.witness"(%0) : (!shape.witness) -> () - return -} - // ----- // assuming_all with known passing witnesses can be folded // CHECK-LABEL: func @f