diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index d807cd46dd58..048bd46679db 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -119,6 +119,13 @@ mlirContextGetNumLoadedDialects(MlirContext context); MLIR_CAPI_EXPORTED MlirDialect mlirContextGetOrLoadDialect(MlirContext context, MlirStringRef name); +/// Returns whether the given fully-qualified operation (i.e. +/// 'dialect.operation') is registered with the context. This will return true +/// if the dialect is loaded and the operation is registered within the +/// dialect. +MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context, + MlirStringRef name); + //===----------------------------------------------------------------------===// // Dialect API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 0a4c5fcb40c3..5046eedb1194 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1752,7 +1752,12 @@ void mlir::python::populateIRCore(py::module &m) { }, [](PyMlirContext &self, bool value) { mlirContextSetAllowUnregisteredDialects(self.get(), value); - }); + }) + .def("is_registered_operation", + [](PyMlirContext &self, std::string &name) { + return mlirContextIsRegisteredOperation( + self.get(), MlirStringRef{name.data(), name.size()}); + }); //---------------------------------------------------------------------------- // Mapping of PyDialectDescriptor diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 67032a4b5540..14cde9633f52 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -60,6 +60,10 @@ MlirDialect mlirContextGetOrLoadDialect(MlirContext context, return wrap(unwrap(context)->getOrLoadDialect(unwrap(name))); } +bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name) { + return unwrap(context)->isOperationRegistered(unwrap(name)); +} + //===----------------------------------------------------------------------===// // Dialect API. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Bindings/Python/dialects.py b/mlir/test/Bindings/Python/dialects.py index 41f4239e2b66..d5f5bee7f4b0 100644 --- a/mlir/test/Bindings/Python/dialects.py +++ b/mlir/test/Bindings/Python/dialects.py @@ -3,14 +3,17 @@ import gc from mlir.ir import * + def run(f): print("\nTEST:", f.__name__) f() gc.collect() assert Context._get_live_count() == 0 + return f # CHECK-LABEL: TEST: testDialectDescriptor +@run def testDialectDescriptor(): ctx = Context() d = ctx.get_dialect_descriptor("std") @@ -25,10 +28,9 @@ def testDialectDescriptor(): else: assert False, "Expected exception" -run(testDialectDescriptor) - # CHECK-LABEL: TEST: testUserDialectClass +@run def testUserDialectClass(): ctx = Context() # Access using attribute. @@ -60,14 +62,14 @@ def testUserDialectClass(): # CHECK: print(d) -run(testUserDialectClass) - # CHECK-LABEL: TEST: testCustomOpView # This test uses the standard dialect AddFOp as an example of a user op. # TODO: Op creation and access is still quite verbose: simplify this test as # additional capabilities come online. +@run def testCustomOpView(): + def createInput(): op = Operation.create("pytest_dummy.intinput", results=[f32]) # TODO: Auto result cast from operation @@ -95,4 +97,12 @@ def testCustomOpView(): m.operation.print() -run(testCustomOpView) +# CHECK-LABEL: TEST: testIsRegisteredOperation +@run +def testIsRegisteredOperation(): + ctx = Context() + + # CHECK: std.cond_br: True + print(f"std.cond_br: {ctx.is_registered_operation('std.cond_br')}") + # CHECK: std.not_existing: False + print(f"std.not_existing: {ctx.is_registered_operation('std.not_existing')}") diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c index 40ef39b19d26..5ce496c8a0e2 100644 --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -1442,6 +1442,22 @@ int registerOnlyStd() { fprintf(stderr, "@registration\n"); // CHECK-LABEL: @registration + // CHECK: std.cond_br is_registered: 1 + fprintf(stderr, "std.cond_br is_registered: %d\n", + mlirContextIsRegisteredOperation( + ctx, mlirStringRefCreateFromCString("std.cond_br"))); + + // CHECK: std.not_existing_op is_registered: 0 + fprintf(stderr, "std.not_existing_op is_registered: %d\n", + mlirContextIsRegisteredOperation( + ctx, mlirStringRefCreateFromCString("std.not_existing_op"))); + + // CHECK: not_existing_dialect.not_existing_op is_registered: 0 + fprintf(stderr, "not_existing_dialect.not_existing_op is_registered: %d\n", + mlirContextIsRegisteredOperation( + ctx, mlirStringRefCreateFromCString( + "not_existing_dialect.not_existing_op"))); + return 0; }