forked from OSchip/llvm-project
[mlir] Add initial Python bindings for DenseInt/FPElementsAttr
Enumerating elements in these classes is necessary to enable custom operand accessors for variadic operands. Depends On D90919 Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D90923
This commit is contained in:
parent
f0d76275cb
commit
4669ea3bd8
|
@ -1621,11 +1621,14 @@ public:
|
|||
return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
|
||||
}
|
||||
|
||||
intptr_t dunderLen() { return mlirElementsAttrGetNumElements(attr); }
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_static("get", PyDenseElementsAttribute::getFromBuffer,
|
||||
py::arg("array"), py::arg("signless") = true,
|
||||
py::arg("context") = py::none(),
|
||||
"Gets from a buffer or ndarray")
|
||||
c.def("__len__", &PyDenseElementsAttribute::dunderLen)
|
||||
.def_static("get", PyDenseElementsAttribute::getFromBuffer,
|
||||
py::arg("array"), py::arg("signless") = true,
|
||||
py::arg("context") = py::none(),
|
||||
"Gets from a buffer or ndarray")
|
||||
.def_static("get_splat", PyDenseElementsAttribute::getSplat,
|
||||
py::arg("shaped_type"), py::arg("element_attr"),
|
||||
"Gets a DenseElementsAttr where all values are the same")
|
||||
|
@ -1651,6 +1654,101 @@ private:
|
|||
}
|
||||
};
|
||||
|
||||
/// Refinement of the PyDenseElementsAttribute for attributes containing integer
|
||||
/// (and boolean) values. Supports element access.
|
||||
class PyDenseIntElementsAttribute
|
||||
: public PyConcreteAttribute<PyDenseIntElementsAttribute,
|
||||
PyDenseElementsAttribute> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
|
||||
static constexpr const char *pyClassName = "DenseIntElementsAttr";
|
||||
using PyConcreteAttribute::PyConcreteAttribute;
|
||||
|
||||
/// Returns the element at the given linear position. Asserts if the index is
|
||||
/// out of range.
|
||||
py::int_ dunderGetItem(intptr_t pos) {
|
||||
if (pos < 0 || pos >= dunderLen()) {
|
||||
throw SetPyError(PyExc_IndexError,
|
||||
"attempt to access out of bounds element");
|
||||
}
|
||||
|
||||
MlirType type = mlirAttributeGetType(attr);
|
||||
type = mlirShapedTypeGetElementType(type);
|
||||
assert(mlirTypeIsAInteger(type) &&
|
||||
"expected integer element type in dense int elements attribute");
|
||||
// Dispatch element extraction to an appropriate C function based on the
|
||||
// elemental type of the attribute. py::int_ is implicitly constructible
|
||||
// from any C++ integral type and handles bitwidth correctly.
|
||||
// TODO: consider caching the type properties in the constructor to avoid
|
||||
// querying them on each element access.
|
||||
unsigned width = mlirIntegerTypeGetWidth(type);
|
||||
bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
|
||||
if (isUnsigned) {
|
||||
if (width == 1) {
|
||||
return mlirDenseElementsAttrGetBoolValue(attr, pos);
|
||||
}
|
||||
if (width == 32) {
|
||||
return mlirDenseElementsAttrGetUInt32Value(attr, pos);
|
||||
}
|
||||
if (width == 64) {
|
||||
return mlirDenseElementsAttrGetUInt64Value(attr, pos);
|
||||
}
|
||||
} else {
|
||||
if (width == 1) {
|
||||
return mlirDenseElementsAttrGetBoolValue(attr, pos);
|
||||
}
|
||||
if (width == 32) {
|
||||
return mlirDenseElementsAttrGetInt32Value(attr, pos);
|
||||
}
|
||||
if (width == 64) {
|
||||
return mlirDenseElementsAttrGetInt64Value(attr, pos);
|
||||
}
|
||||
}
|
||||
throw SetPyError(PyExc_TypeError, "Unsupported integer type");
|
||||
}
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
|
||||
}
|
||||
};
|
||||
|
||||
/// Refinement of PyDenseElementsAttribute for attributes containing
|
||||
/// floating-point values. Supports element access.
|
||||
class PyDenseFPElementsAttribute
|
||||
: public PyConcreteAttribute<PyDenseFPElementsAttribute,
|
||||
PyDenseElementsAttribute> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
|
||||
static constexpr const char *pyClassName = "DenseFPElementsAttr";
|
||||
using PyConcreteAttribute::PyConcreteAttribute;
|
||||
|
||||
py::float_ dunderGetItem(intptr_t pos) {
|
||||
if (pos < 0 || pos >= dunderLen()) {
|
||||
throw SetPyError(PyExc_IndexError,
|
||||
"attempt to access out of bounds element");
|
||||
}
|
||||
|
||||
MlirType type = mlirAttributeGetType(attr);
|
||||
type = mlirShapedTypeGetElementType(type);
|
||||
// Dispatch element extraction to an appropriate C function based on the
|
||||
// elemental type of the attribute. py::float_ is implicitly constructible
|
||||
// from float and double.
|
||||
// TODO: consider caching the type properties in the constructor to avoid
|
||||
// querying them on each element access.
|
||||
if (mlirTypeIsAF32(type)) {
|
||||
return mlirDenseElementsAttrGetFloatValue(attr, pos);
|
||||
}
|
||||
if (mlirTypeIsAF64(type)) {
|
||||
return mlirDenseElementsAttrGetDoubleValue(attr, pos);
|
||||
}
|
||||
throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
|
||||
}
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
@ -2754,6 +2852,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
PyBoolAttribute::bind(m);
|
||||
PyStringAttribute::bind(m);
|
||||
PyDenseElementsAttribute::bind(m);
|
||||
PyDenseIntElementsAttribute::bind(m);
|
||||
PyDenseFPElementsAttribute::bind(m);
|
||||
|
||||
//----------------------------------------------------------------------------
|
||||
// Mapping of PyType.
|
||||
|
|
|
@ -181,3 +181,63 @@ def testNamedAttr():
|
|||
print("named:", named)
|
||||
|
||||
run(testNamedAttr)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testDenseIntAttr
|
||||
def testDenseIntAttr():
|
||||
with Context():
|
||||
raw = Attribute.parse("dense<[[0,1,2],[3,4,5]]> : vector<2x3xi32>")
|
||||
# CHECK: attr: dense<[{{\[}}0, 1, 2], [3, 4, 5]]>
|
||||
print("attr:", raw)
|
||||
|
||||
a = DenseIntElementsAttr(raw)
|
||||
assert len(a) == 6
|
||||
|
||||
# CHECK: 0 1 2 3 4 5
|
||||
for value in a:
|
||||
print(value, end=" ")
|
||||
print()
|
||||
|
||||
# CHECK: i32
|
||||
print(ShapedType(a.type).element_type)
|
||||
|
||||
raw = Attribute.parse("dense<[true,false,true,false]> : vector<4xi1>")
|
||||
# CHECK: attr: dense<[true, false, true, false]>
|
||||
print("attr:", raw)
|
||||
|
||||
a = DenseIntElementsAttr(raw)
|
||||
assert len(a) == 4
|
||||
|
||||
# CHECK: 1 0 1 0
|
||||
for value in a:
|
||||
print(value, end=" ")
|
||||
print()
|
||||
|
||||
# CHECK: i1
|
||||
print(ShapedType(a.type).element_type)
|
||||
|
||||
|
||||
run(testDenseIntAttr)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testDenseFPAttr
|
||||
def testDenseFPAttr():
|
||||
with Context():
|
||||
raw = Attribute.parse("dense<[0.0, 1.0, 2.0, 3.0]> : vector<4xf32>")
|
||||
# CHECK: attr: dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]>
|
||||
|
||||
print("attr:", raw)
|
||||
|
||||
a = DenseFPElementsAttr(raw)
|
||||
assert len(a) == 4
|
||||
|
||||
# CHECK: 0.0 1.0 2.0 3.0
|
||||
for value in a:
|
||||
print(value, end=" ")
|
||||
print()
|
||||
|
||||
# CHECK: f32
|
||||
print(ShapedType(a.type).element_type)
|
||||
|
||||
|
||||
run(testDenseFPAttr)
|
||||
|
|
Loading…
Reference in New Issue