forked from OSchip/llvm-project
Add ability to verify type matching between operands/results
This extends and generalizes the functionality for checking that element types match PiperOrigin-RevId: 253110512
This commit is contained in:
parent
54b35cec08
commit
d156b83060
|
@ -1075,16 +1075,23 @@ class ElementTypeIsPred<string name, Type type> : And<[
|
|||
class ElementTypeIs<string name, Type type> : PredOpTrait<
|
||||
"'" # name # "' is " # type.description, ElementTypeIsPred<name, type>>;
|
||||
|
||||
// Predicate to verify that all the arguments and results have the same element
|
||||
// type.
|
||||
// TODO(b/135032064): Only works for non-variadic.
|
||||
class AllElementTypesMatchPred<list<string> names> :
|
||||
CPred<"llvm::is_splat(ArrayRef<Type>{" # !if(!empty(names), "",
|
||||
!foldl("getElementTypeOrSelf($" # !head(names) # ")", !tail(names),
|
||||
prev, cur, prev # ", getElementTypeOrSelf($" # cur # ")")) # "})">;
|
||||
class AllElementTypesMatch<list<string> names> : PredOpTrait<
|
||||
"all of {" # StrJoin<names>.result # "} have same element type",
|
||||
AllElementTypesMatchPred<names>>;
|
||||
class AllMatchPred<list<string> names, string operator> :
|
||||
CPred<"llvm::is_splat(llvm::makeArrayRef({" #
|
||||
StrJoin<!foreach(n, names,
|
||||
!subst("$_self", "$" # n, operator))>.result
|
||||
# "}))">;
|
||||
|
||||
class AllMatchTrait<list<string> names, string operator, string description> :
|
||||
PredOpTrait<
|
||||
"all of {" # StrJoin<names>.result # "} have same " # description,
|
||||
AllMatchPred<names, operator>>;
|
||||
|
||||
class AllElementTypesMatch<list<string> names> :
|
||||
AllMatchTrait<names, "getElementTypeOrSelf($_self)", "element type">;
|
||||
|
||||
class AllTypesMatch<list<string> names> :
|
||||
AllMatchTrait<names, "$_self.getType()", "type">;
|
||||
|
||||
// Predicate to verify that the i'th operand and the j'th operand have the same
|
||||
// elemental type.
|
||||
|
|
|
@ -134,6 +134,18 @@ def OperandOneAndResultHaveSameElementType : TEST_Op<
|
|||
let results = (outs AnyTensor:$res);
|
||||
}
|
||||
|
||||
def OperandsHaveSameType :
|
||||
TEST_Op<"operands_have_same_type", [AllTypesMatch<["x", "y"]>]> {
|
||||
let arguments = (ins AnyTensor:$x, AnyTensor:$y);
|
||||
}
|
||||
|
||||
def OperandOneAndResultHaveSameType :
|
||||
TEST_Op<"operand_one_and_result_have_same_type",
|
||||
[AllTypesMatch<["x", "res"]>]> {
|
||||
let arguments = (ins AnyTensor:$x, AnyTensor:$y);
|
||||
let results = (outs AnyTensor:$res);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test Patterns
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -105,7 +105,7 @@ func @fixed_element_types(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
|
|||
|
||||
// -----
|
||||
|
||||
func @fixed_element_types(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
|
||||
func @same_element_types(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
|
||||
// expected-error@+1 {{verify that all of {x, y} have same element type}}
|
||||
"test.operands_have_same_element_type"(%arg1, %arg0): (tensor<* x f32>, tensor<* x i32>) -> ()
|
||||
return
|
||||
|
@ -113,16 +113,64 @@ func @fixed_element_types(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: same_types
|
||||
func @same_types(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
|
||||
// CHECK-LABEL: same_element_types
|
||||
func @same_element_types(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
|
||||
%0 = "test.operand_one_and_result_have_same_element_type"(%arg1, %arg0) : (tensor<* x f32>, tensor<* x i32>) -> tensor<* x f32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @same_types(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
|
||||
func @same_element_types(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
|
||||
// expected-error@+1 {{all of {x, res} have same element type}}
|
||||
%0 = "test.operand_one_and_result_have_same_element_type"(%arg1, %arg0) : (tensor<* x f32>, tensor<* x i32>) -> tensor<* x i32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: same_types
|
||||
func @same_types(%arg0: tensor<* x i32>, %arg1: tensor<* x i32>) {
|
||||
"test.operands_have_same_type"(%arg0, %arg1) : (tensor<* x i32>, tensor<* x i32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @same_types_element_mismatch(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
|
||||
// expected-error@+1 {{all of {x, y} have same type}}
|
||||
"test.operands_have_same_type"(%arg0, %arg1) : (tensor<* x i32>, tensor<* x f32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @same_types_shape_mismatch(%arg0: tensor<1x2xi32>, %arg1: tensor<2x1xi32>) {
|
||||
// expected-error@+1 {{all of {x, y} have same type}}
|
||||
"test.operands_have_same_type"(%arg0, %arg1) : (tensor<1x2xi32>, tensor<2x1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: same_types
|
||||
func @same_types(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
|
||||
%0 = "test.operand_one_and_result_have_same_type"(%arg0, %arg1) : (tensor<* x i32>, tensor<* x f32>) -> tensor<* x i32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @same_types_element_mismatch(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
|
||||
// expected-error@+1 {{all of {x, res} have same type}}
|
||||
%0 = "test.operand_one_and_result_have_same_type"(%arg0, %arg1) : (tensor<* x i32>, tensor<* x f32>) -> tensor<* x f32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @same_types_shape_mismatch(%arg0: tensor<1x2xi32>, %arg1: tensor<2x1xi32>) {
|
||||
// expected-error@+1 {{all of {x, res} have same type}}
|
||||
%0 = "test.operand_one_and_result_have_same_type"(%arg0, %arg1) : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor<2x1xi32>
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue