Add type constraints for shaped types with same rank and element count

PiperOrigin-RevId: 269000237
This commit is contained in:
Geoffrey Martin-Noble 2019-09-13 16:05:06 -07:00 committed by A. Unique TensorFlower
parent 113aadddf9
commit efbd3e4610
3 changed files with 85 additions and 0 deletions

View File

@ -1410,9 +1410,16 @@ class AllMatchTrait<list<string> names, string operator, string description> :
"all of {" # StrJoin<names>.result # "} have same " # description,
AllMatchPred<names, operator>>;
class AllElementCountsMatch<list<string> names> :
AllMatchTrait<names, "$_self.getType().cast<ShapedType>().getNumElements()",
"element count">;
class AllElementTypesMatch<list<string> names> :
AllMatchTrait<names, "getElementTypeOrSelf($_self)", "element type">;
class AllRanksMatch<list<string> names> :
AllMatchTrait<names, "$_self.getType().cast<ShapedType>().getRank()", "rank">;
class AllTypesMatch<list<string> names> :
AllMatchTrait<names, "$_self.getType()", "type">;

View File

@ -246,6 +246,25 @@ def OperandOneAndResultHaveSameType :
let results = (outs AnyTensor:$res);
}
def OperandsHaveSameRank :
TEST_Op<"operands_have_same_rank", [AllRanksMatch<["x", "y"]>]> {
let arguments = (ins AnyTensor:$x, AnyTensor:$y);
}
def Operand1AndResultHaveSameRank :
TEST_Op<"operand1_and_result_have_same_rank",
[AllRanksMatch<["x", "res"]>]> {
let arguments = (ins AnyTensor:$x, AnyTensor:$y);
let results = (outs AnyTensor:$res);
}
def Operand1AndResultHaveSameElementCount :
TEST_Op<"operand1_and_result_have_same_element_count",
[AllElementCountsMatch<["x", "res"]>]> {
let arguments = (ins AnyTensor:$x, AnyTensor:$y);
let results = (outs AnyTensor:$res);
}
def IfFirstOperandIsNoneThenSoIsSecond :
TEST_Op<"if_first_operand_is_none_then_so_is_second", [PredOpTrait<
"has either both none type operands or first is not none",

View File

@ -239,6 +239,65 @@ func @same_types(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
// -----
// CHECK-LABEL: operands_have_same_rank_success
func @operands_have_same_rank_success(%arg0: tensor<1xi32>, %arg1: tensor<2xf32>) {
"test.operands_have_same_rank"(%arg0, %arg1) : (tensor<1xi32>, tensor<2xf32>) -> ()
return
}
// -----
func @operands_have_same_rank_failure(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) {
// expected-error@+1 {{all of {x, y} have same rank}}
"test.operands_have_same_rank"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> ()
return
}
// -----
// CHECK-LABEL: operand1_and_result_have_same_rank_success
func @operand1_and_result_have_same_rank_success(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>, %arg3: tensor<1x2xi32>) {
"test.operand1_and_result_have_same_rank"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> (tensor<3xf32>)
"test.operand1_and_result_have_same_rank"(%arg3, %arg1) : (tensor<1x2xi32>, tensor<1x2xf32>) -> (tensor<3x3xf64>)
return
}
// -----
func @operand1_and_result_have_same_rank_failure(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) {
// expected-error@+1 {{all of {x, res} have same rank}}
"test.operand1_and_result_have_same_rank"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> (tensor<i32>)
return
}
// -----
func @operand1_and_result_have_same_rank_failure(%arg0: tensor<1x2xi32>, %arg1: tensor<1x2xf32>) {
// expected-error@+1 {{all of {x, res} have same rank}}
"test.operand1_and_result_have_same_rank"(%arg0, %arg1) : (tensor<1x2xi32>, tensor<1x2xf32>) -> (tensor<3xi32>)
return
}
// -----
// CHECK-LABEL: operand1_and_result_have_same_element_count_success
func @operand1_and_result_have_same_element_count_success(%arg0: tensor<36xi32>, %arg1: tensor<1x2xf32>, %arg3: tensor<f32>) {
"test.operand1_and_result_have_same_element_count"(%arg0, %arg1) : (tensor<36xi32>, tensor<1x2xf32>) -> (tensor<3x4x3xf32>)
"test.operand1_and_result_have_same_element_count"(%arg0, %arg1) : (tensor<36xi32>, tensor<1x2xf32>) -> (tensor<3x12xf64>)
"test.operand1_and_result_have_same_element_count"(%arg3, %arg1) : (tensor<f32>, tensor<1x2xf32>) -> (tensor<1x1x1xi32>)
return
}
// -----
func @operand1_and_result_have_same_element_count_failure(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) {
// expected-error@+1 {{all of {x, res} have same element count}}
"test.operand1_and_result_have_same_element_count"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> (tensor<2xi32>)
return
}
// -----
func @same_types_element_mismatch(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
// expected-error@+1 {{all of {x, res} have same type}}
%0 = "test.operand_one_and_result_have_same_type"(%arg0, %arg1) : (tensor<* x i32>, tensor<* x f32>) -> tensor<* x f32>