[mlir][capi] Add DialectRegistry to MLIR C-API

Exposes mlir::DialectRegistry to the C API as MlirDialectRegistry along with
helper functions. A hook has been added to MlirDialectHandle that inserts
the dialect into a registry.

A future possible change is removing mlirDialectHandleRegisterDialect in
favor of using mlirDialectHandleInsertDialect, which it is now implemented with.

Differential Revision: https://reviews.llvm.org/D118293
This commit is contained in:
Daniel Resnick 2022-01-26 17:13:24 -07:00
parent 79606ee85c
commit 97fc568211
7 changed files with 93 additions and 9 deletions

View File

@ -50,6 +50,7 @@ extern "C" {
DEFINE_C_API_STRUCT(MlirContext, void);
DEFINE_C_API_STRUCT(MlirDialect, void);
DEFINE_C_API_STRUCT(MlirDialectRegistry, void);
DEFINE_C_API_STRUCT(MlirOperation, void);
DEFINE_C_API_STRUCT(MlirOpPrintingFlags, void);
DEFINE_C_API_STRUCT(MlirBlock, void);
@ -108,6 +109,11 @@ mlirContextGetAllowUnregisteredDialects(MlirContext context);
MLIR_CAPI_EXPORTED intptr_t
mlirContextGetNumRegisteredDialects(MlirContext context);
/// Append the contents of the given dialect registry to the registry associated
/// with the context.
MLIR_CAPI_EXPORTED void
mlirContextAppendDialectRegistry(MlirContext ctx, MlirDialectRegistry registry);
/// Returns the number of dialects loaded by the context.
MLIR_CAPI_EXPORTED intptr_t
@ -152,6 +158,22 @@ MLIR_CAPI_EXPORTED bool mlirDialectEqual(MlirDialect dialect1,
/// Returns the namespace of the given dialect.
MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect);
//===----------------------------------------------------------------------===//
// DialectRegistry API.
//===----------------------------------------------------------------------===//
/// Creates a dialect registry and transfers its ownership to the caller.
MLIR_CAPI_EXPORTED MlirDialectRegistry mlirDialectRegistryCreate();
/// Checks if the dialect registry is null.
static inline bool mlirDialectRegistryIsNull(MlirDialectRegistry registry) {
return !registry.ptr;
}
/// Takes a dialect registry owned by the caller and destroys it.
MLIR_CAPI_EXPORTED void
mlirDialectRegistryDestroy(MlirDialectRegistry registry);
//===----------------------------------------------------------------------===//
// Location API.
//===----------------------------------------------------------------------===//

View File

@ -44,6 +44,11 @@ typedef struct MlirDialectHandle MlirDialectHandle;
MLIR_CAPI_EXPORTED
MlirStringRef mlirDialectHandleGetNamespace(MlirDialectHandle);
/// Inserts the dialect associated with the provided dialect handle into the
/// provided dialect registry
MLIR_CAPI_EXPORTED void mlirDialectHandleInsertDialect(MlirDialectHandle,
MlirDialectRegistry);
/// Registers the dialect associated with the provided dialect handle.
MLIR_CAPI_EXPORTED void mlirDialectHandleRegisterDialect(MlirDialectHandle,
MlirContext);

View File

@ -22,6 +22,7 @@
DEFINE_C_API_PTR_METHODS(MlirContext, mlir::MLIRContext)
DEFINE_C_API_PTR_METHODS(MlirDialect, mlir::Dialect)
DEFINE_C_API_PTR_METHODS(MlirDialectRegistry, mlir::DialectRegistry)
DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation)
DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block)
DEFINE_C_API_PTR_METHODS(MlirOpPrintingFlags, mlir::OpPrintingFlags)

View File

@ -21,23 +21,23 @@
//===----------------------------------------------------------------------===//
/// Hooks for dynamic discovery of dialects.
typedef void (*MlirContextRegisterDialectHook)(MlirContext context);
typedef void (*MlirDialectRegistryInsertDialectHook)(
MlirDialectRegistry registry);
typedef MlirDialect (*MlirContextLoadDialectHook)(MlirContext context);
typedef MlirStringRef (*MlirDialectGetNamespaceHook)();
/// Structure of dialect registration hooks.
struct MlirDialectRegistrationHooks {
MlirContextRegisterDialectHook registerHook;
MlirDialectRegistryInsertDialectHook insertHook;
MlirContextLoadDialectHook loadHook;
MlirDialectGetNamespaceHook getNamespaceHook;
};
typedef struct MlirDialectRegistrationHooks MlirDialectRegistrationHooks;
#define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Name, Namespace, ClassName) \
static void mlirContextRegister##Name##Dialect(MlirContext context) { \
mlir::DialectRegistry registry; \
registry.insert<ClassName>(); \
unwrap(context)->appendDialectRegistry(registry); \
static void mlirDialectRegistryInsert##Name##Dialect( \
MlirDialectRegistry registry) { \
unwrap(registry)->insert<ClassName>(); \
} \
static MlirDialect mlirContextLoad##Name##Dialect(MlirContext context) { \
return wrap(unwrap(context)->getOrLoadDialect<ClassName>()); \
@ -47,8 +47,8 @@ typedef struct MlirDialectRegistrationHooks MlirDialectRegistrationHooks;
} \
MlirDialectHandle mlirGetDialectHandle__##Namespace##__() { \
static MlirDialectRegistrationHooks hooks = { \
mlirContextRegister##Name##Dialect, mlirContextLoad##Name##Dialect, \
mlir##Name##DialectGetNamespace}; \
mlirDialectRegistryInsert##Name##Dialect, \
mlirContextLoad##Name##Dialect, mlir##Name##DialectGetNamespace}; \
return MlirDialectHandle{&hooks}; \
}

View File

@ -17,9 +17,16 @@ MlirStringRef mlirDialectHandleGetNamespace(MlirDialectHandle handle) {
return unwrap(handle)->getNamespaceHook();
}
void mlirDialectHandleInsertDialect(MlirDialectHandle handle,
MlirDialectRegistry registry) {
unwrap(handle)->insertHook(registry);
}
void mlirDialectHandleRegisterDialect(MlirDialectHandle handle,
MlirContext ctx) {
unwrap(handle)->registerHook(ctx);
mlir::DialectRegistry registry;
mlirDialectHandleInsertDialect(handle, wrap(&registry));
unwrap(ctx)->appendDialectRegistry(registry);
}
MlirDialect mlirDialectHandleLoadDialect(MlirDialectHandle handle,

View File

@ -53,6 +53,11 @@ intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) {
return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size());
}
void mlirContextAppendDialectRegistry(MlirContext ctx,
MlirDialectRegistry registry) {
unwrap(ctx)->appendDialectRegistry(*unwrap(registry));
}
// TODO: expose a cheaper way than constructing + sorting a vector only to take
// its size.
intptr_t mlirContextGetNumLoadedDialects(MlirContext context) {
@ -88,6 +93,18 @@ MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) {
return wrap(unwrap(dialect)->getNamespace());
}
//===----------------------------------------------------------------------===//
// DialectRegistry API.
//===----------------------------------------------------------------------===//
MlirDialectRegistry mlirDialectRegistryCreate() {
return wrap(new DialectRegistry());
}
void mlirDialectRegistryDestroy(MlirDialectRegistry registry) {
delete unwrap(registry);
}
//===----------------------------------------------------------------------===//
// Printing flags API.
//===----------------------------------------------------------------------===//

View File

@ -1904,6 +1904,36 @@ int testSymbolTable(MlirContext ctx) {
return 0;
}
int testDialectRegistry() {
fprintf(stderr, "@testDialectRegistry\n");
MlirDialectRegistry registry = mlirDialectRegistryCreate();
if (mlirDialectRegistryIsNull(registry)) {
fprintf(stderr, "ERROR: Expected registry to be present\n");
return 1;
}
MlirDialectHandle stdHandle = mlirGetDialectHandle__std__();
mlirDialectHandleInsertDialect(stdHandle, registry);
MlirContext ctx = mlirContextCreate();
if (mlirContextGetNumRegisteredDialects(ctx) != 0) {
fprintf(stderr,
"ERROR: Expected no dialects to be registered to new context\n");
}
mlirContextAppendDialectRegistry(ctx, registry);
if (mlirContextGetNumRegisteredDialects(ctx) != 1) {
fprintf(stderr, "ERROR: Expected the dialect in the registry to be "
"registered to the context\n");
}
mlirContextDestroy(ctx);
mlirDialectRegistryDestroy(registry);
return 0;
}
void testDiagnostics() {
MlirContext ctx = mlirContextCreate();
MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(
@ -1988,6 +2018,8 @@ int main() {
return 13;
if (testSymbolTable(ctx))
return 14;
if (testDialectRegistry())
return 15;
mlirContextDestroy(ctx);