[mlir][CAPI] Attribute set/remove on operations.

* New functions: mlirOperationSetAttributeByName, mlirOperationRemoveAttributeByName
* Also adds some *IsNull checks and standardizes the rest to use "static inline" form, which makes them all non-opaque and not part of the ABI (which is desirable).
* Changes needed to resolve TODOs in npcomp PyTorch capture.

Differential Revision: https://reviews.llvm.org/D88946
This commit is contained in:
Stella Laurenzo 2020-10-06 23:01:20 -07:00
parent 3bba91f64e
commit 4aa217160e
3 changed files with 85 additions and 27 deletions

View File

@ -92,7 +92,9 @@ MlirContext mlirContextCreate();
int mlirContextEqual(MlirContext ctx1, MlirContext ctx2);
/** Checks whether a context is null. */
inline int mlirContextIsNull(MlirContext context) { return !context.ptr; }
static inline int mlirContextIsNull(MlirContext context) {
return !context.ptr;
}
/** Takes an MLIR context owned by the caller and destroys it. */
void mlirContextDestroy(MlirContext context);
@ -127,7 +129,9 @@ MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
MlirContext mlirDialectGetContext(MlirDialect dialect);
/** Checks if the dialect is null. */
int mlirDialectIsNull(MlirDialect dialect);
static inline int mlirDialectIsNull(MlirDialect dialect) {
return !dialect.ptr;
}
/** Checks if two dialects that belong to the same context are equal. Dialects
* from different contexts will not compare equal. */
@ -171,7 +175,7 @@ MlirModule mlirModuleCreateParse(MlirContext context, const char *module);
MlirContext mlirModuleGetContext(MlirModule module);
/** Checks whether a module is null. */
inline int mlirModuleIsNull(MlirModule module) { return !module.ptr; }
static inline int mlirModuleIsNull(MlirModule module) { return !module.ptr; }
/** Takes a module owned by the caller and deletes it. */
void mlirModuleDestroy(MlirModule module);
@ -235,7 +239,7 @@ MlirOperation mlirOperationCreate(const MlirOperationState *state);
void mlirOperationDestroy(MlirOperation op);
/** Checks whether the underlying operation is null. */
int mlirOperationIsNull(MlirOperation op);
static inline int mlirOperationIsNull(MlirOperation op) { return !op.ptr; }
/** Returns the number of regions attached to the given operation. */
intptr_t mlirOperationGetNumRegions(MlirOperation op);
@ -275,6 +279,15 @@ MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos);
MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
const char *name);
/** Sets an attribute by name, replacing the existing if it exists or
* adding a new one otherwise. */
void mlirOperationSetAttributeByName(MlirOperation op, const char *name,
MlirAttribute attr);
/** Removes an attribute by name. Returns 0 if the attribute was not found
* and !0 if removed. */
int mlirOperationRemoveAttributeByName(MlirOperation op, const char *name);
/** Prints an operation by sending chunks of the string representation and
* forwarding `userData to `callback`. Note that the callback may be called
* several times with consecutive chunks of the string. */
@ -295,7 +308,7 @@ MlirRegion mlirRegionCreate();
void mlirRegionDestroy(MlirRegion region);
/** Checks whether a region is null. */
int mlirRegionIsNull(MlirRegion region);
static inline int mlirRegionIsNull(MlirRegion region) { return !region.ptr; }
/** Gets the first block in the region. */
MlirBlock mlirRegionGetFirstBlock(MlirRegion region);
@ -333,7 +346,7 @@ MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType *args);
void mlirBlockDestroy(MlirBlock block);
/** Checks whether a block is null. */
int mlirBlockIsNull(MlirBlock block);
static inline int mlirBlockIsNull(MlirBlock block) { return !block.ptr; }
/** Returns the block immediately following the given block in its parent
* region. */
@ -381,6 +394,9 @@ void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
/* Value API. */
/*============================================================================*/
/** Returns whether the value is null. */
static inline int mlirValueIsNull(MlirValue value) { return !value.ptr; }
/** Returns the type of the value. */
MlirType mlirValueGetType(MlirValue value);
@ -401,7 +417,7 @@ MlirType mlirTypeParseGet(MlirContext context, const char *type);
MlirContext mlirTypeGetContext(MlirType type);
/** Checks whether a type is null. */
inline int mlirTypeIsNull(MlirType type) { return !type.ptr; }
static inline int mlirTypeIsNull(MlirType type) { return !type.ptr; }
/** Checks if two types are equal. */
int mlirTypeEqual(MlirType t1, MlirType t2);
@ -425,7 +441,7 @@ MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr);
MlirContext mlirAttributeGetContext(MlirAttribute attribute);
/** Checks whether an attribute is null. */
inline int mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; }
static inline int mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; }
/** Checks if two attributes are equal. */
int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2);

View File

@ -66,10 +66,6 @@ MlirContext mlirDialectGetContext(MlirDialect dialect) {
return wrap(unwrap(dialect)->getContext());
}
int mlirDialectIsNull(MlirDialect dialect) {
return unwrap(dialect) == nullptr;
}
int mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) {
return unwrap(dialect1) == unwrap(dialect2);
}
@ -215,8 +211,6 @@ MlirOperation mlirOperationCreate(const MlirOperationState *state) {
void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
int mlirOperationIsNull(MlirOperation op) { return unwrap(op) == nullptr; }
intptr_t mlirOperationGetNumRegions(MlirOperation op) {
return static_cast<intptr_t>(unwrap(op)->getNumRegions());
}
@ -267,6 +261,16 @@ MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
return wrap(unwrap(op)->getAttr(name));
}
void mlirOperationSetAttributeByName(MlirOperation op, const char *name,
MlirAttribute attr) {
unwrap(op)->setAttr(name, unwrap(attr));
}
int mlirOperationRemoveAttributeByName(MlirOperation op, const char *name) {
auto removeResult = unwrap(op)->removeAttr(name);
return removeResult == MutableDictionaryAttr::RemoveResult::Removed;
}
void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
void *userData) {
detail::CallbackOstream stream(callback, userData);
@ -328,8 +332,6 @@ void mlirRegionDestroy(MlirRegion region) {
delete static_cast<Region *>(region.ptr);
}
int mlirRegionIsNull(MlirRegion region) { return unwrap(region) == nullptr; }
/* ========================================================================== */
/* Block API. */
/* ========================================================================== */
@ -391,8 +393,6 @@ void mlirBlockInsertOwnedOperationBefore(MlirBlock block,
void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); }
int mlirBlockIsNull(MlirBlock block) { return unwrap(block) == nullptr; }
intptr_t mlirBlockGetNumArguments(MlirBlock block) {
return static_cast<intptr_t>(unwrap(block)->getNumArguments());
}

View File

@ -216,7 +216,7 @@ static void printToStderr(const char *str, intptr_t len, void *userData) {
fwrite(str, 1, len, stderr);
}
static void printFirstOfEach(MlirOperation operation) {
static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
// Assuming we are given a module, go to the first operation of the first
// function.
MlirRegion region = mlirOperationGetRegion(operation, 0);
@ -227,24 +227,59 @@ static void printFirstOfEach(MlirOperation operation) {
operation = mlirBlockGetFirstOperation(block);
// In the module we created, the first operation of the first function is an
// "std.dim", which has an attribute an a single result that we can use to
// "std.dim", which has an attribute and a single result that we can use to
// test the printing mechanism.
mlirBlockPrint(block, printToStderr, NULL);
fprintf(stderr, "\n");
fprintf(stderr, "First operation: ");
mlirOperationPrint(operation, printToStderr, NULL);
fprintf(stderr, "\n");
MlirNamedAttribute namedAttr = mlirOperationGetAttribute(operation, 0);
mlirAttributePrint(namedAttr.attribute, printToStderr, NULL);
// Get the attribute by index.
MlirNamedAttribute namedAttr0 = mlirOperationGetAttribute(operation, 0);
fprintf(stderr, "Get attr 0: ");
mlirAttributePrint(namedAttr0.attribute, printToStderr, NULL);
fprintf(stderr, "\n");
// Now re-get the attribute by name.
MlirAttribute attr0ByName =
mlirOperationGetAttributeByName(operation, namedAttr0.name);
fprintf(stderr, "Get attr 0 by name: ");
mlirAttributePrint(attr0ByName, printToStderr, NULL);
fprintf(stderr, "\n");
// Get a non-existing attribute and assert that it is null (sanity).
fprintf(stderr, "does_not_exist is null: %d\n",
mlirAttributeIsNull(
mlirOperationGetAttributeByName(operation, "does_not_exist")));
// Get result 0 and its type.
MlirValue value = mlirOperationGetResult(operation, 0);
fprintf(stderr, "Result 0: ");
mlirValuePrint(value, printToStderr, NULL);
fprintf(stderr, "\n");
fprintf(stderr, "Value is null: %d\n", mlirValueIsNull(value));
MlirType type = mlirValueGetType(value);
fprintf(stderr, "Result 0 type: ");
mlirTypePrint(type, printToStderr, NULL);
fprintf(stderr, "\n");
// Set a custom attribute.
mlirOperationSetAttributeByName(operation, "custom_attr",
mlirBoolAttrGet(ctx, 1));
fprintf(stderr, "Op with set attr: ");
mlirOperationPrint(operation, printToStderr, NULL);
fprintf(stderr, "\n");
// Remove the attribute.
fprintf(stderr, "Remove attr: %d\n",
mlirOperationRemoveAttributeByName(operation, "custom_attr"));
fprintf(stderr, "Remove attr again: %d\n",
mlirOperationRemoveAttributeByName(operation, "custom_attr"));
fprintf(stderr, "Removed attr is null: %d\n",
mlirAttributeIsNull(
mlirOperationGetAttributeByName(operation, "custom_attr")));
}
/// Creates an operation with a region containing multiple blocks with
@ -884,7 +919,7 @@ int main() {
// CHECK: Number of values: 9
// clang-format on
printFirstOfEach(module);
printFirstOfEach(ctx, module);
// clang-format off
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[DIM:.*]] = dim %{{.*}}, %[[C0]] : memref<?xf32>
@ -896,10 +931,17 @@ int main() {
// CHECK: store %[[SUM]], %{{.*}}[%[[I]]] : memref<?xf32>
// CHECK: }
// CHECK: return
// CHECK: constant 0 : index
// CHECK: 0 : index
// CHECK: constant 0 : index
// CHECK: index
// CHECK: First operation: {{.*}} = constant 0 : index
// CHECK: Get attr 0: 0 : index
// CHECK: Get attr 0 by name: 0 : index
// CHECK: does_not_exist is null: 1
// CHECK: Result 0: {{.*}} = constant 0 : index
// CHECK: Value is null: 0
// CHECK: Result 0 type: index
// CHECK: Op with set attr: {{.*}} {custom_attr = true}
// CHECK: Remove attr: 1
// CHECK: Remove attr again: 0
// CHECK: Removed attr is null: 1
// clang-format on
mlirModuleDestroy(moduleOp);