[mlir][python] Directly implement sequence protocol on Sliceable.

* While annoying, this is the only way to get C++ exception handling out of the happy path for normal iteration.
* Implements sq_length and sq_item for the sequence protocol (used for iteration, including list() construction).
* Implements mp_subscript for general use (i.e. foo[1] and foo[1:1]).
* For constructing a `list(op.results)`, this reduces the time from ~4-5us to ~1.5us on my machine (give or take measurement overhead) and eliminates C++ exceptions, which is a worthy goal in itself.
  * Compared to a baseline of similar construction of a three-integer list, which takes 450ns (might just be measuring function call overhead).
  * See issue discussed on the pybind side: https://github.com/pybind/pybind11/issues/2842

Differential Revision: https://reviews.llvm.org/D119691
This commit is contained in:
Stella Laurenzo 2022-02-13 22:49:28 -08:00
parent e404e22587
commit 429b0cf1de
2 changed files with 85 additions and 31 deletions

View File

@ -207,6 +207,8 @@ private:
/// constructs a new instance of the derived pseudo-container with the /// constructs a new instance of the derived pseudo-container with the
/// given slice parameters (to be forwarded to the Sliceable constructor). /// given slice parameters (to be forwarded to the Sliceable constructor).
/// ///
/// The getNumElements() and getElement(intptr_t) callbacks must not throw.
///
/// A derived class may additionally define: /// A derived class may additionally define:
/// - a `static void bindDerived(ClassTy &)` method to bind additional methods /// - a `static void bindDerived(ClassTy &)` method to bind additional methods
/// the python class. /// the python class.
@ -215,49 +217,53 @@ class Sliceable {
protected: protected:
using ClassTy = pybind11::class_<Derived>; using ClassTy = pybind11::class_<Derived>;
// Transforms `index` into a legal value to access the underlying sequence.
// Returns <0 on failure.
intptr_t wrapIndex(intptr_t index) { intptr_t wrapIndex(intptr_t index) {
if (index < 0) if (index < 0)
index = length + index; index = length + index;
if (index < 0 || index >= length) { if (index < 0 || index >= length)
throw python::SetPyError(PyExc_IndexError, return -1;
"attempt to access out of bounds");
}
return index; return index;
} }
public:
explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
: startIndex(startIndex), length(length), step(step) {
assert(length >= 0 && "expected non-negative slice length");
}
/// Returns the length of the slice.
intptr_t dunderLen() const { return length; }
/// Returns the element at the given slice index. Supports negative indices /// Returns the element at the given slice index. Supports negative indices
/// by taking elements in inverse order. Throws if the index is out of bounds. /// by taking elements in inverse order. Returns a nullptr object if out
ElementTy dunderGetItem(intptr_t index) { /// of bounds.
pybind11::object getItem(intptr_t index) {
// Negative indices mean we count from the end. // Negative indices mean we count from the end.
index = wrapIndex(index); index = wrapIndex(index);
if (index < 0) {
PyErr_SetString(PyExc_IndexError, "index out of range");
return {};
}
// Compute the linear index given the current slice properties. // Compute the linear index given the current slice properties.
int linearIndex = index * step + startIndex; int linearIndex = index * step + startIndex;
assert(linearIndex >= 0 && assert(linearIndex >= 0 &&
linearIndex < static_cast<Derived *>(this)->getNumElements() && linearIndex < static_cast<Derived *>(this)->getNumElements() &&
"linear index out of bounds, the slice is ill-formed"); "linear index out of bounds, the slice is ill-formed");
return static_cast<Derived *>(this)->getElement(linearIndex); return pybind11::cast(
static_cast<Derived *>(this)->getElement(linearIndex));
} }
/// Returns a new instance of the pseudo-container restricted to the given /// Returns a new instance of the pseudo-container restricted to the given
/// slice. /// slice. Returns a nullptr object on failure.
Derived dunderGetItemSlice(pybind11::slice slice) { pybind11::object getItemSlice(PyObject *slice) {
ssize_t start, stop, extraStep, sliceLength; ssize_t start, stop, extraStep, sliceLength;
if (!slice.compute(dunderLen(), &start, &stop, &extraStep, &sliceLength)) { if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep,
throw python::SetPyError(PyExc_IndexError, &sliceLength) != 0) {
"attempt to access out of bounds"); PyErr_SetString(PyExc_IndexError, "index out of range");
return {};
} }
return static_cast<Derived *>(this)->slice(startIndex + start * step, return pybind11::cast(static_cast<Derived *>(this)->slice(
sliceLength, step * extraStep); startIndex + start * step, sliceLength, step * extraStep));
}
public:
explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
: startIndex(startIndex), length(length), step(step) {
assert(length >= 0 && "expected non-negative slice length");
} }
/// Returns a new vector (mapped to Python list) containing elements from two /// Returns a new vector (mapped to Python list) containing elements from two
@ -267,10 +273,10 @@ public:
std::vector<ElementTy> elements; std::vector<ElementTy> elements;
elements.reserve(length + other.length); elements.reserve(length + other.length);
for (intptr_t i = 0; i < length; ++i) { for (intptr_t i = 0; i < length; ++i) {
elements.push_back(dunderGetItem(i)); elements.push_back(static_cast<Derived *>(this)->getElement(i));
} }
for (intptr_t i = 0; i < other.length; ++i) { for (intptr_t i = 0; i < other.length; ++i) {
elements.push_back(other.dunderGetItem(i)); elements.push_back(static_cast<Derived *>(this)->getElement(i));
} }
return elements; return elements;
} }
@ -279,11 +285,51 @@ public:
static void bind(pybind11::module &m) { static void bind(pybind11::module &m) {
auto clazz = pybind11::class_<Derived>(m, Derived::pyClassName, auto clazz = pybind11::class_<Derived>(m, Derived::pyClassName,
pybind11::module_local()) pybind11::module_local())
.def("__len__", &Sliceable::dunderLen)
.def("__getitem__", &Sliceable::dunderGetItem)
.def("__getitem__", &Sliceable::dunderGetItemSlice)
.def("__add__", &Sliceable::dunderAdd); .def("__add__", &Sliceable::dunderAdd);
Derived::bindDerived(clazz); Derived::bindDerived(clazz);
// Manually implement the sequence protocol via the C API. We do this
// because it is approx 4x faster than via pybind11, largely because that
// formulation requires a C++ exception to be thrown to detect end of
// sequence.
// Since we are in a C-context, any C++ exception that happens here
// will terminate the program. There is nothing in this implementation
// that should throw in a non-terminal way, so we forgo further
// exception marshalling.
// See: https://github.com/pybind/pybind11/issues/2842
auto heap_type = reinterpret_cast<PyHeapTypeObject *>(clazz.ptr());
assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE &&
"must be heap type");
heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t {
auto self = pybind11::cast<Derived *>(rawSelf);
return self->length;
};
// sq_item is called as part of the sequence protocol for iteration,
// list construction, etc.
heap_type->as_sequence.sq_item =
+[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * {
auto self = pybind11::cast<Derived *>(rawSelf);
return self->getItem(index).release().ptr();
};
// mp_subscript is used for both slices and integer lookups.
heap_type->as_mapping.mp_subscript =
+[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * {
auto self = pybind11::cast<Derived *>(rawSelf);
Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError);
if (!PyErr_Occurred()) {
// Integer indexing.
return self->getItem(index).release().ptr();
}
PyErr_Clear();
// Assume slice-based indexing.
if (PySlice_Check(rawSubscript)) {
return self->getItemSlice(rawSubscript).release().ptr();
}
PyErr_SetString(PyExc_ValueError, "expected integer or slice");
return nullptr;
};
} }
/// Hook for derived classes willing to bind more methods. /// Hook for derived classes willing to bind more methods.

View File

@ -14,6 +14,14 @@ def run(f):
return f return f
def expect_index_error(callback):
try:
_ = callback()
raise RuntimeError("Expected IndexError")
except IndexError:
pass
# Verify iterator based traversal of the op/region/block hierarchy. # Verify iterator based traversal of the op/region/block hierarchy.
# CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators # CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators
@run @run
@ -418,7 +426,9 @@ def testOperationResultList():
for t in call.results.types: for t in call.results.types:
print(f"Result type {t}") print(f"Result type {t}")
# Out of range
expect_index_error(lambda: call.results[3])
expect_index_error(lambda: call.results[-4])
# CHECK-LABEL: TEST: testOperationResultListSlice # CHECK-LABEL: TEST: testOperationResultListSlice
@ -470,8 +480,6 @@ def testOperationResultListSlice():
print(f"Result {res.result_number}, type {res.type}") print(f"Result {res.result_number}, type {res.type}")
# CHECK-LABEL: TEST: testOperationAttributes # CHECK-LABEL: TEST: testOperationAttributes
@run @run
def testOperationAttributes(): def testOperationAttributes():