From a260436714b35584bbc1cdd39834b5556b991178 Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Fri, 13 Sep 2019 13:21:57 -0700 Subject: [PATCH] Add tablegen class for memrefs with rank constraints PiperOrigin-RevId: 268968004 --- mlir/include/mlir/IR/OpBase.td | 23 ++++++++++++++--------- mlir/test/lib/TestDialect/TestOps.td | 4 ++++ mlir/test/mlir-tblgen/types.mlir | 27 +++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index cee4c1d3aab2..5467e2f78215 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -363,6 +363,15 @@ class ShapedContainerType allowedTypes, Pred containerPred, string de ContainerType, containerPred, "$_self.cast().getElementType()", descr>; +// Whether a shaped type is ranked. +def HasRankPred : CPred<"$_self.cast().hasRank()">; + +// Whether a shaped type has one of the specified ranks. +class HasAnyRankOfPred ranks> : And<[ + HasRankPred, + Or().getRank() == " # rank>)>]>; + // Vector types. class VectorOf allowedTypes> : @@ -396,15 +405,6 @@ def F16Tensor : TensorOf<[F16]>; def F32Tensor : TensorOf<[F32]>; def F64Tensor : TensorOf<[F64]>; -// Whether a shaped type is ranked. -def HasRankPred : CPred<"$_self.cast().hasRank()">; - -// Whether a shaped type has one of the specified ranks. -class HasAnyRankOfPred ranks> : And<[ - HasRankPred, - Or().getRank() == " # rank>)>]>; - // Ranked tensor type with one of the specified types and ranks. class TensorRankOf allowedTypes, list ranks> : Type.predicate, HasAnyRankOfPred]>, @@ -439,6 +439,11 @@ def F32MemRef : MemRefOf<[F32]>; def F64MemRef : MemRefOf<[F64]>; // TODO(b/130064155) Have an easy way to add another constraint to a type. +class MemRefRankOf allowedTypes, list ranks> : + Type.predicate, HasAnyRankOfPred]>, + StrJoin.result # " " # + MemRefOf.description>; + class StaticShapeMemRefOf allowedTypes> : Type.predicate, HasStaticShapePred]>, "statically shaped " # MemRefOf.description>; diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td index 1fda7d41356f..fb2c2b5c1976 100644 --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -50,6 +50,10 @@ def TakesStaticMemRefOp : TEST_Op<"takes_static_memref"> { let arguments = (ins AnyStaticShapeMemRef:$x); } +def RankLessThan2I8F32MemRefOp : TEST_Op<"rank_less_than_2_I8_F32_memref"> { + let results = (outs MemRefRankOf<[I8, F32], [0, 1]>); +} + def NDTensorOfOp : TEST_Op<"nd_tensor_of"> { let arguments = (ins 0DTensorOf<[F32]>:$arg0, diff --git a/mlir/test/mlir-tblgen/types.mlir b/mlir/test/mlir-tblgen/types.mlir index 6f4dfbb1fbcb..7050da17bcfa 100644 --- a/mlir/test/mlir-tblgen/types.mlir +++ b/mlir/test/mlir-tblgen/types.mlir @@ -81,6 +81,33 @@ func @nested_tuple_multi_level_wrong_type() { // ----- +// CHECK-LABEL: func @rank_less_than_2_I8_F32_memref_success +func @rank_less_than_2_I8_F32_memref_success() { + "test.rank_less_than_2_I8_F32_memref"() : () -> (memref) + "test.rank_less_than_2_I8_F32_memref"() : () -> (memref<3xi8>) + "test.rank_less_than_2_I8_F32_memref"() : () -> (memref) + "test.rank_less_than_2_I8_F32_memref"() : () -> (memref<1xf32>) + return +} + +// ----- + +func @rank_less_than_2_I8_F32_memref_bad_type() { + // expected-error@+1 {{must be 0D/1D memref of 8-bit integer or 32-bit float values}} + "test.rank_less_than_2_I8_F32_memref"() : () -> (memref) + return +} + +// ----- + +func @rank_less_than_2_I8_F32_memref_bad_rank() { + // expected-error@+1 {{must be 0D/1D memref of 8-bit integer or 32-bit float values}} + "test.rank_less_than_2_I8_F32_memref"() : () -> (memref<1x2xi8>) + return +} + +// ----- + func @nd_tensor_of_success(%arg0: tensor, %arg1: tensor<10xf32>, %arg2: tensor<20x30xi16>, %arg3: tensor<40x50x60xi16>, %arg4: tensor<70x80x90x100xi16>) { "test.nd_tensor_of"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor, tensor<10xf32>, tensor<20x30xi16>, tensor<40x50x60xi16>, tensor<70x80x90x100xi16>) -> () return