From e12db3ed997de473b2b7189781dbec7a239a3994 Mon Sep 17 00:00:00 2001 From: Stephan Herhut Date: Thu, 30 Jul 2020 14:02:46 +0200 Subject: [PATCH] [mlir] Allow index as element type of memref Differential Revision: https://reviews.llvm.org/D84934 --- mlir/docs/Rationale/Rationale.md | 27 +++++-------------- mlir/lib/IR/StandardTypes.cpp | 2 +- mlir/lib/Parser/TypeParser.cpp | 2 +- .../StandardToLLVM/convert-to-llvmir.mlir | 14 ++++++++++ mlir/test/IR/invalid.mlir | 4 --- mlir/test/IR/parser.mlir | 3 +++ 6 files changed, 26 insertions(+), 26 deletions(-) diff --git a/mlir/docs/Rationale/Rationale.md b/mlir/docs/Rationale/Rationale.md index 22e21383e903..e906559a14a7 100644 --- a/mlir/docs/Rationale/Rationale.md +++ b/mlir/docs/Rationale/Rationale.md @@ -202,32 +202,19 @@ and described in interest [starts here](https://www.google.com/url?q=https://youtu.be/Ntj8ab-5cvE?t%3D596&sa=D&ust=1529450150971000&usg=AFQjCNFQHEWL7m8q3eO-1DiKw9zqC2v24Q). -### Index type disallowed in vector/memref types +### Index type disallowed in vector types -Index types are not allowed as elements of `vector` and `memref` types. Index +Index types are not allowed as elements of `vector` types. Index types are intended to be used for platform-specific "size" values and may appear in subscripts, sizes of aggregate types and affine expressions. They are also tightly coupled with `affine.apply` and affine.load/store operations; having `index` type is a necessary precondition of a value to be acceptable by these -operations. While it may be useful to have `memref` to express indirect -accesses, e.g. sparse matrix manipulations or lookup tables, it creates problems -MLIR is not ready to address yet. MLIR needs to internally store constants of -aggregate types and emit code operating on values of those types, which are -subject to target-specific size and alignment constraints. Since MLIR does not -have a target description mechanism at the moment, it cannot reliably emit such -code. Moreover, some platforms may not support vectors of type equivalent to -`index`. +operations. -Indirect access use cases can be alternatively supported by providing and -`index_cast` instruction that allows for conversion between `index` and -fixed-width integer types, at the SSA value level. It has an additional benefit -of supporting smaller integer types, e.g. `i8` or `i16`, for small indices -instead of (presumably larger) `index` type. - -Index types are allowed as element types of `tensor` types. The `tensor` type -specifically abstracts the target-specific aspects that intersect with the -code-generation-related/lowering-related concerns explained above. In fact, the -`tensor` type even allows dialect-specific types as element types. +We allow `index` types in tensors and memrefs as a code generation strategy has +to map `index` to an implementation type and hence needs to be able to +materialize corresponding values. However, the target might lack support for +`vector` values with the target specfic equivalent of the `index` type. ### Bit width of a non-primitive type and `index` is undefined diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp index 5a9d22148b76..2d1f8d8eb6f0 100644 --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -398,7 +398,7 @@ MemRefType MemRefType::getImpl(ArrayRef shape, Type elementType, auto *context = elementType.getContext(); // Check that memref is formed from allowed types. - if (!elementType.isIntOrFloat() && + if (!elementType.isIntOrIndexOrFloat() && !elementType.isa()) return emitOptionalError(location, "invalid memref element type"), MemRefType(); diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp index 9d8d198aa1c8..f5c98f3c6f9d 100644 --- a/mlir/lib/Parser/TypeParser.cpp +++ b/mlir/lib/Parser/TypeParser.cpp @@ -217,7 +217,7 @@ Type Parser::parseMemRefType() { return nullptr; // Check that memref is formed from allowed types. - if (!elementType.isIntOrFloat() && + if (!elementType.isIntOrIndexOrFloat() && !elementType.isa()) return emitError(typeLoc, "invalid memref element type"), nullptr; diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir index 2129cf6819a9..c1ec558da86f 100644 --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -1277,3 +1277,17 @@ func @bfloat(%arg0: bf16) -> bf16 { return %arg0 : bf16 } // CHECK-NEXT: return %{{.*}} : !llvm.bfloat + +// ----- + +// CHECK-LABEL: func @memref_index +// CHECK-SAME: %arg0: !llvm<"i64*">, %arg1: !llvm<"i64*">, +// CHECK-SAME: %arg2: !llvm.i64, %arg3: !llvm.i64, %arg4: !llvm.i64) +// CHECK-SAME: -> !llvm<"{ i64*, i64*, i64, [1 x i64], [1 x i64] }"> +// CHECK32-LABEL: func @memref_index +// CHECK32-SAME: %arg0: !llvm<"i32*">, %arg1: !llvm<"i32*">, +// CHECK32-SAME: %arg2: !llvm.i32, %arg3: !llvm.i32, %arg4: !llvm.i32) +// CHECK32-SAME: -> !llvm<"{ i32*, i32*, i32, [1 x i32], [1 x i32] }"> +func @memref_index(%arg0: memref<32xindex>) -> memref<32xindex> { + return %arg0 : memref<32xindex> +} diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index 2d8474c655f6..dcf04735c901 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -19,10 +19,6 @@ func @nestedtensor(tensor>) -> () // expected-error {{invalid tensor func @indexvector(vector<4 x index>) -> () // expected-error {{vector elements must be int or float type}} -// ----- - -func @indexmemref(memref) -> () // expected-error {{invalid memref element type}} - // ----- // Test no map in memref type. func @memrefs(memref<2x4xi8, >) // expected-error {{expected list element}} diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index 93db23fd5d0d..8d3d161ef27e 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -140,6 +140,9 @@ func @memrefs_compose_with_id(memref<2x2xi8, affine_map<(d0, d1) -> (d0, d1)>, func @complex_types(complex) -> complex +// CHECK: func @memref_with_index_elems(memref<1x?xindex>) +func @memref_with_index_elems(memref<1x?xindex>) + // CHECK: func @memref_with_complex_elems(memref<1x?xcomplex>) func @memref_with_complex_elems(memref<1x?xcomplex>)