forked from OSchip/llvm-project
[mlir][CAPI] Attribute set/remove on operations.
* New functions: mlirOperationSetAttributeByName, mlirOperationRemoveAttributeByName * Also adds some *IsNull checks and standardizes the rest to use "static inline" form, which makes them all non-opaque and not part of the ABI (which is desirable). * Changes needed to resolve TODOs in npcomp PyTorch capture. Differential Revision: https://reviews.llvm.org/D88946
This commit is contained in:
parent
3bba91f64e
commit
4aa217160e
|
@ -92,7 +92,9 @@ MlirContext mlirContextCreate();
|
|||
int mlirContextEqual(MlirContext ctx1, MlirContext ctx2);
|
||||
|
||||
/** Checks whether a context is null. */
|
||||
inline int mlirContextIsNull(MlirContext context) { return !context.ptr; }
|
||||
static inline int mlirContextIsNull(MlirContext context) {
|
||||
return !context.ptr;
|
||||
}
|
||||
|
||||
/** Takes an MLIR context owned by the caller and destroys it. */
|
||||
void mlirContextDestroy(MlirContext context);
|
||||
|
@ -127,7 +129,9 @@ MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
|
|||
MlirContext mlirDialectGetContext(MlirDialect dialect);
|
||||
|
||||
/** Checks if the dialect is null. */
|
||||
int mlirDialectIsNull(MlirDialect dialect);
|
||||
static inline int mlirDialectIsNull(MlirDialect dialect) {
|
||||
return !dialect.ptr;
|
||||
}
|
||||
|
||||
/** Checks if two dialects that belong to the same context are equal. Dialects
|
||||
* from different contexts will not compare equal. */
|
||||
|
@ -171,7 +175,7 @@ MlirModule mlirModuleCreateParse(MlirContext context, const char *module);
|
|||
MlirContext mlirModuleGetContext(MlirModule module);
|
||||
|
||||
/** Checks whether a module is null. */
|
||||
inline int mlirModuleIsNull(MlirModule module) { return !module.ptr; }
|
||||
static inline int mlirModuleIsNull(MlirModule module) { return !module.ptr; }
|
||||
|
||||
/** Takes a module owned by the caller and deletes it. */
|
||||
void mlirModuleDestroy(MlirModule module);
|
||||
|
@ -235,7 +239,7 @@ MlirOperation mlirOperationCreate(const MlirOperationState *state);
|
|||
void mlirOperationDestroy(MlirOperation op);
|
||||
|
||||
/** Checks whether the underlying operation is null. */
|
||||
int mlirOperationIsNull(MlirOperation op);
|
||||
static inline int mlirOperationIsNull(MlirOperation op) { return !op.ptr; }
|
||||
|
||||
/** Returns the number of regions attached to the given operation. */
|
||||
intptr_t mlirOperationGetNumRegions(MlirOperation op);
|
||||
|
@ -275,6 +279,15 @@ MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos);
|
|||
MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
|
||||
const char *name);
|
||||
|
||||
/** Sets an attribute by name, replacing the existing if it exists or
|
||||
* adding a new one otherwise. */
|
||||
void mlirOperationSetAttributeByName(MlirOperation op, const char *name,
|
||||
MlirAttribute attr);
|
||||
|
||||
/** Removes an attribute by name. Returns 0 if the attribute was not found
|
||||
* and !0 if removed. */
|
||||
int mlirOperationRemoveAttributeByName(MlirOperation op, const char *name);
|
||||
|
||||
/** Prints an operation by sending chunks of the string representation and
|
||||
* forwarding `userData to `callback`. Note that the callback may be called
|
||||
* several times with consecutive chunks of the string. */
|
||||
|
@ -295,7 +308,7 @@ MlirRegion mlirRegionCreate();
|
|||
void mlirRegionDestroy(MlirRegion region);
|
||||
|
||||
/** Checks whether a region is null. */
|
||||
int mlirRegionIsNull(MlirRegion region);
|
||||
static inline int mlirRegionIsNull(MlirRegion region) { return !region.ptr; }
|
||||
|
||||
/** Gets the first block in the region. */
|
||||
MlirBlock mlirRegionGetFirstBlock(MlirRegion region);
|
||||
|
@ -333,7 +346,7 @@ MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType *args);
|
|||
void mlirBlockDestroy(MlirBlock block);
|
||||
|
||||
/** Checks whether a block is null. */
|
||||
int mlirBlockIsNull(MlirBlock block);
|
||||
static inline int mlirBlockIsNull(MlirBlock block) { return !block.ptr; }
|
||||
|
||||
/** Returns the block immediately following the given block in its parent
|
||||
* region. */
|
||||
|
@ -381,6 +394,9 @@ void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
|
|||
/* Value API. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Returns whether the value is null. */
|
||||
static inline int mlirValueIsNull(MlirValue value) { return !value.ptr; }
|
||||
|
||||
/** Returns the type of the value. */
|
||||
MlirType mlirValueGetType(MlirValue value);
|
||||
|
||||
|
@ -401,7 +417,7 @@ MlirType mlirTypeParseGet(MlirContext context, const char *type);
|
|||
MlirContext mlirTypeGetContext(MlirType type);
|
||||
|
||||
/** Checks whether a type is null. */
|
||||
inline int mlirTypeIsNull(MlirType type) { return !type.ptr; }
|
||||
static inline int mlirTypeIsNull(MlirType type) { return !type.ptr; }
|
||||
|
||||
/** Checks if two types are equal. */
|
||||
int mlirTypeEqual(MlirType t1, MlirType t2);
|
||||
|
@ -425,7 +441,7 @@ MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr);
|
|||
MlirContext mlirAttributeGetContext(MlirAttribute attribute);
|
||||
|
||||
/** Checks whether an attribute is null. */
|
||||
inline int mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; }
|
||||
static inline int mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; }
|
||||
|
||||
/** Checks if two attributes are equal. */
|
||||
int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2);
|
||||
|
|
|
@ -66,10 +66,6 @@ MlirContext mlirDialectGetContext(MlirDialect dialect) {
|
|||
return wrap(unwrap(dialect)->getContext());
|
||||
}
|
||||
|
||||
int mlirDialectIsNull(MlirDialect dialect) {
|
||||
return unwrap(dialect) == nullptr;
|
||||
}
|
||||
|
||||
int mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) {
|
||||
return unwrap(dialect1) == unwrap(dialect2);
|
||||
}
|
||||
|
@ -215,8 +211,6 @@ MlirOperation mlirOperationCreate(const MlirOperationState *state) {
|
|||
|
||||
void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
|
||||
|
||||
int mlirOperationIsNull(MlirOperation op) { return unwrap(op) == nullptr; }
|
||||
|
||||
intptr_t mlirOperationGetNumRegions(MlirOperation op) {
|
||||
return static_cast<intptr_t>(unwrap(op)->getNumRegions());
|
||||
}
|
||||
|
@ -267,6 +261,16 @@ MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
|
|||
return wrap(unwrap(op)->getAttr(name));
|
||||
}
|
||||
|
||||
void mlirOperationSetAttributeByName(MlirOperation op, const char *name,
|
||||
MlirAttribute attr) {
|
||||
unwrap(op)->setAttr(name, unwrap(attr));
|
||||
}
|
||||
|
||||
int mlirOperationRemoveAttributeByName(MlirOperation op, const char *name) {
|
||||
auto removeResult = unwrap(op)->removeAttr(name);
|
||||
return removeResult == MutableDictionaryAttr::RemoveResult::Removed;
|
||||
}
|
||||
|
||||
void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
|
||||
void *userData) {
|
||||
detail::CallbackOstream stream(callback, userData);
|
||||
|
@ -328,8 +332,6 @@ void mlirRegionDestroy(MlirRegion region) {
|
|||
delete static_cast<Region *>(region.ptr);
|
||||
}
|
||||
|
||||
int mlirRegionIsNull(MlirRegion region) { return unwrap(region) == nullptr; }
|
||||
|
||||
/* ========================================================================== */
|
||||
/* Block API. */
|
||||
/* ========================================================================== */
|
||||
|
@ -391,8 +393,6 @@ void mlirBlockInsertOwnedOperationBefore(MlirBlock block,
|
|||
|
||||
void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); }
|
||||
|
||||
int mlirBlockIsNull(MlirBlock block) { return unwrap(block) == nullptr; }
|
||||
|
||||
intptr_t mlirBlockGetNumArguments(MlirBlock block) {
|
||||
return static_cast<intptr_t>(unwrap(block)->getNumArguments());
|
||||
}
|
||||
|
|
|
@ -216,7 +216,7 @@ static void printToStderr(const char *str, intptr_t len, void *userData) {
|
|||
fwrite(str, 1, len, stderr);
|
||||
}
|
||||
|
||||
static void printFirstOfEach(MlirOperation operation) {
|
||||
static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
|
||||
// Assuming we are given a module, go to the first operation of the first
|
||||
// function.
|
||||
MlirRegion region = mlirOperationGetRegion(operation, 0);
|
||||
|
@ -227,24 +227,59 @@ static void printFirstOfEach(MlirOperation operation) {
|
|||
operation = mlirBlockGetFirstOperation(block);
|
||||
|
||||
// In the module we created, the first operation of the first function is an
|
||||
// "std.dim", which has an attribute an a single result that we can use to
|
||||
// "std.dim", which has an attribute and a single result that we can use to
|
||||
// test the printing mechanism.
|
||||
mlirBlockPrint(block, printToStderr, NULL);
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "First operation: ");
|
||||
mlirOperationPrint(operation, printToStderr, NULL);
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
MlirNamedAttribute namedAttr = mlirOperationGetAttribute(operation, 0);
|
||||
mlirAttributePrint(namedAttr.attribute, printToStderr, NULL);
|
||||
// Get the attribute by index.
|
||||
MlirNamedAttribute namedAttr0 = mlirOperationGetAttribute(operation, 0);
|
||||
fprintf(stderr, "Get attr 0: ");
|
||||
mlirAttributePrint(namedAttr0.attribute, printToStderr, NULL);
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
// Now re-get the attribute by name.
|
||||
MlirAttribute attr0ByName =
|
||||
mlirOperationGetAttributeByName(operation, namedAttr0.name);
|
||||
fprintf(stderr, "Get attr 0 by name: ");
|
||||
mlirAttributePrint(attr0ByName, printToStderr, NULL);
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
// Get a non-existing attribute and assert that it is null (sanity).
|
||||
fprintf(stderr, "does_not_exist is null: %d\n",
|
||||
mlirAttributeIsNull(
|
||||
mlirOperationGetAttributeByName(operation, "does_not_exist")));
|
||||
|
||||
// Get result 0 and its type.
|
||||
MlirValue value = mlirOperationGetResult(operation, 0);
|
||||
fprintf(stderr, "Result 0: ");
|
||||
mlirValuePrint(value, printToStderr, NULL);
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "Value is null: %d\n", mlirValueIsNull(value));
|
||||
|
||||
MlirType type = mlirValueGetType(value);
|
||||
fprintf(stderr, "Result 0 type: ");
|
||||
mlirTypePrint(type, printToStderr, NULL);
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
// Set a custom attribute.
|
||||
mlirOperationSetAttributeByName(operation, "custom_attr",
|
||||
mlirBoolAttrGet(ctx, 1));
|
||||
fprintf(stderr, "Op with set attr: ");
|
||||
mlirOperationPrint(operation, printToStderr, NULL);
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
// Remove the attribute.
|
||||
fprintf(stderr, "Remove attr: %d\n",
|
||||
mlirOperationRemoveAttributeByName(operation, "custom_attr"));
|
||||
fprintf(stderr, "Remove attr again: %d\n",
|
||||
mlirOperationRemoveAttributeByName(operation, "custom_attr"));
|
||||
fprintf(stderr, "Removed attr is null: %d\n",
|
||||
mlirAttributeIsNull(
|
||||
mlirOperationGetAttributeByName(operation, "custom_attr")));
|
||||
}
|
||||
|
||||
/// Creates an operation with a region containing multiple blocks with
|
||||
|
@ -884,7 +919,7 @@ int main() {
|
|||
// CHECK: Number of values: 9
|
||||
// clang-format on
|
||||
|
||||
printFirstOfEach(module);
|
||||
printFirstOfEach(ctx, module);
|
||||
// clang-format off
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[DIM:.*]] = dim %{{.*}}, %[[C0]] : memref<?xf32>
|
||||
|
@ -896,10 +931,17 @@ int main() {
|
|||
// CHECK: store %[[SUM]], %{{.*}}[%[[I]]] : memref<?xf32>
|
||||
// CHECK: }
|
||||
// CHECK: return
|
||||
// CHECK: constant 0 : index
|
||||
// CHECK: 0 : index
|
||||
// CHECK: constant 0 : index
|
||||
// CHECK: index
|
||||
// CHECK: First operation: {{.*}} = constant 0 : index
|
||||
// CHECK: Get attr 0: 0 : index
|
||||
// CHECK: Get attr 0 by name: 0 : index
|
||||
// CHECK: does_not_exist is null: 1
|
||||
// CHECK: Result 0: {{.*}} = constant 0 : index
|
||||
// CHECK: Value is null: 0
|
||||
// CHECK: Result 0 type: index
|
||||
// CHECK: Op with set attr: {{.*}} {custom_attr = true}
|
||||
// CHECK: Remove attr: 1
|
||||
// CHECK: Remove attr again: 0
|
||||
// CHECK: Removed attr is null: 1
|
||||
// clang-format on
|
||||
|
||||
mlirModuleDestroy(moduleOp);
|
||||
|
|
Loading…
Reference in New Issue