Add ability to verify type matching between operands/results

This extends and generalizes the functionality for checking that element types match

PiperOrigin-RevId: 253110512
This commit is contained in:
Geoffrey Martin-Noble 2019-06-13 14:52:29 -07:00 committed by Mehdi Amini
parent 54b35cec08
commit d156b83060
3 changed files with 80 additions and 13 deletions

View File

@ -1075,16 +1075,23 @@ class ElementTypeIsPred<string name, Type type> : And<[
class ElementTypeIs<string name, Type type> : PredOpTrait<
"'" # name # "' is " # type.description, ElementTypeIsPred<name, type>>;
// Predicate to verify that all the arguments and results have the same element
// type.
// TODO(b/135032064): Only works for non-variadic.
class AllElementTypesMatchPred<list<string> names> :
CPred<"llvm::is_splat(ArrayRef<Type>{" # !if(!empty(names), "",
!foldl("getElementTypeOrSelf($" # !head(names) # ")", !tail(names),
prev, cur, prev # ", getElementTypeOrSelf($" # cur # ")")) # "})">;
class AllElementTypesMatch<list<string> names> : PredOpTrait<
"all of {" # StrJoin<names>.result # "} have same element type",
AllElementTypesMatchPred<names>>;
class AllMatchPred<list<string> names, string operator> :
CPred<"llvm::is_splat(llvm::makeArrayRef({" #
StrJoin<!foreach(n, names,
!subst("$_self", "$" # n, operator))>.result
# "}))">;
class AllMatchTrait<list<string> names, string operator, string description> :
PredOpTrait<
"all of {" # StrJoin<names>.result # "} have same " # description,
AllMatchPred<names, operator>>;
class AllElementTypesMatch<list<string> names> :
AllMatchTrait<names, "getElementTypeOrSelf($_self)", "element type">;
class AllTypesMatch<list<string> names> :
AllMatchTrait<names, "$_self.getType()", "type">;
// Predicate to verify that the i'th operand and the j'th operand have the same
// elemental type.

View File

@ -134,6 +134,18 @@ def OperandOneAndResultHaveSameElementType : TEST_Op<
let results = (outs AnyTensor:$res);
}
def OperandsHaveSameType :
TEST_Op<"operands_have_same_type", [AllTypesMatch<["x", "y"]>]> {
let arguments = (ins AnyTensor:$x, AnyTensor:$y);
}
def OperandOneAndResultHaveSameType :
TEST_Op<"operand_one_and_result_have_same_type",
[AllTypesMatch<["x", "res"]>]> {
let arguments = (ins AnyTensor:$x, AnyTensor:$y);
let results = (outs AnyTensor:$res);
}
//===----------------------------------------------------------------------===//
// Test Patterns
//===----------------------------------------------------------------------===//

View File

@ -105,7 +105,7 @@ func @fixed_element_types(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
// -----
func @fixed_element_types(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
func @same_element_types(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
// expected-error@+1 {{verify that all of {x, y} have same element type}}
"test.operands_have_same_element_type"(%arg1, %arg0): (tensor<* x f32>, tensor<* x i32>) -> ()
return
@ -113,16 +113,64 @@ func @fixed_element_types(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
// -----
// CHECK-LABEL: same_types
func @same_types(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
// CHECK-LABEL: same_element_types
func @same_element_types(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
%0 = "test.operand_one_and_result_have_same_element_type"(%arg1, %arg0) : (tensor<* x f32>, tensor<* x i32>) -> tensor<* x f32>
return
}
// -----
func @same_types(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
func @same_element_types(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
// expected-error@+1 {{all of {x, res} have same element type}}
%0 = "test.operand_one_and_result_have_same_element_type"(%arg1, %arg0) : (tensor<* x f32>, tensor<* x i32>) -> tensor<* x i32>
return
}
// -----
// CHECK-LABEL: same_types
func @same_types(%arg0: tensor<* x i32>, %arg1: tensor<* x i32>) {
"test.operands_have_same_type"(%arg0, %arg1) : (tensor<* x i32>, tensor<* x i32>) -> ()
return
}
// -----
func @same_types_element_mismatch(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
// expected-error@+1 {{all of {x, y} have same type}}
"test.operands_have_same_type"(%arg0, %arg1) : (tensor<* x i32>, tensor<* x f32>) -> ()
return
}
// -----
func @same_types_shape_mismatch(%arg0: tensor<1x2xi32>, %arg1: tensor<2x1xi32>) {
// expected-error@+1 {{all of {x, y} have same type}}
"test.operands_have_same_type"(%arg0, %arg1) : (tensor<1x2xi32>, tensor<2x1xi32>) -> ()
return
}
// -----
// CHECK-LABEL: same_types
func @same_types(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
%0 = "test.operand_one_and_result_have_same_type"(%arg0, %arg1) : (tensor<* x i32>, tensor<* x f32>) -> tensor<* x i32>
return
}
// -----
func @same_types_element_mismatch(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
// expected-error@+1 {{all of {x, res} have same type}}
%0 = "test.operand_one_and_result_have_same_type"(%arg0, %arg1) : (tensor<* x i32>, tensor<* x f32>) -> tensor<* x f32>
return
}
// -----
func @same_types_shape_mismatch(%arg0: tensor<1x2xi32>, %arg1: tensor<2x1xi32>) {
// expected-error@+1 {{all of {x, res} have same type}}
%0 = "test.operand_one_and_result_have_same_type"(%arg0, %arg1) : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor<2x1xi32>
return
}