Add tablegen verification traits for comparing different properties

This allows things like comparing the rank of one operand to the size of another that specifies indices into it.

PiperOrigin-RevId: 271150439
This commit is contained in:
Geoffrey Martin-Noble 2019-09-25 10:16:39 -07:00 committed by A. Unique TensorFlower
parent b76c4f8780
commit 9b7435fb50
3 changed files with 68 additions and 25 deletions

View File

@ -1437,6 +1437,38 @@ def HasNoUseOf: Constraint<
// TODO(b/135033717): Improve the autogenerated error messages.
class AllMatchPred<list<string> values> :
CPred<"llvm::is_splat(llvm::makeArrayRef({"# StrJoin<values>.result #"}))">;
class AllMatch<list<string> values, string description> :
PredOpTrait<description, AllMatchPred<values>>;
// TODO(b/135032064): Only works for non-variadic.
class AllMatchSameOperatorPred<list<string> names, string operator> :
AllMatchPred<!foreach(n, names, !subst("$_self", "$" # n, operator))>;
class AllMatchSameOperatorTrait<list<string> names, string operator,
string description> :
PredOpTrait<
"all of {" # StrJoin<names>.result # "} have same " # description,
AllMatchSameOperatorPred<names, operator>>;
class AllElementCountsMatch<list<string> names> :
AllMatchSameOperatorTrait<
names, "$_self.getType().cast<ShapedType>().getNumElements()",
"element count">;
class AllElementTypesMatch<list<string> names> :
AllMatchSameOperatorTrait<names,
"getElementTypeOrSelf($_self)", "element type">;
class AllRanksMatch<list<string> names> :
AllMatchSameOperatorTrait<
names, "$_self.getType().cast<ShapedType>().getRank()", "rank">;
class AllTypesMatch<list<string> names> :
AllMatchSameOperatorTrait<names, "$_self.getType()", "type">;
// Type Constraint operand `idx`'s Element type is `type`.
class TCopVTEtIs<int idx, Type type> : And<[
CPred<"$_op.getNumOperands() > " # idx>,
@ -1461,31 +1493,6 @@ class ElementTypeIsPred<string name, Type type> : And<[
class ElementTypeIs<string name, Type type> : PredOpTrait<
"'" # name # "' is " # type.description, ElementTypeIsPred<name, type>>;
// TODO(b/135032064): Only works for non-variadic.
class AllMatchPred<list<string> names, string operator> :
CPred<"llvm::is_splat(llvm::makeArrayRef({" #
StrJoin<!foreach(n, names,
!subst("$_self", "$" # n, operator))>.result
# "}))">;
class AllMatchTrait<list<string> names, string operator, string description> :
PredOpTrait<
"all of {" # StrJoin<names>.result # "} have same " # description,
AllMatchPred<names, operator>>;
class AllElementCountsMatch<list<string> names> :
AllMatchTrait<names, "$_self.getType().cast<ShapedType>().getNumElements()",
"element count">;
class AllElementTypesMatch<list<string> names> :
AllMatchTrait<names, "getElementTypeOrSelf($_self)", "element type">;
class AllRanksMatch<list<string> names> :
AllMatchTrait<names, "$_self.getType().cast<ShapedType>().getRank()", "rank">;
class AllTypesMatch<list<string> names> :
AllMatchTrait<names, "$_self.getType()", "type">;
// Predicate to verify that the i'th operand and the j'th operand have the same
// elemental type.
// Type Constraint operand `i`'s Element type is Same As operand `j`'s Element

View File

@ -288,6 +288,18 @@ def Operand1AndResultHaveSameElementCount :
let results = (outs AnyTensor:$res);
}
def FourEqualsFive :
TEST_Op<"four_equals_five", [AllMatch<["5", "4"], "4 equals 5">]>;
def OperandRankEqualsResultSize :
TEST_Op<"operand_rank_equals_result_size",
[AllMatch<["$operand.getType().cast<ShapedType>().getRank()",
"$result.getType().cast<ShapedType>().getNumElements()"
], "operand rank equals result size">]> {
let arguments = (ins AnyTensor:$operand);
let results = (outs AnyTensor:$result);
}
def IfFirstOperandIsNoneThenSoIsSecond :
TEST_Op<"if_first_operand_is_none_then_so_is_second", [PredOpTrait<
"has either both none type operands or first is not none",

View File

@ -298,6 +298,30 @@ func @operand1_and_result_have_same_element_count_failure(%arg0: tensor<1xi32>,
// -----
func @four_equals_five() {
// expected-error@+1 {{failed to verify that 4 equals 5}}
"test.four_equals_five"() : () -> ()
return
}
// -----
func @operand_rank_equals_result_size_success(%arg : tensor<1x2x3x4xi32>) {
%0 = "test.operand_rank_equals_result_size"(%arg) : (tensor<1x2x3x4xi32>) -> tensor<4xi32>
%1 = "test.operand_rank_equals_result_size"(%arg) : (tensor<1x2x3x4xi32>) -> tensor<2x2xf32>
return
}
// -----
func @operand_rank_equals_result_size_failure(%arg : tensor<1x2x3x4xi32>) {
// expected-error@+1 {{failed to verify that operand rank equals result size}}
%0 = "test.operand_rank_equals_result_size"(%arg) : (tensor<1x2x3x4xi32>) -> tensor<2xi32>
return
}
// -----
func @same_types_element_mismatch(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
// expected-error@+1 {{all of {x, res} have same type}}
%0 = "test.operand_one_and_result_have_same_type"(%arg0, %arg1) : (tensor<* x i32>, tensor<* x f32>) -> tensor<* x f32>