[mlir][python] Usability improvements for Python bindings

Provide a couple of quality-of-life usability improvements for Python bindings,
in particular:

  * give access to the list of types for the list of op results or block
    arguments, similarly to ValueRange->TypeRange,

  * allow for constructing empty dictionary arrays,

  * support construction of array attributes by concatenating an existing
    attribute with a Python list of attributes.

All these are required for the upcoming customization of builtin and standard
ops.

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D110946
This commit is contained in:
Alex Zinenko 2021-10-04 11:38:20 +02:00
parent c7bd643599
commit ed9e52f3af
4 changed files with 83 additions and 19 deletions

View File

@ -18,7 +18,6 @@ using namespace mlir;
using namespace mlir::python;
using llvm::SmallVector;
using llvm::StringRef;
using llvm::Twine;
namespace {
@ -44,6 +43,24 @@ public:
}
};
template <typename T>
static T pyTryCast(py::handle object) {
try {
return object.cast<T>();
} catch (py::cast_error &err) {
std::string msg =
std::string(
"Invalid attribute when attempting to create an ArrayAttribute (") +
err.what() + ")";
throw py::cast_error(msg);
} catch (py::reference_cast_error &err) {
std::string msg = std::string("Invalid attribute (None?) when attempting "
"to create an ArrayAttribute (") +
err.what() + ")";
throw py::cast_error(msg);
}
}
class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
public:
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
@ -76,6 +93,10 @@ public:
int nextIndex = 0;
};
PyAttribute getItem(intptr_t i) {
return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i));
}
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
@ -83,21 +104,7 @@ public:
SmallVector<MlirAttribute> mlirAttributes;
mlirAttributes.reserve(py::len(attributes));
for (auto attribute : attributes) {
try {
mlirAttributes.push_back(attribute.cast<PyAttribute>());
} catch (py::cast_error &err) {
std::string msg = std::string("Invalid attribute when attempting "
"to create an ArrayAttribute (") +
err.what() + ")";
throw py::cast_error(msg);
} catch (py::reference_cast_error &err) {
// This exception seems thrown when the value is "None".
std::string msg =
std::string("Invalid attribute (None?) when attempting to "
"create an ArrayAttribute (") +
err.what() + ")";
throw py::cast_error(msg);
}
mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
}
MlirAttribute attr = mlirArrayAttrGet(
context->get(), mlirAttributes.size(), mlirAttributes.data());
@ -109,8 +116,7 @@ public:
[](PyArrayAttribute &arr, intptr_t i) {
if (i >= mlirArrayAttrGetNumElements(arr))
throw py::index_error("ArrayAttribute index out of range");
return PyAttribute(arr.getContext(),
mlirArrayAttrGetElement(arr, i));
return arr.getItem(i);
})
.def("__len__",
[](const PyArrayAttribute &arr) {
@ -119,6 +125,18 @@ public:
.def("__iter__", [](const PyArrayAttribute &arr) {
return PyArrayAttributeIterator(arr);
});
c.def("__add__", [](PyArrayAttribute arr, py::list extras) {
std::vector<MlirAttribute> attributes;
intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
attributes.reserve(numOldElements + py::len(extras));
for (intptr_t i = 0; i < numOldElements; ++i)
attributes.push_back(arr.getItem(i));
for (py::handle attr : extras)
attributes.push_back(pyTryCast<PyAttribute>(attr));
MlirAttribute arrayAttr = mlirArrayAttrGet(
arr.getContext()->get(), attributes.size(), attributes.data());
return PyArrayAttribute(arr.getContext(), arrayAttr);
});
}
};
@ -602,7 +620,7 @@ public:
mlirNamedAttributes.data());
return PyDictAttribute(context->getRef(), attr);
},
py::arg("value"), py::arg("context") = py::none(),
py::arg("value") = py::dict(), py::arg("context") = py::none(),
"Gets an uniqued dict attribute");
c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
MlirAttribute attr =

View File

@ -1590,6 +1590,19 @@ public:
}
};
/// Returns the list of types of the values held by container.
template <typename Container>
static std::vector<PyType> getValueTypes(Container &container,
PyMlirContextRef &context) {
std::vector<PyType> result;
result.reserve(container.getNumElements());
for (int i = 0, e = container.getNumElements(); i < e; ++i) {
result.push_back(
PyType(context, mlirValueGetType(container.getElement(i).get())));
}
return result;
}
/// A list of block arguments. Internally, these are stored as consecutive
/// elements, random access is cheap. The argument list is associated with the
/// operation that contains the block (detached blocks are not allowed in
@ -1625,6 +1638,12 @@ public:
return PyBlockArgumentList(operation, block, startIndex, length, step);
}
static void bindDerived(ClassTy &c) {
c.def_property_readonly("types", [](PyBlockArgumentList &self) {
return getValueTypes(self, self.operation->getContext());
});
}
private:
PyOperationRef operation;
MlirBlock block;
@ -1712,6 +1731,12 @@ public:
return PyOpResultList(operation, startIndex, length, step);
}
static void bindDerived(ClassTy &c) {
c.def_property_readonly("types", [](PyOpResultList &self) {
return getValueTypes(self, self.operation->getContext());
});
}
private:
PyOperationRef operation;
};

View File

@ -343,6 +343,9 @@ def testDictAttr():
else:
assert False, "expected IndexError on accessing an out-of-bounds attribute"
# CHECK "empty: {}"
print("empty: ", DictAttr.get())
# CHECK-LABEL: TEST: testTypeAttr
@run
@ -404,3 +407,9 @@ def testArrayAttr():
except RuntimeError as e:
# CHECK: Error: Invalid attribute when attempting to create an ArrayAttribute
print("Error: ", e)
with Context():
array = ArrayAttr.get([StringAttr.get("a"), StringAttr.get("b")])
array = array + [StringAttr.get("c")]
# CHECK: concat: ["a", "b", "c"]
print("concat: ", array)

View File

@ -145,6 +145,12 @@ def testBlockArgumentList():
print("Length: ",
len(entry_block.arguments[:2] + entry_block.arguments[1:]))
# CHECK: Type: i8
# CHECK: Type: i16
# CHECK: Type: i24
for t in entry_block.arguments.types:
print("Type: ", t)
run(testBlockArgumentList)
@ -380,6 +386,12 @@ def testOperationResultList():
for res in call.results:
print(f"Result {res.result_number}, type {res.type}")
# CHECK: Result type i32
# CHECK: Result type f64
# CHECK: Result type index
for t in call.results.types:
print(f"Result type {t}")
run(testOperationResultList)