forked from OSchip/llvm-project
[mlir] Expose operation attributes to Python bindings
Operations in a MLIR have a dictionary of attributes attached. Expose those to Python bindings through a pseudo-container that can be indexed either by attribute name, producing a PyAttribute, or by a contiguous index for enumeration purposes, producing a PyNamedAttribute. Depends On D90917 Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D90919
This commit is contained in:
parent
885d3f4129
commit
c3a6e7c9b7
|
@ -1282,6 +1282,47 @@ private:
|
|||
PyOperationRef operation;
|
||||
};
|
||||
|
||||
/// A list of operation attributes. Can be indexed by name, producing
|
||||
/// attributes, or by index, producing named attributes.
|
||||
class PyOpAttributeMap {
|
||||
public:
|
||||
PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
|
||||
|
||||
PyAttribute dunderGetItemNamed(const std::string &name) {
|
||||
MlirAttribute attr =
|
||||
mlirOperationGetAttributeByName(operation->get(), name.c_str());
|
||||
if (mlirAttributeIsNull(attr)) {
|
||||
throw SetPyError(PyExc_KeyError,
|
||||
"attempt to access a non-existent attribute");
|
||||
}
|
||||
return PyAttribute(operation->getContext(), attr);
|
||||
}
|
||||
|
||||
PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
|
||||
if (index < 0 || index >= dunderLen()) {
|
||||
throw SetPyError(PyExc_IndexError,
|
||||
"attempt to access out of bounds attribute");
|
||||
}
|
||||
MlirNamedAttribute namedAttr =
|
||||
mlirOperationGetAttribute(operation->get(), index);
|
||||
return PyNamedAttribute(namedAttr.attribute, std::string(namedAttr.name));
|
||||
}
|
||||
|
||||
intptr_t dunderLen() {
|
||||
return mlirOperationGetNumAttributes(operation->get());
|
||||
}
|
||||
|
||||
static void bind(py::module &m) {
|
||||
py::class_<PyOpAttributeMap>(m, "OpAttributeMap")
|
||||
.def("__len__", &PyOpAttributeMap::dunderLen)
|
||||
.def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
|
||||
.def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed);
|
||||
}
|
||||
|
||||
private:
|
||||
PyOperationRef operation;
|
||||
};
|
||||
|
||||
} // end namespace
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
@ -2436,6 +2477,11 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
})
|
||||
.def("__eq__",
|
||||
[](PyOperationBase &self, py::object other) { return false; })
|
||||
.def_property_readonly("attributes",
|
||||
[](PyOperationBase &self) {
|
||||
return PyOpAttributeMap(
|
||||
self.getOperation().getRef());
|
||||
})
|
||||
.def_property_readonly("operands",
|
||||
[](PyOperationBase &self) {
|
||||
return PyOpOperandList(
|
||||
|
@ -2810,6 +2856,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
PyBlockList::bind(m);
|
||||
PyOperationIterator::bind(m);
|
||||
PyOperationList::bind(m);
|
||||
PyOpAttributeMap::bind(m);
|
||||
PyOpOperandList::bind(m);
|
||||
PyOpResultList::bind(m);
|
||||
PyRegionIterator::bind(m);
|
||||
|
|
|
@ -277,6 +277,53 @@ def testOperationResultList():
|
|||
run(testOperationResultList)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testOperationAttributes
|
||||
def testOperationAttributes():
|
||||
ctx = Context()
|
||||
ctx.allow_unregistered_dialects = True
|
||||
module = Module.parse(r"""
|
||||
"some.op"() { some.attribute = 1 : i8,
|
||||
other.attribute = 3.0,
|
||||
dependent = "text" } : () -> ()
|
||||
""", ctx)
|
||||
op = module.body.operations[0]
|
||||
assert len(op.attributes) == 3
|
||||
iattr = IntegerAttr(op.attributes["some.attribute"])
|
||||
fattr = FloatAttr(op.attributes["other.attribute"])
|
||||
sattr = StringAttr(op.attributes["dependent"])
|
||||
# CHECK: Attribute type i8, value 1
|
||||
print(f"Attribute type {iattr.type}, value {iattr.value}")
|
||||
# CHECK: Attribute type f64, value 3.0
|
||||
print(f"Attribute type {fattr.type}, value {fattr.value}")
|
||||
# CHECK: Attribute value text
|
||||
print(f"Attribute value {sattr.value}")
|
||||
|
||||
# We don't know in which order the attributes are stored.
|
||||
# CHECK-DAG: NamedAttribute(dependent="text")
|
||||
# CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64)
|
||||
# CHECK-DAG: NamedAttribute(some.attribute=1 : i8)
|
||||
for attr in op.attributes:
|
||||
print(str(attr))
|
||||
|
||||
# Check that exceptions are raised as expected.
|
||||
try:
|
||||
op.attributes["does_not_exist"]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
assert False, "expected KeyError on accessing a non-existent attribute"
|
||||
|
||||
try:
|
||||
op.attributes[42]
|
||||
except IndexError:
|
||||
pass
|
||||
else:
|
||||
assert False, "expected IndexError on accessing an out-of-bounds attribute"
|
||||
|
||||
|
||||
run(testOperationAttributes)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testOperationPrint
|
||||
def testOperationPrint():
|
||||
ctx = Context()
|
||||
|
|
Loading…
Reference in New Issue