forked from OSchip/llvm-project
Add type constraints for shaped types with same rank and element count
PiperOrigin-RevId: 269000237
This commit is contained in:
parent
113aadddf9
commit
efbd3e4610
|
@ -1410,9 +1410,16 @@ class AllMatchTrait<list<string> names, string operator, string description> :
|
|||
"all of {" # StrJoin<names>.result # "} have same " # description,
|
||||
AllMatchPred<names, operator>>;
|
||||
|
||||
class AllElementCountsMatch<list<string> names> :
|
||||
AllMatchTrait<names, "$_self.getType().cast<ShapedType>().getNumElements()",
|
||||
"element count">;
|
||||
|
||||
class AllElementTypesMatch<list<string> names> :
|
||||
AllMatchTrait<names, "getElementTypeOrSelf($_self)", "element type">;
|
||||
|
||||
class AllRanksMatch<list<string> names> :
|
||||
AllMatchTrait<names, "$_self.getType().cast<ShapedType>().getRank()", "rank">;
|
||||
|
||||
class AllTypesMatch<list<string> names> :
|
||||
AllMatchTrait<names, "$_self.getType()", "type">;
|
||||
|
||||
|
|
|
@ -246,6 +246,25 @@ def OperandOneAndResultHaveSameType :
|
|||
let results = (outs AnyTensor:$res);
|
||||
}
|
||||
|
||||
def OperandsHaveSameRank :
|
||||
TEST_Op<"operands_have_same_rank", [AllRanksMatch<["x", "y"]>]> {
|
||||
let arguments = (ins AnyTensor:$x, AnyTensor:$y);
|
||||
}
|
||||
|
||||
def Operand1AndResultHaveSameRank :
|
||||
TEST_Op<"operand1_and_result_have_same_rank",
|
||||
[AllRanksMatch<["x", "res"]>]> {
|
||||
let arguments = (ins AnyTensor:$x, AnyTensor:$y);
|
||||
let results = (outs AnyTensor:$res);
|
||||
}
|
||||
|
||||
def Operand1AndResultHaveSameElementCount :
|
||||
TEST_Op<"operand1_and_result_have_same_element_count",
|
||||
[AllElementCountsMatch<["x", "res"]>]> {
|
||||
let arguments = (ins AnyTensor:$x, AnyTensor:$y);
|
||||
let results = (outs AnyTensor:$res);
|
||||
}
|
||||
|
||||
def IfFirstOperandIsNoneThenSoIsSecond :
|
||||
TEST_Op<"if_first_operand_is_none_then_so_is_second", [PredOpTrait<
|
||||
"has either both none type operands or first is not none",
|
||||
|
|
|
@ -239,6 +239,65 @@ func @same_types(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: operands_have_same_rank_success
|
||||
func @operands_have_same_rank_success(%arg0: tensor<1xi32>, %arg1: tensor<2xf32>) {
|
||||
"test.operands_have_same_rank"(%arg0, %arg1) : (tensor<1xi32>, tensor<2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @operands_have_same_rank_failure(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) {
|
||||
// expected-error@+1 {{all of {x, y} have same rank}}
|
||||
"test.operands_have_same_rank"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: operand1_and_result_have_same_rank_success
|
||||
func @operand1_and_result_have_same_rank_success(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>, %arg3: tensor<1x2xi32>) {
|
||||
"test.operand1_and_result_have_same_rank"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> (tensor<3xf32>)
|
||||
"test.operand1_and_result_have_same_rank"(%arg3, %arg1) : (tensor<1x2xi32>, tensor<1x2xf32>) -> (tensor<3x3xf64>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @operand1_and_result_have_same_rank_failure(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) {
|
||||
// expected-error@+1 {{all of {x, res} have same rank}}
|
||||
"test.operand1_and_result_have_same_rank"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> (tensor<i32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @operand1_and_result_have_same_rank_failure(%arg0: tensor<1x2xi32>, %arg1: tensor<1x2xf32>) {
|
||||
// expected-error@+1 {{all of {x, res} have same rank}}
|
||||
"test.operand1_and_result_have_same_rank"(%arg0, %arg1) : (tensor<1x2xi32>, tensor<1x2xf32>) -> (tensor<3xi32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: operand1_and_result_have_same_element_count_success
|
||||
func @operand1_and_result_have_same_element_count_success(%arg0: tensor<36xi32>, %arg1: tensor<1x2xf32>, %arg3: tensor<f32>) {
|
||||
"test.operand1_and_result_have_same_element_count"(%arg0, %arg1) : (tensor<36xi32>, tensor<1x2xf32>) -> (tensor<3x4x3xf32>)
|
||||
"test.operand1_and_result_have_same_element_count"(%arg0, %arg1) : (tensor<36xi32>, tensor<1x2xf32>) -> (tensor<3x12xf64>)
|
||||
"test.operand1_and_result_have_same_element_count"(%arg3, %arg1) : (tensor<f32>, tensor<1x2xf32>) -> (tensor<1x1x1xi32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @operand1_and_result_have_same_element_count_failure(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) {
|
||||
// expected-error@+1 {{all of {x, res} have same element count}}
|
||||
"test.operand1_and_result_have_same_element_count"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> (tensor<2xi32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @same_types_element_mismatch(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
|
||||
// expected-error@+1 {{all of {x, res} have same type}}
|
||||
%0 = "test.operand_one_and_result_have_same_type"(%arg0, %arg1) : (tensor<* x i32>, tensor<* x f32>) -> tensor<* x f32>
|
||||
|
|
Loading…
Reference in New Issue