Add trait for specified shapes matching

PiperOrigin-RevId: 274046434
This commit is contained in:
Geoffrey Martin-Noble 2019-10-10 15:01:34 -07:00 committed by Jacques Pienaar
parent 6b1cc3c6ea
commit 736f80d0dd
3 changed files with 31 additions and 0 deletions

View File

@ -1493,6 +1493,9 @@ def HasNoUseOf: Constraint<
class Rank<string name> :
StrFunc<"$" # name # ".getType().cast<ShapedType>().getRank()">;
class Shape<string name> :
StrFunc<"$" # name # ".getType().cast<ShapedType>().getShape()">;
class ElementCount<string name> :
StrFunc<"$" # name # ".getType().cast<ShapedType>().getNumElements()">;
@ -1525,6 +1528,9 @@ class AllElementTypesMatch<list<string> names> :
class AllRanksMatch<list<string> names> :
AllMatchSameOperatorTrait<names, Rank<"_self">.result, "rank">;
class AllShapesMatch<list<string> names> :
AllMatchSameOperatorTrait<names, Shape<"_self">.result, "shape">;
class AllTypesMatch<list<string> names> :
AllMatchSameOperatorTrait<names, "$_self.getType()", "type">;

View File

@ -309,6 +309,13 @@ def OperandZeroAndResultHaveSameRank :
let results = (outs AnyShaped:$res);
}
def OperandZeroAndResultHaveSameShape :
TEST_Op<"operand0_and_result_have_same_shape",
[AllShapesMatch<["x", "res"]>]> {
let arguments = (ins AnyShaped:$x, AnyShaped:$y);
let results = (outs AnyShaped:$res);
}
def OperandZeroAndResultHaveSameElementCount :
TEST_Op<"operand0_and_result_have_same_element_count",
[AllElementCountsMatch<["x", "res"]>]> {

View File

@ -289,6 +289,24 @@ func @same_rank_failure(%arg0: tensor<1x2xi32>, %arg1: tensor<1x2xf32>) {
// -----
// CHECK-LABEL: same_shape_success
func @same_shape_success(%t2x3: tensor<2x3xi32>, %m2x3: memref<2x3xf32>, %v2x3 : vector<2x3xi32>, %t4x5 : tensor<4x5xi32>) {
"test.operand0_and_result_have_same_shape"(%t2x3, %t4x5) : (tensor<2x3xi32>, tensor<4x5xi32>) -> (tensor<2x3xf32>)
"test.operand0_and_result_have_same_shape"(%t2x3, %t4x5) : (tensor<2x3xi32>, tensor<4x5xi32>) -> (memref<2x3xf32>)
"test.operand0_and_result_have_same_shape"(%t2x3, %t4x5) : (tensor<2x3xi32>, tensor<4x5xi32>) -> (vector<2x3xf32>)
return
}
// -----
func @same_shape_failure(%t2x3: tensor<2x3xi32>, %t4x5 : tensor<4x5xi32>) {
// expected-error@+1 {{all of {x, res} have same shape}}
"test.operand0_and_result_have_same_shape"(%t2x3, %t4x5) : (tensor<2x3xi32>, tensor<4x5xi32>) -> (tensor<1x3xf32>)
return
}
// -----
// CHECK-LABEL: same_element_count_success
func @same_element_count_success(%arg0: tensor<36xi32>, %arg1: tensor<1x2xf32>, %arg3: tensor<f32>) {
"test.operand0_and_result_have_same_element_count"(%arg0, %arg1) : (tensor<36xi32>, tensor<1x2xf32>) -> (tensor<3x4x3xf32>)