[MLIR] Add block detach func to CAPI and use it in Python bindings

Adds `mlirBlockDetach` to the CAPI to remove a block from its parent
region. Use it in the Python bindings to implement
`Block.append_to(region)`.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D123165
This commit is contained in:
John Demme 2022-04-06 10:06:30 -07:00
parent 51f6caf2fb
commit 8d8738f6fe
5 changed files with 54 additions and 0 deletions

View File

@ -558,6 +558,9 @@ MLIR_CAPI_EXPORTED MlirBlock mlirBlockCreate(intptr_t nArgs,
/// Takes a block owned by the caller and destroys it.
MLIR_CAPI_EXPORTED void mlirBlockDestroy(MlirBlock block);
/// Detach a block from the owning region and assume ownership.
MLIR_CAPI_EXPORTED void mlirBlockDetach(MlirBlock block);
/// Checks whether a block is null.
static inline bool mlirBlockIsNull(MlirBlock block) { return !block.ptr; }

View File

@ -2755,6 +2755,15 @@ void mlir::python::populateIRCore(py::module &m) {
py::arg("parent"), py::arg("arg_types") = py::list(),
"Creates and returns a new Block at the beginning of the given "
"region (with given argument types).")
.def(
"append_to",
[](PyBlock &self, PyRegion &region) {
MlirBlock b = self.get();
if (!mlirRegionIsNull(mlirBlockGetParentRegion(b)))
mlirBlockDetach(b);
mlirRegionAppendOwnedBlock(region.get(), b);
},
"Append this block to a region, transferring ownership if necessary")
.def(
"create_before",
[](PyBlock &self, py::args pyArgTypes) {

View File

@ -634,6 +634,11 @@ void mlirBlockInsertOwnedOperationBefore(MlirBlock block,
void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); }
void mlirBlockDetach(MlirBlock block) {
Block *b = unwrap(block);
b->getParent()->getBlocks().remove(b);
}
intptr_t mlirBlockGetNumArguments(MlirBlock block) {
return static_cast<intptr_t>(unwrap(block)->getNumArguments());
}

View File

@ -510,15 +510,18 @@ static void buildWithInsertionsAndPrint(MlirContext ctx) {
MlirType i2 = mlirIntegerTypeGet(ctx, 2);
MlirType i3 = mlirIntegerTypeGet(ctx, 3);
MlirType i4 = mlirIntegerTypeGet(ctx, 4);
MlirType i5 = mlirIntegerTypeGet(ctx, 5);
MlirBlock block1 = mlirBlockCreate(1, &i1, &loc);
MlirBlock block2 = mlirBlockCreate(1, &i2, &loc);
MlirBlock block3 = mlirBlockCreate(1, &i3, &loc);
MlirBlock block4 = mlirBlockCreate(1, &i4, &loc);
MlirBlock block5 = mlirBlockCreate(1, &i5, &loc);
// Insert blocks so as to obtain the 1-2-3-4 order,
mlirRegionInsertOwnedBlockBefore(region, nullBlock, block3);
mlirRegionInsertOwnedBlockBefore(region, block3, block2);
mlirRegionInsertOwnedBlockAfter(region, nullBlock, block1);
mlirRegionInsertOwnedBlockAfter(region, block3, block4);
mlirRegionInsertOwnedBlockBefore(region, block3, block5);
MlirOperationState op1State =
mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op1"), loc);
@ -534,6 +537,8 @@ static void buildWithInsertionsAndPrint(MlirContext ctx) {
mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op6"), loc);
MlirOperationState op7State =
mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op7"), loc);
MlirOperationState op8State =
mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op8"), loc);
MlirOperation op1 = mlirOperationCreate(&op1State);
MlirOperation op2 = mlirOperationCreate(&op2State);
MlirOperation op3 = mlirOperationCreate(&op3State);
@ -541,6 +546,7 @@ static void buildWithInsertionsAndPrint(MlirContext ctx) {
MlirOperation op5 = mlirOperationCreate(&op5State);
MlirOperation op6 = mlirOperationCreate(&op6State);
MlirOperation op7 = mlirOperationCreate(&op7State);
MlirOperation op8 = mlirOperationCreate(&op8State);
// Insert operations in the first block so as to obtain the 1-2-3-4 order.
MlirOperation nullOperation = mlirBlockGetFirstOperation(block1);
@ -555,6 +561,11 @@ static void buildWithInsertionsAndPrint(MlirContext ctx) {
mlirBlockAppendOwnedOperation(block2, op5);
mlirBlockAppendOwnedOperation(block3, op6);
mlirBlockAppendOwnedOperation(block4, op7);
mlirBlockAppendOwnedOperation(block5, op8);
// Remove block5.
mlirBlockDetach(block5);
mlirBlockDestroy(block5);
mlirOperationDump(op);
mlirOperationDestroy(op);
@ -568,6 +579,8 @@ static void buildWithInsertionsAndPrint(MlirContext ctx) {
// CHECK-NEXT: "dummy.op4"
// CHECK: ^{{.*}}(%{{.*}}: i2
// CHECK: "dummy.op5"
// CHECK-NOT: ^{{.*}}(%{{.*}}: i5
// CHECK-NOT: "dummy.op8"
// CHECK: ^{{.*}}(%{{.*}}: i3
// CHECK: "dummy.op6"
// CHECK: ^{{.*}}(%{{.*}}: i4

View File

@ -70,3 +70,27 @@ def testFirstBlockCreation():
print(module)
assert module.operation.verify()
assert f.body.blocks[0] == entry_block
# CHECK-LABEL: TEST: testBlockMove
# CHECK: %0 = "realop"() ({
# CHECK: ^bb0([[ARG0:%.+]]: f32):
# CHECK: "ret"([[ARG0]]) : (f32) -> ()
# CHECK: }) : () -> f32
@run
def testBlockMove():
with Context() as ctx, Location.unknown():
ctx.allow_unregistered_dialects = True
module = Module.create()
f32 = F32Type.get()
with InsertionPoint(module.body):
dummy = Operation.create("dummy", regions=1)
block = Block.create_at_start(dummy.operation.regions[0], [f32])
with InsertionPoint(block):
ret_op = Operation.create("ret", operands=[block.arguments[0]])
realop = Operation.create("realop",
results=[r.type for r in ret_op.operands],
regions=1)
block.append_to(realop.operation.regions[0])
dummy.operation.erase()
print(module)