forked from OSchip/llvm-project
First pass on MLIR python context lifetime management.
* Per thread https://llvm.discourse.group/t/revisiting-ownership-and-lifetime-in-the-python-bindings/1769 * Reworks contexts so it is always possible to get back to a py::object that holds the reference count for an arbitrary MlirContext. * Retrofits some of the base classes to automatically take a reference to the context, elimintating keep_alives. * More needs to be done, as discussed, when moving on to the operations/blocks/regions. Differential Revision: https://reviews.llvm.org/D87886
This commit is contained in:
parent
7bd75b6301
commit
85185b61b6
|
@ -103,6 +103,9 @@ MlirLocation mlirLocationFileLineColGet(MlirContext context,
|
|||
/** Creates a location with unknown position owned by the given context. */
|
||||
MlirLocation mlirLocationUnknownGet(MlirContext context);
|
||||
|
||||
/** Gets the context that a location was created with. */
|
||||
MlirContext mlirLocationGetContext(MlirLocation location);
|
||||
|
||||
/** Prints a location by sending chunks of the string representation and
|
||||
* forwarding `userData to `callback`. Note that the callback may be called
|
||||
* several times with consecutive chunks of the string. */
|
||||
|
@ -119,6 +122,9 @@ MlirModule mlirModuleCreateEmpty(MlirLocation location);
|
|||
/** Parses a module from the string and transfers ownership to the caller. */
|
||||
MlirModule mlirModuleCreateParse(MlirContext context, const char *module);
|
||||
|
||||
/** Gets the context that a module was created with. */
|
||||
MlirContext mlirModuleGetContext(MlirModule module);
|
||||
|
||||
/** Checks whether a module is null. */
|
||||
inline int mlirModuleIsNull(MlirModule module) { return !module.ptr; }
|
||||
|
||||
|
@ -342,6 +348,9 @@ void mlirTypeDump(MlirType type);
|
|||
/** Parses an attribute. The attribute is owned by the context. */
|
||||
MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr);
|
||||
|
||||
/** Gets the context that an attribute was created with. */
|
||||
MlirContext mlirAttributeGetContext(MlirAttribute attribute);
|
||||
|
||||
/** Checks whether an attribute is null. */
|
||||
inline int mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; }
|
||||
|
||||
|
|
|
@ -170,6 +170,51 @@ int mlirTypeIsAIntegerOrFloat(MlirType type) {
|
|||
|
||||
} // namespace
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// PyMlirContext
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
PyMlirContext *PyMlirContextRef::release() {
|
||||
object.release();
|
||||
return &referrent;
|
||||
}
|
||||
|
||||
PyMlirContext::PyMlirContext(MlirContext context) : context(context) {}
|
||||
|
||||
PyMlirContext::~PyMlirContext() {
|
||||
// Note that the only public way to construct an instance is via the
|
||||
// forContext method, which always puts the associated handle into
|
||||
// liveContexts.
|
||||
py::gil_scoped_acquire acquire;
|
||||
getLiveContexts().erase(context.ptr);
|
||||
mlirContextDestroy(context);
|
||||
}
|
||||
|
||||
PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
|
||||
py::gil_scoped_acquire acquire;
|
||||
auto &liveContexts = getLiveContexts();
|
||||
auto it = liveContexts.find(context.ptr);
|
||||
if (it == liveContexts.end()) {
|
||||
// Create.
|
||||
PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
|
||||
py::object pyRef = py::cast(unownedContextWrapper);
|
||||
unownedContextWrapper->handle = pyRef;
|
||||
liveContexts[context.ptr] = std::make_pair(pyRef, unownedContextWrapper);
|
||||
return PyMlirContextRef(*unownedContextWrapper, std::move(pyRef));
|
||||
} else {
|
||||
// Use existing.
|
||||
py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
|
||||
return PyMlirContextRef(*it->second.second, std::move(pyRef));
|
||||
}
|
||||
}
|
||||
|
||||
PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
|
||||
static LiveContextMap liveContexts;
|
||||
return liveContexts;
|
||||
}
|
||||
|
||||
size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// PyBlock, PyRegion, and PyOperation.
|
||||
//------------------------------------------------------------------------------
|
||||
|
@ -234,9 +279,10 @@ public:
|
|||
using IsAFunctionTy = int (*)(MlirAttribute);
|
||||
|
||||
PyConcreteAttribute() = default;
|
||||
PyConcreteAttribute(MlirAttribute attr) : BaseTy(attr) {}
|
||||
PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
|
||||
: BaseTy(std::move(contextRef), attr) {}
|
||||
PyConcreteAttribute(PyAttribute &orig)
|
||||
: PyConcreteAttribute(castFrom(orig)) {}
|
||||
: PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
|
||||
|
||||
static MlirAttribute castFrom(PyAttribute &orig) {
|
||||
if (!DerivedTy::isaFunction(orig.attr)) {
|
||||
|
@ -269,18 +315,18 @@ public:
|
|||
"get",
|
||||
[](PyMlirContext &context, std::string value) {
|
||||
MlirAttribute attr =
|
||||
mlirStringAttrGet(context.context, value.size(), &value[0]);
|
||||
return PyStringAttribute(attr);
|
||||
mlirStringAttrGet(context.get(), value.size(), &value[0]);
|
||||
return PyStringAttribute(context.getRef(), attr);
|
||||
},
|
||||
py::keep_alive<0, 1>(), "Gets a uniqued string attribute");
|
||||
"Gets a uniqued string attribute");
|
||||
c.def_static(
|
||||
"get_typed",
|
||||
[](PyType &type, std::string value) {
|
||||
MlirAttribute attr =
|
||||
mlirStringAttrTypedGet(type.type, value.size(), &value[0]);
|
||||
return PyStringAttribute(attr);
|
||||
return PyStringAttribute(type.getContext(), attr);
|
||||
},
|
||||
py::keep_alive<0, 1>(),
|
||||
|
||||
"Gets a uniqued string attribute associated to a type");
|
||||
c.def_property_readonly(
|
||||
"value",
|
||||
|
@ -315,8 +361,10 @@ public:
|
|||
using IsAFunctionTy = int (*)(MlirType);
|
||||
|
||||
PyConcreteType() = default;
|
||||
PyConcreteType(MlirType t) : BaseTy(t) {}
|
||||
PyConcreteType(PyType &orig) : PyConcreteType(castFrom(orig)) {}
|
||||
PyConcreteType(PyMlirContextRef contextRef, MlirType t)
|
||||
: BaseTy(std::move(contextRef), t) {}
|
||||
PyConcreteType(PyType &orig)
|
||||
: PyConcreteType(orig.getContext(), castFrom(orig)) {}
|
||||
|
||||
static MlirType castFrom(PyType &orig) {
|
||||
if (!DerivedTy::isaFunction(orig.type)) {
|
||||
|
@ -348,24 +396,24 @@ public:
|
|||
c.def_static(
|
||||
"get_signless",
|
||||
[](PyMlirContext &context, unsigned width) {
|
||||
MlirType t = mlirIntegerTypeGet(context.context, width);
|
||||
return PyIntegerType(t);
|
||||
MlirType t = mlirIntegerTypeGet(context.get(), width);
|
||||
return PyIntegerType(context.getRef(), t);
|
||||
},
|
||||
py::keep_alive<0, 1>(), "Create a signless integer type");
|
||||
"Create a signless integer type");
|
||||
c.def_static(
|
||||
"get_signed",
|
||||
[](PyMlirContext &context, unsigned width) {
|
||||
MlirType t = mlirIntegerTypeSignedGet(context.context, width);
|
||||
return PyIntegerType(t);
|
||||
MlirType t = mlirIntegerTypeSignedGet(context.get(), width);
|
||||
return PyIntegerType(context.getRef(), t);
|
||||
},
|
||||
py::keep_alive<0, 1>(), "Create a signed integer type");
|
||||
"Create a signed integer type");
|
||||
c.def_static(
|
||||
"get_unsigned",
|
||||
[](PyMlirContext &context, unsigned width) {
|
||||
MlirType t = mlirIntegerTypeUnsignedGet(context.context, width);
|
||||
return PyIntegerType(t);
|
||||
MlirType t = mlirIntegerTypeUnsignedGet(context.get(), width);
|
||||
return PyIntegerType(context.getRef(), t);
|
||||
},
|
||||
py::keep_alive<0, 1>(), "Create an unsigned integer type");
|
||||
"Create an unsigned integer type");
|
||||
c.def_property_readonly(
|
||||
"width",
|
||||
[](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self.type); },
|
||||
|
@ -400,10 +448,10 @@ public:
|
|||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def(py::init([](PyMlirContext &context) {
|
||||
MlirType t = mlirIndexTypeGet(context.context);
|
||||
return PyIndexType(t);
|
||||
MlirType t = mlirIndexTypeGet(context.get());
|
||||
return PyIndexType(context.getRef(), t);
|
||||
}),
|
||||
py::keep_alive<0, 1>(), "Create a index type.");
|
||||
"Create a index type.");
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -416,10 +464,10 @@ public:
|
|||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def(py::init([](PyMlirContext &context) {
|
||||
MlirType t = mlirBF16TypeGet(context.context);
|
||||
return PyBF16Type(t);
|
||||
MlirType t = mlirBF16TypeGet(context.get());
|
||||
return PyBF16Type(context.getRef(), t);
|
||||
}),
|
||||
py::keep_alive<0, 1>(), "Create a bf16 type.");
|
||||
"Create a bf16 type.");
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -432,10 +480,10 @@ public:
|
|||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def(py::init([](PyMlirContext &context) {
|
||||
MlirType t = mlirF16TypeGet(context.context);
|
||||
return PyF16Type(t);
|
||||
MlirType t = mlirF16TypeGet(context.get());
|
||||
return PyF16Type(context.getRef(), t);
|
||||
}),
|
||||
py::keep_alive<0, 1>(), "Create a f16 type.");
|
||||
"Create a f16 type.");
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -448,10 +496,10 @@ public:
|
|||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def(py::init([](PyMlirContext &context) {
|
||||
MlirType t = mlirF32TypeGet(context.context);
|
||||
return PyF32Type(t);
|
||||
MlirType t = mlirF32TypeGet(context.get());
|
||||
return PyF32Type(context.getRef(), t);
|
||||
}),
|
||||
py::keep_alive<0, 1>(), "Create a f32 type.");
|
||||
"Create a f32 type.");
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -464,10 +512,10 @@ public:
|
|||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def(py::init([](PyMlirContext &context) {
|
||||
MlirType t = mlirF64TypeGet(context.context);
|
||||
return PyF64Type(t);
|
||||
MlirType t = mlirF64TypeGet(context.get());
|
||||
return PyF64Type(context.getRef(), t);
|
||||
}),
|
||||
py::keep_alive<0, 1>(), "Create a f64 type.");
|
||||
"Create a f64 type.");
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -480,10 +528,10 @@ public:
|
|||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def(py::init([](PyMlirContext &context) {
|
||||
MlirType t = mlirNoneTypeGet(context.context);
|
||||
return PyNoneType(t);
|
||||
MlirType t = mlirNoneTypeGet(context.get());
|
||||
return PyNoneType(context.getRef(), t);
|
||||
}),
|
||||
py::keep_alive<0, 1>(), "Create a none type.");
|
||||
"Create a none type.");
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -501,7 +549,7 @@ public:
|
|||
// The element must be a floating point or integer scalar type.
|
||||
if (mlirTypeIsAIntegerOrFloat(elementType.type)) {
|
||||
MlirType t = mlirComplexTypeGet(elementType.type);
|
||||
return PyComplexType(t);
|
||||
return PyComplexType(elementType.getContext(), t);
|
||||
}
|
||||
throw SetPyError(
|
||||
PyExc_ValueError,
|
||||
|
@ -509,12 +557,12 @@ public:
|
|||
py::repr(py::cast(elementType)).cast<std::string>() +
|
||||
"' and expected floating point or integer type.");
|
||||
},
|
||||
py::keep_alive<0, 1>(), "Create a complex type");
|
||||
"Create a complex type");
|
||||
c.def_property_readonly(
|
||||
"element_type",
|
||||
[](PyComplexType &self) -> PyType {
|
||||
MlirType t = mlirComplexTypeGetElementType(self.type);
|
||||
return PyType(t);
|
||||
return PyType(self.getContext(), t);
|
||||
},
|
||||
"Returns element type.");
|
||||
}
|
||||
|
@ -531,9 +579,9 @@ public:
|
|||
"element_type",
|
||||
[](PyShapedType &self) {
|
||||
MlirType t = mlirShapedTypeGetElementType(self.type);
|
||||
return PyType(t);
|
||||
return PyType(self.getContext(), t);
|
||||
},
|
||||
py::keep_alive<0, 1>(), "Returns the element type of the shaped type.");
|
||||
"Returns the element type of the shaped type.");
|
||||
c.def_property_readonly(
|
||||
"has_rank",
|
||||
[](PyShapedType &self) -> bool {
|
||||
|
@ -616,9 +664,9 @@ public:
|
|||
py::repr(py::cast(elementType)).cast<std::string>() +
|
||||
"' and expected floating point or integer type.");
|
||||
}
|
||||
return PyVectorType(t);
|
||||
return PyVectorType(elementType.getContext(), t);
|
||||
},
|
||||
py::keep_alive<0, 2>(), "Create a vector type");
|
||||
"Create a vector type");
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -648,9 +696,9 @@ public:
|
|||
"complex "
|
||||
"type.");
|
||||
}
|
||||
return PyRankedTensorType(t);
|
||||
return PyRankedTensorType(elementType.getContext(), t);
|
||||
},
|
||||
py::keep_alive<0, 2>(), "Create a ranked tensor type");
|
||||
"Create a ranked tensor type");
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -680,9 +728,9 @@ public:
|
|||
"complex "
|
||||
"type.");
|
||||
}
|
||||
return PyUnrankedTensorType(t);
|
||||
return PyUnrankedTensorType(elementType.getContext(), t);
|
||||
},
|
||||
py::keep_alive<0, 1>(), "Create a unranked tensor type");
|
||||
"Create a unranked tensor type");
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -715,9 +763,9 @@ public:
|
|||
"complex "
|
||||
"type.");
|
||||
}
|
||||
return PyMemRefType(t);
|
||||
return PyMemRefType(elementType.getContext(), t);
|
||||
},
|
||||
py::keep_alive<0, 1>(), "Create a memref type")
|
||||
"Create a memref type")
|
||||
.def_property_readonly(
|
||||
"num_affine_maps",
|
||||
[](PyMemRefType &self) -> intptr_t {
|
||||
|
@ -760,9 +808,9 @@ public:
|
|||
"complex "
|
||||
"type.");
|
||||
}
|
||||
return PyUnrankedMemRefType(t);
|
||||
return PyUnrankedMemRefType(elementType.getContext(), t);
|
||||
},
|
||||
py::keep_alive<0, 1>(), "Create a unranked memref type")
|
||||
"Create a unranked memref type")
|
||||
.def_property_readonly(
|
||||
"memory_space",
|
||||
[](PyUnrankedMemRefType &self) -> unsigned {
|
||||
|
@ -788,17 +836,17 @@ public:
|
|||
SmallVector<MlirType, 4> elements;
|
||||
for (auto element : elementList)
|
||||
elements.push_back(element.cast<PyType>().type);
|
||||
MlirType t = mlirTupleTypeGet(context.context, num, elements.data());
|
||||
return PyTupleType(t);
|
||||
MlirType t = mlirTupleTypeGet(context.get(), num, elements.data());
|
||||
return PyTupleType(context.getRef(), t);
|
||||
},
|
||||
py::keep_alive<0, 1>(), "Create a tuple type");
|
||||
"Create a tuple type");
|
||||
c.def(
|
||||
"get_type",
|
||||
[](PyTupleType &self, intptr_t pos) -> PyType {
|
||||
MlirType t = mlirTupleTypeGetType(self.type, pos);
|
||||
return PyType(t);
|
||||
return PyType(self.getContext(), t);
|
||||
},
|
||||
py::keep_alive<0, 1>(), "Returns the pos-th type in the tuple type.");
|
||||
"Returns the pos-th type in the tuple type.");
|
||||
c.def_property_readonly(
|
||||
"num_types",
|
||||
[](PyTupleType &self) -> intptr_t {
|
||||
|
@ -817,12 +865,21 @@ public:
|
|||
void mlir::python::populateIRSubmodule(py::module &m) {
|
||||
// Mapping of MlirContext
|
||||
py::class_<PyMlirContext>(m, "Context")
|
||||
.def(py::init<>())
|
||||
.def(py::init<>([]() {
|
||||
MlirContext context = mlirContextCreate();
|
||||
auto contextRef = PyMlirContext::forContext(context);
|
||||
return contextRef.release();
|
||||
}))
|
||||
.def_static("_get_live_count", &PyMlirContext::getLiveCount)
|
||||
.def("_get_context_again",
|
||||
[](PyMlirContext &self) {
|
||||
auto ref = PyMlirContext::forContext(self.get());
|
||||
return ref.release();
|
||||
})
|
||||
.def(
|
||||
"parse_module",
|
||||
[](PyMlirContext &self, const std::string module) {
|
||||
auto moduleRef =
|
||||
mlirModuleCreateParse(self.context, module.c_str());
|
||||
auto moduleRef = mlirModuleCreateParse(self.get(), module.c_str());
|
||||
// TODO: Rework error reporting once diagnostic engine is exposed
|
||||
// in C API.
|
||||
if (mlirModuleIsNull(moduleRef)) {
|
||||
|
@ -830,14 +887,14 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
PyExc_ValueError,
|
||||
"Unable to parse module assembly (see diagnostics)");
|
||||
}
|
||||
return PyModule(moduleRef);
|
||||
return PyModule(self.getRef(), moduleRef);
|
||||
},
|
||||
py::keep_alive<0, 1>(), kContextParseDocstring)
|
||||
kContextParseDocstring)
|
||||
.def(
|
||||
"parse_attr",
|
||||
[](PyMlirContext &self, std::string attrSpec) {
|
||||
MlirAttribute type =
|
||||
mlirAttributeParseGet(self.context, attrSpec.c_str());
|
||||
mlirAttributeParseGet(self.get(), attrSpec.c_str());
|
||||
// TODO: Rework error reporting once diagnostic engine is exposed
|
||||
// in C API.
|
||||
if (mlirAttributeIsNull(type)) {
|
||||
|
@ -845,13 +902,13 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
llvm::Twine("Unable to parse attribute: '") +
|
||||
attrSpec + "'");
|
||||
}
|
||||
return PyAttribute(type);
|
||||
return PyAttribute(self.getRef(), type);
|
||||
},
|
||||
py::keep_alive<0, 1>())
|
||||
.def(
|
||||
"parse_type",
|
||||
[](PyMlirContext &self, std::string typeSpec) {
|
||||
MlirType type = mlirTypeParseGet(self.context, typeSpec.c_str());
|
||||
MlirType type = mlirTypeParseGet(self.get(), typeSpec.c_str());
|
||||
// TODO: Rework error reporting once diagnostic engine is exposed
|
||||
// in C API.
|
||||
if (mlirTypeIsNull(type)) {
|
||||
|
@ -859,30 +916,32 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
llvm::Twine("Unable to parse type: '") +
|
||||
typeSpec + "'");
|
||||
}
|
||||
return PyType(type);
|
||||
return PyType(self.getRef(), type);
|
||||
},
|
||||
py::keep_alive<0, 1>(), kContextParseTypeDocstring)
|
||||
kContextParseTypeDocstring)
|
||||
.def(
|
||||
"get_unknown_location",
|
||||
[](PyMlirContext &self) {
|
||||
return PyLocation(mlirLocationUnknownGet(self.context));
|
||||
return PyLocation(self.getRef(),
|
||||
mlirLocationUnknownGet(self.get()));
|
||||
},
|
||||
py::keep_alive<0, 1>(), kContextGetUnknownLocationDocstring)
|
||||
kContextGetUnknownLocationDocstring)
|
||||
.def(
|
||||
"get_file_location",
|
||||
[](PyMlirContext &self, std::string filename, int line, int col) {
|
||||
return PyLocation(mlirLocationFileLineColGet(
|
||||
self.context, filename.c_str(), line, col));
|
||||
return PyLocation(self.getRef(),
|
||||
mlirLocationFileLineColGet(
|
||||
self.get(), filename.c_str(), line, col));
|
||||
},
|
||||
py::keep_alive<0, 1>(), kContextGetFileLocationDocstring,
|
||||
py::arg("filename"), py::arg("line"), py::arg("col"))
|
||||
kContextGetFileLocationDocstring, py::arg("filename"),
|
||||
py::arg("line"), py::arg("col"))
|
||||
.def(
|
||||
"create_region",
|
||||
[](PyMlirContext &self) {
|
||||
// The creating context is explicitly captured on regions to
|
||||
// facilitate illegal assemblies of objects from multiple contexts
|
||||
// that would invalidate the memory model.
|
||||
return PyRegion(self.context, mlirRegionCreate(),
|
||||
return PyRegion(self.get(), mlirRegionCreate(),
|
||||
/*detached=*/true);
|
||||
},
|
||||
py::keep_alive<0, 1>(), kContextCreateRegionDocstring)
|
||||
|
@ -893,7 +952,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
// types must be from the same context.
|
||||
for (auto pyType : pyTypes) {
|
||||
if (!mlirContextEqual(mlirTypeGetContext(pyType.type),
|
||||
self.context)) {
|
||||
self.get())) {
|
||||
throw SetPyError(
|
||||
PyExc_ValueError,
|
||||
"All types used to construct a block must be from "
|
||||
|
@ -902,8 +961,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
}
|
||||
llvm::SmallVector<MlirType, 4> types(pyTypes.begin(),
|
||||
pyTypes.end());
|
||||
return PyBlock(self.context,
|
||||
mlirBlockCreate(types.size(), &types[0]),
|
||||
return PyBlock(self.get(), mlirBlockCreate(types.size(), &types[0]),
|
||||
/*detached=*/true);
|
||||
},
|
||||
py::keep_alive<0, 1>(), kContextCreateBlockDocstring);
|
||||
|
@ -1063,7 +1121,11 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
.def_property_readonly(
|
||||
"attr",
|
||||
[](PyNamedAttribute &self) {
|
||||
return PyAttribute(self.namedAttr.attribute);
|
||||
// TODO: When named attribute is removed/refactored, also remove
|
||||
// this constructor (it does an inefficient table lookup).
|
||||
auto contextRef = PyMlirContext::forContext(
|
||||
mlirAttributeGetContext(self.namedAttr.attribute));
|
||||
return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
|
||||
},
|
||||
py::keep_alive<0, 1>(),
|
||||
"The underlying generic attribute of the NamedAttribute binding");
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace python {
|
||||
|
@ -19,28 +20,105 @@ namespace python {
|
|||
class PyMlirContext;
|
||||
class PyModule;
|
||||
|
||||
/// Holds a C++ PyMlirContext and associated py::object, making it convenient
|
||||
/// to have an auto-releasing C++-side keep-alive reference to the context.
|
||||
/// The reference to the PyMlirContext is a simple C++ reference and the
|
||||
/// py::object holds the reference count which keeps it alive.
|
||||
class PyMlirContextRef {
|
||||
public:
|
||||
PyMlirContextRef(PyMlirContext &referrent, pybind11::object object)
|
||||
: referrent(referrent), object(std::move(object)) {}
|
||||
~PyMlirContextRef() {}
|
||||
|
||||
/// Releases the object held by this instance, causing its reference count
|
||||
/// to remain artifically inflated by one. This must be used to return
|
||||
/// the referenced PyMlirContext from a function. Otherwise, the destructor
|
||||
/// of this reference would be called prior to the default take_ownership
|
||||
/// policy assuming that the reference count has been transferred to it.
|
||||
PyMlirContext *release();
|
||||
|
||||
PyMlirContext &operator->() { return referrent; }
|
||||
pybind11::object getObject() { return object; }
|
||||
|
||||
private:
|
||||
PyMlirContext &referrent;
|
||||
pybind11::object object;
|
||||
};
|
||||
|
||||
/// Wrapper around MlirContext.
|
||||
class PyMlirContext {
|
||||
public:
|
||||
PyMlirContext() { context = mlirContextCreate(); }
|
||||
~PyMlirContext() { mlirContextDestroy(context); }
|
||||
PyMlirContext() = delete;
|
||||
PyMlirContext(const PyMlirContext &) = delete;
|
||||
PyMlirContext(PyMlirContext &&) = delete;
|
||||
|
||||
/// Returns a context reference for the singleton PyMlirContext wrapper for
|
||||
/// the given context.
|
||||
static PyMlirContextRef forContext(MlirContext context);
|
||||
~PyMlirContext();
|
||||
|
||||
/// Accesses the underlying MlirContext.
|
||||
MlirContext get() { return context; }
|
||||
|
||||
/// Gets a strong reference to this context, which will ensure it is kept
|
||||
/// alive for the life of the reference.
|
||||
PyMlirContextRef getRef() {
|
||||
return PyMlirContextRef(
|
||||
*this, pybind11::reinterpret_borrow<pybind11::object>(handle));
|
||||
}
|
||||
|
||||
/// Gets the count of live context objects. Used for testing.
|
||||
static size_t getLiveCount();
|
||||
|
||||
private:
|
||||
PyMlirContext(MlirContext context);
|
||||
|
||||
// Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
|
||||
// preserving the relationship that an MlirContext maps to a single
|
||||
// PyMlirContext wrapper. This could be replaced in the future with an
|
||||
// extension mechanism on the MlirContext for stashing user pointers.
|
||||
// Note that this holds a handle, which does not imply ownership.
|
||||
// Mappings will be removed when the context is destructed.
|
||||
using LiveContextMap =
|
||||
llvm::DenseMap<void *, std::pair<pybind11::handle, PyMlirContext *>>;
|
||||
static LiveContextMap &getLiveContexts();
|
||||
|
||||
MlirContext context;
|
||||
// The handle is set as part of lookup with forContext() (post construction).
|
||||
pybind11::handle handle;
|
||||
};
|
||||
|
||||
/// Base class for all objects that directly or indirectly depend on an
|
||||
/// MlirContext. The lifetime of the context will extend at least to the
|
||||
/// lifetime of these instances.
|
||||
/// Immutable objects that depend on a context extend this directly.
|
||||
class BaseContextObject {
|
||||
public:
|
||||
BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) {}
|
||||
|
||||
/// Accesses the context reference.
|
||||
PyMlirContextRef &getContext() { return contextRef; }
|
||||
|
||||
private:
|
||||
PyMlirContextRef contextRef;
|
||||
};
|
||||
|
||||
/// Wrapper around an MlirLocation.
|
||||
class PyLocation {
|
||||
class PyLocation : public BaseContextObject {
|
||||
public:
|
||||
PyLocation(MlirLocation loc) : loc(loc) {}
|
||||
PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
|
||||
: BaseContextObject(std::move(contextRef)), loc(loc) {}
|
||||
MlirLocation loc;
|
||||
};
|
||||
|
||||
/// Wrapper around MlirModule.
|
||||
class PyModule {
|
||||
class PyModule : public BaseContextObject {
|
||||
public:
|
||||
PyModule(MlirModule module) : module(module) {}
|
||||
PyModule(PyMlirContextRef contextRef, MlirModule module)
|
||||
: BaseContextObject(std::move(contextRef)), module(module) {}
|
||||
PyModule(PyModule &) = delete;
|
||||
PyModule(PyModule &&other) {
|
||||
PyModule(PyModule &&other)
|
||||
: BaseContextObject(std::move(other.getContext())) {
|
||||
module = other.module;
|
||||
other.module.ptr = nullptr;
|
||||
}
|
||||
|
@ -120,9 +198,10 @@ private:
|
|||
|
||||
/// Wrapper around the generic MlirAttribute.
|
||||
/// The lifetime of a type is bound by the PyContext that created it.
|
||||
class PyAttribute {
|
||||
class PyAttribute : public BaseContextObject {
|
||||
public:
|
||||
PyAttribute(MlirAttribute attr) : attr(attr) {}
|
||||
PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
|
||||
: BaseContextObject(std::move(contextRef)), attr(attr) {}
|
||||
bool operator==(const PyAttribute &other);
|
||||
|
||||
MlirAttribute attr;
|
||||
|
@ -153,9 +232,10 @@ private:
|
|||
|
||||
/// Wrapper around the generic MlirType.
|
||||
/// The lifetime of a type is bound by the PyContext that created it.
|
||||
class PyType {
|
||||
class PyType : public BaseContextObject {
|
||||
public:
|
||||
PyType(MlirType type) : type(type) {}
|
||||
PyType(PyMlirContextRef contextRef, MlirType type)
|
||||
: BaseContextObject(std::move(contextRef)), type(type) {}
|
||||
bool operator==(const PyType &other);
|
||||
operator MlirType() const { return type; }
|
||||
|
||||
|
|
|
@ -48,6 +48,10 @@ MlirLocation mlirLocationUnknownGet(MlirContext context) {
|
|||
return wrap(UnknownLoc::get(unwrap(context)));
|
||||
}
|
||||
|
||||
MlirContext mlirLocationGetContext(MlirLocation location) {
|
||||
return wrap(unwrap(location).getContext());
|
||||
}
|
||||
|
||||
void mlirLocationPrint(MlirLocation location, MlirStringCallback callback,
|
||||
void *userData) {
|
||||
detail::CallbackOstream stream(callback, userData);
|
||||
|
@ -70,6 +74,10 @@ MlirModule mlirModuleCreateParse(MlirContext context, const char *module) {
|
|||
return MlirModule{owning.release().getOperation()};
|
||||
}
|
||||
|
||||
MlirContext mlirModuleGetContext(MlirModule module) {
|
||||
return wrap(unwrap(module).getContext());
|
||||
}
|
||||
|
||||
void mlirModuleDestroy(MlirModule module) {
|
||||
// Transfer ownership to an OwningModuleRef so that its destructor is called.
|
||||
OwningModuleRef(unwrap(module));
|
||||
|
@ -349,6 +357,10 @@ MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) {
|
|||
return wrap(mlir::parseAttribute(attr, unwrap(context)));
|
||||
}
|
||||
|
||||
MlirContext mlirAttributeGetContext(MlirAttribute attribute) {
|
||||
return wrap(unwrap(attribute).getContext());
|
||||
}
|
||||
|
||||
int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
|
||||
return unwrap(a1) == unwrap(a2);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
# RUN: %PYTHON %s
|
||||
# Standalone sanity check of context life-cycle.
|
||||
import gc
|
||||
import mlir
|
||||
|
||||
assert mlir.ir.Context._get_live_count() == 0
|
||||
|
||||
# Create first context.
|
||||
print("CREATE C1")
|
||||
c1 = mlir.ir.Context()
|
||||
assert mlir.ir.Context._get_live_count() == 1
|
||||
c1_repr = repr(c1)
|
||||
print("C1 = ", c1_repr)
|
||||
|
||||
print("GETTING AGAIN...")
|
||||
c2 = c1._get_context_again()
|
||||
c2_repr = repr(c2)
|
||||
assert mlir.ir.Context._get_live_count() == 1
|
||||
assert c1_repr == c2_repr
|
||||
|
||||
print("C2 =", c2)
|
||||
|
||||
# Make sure new contexts on constructor.
|
||||
print("CREATE C3")
|
||||
c3 = mlir.ir.Context()
|
||||
assert mlir.ir.Context._get_live_count() == 2
|
||||
c3_repr = repr(c3)
|
||||
print("C3 =", c3)
|
||||
assert c3_repr != c1_repr
|
||||
print("FREE C3")
|
||||
c3 = None
|
||||
gc.collect()
|
||||
assert mlir.ir.Context._get_live_count() == 1
|
||||
|
||||
print("Free C1")
|
||||
c1 = None
|
||||
gc.collect()
|
||||
assert mlir.ir.Context._get_live_count() == 1
|
||||
print("Free C2")
|
||||
c2 = None
|
||||
gc.collect()
|
||||
assert mlir.ir.Context._get_live_count() == 0
|
Loading…
Reference in New Issue