From d89602ed62f3e5f47781659059db6a8cc11122fe Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 17 May 2021 10:14:02 +0000 Subject: [PATCH] Add `mlirModuleFromOperation` to C API At the moment `MlirModule`s can be converted to `MlirOperation`s, but not the other way around (at least not without going around the C API). This makes it impossible to e.g. run passes over a `ModuleOp` created through `mlirOperationCreate`. Reviewed By: nicolasvasilache, mehdi_amini Differential Revision: https://reviews.llvm.org/D102497 --- mlir/include/mlir-c/IR.h | 4 ++++ mlir/lib/CAPI/IR/IR.cpp | 4 ++++ mlir/test/CAPI/ir.c | 2 ++ 3 files changed, 10 insertions(+) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 1b243165cbb3..638eea9b86c9 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -209,6 +209,10 @@ MLIR_CAPI_EXPORTED void mlirModuleDestroy(MlirModule module); /// Views the module as a generic operation. MLIR_CAPI_EXPORTED MlirOperation mlirModuleGetOperation(MlirModule module); +/// Views the generic operation as a module. +/// The returned module is null when the input operation was not a ModuleOp. +MLIR_CAPI_EXPORTED MlirModule mlirModuleFromOperation(MlirOperation op); + //===----------------------------------------------------------------------===// // Operation state. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 4e21835164ab..ebabd6899e06 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -181,6 +181,10 @@ MlirOperation mlirModuleGetOperation(MlirModule module) { return wrap(unwrap(module).getOperation()); } +MlirModule mlirModuleFromOperation(MlirOperation op) { + return wrap(dyn_cast(unwrap(op))); +} + //===----------------------------------------------------------------------===// // Operation state API. //===----------------------------------------------------------------------===// diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c index 7176cbb2625f..42dfa532727e 100644 --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -319,6 +319,7 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) { MlirOperation parentOperation = operation; block = mlirRegionGetFirstBlock(region); operation = mlirBlockGetFirstOperation(block); + assert(mlirModuleIsNull(mlirModuleFromOperation(operation))); // Verify that parent operation and block report correctly. fprintf(stderr, "Parent operation eq: %d\n", @@ -460,6 +461,7 @@ static int constructAndTraverseIr(MlirContext ctx) { MlirModule moduleOp = makeAndDumpAdd(ctx, location); MlirOperation module = mlirModuleGetOperation(moduleOp); + assert(!mlirModuleIsNull(mlirModuleFromOperation(module))); int errcode = collectStats(module); if (errcode)