forked from OSchip/llvm-project
[mlir][capi] Add DialectRegistry to MLIR C-API
Exposes mlir::DialectRegistry to the C API as MlirDialectRegistry along with helper functions. A hook has been added to MlirDialectHandle that inserts the dialect into a registry. A future possible change is removing mlirDialectHandleRegisterDialect in favor of using mlirDialectHandleInsertDialect, which it is now implemented with. Differential Revision: https://reviews.llvm.org/D118293
This commit is contained in:
parent
79606ee85c
commit
97fc568211
|
@ -50,6 +50,7 @@ extern "C" {
|
|||
|
||||
DEFINE_C_API_STRUCT(MlirContext, void);
|
||||
DEFINE_C_API_STRUCT(MlirDialect, void);
|
||||
DEFINE_C_API_STRUCT(MlirDialectRegistry, void);
|
||||
DEFINE_C_API_STRUCT(MlirOperation, void);
|
||||
DEFINE_C_API_STRUCT(MlirOpPrintingFlags, void);
|
||||
DEFINE_C_API_STRUCT(MlirBlock, void);
|
||||
|
@ -108,6 +109,11 @@ mlirContextGetAllowUnregisteredDialects(MlirContext context);
|
|||
MLIR_CAPI_EXPORTED intptr_t
|
||||
mlirContextGetNumRegisteredDialects(MlirContext context);
|
||||
|
||||
/// Append the contents of the given dialect registry to the registry associated
|
||||
/// with the context.
|
||||
MLIR_CAPI_EXPORTED void
|
||||
mlirContextAppendDialectRegistry(MlirContext ctx, MlirDialectRegistry registry);
|
||||
|
||||
/// Returns the number of dialects loaded by the context.
|
||||
|
||||
MLIR_CAPI_EXPORTED intptr_t
|
||||
|
@ -152,6 +158,22 @@ MLIR_CAPI_EXPORTED bool mlirDialectEqual(MlirDialect dialect1,
|
|||
/// Returns the namespace of the given dialect.
|
||||
MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DialectRegistry API.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Creates a dialect registry and transfers its ownership to the caller.
|
||||
MLIR_CAPI_EXPORTED MlirDialectRegistry mlirDialectRegistryCreate();
|
||||
|
||||
/// Checks if the dialect registry is null.
|
||||
static inline bool mlirDialectRegistryIsNull(MlirDialectRegistry registry) {
|
||||
return !registry.ptr;
|
||||
}
|
||||
|
||||
/// Takes a dialect registry owned by the caller and destroys it.
|
||||
MLIR_CAPI_EXPORTED void
|
||||
mlirDialectRegistryDestroy(MlirDialectRegistry registry);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Location API.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -44,6 +44,11 @@ typedef struct MlirDialectHandle MlirDialectHandle;
|
|||
MLIR_CAPI_EXPORTED
|
||||
MlirStringRef mlirDialectHandleGetNamespace(MlirDialectHandle);
|
||||
|
||||
/// Inserts the dialect associated with the provided dialect handle into the
|
||||
/// provided dialect registry
|
||||
MLIR_CAPI_EXPORTED void mlirDialectHandleInsertDialect(MlirDialectHandle,
|
||||
MlirDialectRegistry);
|
||||
|
||||
/// Registers the dialect associated with the provided dialect handle.
|
||||
MLIR_CAPI_EXPORTED void mlirDialectHandleRegisterDialect(MlirDialectHandle,
|
||||
MlirContext);
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
|
||||
DEFINE_C_API_PTR_METHODS(MlirContext, mlir::MLIRContext)
|
||||
DEFINE_C_API_PTR_METHODS(MlirDialect, mlir::Dialect)
|
||||
DEFINE_C_API_PTR_METHODS(MlirDialectRegistry, mlir::DialectRegistry)
|
||||
DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation)
|
||||
DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block)
|
||||
DEFINE_C_API_PTR_METHODS(MlirOpPrintingFlags, mlir::OpPrintingFlags)
|
||||
|
|
|
@ -21,23 +21,23 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Hooks for dynamic discovery of dialects.
|
||||
typedef void (*MlirContextRegisterDialectHook)(MlirContext context);
|
||||
typedef void (*MlirDialectRegistryInsertDialectHook)(
|
||||
MlirDialectRegistry registry);
|
||||
typedef MlirDialect (*MlirContextLoadDialectHook)(MlirContext context);
|
||||
typedef MlirStringRef (*MlirDialectGetNamespaceHook)();
|
||||
|
||||
/// Structure of dialect registration hooks.
|
||||
struct MlirDialectRegistrationHooks {
|
||||
MlirContextRegisterDialectHook registerHook;
|
||||
MlirDialectRegistryInsertDialectHook insertHook;
|
||||
MlirContextLoadDialectHook loadHook;
|
||||
MlirDialectGetNamespaceHook getNamespaceHook;
|
||||
};
|
||||
typedef struct MlirDialectRegistrationHooks MlirDialectRegistrationHooks;
|
||||
|
||||
#define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Name, Namespace, ClassName) \
|
||||
static void mlirContextRegister##Name##Dialect(MlirContext context) { \
|
||||
mlir::DialectRegistry registry; \
|
||||
registry.insert<ClassName>(); \
|
||||
unwrap(context)->appendDialectRegistry(registry); \
|
||||
static void mlirDialectRegistryInsert##Name##Dialect( \
|
||||
MlirDialectRegistry registry) { \
|
||||
unwrap(registry)->insert<ClassName>(); \
|
||||
} \
|
||||
static MlirDialect mlirContextLoad##Name##Dialect(MlirContext context) { \
|
||||
return wrap(unwrap(context)->getOrLoadDialect<ClassName>()); \
|
||||
|
@ -47,8 +47,8 @@ typedef struct MlirDialectRegistrationHooks MlirDialectRegistrationHooks;
|
|||
} \
|
||||
MlirDialectHandle mlirGetDialectHandle__##Namespace##__() { \
|
||||
static MlirDialectRegistrationHooks hooks = { \
|
||||
mlirContextRegister##Name##Dialect, mlirContextLoad##Name##Dialect, \
|
||||
mlir##Name##DialectGetNamespace}; \
|
||||
mlirDialectRegistryInsert##Name##Dialect, \
|
||||
mlirContextLoad##Name##Dialect, mlir##Name##DialectGetNamespace}; \
|
||||
return MlirDialectHandle{&hooks}; \
|
||||
}
|
||||
|
||||
|
|
|
@ -17,9 +17,16 @@ MlirStringRef mlirDialectHandleGetNamespace(MlirDialectHandle handle) {
|
|||
return unwrap(handle)->getNamespaceHook();
|
||||
}
|
||||
|
||||
void mlirDialectHandleInsertDialect(MlirDialectHandle handle,
|
||||
MlirDialectRegistry registry) {
|
||||
unwrap(handle)->insertHook(registry);
|
||||
}
|
||||
|
||||
void mlirDialectHandleRegisterDialect(MlirDialectHandle handle,
|
||||
MlirContext ctx) {
|
||||
unwrap(handle)->registerHook(ctx);
|
||||
mlir::DialectRegistry registry;
|
||||
mlirDialectHandleInsertDialect(handle, wrap(®istry));
|
||||
unwrap(ctx)->appendDialectRegistry(registry);
|
||||
}
|
||||
|
||||
MlirDialect mlirDialectHandleLoadDialect(MlirDialectHandle handle,
|
||||
|
|
|
@ -53,6 +53,11 @@ intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) {
|
|||
return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size());
|
||||
}
|
||||
|
||||
void mlirContextAppendDialectRegistry(MlirContext ctx,
|
||||
MlirDialectRegistry registry) {
|
||||
unwrap(ctx)->appendDialectRegistry(*unwrap(registry));
|
||||
}
|
||||
|
||||
// TODO: expose a cheaper way than constructing + sorting a vector only to take
|
||||
// its size.
|
||||
intptr_t mlirContextGetNumLoadedDialects(MlirContext context) {
|
||||
|
@ -88,6 +93,18 @@ MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) {
|
|||
return wrap(unwrap(dialect)->getNamespace());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DialectRegistry API.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MlirDialectRegistry mlirDialectRegistryCreate() {
|
||||
return wrap(new DialectRegistry());
|
||||
}
|
||||
|
||||
void mlirDialectRegistryDestroy(MlirDialectRegistry registry) {
|
||||
delete unwrap(registry);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Printing flags API.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1904,6 +1904,36 @@ int testSymbolTable(MlirContext ctx) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
int testDialectRegistry() {
|
||||
fprintf(stderr, "@testDialectRegistry\n");
|
||||
|
||||
MlirDialectRegistry registry = mlirDialectRegistryCreate();
|
||||
if (mlirDialectRegistryIsNull(registry)) {
|
||||
fprintf(stderr, "ERROR: Expected registry to be present\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
MlirDialectHandle stdHandle = mlirGetDialectHandle__std__();
|
||||
mlirDialectHandleInsertDialect(stdHandle, registry);
|
||||
|
||||
MlirContext ctx = mlirContextCreate();
|
||||
if (mlirContextGetNumRegisteredDialects(ctx) != 0) {
|
||||
fprintf(stderr,
|
||||
"ERROR: Expected no dialects to be registered to new context\n");
|
||||
}
|
||||
|
||||
mlirContextAppendDialectRegistry(ctx, registry);
|
||||
if (mlirContextGetNumRegisteredDialects(ctx) != 1) {
|
||||
fprintf(stderr, "ERROR: Expected the dialect in the registry to be "
|
||||
"registered to the context\n");
|
||||
}
|
||||
|
||||
mlirContextDestroy(ctx);
|
||||
mlirDialectRegistryDestroy(registry);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void testDiagnostics() {
|
||||
MlirContext ctx = mlirContextCreate();
|
||||
MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(
|
||||
|
@ -1988,6 +2018,8 @@ int main() {
|
|||
return 13;
|
||||
if (testSymbolTable(ctx))
|
||||
return 14;
|
||||
if (testDialectRegistry())
|
||||
return 15;
|
||||
|
||||
mlirContextDestroy(ctx);
|
||||
|
||||
|
|
Loading…
Reference in New Issue