From c538169ee99516c178ecc00a5ec5187d78941fac Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Wed, 23 Sep 2020 15:02:47 +0200 Subject: [PATCH] [mlir] Add insert before/after to list-like constructs in C API Blocks in a region and operations in a block are organized in a linked list. The C API only provides functions to append or to insert elements at the specified numeric position in the list. The latter is expensive since it requires to traverse the list. Add insert before/after functionality with low cost that relies on the iplist elements being convertible to iterators. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D88148 --- mlir/include/mlir-c/IR.h | 32 +++++++++++++++-- mlir/lib/CAPI/IR/IR.cpp | 52 +++++++++++++++++++++++++++ mlir/test/CAPI/ir.c | 78 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 160 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index b9c5bec3aa44..4aca261868f3 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -265,10 +265,23 @@ MlirBlock mlirRegionGetFirstBlock(MlirRegion region); void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block); /** Takes a block owned by the caller and inserts it at `pos` to the given - * region. */ + * region. This is an expensive operation that linearly scans the region, prefer + * insertAfter/Before instead. */ void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, MlirBlock block); +/** Takes a block owned by the caller and inserts it after the (non-owned) + * reference block in the given region. The reference block must belong to the + * region. If the reference block is null, prepends the block to the region. */ +void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, + MlirBlock block); + +/** Takes a block owned by the caller and inserts it before the (non-owned) + * reference block in the given region. The reference block must belong to the + * region. If the reference block is null, appends the block to the region. */ +void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, + MlirBlock block); + /*============================================================================*/ /* Block API. */ /*============================================================================*/ @@ -294,10 +307,25 @@ MlirOperation mlirBlockGetFirstOperation(MlirBlock block); void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation); /** Takes an operation owned by the caller and inserts it as `pos` to the block. - */ + This is an expensive operation that scans the block linearly, prefer + insertBefore/After instead. */ void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos, MlirOperation operation); +/** Takes an operation owned by the caller and inserts it after the (non-owned) + * reference operation in the given block. If the reference is null, prepends + * the operation. Otherwise, the reference must belong to the block. */ +void mlirBlockInsertOwnedOperationAfter(MlirBlock block, + MlirOperation reference, + MlirOperation operation); + +/** Takes an operation owned by the caller and inserts it before the (non-owned) + * reference operation in the given block. If the reference is null, appends the + * operation. Otherwise, the reference must belong to the block. */ +void mlirBlockInsertOwnedOperationBefore(MlirBlock block, + MlirOperation reference, + MlirOperation operation); + /** Returns the number of arguments of the block. */ intptr_t mlirBlockGetNumArguments(MlirBlock block); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 2265df1c8234..3f5c7cf8986c 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -268,6 +268,31 @@ void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, blockList.insert(std::next(blockList.begin(), pos), unwrap(block)); } +void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, + MlirBlock block) { + Region *cppRegion = unwrap(region); + if (mlirBlockIsNull(reference)) { + cppRegion->getBlocks().insert(cppRegion->begin(), unwrap(block)); + return; + } + + assert(unwrap(reference)->getParent() == unwrap(region) && + "expected reference block to belong to the region"); + cppRegion->getBlocks().insertAfter(Region::iterator(unwrap(reference)), + unwrap(block)); +} + +void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, + MlirBlock block) { + if (mlirBlockIsNull(reference)) + return mlirRegionAppendOwnedBlock(region, block); + + assert(unwrap(reference)->getParent() == unwrap(region) && + "expected reference block to belong to the region"); + unwrap(region)->getBlocks().insert(Region::iterator(unwrap(reference)), + unwrap(block)); +} + void mlirRegionDestroy(MlirRegion region) { delete static_cast(region.ptr); } @@ -306,6 +331,33 @@ void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos, opList.insert(std::next(opList.begin(), pos), unwrap(operation)); } +void mlirBlockInsertOwnedOperationAfter(MlirBlock block, + MlirOperation reference, + MlirOperation operation) { + Block *cppBlock = unwrap(block); + if (mlirOperationIsNull(reference)) { + cppBlock->getOperations().insert(cppBlock->begin(), unwrap(operation)); + return; + } + + assert(unwrap(reference)->getBlock() == unwrap(block) && + "expected reference operation to belong to the block"); + cppBlock->getOperations().insertAfter(Block::iterator(unwrap(reference)), + unwrap(operation)); +} + +void mlirBlockInsertOwnedOperationBefore(MlirBlock block, + MlirOperation reference, + MlirOperation operation) { + if (mlirOperationIsNull(reference)) + return mlirBlockAppendOwnedOperation(block, operation); + + assert(unwrap(reference)->getBlock() == unwrap(block) && + "expected reference operation to belong to the block"); + unwrap(block)->getOperations().insert(Block::iterator(unwrap(reference)), + unwrap(operation)); +} + void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); } int mlirBlockIsNull(MlirBlock block) { return unwrap(block) == nullptr; } diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c index 01b007e71783..4849111986cd 100644 --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -245,6 +245,68 @@ static void printFirstOfEach(MlirOperation operation) { fprintf(stderr, "\n"); } +/// Creates an operation with a region containing multiple blocks with +/// operations and dumps it. The blocks and operations are inserted using +/// block/operation-relative API and their final order is checked. +static void buildWithInsertionsAndPrint(MlirContext ctx) { + MlirLocation loc = mlirLocationUnknownGet(ctx); + + MlirRegion owningRegion = mlirRegionCreate(); + MlirBlock nullBlock = mlirRegionGetFirstBlock(owningRegion); + MlirOperationState state = mlirOperationStateGet("insertion.order.test", loc); + mlirOperationStateAddOwnedRegions(&state, 1, &owningRegion); + MlirOperation op = mlirOperationCreate(&state); + MlirRegion region = mlirOperationGetRegion(op, 0); + + // Use integer types of different bitwidth as block arguments in order to + // differentiate blocks. + MlirType i1 = mlirIntegerTypeGet(ctx, 1); + MlirType i2 = mlirIntegerTypeGet(ctx, 2); + MlirType i3 = mlirIntegerTypeGet(ctx, 3); + MlirType i4 = mlirIntegerTypeGet(ctx, 4); + MlirBlock block1 = mlirBlockCreate(1, &i1); + MlirBlock block2 = mlirBlockCreate(1, &i2); + MlirBlock block3 = mlirBlockCreate(1, &i3); + MlirBlock block4 = mlirBlockCreate(1, &i4); + // Insert blocks so as to obtain the 1-2-3-4 order, + mlirRegionInsertOwnedBlockBefore(region, nullBlock, block3); + mlirRegionInsertOwnedBlockBefore(region, block3, block2); + mlirRegionInsertOwnedBlockAfter(region, nullBlock, block1); + mlirRegionInsertOwnedBlockAfter(region, block3, block4); + + MlirOperationState op1State = mlirOperationStateGet("dummy.op1", loc); + MlirOperationState op2State = mlirOperationStateGet("dummy.op2", loc); + MlirOperationState op3State = mlirOperationStateGet("dummy.op3", loc); + MlirOperationState op4State = mlirOperationStateGet("dummy.op4", loc); + MlirOperationState op5State = mlirOperationStateGet("dummy.op5", loc); + MlirOperationState op6State = mlirOperationStateGet("dummy.op6", loc); + MlirOperationState op7State = mlirOperationStateGet("dummy.op7", loc); + MlirOperation op1 = mlirOperationCreate(&op1State); + MlirOperation op2 = mlirOperationCreate(&op2State); + MlirOperation op3 = mlirOperationCreate(&op3State); + MlirOperation op4 = mlirOperationCreate(&op4State); + MlirOperation op5 = mlirOperationCreate(&op5State); + MlirOperation op6 = mlirOperationCreate(&op6State); + MlirOperation op7 = mlirOperationCreate(&op7State); + + // Insert operations in the first block so as to obtain the 1-2-3-4 order. + MlirOperation nullOperation = mlirBlockGetFirstOperation(block1); + assert(mlirOperationIsNull(nullOperation)); + mlirBlockInsertOwnedOperationBefore(block1, nullOperation, op3); + mlirBlockInsertOwnedOperationBefore(block1, op3, op2); + mlirBlockInsertOwnedOperationAfter(block1, nullOperation, op1); + mlirBlockInsertOwnedOperationAfter(block1, op3, op4); + + // Append operations to the rest of blocks to make them non-empty and thus + // printable. + mlirBlockAppendOwnedOperation(block2, op5); + mlirBlockAppendOwnedOperation(block3, op6); + mlirBlockAppendOwnedOperation(block4, op7); + + mlirOperationDump(op); + mlirOperationDestroy(op); +} + /// Dumps instances of all standard types to check that C API works correctly. /// Additionally, performs simple identity checks that a standard type /// constructed with C API can be inspected and has the expected type. The @@ -763,6 +825,22 @@ int main() { mlirModuleDestroy(moduleOp); + buildWithInsertionsAndPrint(ctx); + // clang-format off + // CHECK-LABEL: "insertion.order.test" + // CHECK: ^{{.*}}(%{{.*}}: i1 + // CHECK: "dummy.op1" + // CHECK-NEXT: "dummy.op2" + // CHECK-NEXT: "dummy.op3" + // CHECK-NEXT: "dummy.op4" + // CHECK: ^{{.*}}(%{{.*}}: i2 + // CHECK: "dummy.op5" + // CHECK: ^{{.*}}(%{{.*}}: i3 + // CHECK: "dummy.op6" + // CHECK: ^{{.*}}(%{{.*}}: i4 + // CHECK: "dummy.op7" + // clang-format on + // clang-format off // CHECK-LABEL: @types // CHECK: i32