forked from OSchip/llvm-project
Add tablegen class for memrefs with rank constraints
PiperOrigin-RevId: 268968004
This commit is contained in:
parent
8a1cdeb31b
commit
a260436714
|
@ -363,6 +363,15 @@ class ShapedContainerType<list<Type> allowedTypes, Pred containerPred, string de
|
|||
ContainerType<AnyTypeOf<allowedTypes>, containerPred,
|
||||
"$_self.cast<ShapedType>().getElementType()", descr>;
|
||||
|
||||
// Whether a shaped type is ranked.
|
||||
def HasRankPred : CPred<"$_self.cast<ShapedType>().hasRank()">;
|
||||
|
||||
// Whether a shaped type has one of the specified ranks.
|
||||
class HasAnyRankOfPred<list<int> ranks> : And<[
|
||||
HasRankPred,
|
||||
Or<!foreach(rank, ranks,
|
||||
CPred<"$_self.cast<ShapedType>().getRank() == " # rank>)>]>;
|
||||
|
||||
// Vector types.
|
||||
|
||||
class VectorOf<list<Type> allowedTypes> :
|
||||
|
@ -396,15 +405,6 @@ def F16Tensor : TensorOf<[F16]>;
|
|||
def F32Tensor : TensorOf<[F32]>;
|
||||
def F64Tensor : TensorOf<[F64]>;
|
||||
|
||||
// Whether a shaped type is ranked.
|
||||
def HasRankPred : CPred<"$_self.cast<ShapedType>().hasRank()">;
|
||||
|
||||
// Whether a shaped type has one of the specified ranks.
|
||||
class HasAnyRankOfPred<list<int> ranks> : And<[
|
||||
HasRankPred,
|
||||
Or<!foreach(rank, ranks,
|
||||
CPred<"$_self.cast<ShapedType>().getRank() == " # rank>)>]>;
|
||||
|
||||
// Ranked tensor type with one of the specified types and ranks.
|
||||
class TensorRankOf<list<Type> allowedTypes, list<int> ranks> :
|
||||
Type<And<[TensorOf<allowedTypes>.predicate, HasAnyRankOfPred<ranks>]>,
|
||||
|
@ -439,6 +439,11 @@ def F32MemRef : MemRefOf<[F32]>;
|
|||
def F64MemRef : MemRefOf<[F64]>;
|
||||
|
||||
// TODO(b/130064155) Have an easy way to add another constraint to a type.
|
||||
class MemRefRankOf<list<Type> allowedTypes, list<int> ranks> :
|
||||
Type<And<[MemRefOf<allowedTypes>.predicate, HasAnyRankOfPred<ranks>]>,
|
||||
StrJoin<!foreach(rank, ranks, rank # "D"), "/">.result # " " #
|
||||
MemRefOf<allowedTypes>.description>;
|
||||
|
||||
class StaticShapeMemRefOf<list<Type> allowedTypes>
|
||||
: Type<And<[MemRefOf<allowedTypes>.predicate, HasStaticShapePred]>,
|
||||
"statically shaped " # MemRefOf<allowedTypes>.description>;
|
||||
|
|
|
@ -50,6 +50,10 @@ def TakesStaticMemRefOp : TEST_Op<"takes_static_memref"> {
|
|||
let arguments = (ins AnyStaticShapeMemRef:$x);
|
||||
}
|
||||
|
||||
def RankLessThan2I8F32MemRefOp : TEST_Op<"rank_less_than_2_I8_F32_memref"> {
|
||||
let results = (outs MemRefRankOf<[I8, F32], [0, 1]>);
|
||||
}
|
||||
|
||||
def NDTensorOfOp : TEST_Op<"nd_tensor_of"> {
|
||||
let arguments = (ins
|
||||
0DTensorOf<[F32]>:$arg0,
|
||||
|
|
|
@ -81,6 +81,33 @@ func @nested_tuple_multi_level_wrong_type() {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @rank_less_than_2_I8_F32_memref_success
|
||||
func @rank_less_than_2_I8_F32_memref_success() {
|
||||
"test.rank_less_than_2_I8_F32_memref"() : () -> (memref<i8>)
|
||||
"test.rank_less_than_2_I8_F32_memref"() : () -> (memref<3xi8>)
|
||||
"test.rank_less_than_2_I8_F32_memref"() : () -> (memref<f32>)
|
||||
"test.rank_less_than_2_I8_F32_memref"() : () -> (memref<1xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @rank_less_than_2_I8_F32_memref_bad_type() {
|
||||
// expected-error@+1 {{must be 0D/1D memref of 8-bit integer or 32-bit float values}}
|
||||
"test.rank_less_than_2_I8_F32_memref"() : () -> (memref<i16>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @rank_less_than_2_I8_F32_memref_bad_rank() {
|
||||
// expected-error@+1 {{must be 0D/1D memref of 8-bit integer or 32-bit float values}}
|
||||
"test.rank_less_than_2_I8_F32_memref"() : () -> (memref<1x2xi8>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @nd_tensor_of_success(%arg0: tensor<f32>, %arg1: tensor<10xf32>, %arg2: tensor<20x30xi16>, %arg3: tensor<40x50x60xi16>, %arg4: tensor<70x80x90x100xi16>) {
|
||||
"test.nd_tensor_of"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<f32>, tensor<10xf32>, tensor<20x30xi16>, tensor<40x50x60xi16>, tensor<70x80x90x100xi16>) -> ()
|
||||
return
|
||||
|
|
Loading…
Reference in New Issue