diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 8bee618213ff..65c097a7604e 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -322,6 +322,9 @@ static inline bool mlirOperationIsNull(MlirOperation op) { return !op.ptr; } MLIR_CAPI_EXPORTED bool mlirOperationEqual(MlirOperation op, MlirOperation other); +/// Gets the context this operation is associated with +MLIR_CAPI_EXPORTED MlirContext mlirOperationGetContext(MlirOperation op); + /// Gets the name of the operation as an identifier. MLIR_CAPI_EXPORTED MlirIdentifier mlirOperationGetName(MlirOperation op); @@ -467,6 +470,9 @@ static inline bool mlirBlockIsNull(MlirBlock block) { return !block.ptr; } /// perform deep comparison. MLIR_CAPI_EXPORTED bool mlirBlockEqual(MlirBlock block, MlirBlock other); +/// Returns the closest surrounding operation that contains this block. +MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetParentOperation(MlirBlock); + /// Returns the block immediately following the given block in its parent /// region. MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetNextInRegion(MlirBlock block); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index fdb830e4940f..87c09944c77c 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -305,6 +305,10 @@ bool mlirOperationEqual(MlirOperation op, MlirOperation other) { return unwrap(op) == unwrap(other); } +MlirContext mlirOperationGetContext(MlirOperation op) { + return wrap(unwrap(op)->getContext()); +} + MlirIdentifier mlirOperationGetName(MlirOperation op) { return wrap(unwrap(op)->getName().getIdentifier()); } @@ -461,6 +465,10 @@ bool mlirBlockEqual(MlirBlock block, MlirBlock other) { return unwrap(block) == unwrap(other); } +MlirOperation mlirBlockGetParentOperation(MlirBlock block) { + return wrap(unwrap(block)->getParentOp()); +} + MlirBlock mlirBlockGetNextInRegion(MlirBlock block) { return wrap(unwrap(block)->getNextNode()); } diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c index c2b13d55473f..2f81d13160c2 100644 --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -1439,6 +1439,37 @@ int registerOnlyStd() { return 0; } +/// Tests backreference APIs +static int testBackreferences() { + fprintf(stderr, "@test_backreferences\n"); + + MlirContext ctx = mlirContextCreate(); + mlirContextSetAllowUnregisteredDialects(ctx, true); + MlirLocation loc = mlirLocationUnknownGet(ctx); + + MlirOperationState opState = mlirOperationStateGet(mlirStringRefCreateFromCString("invalid.op"), loc); + MlirRegion region = mlirRegionCreate(); + MlirBlock block = mlirBlockCreate(0, NULL); + mlirRegionAppendOwnedBlock(region, block); + mlirOperationStateAddOwnedRegions(&opState, 1, ®ion); + MlirOperation op = mlirOperationCreate(&opState); + + if (!mlirContextEqual(ctx, mlirOperationGetContext(op))) { + fprintf(stderr, "ERROR: Getting context from operation failed\n"); + return 1; + } + if (!mlirOperationEqual(op, mlirBlockGetParentOperation(block))) { + fprintf(stderr, "ERROR: Getting parent operation from block failed\n"); + return 2; + } + + mlirOperationDestroy(op); + mlirContextDestroy(ctx); + + // CHECK-LABEL: @test_backreferences + return 0; +} + // Wraps a diagnostic into additional text we can match against. MlirLogicalResult errorHandler(MlirDiagnostic diagnostic, void *userData) { fprintf(stderr, "processing diagnostic (userData: %ld) <<\n", (long)userData); @@ -1514,6 +1545,8 @@ int main() { return 8; if (registerOnlyStd()) return 9; + if (testBackreferences()) + return 10; mlirContextDestroy(ctx);