[mlir] Make MemRef element type extensible

Historically, MemRef only supported a restricted list of element types that
were known to be storable in memory. This is unnecessarily restrictive given
the open nature of MLIR's type system. Allow types to opt into being used as
MemRef elements by implementing a type interface. For now, the interface is
merely a declaration with no methods. Later, methods to query, e.g., the type
size or whether a type can alias elements of another type may be added.

Harden the "standard"-to-LLVM conversion against memrefs with non-builtin
types.

See https://llvm.discourse.group/t/rfc-memref-of-custom-types/3558.

Depends On D103826

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D103827
This commit is contained in:
Alex Zinenko 2021-06-07 18:33:29 +02:00
parent 3c70a82e28
commit ada9aa5a22
11 changed files with 90 additions and 3 deletions

View File

@ -30,3 +30,7 @@ Operations.
## Types
[include "Dialects/BuiltinTypes.md"]
## Type Interfaces
[include "Dialects/BuiltinTypeInterfaces.md"]

View File

@ -192,6 +192,12 @@ public:
#define GET_TYPEDEF_CLASSES
#include "mlir/IR/BuiltinTypes.h.inc"
//===----------------------------------------------------------------------===//
// Tablegen Interface Declarations
//===----------------------------------------------------------------------===//
#include "mlir/IR/BuiltinTypeInterfaces.h.inc"
namespace mlir {
//===----------------------------------------------------------------------===//
// MemRefType
@ -266,7 +272,8 @@ inline bool BaseMemRefType::classof(Type type) {
}
inline bool BaseMemRefType::isValidElementType(Type type) {
return type.isIntOrIndexOrFloat() || type.isa<ComplexType, VectorType>();
return type.isIntOrIndexOrFloat() || type.isa<ComplexType, VectorType>() ||
type.isa<MemRefElementTypeInterface>();
}
inline bool FloatType::classof(Type type) {

View File

@ -248,6 +248,31 @@ def Builtin_Integer : Builtin_Type<"Integer"> {
}];
}
//===----------------------------------------------------------------------===//
// MemRefElementTypeInterface
//===----------------------------------------------------------------------===//
def MemRefElementTypeInterface : TypeInterface<"MemRefElementTypeInterface"> {
let cppNamespace = "::mlir";
let description = [{
Indication that this type can be used as element in memref types.
Implementing this interface establishes a contract between this type and the
memref type indicating that this type can be used as element of ranked or
unranked memrefs. The type is expected to:
- model an entity stored in memory;
- have non-zero size.
For example, scalar values such as integers can implement this interface,
but indicator types such as `void` or `unit` should not.
The interface currently has no methods and is used by types to opt into
being memref elements. This may change in the future, in particular to
require types to provide their size or alignment given a data layout.
}];
}
//===----------------------------------------------------------------------===//
// MemRefType
//===----------------------------------------------------------------------===//
@ -282,6 +307,14 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "BaseMemRefType"> {
on the rank. Other uses of this type are disallowed or will have undefined
behavior.
Are accepted as elements:
- built-in integer types;
- built-in index type;
- built-in floating point types;
- built-in vector types with elements of the above types;
- any other type implementing `MemRefElementTypeInterface`.
##### Codegen of Unranked Memref
Using unranked memref in codegen besides the case mentioned above is highly

View File

@ -24,6 +24,8 @@ add_public_tablegen_target(MLIRBuiltinOpsIncGen)
set(LLVM_TARGET_DEFINITIONS BuiltinTypes.td)
mlir_tablegen(BuiltinTypes.h.inc -gen-typedef-decls)
mlir_tablegen(BuiltinTypes.cpp.inc -gen-typedef-defs)
mlir_tablegen(BuiltinTypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(BuiltinTypeInterfaces.cpp.inc -gen-type-interface-defs)
add_public_tablegen_target(MLIRBuiltinTypesIncGen)
set(LLVM_TARGET_DEFINITIONS TensorEncoding.td)
@ -35,3 +37,4 @@ add_mlir_doc(BuiltinAttributes BuiltinAttributes Dialects/ -gen-attrdef-doc)
add_mlir_doc(BuiltinLocationAttributes BuiltinLocationAttributes Dialects/ -gen-attrdef-doc)
add_mlir_doc(BuiltinOps BuiltinOps Dialects/ -gen-op-doc)
add_mlir_doc(BuiltinTypes BuiltinTypes Dialects/ -gen-typedef-doc)
add_mlir_doc(BuiltinTypes BuiltinTypeInterfaces Dialects/ -gen-type-interface-docs)

View File

@ -349,6 +349,8 @@ Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
// unpack the `sizes` and `strides` arrays.
SmallVector<Type, 5> types =
getMemRefDescriptorFields(type, /*unpackAggregates=*/false);
if (types.empty())
return {};
return LLVM::LLVMStructType::getLiteral(&getContext(), types);
}
@ -368,6 +370,8 @@ SmallVector<Type, 2> LLVMTypeConverter::getUnrankedMemRefDescriptorFields() {
}
Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) {
if (!convertType(type.getElementType()))
return {};
return LLVM::LLVMStructType::getLiteral(&getContext(),
getUnrankedMemRefDescriptorFields());
}

View File

@ -31,6 +31,12 @@ using namespace mlir::detail;
#define GET_TYPEDEF_CLASSES
#include "mlir/IR/BuiltinTypes.cpp.inc"
//===----------------------------------------------------------------------===//
/// Tablegen Interface Definitions
//===----------------------------------------------------------------------===//
#include "mlir/IR/BuiltinTypeInterfaces.cpp.inc"
//===----------------------------------------------------------------------===//
// BuiltinDialect
//===----------------------------------------------------------------------===//

View File

@ -427,3 +427,23 @@ module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<index, 32>> } {
return
}
}
// -----
// Should not convert memrefs with unsupported types in any convention.
// CHECK: @unsupported_memref_element_type
// CHECK-SAME: memref<
// CHECK-NOT: !llvm.struct
// BAREPTR: @unsupported_memref_element_type
// BAREPTR-SAME: memref<
// BAREPTR-NOT: !llvm.ptr
func private @unsupported_memref_element_type() -> memref<42 x !test.memref_element>
// CHECK: @unsupported_unranked_memref_element_type
// CHECK-SAME: memref<
// CHECK-NOT: !llvm.struct
// BAREPTR: @unsupported_unranked_memref_element_type
// BAREPTR-SAME: memref<
// BAREPTR-NOT: !llvm.ptr
func private @unsupported_unranked_memref_element_type() -> memref<* x !test.memref_element>

View File

@ -6,3 +6,4 @@ func private @unsupported_signature() -> tensor<10 x i32>
// -----
func private @partially_supported_signature() -> (vector<10 x i32>, tensor<10 x i32>)

View File

@ -178,6 +178,9 @@ func private @memref_with_complex_elems(memref<1x?xcomplex<f32>>)
// CHECK: func private @memref_with_vector_elems(memref<1x?xvector<10xf32>>)
func private @memref_with_vector_elems(memref<1x?xvector<10xf32>>)
// CHECK: func private @memref_with_custom_elem(memref<1x?x!test.memref_element>)
func private @memref_with_custom_elem(memref<1x?x!test.memref_element>)
// CHECK: func private @unranked_memref_with_complex_elems(memref<*xcomplex<f32>>)
func private @unranked_memref_with_complex_elems(memref<*xcomplex<f32>>)

View File

@ -17,8 +17,8 @@ mlir_tablegen(TestAttrDefs.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(MLIRTestAttrDefIncGen)
set(LLVM_TARGET_DEFINITIONS TestTypeDefs.td)
mlir_tablegen(TestTypeDefs.h.inc -gen-typedef-decls)
mlir_tablegen(TestTypeDefs.cpp.inc -gen-typedef-defs)
mlir_tablegen(TestTypeDefs.h.inc -gen-typedef-decls -typedefs-dialect=test)
mlir_tablegen(TestTypeDefs.cpp.inc -gen-typedef-defs -typedefs-dialect=test)
add_public_tablegen_target(MLIRTestTypeDefIncGen)

View File

@ -15,6 +15,7 @@
// To get the test dialect def.
include "TestOps.td"
include "mlir/IR/BuiltinTypes.td"
include "mlir/Interfaces/DataLayoutInterfaces.td"
// All of the types will extend this class.
@ -176,4 +177,9 @@ def TestTypeWithLayoutType : Test_Type<"TestTypeWithLayout", [
}];
}
def TestMemRefElementType : Test_Type<"TestMemRefElementType",
[MemRefElementTypeInterface]> {
let mnemonic = "memref_element";
}
#endif // TEST_TYPEDEFS