forked from OSchip/llvm-project
[mlir][python] Directly implement sequence protocol on Sliceable.
* While annoying, this is the only way to get C++ exception handling out of the happy path for normal iteration. * Implements sq_length and sq_item for the sequence protocol (used for iteration, including list() construction). * Implements mp_subscript for general use (i.e. foo[1] and foo[1:1]). * For constructing a `list(op.results)`, this reduces the time from ~4-5us to ~1.5us on my machine (give or take measurement overhead) and eliminates C++ exceptions, which is a worthy goal in itself. * Compared to a baseline of similar construction of a three-integer list, which takes 450ns (might just be measuring function call overhead). * See issue discussed on the pybind side: https://github.com/pybind/pybind11/issues/2842 Differential Revision: https://reviews.llvm.org/D119691
This commit is contained in:
parent
e404e22587
commit
429b0cf1de
|
@ -207,6 +207,8 @@ private:
|
|||
/// constructs a new instance of the derived pseudo-container with the
|
||||
/// given slice parameters (to be forwarded to the Sliceable constructor).
|
||||
///
|
||||
/// The getNumElements() and getElement(intptr_t) callbacks must not throw.
|
||||
///
|
||||
/// A derived class may additionally define:
|
||||
/// - a `static void bindDerived(ClassTy &)` method to bind additional methods
|
||||
/// the python class.
|
||||
|
@ -215,49 +217,53 @@ class Sliceable {
|
|||
protected:
|
||||
using ClassTy = pybind11::class_<Derived>;
|
||||
|
||||
// Transforms `index` into a legal value to access the underlying sequence.
|
||||
// Returns <0 on failure.
|
||||
intptr_t wrapIndex(intptr_t index) {
|
||||
if (index < 0)
|
||||
index = length + index;
|
||||
if (index < 0 || index >= length) {
|
||||
throw python::SetPyError(PyExc_IndexError,
|
||||
"attempt to access out of bounds");
|
||||
}
|
||||
if (index < 0 || index >= length)
|
||||
return -1;
|
||||
return index;
|
||||
}
|
||||
|
||||
public:
|
||||
explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
|
||||
: startIndex(startIndex), length(length), step(step) {
|
||||
assert(length >= 0 && "expected non-negative slice length");
|
||||
}
|
||||
|
||||
/// Returns the length of the slice.
|
||||
intptr_t dunderLen() const { return length; }
|
||||
|
||||
/// Returns the element at the given slice index. Supports negative indices
|
||||
/// by taking elements in inverse order. Throws if the index is out of bounds.
|
||||
ElementTy dunderGetItem(intptr_t index) {
|
||||
/// by taking elements in inverse order. Returns a nullptr object if out
|
||||
/// of bounds.
|
||||
pybind11::object getItem(intptr_t index) {
|
||||
// Negative indices mean we count from the end.
|
||||
index = wrapIndex(index);
|
||||
if (index < 0) {
|
||||
PyErr_SetString(PyExc_IndexError, "index out of range");
|
||||
return {};
|
||||
}
|
||||
|
||||
// Compute the linear index given the current slice properties.
|
||||
int linearIndex = index * step + startIndex;
|
||||
assert(linearIndex >= 0 &&
|
||||
linearIndex < static_cast<Derived *>(this)->getNumElements() &&
|
||||
"linear index out of bounds, the slice is ill-formed");
|
||||
return static_cast<Derived *>(this)->getElement(linearIndex);
|
||||
return pybind11::cast(
|
||||
static_cast<Derived *>(this)->getElement(linearIndex));
|
||||
}
|
||||
|
||||
/// Returns a new instance of the pseudo-container restricted to the given
|
||||
/// slice.
|
||||
Derived dunderGetItemSlice(pybind11::slice slice) {
|
||||
/// slice. Returns a nullptr object on failure.
|
||||
pybind11::object getItemSlice(PyObject *slice) {
|
||||
ssize_t start, stop, extraStep, sliceLength;
|
||||
if (!slice.compute(dunderLen(), &start, &stop, &extraStep, &sliceLength)) {
|
||||
throw python::SetPyError(PyExc_IndexError,
|
||||
"attempt to access out of bounds");
|
||||
if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep,
|
||||
&sliceLength) != 0) {
|
||||
PyErr_SetString(PyExc_IndexError, "index out of range");
|
||||
return {};
|
||||
}
|
||||
return static_cast<Derived *>(this)->slice(startIndex + start * step,
|
||||
sliceLength, step * extraStep);
|
||||
return pybind11::cast(static_cast<Derived *>(this)->slice(
|
||||
startIndex + start * step, sliceLength, step * extraStep));
|
||||
}
|
||||
|
||||
public:
|
||||
explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
|
||||
: startIndex(startIndex), length(length), step(step) {
|
||||
assert(length >= 0 && "expected non-negative slice length");
|
||||
}
|
||||
|
||||
/// Returns a new vector (mapped to Python list) containing elements from two
|
||||
|
@ -267,10 +273,10 @@ public:
|
|||
std::vector<ElementTy> elements;
|
||||
elements.reserve(length + other.length);
|
||||
for (intptr_t i = 0; i < length; ++i) {
|
||||
elements.push_back(dunderGetItem(i));
|
||||
elements.push_back(static_cast<Derived *>(this)->getElement(i));
|
||||
}
|
||||
for (intptr_t i = 0; i < other.length; ++i) {
|
||||
elements.push_back(other.dunderGetItem(i));
|
||||
elements.push_back(static_cast<Derived *>(this)->getElement(i));
|
||||
}
|
||||
return elements;
|
||||
}
|
||||
|
@ -279,11 +285,51 @@ public:
|
|||
static void bind(pybind11::module &m) {
|
||||
auto clazz = pybind11::class_<Derived>(m, Derived::pyClassName,
|
||||
pybind11::module_local())
|
||||
.def("__len__", &Sliceable::dunderLen)
|
||||
.def("__getitem__", &Sliceable::dunderGetItem)
|
||||
.def("__getitem__", &Sliceable::dunderGetItemSlice)
|
||||
.def("__add__", &Sliceable::dunderAdd);
|
||||
Derived::bindDerived(clazz);
|
||||
|
||||
// Manually implement the sequence protocol via the C API. We do this
|
||||
// because it is approx 4x faster than via pybind11, largely because that
|
||||
// formulation requires a C++ exception to be thrown to detect end of
|
||||
// sequence.
|
||||
// Since we are in a C-context, any C++ exception that happens here
|
||||
// will terminate the program. There is nothing in this implementation
|
||||
// that should throw in a non-terminal way, so we forgo further
|
||||
// exception marshalling.
|
||||
// See: https://github.com/pybind/pybind11/issues/2842
|
||||
auto heap_type = reinterpret_cast<PyHeapTypeObject *>(clazz.ptr());
|
||||
assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE &&
|
||||
"must be heap type");
|
||||
heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t {
|
||||
auto self = pybind11::cast<Derived *>(rawSelf);
|
||||
return self->length;
|
||||
};
|
||||
// sq_item is called as part of the sequence protocol for iteration,
|
||||
// list construction, etc.
|
||||
heap_type->as_sequence.sq_item =
|
||||
+[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * {
|
||||
auto self = pybind11::cast<Derived *>(rawSelf);
|
||||
return self->getItem(index).release().ptr();
|
||||
};
|
||||
// mp_subscript is used for both slices and integer lookups.
|
||||
heap_type->as_mapping.mp_subscript =
|
||||
+[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * {
|
||||
auto self = pybind11::cast<Derived *>(rawSelf);
|
||||
Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError);
|
||||
if (!PyErr_Occurred()) {
|
||||
// Integer indexing.
|
||||
return self->getItem(index).release().ptr();
|
||||
}
|
||||
PyErr_Clear();
|
||||
|
||||
// Assume slice-based indexing.
|
||||
if (PySlice_Check(rawSubscript)) {
|
||||
return self->getItemSlice(rawSubscript).release().ptr();
|
||||
}
|
||||
|
||||
PyErr_SetString(PyExc_ValueError, "expected integer or slice");
|
||||
return nullptr;
|
||||
};
|
||||
}
|
||||
|
||||
/// Hook for derived classes willing to bind more methods.
|
||||
|
|
|
@ -14,6 +14,14 @@ def run(f):
|
|||
return f
|
||||
|
||||
|
||||
def expect_index_error(callback):
|
||||
try:
|
||||
_ = callback()
|
||||
raise RuntimeError("Expected IndexError")
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
|
||||
# Verify iterator based traversal of the op/region/block hierarchy.
|
||||
# CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators
|
||||
@run
|
||||
|
@ -418,7 +426,9 @@ def testOperationResultList():
|
|||
for t in call.results.types:
|
||||
print(f"Result type {t}")
|
||||
|
||||
|
||||
# Out of range
|
||||
expect_index_error(lambda: call.results[3])
|
||||
expect_index_error(lambda: call.results[-4])
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testOperationResultListSlice
|
||||
|
@ -470,8 +480,6 @@ def testOperationResultListSlice():
|
|||
print(f"Result {res.result_number}, type {res.type}")
|
||||
|
||||
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testOperationAttributes
|
||||
@run
|
||||
def testOperationAttributes():
|
||||
|
|
Loading…
Reference in New Issue