[mlir] Add initial Python bindings for DenseInt/FPElementsAttr

Enumerating elements in these classes is necessary to enable custom
operand accessors for variadic operands.

Depends On D90919

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D90923
This commit is contained in:
Alex Zinenko 2020-11-06 11:59:22 +01:00
parent f0d76275cb
commit 4669ea3bd8
2 changed files with 164 additions and 4 deletions

View File

@ -1621,11 +1621,14 @@ public:
return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
}
intptr_t dunderLen() { return mlirElementsAttrGetNumElements(attr); }
static void bindDerived(ClassTy &c) {
c.def_static("get", PyDenseElementsAttribute::getFromBuffer,
py::arg("array"), py::arg("signless") = true,
py::arg("context") = py::none(),
"Gets from a buffer or ndarray")
c.def("__len__", &PyDenseElementsAttribute::dunderLen)
.def_static("get", PyDenseElementsAttribute::getFromBuffer,
py::arg("array"), py::arg("signless") = true,
py::arg("context") = py::none(),
"Gets from a buffer or ndarray")
.def_static("get_splat", PyDenseElementsAttribute::getSplat,
py::arg("shaped_type"), py::arg("element_attr"),
"Gets a DenseElementsAttr where all values are the same")
@ -1651,6 +1654,101 @@ private:
}
};
/// Refinement of the PyDenseElementsAttribute for attributes containing integer
/// (and boolean) values. Supports element access.
class PyDenseIntElementsAttribute
: public PyConcreteAttribute<PyDenseIntElementsAttribute,
PyDenseElementsAttribute> {
public:
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
static constexpr const char *pyClassName = "DenseIntElementsAttr";
using PyConcreteAttribute::PyConcreteAttribute;
/// Returns the element at the given linear position. Asserts if the index is
/// out of range.
py::int_ dunderGetItem(intptr_t pos) {
if (pos < 0 || pos >= dunderLen()) {
throw SetPyError(PyExc_IndexError,
"attempt to access out of bounds element");
}
MlirType type = mlirAttributeGetType(attr);
type = mlirShapedTypeGetElementType(type);
assert(mlirTypeIsAInteger(type) &&
"expected integer element type in dense int elements attribute");
// Dispatch element extraction to an appropriate C function based on the
// elemental type of the attribute. py::int_ is implicitly constructible
// from any C++ integral type and handles bitwidth correctly.
// TODO: consider caching the type properties in the constructor to avoid
// querying them on each element access.
unsigned width = mlirIntegerTypeGetWidth(type);
bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
if (isUnsigned) {
if (width == 1) {
return mlirDenseElementsAttrGetBoolValue(attr, pos);
}
if (width == 32) {
return mlirDenseElementsAttrGetUInt32Value(attr, pos);
}
if (width == 64) {
return mlirDenseElementsAttrGetUInt64Value(attr, pos);
}
} else {
if (width == 1) {
return mlirDenseElementsAttrGetBoolValue(attr, pos);
}
if (width == 32) {
return mlirDenseElementsAttrGetInt32Value(attr, pos);
}
if (width == 64) {
return mlirDenseElementsAttrGetInt64Value(attr, pos);
}
}
throw SetPyError(PyExc_TypeError, "Unsupported integer type");
}
static void bindDerived(ClassTy &c) {
c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
}
};
/// Refinement of PyDenseElementsAttribute for attributes containing
/// floating-point values. Supports element access.
class PyDenseFPElementsAttribute
: public PyConcreteAttribute<PyDenseFPElementsAttribute,
PyDenseElementsAttribute> {
public:
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
static constexpr const char *pyClassName = "DenseFPElementsAttr";
using PyConcreteAttribute::PyConcreteAttribute;
py::float_ dunderGetItem(intptr_t pos) {
if (pos < 0 || pos >= dunderLen()) {
throw SetPyError(PyExc_IndexError,
"attempt to access out of bounds element");
}
MlirType type = mlirAttributeGetType(attr);
type = mlirShapedTypeGetElementType(type);
// Dispatch element extraction to an appropriate C function based on the
// elemental type of the attribute. py::float_ is implicitly constructible
// from float and double.
// TODO: consider caching the type properties in the constructor to avoid
// querying them on each element access.
if (mlirTypeIsAF32(type)) {
return mlirDenseElementsAttrGetFloatValue(attr, pos);
}
if (mlirTypeIsAF64(type)) {
return mlirDenseElementsAttrGetDoubleValue(attr, pos);
}
throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
}
static void bindDerived(ClassTy &c) {
c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
}
};
} // namespace
//------------------------------------------------------------------------------
@ -2754,6 +2852,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
PyBoolAttribute::bind(m);
PyStringAttribute::bind(m);
PyDenseElementsAttribute::bind(m);
PyDenseIntElementsAttribute::bind(m);
PyDenseFPElementsAttribute::bind(m);
//----------------------------------------------------------------------------
// Mapping of PyType.

View File

@ -181,3 +181,63 @@ def testNamedAttr():
print("named:", named)
run(testNamedAttr)
# CHECK-LABEL: TEST: testDenseIntAttr
def testDenseIntAttr():
with Context():
raw = Attribute.parse("dense<[[0,1,2],[3,4,5]]> : vector<2x3xi32>")
# CHECK: attr: dense<[{{\[}}0, 1, 2], [3, 4, 5]]>
print("attr:", raw)
a = DenseIntElementsAttr(raw)
assert len(a) == 6
# CHECK: 0 1 2 3 4 5
for value in a:
print(value, end=" ")
print()
# CHECK: i32
print(ShapedType(a.type).element_type)
raw = Attribute.parse("dense<[true,false,true,false]> : vector<4xi1>")
# CHECK: attr: dense<[true, false, true, false]>
print("attr:", raw)
a = DenseIntElementsAttr(raw)
assert len(a) == 4
# CHECK: 1 0 1 0
for value in a:
print(value, end=" ")
print()
# CHECK: i1
print(ShapedType(a.type).element_type)
run(testDenseIntAttr)
# CHECK-LABEL: TEST: testDenseFPAttr
def testDenseFPAttr():
with Context():
raw = Attribute.parse("dense<[0.0, 1.0, 2.0, 3.0]> : vector<4xf32>")
# CHECK: attr: dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]>
print("attr:", raw)
a = DenseFPElementsAttr(raw)
assert len(a) == 4
# CHECK: 0.0 1.0 2.0 3.0
for value in a:
print(value, end=" ")
print()
# CHECK: f32
print(ShapedType(a.type).element_type)
run(testDenseFPAttr)