[mlir] Add insert before/after to list-like constructs in C API

Blocks in a region and operations in a block are organized in a linked list.
The C API only provides functions to append or to insert elements at the
specified numeric position in the list. The latter is expensive since it
requires to traverse the list. Add insert before/after functionality with low
cost that relies on the iplist elements being convertible to iterators.

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D88148
This commit is contained in:
Alex Zinenko 2020-09-23 15:02:47 +02:00
parent 9abd1e8f4e
commit c538169ee9
3 changed files with 160 additions and 2 deletions

View File

@ -265,10 +265,23 @@ MlirBlock mlirRegionGetFirstBlock(MlirRegion region);
void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block);
/** Takes a block owned by the caller and inserts it at `pos` to the given
* region. */
* region. This is an expensive operation that linearly scans the region, prefer
* insertAfter/Before instead. */
void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos,
MlirBlock block);
/** Takes a block owned by the caller and inserts it after the (non-owned)
* reference block in the given region. The reference block must belong to the
* region. If the reference block is null, prepends the block to the region. */
void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference,
MlirBlock block);
/** Takes a block owned by the caller and inserts it before the (non-owned)
* reference block in the given region. The reference block must belong to the
* region. If the reference block is null, appends the block to the region. */
void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference,
MlirBlock block);
/*============================================================================*/
/* Block API. */
/*============================================================================*/
@ -294,10 +307,25 @@ MlirOperation mlirBlockGetFirstOperation(MlirBlock block);
void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation);
/** Takes an operation owned by the caller and inserts it as `pos` to the block.
*/
This is an expensive operation that scans the block linearly, prefer
insertBefore/After instead. */
void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos,
MlirOperation operation);
/** Takes an operation owned by the caller and inserts it after the (non-owned)
* reference operation in the given block. If the reference is null, prepends
* the operation. Otherwise, the reference must belong to the block. */
void mlirBlockInsertOwnedOperationAfter(MlirBlock block,
MlirOperation reference,
MlirOperation operation);
/** Takes an operation owned by the caller and inserts it before the (non-owned)
* reference operation in the given block. If the reference is null, appends the
* operation. Otherwise, the reference must belong to the block. */
void mlirBlockInsertOwnedOperationBefore(MlirBlock block,
MlirOperation reference,
MlirOperation operation);
/** Returns the number of arguments of the block. */
intptr_t mlirBlockGetNumArguments(MlirBlock block);

View File

@ -268,6 +268,31 @@ void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos,
blockList.insert(std::next(blockList.begin(), pos), unwrap(block));
}
void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference,
MlirBlock block) {
Region *cppRegion = unwrap(region);
if (mlirBlockIsNull(reference)) {
cppRegion->getBlocks().insert(cppRegion->begin(), unwrap(block));
return;
}
assert(unwrap(reference)->getParent() == unwrap(region) &&
"expected reference block to belong to the region");
cppRegion->getBlocks().insertAfter(Region::iterator(unwrap(reference)),
unwrap(block));
}
void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference,
MlirBlock block) {
if (mlirBlockIsNull(reference))
return mlirRegionAppendOwnedBlock(region, block);
assert(unwrap(reference)->getParent() == unwrap(region) &&
"expected reference block to belong to the region");
unwrap(region)->getBlocks().insert(Region::iterator(unwrap(reference)),
unwrap(block));
}
void mlirRegionDestroy(MlirRegion region) {
delete static_cast<Region *>(region.ptr);
}
@ -306,6 +331,33 @@ void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos,
opList.insert(std::next(opList.begin(), pos), unwrap(operation));
}
void mlirBlockInsertOwnedOperationAfter(MlirBlock block,
MlirOperation reference,
MlirOperation operation) {
Block *cppBlock = unwrap(block);
if (mlirOperationIsNull(reference)) {
cppBlock->getOperations().insert(cppBlock->begin(), unwrap(operation));
return;
}
assert(unwrap(reference)->getBlock() == unwrap(block) &&
"expected reference operation to belong to the block");
cppBlock->getOperations().insertAfter(Block::iterator(unwrap(reference)),
unwrap(operation));
}
void mlirBlockInsertOwnedOperationBefore(MlirBlock block,
MlirOperation reference,
MlirOperation operation) {
if (mlirOperationIsNull(reference))
return mlirBlockAppendOwnedOperation(block, operation);
assert(unwrap(reference)->getBlock() == unwrap(block) &&
"expected reference operation to belong to the block");
unwrap(block)->getOperations().insert(Block::iterator(unwrap(reference)),
unwrap(operation));
}
void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); }
int mlirBlockIsNull(MlirBlock block) { return unwrap(block) == nullptr; }

View File

@ -245,6 +245,68 @@ static void printFirstOfEach(MlirOperation operation) {
fprintf(stderr, "\n");
}
/// Creates an operation with a region containing multiple blocks with
/// operations and dumps it. The blocks and operations are inserted using
/// block/operation-relative API and their final order is checked.
static void buildWithInsertionsAndPrint(MlirContext ctx) {
MlirLocation loc = mlirLocationUnknownGet(ctx);
MlirRegion owningRegion = mlirRegionCreate();
MlirBlock nullBlock = mlirRegionGetFirstBlock(owningRegion);
MlirOperationState state = mlirOperationStateGet("insertion.order.test", loc);
mlirOperationStateAddOwnedRegions(&state, 1, &owningRegion);
MlirOperation op = mlirOperationCreate(&state);
MlirRegion region = mlirOperationGetRegion(op, 0);
// Use integer types of different bitwidth as block arguments in order to
// differentiate blocks.
MlirType i1 = mlirIntegerTypeGet(ctx, 1);
MlirType i2 = mlirIntegerTypeGet(ctx, 2);
MlirType i3 = mlirIntegerTypeGet(ctx, 3);
MlirType i4 = mlirIntegerTypeGet(ctx, 4);
MlirBlock block1 = mlirBlockCreate(1, &i1);
MlirBlock block2 = mlirBlockCreate(1, &i2);
MlirBlock block3 = mlirBlockCreate(1, &i3);
MlirBlock block4 = mlirBlockCreate(1, &i4);
// Insert blocks so as to obtain the 1-2-3-4 order,
mlirRegionInsertOwnedBlockBefore(region, nullBlock, block3);
mlirRegionInsertOwnedBlockBefore(region, block3, block2);
mlirRegionInsertOwnedBlockAfter(region, nullBlock, block1);
mlirRegionInsertOwnedBlockAfter(region, block3, block4);
MlirOperationState op1State = mlirOperationStateGet("dummy.op1", loc);
MlirOperationState op2State = mlirOperationStateGet("dummy.op2", loc);
MlirOperationState op3State = mlirOperationStateGet("dummy.op3", loc);
MlirOperationState op4State = mlirOperationStateGet("dummy.op4", loc);
MlirOperationState op5State = mlirOperationStateGet("dummy.op5", loc);
MlirOperationState op6State = mlirOperationStateGet("dummy.op6", loc);
MlirOperationState op7State = mlirOperationStateGet("dummy.op7", loc);
MlirOperation op1 = mlirOperationCreate(&op1State);
MlirOperation op2 = mlirOperationCreate(&op2State);
MlirOperation op3 = mlirOperationCreate(&op3State);
MlirOperation op4 = mlirOperationCreate(&op4State);
MlirOperation op5 = mlirOperationCreate(&op5State);
MlirOperation op6 = mlirOperationCreate(&op6State);
MlirOperation op7 = mlirOperationCreate(&op7State);
// Insert operations in the first block so as to obtain the 1-2-3-4 order.
MlirOperation nullOperation = mlirBlockGetFirstOperation(block1);
assert(mlirOperationIsNull(nullOperation));
mlirBlockInsertOwnedOperationBefore(block1, nullOperation, op3);
mlirBlockInsertOwnedOperationBefore(block1, op3, op2);
mlirBlockInsertOwnedOperationAfter(block1, nullOperation, op1);
mlirBlockInsertOwnedOperationAfter(block1, op3, op4);
// Append operations to the rest of blocks to make them non-empty and thus
// printable.
mlirBlockAppendOwnedOperation(block2, op5);
mlirBlockAppendOwnedOperation(block3, op6);
mlirBlockAppendOwnedOperation(block4, op7);
mlirOperationDump(op);
mlirOperationDestroy(op);
}
/// Dumps instances of all standard types to check that C API works correctly.
/// Additionally, performs simple identity checks that a standard type
/// constructed with C API can be inspected and has the expected type. The
@ -763,6 +825,22 @@ int main() {
mlirModuleDestroy(moduleOp);
buildWithInsertionsAndPrint(ctx);
// clang-format off
// CHECK-LABEL: "insertion.order.test"
// CHECK: ^{{.*}}(%{{.*}}: i1
// CHECK: "dummy.op1"
// CHECK-NEXT: "dummy.op2"
// CHECK-NEXT: "dummy.op3"
// CHECK-NEXT: "dummy.op4"
// CHECK: ^{{.*}}(%{{.*}}: i2
// CHECK: "dummy.op5"
// CHECK: ^{{.*}}(%{{.*}}: i3
// CHECK: "dummy.op6"
// CHECK: ^{{.*}}(%{{.*}}: i4
// CHECK: "dummy.op7"
// clang-format on
// clang-format off
// CHECK-LABEL: @types
// CHECK: i32