forked from OSchip/llvm-project
[MLIR] Add function to create BFloat16 array attribute
This patch adds a new function `mlirDenseElementsAttrBFloat16Get()`, which accepts the shaped type, the number of BFloat16 values, and a pointer to an array of BFloat16 values, each of which is a `uint16_t` value. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D123981
This commit is contained in:
parent
0f8c626723
commit
25c218be36
|
@ -379,6 +379,8 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrFloatGet(
|
|||
MlirType shapedType, intptr_t numElements, const float *elements);
|
||||
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrDoubleGet(
|
||||
MlirType shapedType, intptr_t numElements, const double *elements);
|
||||
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrBFloat16Get(
|
||||
MlirType shapedType, intptr_t numElements, const uint16_t *elements);
|
||||
|
||||
/// Creates a dense elements attribute with the given shaped type from string
|
||||
/// elements.
|
||||
|
|
|
@ -474,6 +474,13 @@ MlirAttribute mlirDenseElementsAttrDoubleGet(MlirType shapedType,
|
|||
const double *elements) {
|
||||
return getDenseAttribute(shapedType, numElements, elements);
|
||||
}
|
||||
MlirAttribute mlirDenseElementsAttrBFloat16Get(MlirType shapedType,
|
||||
intptr_t numElements,
|
||||
const uint16_t *elements) {
|
||||
size_t bufferSize = numElements * 2;
|
||||
const void *buffer = static_cast<const void *>(elements);
|
||||
return mlirDenseElementsAttrRawBufferGet(shapedType, bufferSize, buffer);
|
||||
}
|
||||
|
||||
MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType,
|
||||
intptr_t numElements,
|
||||
|
|
|
@ -936,6 +936,7 @@ int printBuiltinAttributes(MlirContext ctx) {
|
|||
int64_t ints64[] = {0, 1};
|
||||
float floats[] = {0.0f, 1.0f};
|
||||
double doubles[] = {0.0, 1.0};
|
||||
uint16_t bf16s[] = {0x0, 0x3f80};
|
||||
MlirAttribute encoding = mlirAttributeGetNull();
|
||||
MlirAttribute boolElements = mlirDenseElementsAttrBoolGet(
|
||||
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1), encoding),
|
||||
|
@ -974,6 +975,9 @@ int printBuiltinAttributes(MlirContext ctx) {
|
|||
MlirAttribute doubleElements = mlirDenseElementsAttrDoubleGet(
|
||||
mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx), encoding), 2,
|
||||
doubles);
|
||||
MlirAttribute bf16Elements = mlirDenseElementsAttrBFloat16Get(
|
||||
mlirRankedTensorTypeGet(2, shape, mlirBF16TypeGet(ctx), encoding), 2,
|
||||
bf16s);
|
||||
|
||||
if (!mlirAttributeIsADenseElements(boolElements) ||
|
||||
!mlirAttributeIsADenseElements(uint8Elements) ||
|
||||
|
@ -983,7 +987,8 @@ int printBuiltinAttributes(MlirContext ctx) {
|
|||
!mlirAttributeIsADenseElements(uint64Elements) ||
|
||||
!mlirAttributeIsADenseElements(int64Elements) ||
|
||||
!mlirAttributeIsADenseElements(floatElements) ||
|
||||
!mlirAttributeIsADenseElements(doubleElements))
|
||||
!mlirAttributeIsADenseElements(doubleElements) ||
|
||||
!mlirAttributeIsADenseElements(bf16Elements))
|
||||
return 14;
|
||||
|
||||
if (mlirDenseElementsAttrGetBoolValue(boolElements, 1) != 1 ||
|
||||
|
@ -1009,6 +1014,7 @@ int printBuiltinAttributes(MlirContext ctx) {
|
|||
mlirAttributeDump(int64Elements);
|
||||
mlirAttributeDump(floatElements);
|
||||
mlirAttributeDump(doubleElements);
|
||||
mlirAttributeDump(bf16Elements);
|
||||
// CHECK: dense<{{\[}}[false, true]]> : tensor<1x2xi1>
|
||||
// CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui8>
|
||||
// CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi8>
|
||||
|
@ -1018,6 +1024,7 @@ int printBuiltinAttributes(MlirContext ctx) {
|
|||
// CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi64>
|
||||
// CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf32>
|
||||
// CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf64>
|
||||
// CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xbf16>
|
||||
|
||||
MlirAttribute splatBool = mlirDenseElementsAttrBoolSplatGet(
|
||||
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1), encoding),
|
||||
|
@ -1094,12 +1101,15 @@ int printBuiltinAttributes(MlirContext ctx) {
|
|||
float *floatRawData = (float *)mlirDenseElementsAttrGetRawData(floatElements);
|
||||
double *doubleRawData =
|
||||
(double *)mlirDenseElementsAttrGetRawData(doubleElements);
|
||||
uint16_t *bf16RawData =
|
||||
(uint16_t *)mlirDenseElementsAttrGetRawData(bf16Elements);
|
||||
if (uint8RawData[0] != 0u || uint8RawData[1] != 1u || int8RawData[0] != 0 ||
|
||||
int8RawData[1] != 1 || uint32RawData[0] != 0u || uint32RawData[1] != 1u ||
|
||||
int32RawData[0] != 0 || int32RawData[1] != 1 || uint64RawData[0] != 0u ||
|
||||
uint64RawData[1] != 1u || int64RawData[0] != 0 || int64RawData[1] != 1 ||
|
||||
floatRawData[0] != 0.0f || floatRawData[1] != 1.0f ||
|
||||
doubleRawData[0] != 0.0 || doubleRawData[1] != 1.0)
|
||||
doubleRawData[0] != 0.0 || doubleRawData[1] != 1.0 ||
|
||||
bf16RawData[0] != 0 || bf16RawData[1] != 0x3f80)
|
||||
return 18;
|
||||
|
||||
mlirAttributeDump(splatBool);
|
||||
|
@ -1123,8 +1133,10 @@ int printBuiltinAttributes(MlirContext ctx) {
|
|||
|
||||
mlirAttributeDump(mlirElementsAttrGetValue(floatElements, 2, uints64));
|
||||
mlirAttributeDump(mlirElementsAttrGetValue(doubleElements, 2, uints64));
|
||||
mlirAttributeDump(mlirElementsAttrGetValue(bf16Elements, 2, uints64));
|
||||
// CHECK: 1.000000e+00 : f32
|
||||
// CHECK: 1.000000e+00 : f64
|
||||
// CHECK: 1.000000e+00 : bf16
|
||||
|
||||
int64_t indices[] = {0, 1};
|
||||
int64_t one = 1;
|
||||
|
|
Loading…
Reference in New Issue