From 9b7435fb50230621b5660a8d3dad51c40c6c348d Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Wed, 25 Sep 2019 10:16:39 -0700 Subject: [PATCH] Add tablegen verification traits for comparing different properties This allows things like comparing the rank of one operand to the size of another that specifies indices into it. PiperOrigin-RevId: 271150439 --- mlir/include/mlir/IR/OpBase.td | 57 ++++++++++++++++------------ mlir/test/lib/TestDialect/TestOps.td | 12 ++++++ mlir/test/mlir-tblgen/types.mlir | 24 ++++++++++++ 3 files changed, 68 insertions(+), 25 deletions(-) diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 1139c7fdd9a7..43e9b5e0ae57 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1437,6 +1437,38 @@ def HasNoUseOf: Constraint< // TODO(b/135033717): Improve the autogenerated error messages. +class AllMatchPred values> : + CPred<"llvm::is_splat(llvm::makeArrayRef({"# StrJoin.result #"}))">; + +class AllMatch values, string description> : + PredOpTrait>; + +// TODO(b/135032064): Only works for non-variadic. +class AllMatchSameOperatorPred names, string operator> : + AllMatchPred; + +class AllMatchSameOperatorTrait names, string operator, + string description> : + PredOpTrait< + "all of {" # StrJoin.result # "} have same " # description, + AllMatchSameOperatorPred>; + +class AllElementCountsMatch names> : + AllMatchSameOperatorTrait< + names, "$_self.getType().cast().getNumElements()", + "element count">; + +class AllElementTypesMatch names> : + AllMatchSameOperatorTrait; + +class AllRanksMatch names> : + AllMatchSameOperatorTrait< + names, "$_self.getType().cast().getRank()", "rank">; + +class AllTypesMatch names> : + AllMatchSameOperatorTrait; + // Type Constraint operand `idx`'s Element type is `type`. class TCopVTEtIs : And<[ CPred<"$_op.getNumOperands() > " # idx>, @@ -1461,31 +1493,6 @@ class ElementTypeIsPred : And<[ class ElementTypeIs : PredOpTrait< "'" # name # "' is " # type.description, ElementTypeIsPred>; -// TODO(b/135032064): Only works for non-variadic. -class AllMatchPred names, string operator> : - CPred<"llvm::is_splat(llvm::makeArrayRef({" # - StrJoin.result - # "}))">; - -class AllMatchTrait names, string operator, string description> : - PredOpTrait< - "all of {" # StrJoin.result # "} have same " # description, - AllMatchPred>; - -class AllElementCountsMatch names> : - AllMatchTrait().getNumElements()", - "element count">; - -class AllElementTypesMatch names> : - AllMatchTrait; - -class AllRanksMatch names> : - AllMatchTrait().getRank()", "rank">; - -class AllTypesMatch names> : - AllMatchTrait; - // Predicate to verify that the i'th operand and the j'th operand have the same // elemental type. // Type Constraint operand `i`'s Element type is Same As operand `j`'s Element diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td index 7b4d4d458e09..862861ae8591 100644 --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -288,6 +288,18 @@ def Operand1AndResultHaveSameElementCount : let results = (outs AnyTensor:$res); } +def FourEqualsFive : + TEST_Op<"four_equals_five", [AllMatch<["5", "4"], "4 equals 5">]>; + +def OperandRankEqualsResultSize : + TEST_Op<"operand_rank_equals_result_size", + [AllMatch<["$operand.getType().cast().getRank()", + "$result.getType().cast().getNumElements()" + ], "operand rank equals result size">]> { + let arguments = (ins AnyTensor:$operand); + let results = (outs AnyTensor:$result); +} + def IfFirstOperandIsNoneThenSoIsSecond : TEST_Op<"if_first_operand_is_none_then_so_is_second", [PredOpTrait< "has either both none type operands or first is not none", diff --git a/mlir/test/mlir-tblgen/types.mlir b/mlir/test/mlir-tblgen/types.mlir index 4c5b80ea4e77..7b48d0a6da23 100644 --- a/mlir/test/mlir-tblgen/types.mlir +++ b/mlir/test/mlir-tblgen/types.mlir @@ -298,6 +298,30 @@ func @operand1_and_result_have_same_element_count_failure(%arg0: tensor<1xi32>, // ----- +func @four_equals_five() { + // expected-error@+1 {{failed to verify that 4 equals 5}} + "test.four_equals_five"() : () -> () + return +} + +// ----- + +func @operand_rank_equals_result_size_success(%arg : tensor<1x2x3x4xi32>) { + %0 = "test.operand_rank_equals_result_size"(%arg) : (tensor<1x2x3x4xi32>) -> tensor<4xi32> + %1 = "test.operand_rank_equals_result_size"(%arg) : (tensor<1x2x3x4xi32>) -> tensor<2x2xf32> + return +} + +// ----- + +func @operand_rank_equals_result_size_failure(%arg : tensor<1x2x3x4xi32>) { + // expected-error@+1 {{failed to verify that operand rank equals result size}} + %0 = "test.operand_rank_equals_result_size"(%arg) : (tensor<1x2x3x4xi32>) -> tensor<2xi32> + return +} + +// ----- + func @same_types_element_mismatch(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) { // expected-error@+1 {{all of {x, res} have same type}} %0 = "test.operand_one_and_result_have_same_type"(%arg0, %arg1) : (tensor<* x i32>, tensor<* x f32>) -> tensor<* x f32>