forked from OSchip/llvm-project
[mlir] Make MemRef element type extensible
Historically, MemRef only supported a restricted list of element types that were known to be storable in memory. This is unnecessarily restrictive given the open nature of MLIR's type system. Allow types to opt into being used as MemRef elements by implementing a type interface. For now, the interface is merely a declaration with no methods. Later, methods to query, e.g., the type size or whether a type can alias elements of another type may be added. Harden the "standard"-to-LLVM conversion against memrefs with non-builtin types. See https://llvm.discourse.group/t/rfc-memref-of-custom-types/3558. Depends On D103826 Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D103827
This commit is contained in:
parent
3c70a82e28
commit
ada9aa5a22
|
@ -30,3 +30,7 @@ Operations.
|
|||
## Types
|
||||
|
||||
[include "Dialects/BuiltinTypes.md"]
|
||||
|
||||
## Type Interfaces
|
||||
|
||||
[include "Dialects/BuiltinTypeInterfaces.md"]
|
||||
|
|
|
@ -192,6 +192,12 @@ public:
|
|||
#define GET_TYPEDEF_CLASSES
|
||||
#include "mlir/IR/BuiltinTypes.h.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Tablegen Interface Declarations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MemRefType
|
||||
|
@ -266,7 +272,8 @@ inline bool BaseMemRefType::classof(Type type) {
|
|||
}
|
||||
|
||||
inline bool BaseMemRefType::isValidElementType(Type type) {
|
||||
return type.isIntOrIndexOrFloat() || type.isa<ComplexType, VectorType>();
|
||||
return type.isIntOrIndexOrFloat() || type.isa<ComplexType, VectorType>() ||
|
||||
type.isa<MemRefElementTypeInterface>();
|
||||
}
|
||||
|
||||
inline bool FloatType::classof(Type type) {
|
||||
|
|
|
@ -248,6 +248,31 @@ def Builtin_Integer : Builtin_Type<"Integer"> {
|
|||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MemRefElementTypeInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def MemRefElementTypeInterface : TypeInterface<"MemRefElementTypeInterface"> {
|
||||
let cppNamespace = "::mlir";
|
||||
let description = [{
|
||||
Indication that this type can be used as element in memref types.
|
||||
|
||||
Implementing this interface establishes a contract between this type and the
|
||||
memref type indicating that this type can be used as element of ranked or
|
||||
unranked memrefs. The type is expected to:
|
||||
|
||||
- model an entity stored in memory;
|
||||
- have non-zero size.
|
||||
|
||||
For example, scalar values such as integers can implement this interface,
|
||||
but indicator types such as `void` or `unit` should not.
|
||||
|
||||
The interface currently has no methods and is used by types to opt into
|
||||
being memref elements. This may change in the future, in particular to
|
||||
require types to provide their size or alignment given a data layout.
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MemRefType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -282,6 +307,14 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "BaseMemRefType"> {
|
|||
on the rank. Other uses of this type are disallowed or will have undefined
|
||||
behavior.
|
||||
|
||||
Are accepted as elements:
|
||||
|
||||
- built-in integer types;
|
||||
- built-in index type;
|
||||
- built-in floating point types;
|
||||
- built-in vector types with elements of the above types;
|
||||
- any other type implementing `MemRefElementTypeInterface`.
|
||||
|
||||
##### Codegen of Unranked Memref
|
||||
|
||||
Using unranked memref in codegen besides the case mentioned above is highly
|
||||
|
|
|
@ -24,6 +24,8 @@ add_public_tablegen_target(MLIRBuiltinOpsIncGen)
|
|||
set(LLVM_TARGET_DEFINITIONS BuiltinTypes.td)
|
||||
mlir_tablegen(BuiltinTypes.h.inc -gen-typedef-decls)
|
||||
mlir_tablegen(BuiltinTypes.cpp.inc -gen-typedef-defs)
|
||||
mlir_tablegen(BuiltinTypeInterfaces.h.inc -gen-type-interface-decls)
|
||||
mlir_tablegen(BuiltinTypeInterfaces.cpp.inc -gen-type-interface-defs)
|
||||
add_public_tablegen_target(MLIRBuiltinTypesIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS TensorEncoding.td)
|
||||
|
@ -35,3 +37,4 @@ add_mlir_doc(BuiltinAttributes BuiltinAttributes Dialects/ -gen-attrdef-doc)
|
|||
add_mlir_doc(BuiltinLocationAttributes BuiltinLocationAttributes Dialects/ -gen-attrdef-doc)
|
||||
add_mlir_doc(BuiltinOps BuiltinOps Dialects/ -gen-op-doc)
|
||||
add_mlir_doc(BuiltinTypes BuiltinTypes Dialects/ -gen-typedef-doc)
|
||||
add_mlir_doc(BuiltinTypes BuiltinTypeInterfaces Dialects/ -gen-type-interface-docs)
|
||||
|
|
|
@ -349,6 +349,8 @@ Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
|
|||
// unpack the `sizes` and `strides` arrays.
|
||||
SmallVector<Type, 5> types =
|
||||
getMemRefDescriptorFields(type, /*unpackAggregates=*/false);
|
||||
if (types.empty())
|
||||
return {};
|
||||
return LLVM::LLVMStructType::getLiteral(&getContext(), types);
|
||||
}
|
||||
|
||||
|
@ -368,6 +370,8 @@ SmallVector<Type, 2> LLVMTypeConverter::getUnrankedMemRefDescriptorFields() {
|
|||
}
|
||||
|
||||
Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) {
|
||||
if (!convertType(type.getElementType()))
|
||||
return {};
|
||||
return LLVM::LLVMStructType::getLiteral(&getContext(),
|
||||
getUnrankedMemRefDescriptorFields());
|
||||
}
|
||||
|
|
|
@ -31,6 +31,12 @@ using namespace mlir::detail;
|
|||
#define GET_TYPEDEF_CLASSES
|
||||
#include "mlir/IR/BuiltinTypes.cpp.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Tablegen Interface Definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.cpp.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BuiltinDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -427,3 +427,23 @@ module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<index, 32>> } {
|
|||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Should not convert memrefs with unsupported types in any convention.
|
||||
|
||||
// CHECK: @unsupported_memref_element_type
|
||||
// CHECK-SAME: memref<
|
||||
// CHECK-NOT: !llvm.struct
|
||||
// BAREPTR: @unsupported_memref_element_type
|
||||
// BAREPTR-SAME: memref<
|
||||
// BAREPTR-NOT: !llvm.ptr
|
||||
func private @unsupported_memref_element_type() -> memref<42 x !test.memref_element>
|
||||
|
||||
// CHECK: @unsupported_unranked_memref_element_type
|
||||
// CHECK-SAME: memref<
|
||||
// CHECK-NOT: !llvm.struct
|
||||
// BAREPTR: @unsupported_unranked_memref_element_type
|
||||
// BAREPTR-SAME: memref<
|
||||
// BAREPTR-NOT: !llvm.ptr
|
||||
func private @unsupported_unranked_memref_element_type() -> memref<* x !test.memref_element>
|
||||
|
|
|
@ -6,3 +6,4 @@ func private @unsupported_signature() -> tensor<10 x i32>
|
|||
// -----
|
||||
|
||||
func private @partially_supported_signature() -> (vector<10 x i32>, tensor<10 x i32>)
|
||||
|
||||
|
|
|
@ -178,6 +178,9 @@ func private @memref_with_complex_elems(memref<1x?xcomplex<f32>>)
|
|||
// CHECK: func private @memref_with_vector_elems(memref<1x?xvector<10xf32>>)
|
||||
func private @memref_with_vector_elems(memref<1x?xvector<10xf32>>)
|
||||
|
||||
// CHECK: func private @memref_with_custom_elem(memref<1x?x!test.memref_element>)
|
||||
func private @memref_with_custom_elem(memref<1x?x!test.memref_element>)
|
||||
|
||||
// CHECK: func private @unranked_memref_with_complex_elems(memref<*xcomplex<f32>>)
|
||||
func private @unranked_memref_with_complex_elems(memref<*xcomplex<f32>>)
|
||||
|
||||
|
|
|
@ -17,8 +17,8 @@ mlir_tablegen(TestAttrDefs.cpp.inc -gen-attrdef-defs)
|
|||
add_public_tablegen_target(MLIRTestAttrDefIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS TestTypeDefs.td)
|
||||
mlir_tablegen(TestTypeDefs.h.inc -gen-typedef-decls)
|
||||
mlir_tablegen(TestTypeDefs.cpp.inc -gen-typedef-defs)
|
||||
mlir_tablegen(TestTypeDefs.h.inc -gen-typedef-decls -typedefs-dialect=test)
|
||||
mlir_tablegen(TestTypeDefs.cpp.inc -gen-typedef-defs -typedefs-dialect=test)
|
||||
add_public_tablegen_target(MLIRTestTypeDefIncGen)
|
||||
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
// To get the test dialect def.
|
||||
include "TestOps.td"
|
||||
include "mlir/IR/BuiltinTypes.td"
|
||||
include "mlir/Interfaces/DataLayoutInterfaces.td"
|
||||
|
||||
// All of the types will extend this class.
|
||||
|
@ -176,4 +177,9 @@ def TestTypeWithLayoutType : Test_Type<"TestTypeWithLayout", [
|
|||
}];
|
||||
}
|
||||
|
||||
def TestMemRefElementType : Test_Type<"TestMemRefElementType",
|
||||
[MemRefElementTypeInterface]> {
|
||||
let mnemonic = "memref_element";
|
||||
}
|
||||
|
||||
#endif // TEST_TYPEDEFS
|
||||
|
|
Loading…
Reference in New Issue