Add tablegen class for memrefs with rank constraints

PiperOrigin-RevId: 268968004
This commit is contained in:
Geoffrey Martin-Noble 2019-09-13 13:21:57 -07:00 committed by A. Unique TensorFlower
parent 8a1cdeb31b
commit a260436714
3 changed files with 45 additions and 9 deletions

View File

@ -363,6 +363,15 @@ class ShapedContainerType<list<Type> allowedTypes, Pred containerPred, string de
ContainerType<AnyTypeOf<allowedTypes>, containerPred,
"$_self.cast<ShapedType>().getElementType()", descr>;
// Whether a shaped type is ranked.
def HasRankPred : CPred<"$_self.cast<ShapedType>().hasRank()">;
// Whether a shaped type has one of the specified ranks.
class HasAnyRankOfPred<list<int> ranks> : And<[
HasRankPred,
Or<!foreach(rank, ranks,
CPred<"$_self.cast<ShapedType>().getRank() == " # rank>)>]>;
// Vector types.
class VectorOf<list<Type> allowedTypes> :
@ -396,15 +405,6 @@ def F16Tensor : TensorOf<[F16]>;
def F32Tensor : TensorOf<[F32]>;
def F64Tensor : TensorOf<[F64]>;
// Whether a shaped type is ranked.
def HasRankPred : CPred<"$_self.cast<ShapedType>().hasRank()">;
// Whether a shaped type has one of the specified ranks.
class HasAnyRankOfPred<list<int> ranks> : And<[
HasRankPred,
Or<!foreach(rank, ranks,
CPred<"$_self.cast<ShapedType>().getRank() == " # rank>)>]>;
// Ranked tensor type with one of the specified types and ranks.
class TensorRankOf<list<Type> allowedTypes, list<int> ranks> :
Type<And<[TensorOf<allowedTypes>.predicate, HasAnyRankOfPred<ranks>]>,
@ -439,6 +439,11 @@ def F32MemRef : MemRefOf<[F32]>;
def F64MemRef : MemRefOf<[F64]>;
// TODO(b/130064155) Have an easy way to add another constraint to a type.
class MemRefRankOf<list<Type> allowedTypes, list<int> ranks> :
Type<And<[MemRefOf<allowedTypes>.predicate, HasAnyRankOfPred<ranks>]>,
StrJoin<!foreach(rank, ranks, rank # "D"), "/">.result # " " #
MemRefOf<allowedTypes>.description>;
class StaticShapeMemRefOf<list<Type> allowedTypes>
: Type<And<[MemRefOf<allowedTypes>.predicate, HasStaticShapePred]>,
"statically shaped " # MemRefOf<allowedTypes>.description>;

View File

@ -50,6 +50,10 @@ def TakesStaticMemRefOp : TEST_Op<"takes_static_memref"> {
let arguments = (ins AnyStaticShapeMemRef:$x);
}
def RankLessThan2I8F32MemRefOp : TEST_Op<"rank_less_than_2_I8_F32_memref"> {
let results = (outs MemRefRankOf<[I8, F32], [0, 1]>);
}
def NDTensorOfOp : TEST_Op<"nd_tensor_of"> {
let arguments = (ins
0DTensorOf<[F32]>:$arg0,

View File

@ -81,6 +81,33 @@ func @nested_tuple_multi_level_wrong_type() {
// -----
// CHECK-LABEL: func @rank_less_than_2_I8_F32_memref_success
func @rank_less_than_2_I8_F32_memref_success() {
"test.rank_less_than_2_I8_F32_memref"() : () -> (memref<i8>)
"test.rank_less_than_2_I8_F32_memref"() : () -> (memref<3xi8>)
"test.rank_less_than_2_I8_F32_memref"() : () -> (memref<f32>)
"test.rank_less_than_2_I8_F32_memref"() : () -> (memref<1xf32>)
return
}
// -----
func @rank_less_than_2_I8_F32_memref_bad_type() {
// expected-error@+1 {{must be 0D/1D memref of 8-bit integer or 32-bit float values}}
"test.rank_less_than_2_I8_F32_memref"() : () -> (memref<i16>)
return
}
// -----
func @rank_less_than_2_I8_F32_memref_bad_rank() {
// expected-error@+1 {{must be 0D/1D memref of 8-bit integer or 32-bit float values}}
"test.rank_less_than_2_I8_F32_memref"() : () -> (memref<1x2xi8>)
return
}
// -----
func @nd_tensor_of_success(%arg0: tensor<f32>, %arg1: tensor<10xf32>, %arg2: tensor<20x30xi16>, %arg3: tensor<40x50x60xi16>, %arg4: tensor<70x80x90x100xi16>) {
"test.nd_tensor_of"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<f32>, tensor<10xf32>, tensor<20x30xi16>, tensor<40x50x60xi16>, tensor<70x80x90x100xi16>) -> ()
return