From ddf2d62c7dddf1e4a9012d96819ff1ed005fbb05 Mon Sep 17 00:00:00 2001 From: Michal Terepeta Date: Wed, 17 Nov 2021 14:57:55 +0000 Subject: [PATCH] [mlir][Vector] First step for 0D vector type There seems to be a consensus that we should allow 0D vectors: https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097 This commit is only the first step: it changes the verifier and the parser to allow vectors like `vector` (but does not allow explicit 0 dimensions, i.e., `vector<0xf32>` is not allowed). Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D114086 --- mlir/include/mlir/IR/BuiltinTypes.td | 6 +++--- mlir/lib/IR/BuiltinTypes.cpp | 3 --- mlir/lib/Parser/TypeParser.cpp | 6 +----- mlir/test/IR/invalid.mlir | 2 +- mlir/test/IR/parser.mlir | 4 ++-- 5 files changed, 7 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td index 41577193c778..af47f1d0f16c 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -895,7 +895,7 @@ def Builtin_Vector : Builtin_Type<"Vector", [ vector-type ::= `vector` `<` static-dimension-list vector-element-type `>` vector-element-type ::= float-type | integer-type | index-type - static-dimension-list ::= (decimal-literal `x`)+ + static-dimension-list ::= (decimal-literal `x`)* ``` The vector type represents a SIMD style vector, used by target-specific @@ -903,13 +903,13 @@ def Builtin_Vector : Builtin_Type<"Vector", [ vector<16 x f32>) we also support multidimensional registers on targets that support them (like TPUs). - Vector shapes must be positive decimal integers. + Vector shapes must be positive decimal integers. 0D vectors are allowed by + omitting the dimension: `vector`. Note: hexadecimal integer literals are not allowed in vector type declarations, `vector<0x42xi32>` is invalid because it is interpreted as a 2D vector with shape `(0, 42)` and zero shapes are not allowed. - Examples: ```mlir diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index e71f335d882c..64dceaaa4480 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -441,9 +441,6 @@ bool ShapedType::hasStaticShape(ArrayRef shape) const { LogicalResult VectorType::verify(function_ref emitError, ArrayRef shape, Type elementType) { - if (shape.empty()) - return emitError() << "vector types must have at least one dimension"; - if (!isValidElementType(elementType)) return emitError() << "vector elements must be int/index/float type but got " diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp index 256d6c0a96a2..57442e76360f 100644 --- a/mlir/lib/Parser/TypeParser.cpp +++ b/mlir/lib/Parser/TypeParser.cpp @@ -442,9 +442,7 @@ Type Parser::parseTupleType() { /// Parse a vector type. /// -/// vector-type ::= `vector` `<` non-empty-static-dimension-list type `>` -/// non-empty-static-dimension-list ::= decimal-literal `x` -/// static-dimension-list +/// vector-type ::= `vector` `<` static-dimension-list type `>` /// static-dimension-list ::= (decimal-literal `x`)* /// VectorType Parser::parseVectorType() { @@ -456,8 +454,6 @@ VectorType Parser::parseVectorType() { SmallVector dimensions; if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false)) return nullptr; - if (dimensions.empty()) - return (emitError("expected dimension size in vector type"), nullptr); if (any_of(dimensions, [](int64_t i) { return i <= 0; })) return emitError(getToken().getLoc(), "vector types must have positive constant sizes"), diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index b9187dc96673..01082fc336ba 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -949,7 +949,7 @@ func @zero_in_vector_type() -> vector<1x0xi32> // ----- -// expected-error @+1 {{expected dimension size in vector type}} +// expected-error @+1 {{expected non-function type}} func @negative_vector_size() -> vector<-1xi32> // ----- diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index d1f4ab64f7f3..74c7320c1ab9 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -67,8 +67,8 @@ func private @uint_types(ui2, ui4) -> (ui7, ui1023) // CHECK: func private @float_types(f80, f128) func private @float_types(f80, f128) -// CHECK: func private @vectors(vector<1xf32>, vector<2x4xf32>) -func private @vectors(vector<1 x f32>, vector<2x4xf32>) +// CHECK: func private @vectors(vector, vector<1xf32>, vector<2x4xf32>) +func private @vectors(vector, vector<1 x f32>, vector<2x4xf32>) // CHECK: func private @tensors(tensor<*xf32>, tensor<*xvector<2x4xf32>>, tensor<1x?x4x?x?xi32>, tensor) func private @tensors(tensor<* x f32>, tensor<* x vector<2x4xf32>>,