[mlir] Expose Value hierarchy to C API

The Value hierarchy consists of BlockArgument and OpResult, both of which
derive Value. Introduce IsA functions and functions specific to each class,
similarly to other class hierarchies. Also, introduce functions for
pointer-comparison of Block and Operation that are necessary for testing and
are generally useful.

Reviewed By: stellaraccident, mehdi_amini

Differential Revision: https://reviews.llvm.org/D89714
This commit is contained in:
Alex Zinenko 2020-10-19 19:17:51 +02:00
parent 595c615606
commit 39613c2cbc
3 changed files with 121 additions and 6 deletions

View File

@ -241,6 +241,10 @@ void mlirOperationDestroy(MlirOperation op);
/** Checks whether the underlying operation is null. */
static inline int mlirOperationIsNull(MlirOperation op) { return !op.ptr; }
/** Checks whether two operation handles point to the same operation. This does
* not perform deep comparison. */
int mlirOperationEqual(MlirOperation op, MlirOperation other);
/** Returns the number of regions attached to the given operation. */
intptr_t mlirOperationGetNumRegions(MlirOperation op);
@ -348,6 +352,10 @@ void mlirBlockDestroy(MlirBlock block);
/** Checks whether a block is null. */
static inline int mlirBlockIsNull(MlirBlock block) { return !block.ptr; }
/** Checks whether two blocks handles point to the same block. This does not
* perform deep comparison. */
int mlirBlockEqual(MlirBlock block, MlirBlock other);
/** Returns the block immediately following the given block in its parent
* region. */
MlirBlock mlirBlockGetNextInRegion(MlirBlock block);
@ -397,6 +405,30 @@ void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
/** Returns whether the value is null. */
static inline int mlirValueIsNull(MlirValue value) { return !value.ptr; }
/** Returns 1 if the value is a block argument, 0 otherwise. */
int mlirValueIsABlockArgument(MlirValue value);
/** Returns 1 if the value is an operation result, 0 otherwise. */
int mlirValueIsAOpResult(MlirValue value);
/** Returns the block in which this value is defined as an argument. Asserts if
* the value is not a block argument. */
MlirBlock mlirBlockArgumentGetOwner(MlirValue value);
/** Returns the position of the value in the argument list of its block. */
intptr_t mlirBlockArgumentGetArgNumber(MlirValue value);
/** Sets the type of the block argument to the given type. */
void mlirBlockArgumentSetType(MlirValue value, MlirType type);
/** Returns an operation that produced this value as its result. Asserts if the
* value is not an op result. */
MlirOperation mlirOpResultGetOwner(MlirValue value);
/** Returns the position of the value in the list of results of the operation
* that produced it. */
intptr_t mlirOpResultGetResultNumber(MlirValue value);
/** Returns the type of the value. */
MlirType mlirValueGetType(MlirValue value);

View File

@ -211,6 +211,10 @@ MlirOperation mlirOperationCreate(const MlirOperationState *state) {
void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
int mlirOperationEqual(MlirOperation op, MlirOperation other) {
return unwrap(op) == unwrap(other);
}
intptr_t mlirOperationGetNumRegions(MlirOperation op) {
return static_cast<intptr_t>(unwrap(op)->getNumRegions());
}
@ -343,6 +347,10 @@ MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType *args) {
return wrap(b);
}
int mlirBlockEqual(MlirBlock block, MlirBlock other) {
return unwrap(block) == unwrap(other);
}
MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
return wrap(unwrap(block)->getNextNode());
}
@ -412,6 +420,36 @@ void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
/* Value API. */
/* ========================================================================== */
int mlirValueIsABlockArgument(MlirValue value) {
return unwrap(value).isa<BlockArgument>();
}
int mlirValueIsAOpResult(MlirValue value) {
return unwrap(value).isa<OpResult>();
}
MlirBlock mlirBlockArgumentGetOwner(MlirValue value) {
return wrap(unwrap(value).cast<BlockArgument>().getOwner());
}
intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) {
return static_cast<intptr_t>(
unwrap(value).cast<BlockArgument>().getArgNumber());
}
void mlirBlockArgumentSetType(MlirValue value, MlirType type) {
unwrap(value).cast<BlockArgument>().setType(unwrap(type));
}
MlirOperation mlirOpResultGetOwner(MlirValue value) {
return wrap(unwrap(value).cast<OpResult>().getOwner());
}
intptr_t mlirOpResultGetResultNumber(MlirValue value) {
return static_cast<intptr_t>(
unwrap(value).cast<OpResult>().getResultNumber());
}
MlirType mlirValueGetType(MlirValue value) {
return wrap(unwrap(value).getType());
}

View File

@ -153,10 +153,12 @@ struct ModuleStats {
unsigned numBlocks;
unsigned numRegions;
unsigned numValues;
unsigned numBlockArguments;
unsigned numOpResults;
};
typedef struct ModuleStats ModuleStats;
void collectStatsSingle(OpListNode *head, ModuleStats *stats) {
int collectStatsSingle(OpListNode *head, ModuleStats *stats) {
MlirOperation operation = head->op;
stats->numOperations += 1;
stats->numValues += mlirOperationGetNumResults(operation);
@ -166,12 +168,39 @@ void collectStatsSingle(OpListNode *head, ModuleStats *stats) {
stats->numRegions += numRegions;
intptr_t numResults = mlirOperationGetNumResults(operation);
for (intptr_t i = 0; i < numResults; ++i) {
MlirValue result = mlirOperationGetResult(operation, i);
if (!mlirValueIsAOpResult(result))
return 1;
if (mlirValueIsABlockArgument(result))
return 2;
if (!mlirOperationEqual(operation, mlirOpResultGetOwner(result)))
return 3;
if (i != mlirOpResultGetResultNumber(result))
return 4;
++stats->numOpResults;
}
for (unsigned i = 0; i < numRegions; ++i) {
MlirRegion region = mlirOperationGetRegion(operation, i);
for (MlirBlock block = mlirRegionGetFirstBlock(region);
!mlirBlockIsNull(block); block = mlirBlockGetNextInRegion(block)) {
++stats->numBlocks;
stats->numValues += mlirBlockGetNumArguments(block);
intptr_t numArgs = mlirBlockGetNumArguments(block);
stats->numValues += numArgs;
for (intptr_t j = 0; j < numArgs; ++j) {
MlirValue arg = mlirBlockGetArgument(block, j);
if (!mlirValueIsABlockArgument(arg))
return 5;
if (mlirValueIsAOpResult(arg))
return 6;
if (!mlirBlockEqual(block, mlirBlockArgumentGetOwner(arg)))
return 7;
if (j != mlirBlockArgumentGetArgNumber(arg))
return 8;
++stats->numBlockArguments;
}
for (MlirOperation child = mlirBlockGetFirstOperation(block);
!mlirOperationIsNull(child);
@ -183,9 +212,10 @@ void collectStatsSingle(OpListNode *head, ModuleStats *stats) {
}
}
}
return 0;
}
void collectStats(MlirOperation operation) {
int collectStats(MlirOperation operation) {
OpListNode *head = malloc(sizeof(OpListNode));
head->op = operation;
head->next = NULL;
@ -196,9 +226,13 @@ void collectStats(MlirOperation operation) {
stats.numBlocks = 0;
stats.numRegions = 0;
stats.numValues = 0;
stats.numBlockArguments = 0;
stats.numOpResults = 0;
do {
collectStatsSingle(head, &stats);
int retval = collectStatsSingle(head, &stats);
if (retval)
return retval;
OpListNode *next = head->next;
free(head);
head = next;
@ -209,6 +243,11 @@ void collectStats(MlirOperation operation) {
fprintf(stderr, "Number of blocks: %u\n", stats.numBlocks);
fprintf(stderr, "Number of regions: %u\n", stats.numRegions);
fprintf(stderr, "Number of values: %u\n", stats.numValues);
fprintf(stderr, "Number of block arguments: %u\n", stats.numBlockArguments);
fprintf(stderr, "Number of op results: %u\n", stats.numOpResults);
if (stats.numValues != stats.numBlockArguments + stats.numOpResults)
return 100;
return 0;
}
static void printToStderr(const char *str, intptr_t len, void *userData) {
@ -914,13 +953,19 @@ int main() {
// CHECK: }
// clang-format on
collectStats(module);
fprintf(stderr, "@stats\n");
int errcode = collectStats(module);
fprintf(stderr, "%d\n", errcode);
// clang-format off
// CHECK-LABEL: @stats
// CHECK: Number of operations: 13
// CHECK: Number of attributes: 4
// CHECK: Number of blocks: 3
// CHECK: Number of regions: 3
// CHECK: Number of values: 9
// CHECK: Number of block arguments: 3
// CHECK: Number of op results: 6
// CHECK: 0
// clang-format on
printFirstOfEach(ctx, module);
@ -988,7 +1033,7 @@ int main() {
// CHECK: 0
// clang-format on
fprintf(stderr, "@types\n");
int errcode = printStandardTypes(ctx);
errcode = printStandardTypes(ctx);
fprintf(stderr, "%d\n", errcode);
// clang-format off