From dfbb5a087e20ea1c14300eef600e52360320b390 Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Thu, 12 Nov 2020 17:08:56 -0800 Subject: [PATCH] [mlir] Remove SameOperandsAndResultShape when redundant with ElementwiseMappable SameOperandsAndResultShape and ElementwiseMappable have similar verification, but in general neither is strictly redundant with the other. Examples: - SameOperandsAndResultShape allows `"foo"(%0) : tensor<2xf32> -> tensor but ElementwiseMappable does not. - ElementwiseMappable allows `select %scalar_pred, %true_tensor, %false_tensor` but SameOperandsAndResultShape does not. SameOperandsAndResultShape is redundant with ElementwiseMappable when we can prove that the mixed scalar/non-scalar case cannot happen. In those situations, `ElementwiseMappable & SameOperandsAndResultShape == ElementwiseMappable`: - Ops with 1 operand: the case of mixed scalar and non-scalar operands cannot happen since there is only one operand. - When SameTypeOperands is also present, the mixed scalar/non-scalar operand case cannot happen. Differential Revision: https://reviews.llvm.org/D91396 --- .../mlir/Dialect/StandardOps/IR/Ops.td | 22 ++++++++----------- mlir/test/IR/invalid-ops.mlir | 4 ++-- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td index c44d99b1620d..441cff497ed2 100644 --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -991,10 +991,10 @@ def CmpFPredicateAttr : I64EnumAttr< } def CmpFOp : Std_Op<"cmpf", - [NoSideEffect, SameTypeOperands, - SameOperandsAndResultShape, TypesMatchWith< + [NoSideEffect, SameTypeOperands, ElementwiseMappable, + TypesMatchWith< "result type has i1 element type and same shape as operands", - "lhs", "result", "getI1SameShape($_self)">, ElementwiseMappable]> { + "lhs", "result", "getI1SameShape($_self)">]> { let summary = "floating-point comparison operation"; let description = [{ The `cmpf` operation compares its two operands according to the float @@ -1075,10 +1075,10 @@ def CmpIPredicateAttr : I64EnumAttr< } def CmpIOp : Std_Op<"cmpi", - [NoSideEffect, SameTypeOperands, - SameOperandsAndResultShape, TypesMatchWith< + [NoSideEffect, SameTypeOperands, ElementwiseMappable, + TypesMatchWith< "result type has i1 element type and same shape as operands", - "lhs", "result", "getI1SameShape($_self)">, ElementwiseMappable]> { + "lhs", "result", "getI1SameShape($_self)">]> { let summary = "integer comparison operation"; let description = [{ The `cmpi` operation is a generic comparison for integer-like types. Its two @@ -2799,7 +2799,7 @@ def SignedShiftRightOp : IntArithmeticOp<"shift_right_signed"> { //===----------------------------------------------------------------------===// def SignExtendIOp : Std_Op<"sexti", - [NoSideEffect, SameOperandsAndResultShape, ElementwiseMappable]> { + [NoSideEffect, ElementwiseMappable]> { let summary = "integer sign extension operation"; let description = [{ The integer sign extension operation takes an integer input of @@ -3665,9 +3665,7 @@ def TransposeOp : Std_Op<"transpose", [NoSideEffect]>, // TruncateIOp //===----------------------------------------------------------------------===// -def TruncateIOp : Std_Op<"trunci", [NoSideEffect, - SameOperandsAndResultShape, - ElementwiseMappable]> { +def TruncateIOp : Std_Op<"trunci", [NoSideEffect, ElementwiseMappable]> { let summary = "integer truncation operation"; let description = [{ The integer truncation operation takes an integer input of @@ -3934,9 +3932,7 @@ def XOrOp : IntArithmeticOp<"xor", [Commutative]> { // ZeroExtendIOp //===----------------------------------------------------------------------===// -def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect, - SameOperandsAndResultShape, - ElementwiseMappable]> { +def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect, ElementwiseMappable]> { let summary = "integer zero extension operation"; let description = [{ The integer zero extension operation takes an integer input of diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index 3d9fe45959ed..1731c9c1aeb9 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -236,7 +236,7 @@ func @func_with_ops(i32, i32) { func @func_with_ops() { ^bb0: %c = constant dense<0> : vector<42 x i32> - // expected-error@+1 {{op requires the same shape for all operands and results}} + // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<41xi1>' and 'vector<42xi32>'}} %r = "std.cmpi"(%c, %c) {predicate = 0} : (vector<42 x i32>, vector<42 x i32>) -> vector<41 x i1> } @@ -514,7 +514,7 @@ func @cmpf_canonical_wrong_result_type(%a : f32, %b : f32) -> f32 { // ----- func @cmpf_result_shape_mismatch(%a : vector<42xf32>) { - // expected-error@+1 {{op requires the same shape for all operands and results}} + // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<41xi1>' and 'vector<42xf32>'}} %r = "std.cmpf"(%a, %a) {predicate = 0} : (vector<42 x f32>, vector<42 x f32>) -> vector<41 x i1> }