diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 5467e2f78215..1eb17dabab00 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1410,9 +1410,16 @@ class AllMatchTrait names, string operator, string description> : "all of {" # StrJoin.result # "} have same " # description, AllMatchPred>; +class AllElementCountsMatch names> : + AllMatchTrait().getNumElements()", + "element count">; + class AllElementTypesMatch names> : AllMatchTrait; +class AllRanksMatch names> : + AllMatchTrait().getRank()", "rank">; + class AllTypesMatch names> : AllMatchTrait; diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td index fb2c2b5c1976..ff04bfdd8dda 100644 --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -246,6 +246,25 @@ def OperandOneAndResultHaveSameType : let results = (outs AnyTensor:$res); } +def OperandsHaveSameRank : + TEST_Op<"operands_have_same_rank", [AllRanksMatch<["x", "y"]>]> { + let arguments = (ins AnyTensor:$x, AnyTensor:$y); +} + +def Operand1AndResultHaveSameRank : + TEST_Op<"operand1_and_result_have_same_rank", + [AllRanksMatch<["x", "res"]>]> { + let arguments = (ins AnyTensor:$x, AnyTensor:$y); + let results = (outs AnyTensor:$res); +} + +def Operand1AndResultHaveSameElementCount : + TEST_Op<"operand1_and_result_have_same_element_count", + [AllElementCountsMatch<["x", "res"]>]> { + let arguments = (ins AnyTensor:$x, AnyTensor:$y); + let results = (outs AnyTensor:$res); +} + def IfFirstOperandIsNoneThenSoIsSecond : TEST_Op<"if_first_operand_is_none_then_so_is_second", [PredOpTrait< "has either both none type operands or first is not none", diff --git a/mlir/test/mlir-tblgen/types.mlir b/mlir/test/mlir-tblgen/types.mlir index 7050da17bcfa..4c5b80ea4e77 100644 --- a/mlir/test/mlir-tblgen/types.mlir +++ b/mlir/test/mlir-tblgen/types.mlir @@ -239,6 +239,65 @@ func @same_types(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) { // ----- +// CHECK-LABEL: operands_have_same_rank_success +func @operands_have_same_rank_success(%arg0: tensor<1xi32>, %arg1: tensor<2xf32>) { + "test.operands_have_same_rank"(%arg0, %arg1) : (tensor<1xi32>, tensor<2xf32>) -> () + return +} + +// ----- + +func @operands_have_same_rank_failure(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) { + // expected-error@+1 {{all of {x, y} have same rank}} + "test.operands_have_same_rank"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: operand1_and_result_have_same_rank_success +func @operand1_and_result_have_same_rank_success(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>, %arg3: tensor<1x2xi32>) { + "test.operand1_and_result_have_same_rank"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> (tensor<3xf32>) + "test.operand1_and_result_have_same_rank"(%arg3, %arg1) : (tensor<1x2xi32>, tensor<1x2xf32>) -> (tensor<3x3xf64>) + return +} + +// ----- + +func @operand1_and_result_have_same_rank_failure(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) { + // expected-error@+1 {{all of {x, res} have same rank}} + "test.operand1_and_result_have_same_rank"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> (tensor) + return +} + +// ----- + +func @operand1_and_result_have_same_rank_failure(%arg0: tensor<1x2xi32>, %arg1: tensor<1x2xf32>) { + // expected-error@+1 {{all of {x, res} have same rank}} + "test.operand1_and_result_have_same_rank"(%arg0, %arg1) : (tensor<1x2xi32>, tensor<1x2xf32>) -> (tensor<3xi32>) + return +} + +// ----- + +// CHECK-LABEL: operand1_and_result_have_same_element_count_success +func @operand1_and_result_have_same_element_count_success(%arg0: tensor<36xi32>, %arg1: tensor<1x2xf32>, %arg3: tensor) { + "test.operand1_and_result_have_same_element_count"(%arg0, %arg1) : (tensor<36xi32>, tensor<1x2xf32>) -> (tensor<3x4x3xf32>) + "test.operand1_and_result_have_same_element_count"(%arg0, %arg1) : (tensor<36xi32>, tensor<1x2xf32>) -> (tensor<3x12xf64>) + "test.operand1_and_result_have_same_element_count"(%arg3, %arg1) : (tensor, tensor<1x2xf32>) -> (tensor<1x1x1xi32>) + return +} + +// ----- + +func @operand1_and_result_have_same_element_count_failure(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) { + // expected-error@+1 {{all of {x, res} have same element count}} + "test.operand1_and_result_have_same_element_count"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> (tensor<2xi32>) + return +} + +// ----- + func @same_types_element_mismatch(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) { // expected-error@+1 {{all of {x, res} have same type}} %0 = "test.operand_one_and_result_have_same_type"(%arg0, %arg1) : (tensor<* x i32>, tensor<* x f32>) -> tensor<* x f32>