forked from OSchip/llvm-project
[mlir] Expose Dialect class and registration/loading to C API
- Add a minimalist C API for mlir::Dialect. - Allow one to query the context about registered and loaded dialects. - Add API for loading dialects. - Provide functions to register the Standard dialect. When used naively, this will require to separately register each dialect. When we have more than one exposed, we can add variadic macros that expand to individual calls. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D88162
This commit is contained in:
parent
042f22bda5
commit
64c0c9f015
|
@ -20,6 +20,8 @@
|
|||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "mlir-c/Support.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
@ -46,6 +48,7 @@ extern "C" {
|
|||
typedef struct name name
|
||||
|
||||
DEFINE_C_API_STRUCT(MlirContext, void);
|
||||
DEFINE_C_API_STRUCT(MlirDialect, void);
|
||||
DEFINE_C_API_STRUCT(MlirOperation, void);
|
||||
DEFINE_C_API_STRUCT(MlirBlock, void);
|
||||
DEFINE_C_API_STRUCT(MlirRegion, void);
|
||||
|
@ -97,6 +100,39 @@ void mlirContextSetAllowUnregisteredDialects(MlirContext context, int allow);
|
|||
/** Returns whether the context allows unregistered dialects. */
|
||||
int mlirContextGetAllowUnregisteredDialects(MlirContext context);
|
||||
|
||||
/** Returns the number of dialects registered with the given context. A
|
||||
* registered dialect will be loaded if needed by the parser. */
|
||||
intptr_t mlirContextGetNumRegisteredDialects(MlirContext context);
|
||||
|
||||
/** Returns the number of dialects loaded by the context.
|
||||
*/
|
||||
intptr_t mlirContextGetNumLoadedDialects(MlirContext context);
|
||||
|
||||
/** Gets the dialect instance owned by the given context using the dialect
|
||||
* namespace to identify it, loads (i.e., constructs the instance of) the
|
||||
* dialect if necessary. If the dialect is not registered with the context,
|
||||
* returns null. Use mlirContextLoad<Name>Dialect to load an unregistered
|
||||
* dialect. */
|
||||
MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
|
||||
MlirStringRef name);
|
||||
|
||||
/*============================================================================*/
|
||||
/* Dialect API. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Returns the context that owns the dialect. */
|
||||
MlirContext mlirDialectGetContext(MlirDialect dialect);
|
||||
|
||||
/** Checks if the dialect is null. */
|
||||
int mlirDialectIsNull(MlirDialect dialect);
|
||||
|
||||
/** Checks if two dialects that belong to the same context are equal. Dialects
|
||||
* from different contexts will not compare equal. */
|
||||
int mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2);
|
||||
|
||||
/** Returns the namespace of the given dialect. */
|
||||
MlirStringRef mlirDialectGetNamespace(MlirDialect dialect);
|
||||
|
||||
/*============================================================================*/
|
||||
/* Location API. */
|
||||
/*============================================================================*/
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
/*===-- mlir-c/StandardDialect.h - C API for Standard dialect -----*- C -*-===*\
|
||||
|* *|
|
||||
|* Part of the LLVM Project, under the Apache License v2.0 with LLVM *|
|
||||
|* Exceptions. *|
|
||||
|* See https://llvm.org/LICENSE.txt for license information. *|
|
||||
|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *|
|
||||
|* *|
|
||||
|*===----------------------------------------------------------------------===*|
|
||||
|* *|
|
||||
|* This header declares the C interface for registering and accessing the *|
|
||||
|* Standard dialect. A dialect should be registered with a context to make it *|
|
||||
|* available to users of the context. These users must load the dialect *|
|
||||
|* before using any of its attributes, operations or types. Parser and pass *|
|
||||
|* manager can load registered dialects automatically. *|
|
||||
|* *|
|
||||
\*===----------------------------------------------------------------------===*/
|
||||
|
||||
#ifndef MLIR_C_STANDARDDIALECT_H
|
||||
#define MLIR_C_STANDARDDIALECT_H
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/** Registers the Standard dialect with the given context. This allows the
|
||||
* dialect to be loaded dynamically if needed when parsing. */
|
||||
void mlirContextRegisterStandardDialect(MlirContext context);
|
||||
|
||||
/** Loads the Standard dialect into the given context. The dialect does _not_
|
||||
* have to be registered in advance. */
|
||||
MlirDialect mlirContextLoadStandardDialect(MlirContext context);
|
||||
|
||||
/** Returns the namespace of the Standard dialect, suitable for loading it. */
|
||||
MlirStringRef mlirStandardDialectGetNamespace();
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MLIR_C_STANDARDDIALECT_H
|
|
@ -21,6 +21,7 @@
|
|||
#include "mlir/IR/Operation.h"
|
||||
|
||||
DEFINE_C_API_PTR_METHODS(MlirContext, mlir::MLIRContext)
|
||||
DEFINE_C_API_PTR_METHODS(MlirDialect, mlir::Dialect)
|
||||
DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation)
|
||||
DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block)
|
||||
DEFINE_C_API_PTR_METHODS(MlirRegion, mlir::Region)
|
||||
|
|
|
@ -1,2 +1,3 @@
|
|||
add_subdirectory(IR)
|
||||
add_subdirectory(Registration)
|
||||
add_subdirectory(Standard)
|
||||
|
|
|
@ -7,8 +7,10 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Support.h"
|
||||
|
||||
#include "mlir/CAPI/IR.h"
|
||||
#include "mlir/CAPI/Support.h"
|
||||
#include "mlir/CAPI/Utils.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
@ -41,6 +43,40 @@ void mlirContextSetAllowUnregisteredDialects(MlirContext context, int allow) {
|
|||
int mlirContextGetAllowUnregisteredDialects(MlirContext context) {
|
||||
return unwrap(context)->allowsUnregisteredDialects();
|
||||
}
|
||||
intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) {
|
||||
return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size());
|
||||
}
|
||||
|
||||
// TODO: expose a cheaper way than constructing + sorting a vector only to take
|
||||
// its size.
|
||||
intptr_t mlirContextGetNumLoadedDialects(MlirContext context) {
|
||||
return static_cast<intptr_t>(unwrap(context)->getLoadedDialects().size());
|
||||
}
|
||||
|
||||
MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
|
||||
MlirStringRef name) {
|
||||
return wrap(unwrap(context)->getOrLoadDialect(unwrap(name)));
|
||||
}
|
||||
|
||||
/* ========================================================================== */
|
||||
/* Dialect API. */
|
||||
/* ========================================================================== */
|
||||
|
||||
MlirContext mlirDialectGetContext(MlirDialect dialect) {
|
||||
return wrap(unwrap(dialect)->getContext());
|
||||
}
|
||||
|
||||
int mlirDialectIsNull(MlirDialect dialect) {
|
||||
return unwrap(dialect) == nullptr;
|
||||
}
|
||||
|
||||
int mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) {
|
||||
return unwrap(dialect1) == unwrap(dialect2);
|
||||
}
|
||||
|
||||
MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) {
|
||||
return wrap(unwrap(dialect)->getNamespace());
|
||||
}
|
||||
|
||||
/* ========================================================================== */
|
||||
/* Location API. */
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
add_mlir_library(MLIRCAPIStandard
|
||||
|
||||
StandardDialect.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir-c
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRCAPIIR
|
||||
MLIRStandardOps
|
||||
)
|
|
@ -0,0 +1,25 @@
|
|||
//===- StandardDialect.cpp - C Interface for Standard dialect -------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir-c/StandardDialect.h"
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir/CAPI/IR.h"
|
||||
#include "mlir/CAPI/Support.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
|
||||
void mlirContextRegisterStandardDialect(MlirContext context) {
|
||||
unwrap(context)->getDialectRegistry().insert<mlir::StandardOpsDialect>();
|
||||
}
|
||||
|
||||
MlirDialect mlirContextLoadStandardDialect(MlirContext context) {
|
||||
return wrap(unwrap(context)->getOrLoadDialect<mlir::StandardOpsDialect>());
|
||||
}
|
||||
|
||||
MlirStringRef mlirStandardDialectGetNamespace() {
|
||||
return wrap(mlir::StandardOpsDialect::getDialectNamespace());
|
||||
}
|
|
@ -13,4 +13,5 @@ target_link_libraries(mlir-capi-ir-test
|
|||
PRIVATE
|
||||
MLIRCAPIIR
|
||||
MLIRCAPIRegistration
|
||||
MLIRCAPIStandard
|
||||
${dialect_libs})
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#include "mlir-c/AffineMap.h"
|
||||
#include "mlir-c/Registration.h"
|
||||
#include "mlir-c/StandardAttributes.h"
|
||||
#include "mlir-c/StandardDialect.h"
|
||||
#include "mlir-c/StandardTypes.h"
|
||||
|
||||
#include <assert.h>
|
||||
|
@ -790,6 +791,42 @@ int printAffineMap(MlirContext ctx) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
int registerOnlyStd() {
|
||||
MlirContext ctx = mlirContextCreate();
|
||||
// The built-in dialect is always loaded.
|
||||
if (mlirContextGetNumLoadedDialects(ctx) != 1)
|
||||
return 1;
|
||||
|
||||
MlirDialect std =
|
||||
mlirContextGetOrLoadDialect(ctx, mlirStandardDialectGetNamespace());
|
||||
if (!mlirDialectIsNull(std))
|
||||
return 2;
|
||||
|
||||
mlirContextRegisterStandardDialect(ctx);
|
||||
if (mlirContextGetNumRegisteredDialects(ctx) != 1)
|
||||
return 3;
|
||||
if (mlirContextGetNumLoadedDialects(ctx) != 1)
|
||||
return 4;
|
||||
|
||||
std = mlirContextGetOrLoadDialect(ctx, mlirStandardDialectGetNamespace());
|
||||
if (mlirDialectIsNull(std))
|
||||
return 5;
|
||||
if (mlirContextGetNumLoadedDialects(ctx) != 2)
|
||||
return 6;
|
||||
|
||||
MlirDialect alsoStd = mlirContextLoadStandardDialect(ctx);
|
||||
if (!mlirDialectEqual(std, alsoStd))
|
||||
return 7;
|
||||
|
||||
MlirStringRef stdNs = mlirDialectGetNamespace(std);
|
||||
MlirStringRef alsoStdNs = mlirStandardDialectGetNamespace();
|
||||
if (stdNs.length != alsoStdNs.length ||
|
||||
strncmp(stdNs.data, alsoStdNs.data, stdNs.length))
|
||||
return 8;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main() {
|
||||
MlirContext ctx = mlirContextCreate();
|
||||
mlirRegisterAllDialects(ctx);
|
||||
|
@ -935,6 +972,14 @@ int main() {
|
|||
errcode = printAffineMap(ctx);
|
||||
fprintf(stderr, "%d\n", errcode);
|
||||
|
||||
fprintf(stderr, "@registration\n");
|
||||
errcode = registerOnlyStd();
|
||||
fprintf(stderr, "%d\n", errcode);
|
||||
// clang-format off
|
||||
// CHECK-LABEL: @registration
|
||||
// CHECK: 0
|
||||
// clang-format on
|
||||
|
||||
mlirContextDestroy(ctx);
|
||||
|
||||
return 0;
|
||||
|
|
Loading…
Reference in New Issue