[mlir] expose standard attributes to C API

Provide C API for MLIR standard attributes. Since standard attributes live
under lib/IR in core MLIR, place the C APIs in the IR library as well (standard
ops will go in a separate library).

Affine map and integer set attributes are only exposed as placeholder types
with IsA support due to the lack of C APIs for the corresponding types.

Integer and floating point attribute APIs expecting APInt and APFloat are not
exposed pending decision on how to support APInt and APFloat.

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D86143
This commit is contained in:
Alex Zinenko 2020-08-19 18:38:56 +02:00
parent 0f95e73190
commit da56297462
8 changed files with 1288 additions and 20 deletions

View File

@ -97,10 +97,27 @@ as follows.
its first argument is `Y`, and it is the responsibility of the caller to its first argument is `Y`, and it is the responsibility of the caller to
ensure it is indeed the case. ensure it is indeed the case.
### Returning String References
Numerous MLIR functions return instances of `StringRef` to refer to a non-owning
segment of a string. This segment may or may not be null-terminated. In C API,
these functions take an additional callback argument of type
`MlirStringCallback` (pointer to a function with signature `void (*)(const char
*, intptr_t, void *)`) and a pointer to user-defined data. This callback is
invoked with a pointer to the string segment, its size and is forwarded the
user-defined data. The caller is in charge of managing the string segment
according to its memory model: for strings owned by the object (e.g., string
attributes), the caller can store the pointer and the size and use them directly
as long as the parent object is live or copy the string to a new location with a
null terminator if expected; for generated strings (e.g., in printing), the
caller is expected to copy the string segment if it intends to use it later.
**Note:** this interface may be revised in the near future.
### Conversion To String and Printing ### Conversion To String and Printing
IR objects can be converted to a string representation, for example for IR objects can be converted to a string representation, for example for
printing, using `mlirXPrint(MlirX, MlirPrintCallback, void *)` functions. These printing, using `mlirXPrint(MlirX, MlirStringCallback, void *)` functions. These
functions accept take arguments a callback with signature `void (*)(const char functions accept take arguments a callback with signature `void (*)(const char
*, intptr_t, void *)` and a pointer to user-defined data. They call the callback *, intptr_t, void *)` and a pointer to user-defined data. They call the callback
and supply it with chunks of the string representation, provided as a pointer to and supply it with chunks of the string representation, provided as a pointer to

View File

@ -67,16 +67,16 @@ struct MlirNamedAttribute {
}; };
typedef struct MlirNamedAttribute MlirNamedAttribute; typedef struct MlirNamedAttribute MlirNamedAttribute;
/** A callback for printing to IR objects. /** A callback for returning string referenes.
* *
* This function is called back by the printing functions with the following * This function is called back by the functions that need to return a reference
* arguments: * to the portion of the string with the following arguments:
* - a pointer to the beginning of a string; * - a pointer to the beginning of a string;
* - the length of the string (the pointer may point to a larger buffer, not * - the length of the string (the pointer may point to a larger buffer, not
* necessarily null-terminated); * necessarily null-terminated);
* - a pointer to user data forwarded from the printing call. * - a pointer to user data forwarded from the printing call.
*/ */
typedef void (*MlirPrintCallback)(const char *, intptr_t, void *); typedef void (*MlirStringCallback)(const char *, intptr_t, void *);
/*============================================================================*/ /*============================================================================*/
/* Context API. */ /* Context API. */
@ -103,7 +103,7 @@ MlirLocation mlirLocationUnknownGet(MlirContext context);
/** Prints a location by sending chunks of the string representation and /** Prints a location by sending chunks of the string representation and
* forwarding `userData to `callback`. Note that the callback may be called * forwarding `userData to `callback`. Note that the callback may be called
* several times with consecutive chunks of the string. */ * several times with consecutive chunks of the string. */
void mlirLocationPrint(MlirLocation location, MlirPrintCallback callback, void mlirLocationPrint(MlirLocation location, MlirStringCallback callback,
void *userData); void *userData);
/*============================================================================*/ /*============================================================================*/
@ -224,7 +224,7 @@ MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
/** Prints an operation by sending chunks of the string representation and /** Prints an operation by sending chunks of the string representation and
* forwarding `userData to `callback`. Note that the callback may be called * forwarding `userData to `callback`. Note that the callback may be called
* several times with consecutive chunks of the string. */ * several times with consecutive chunks of the string. */
void mlirOperationPrint(MlirOperation op, MlirPrintCallback callback, void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
void *userData); void *userData);
/** Prints an operation to stderr. */ /** Prints an operation to stderr. */
@ -292,7 +292,7 @@ MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos);
/** Prints a block by sending chunks of the string representation and /** Prints a block by sending chunks of the string representation and
* forwarding `userData to `callback`. Note that the callback may be called * forwarding `userData to `callback`. Note that the callback may be called
* several times with consecutive chunks of the string. */ * several times with consecutive chunks of the string. */
void mlirBlockPrint(MlirBlock block, MlirPrintCallback callback, void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
void *userData); void *userData);
/*============================================================================*/ /*============================================================================*/
@ -305,7 +305,7 @@ MlirType mlirValueGetType(MlirValue value);
/** Prints a value by sending chunks of the string representation and /** Prints a value by sending chunks of the string representation and
* forwarding `userData to `callback`. Note that the callback may be called * forwarding `userData to `callback`. Note that the callback may be called
* several times with consecutive chunks of the string. */ * several times with consecutive chunks of the string. */
void mlirValuePrint(MlirValue value, MlirPrintCallback callback, void mlirValuePrint(MlirValue value, MlirStringCallback callback,
void *userData); void *userData);
/*============================================================================*/ /*============================================================================*/
@ -324,7 +324,7 @@ int mlirTypeEqual(MlirType t1, MlirType t2);
/** Prints a location by sending chunks of the string representation and /** Prints a location by sending chunks of the string representation and
* forwarding `userData to `callback`. Note that the callback may be called * forwarding `userData to `callback`. Note that the callback may be called
* several times with consecutive chunks of the string. */ * several times with consecutive chunks of the string. */
void mlirTypePrint(MlirType type, MlirPrintCallback callback, void *userData); void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData);
/** Prints the type to the standard error stream. */ /** Prints the type to the standard error stream. */
void mlirTypeDump(MlirType type); void mlirTypeDump(MlirType type);
@ -336,10 +336,13 @@ void mlirTypeDump(MlirType type);
/** Parses an attribute. The attribute is owned by the context. */ /** Parses an attribute. The attribute is owned by the context. */
MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr); MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr);
/** Checks if two attributes are equal. */
int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2);
/** Prints an attribute by sending chunks of the string representation and /** Prints an attribute by sending chunks of the string representation and
* forwarding `userData to `callback`. Note that the callback may be called * forwarding `userData to `callback`. Note that the callback may be called
* several times with consecutive chunks of the string. */ * several times with consecutive chunks of the string. */
void mlirAttributePrint(MlirAttribute attr, MlirPrintCallback callback, void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback,
void *userData); void *userData);
/** Prints the attrbute to the standard error stream. */ /** Prints the attrbute to the standard error stream. */

View File

@ -0,0 +1,442 @@
/*===-- mlir-c/StandardAttributes.h - C API for Std Attributes-----*- C -*-===*\
|* *|
|* Part of the LLVM Project, under the Apache License v2.0 with LLVM *|
|* Exceptions. *|
|* See https://llvm.org/LICENSE.txt for license information. *|
|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *|
|* *|
|*===----------------------------------------------------------------------===*|
|* *|
|* This header declares the C interface to MLIR Standard attributes. *|
|* *|
\*===----------------------------------------------------------------------===*/
#ifndef MLIR_C_STANDARDATTRIBUTES_H
#define MLIR_C_STANDARDATTRIBUTES_H
#include "mlir-c/AffineMap.h"
#include "mlir-c/IR.h"
#ifdef __cplusplus
extern "C" {
#endif
/*============================================================================*/
/* Affine map attribute. */
/*============================================================================*/
/** Checks whether the given attribute is an affine map attribute. */
int mlirAttributeIsAAffineMap(MlirAttribute attr);
/** Creates an affine map attribute wrapping the given map. The attribute
* belongs to the same context as the affine map. */
MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map);
/** Returns the affine map wrapped in the given affine map attribute. */
MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr);
/*============================================================================*/
/* Array attribute. */
/*============================================================================*/
/** Checks whether the given attribute is an array attribute. */
int mlirAttributeIsAArray(MlirAttribute attr);
/** Creates an array element containing the given list of elements in the given
* context. */
MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements,
MlirAttribute *elements);
/** Returns the number of elements stored in the given array attribute. */
intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr);
/** Returns pos-th element stored in the given array attribute. */
MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos);
/*============================================================================*/
/* Dictionary attribute. */
/*============================================================================*/
/** Checks whether the given attribute is a dictionary attribute. */
int mlirAttributeIsADictionary(MlirAttribute attr);
/** Creates a dictionary attribute containing the given list of elements in the
* provided context. */
MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements,
MlirNamedAttribute *elements);
/** Returns the number of attributes contained in a dictionary attribute. */
intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr);
/** Returns pos-th element of the given dictionary attribute. */
MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr,
intptr_t pos);
/** Returns the dictionary attribute element with the given name or NULL if the
* given name does not exist in the dictionary. */
MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr,
const char *name);
/*============================================================================*/
/* Floating point attribute. */
/*============================================================================*/
/* TODO: add support for APFloat and APInt to LLVM IR C API, then expose the
* relevant functions here. */
/** Checks whether the given attribute is a floating point attribute. */
int mlirAttributeIsAFloat(MlirAttribute attr);
/** Creates a floating point attribute in the given context with the given
* double value and double-precision FP semantics. */
MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type,
double value);
/** Returns the value stored in the given floating point attribute, interpreting
* the value as double. */
double mlirFloatAttrGetValueDouble(MlirAttribute attr);
/*============================================================================*/
/* Integer attribute. */
/*============================================================================*/
/* TODO: add support for APFloat and APInt to LLVM IR C API, then expose the
* relevant functions here. */
/** Checks whether the given attribute is an integer attribute. */
int mlirAttributeIsAInteger(MlirAttribute attr);
/** Creates an integer attribute of the given type with the given integer
* value. */
MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value);
/** Returns the value stored in the given integer attribute, assuming the value
* fits into a 64-bit integer. */
int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr);
/*============================================================================*/
/* Bool attribute. */
/*============================================================================*/
/** Checks whether the given attribute is a bool attribute. */
int mlirAttributeIsABool(MlirAttribute attr);
/** Creates a bool attribute in the given context with the given value. */
MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value);
/** Returns the value stored in the given bool attribute. */
int mlirBoolAttrGetValue(MlirAttribute attr);
/*============================================================================*/
/* Integer set attribute. */
/*============================================================================*/
/** Checks whether the given attribute is an integer set attribute. */
int mlirAttributeIsAIntegerSet(MlirAttribute attr);
/*============================================================================*/
/* Opaque attribute. */
/*============================================================================*/
/** Checks whether the given attribute is an opaque attribute. */
int mlirAttributeIsAOpaque(MlirAttribute attr);
/** Creates an opaque attribute in the given context associated with the dialect
* identified by its namespace. The attribute contains opaque byte data of the
* specified length (data need not be null-terminated). */
MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, const char *dialectNamespace,
intptr_t dataLength, const char *data,
MlirType type);
/** Returns the namepsace of the dialect with which the given opaque attribute
* is associated. The namespace string is owned by the context. */
const char *mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr);
/** Calls the provided callback with the opaque byte data stored in the given
* opaque attribute. The callback is invoked once, and the data it receives is
* not necessarily null terminated. The data remains live as long as the context
* in which the attribute lives. */
/* TODO: consider exposing StringRef and using it instead of the callback. */
void mlirOpaqueAttrGetData(MlirAttribute attr, MlirStringCallback callback,
void *userData);
/*============================================================================*/
/* String attribute. */
/*============================================================================*/
/** Checks whether the given attribute is a string attribute. */
int mlirAttributeIsAString(MlirAttribute attr);
/** Creates a string attribute in the given context containing the given string.
* The string need not be null-terminated and its length must be specified. */
MlirAttribute mlirStringAttrGet(MlirContext ctx, intptr_t length,
const char *data);
/** Creates a string attribute in the given context containing the given string.
* The string need not be null-terminated and its length must be specified.
* Additionally, the attribute has the given type. */
MlirAttribute mlirStringAttrTypedGet(MlirType type, intptr_t length,
const char *data);
/** Calls the provided callback with the string stored in the given string
* attribute. The callback is invoked once, and the data it receives is not
* necessarily null terminated. The data remains live as long as the context in
* which the attribute lives. */
/* TODO: consider exposing StringRef and using it instead of the callback. */
void mlirStringAttrGetValue(MlirAttribute attr, MlirStringCallback callback,
void *userData);
/*============================================================================*/
/* SymbolRef attribute. */
/*============================================================================*/
/** Checks whether the given attribute is a symbol reference attribute. */
int mlirAttributeIsASymbolRef(MlirAttribute attr);
/** Creates a symbol reference attribute in the given context referencing a
* symbol identified by the given string inside a list of nested references.
* Each of the references in the list must not be nested. The string need not be
* null-terminated and its length must be specified. */
MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, intptr_t length,
const char *symbol, intptr_t numReferences,
MlirAttribute *references);
/** Calls the provided callback with the string containing the root referenced
* symbol. The callback is invoked once, and the data it receives is not
* necessarily null terminated. The data remains live as long as the context in
* which the attribute lives. */
/* TODO: consider exposing StringRef and using it instead of the callback. */
void mlirSymbolRefAttrGetRootReference(MlirAttribute attr,
MlirStringCallback callback,
void *userData);
/** Calls the provided callback with the string containing the leaf referenced
* symbol. The callback is invoked once, and the data it receives is not
* necessarily null terminated. The data remains live as long as the context in
* which the attribute lives. */
/* TODO: consider exposing StringRef and using it instead of the callback. */
void mlirSymbolRefAttrGetLeafReference(MlirAttribute attr,
MlirStringCallback callback,
void *userData);
/** Returns the number of references nested in the given symbol reference
* attribute. */
intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr);
/** Returns pos-th reference nested in the given symbol reference attribute. */
MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr,
intptr_t pos);
/*============================================================================*/
/* Flat SymbolRef attribute. */
/*============================================================================*/
/** Checks whether the given attribute is a flat symbol reference attribute. */
int mlirAttributeIsAFlatSymbolRef(MlirAttribute attr);
/** Creates a flat symbol reference attribute in the given context referencing a
* symbol identified by the given string. The string need not be null-terminated
* and its length must be specified. */
MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, intptr_t length,
const char *symbol);
/** Calls the provided callback with the string containing the referenced
* symbol. The callback is invoked once, and the data it receives is not
* necessarily null terminated. The data remains live as long as the context in
* which the attribute lives. */
/* TODO: consider exposing StringRef and using it instead of the callback. */
void mlirFloatSymbolRefAttrGetValue(MlirAttribute attr,
MlirStringCallback callback,
void *userData);
/*============================================================================*/
/* Type attribute. */
/*============================================================================*/
/** Checks whether the given attribute is a type attribute. */
int mlirAttributeIsAType(MlirAttribute attr);
/** Creates a type attribute wrapping the given type in the same context as the
* type. */
MlirAttribute mlirTypeAttrGet(MlirType type);
/** Returns the type stored in the given type attribute. */
MlirType mlirTypeAttrGetValue(MlirAttribute attr);
/*============================================================================*/
/* Unit attribute. */
/*============================================================================*/
/** Checks whether the given attribute is a unit attribute. */
int mlirAttributeIsAUnit(MlirAttribute attr);
/** Creates a unit attribute in the given context. */
MlirAttribute mlirUnitAttrGet(MlirContext ctx);
/*============================================================================*/
/* Elements attributes. */
/*============================================================================*/
/** Checks whether the given attribute is an elements attribute. */
int mlirAttributeIsAElements(MlirAttribute attr);
/** Returns the element at the given rank-dimensional index. */
MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank,
uint64_t *idxs);
/** Checks whether the given rank-dimensional index is valid in the given
* elements attribute. */
int mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank,
uint64_t *idxs);
/** Gets the total number of elements in the given elements attribute. In order
* to iterate over the attribute, obtain its type, which must be a statically
* shaped type and use its sizes to build a multi-dimensional index. */
int64_t mlirElementsAttrGetNumElements(MlirAttribute attr);
/*============================================================================*/
/* Dense elements attribute. */
/*============================================================================*/
/* TODO: decide on the interface and add support for complex elements. */
/* TODO: add support for APFloat and APInt to LLVM IR C API, then expose the
* relevant functions here. */
/** Checks whether the given attribute is a dense elements attribute. */
int mlirAttributeIsADenseElements(MlirAttribute attr);
int mlirAttributeIsADenseIntElements(MlirAttribute attr);
int mlirAttributeIsADenseFPElements(MlirAttribute attr);
/** Creates a dense elements attribute with the given Shaped type and elements
* in the same context as the type. */
MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType,
intptr_t numElements,
MlirAttribute *elements);
/** Creates a dense elements attribute with the given Shaped type containing a
* single replicated element (splat). */
MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType,
MlirAttribute element);
MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType,
int element);
MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType,
uint32_t element);
MlirAttribute mlirDenseElementsAttrInt32SplatGet(MlirType shapedType,
int32_t element);
MlirAttribute mlirDenseElementsAttrUInt64SplatGet(MlirType shapedType,
uint64_t element);
MlirAttribute mlirDenseElementsAttrInt64SplatGet(MlirType shapedType,
int64_t element);
MlirAttribute mlirDenseElementsAttrFloatSplatGet(MlirType shapedType,
float element);
MlirAttribute mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType,
double element);
/** Creates a dense elements attribute with the given shaped type from elements
* of a specific type. Expects the element type of the shaped type to match the
* data element type. */
MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType,
intptr_t numElements, int *elements);
MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType,
intptr_t numElements,
uint32_t *elements);
MlirAttribute mlirDenseElementsAttrInt32Get(MlirType shapedType,
intptr_t numElements,
int32_t *elements);
MlirAttribute mlirDenseElementsAttrUInt64Get(MlirType shapedType,
intptr_t numElements,
uint64_t *elements);
MlirAttribute mlirDenseElementsAttrInt64Get(MlirType shapedType,
intptr_t numElements,
int64_t *elements);
MlirAttribute mlirDenseElementsAttrFloatGet(MlirType shapedType,
intptr_t numElements,
float *elements);
MlirAttribute mlirDenseElementsAttrDoubleGet(MlirType shapedType,
intptr_t numElements,
double *elements);
/** Creates a dense elements attribute with the given shaped type from string
* elements. The strings need not be null-terminated and their lengths are
* provided as a separate argument co-indexed with the strs argument. */
MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType,
intptr_t numElements,
intptr_t *strLengths,
const char **strs);
/** Creates a dense elements attribute that has the same data as the given dense
* elements attribute and a different shaped type. The new type must have the
* same total number of elements. */
MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr,
MlirType shapedType);
/** Checks whether the given dense elements attribute contains a single
* replicated value (splat). */
int mlirDenseElementsAttrIsSplat(MlirAttribute attr);
/** Returns the single replicated value (splat) of a specific type contained by
* the given dense elements attribute. */
MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr);
int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr);
int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr);
uint32_t mlirDenseElementsAttrGetUInt32SplatValue(MlirAttribute attr);
int64_t mlirDenseElementsAttrGetInt64SplatValue(MlirAttribute attr);
uint64_t mlirDenseElementsAttrGetUInt64SplatValue(MlirAttribute attr);
float mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr);
double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr);
/* TODO: consider exposing StringRef and using it instead of the callback. */
void mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr,
MlirStringCallback callback,
void *userData);
/** Returns the pos-th value (flat contiguous indexing) of a specific type
* contained by the given dense elements attribute. */
int mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos);
int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos);
uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos);
int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos);
uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos);
float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos);
double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos);
/* TODO: consider exposing StringRef and using it instead of the callback. */
void mlirDenseElementsAttrGetStringValue(MlirAttribute attr, intptr_t pos,
MlirStringCallback callback,
void *userData);
/*============================================================================*/
/* Opaque elements attribute. */
/*============================================================================*/
/* TODO: expose Dialect to the bindings and implement accessors here. */
/** Checks whether the given attribute is an opaque elements attribute. */
int mlirAttributeIsAOpaqueElements(MlirAttribute attr);
/*============================================================================*/
/* Sparse elements attribute. */
/*============================================================================*/
/** Checks whether the given attribute is a sparse elements attribute. */
int mlirAttributeIsASparseElements(MlirAttribute attr);
/** Creates a sparse elements attribute of the given shape from a list of
* indices and a list of associated values. Both lists are expected to be dense
* elements attributes with the same number of elements. The list of indices is
* expected to contain 64-bit integers. The attribute is created in the same
* context as the type. */
MlirAttribute mlirSparseElementsAttribute(MlirType shapedType,
MlirAttribute denseIndices,
MlirAttribute denseValues);
/** Returns the dense elements attribute containing 64-bit integer indices of
* non-null elements in the given sparse elements attribute. */
MlirAttribute mlirSparseElementsAttrGetIndices(MlirAttribute attr);
/** Returns the dense elements attribute containing the non-null elements in the
* given sparse elements attribute. */
MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr);
#ifdef __cplusplus
}
#endif
#endif // MLIR_C_STANDARDATTRIBUTES_H

View File

@ -55,13 +55,13 @@ static const char kDumpDocstring[] =
namespace { namespace {
/// Accumulates into a python string from a method that accepts an /// Accumulates into a python string from a method that accepts an
/// MlirPrintCallback. /// MlirStringCallback.
struct PyPrintAccumulator { struct PyPrintAccumulator {
py::list parts; py::list parts;
void *getUserData() { return this; } void *getUserData() { return this; }
MlirPrintCallback getCallback() { MlirStringCallback getCallback() {
return [](const char *part, intptr_t size, void *userData) { return [](const char *part, intptr_t size, void *userData) {
PyPrintAccumulator *printAccum = PyPrintAccumulator *printAccum =
static_cast<PyPrintAccumulator *>(userData); static_cast<PyPrintAccumulator *>(userData);

View File

@ -2,6 +2,7 @@
add_mlir_library(MLIRCAPIIR add_mlir_library(MLIRCAPIIR
AffineMap.cpp AffineMap.cpp
IR.cpp IR.cpp
StandardAttributes.cpp
StandardTypes.cpp StandardTypes.cpp
EXCLUDE_FROM_LIBMLIR EXCLUDE_FROM_LIBMLIR

View File

@ -71,7 +71,7 @@ MlirLocation mlirLocationUnknownGet(MlirContext context) {
return wrap(UnknownLoc::get(unwrap(context))); return wrap(UnknownLoc::get(unwrap(context)));
} }
void mlirLocationPrint(MlirLocation location, MlirPrintCallback callback, void mlirLocationPrint(MlirLocation location, MlirStringCallback callback,
void *userData) { void *userData) {
CallbackOstream stream(callback, userData); CallbackOstream stream(callback, userData);
unwrap(location).print(stream); unwrap(location).print(stream);
@ -238,7 +238,7 @@ MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
return wrap(unwrap(op)->getAttr(name)); return wrap(unwrap(op)->getAttr(name));
} }
void mlirOperationPrint(MlirOperation op, MlirPrintCallback callback, void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
void *userData) { void *userData) {
CallbackOstream stream(callback, userData); CallbackOstream stream(callback, userData);
unwrap(op)->print(stream); unwrap(op)->print(stream);
@ -320,7 +320,7 @@ MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) {
return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos))); return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos)));
} }
void mlirBlockPrint(MlirBlock block, MlirPrintCallback callback, void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
void *userData) { void *userData) {
CallbackOstream stream(callback, userData); CallbackOstream stream(callback, userData);
unwrap(block)->print(stream); unwrap(block)->print(stream);
@ -335,7 +335,7 @@ MlirType mlirValueGetType(MlirValue value) {
return wrap(unwrap(value).getType()); return wrap(unwrap(value).getType());
} }
void mlirValuePrint(MlirValue value, MlirPrintCallback callback, void mlirValuePrint(MlirValue value, MlirStringCallback callback,
void *userData) { void *userData) {
CallbackOstream stream(callback, userData); CallbackOstream stream(callback, userData);
unwrap(value).print(stream); unwrap(value).print(stream);
@ -352,7 +352,7 @@ MlirType mlirTypeParseGet(MlirContext context, const char *type) {
int mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); } int mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); }
void mlirTypePrint(MlirType type, MlirPrintCallback callback, void *userData) { void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) {
CallbackOstream stream(callback, userData); CallbackOstream stream(callback, userData);
unwrap(type).print(stream); unwrap(type).print(stream);
stream.flush(); stream.flush();
@ -368,7 +368,11 @@ MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) {
return wrap(mlir::parseAttribute(attr, unwrap(context))); return wrap(mlir::parseAttribute(attr, unwrap(context)));
} }
void mlirAttributePrint(MlirAttribute attr, MlirPrintCallback callback, int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
return unwrap(a1) == unwrap(a2);
}
void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback,
void *userData) { void *userData) {
CallbackOstream stream(callback, userData); CallbackOstream stream(callback, userData);
unwrap(attr).print(stream); unwrap(attr).print(stream);

View File

@ -0,0 +1,561 @@
//===- StandardAttributes.cpp - C Interface to MLIR Standard Attributes ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir-c/StandardAttributes.h"
#include "mlir/CAPI/AffineMap.h"
#include "mlir/CAPI/IR.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/StandardTypes.h"
using namespace mlir;
/*============================================================================*/
/* Affine map attribute. */
/*============================================================================*/
int mlirAttributeIsAAffineMap(MlirAttribute attr) {
return unwrap(attr).isa<AffineMapAttr>();
}
MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map) {
return wrap(AffineMapAttr::get(unwrap(map)));
}
MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr) {
return wrap(unwrap(attr).cast<AffineMapAttr>().getValue());
}
/*============================================================================*/
/* Array attribute. */
/*============================================================================*/
int mlirAttributeIsAArray(MlirAttribute attr) {
return unwrap(attr).isa<ArrayAttr>();
}
MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements,
MlirAttribute *elements) {
SmallVector<Attribute, 8> attrs;
return wrap(ArrayAttr::get(
unwrapList(static_cast<size_t>(numElements), elements, attrs),
unwrap(ctx)));
}
intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) {
return static_cast<intptr_t>(unwrap(attr).cast<ArrayAttr>().size());
}
MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos) {
return wrap(unwrap(attr).cast<ArrayAttr>().getValue()[pos]);
}
/*============================================================================*/
/* Dictionary attribute. */
/*============================================================================*/
int mlirAttributeIsADictionary(MlirAttribute attr) {
return unwrap(attr).isa<DictionaryAttr>();
}
MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements,
MlirNamedAttribute *elements) {
SmallVector<NamedAttribute, 8> attributes;
attributes.reserve(numElements);
for (intptr_t i = 0; i < numElements; ++i)
attributes.emplace_back(Identifier::get(elements[i].name, unwrap(ctx)),
unwrap(elements[i].attribute));
return wrap(DictionaryAttr::get(attributes, unwrap(ctx)));
}
intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) {
return static_cast<intptr_t>(unwrap(attr).cast<DictionaryAttr>().size());
}
MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr,
intptr_t pos) {
NamedAttribute attribute =
unwrap(attr).cast<DictionaryAttr>().getValue()[pos];
return {attribute.first.c_str(), wrap(attribute.second)};
}
MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr,
const char *name) {
return wrap(unwrap(attr).cast<DictionaryAttr>().get(name));
}
/*============================================================================*/
/* Floating point attribute. */
/*============================================================================*/
int mlirAttributeIsAFloat(MlirAttribute attr) {
return unwrap(attr).isa<FloatAttr>();
}
MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type,
double value) {
return wrap(FloatAttr::get(unwrap(type), value));
}
double mlirFloatAttrGetValueDouble(MlirAttribute attr) {
return unwrap(attr).cast<FloatAttr>().getValueAsDouble();
}
/*============================================================================*/
/* Integer attribute. */
/*============================================================================*/
int mlirAttributeIsAInteger(MlirAttribute attr) {
return unwrap(attr).isa<IntegerAttr>();
}
MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value) {
return wrap(IntegerAttr::get(unwrap(type), value));
}
int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr) {
return unwrap(attr).cast<IntegerAttr>().getInt();
}
/*============================================================================*/
/* Bool attribute. */
/*============================================================================*/
int mlirAttributeIsABool(MlirAttribute attr) {
return unwrap(attr).isa<BoolAttr>();
}
MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) {
return wrap(BoolAttr::get(value, unwrap(ctx)));
}
int mlirBoolAttrGetValue(MlirAttribute attr) {
return unwrap(attr).cast<BoolAttr>().getValue();
}
/*============================================================================*/
/* Integer set attribute. */
/*============================================================================*/
int mlirAttributeIsAIntegerSet(MlirAttribute attr) {
return unwrap(attr).isa<IntegerSetAttr>();
}
/*============================================================================*/
/* Opaque attribute. */
/*============================================================================*/
int mlirAttributeIsAOpaque(MlirAttribute attr) {
return unwrap(attr).isa<OpaqueAttr>();
}
MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, const char *dialectNamespace,
intptr_t dataLength, const char *data,
MlirType type) {
return wrap(OpaqueAttr::get(Identifier::get(dialectNamespace, unwrap(ctx)),
StringRef(data, dataLength), unwrap(type),
unwrap(ctx)));
}
const char *mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) {
return unwrap(attr).cast<OpaqueAttr>().getDialectNamespace().c_str();
}
void mlirOpaqueAttrGetData(MlirAttribute attr, MlirStringCallback callback,
void *userData) {
StringRef data = unwrap(attr).cast<OpaqueAttr>().getAttrData();
callback(data.data(), static_cast<intptr_t>(data.size()), userData);
}
/*============================================================================*/
/* String attribute. */
/*============================================================================*/
int mlirAttributeIsAString(MlirAttribute attr) {
return unwrap(attr).isa<StringAttr>();
}
MlirAttribute mlirStringAttrGet(MlirContext ctx, intptr_t length,
const char *data) {
return wrap(StringAttr::get(StringRef(data, length), unwrap(ctx)));
}
MlirAttribute mlirStringAttrTypedGet(MlirType type, intptr_t length,
const char *data) {
return wrap(StringAttr::get(StringRef(data, length), unwrap(type)));
}
void mlirStringAttrGetValue(MlirAttribute attr, MlirStringCallback callback,
void *userData) {
StringRef data = unwrap(attr).cast<StringAttr>().getValue();
callback(data.data(), static_cast<intptr_t>(data.size()), userData);
}
/*============================================================================*/
/* SymbolRef attribute. */
/*============================================================================*/
int mlirAttributeIsASymbolRef(MlirAttribute attr) {
return unwrap(attr).isa<SymbolRefAttr>();
}
MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, intptr_t length,
const char *symbol, intptr_t numReferences,
MlirAttribute *references) {
SmallVector<FlatSymbolRefAttr, 4> refs;
refs.reserve(numReferences);
for (intptr_t i = 0; i < numReferences; ++i)
refs.push_back(unwrap(references[i]).cast<FlatSymbolRefAttr>());
return wrap(SymbolRefAttr::get(StringRef(symbol, length), refs, unwrap(ctx)));
}
void mlirSymbolRefAttrGetRootReference(MlirAttribute attr,
MlirStringCallback callback,
void *userData) {
StringRef ref = unwrap(attr).cast<SymbolRefAttr>().getRootReference();
callback(ref.data(), ref.size(), userData);
}
void mlirSymbolRefAttrGetLeafReference(MlirAttribute attr,
MlirStringCallback callback,
void *userData) {
StringRef ref = unwrap(attr).cast<SymbolRefAttr>().getLeafReference();
callback(ref.data(), ref.size(), userData);
}
intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr) {
return static_cast<intptr_t>(
unwrap(attr).cast<SymbolRefAttr>().getNestedReferences().size());
}
MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr,
intptr_t pos) {
return wrap(unwrap(attr).cast<SymbolRefAttr>().getNestedReferences()[pos]);
}
/*============================================================================*/
/* Flat SymbolRef attribute. */
/*============================================================================*/
int mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) {
return unwrap(attr).isa<FlatSymbolRefAttr>();
}
MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, intptr_t length,
const char *symbol) {
return wrap(FlatSymbolRefAttr::get(StringRef(symbol, length), unwrap(ctx)));
}
void mlirFloatSymbolRefAttrGetValue(MlirAttribute attr,
MlirStringCallback callback,
void *userData) {
StringRef symbol = unwrap(attr).cast<FlatSymbolRefAttr>().getValue();
callback(symbol.data(), symbol.size(), userData);
}
/*============================================================================*/
/* Type attribute. */
/*============================================================================*/
int mlirAttributeIsAType(MlirAttribute attr) {
return unwrap(attr).isa<TypeAttr>();
}
MlirAttribute mlirTypeAttrGet(MlirType type) {
return wrap(TypeAttr::get(unwrap(type)));
}
MlirType mlirTypeAttrGetValue(MlirAttribute attr) {
return wrap(unwrap(attr).cast<TypeAttr>().getValue());
}
/*============================================================================*/
/* Unit attribute. */
/*============================================================================*/
int mlirAttributeIsAUnit(MlirAttribute attr) {
return unwrap(attr).isa<UnitAttr>();
}
MlirAttribute mlirUnitAttrGet(MlirContext ctx) {
return wrap(UnitAttr::get(unwrap(ctx)));
}
/*============================================================================*/
/* Elements attributes. */
/*============================================================================*/
int mlirAttributeIsAElements(MlirAttribute attr) {
return unwrap(attr).isa<ElementsAttr>();
}
MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank,
uint64_t *idxs) {
return wrap(unwrap(attr).cast<ElementsAttr>().getValue(
llvm::makeArrayRef(idxs, rank)));
}
int mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank,
uint64_t *idxs) {
return unwrap(attr).cast<ElementsAttr>().isValidIndex(
llvm::makeArrayRef(idxs, rank));
}
int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) {
return unwrap(attr).cast<ElementsAttr>().getNumElements();
}
/*============================================================================*/
/* Dense elements attribute. */
/*============================================================================*/
//===----------------------------------------------------------------------===//
// IsA support.
int mlirAttributeIsADenseElements(MlirAttribute attr) {
return unwrap(attr).isa<DenseElementsAttr>();
}
int mlirAttributeIsADenseIntElements(MlirAttribute attr) {
return unwrap(attr).isa<DenseIntElementsAttr>();
}
int mlirAttributeIsADenseFPElements(MlirAttribute attr) {
return unwrap(attr).isa<DenseFPElementsAttr>();
}
//===----------------------------------------------------------------------===//
// Constructors.
MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType,
intptr_t numElements,
MlirAttribute *elements) {
SmallVector<Attribute, 8> attributes;
return wrap(
DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
unwrapList(numElements, elements, attributes)));
}
MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType,
MlirAttribute element) {
return wrap(DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
unwrap(element)));
}
MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType,
int element) {
return wrap(DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
static_cast<bool>(element)));
}
MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType,
uint32_t element) {
return wrap(
DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
}
MlirAttribute mlirDenseElementsAttrInt32SplatGet(MlirType shapedType,
int32_t element) {
return wrap(
DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
}
MlirAttribute mlirDenseElementsAttrUInt64SplatGet(MlirType shapedType,
uint64_t element) {
return wrap(
DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
}
MlirAttribute mlirDenseElementsAttrInt64SplatGet(MlirType shapedType,
int64_t element) {
return wrap(
DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
}
MlirAttribute mlirDenseElementsAttrFloatSplatGet(MlirType shapedType,
float element) {
return wrap(
DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
}
MlirAttribute mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType,
double element) {
return wrap(
DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
}
MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType,
intptr_t numElements,
int *elements) {
SmallVector<bool, 8> values(elements, elements + numElements);
return wrap(
DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), values));
}
/// Creates a dense attribute with elements of the type deduced by templates.
template <typename T>
static MlirAttribute getDenseAttribute(MlirType shapedType,
intptr_t numElements, T *elements) {
return wrap(
DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
llvm::makeArrayRef(elements, numElements)));
}
MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType,
intptr_t numElements,
uint32_t *elements) {
return getDenseAttribute(shapedType, numElements, elements);
}
MlirAttribute mlirDenseElementsAttrInt32Get(MlirType shapedType,
intptr_t numElements,
int32_t *elements) {
return getDenseAttribute(shapedType, numElements, elements);
}
MlirAttribute mlirDenseElementsAttrUInt64Get(MlirType shapedType,
intptr_t numElements,
uint64_t *elements) {
return getDenseAttribute(shapedType, numElements, elements);
}
MlirAttribute mlirDenseElementsAttrInt64Get(MlirType shapedType,
intptr_t numElements,
int64_t *elements) {
return getDenseAttribute(shapedType, numElements, elements);
}
MlirAttribute mlirDenseElementsAttrFloatGet(MlirType shapedType,
intptr_t numElements,
float *elements) {
return getDenseAttribute(shapedType, numElements, elements);
}
MlirAttribute mlirDenseElementsAttrDoubleGet(MlirType shapedType,
intptr_t numElements,
double *elements) {
return getDenseAttribute(shapedType, numElements, elements);
}
MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType,
intptr_t numElements,
intptr_t *strLengths,
const char **strs) {
SmallVector<StringRef, 8> values;
values.reserve(numElements);
for (intptr_t i = 0; i < numElements; ++i)
values.push_back(StringRef(strs[i], strLengths[i]));
return wrap(
DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), values));
}
MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr,
MlirType shapedType) {
return wrap(unwrap(attr).cast<DenseElementsAttr>().reshape(
unwrap(shapedType).cast<ShapedType>()));
}
//===----------------------------------------------------------------------===//
// Splat accessors.
int mlirDenseElementsAttrIsSplat(MlirAttribute attr) {
return unwrap(attr).cast<DenseElementsAttr>().isSplat();
}
MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) {
return wrap(unwrap(attr).cast<DenseElementsAttr>().getSplatValue());
}
int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) {
return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<bool>();
}
int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr) {
return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<int32_t>();
}
uint32_t mlirDenseElementsAttrGetUInt32SplatValue(MlirAttribute attr) {
return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<uint32_t>();
}
int64_t mlirDenseElementsAttrGetInt64SplatValue(MlirAttribute attr) {
return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<int64_t>();
}
uint64_t mlirDenseElementsAttrGetUInt64SplatValue(MlirAttribute attr) {
return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<uint64_t>();
}
float mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr) {
return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<float>();
}
double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr) {
return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<double>();
}
void mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr,
MlirStringCallback callback,
void *userData) {
StringRef str =
unwrap(attr).cast<DenseElementsAttr>().getSplatValue<StringRef>();
callback(str.data(), str.size(), userData);
}
//===----------------------------------------------------------------------===//
// Indexed accessors.
int mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) {
return *(unwrap(attr).cast<DenseElementsAttr>().getValues<bool>().begin() +
pos);
}
int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) {
return *(unwrap(attr).cast<DenseElementsAttr>().getValues<int32_t>().begin() +
pos);
}
uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) {
return *(
unwrap(attr).cast<DenseElementsAttr>().getValues<uint32_t>().begin() +
pos);
}
int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) {
return *(unwrap(attr).cast<DenseElementsAttr>().getValues<int64_t>().begin() +
pos);
}
uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) {
return *(
unwrap(attr).cast<DenseElementsAttr>().getValues<uint64_t>().begin() +
pos);
}
float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) {
return *(unwrap(attr).cast<DenseElementsAttr>().getValues<float>().begin() +
pos);
}
double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) {
return *(unwrap(attr).cast<DenseElementsAttr>().getValues<double>().begin() +
pos);
}
void mlirDenseElementsAttrGetStringValue(MlirAttribute attr, intptr_t pos,
MlirStringCallback callback,
void *userData) {
StringRef str =
*(unwrap(attr).cast<DenseElementsAttr>().getValues<StringRef>().begin() +
pos);
callback(str.data(), str.size(), userData);
}
/*============================================================================*/
/* Opaque elements attribute. */
/*============================================================================*/
int mlirAttributeIsAOpaqueElements(MlirAttribute attr) {
return unwrap(attr).isa<OpaqueElementsAttr>();
}
/*============================================================================*/
/* Sparse elements attribute. */
/*============================================================================*/
int mlirAttributeIsASparseElements(MlirAttribute attr) {
return unwrap(attr).isa<SparseElementsAttr>();
}
MlirAttribute mlirSparseElementsAttribute(MlirType shapedType,
MlirAttribute denseIndices,
MlirAttribute denseValues) {
return wrap(
SparseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
unwrap(denseIndices).cast<DenseElementsAttr>(),
unwrap(denseValues).cast<DenseElementsAttr>()));
}
MlirAttribute mlirSparseElementsAttrGetIndices(MlirAttribute attr) {
return wrap(unwrap(attr).cast<SparseElementsAttr>().getIndices());
}
MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) {
return wrap(unwrap(attr).cast<SparseElementsAttr>().getValues());
}

View File

@ -12,11 +12,14 @@
#include "mlir-c/IR.h" #include "mlir-c/IR.h"
#include "mlir-c/Registration.h" #include "mlir-c/Registration.h"
#include "mlir-c/StandardAttributes.h"
#include "mlir-c/StandardTypes.h" #include "mlir-c/StandardTypes.h"
#include <assert.h> #include <assert.h>
#include <math.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <string.h>
void populateLoopBody(MlirContext ctx, MlirBlock loopBody, void populateLoopBody(MlirContext ctx, MlirBlock loopBody,
MlirLocation location, MlirBlock funcBody) { MlirLocation location, MlirBlock funcBody) {
@ -380,6 +383,210 @@ static int printStandardTypes(MlirContext ctx) {
return 0; return 0;
} }
void callbackSetFixedLengthString(const char *data, intptr_t len,
void *userData) {
strncpy(userData, data, len);
}
int printStandardAttributes(MlirContext ctx) {
MlirAttribute floating =
mlirFloatAttrDoubleGet(ctx, mlirF64TypeGet(ctx), 2.0);
if (!mlirAttributeIsAFloat(floating) ||
fabs(mlirFloatAttrGetValueDouble(floating) - 2.0) > 1E-6)
return 1;
mlirAttributeDump(floating);
MlirAttribute integer = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 32), 42);
if (!mlirAttributeIsAInteger(integer) ||
mlirIntegerAttrGetValueInt(integer) != 42)
return 2;
mlirAttributeDump(integer);
MlirAttribute boolean = mlirBoolAttrGet(ctx, 1);
if (!mlirAttributeIsABool(boolean) || !mlirBoolAttrGetValue(boolean))
return 3;
mlirAttributeDump(boolean);
const char data[] = "abcdefghijklmnopqestuvwxyz";
char buffer[10];
MlirAttribute opaque =
mlirOpaqueAttrGet(ctx, "std", 3, data, mlirNoneTypeGet(ctx));
if (!mlirAttributeIsAOpaque(opaque) ||
strcmp("std", mlirOpaqueAttrGetDialectNamespace(opaque)))
return 4;
mlirOpaqueAttrGetData(opaque, callbackSetFixedLengthString, buffer);
if (buffer[0] != 'a' || buffer[1] != 'b' || buffer[2] != 'c')
return 5;
mlirAttributeDump(opaque);
MlirAttribute string = mlirStringAttrGet(ctx, 2, data + 3);
if (!mlirAttributeIsAString(string))
return 6;
mlirStringAttrGetValue(string, callbackSetFixedLengthString, buffer);
if (buffer[0] != 'd' || buffer[1] != 'e')
return 7;
mlirAttributeDump(string);
MlirAttribute flatSymbolRef = mlirFlatSymbolRefAttrGet(ctx, 3, data + 5);
if (!mlirAttributeIsAFlatSymbolRef(flatSymbolRef))
return 8;
mlirFloatSymbolRefAttrGetValue(flatSymbolRef, callbackSetFixedLengthString,
buffer);
if (buffer[0] != 'f' || buffer[1] != 'g' || buffer[2] != 'h')
return 9;
mlirAttributeDump(flatSymbolRef);
MlirAttribute symbols[] = {flatSymbolRef, flatSymbolRef};
MlirAttribute symbolRef = mlirSymbolRefAttrGet(ctx, 2, data + 8, 2, symbols);
if (!mlirAttributeIsASymbolRef(symbolRef) ||
mlirSymbolRefAttrGetNumNestedReferences(symbolRef) != 2 ||
!mlirAttributeEqual(mlirSymbolRefAttrGetNestedReference(symbolRef, 0),
flatSymbolRef) ||
!mlirAttributeEqual(mlirSymbolRefAttrGetNestedReference(symbolRef, 1),
flatSymbolRef))
return 10;
mlirSymbolRefAttrGetLeafReference(symbolRef, callbackSetFixedLengthString,
buffer);
mlirSymbolRefAttrGetRootReference(symbolRef, callbackSetFixedLengthString,
buffer + 3);
if (buffer[0] != 'f' || buffer[1] != 'g' || buffer[2] != 'h' ||
buffer[3] != 'i' || buffer[4] != 'j')
return 11;
mlirAttributeDump(symbolRef);
MlirAttribute type = mlirTypeAttrGet(mlirF32TypeGet(ctx));
if (!mlirAttributeIsAType(type) ||
!mlirTypeEqual(mlirF32TypeGet(ctx), mlirTypeAttrGetValue(type)))
return 12;
mlirAttributeDump(type);
MlirAttribute unit = mlirUnitAttrGet(ctx);
if (!mlirAttributeIsAUnit(unit))
return 13;
mlirAttributeDump(unit);
int64_t shape[] = {1, 2};
int bools[] = {0, 1};
uint32_t uints32[] = {0u, 1u};
int32_t ints32[] = {0, 1};
uint64_t uints64[] = {0u, 1u};
int64_t ints64[] = {0, 1};
float floats[] = {0.0f, 1.0f};
double doubles[] = {0.0, 1.0};
MlirAttribute boolElements = mlirDenseElementsAttrBoolGet(
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1)), 2, bools);
MlirAttribute uint32Elements = mlirDenseElementsAttrUInt32Get(
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32)), 2,
uints32);
MlirAttribute int32Elements = mlirDenseElementsAttrInt32Get(
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32)), 2,
ints32);
MlirAttribute uint64Elements = mlirDenseElementsAttrUInt64Get(
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 64)), 2,
uints64);
MlirAttribute int64Elements = mlirDenseElementsAttrInt64Get(
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64)), 2,
ints64);
MlirAttribute floatElements = mlirDenseElementsAttrFloatGet(
mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx)), 2, floats);
MlirAttribute doubleElements = mlirDenseElementsAttrDoubleGet(
mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx)), 2, doubles);
if (!mlirAttributeIsADenseElements(boolElements) ||
!mlirAttributeIsADenseElements(uint32Elements) ||
!mlirAttributeIsADenseElements(int32Elements) ||
!mlirAttributeIsADenseElements(uint64Elements) ||
!mlirAttributeIsADenseElements(int64Elements) ||
!mlirAttributeIsADenseElements(floatElements) ||
!mlirAttributeIsADenseElements(doubleElements))
return 14;
if (mlirDenseElementsAttrGetBoolValue(boolElements, 1) != 1 ||
mlirDenseElementsAttrGetUInt32Value(uint32Elements, 1) != 1 ||
mlirDenseElementsAttrGetInt32Value(int32Elements, 1) != 1 ||
mlirDenseElementsAttrGetUInt64Value(uint64Elements, 1) != 1 ||
mlirDenseElementsAttrGetInt64Value(int64Elements, 1) != 1 ||
fabsf(mlirDenseElementsAttrGetFloatValue(floatElements, 1) - 1.0f) >
1E-6f ||
fabs(mlirDenseElementsAttrGetDoubleValue(doubleElements, 1) - 1.0) > 1E-6)
return 15;
mlirAttributeDump(boolElements);
mlirAttributeDump(uint32Elements);
mlirAttributeDump(int32Elements);
mlirAttributeDump(uint64Elements);
mlirAttributeDump(int64Elements);
mlirAttributeDump(floatElements);
mlirAttributeDump(doubleElements);
MlirAttribute splatBool = mlirDenseElementsAttrBoolSplatGet(
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1)), 1);
MlirAttribute splatUInt32 = mlirDenseElementsAttrUInt32SplatGet(
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32)), 1);
MlirAttribute splatInt32 = mlirDenseElementsAttrInt32SplatGet(
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32)), 1);
MlirAttribute splatUInt64 = mlirDenseElementsAttrUInt64SplatGet(
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64)), 1);
MlirAttribute splatInt64 = mlirDenseElementsAttrInt64SplatGet(
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64)), 1);
MlirAttribute splatFloat = mlirDenseElementsAttrFloatSplatGet(
mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx)), 1.0f);
MlirAttribute splatDouble = mlirDenseElementsAttrDoubleSplatGet(
mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx)), 1.0);
if (!mlirAttributeIsADenseElements(splatBool) ||
!mlirDenseElementsAttrIsSplat(splatBool) ||
!mlirAttributeIsADenseElements(splatUInt32) ||
!mlirDenseElementsAttrIsSplat(splatUInt32) ||
!mlirAttributeIsADenseElements(splatInt32) ||
!mlirDenseElementsAttrIsSplat(splatInt32) ||
!mlirAttributeIsADenseElements(splatUInt64) ||
!mlirDenseElementsAttrIsSplat(splatUInt64) ||
!mlirAttributeIsADenseElements(splatInt64) ||
!mlirDenseElementsAttrIsSplat(splatInt64) ||
!mlirAttributeIsADenseElements(splatFloat) ||
!mlirDenseElementsAttrIsSplat(splatFloat) ||
!mlirAttributeIsADenseElements(splatDouble) ||
!mlirDenseElementsAttrIsSplat(splatDouble))
return 16;
if (mlirDenseElementsAttrGetBoolSplatValue(splatBool) != 1 ||
mlirDenseElementsAttrGetUInt32SplatValue(splatUInt32) != 1 ||
mlirDenseElementsAttrGetInt32SplatValue(splatInt32) != 1 ||
mlirDenseElementsAttrGetUInt64SplatValue(splatUInt64) != 1 ||
mlirDenseElementsAttrGetInt64SplatValue(splatInt64) != 1 ||
fabsf(mlirDenseElementsAttrGetFloatSplatValue(splatFloat) - 1.0f) >
1E-6f ||
fabs(mlirDenseElementsAttrGetDoubleSplatValue(splatDouble) - 1.0) > 1E-6)
return 17;
mlirAttributeDump(splatBool);
mlirAttributeDump(splatUInt32);
mlirAttributeDump(splatInt32);
mlirAttributeDump(splatUInt64);
mlirAttributeDump(splatInt64);
mlirAttributeDump(splatFloat);
mlirAttributeDump(splatDouble);
mlirAttributeDump(mlirElementsAttrGetValue(floatElements, 2, uints64));
mlirAttributeDump(mlirElementsAttrGetValue(doubleElements, 2, uints64));
int64_t indices[] = {4, 7};
int64_t two = 2;
MlirAttribute indicesAttr = mlirDenseElementsAttrInt64Get(
mlirRankedTensorTypeGet(1, &two, mlirIntegerTypeGet(ctx, 64)), 2,
indices);
MlirAttribute valuesAttr = mlirDenseElementsAttrFloatGet(
mlirRankedTensorTypeGet(1, &two, mlirF32TypeGet(ctx)), 2, floats);
MlirAttribute sparseAttr = mlirSparseElementsAttribute(
mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx)), indicesAttr,
valuesAttr);
mlirAttributeDump(sparseAttr);
return 0;
}
int main() { int main() {
MlirContext ctx = mlirContextCreate(); MlirContext ctx = mlirContextCreate();
mlirRegisterAllDialects(ctx); mlirRegisterAllDialects(ctx);
@ -454,10 +661,43 @@ int main() {
// CHECK: tuple<memref<*xf32, 4>, f32> // CHECK: tuple<memref<*xf32, 4>, f32>
// CHECK: 0 // CHECK: 0
// clang-format on // clang-format on
fprintf(stderr, "@types"); fprintf(stderr, "@types\n");
int errcode = printStandardTypes(ctx); int errcode = printStandardTypes(ctx);
fprintf(stderr, "%d\n", errcode); fprintf(stderr, "%d\n", errcode);
// clang-format off
// CHECK-LABEL: @attrs
// CHECK: 2.000000e+00 : f64
// CHECK: 42 : i32
// CHECK: true
// CHECK: #std.abc
// CHECK: "de"
// CHECK: @fgh
// CHECK: @ij::@fgh::@fgh
// CHECK: f32
// CHECK: unit
// CHECK: dense<{{\[}}[false, true]]> : tensor<1x2xi1>
// CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui32>
// CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi32>
// CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui64>
// CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi64>
// CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf32>
// CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf64>
// CHECK: dense<true> : tensor<1x2xi1>
// CHECK: dense<1> : tensor<1x2xi32>
// CHECK: dense<1> : tensor<1x2xi32>
// CHECK: dense<1> : tensor<1x2xi64>
// CHECK: dense<1> : tensor<1x2xi64>
// CHECK: dense<1.000000e+00> : tensor<1x2xf32>
// CHECK: dense<1.000000e+00> : tensor<1x2xf64>
// CHECK: 1.000000e+00 : f32
// CHECK: 1.000000e+00 : f64
// CHECK: sparse<[4, 7], [0.000000e+00, 1.000000e+00]> : tensor<1x2xf32>
// clang-format on
fprintf(stderr, "@attrs\n");
errcode = printStandardAttributes(ctx);
fprintf(stderr, "%d\n", errcode);
mlirContextDestroy(ctx); mlirContextDestroy(ctx);
return 0; return 0;