Add Python binding for MLIR Type Attribute

Differential Revision: https://reviews.llvm.org/D92711
This commit is contained in:
Mehdi Amini 2020-12-05 02:08:38 +00:00
parent e15ae454b4
commit e56f398dd3
2 changed files with 37 additions and 0 deletions

View File

@ -1922,6 +1922,28 @@ public:
}
};
class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
public:
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
static constexpr const char *pyClassName = "TypeAttr";
using PyConcreteAttribute::PyConcreteAttribute;
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](PyType value, DefaultingPyMlirContext context) {
MlirAttribute attr = mlirTypeAttrGet(value.get());
return PyTypeAttribute(context->getRef(), attr);
},
py::arg("value"), py::arg("context") = py::none(),
"Gets a uniqued Type attribute");
c.def_property_readonly("value", [](PyTypeAttribute &self) {
return PyType(self.getContext()->getRef(),
mlirTypeAttrGetValue(self.get()));
});
}
};
/// Unit Attribute subclass. Unit attributes don't have values.
class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
public:
@ -3073,6 +3095,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
PyDenseElementsAttribute::bind(m);
PyDenseIntElementsAttribute::bind(m);
PyDenseFPElementsAttribute::bind(m);
PyTypeAttribute::bind(m);
PyUnitAttribute::bind(m);
//----------------------------------------------------------------------------

View File

@ -255,3 +255,17 @@ def testDenseFPAttr():
run(testDenseFPAttr)
# CHECK-LABEL: TEST: testTypeAttr
def testTypeAttr():
with Context():
raw = Attribute.parse("vector<4xf32>")
# CHECK: attr: vector<4xf32>
print("attr:", raw)
type_attr = TypeAttr(raw)
# CHECK: f32
print(ShapedType(type_attr.value).element_type)
run(testTypeAttr)