[mlir] Expose operation attributes to Python bindings

Operations in a MLIR have a dictionary of attributes attached. Expose
those to Python bindings through a pseudo-container that can be indexed
either by attribute name, producing a PyAttribute, or by a contiguous
index for enumeration purposes, producing a PyNamedAttribute.

Depends On D90917

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D90919
This commit is contained in:
Alex Zinenko 2020-11-06 11:25:41 +01:00
parent 885d3f4129
commit c3a6e7c9b7
2 changed files with 94 additions and 0 deletions

View File

@ -1282,6 +1282,47 @@ private:
PyOperationRef operation;
};
/// A list of operation attributes. Can be indexed by name, producing
/// attributes, or by index, producing named attributes.
class PyOpAttributeMap {
public:
PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
PyAttribute dunderGetItemNamed(const std::string &name) {
MlirAttribute attr =
mlirOperationGetAttributeByName(operation->get(), name.c_str());
if (mlirAttributeIsNull(attr)) {
throw SetPyError(PyExc_KeyError,
"attempt to access a non-existent attribute");
}
return PyAttribute(operation->getContext(), attr);
}
PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
if (index < 0 || index >= dunderLen()) {
throw SetPyError(PyExc_IndexError,
"attempt to access out of bounds attribute");
}
MlirNamedAttribute namedAttr =
mlirOperationGetAttribute(operation->get(), index);
return PyNamedAttribute(namedAttr.attribute, std::string(namedAttr.name));
}
intptr_t dunderLen() {
return mlirOperationGetNumAttributes(operation->get());
}
static void bind(py::module &m) {
py::class_<PyOpAttributeMap>(m, "OpAttributeMap")
.def("__len__", &PyOpAttributeMap::dunderLen)
.def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
.def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed);
}
private:
PyOperationRef operation;
};
} // end namespace
//------------------------------------------------------------------------------
@ -2436,6 +2477,11 @@ void mlir::python::populateIRSubmodule(py::module &m) {
})
.def("__eq__",
[](PyOperationBase &self, py::object other) { return false; })
.def_property_readonly("attributes",
[](PyOperationBase &self) {
return PyOpAttributeMap(
self.getOperation().getRef());
})
.def_property_readonly("operands",
[](PyOperationBase &self) {
return PyOpOperandList(
@ -2810,6 +2856,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
PyBlockList::bind(m);
PyOperationIterator::bind(m);
PyOperationList::bind(m);
PyOpAttributeMap::bind(m);
PyOpOperandList::bind(m);
PyOpResultList::bind(m);
PyRegionIterator::bind(m);

View File

@ -277,6 +277,53 @@ def testOperationResultList():
run(testOperationResultList)
# CHECK-LABEL: TEST: testOperationAttributes
def testOperationAttributes():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(r"""
"some.op"() { some.attribute = 1 : i8,
other.attribute = 3.0,
dependent = "text" } : () -> ()
""", ctx)
op = module.body.operations[0]
assert len(op.attributes) == 3
iattr = IntegerAttr(op.attributes["some.attribute"])
fattr = FloatAttr(op.attributes["other.attribute"])
sattr = StringAttr(op.attributes["dependent"])
# CHECK: Attribute type i8, value 1
print(f"Attribute type {iattr.type}, value {iattr.value}")
# CHECK: Attribute type f64, value 3.0
print(f"Attribute type {fattr.type}, value {fattr.value}")
# CHECK: Attribute value text
print(f"Attribute value {sattr.value}")
# We don't know in which order the attributes are stored.
# CHECK-DAG: NamedAttribute(dependent="text")
# CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64)
# CHECK-DAG: NamedAttribute(some.attribute=1 : i8)
for attr in op.attributes:
print(str(attr))
# Check that exceptions are raised as expected.
try:
op.attributes["does_not_exist"]
except KeyError:
pass
else:
assert False, "expected KeyError on accessing a non-existent attribute"
try:
op.attributes[42]
except IndexError:
pass
else:
assert False, "expected IndexError on accessing an out-of-bounds attribute"
run(testOperationAttributes)
# CHECK-LABEL: TEST: testOperationPrint
def testOperationPrint():
ctx = Context()