[mlir] Remove SameOperandsAndResultShape when redundant with ElementwiseMappable

SameOperandsAndResultShape and ElementwiseMappable have similar
verification, but in general neither is strictly redundant with the
other.

Examples:
- SameOperandsAndResultShape allows
  `"foo"(%0) : tensor<2xf32> -> tensor<?xf32> but ElementwiseMappable
  does not.
- ElementwiseMappable allows
  `select %scalar_pred, %true_tensor, %false_tensor` but
  SameOperandsAndResultShape does not.

SameOperandsAndResultShape is redundant with ElementwiseMappable when
we can prove that the mixed scalar/non-scalar case cannot happen. In
those situations, `ElementwiseMappable & SameOperandsAndResultShape ==
ElementwiseMappable`:
- Ops with 1 operand: the case of mixed scalar and non-scalar operands
  cannot happen since there is only one operand.
- When SameTypeOperands is also present, the mixed scalar/non-scalar
  operand case cannot happen.

Differential Revision: https://reviews.llvm.org/D91396
This commit is contained in:
Sean Silva 2020-11-12 17:08:56 -08:00
parent b228e2bd92
commit dfbb5a087e
2 changed files with 11 additions and 15 deletions

View File

@ -991,10 +991,10 @@ def CmpFPredicateAttr : I64EnumAttr<
}
def CmpFOp : Std_Op<"cmpf",
[NoSideEffect, SameTypeOperands,
SameOperandsAndResultShape, TypesMatchWith<
[NoSideEffect, SameTypeOperands, ElementwiseMappable,
TypesMatchWith<
"result type has i1 element type and same shape as operands",
"lhs", "result", "getI1SameShape($_self)">, ElementwiseMappable]> {
"lhs", "result", "getI1SameShape($_self)">]> {
let summary = "floating-point comparison operation";
let description = [{
The `cmpf` operation compares its two operands according to the float
@ -1075,10 +1075,10 @@ def CmpIPredicateAttr : I64EnumAttr<
}
def CmpIOp : Std_Op<"cmpi",
[NoSideEffect, SameTypeOperands,
SameOperandsAndResultShape, TypesMatchWith<
[NoSideEffect, SameTypeOperands, ElementwiseMappable,
TypesMatchWith<
"result type has i1 element type and same shape as operands",
"lhs", "result", "getI1SameShape($_self)">, ElementwiseMappable]> {
"lhs", "result", "getI1SameShape($_self)">]> {
let summary = "integer comparison operation";
let description = [{
The `cmpi` operation is a generic comparison for integer-like types. Its two
@ -2799,7 +2799,7 @@ def SignedShiftRightOp : IntArithmeticOp<"shift_right_signed"> {
//===----------------------------------------------------------------------===//
def SignExtendIOp : Std_Op<"sexti",
[NoSideEffect, SameOperandsAndResultShape, ElementwiseMappable]> {
[NoSideEffect, ElementwiseMappable]> {
let summary = "integer sign extension operation";
let description = [{
The integer sign extension operation takes an integer input of
@ -3665,9 +3665,7 @@ def TransposeOp : Std_Op<"transpose", [NoSideEffect]>,
// TruncateIOp
//===----------------------------------------------------------------------===//
def TruncateIOp : Std_Op<"trunci", [NoSideEffect,
SameOperandsAndResultShape,
ElementwiseMappable]> {
def TruncateIOp : Std_Op<"trunci", [NoSideEffect, ElementwiseMappable]> {
let summary = "integer truncation operation";
let description = [{
The integer truncation operation takes an integer input of
@ -3934,9 +3932,7 @@ def XOrOp : IntArithmeticOp<"xor", [Commutative]> {
// ZeroExtendIOp
//===----------------------------------------------------------------------===//
def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect,
SameOperandsAndResultShape,
ElementwiseMappable]> {
def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect, ElementwiseMappable]> {
let summary = "integer zero extension operation";
let description = [{
The integer zero extension operation takes an integer input of

View File

@ -236,7 +236,7 @@ func @func_with_ops(i32, i32) {
func @func_with_ops() {
^bb0:
%c = constant dense<0> : vector<42 x i32>
// expected-error@+1 {{op requires the same shape for all operands and results}}
// expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<41xi1>' and 'vector<42xi32>'}}
%r = "std.cmpi"(%c, %c) {predicate = 0} : (vector<42 x i32>, vector<42 x i32>) -> vector<41 x i1>
}
@ -514,7 +514,7 @@ func @cmpf_canonical_wrong_result_type(%a : f32, %b : f32) -> f32 {
// -----
func @cmpf_result_shape_mismatch(%a : vector<42xf32>) {
// expected-error@+1 {{op requires the same shape for all operands and results}}
// expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<41xi1>' and 'vector<42xf32>'}}
%r = "std.cmpf"(%a, %a) {predicate = 0} : (vector<42 x f32>, vector<42 x f32>) -> vector<41 x i1>
}