Use MlirStringRef in StandardAttributes.h

This commit is contained in:
George 2020-12-03 16:01:32 -08:00
parent 99b823c2eb
commit 5f65c4a8e6
3 changed files with 51 additions and 53 deletions

View File

@ -79,7 +79,7 @@ mlirDictionaryAttrGetElement(MlirAttribute attr, intptr_t pos);
/** Returns the dictionary attribute element with the given name or NULL if the
* given name does not exist in the dictionary. */
MLIR_CAPI_EXPORTED MlirAttribute
mlirDictionaryAttrGetElementByName(MlirAttribute attr, const char *name);
mlirDictionaryAttrGetElementByName(MlirAttribute attr, MlirStringRef name);
//===----------------------------------------------------------------------===//
// Floating point attribute.
@ -155,15 +155,13 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsAOpaque(MlirAttribute attr);
/** Creates an opaque attribute in the given context associated with the dialect
* identified by its namespace. The attribute contains opaque byte data of the
* specified length (data need not be null-terminated). */
MLIR_CAPI_EXPORTED MlirAttribute mlirOpaqueAttrGet(MlirContext ctx,
const char *dialectNamespace,
intptr_t dataLength,
const char *data,
MlirType type);
MLIR_CAPI_EXPORTED MlirAttribute
mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace,
intptr_t dataLength, const char *data, MlirType type);
/** Returns the namespace of the dialect with which the given opaque attribute
* is associated. The namespace string is owned by the context. */
MLIR_CAPI_EXPORTED const char *
MLIR_CAPI_EXPORTED MlirStringRef
mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr);
/** Returns the raw data as a string reference. The data remains live as long as
@ -178,17 +176,14 @@ MLIR_CAPI_EXPORTED MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr);
MLIR_CAPI_EXPORTED bool mlirAttributeIsAString(MlirAttribute attr);
/** Creates a string attribute in the given context containing the given string.
* The string need not be null-terminated and its length must be specified. */
*/
MLIR_CAPI_EXPORTED MlirAttribute mlirStringAttrGet(MlirContext ctx,
intptr_t length,
const char *data);
MlirStringRef str);
/** Creates a string attribute in the given context containing the given string.
* The string need not be null-terminated and its length must be specified.
* Additionally, the attribute has the given type. */
MLIR_CAPI_EXPORTED MlirAttribute mlirStringAttrTypedGet(MlirType type,
intptr_t length,
const char *data);
MlirStringRef str);
/** Returns the attribute values as a string reference. The data remains live as
* long as the context in which the attribute lives. */
@ -203,10 +198,9 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsASymbolRef(MlirAttribute attr);
/** Creates a symbol reference attribute in the given context referencing a
* symbol identified by the given string inside a list of nested references.
* Each of the references in the list must not be nested. The string need not be
* null-terminated and its length must be specified. */
* Each of the references in the list must not be nested. */
MLIR_CAPI_EXPORTED MlirAttribute
mlirSymbolRefAttrGet(MlirContext ctx, intptr_t length, const char *symbol,
mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol,
intptr_t numReferences, MlirAttribute const *references);
/** Returns the string reference to the root referenced symbol. The data remains
@ -236,11 +230,9 @@ mlirSymbolRefAttrGetNestedReference(MlirAttribute attr, intptr_t pos);
MLIR_CAPI_EXPORTED bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr);
/** Creates a flat symbol reference attribute in the given context referencing a
* symbol identified by the given string. The string need not be null-terminated
* and its length must be specified. */
* symbol identified by the given string. */
MLIR_CAPI_EXPORTED MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx,
intptr_t length,
const char *symbol);
MlirStringRef symbol);
/** Returns the referenced symbol as a string reference. The data remains live
* as long as the context in which the attribute lives. */
@ -349,11 +341,10 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrDoubleGet(
MlirType shapedType, intptr_t numElements, const double *elements);
/** Creates a dense elements attribute with the given shaped type from string
* elements. The strings need not be null-terminated and their lengths are
* provided as a separate argument co-indexed with the strs argument. */
MLIR_CAPI_EXPORTED MlirAttribute
mlirDenseElementsAttrStringGet(MlirType shapedType, intptr_t numElements,
intptr_t const *strLengths, const char **strs);
* elements. */
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrStringGet(
MlirType shapedType, intptr_t numElements, MlirStringRef *strs);
/** Creates a dense elements attribute that has the same data as the given dense
* elements attribute and a different shaped type. The new type must have the
* same total number of elements. */

View File

@ -86,8 +86,8 @@ MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr,
}
MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr,
const char *name) {
return wrap(unwrap(attr).cast<DictionaryAttr>().get(name));
MlirStringRef name) {
return wrap(unwrap(attr).cast<DictionaryAttr>().get(unwrap(name)));
}
//===----------------------------------------------------------------------===//
@ -160,16 +160,16 @@ bool mlirAttributeIsAOpaque(MlirAttribute attr) {
return unwrap(attr).isa<OpaqueAttr>();
}
MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, const char *dialectNamespace,
MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace,
intptr_t dataLength, const char *data,
MlirType type) {
return wrap(OpaqueAttr::get(Identifier::get(dialectNamespace, unwrap(ctx)),
StringRef(data, dataLength), unwrap(type),
unwrap(ctx)));
return wrap(
OpaqueAttr::get(Identifier::get(unwrap(dialectNamespace), unwrap(ctx)),
StringRef(data, dataLength), unwrap(type), unwrap(ctx)));
}
const char *mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) {
return unwrap(attr).cast<OpaqueAttr>().getDialectNamespace().c_str();
MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) {
return wrap(unwrap(attr).cast<OpaqueAttr>().getDialectNamespace().strref());
}
MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr) {
@ -184,14 +184,12 @@ bool mlirAttributeIsAString(MlirAttribute attr) {
return unwrap(attr).isa<StringAttr>();
}
MlirAttribute mlirStringAttrGet(MlirContext ctx, intptr_t length,
const char *data) {
return wrap(StringAttr::get(StringRef(data, length), unwrap(ctx)));
MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) {
return wrap(StringAttr::get(unwrap(str), unwrap(ctx)));
}
MlirAttribute mlirStringAttrTypedGet(MlirType type, intptr_t length,
const char *data) {
return wrap(StringAttr::get(StringRef(data, length), unwrap(type)));
MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) {
return wrap(StringAttr::get(unwrap(str), unwrap(type)));
}
MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) {
@ -206,14 +204,14 @@ bool mlirAttributeIsASymbolRef(MlirAttribute attr) {
return unwrap(attr).isa<SymbolRefAttr>();
}
MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, intptr_t length,
const char *symbol, intptr_t numReferences,
MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol,
intptr_t numReferences,
MlirAttribute const *references) {
SmallVector<FlatSymbolRefAttr, 4> refs;
refs.reserve(numReferences);
for (intptr_t i = 0; i < numReferences; ++i)
refs.push_back(unwrap(references[i]).cast<FlatSymbolRefAttr>());
return wrap(SymbolRefAttr::get(StringRef(symbol, length), refs, unwrap(ctx)));
return wrap(SymbolRefAttr::get(unwrap(symbol), refs, unwrap(ctx)));
}
MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) {
@ -242,9 +240,8 @@ bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) {
return unwrap(attr).isa<FlatSymbolRefAttr>();
}
MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, intptr_t length,
const char *symbol) {
return wrap(FlatSymbolRefAttr::get(StringRef(symbol, length), unwrap(ctx)));
MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) {
return wrap(FlatSymbolRefAttr::get(unwrap(symbol), unwrap(ctx)));
}
MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) {
@ -424,12 +421,11 @@ MlirAttribute mlirDenseElementsAttrDoubleGet(MlirType shapedType,
MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType,
intptr_t numElements,
intptr_t const *strLengths,
const char **strs) {
MlirStringRef *strs) {
SmallVector<StringRef, 8> values;
values.reserve(numElements);
for (intptr_t i = 0; i < numElements; ++i)
values.push_back(StringRef(strs[i], strLengths[i]));
values.push_back(unwrap(strs[i]));
return wrap(
DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), values));

View File

@ -732,6 +732,13 @@ void callbackSetFixedLengthString(const char *data, intptr_t len,
strncpy(userData, data, len);
}
bool stringIsEqual(const char *lhs, MlirStringRef rhs) {
if (strlen(lhs) != rhs.length) {
return false;
}
return !strncmp(lhs, rhs.data, rhs.length);
}
int printStandardAttributes(MlirContext ctx) {
MlirAttribute floating =
mlirFloatAttrDoubleGet(ctx, mlirF64TypeGet(ctx), 2.0);
@ -763,9 +770,10 @@ int printStandardAttributes(MlirContext ctx) {
const char data[] = "abcdefghijklmnopqestuvwxyz";
MlirAttribute opaque =
mlirOpaqueAttrGet(ctx, "std", 3, data, mlirNoneTypeGet(ctx));
mlirOpaqueAttrGet(ctx, mlirStringRefCreateFromCString("std"), 3, data,
mlirNoneTypeGet(ctx));
if (!mlirAttributeIsAOpaque(opaque) ||
strcmp("std", mlirOpaqueAttrGetDialectNamespace(opaque)))
!stringIsEqual("std", mlirOpaqueAttrGetDialectNamespace(opaque)))
return 4;
MlirStringRef opaqueData = mlirOpaqueAttrGetData(opaque);
@ -775,7 +783,8 @@ int printStandardAttributes(MlirContext ctx) {
mlirAttributeDump(opaque);
// CHECK: #std.abc
MlirAttribute string = mlirStringAttrGet(ctx, 2, data + 3);
MlirAttribute string =
mlirStringAttrGet(ctx, mlirStringRefCreate(data + 3, 2));
if (!mlirAttributeIsAString(string))
return 6;
@ -786,7 +795,8 @@ int printStandardAttributes(MlirContext ctx) {
mlirAttributeDump(string);
// CHECK: "de"
MlirAttribute flatSymbolRef = mlirFlatSymbolRefAttrGet(ctx, 3, data + 5);
MlirAttribute flatSymbolRef =
mlirFlatSymbolRefAttrGet(ctx, mlirStringRefCreate(data + 5, 3));
if (!mlirAttributeIsAFlatSymbolRef(flatSymbolRef))
return 8;
@ -799,7 +809,8 @@ int printStandardAttributes(MlirContext ctx) {
// CHECK: @fgh
MlirAttribute symbols[] = {flatSymbolRef, flatSymbolRef};
MlirAttribute symbolRef = mlirSymbolRefAttrGet(ctx, 2, data + 8, 2, symbols);
MlirAttribute symbolRef =
mlirSymbolRefAttrGet(ctx, mlirStringRefCreate(data + 8, 2), 2, symbols);
if (!mlirAttributeIsASymbolRef(symbolRef) ||
mlirSymbolRefAttrGetNumNestedReferences(symbolRef) != 2 ||
!mlirAttributeEqual(mlirSymbolRefAttrGetNestedReference(symbolRef, 0),