forked from OSchip/llvm-project
[mlir][Python] Fix lifetime of ExecutionEngine runtime functions.
We weren't retaining the ctypes closures that the ExecutionEngine was calling back into, leading to mysterious errors. Open to feedback about how to test this. And an extra pair of eyes to make sure I caught all the places that need to be aware of this. Differential Revision: https://reviews.llvm.org/D110661
This commit is contained in:
parent
aa53785f23
commit
204d301bb1
|
@ -31,12 +31,21 @@ public:
|
|||
}
|
||||
MlirExecutionEngine get() { return executionEngine; }
|
||||
|
||||
void release() { executionEngine.ptr = nullptr; }
|
||||
void release() {
|
||||
executionEngine.ptr = nullptr;
|
||||
referencedObjects.clear();
|
||||
}
|
||||
pybind11::object getCapsule() {
|
||||
return py::reinterpret_steal<py::object>(
|
||||
mlirPythonExecutionEngineToCapsule(get()));
|
||||
}
|
||||
|
||||
// Add an object to the list of referenced objects whose lifetime must exceed
|
||||
// those of the ExecutionEngine.
|
||||
void addReferencedObject(pybind11::object obj) {
|
||||
referencedObjects.push_back(obj);
|
||||
}
|
||||
|
||||
static pybind11::object createFromCapsule(pybind11::object capsule) {
|
||||
MlirExecutionEngine rawPm =
|
||||
mlirPythonCapsuleToExecutionEngine(capsule.ptr());
|
||||
|
@ -47,6 +56,10 @@ public:
|
|||
|
||||
private:
|
||||
MlirExecutionEngine executionEngine;
|
||||
// We support Python ctypes closures as callbacks. Keep a list of the objects
|
||||
// so that they don't get garbage collected. (The ExecutionEngine itself
|
||||
// just holds raw pointers with no lifetime semantics).
|
||||
std::vector<py::object> referencedObjects;
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
@ -96,13 +109,17 @@ PYBIND11_MODULE(_mlirExecutionEngine, m) {
|
|||
.def(
|
||||
"raw_register_runtime",
|
||||
[](PyExecutionEngine &executionEngine, const std::string &name,
|
||||
uintptr_t sym) {
|
||||
py::object callbackObj) {
|
||||
executionEngine.addReferencedObject(callbackObj);
|
||||
uintptr_t rawSym =
|
||||
py::cast<uintptr_t>(py::getattr(callbackObj, "value"));
|
||||
mlirExecutionEngineRegisterSymbol(
|
||||
executionEngine.get(),
|
||||
mlirStringRefCreate(name.c_str(), name.size()),
|
||||
reinterpret_cast<void *>(sym));
|
||||
reinterpret_cast<void *>(rawSym));
|
||||
},
|
||||
"Lookup function `func` in the ExecutionEngine.")
|
||||
py::arg("name"), py::arg("callback"),
|
||||
"Register `callback` as the runtime symbol `name`.")
|
||||
.def(
|
||||
"dump_to_object_file",
|
||||
[](PyExecutionEngine &executionEngine, const std::string &fileName) {
|
||||
|
|
|
@ -39,5 +39,5 @@ class ExecutionEngine(_execution_engine.ExecutionEngine):
|
|||
under the provided `name`. The `ctypes_callback` must be a
|
||||
`CFuncType` that outlives the execution engine.
|
||||
"""
|
||||
callback = ctypes.cast(ctypes_callback, ctypes.c_void_p).value
|
||||
callback = ctypes.cast(ctypes_callback, ctypes.c_void_p)
|
||||
self.raw_register_runtime("_mlir_ciface_" + name, callback)
|
||||
|
|
Loading…
Reference in New Issue