forked from OSchip/llvm-project
[mlir][python] Usability improvements for Python bindings
Provide a couple of quality-of-life usability improvements for Python bindings, in particular: * give access to the list of types for the list of op results or block arguments, similarly to ValueRange->TypeRange, * allow for constructing empty dictionary arrays, * support construction of array attributes by concatenating an existing attribute with a Python list of attributes. All these are required for the upcoming customization of builtin and standard ops. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D110946
This commit is contained in:
parent
c7bd643599
commit
ed9e52f3af
|
@ -18,7 +18,6 @@ using namespace mlir;
|
|||
using namespace mlir::python;
|
||||
|
||||
using llvm::SmallVector;
|
||||
using llvm::StringRef;
|
||||
using llvm::Twine;
|
||||
|
||||
namespace {
|
||||
|
@ -44,6 +43,24 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static T pyTryCast(py::handle object) {
|
||||
try {
|
||||
return object.cast<T>();
|
||||
} catch (py::cast_error &err) {
|
||||
std::string msg =
|
||||
std::string(
|
||||
"Invalid attribute when attempting to create an ArrayAttribute (") +
|
||||
err.what() + ")";
|
||||
throw py::cast_error(msg);
|
||||
} catch (py::reference_cast_error &err) {
|
||||
std::string msg = std::string("Invalid attribute (None?) when attempting "
|
||||
"to create an ArrayAttribute (") +
|
||||
err.what() + ")";
|
||||
throw py::cast_error(msg);
|
||||
}
|
||||
}
|
||||
|
||||
class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
|
||||
|
@ -76,6 +93,10 @@ public:
|
|||
int nextIndex = 0;
|
||||
};
|
||||
|
||||
PyAttribute getItem(intptr_t i) {
|
||||
return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i));
|
||||
}
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_static(
|
||||
"get",
|
||||
|
@ -83,21 +104,7 @@ public:
|
|||
SmallVector<MlirAttribute> mlirAttributes;
|
||||
mlirAttributes.reserve(py::len(attributes));
|
||||
for (auto attribute : attributes) {
|
||||
try {
|
||||
mlirAttributes.push_back(attribute.cast<PyAttribute>());
|
||||
} catch (py::cast_error &err) {
|
||||
std::string msg = std::string("Invalid attribute when attempting "
|
||||
"to create an ArrayAttribute (") +
|
||||
err.what() + ")";
|
||||
throw py::cast_error(msg);
|
||||
} catch (py::reference_cast_error &err) {
|
||||
// This exception seems thrown when the value is "None".
|
||||
std::string msg =
|
||||
std::string("Invalid attribute (None?) when attempting to "
|
||||
"create an ArrayAttribute (") +
|
||||
err.what() + ")";
|
||||
throw py::cast_error(msg);
|
||||
}
|
||||
mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
|
||||
}
|
||||
MlirAttribute attr = mlirArrayAttrGet(
|
||||
context->get(), mlirAttributes.size(), mlirAttributes.data());
|
||||
|
@ -109,8 +116,7 @@ public:
|
|||
[](PyArrayAttribute &arr, intptr_t i) {
|
||||
if (i >= mlirArrayAttrGetNumElements(arr))
|
||||
throw py::index_error("ArrayAttribute index out of range");
|
||||
return PyAttribute(arr.getContext(),
|
||||
mlirArrayAttrGetElement(arr, i));
|
||||
return arr.getItem(i);
|
||||
})
|
||||
.def("__len__",
|
||||
[](const PyArrayAttribute &arr) {
|
||||
|
@ -119,6 +125,18 @@ public:
|
|||
.def("__iter__", [](const PyArrayAttribute &arr) {
|
||||
return PyArrayAttributeIterator(arr);
|
||||
});
|
||||
c.def("__add__", [](PyArrayAttribute arr, py::list extras) {
|
||||
std::vector<MlirAttribute> attributes;
|
||||
intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
|
||||
attributes.reserve(numOldElements + py::len(extras));
|
||||
for (intptr_t i = 0; i < numOldElements; ++i)
|
||||
attributes.push_back(arr.getItem(i));
|
||||
for (py::handle attr : extras)
|
||||
attributes.push_back(pyTryCast<PyAttribute>(attr));
|
||||
MlirAttribute arrayAttr = mlirArrayAttrGet(
|
||||
arr.getContext()->get(), attributes.size(), attributes.data());
|
||||
return PyArrayAttribute(arr.getContext(), arrayAttr);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -602,7 +620,7 @@ public:
|
|||
mlirNamedAttributes.data());
|
||||
return PyDictAttribute(context->getRef(), attr);
|
||||
},
|
||||
py::arg("value"), py::arg("context") = py::none(),
|
||||
py::arg("value") = py::dict(), py::arg("context") = py::none(),
|
||||
"Gets an uniqued dict attribute");
|
||||
c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
|
||||
MlirAttribute attr =
|
||||
|
|
|
@ -1590,6 +1590,19 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/// Returns the list of types of the values held by container.
|
||||
template <typename Container>
|
||||
static std::vector<PyType> getValueTypes(Container &container,
|
||||
PyMlirContextRef &context) {
|
||||
std::vector<PyType> result;
|
||||
result.reserve(container.getNumElements());
|
||||
for (int i = 0, e = container.getNumElements(); i < e; ++i) {
|
||||
result.push_back(
|
||||
PyType(context, mlirValueGetType(container.getElement(i).get())));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/// A list of block arguments. Internally, these are stored as consecutive
|
||||
/// elements, random access is cheap. The argument list is associated with the
|
||||
/// operation that contains the block (detached blocks are not allowed in
|
||||
|
@ -1625,6 +1638,12 @@ public:
|
|||
return PyBlockArgumentList(operation, block, startIndex, length, step);
|
||||
}
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_property_readonly("types", [](PyBlockArgumentList &self) {
|
||||
return getValueTypes(self, self.operation->getContext());
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
PyOperationRef operation;
|
||||
MlirBlock block;
|
||||
|
@ -1712,6 +1731,12 @@ public:
|
|||
return PyOpResultList(operation, startIndex, length, step);
|
||||
}
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_property_readonly("types", [](PyOpResultList &self) {
|
||||
return getValueTypes(self, self.operation->getContext());
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
PyOperationRef operation;
|
||||
};
|
||||
|
|
|
@ -343,6 +343,9 @@ def testDictAttr():
|
|||
else:
|
||||
assert False, "expected IndexError on accessing an out-of-bounds attribute"
|
||||
|
||||
# CHECK "empty: {}"
|
||||
print("empty: ", DictAttr.get())
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testTypeAttr
|
||||
@run
|
||||
|
@ -404,3 +407,9 @@ def testArrayAttr():
|
|||
except RuntimeError as e:
|
||||
# CHECK: Error: Invalid attribute when attempting to create an ArrayAttribute
|
||||
print("Error: ", e)
|
||||
|
||||
with Context():
|
||||
array = ArrayAttr.get([StringAttr.get("a"), StringAttr.get("b")])
|
||||
array = array + [StringAttr.get("c")]
|
||||
# CHECK: concat: ["a", "b", "c"]
|
||||
print("concat: ", array)
|
||||
|
|
|
@ -145,6 +145,12 @@ def testBlockArgumentList():
|
|||
print("Length: ",
|
||||
len(entry_block.arguments[:2] + entry_block.arguments[1:]))
|
||||
|
||||
# CHECK: Type: i8
|
||||
# CHECK: Type: i16
|
||||
# CHECK: Type: i24
|
||||
for t in entry_block.arguments.types:
|
||||
print("Type: ", t)
|
||||
|
||||
|
||||
run(testBlockArgumentList)
|
||||
|
||||
|
@ -380,6 +386,12 @@ def testOperationResultList():
|
|||
for res in call.results:
|
||||
print(f"Result {res.result_number}, type {res.type}")
|
||||
|
||||
# CHECK: Result type i32
|
||||
# CHECK: Result type f64
|
||||
# CHECK: Result type index
|
||||
for t in call.results.types:
|
||||
print(f"Result type {t}")
|
||||
|
||||
|
||||
run(testOperationResultList)
|
||||
|
||||
|
|
Loading…
Reference in New Issue