forked from OSchip/llvm-project
[mlir][python] 8b/16b DenseIntElements access
This extends dense attribute element access to support 8b and 16b ints. Also extends the corresponding parts of the C api. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D117731
This commit is contained in:
parent
26167cae45
commit
308d8b8c66
|
@ -355,6 +355,10 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt8Get(
|
|||
MlirType shapedType, intptr_t numElements, const uint8_t *elements);
|
||||
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt8Get(
|
||||
MlirType shapedType, intptr_t numElements, const int8_t *elements);
|
||||
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt16Get(
|
||||
MlirType shapedType, intptr_t numElements, const uint16_t *elements);
|
||||
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt16Get(
|
||||
MlirType shapedType, intptr_t numElements, const int16_t *elements);
|
||||
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt32Get(
|
||||
MlirType shapedType, intptr_t numElements, const uint32_t *elements);
|
||||
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt32Get(
|
||||
|
@ -416,6 +420,10 @@ MLIR_CAPI_EXPORTED int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr,
|
|||
intptr_t pos);
|
||||
MLIR_CAPI_EXPORTED uint8_t
|
||||
mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos);
|
||||
MLIR_CAPI_EXPORTED int16_t
|
||||
mlirDenseElementsAttrGetInt16Value(MlirAttribute attr, intptr_t pos);
|
||||
MLIR_CAPI_EXPORTED uint16_t
|
||||
mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr, intptr_t pos);
|
||||
MLIR_CAPI_EXPORTED int32_t
|
||||
mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos);
|
||||
MLIR_CAPI_EXPORTED uint32_t
|
||||
|
|
|
@ -673,6 +673,12 @@ public:
|
|||
if (width == 1) {
|
||||
return mlirDenseElementsAttrGetBoolValue(*this, pos);
|
||||
}
|
||||
if (width == 8) {
|
||||
return mlirDenseElementsAttrGetUInt8Value(*this, pos);
|
||||
}
|
||||
if (width == 16) {
|
||||
return mlirDenseElementsAttrGetUInt16Value(*this, pos);
|
||||
}
|
||||
if (width == 32) {
|
||||
return mlirDenseElementsAttrGetUInt32Value(*this, pos);
|
||||
}
|
||||
|
@ -683,6 +689,12 @@ public:
|
|||
if (width == 1) {
|
||||
return mlirDenseElementsAttrGetBoolValue(*this, pos);
|
||||
}
|
||||
if (width == 8) {
|
||||
return mlirDenseElementsAttrGetInt8Value(*this, pos);
|
||||
}
|
||||
if (width == 16) {
|
||||
return mlirDenseElementsAttrGetInt16Value(*this, pos);
|
||||
}
|
||||
if (width == 32) {
|
||||
return mlirDenseElementsAttrGetInt32Value(*this, pos);
|
||||
}
|
||||
|
|
|
@ -426,6 +426,16 @@ MlirAttribute mlirDenseElementsAttrInt8Get(MlirType shapedType,
|
|||
const int8_t *elements) {
|
||||
return getDenseAttribute(shapedType, numElements, elements);
|
||||
}
|
||||
MlirAttribute mlirDenseElementsAttrUInt16Get(MlirType shapedType,
|
||||
intptr_t numElements,
|
||||
const uint16_t *elements) {
|
||||
return getDenseAttribute(shapedType, numElements, elements);
|
||||
}
|
||||
MlirAttribute mlirDenseElementsAttrInt16Get(MlirType shapedType,
|
||||
intptr_t numElements,
|
||||
const int16_t *elements) {
|
||||
return getDenseAttribute(shapedType, numElements, elements);
|
||||
}
|
||||
MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType,
|
||||
intptr_t numElements,
|
||||
const uint32_t *elements) {
|
||||
|
@ -530,6 +540,12 @@ int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) {
|
|||
uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) {
|
||||
return unwrap(attr).cast<DenseElementsAttr>().getValues<uint8_t>()[pos];
|
||||
}
|
||||
int16_t mlirDenseElementsAttrGetInt16Value(MlirAttribute attr, intptr_t pos) {
|
||||
return unwrap(attr).cast<DenseElementsAttr>().getValues<int16_t>()[pos];
|
||||
}
|
||||
uint16_t mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr, intptr_t pos) {
|
||||
return unwrap(attr).cast<DenseElementsAttr>().getValues<uint16_t>()[pos];
|
||||
}
|
||||
int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) {
|
||||
return unwrap(attr).cast<DenseElementsAttr>().getValues<int32_t>()[pos];
|
||||
}
|
||||
|
|
|
@ -904,6 +904,8 @@ int printBuiltinAttributes(MlirContext ctx) {
|
|||
int bools[] = {0, 1};
|
||||
uint8_t uints8[] = {0u, 1u};
|
||||
int8_t ints8[] = {0, 1};
|
||||
uint16_t uints16[] = {0u, 1u};
|
||||
int16_t ints16[] = {0, 1};
|
||||
uint32_t uints32[] = {0u, 1u};
|
||||
int32_t ints32[] = {0, 1};
|
||||
uint64_t uints64[] = {0u, 1u};
|
||||
|
@ -921,6 +923,13 @@ int printBuiltinAttributes(MlirContext ctx) {
|
|||
MlirAttribute int8Elements = mlirDenseElementsAttrInt8Get(
|
||||
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 8), encoding),
|
||||
2, ints8);
|
||||
MlirAttribute uint16Elements = mlirDenseElementsAttrUInt16Get(
|
||||
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 16),
|
||||
encoding),
|
||||
2, uints16);
|
||||
MlirAttribute int16Elements = mlirDenseElementsAttrInt16Get(
|
||||
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 16), encoding),
|
||||
2, ints16);
|
||||
MlirAttribute uint32Elements = mlirDenseElementsAttrUInt32Get(
|
||||
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32),
|
||||
encoding),
|
||||
|
@ -956,6 +965,8 @@ int printBuiltinAttributes(MlirContext ctx) {
|
|||
if (mlirDenseElementsAttrGetBoolValue(boolElements, 1) != 1 ||
|
||||
mlirDenseElementsAttrGetUInt8Value(uint8Elements, 1) != 1 ||
|
||||
mlirDenseElementsAttrGetInt8Value(int8Elements, 1) != 1 ||
|
||||
mlirDenseElementsAttrGetUInt16Value(uint16Elements, 1) != 1 ||
|
||||
mlirDenseElementsAttrGetInt16Value(int16Elements, 1) != 1 ||
|
||||
mlirDenseElementsAttrGetUInt32Value(uint32Elements, 1) != 1 ||
|
||||
mlirDenseElementsAttrGetInt32Value(int32Elements, 1) != 1 ||
|
||||
mlirDenseElementsAttrGetUInt64Value(uint64Elements, 1) != 1 ||
|
||||
|
|
|
@ -292,6 +292,50 @@ def testDenseIntAttr():
|
|||
print(ShapedType(a.type).element_type)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testDenseIntAttrGetItem
|
||||
@run
|
||||
def testDenseIntAttrGetItem():
|
||||
def print_item(attr_asm):
|
||||
attr = DenseIntElementsAttr(Attribute.parse(attr_asm))
|
||||
dtype = ShapedType(attr.type).element_type
|
||||
try:
|
||||
item = attr[0]
|
||||
print(f"{dtype}:", item)
|
||||
except TypeError as e:
|
||||
print(f"{dtype}:", e)
|
||||
|
||||
with Context():
|
||||
# CHECK: i1: 1
|
||||
print_item("dense<true> : tensor<i1>")
|
||||
# CHECK: i8: 123
|
||||
print_item("dense<123> : tensor<i8>")
|
||||
# CHECK: i16: 123
|
||||
print_item("dense<123> : tensor<i16>")
|
||||
# CHECK: i32: 123
|
||||
print_item("dense<123> : tensor<i32>")
|
||||
# CHECK: i64: 123
|
||||
print_item("dense<123> : tensor<i64>")
|
||||
# CHECK: ui8: 123
|
||||
print_item("dense<123> : tensor<ui8>")
|
||||
# CHECK: ui16: 123
|
||||
print_item("dense<123> : tensor<ui16>")
|
||||
# CHECK: ui32: 123
|
||||
print_item("dense<123> : tensor<ui32>")
|
||||
# CHECK: ui64: 123
|
||||
print_item("dense<123> : tensor<ui64>")
|
||||
# CHECK: si8: -123
|
||||
print_item("dense<-123> : tensor<si8>")
|
||||
# CHECK: si16: -123
|
||||
print_item("dense<-123> : tensor<si16>")
|
||||
# CHECK: si32: -123
|
||||
print_item("dense<-123> : tensor<si32>")
|
||||
# CHECK: si64: -123
|
||||
print_item("dense<-123> : tensor<si64>")
|
||||
|
||||
# CHECK: i7: Unsupported integer type
|
||||
print_item("dense<123> : tensor<i7>")
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testDenseFPAttr
|
||||
@run
|
||||
def testDenseFPAttr():
|
||||
|
|
Loading…
Reference in New Issue