[mlir][Python] Fix lifetime of ExecutionEngine runtime functions.

We weren't retaining the ctypes closures that the ExecutionEngine was
calling back into, leading to mysterious errors.

Open to feedback about how to test this. And an extra pair of eyes to
make sure I caught all the places that need to be aware of this.

Differential Revision: https://reviews.llvm.org/D110661
This commit is contained in:
Sean Silva 2021-09-28 21:58:51 +00:00
parent aa53785f23
commit 204d301bb1
2 changed files with 22 additions and 5 deletions

View File

@ -31,12 +31,21 @@ public:
}
MlirExecutionEngine get() { return executionEngine; }
void release() { executionEngine.ptr = nullptr; }
void release() {
executionEngine.ptr = nullptr;
referencedObjects.clear();
}
pybind11::object getCapsule() {
return py::reinterpret_steal<py::object>(
mlirPythonExecutionEngineToCapsule(get()));
}
// Add an object to the list of referenced objects whose lifetime must exceed
// those of the ExecutionEngine.
void addReferencedObject(pybind11::object obj) {
referencedObjects.push_back(obj);
}
static pybind11::object createFromCapsule(pybind11::object capsule) {
MlirExecutionEngine rawPm =
mlirPythonCapsuleToExecutionEngine(capsule.ptr());
@ -47,6 +56,10 @@ public:
private:
MlirExecutionEngine executionEngine;
// We support Python ctypes closures as callbacks. Keep a list of the objects
// so that they don't get garbage collected. (The ExecutionEngine itself
// just holds raw pointers with no lifetime semantics).
std::vector<py::object> referencedObjects;
};
} // anonymous namespace
@ -96,13 +109,17 @@ PYBIND11_MODULE(_mlirExecutionEngine, m) {
.def(
"raw_register_runtime",
[](PyExecutionEngine &executionEngine, const std::string &name,
uintptr_t sym) {
py::object callbackObj) {
executionEngine.addReferencedObject(callbackObj);
uintptr_t rawSym =
py::cast<uintptr_t>(py::getattr(callbackObj, "value"));
mlirExecutionEngineRegisterSymbol(
executionEngine.get(),
mlirStringRefCreate(name.c_str(), name.size()),
reinterpret_cast<void *>(sym));
reinterpret_cast<void *>(rawSym));
},
"Lookup function `func` in the ExecutionEngine.")
py::arg("name"), py::arg("callback"),
"Register `callback` as the runtime symbol `name`.")
.def(
"dump_to_object_file",
[](PyExecutionEngine &executionEngine, const std::string &fileName) {

View File

@ -39,5 +39,5 @@ class ExecutionEngine(_execution_engine.ExecutionEngine):
under the provided `name`. The `ctypes_callback` must be a
`CFuncType` that outlives the execution engine.
"""
callback = ctypes.cast(ctypes_callback, ctypes.c_void_p).value
callback = ctypes.cast(ctypes_callback, ctypes.c_void_p)
self.raw_register_runtime("_mlir_ciface_" + name, callback)