diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 4aca261868f3..82149c7fce06 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -20,6 +20,8 @@ #include +#include "mlir-c/Support.h" + #ifdef __cplusplus extern "C" { #endif @@ -46,6 +48,7 @@ extern "C" { typedef struct name name DEFINE_C_API_STRUCT(MlirContext, void); +DEFINE_C_API_STRUCT(MlirDialect, void); DEFINE_C_API_STRUCT(MlirOperation, void); DEFINE_C_API_STRUCT(MlirBlock, void); DEFINE_C_API_STRUCT(MlirRegion, void); @@ -97,6 +100,39 @@ void mlirContextSetAllowUnregisteredDialects(MlirContext context, int allow); /** Returns whether the context allows unregistered dialects. */ int mlirContextGetAllowUnregisteredDialects(MlirContext context); +/** Returns the number of dialects registered with the given context. A + * registered dialect will be loaded if needed by the parser. */ +intptr_t mlirContextGetNumRegisteredDialects(MlirContext context); + +/** Returns the number of dialects loaded by the context. + */ +intptr_t mlirContextGetNumLoadedDialects(MlirContext context); + +/** Gets the dialect instance owned by the given context using the dialect + * namespace to identify it, loads (i.e., constructs the instance of) the + * dialect if necessary. If the dialect is not registered with the context, + * returns null. Use mlirContextLoadDialect to load an unregistered + * dialect. */ +MlirDialect mlirContextGetOrLoadDialect(MlirContext context, + MlirStringRef name); + +/*============================================================================*/ +/* Dialect API. */ +/*============================================================================*/ + +/** Returns the context that owns the dialect. */ +MlirContext mlirDialectGetContext(MlirDialect dialect); + +/** Checks if the dialect is null. */ +int mlirDialectIsNull(MlirDialect dialect); + +/** Checks if two dialects that belong to the same context are equal. Dialects + * from different contexts will not compare equal. */ +int mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2); + +/** Returns the namespace of the given dialect. */ +MlirStringRef mlirDialectGetNamespace(MlirDialect dialect); + /*============================================================================*/ /* Location API. */ /*============================================================================*/ diff --git a/mlir/include/mlir-c/StandardDialect.h b/mlir/include/mlir-c/StandardDialect.h new file mode 100644 index 000000000000..946d14859d5d --- /dev/null +++ b/mlir/include/mlir-c/StandardDialect.h @@ -0,0 +1,42 @@ +/*===-- mlir-c/StandardDialect.h - C API for Standard dialect -----*- C -*-===*\ +|* *| +|* Part of the LLVM Project, under the Apache License v2.0 with LLVM *| +|* Exceptions. *| +|* See https://llvm.org/LICENSE.txt for license information. *| +|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *| +|* *| +|*===----------------------------------------------------------------------===*| +|* *| +|* This header declares the C interface for registering and accessing the *| +|* Standard dialect. A dialect should be registered with a context to make it *| +|* available to users of the context. These users must load the dialect *| +|* before using any of its attributes, operations or types. Parser and pass *| +|* manager can load registered dialects automatically. *| +|* *| +\*===----------------------------------------------------------------------===*/ + +#ifndef MLIR_C_STANDARDDIALECT_H +#define MLIR_C_STANDARDDIALECT_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** Registers the Standard dialect with the given context. This allows the + * dialect to be loaded dynamically if needed when parsing. */ +void mlirContextRegisterStandardDialect(MlirContext context); + +/** Loads the Standard dialect into the given context. The dialect does _not_ + * have to be registered in advance. */ +MlirDialect mlirContextLoadStandardDialect(MlirContext context); + +/** Returns the namespace of the Standard dialect, suitable for loading it. */ +MlirStringRef mlirStandardDialectGetNamespace(); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_STANDARDDIALECT_H diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h index 9a60ecf04fc8..dce293d05588 100644 --- a/mlir/include/mlir/CAPI/IR.h +++ b/mlir/include/mlir/CAPI/IR.h @@ -21,6 +21,7 @@ #include "mlir/IR/Operation.h" DEFINE_C_API_PTR_METHODS(MlirContext, mlir::MLIRContext) +DEFINE_C_API_PTR_METHODS(MlirDialect, mlir::Dialect) DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation) DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block) DEFINE_C_API_PTR_METHODS(MlirRegion, mlir::Region) diff --git a/mlir/lib/CAPI/CMakeLists.txt b/mlir/lib/CAPI/CMakeLists.txt index 79d472b2d026..b9d2c4601b98 100644 --- a/mlir/lib/CAPI/CMakeLists.txt +++ b/mlir/lib/CAPI/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Registration) +add_subdirectory(Standard) diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 3b99f8ac4748..359ee69708eb 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -7,8 +7,10 @@ //===----------------------------------------------------------------------===// #include "mlir-c/IR.h" +#include "mlir-c/Support.h" #include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Support.h" #include "mlir/CAPI/Utils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Dialect.h" @@ -41,6 +43,40 @@ void mlirContextSetAllowUnregisteredDialects(MlirContext context, int allow) { int mlirContextGetAllowUnregisteredDialects(MlirContext context) { return unwrap(context)->allowsUnregisteredDialects(); } +intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) { + return static_cast(unwrap(context)->getAvailableDialects().size()); +} + +// TODO: expose a cheaper way than constructing + sorting a vector only to take +// its size. +intptr_t mlirContextGetNumLoadedDialects(MlirContext context) { + return static_cast(unwrap(context)->getLoadedDialects().size()); +} + +MlirDialect mlirContextGetOrLoadDialect(MlirContext context, + MlirStringRef name) { + return wrap(unwrap(context)->getOrLoadDialect(unwrap(name))); +} + +/* ========================================================================== */ +/* Dialect API. */ +/* ========================================================================== */ + +MlirContext mlirDialectGetContext(MlirDialect dialect) { + return wrap(unwrap(dialect)->getContext()); +} + +int mlirDialectIsNull(MlirDialect dialect) { + return unwrap(dialect) == nullptr; +} + +int mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) { + return unwrap(dialect1) == unwrap(dialect2); +} + +MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) { + return wrap(unwrap(dialect)->getNamespace()); +} /* ========================================================================== */ /* Location API. */ diff --git a/mlir/lib/CAPI/Standard/CMakeLists.txt b/mlir/lib/CAPI/Standard/CMakeLists.txt new file mode 100644 index 000000000000..662841c2d235 --- /dev/null +++ b/mlir/lib/CAPI/Standard/CMakeLists.txt @@ -0,0 +1,11 @@ +add_mlir_library(MLIRCAPIStandard + + StandardDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir-c + + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRStandardOps + ) diff --git a/mlir/lib/CAPI/Standard/StandardDialect.cpp b/mlir/lib/CAPI/Standard/StandardDialect.cpp new file mode 100644 index 000000000000..f78c9c916873 --- /dev/null +++ b/mlir/lib/CAPI/Standard/StandardDialect.cpp @@ -0,0 +1,25 @@ +//===- StandardDialect.cpp - C Interface for Standard dialect -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/StandardDialect.h" +#include "mlir-c/IR.h" +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Support.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" + +void mlirContextRegisterStandardDialect(MlirContext context) { + unwrap(context)->getDialectRegistry().insert(); +} + +MlirDialect mlirContextLoadStandardDialect(MlirContext context) { + return wrap(unwrap(context)->getOrLoadDialect()); +} + +MlirStringRef mlirStandardDialectGetNamespace() { + return wrap(mlir::StandardOpsDialect::getDialectNamespace()); +} diff --git a/mlir/test/CAPI/CMakeLists.txt b/mlir/test/CAPI/CMakeLists.txt index 19deda5e3f11..876d701d7211 100644 --- a/mlir/test/CAPI/CMakeLists.txt +++ b/mlir/test/CAPI/CMakeLists.txt @@ -13,4 +13,5 @@ target_link_libraries(mlir-capi-ir-test PRIVATE MLIRCAPIIR MLIRCAPIRegistration + MLIRCAPIStandard ${dialect_libs}) diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c index 909929647a84..ae60d56a22ed 100644 --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -14,6 +14,7 @@ #include "mlir-c/AffineMap.h" #include "mlir-c/Registration.h" #include "mlir-c/StandardAttributes.h" +#include "mlir-c/StandardDialect.h" #include "mlir-c/StandardTypes.h" #include @@ -790,6 +791,42 @@ int printAffineMap(MlirContext ctx) { return 0; } +int registerOnlyStd() { + MlirContext ctx = mlirContextCreate(); + // The built-in dialect is always loaded. + if (mlirContextGetNumLoadedDialects(ctx) != 1) + return 1; + + MlirDialect std = + mlirContextGetOrLoadDialect(ctx, mlirStandardDialectGetNamespace()); + if (!mlirDialectIsNull(std)) + return 2; + + mlirContextRegisterStandardDialect(ctx); + if (mlirContextGetNumRegisteredDialects(ctx) != 1) + return 3; + if (mlirContextGetNumLoadedDialects(ctx) != 1) + return 4; + + std = mlirContextGetOrLoadDialect(ctx, mlirStandardDialectGetNamespace()); + if (mlirDialectIsNull(std)) + return 5; + if (mlirContextGetNumLoadedDialects(ctx) != 2) + return 6; + + MlirDialect alsoStd = mlirContextLoadStandardDialect(ctx); + if (!mlirDialectEqual(std, alsoStd)) + return 7; + + MlirStringRef stdNs = mlirDialectGetNamespace(std); + MlirStringRef alsoStdNs = mlirStandardDialectGetNamespace(); + if (stdNs.length != alsoStdNs.length || + strncmp(stdNs.data, alsoStdNs.data, stdNs.length)) + return 8; + + return 0; +} + int main() { MlirContext ctx = mlirContextCreate(); mlirRegisterAllDialects(ctx); @@ -935,6 +972,14 @@ int main() { errcode = printAffineMap(ctx); fprintf(stderr, "%d\n", errcode); + fprintf(stderr, "@registration\n"); + errcode = registerOnlyStd(); + fprintf(stderr, "%d\n", errcode); + // clang-format off + // CHECK-LABEL: @registration + // CHECK: 0 + // clang-format on + mlirContextDestroy(ctx); return 0;