forked from OSchip/llvm-project
Add trait for specified shapes matching
PiperOrigin-RevId: 274046434
This commit is contained in:
parent
6b1cc3c6ea
commit
736f80d0dd
|
@ -1493,6 +1493,9 @@ def HasNoUseOf: Constraint<
|
|||
class Rank<string name> :
|
||||
StrFunc<"$" # name # ".getType().cast<ShapedType>().getRank()">;
|
||||
|
||||
class Shape<string name> :
|
||||
StrFunc<"$" # name # ".getType().cast<ShapedType>().getShape()">;
|
||||
|
||||
class ElementCount<string name> :
|
||||
StrFunc<"$" # name # ".getType().cast<ShapedType>().getNumElements()">;
|
||||
|
||||
|
@ -1525,6 +1528,9 @@ class AllElementTypesMatch<list<string> names> :
|
|||
class AllRanksMatch<list<string> names> :
|
||||
AllMatchSameOperatorTrait<names, Rank<"_self">.result, "rank">;
|
||||
|
||||
class AllShapesMatch<list<string> names> :
|
||||
AllMatchSameOperatorTrait<names, Shape<"_self">.result, "shape">;
|
||||
|
||||
class AllTypesMatch<list<string> names> :
|
||||
AllMatchSameOperatorTrait<names, "$_self.getType()", "type">;
|
||||
|
||||
|
|
|
@ -309,6 +309,13 @@ def OperandZeroAndResultHaveSameRank :
|
|||
let results = (outs AnyShaped:$res);
|
||||
}
|
||||
|
||||
def OperandZeroAndResultHaveSameShape :
|
||||
TEST_Op<"operand0_and_result_have_same_shape",
|
||||
[AllShapesMatch<["x", "res"]>]> {
|
||||
let arguments = (ins AnyShaped:$x, AnyShaped:$y);
|
||||
let results = (outs AnyShaped:$res);
|
||||
}
|
||||
|
||||
def OperandZeroAndResultHaveSameElementCount :
|
||||
TEST_Op<"operand0_and_result_have_same_element_count",
|
||||
[AllElementCountsMatch<["x", "res"]>]> {
|
||||
|
|
|
@ -289,6 +289,24 @@ func @same_rank_failure(%arg0: tensor<1x2xi32>, %arg1: tensor<1x2xf32>) {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: same_shape_success
|
||||
func @same_shape_success(%t2x3: tensor<2x3xi32>, %m2x3: memref<2x3xf32>, %v2x3 : vector<2x3xi32>, %t4x5 : tensor<4x5xi32>) {
|
||||
"test.operand0_and_result_have_same_shape"(%t2x3, %t4x5) : (tensor<2x3xi32>, tensor<4x5xi32>) -> (tensor<2x3xf32>)
|
||||
"test.operand0_and_result_have_same_shape"(%t2x3, %t4x5) : (tensor<2x3xi32>, tensor<4x5xi32>) -> (memref<2x3xf32>)
|
||||
"test.operand0_and_result_have_same_shape"(%t2x3, %t4x5) : (tensor<2x3xi32>, tensor<4x5xi32>) -> (vector<2x3xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @same_shape_failure(%t2x3: tensor<2x3xi32>, %t4x5 : tensor<4x5xi32>) {
|
||||
// expected-error@+1 {{all of {x, res} have same shape}}
|
||||
"test.operand0_and_result_have_same_shape"(%t2x3, %t4x5) : (tensor<2x3xi32>, tensor<4x5xi32>) -> (tensor<1x3xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: same_element_count_success
|
||||
func @same_element_count_success(%arg0: tensor<36xi32>, %arg1: tensor<1x2xf32>, %arg3: tensor<f32>) {
|
||||
"test.operand0_and_result_have_same_element_count"(%arg0, %arg1) : (tensor<36xi32>, tensor<1x2xf32>) -> (tensor<3x4x3xf32>)
|
||||
|
|
Loading…
Reference in New Issue