First pass on MLIR python context lifetime management.

* Per thread https://llvm.discourse.group/t/revisiting-ownership-and-lifetime-in-the-python-bindings/1769
* Reworks contexts so it is always possible to get back to a py::object that holds the reference count for an arbitrary MlirContext.
* Retrofits some of the base classes to automatically take a reference to the context, elimintating keep_alives.
* More needs to be done, as discussed, when moving on to the operations/blocks/regions.

Differential Revision: https://reviews.llvm.org/D87886
This commit is contained in:
Stella Laurenzo 2020-09-18 00:21:09 -07:00
parent 7bd75b6301
commit 85185b61b6
5 changed files with 293 additions and 88 deletions

View File

@ -103,6 +103,9 @@ MlirLocation mlirLocationFileLineColGet(MlirContext context,
/** Creates a location with unknown position owned by the given context. */
MlirLocation mlirLocationUnknownGet(MlirContext context);
/** Gets the context that a location was created with. */
MlirContext mlirLocationGetContext(MlirLocation location);
/** Prints a location by sending chunks of the string representation and
* forwarding `userData to `callback`. Note that the callback may be called
* several times with consecutive chunks of the string. */
@ -119,6 +122,9 @@ MlirModule mlirModuleCreateEmpty(MlirLocation location);
/** Parses a module from the string and transfers ownership to the caller. */
MlirModule mlirModuleCreateParse(MlirContext context, const char *module);
/** Gets the context that a module was created with. */
MlirContext mlirModuleGetContext(MlirModule module);
/** Checks whether a module is null. */
inline int mlirModuleIsNull(MlirModule module) { return !module.ptr; }
@ -342,6 +348,9 @@ void mlirTypeDump(MlirType type);
/** Parses an attribute. The attribute is owned by the context. */
MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr);
/** Gets the context that an attribute was created with. */
MlirContext mlirAttributeGetContext(MlirAttribute attribute);
/** Checks whether an attribute is null. */
inline int mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; }

View File

@ -170,6 +170,51 @@ int mlirTypeIsAIntegerOrFloat(MlirType type) {
} // namespace
//------------------------------------------------------------------------------
// PyMlirContext
//------------------------------------------------------------------------------
PyMlirContext *PyMlirContextRef::release() {
object.release();
return &referrent;
}
PyMlirContext::PyMlirContext(MlirContext context) : context(context) {}
PyMlirContext::~PyMlirContext() {
// Note that the only public way to construct an instance is via the
// forContext method, which always puts the associated handle into
// liveContexts.
py::gil_scoped_acquire acquire;
getLiveContexts().erase(context.ptr);
mlirContextDestroy(context);
}
PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
py::gil_scoped_acquire acquire;
auto &liveContexts = getLiveContexts();
auto it = liveContexts.find(context.ptr);
if (it == liveContexts.end()) {
// Create.
PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
py::object pyRef = py::cast(unownedContextWrapper);
unownedContextWrapper->handle = pyRef;
liveContexts[context.ptr] = std::make_pair(pyRef, unownedContextWrapper);
return PyMlirContextRef(*unownedContextWrapper, std::move(pyRef));
} else {
// Use existing.
py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
return PyMlirContextRef(*it->second.second, std::move(pyRef));
}
}
PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
static LiveContextMap liveContexts;
return liveContexts;
}
size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
//------------------------------------------------------------------------------
// PyBlock, PyRegion, and PyOperation.
//------------------------------------------------------------------------------
@ -234,9 +279,10 @@ public:
using IsAFunctionTy = int (*)(MlirAttribute);
PyConcreteAttribute() = default;
PyConcreteAttribute(MlirAttribute attr) : BaseTy(attr) {}
PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
: BaseTy(std::move(contextRef), attr) {}
PyConcreteAttribute(PyAttribute &orig)
: PyConcreteAttribute(castFrom(orig)) {}
: PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
static MlirAttribute castFrom(PyAttribute &orig) {
if (!DerivedTy::isaFunction(orig.attr)) {
@ -269,18 +315,18 @@ public:
"get",
[](PyMlirContext &context, std::string value) {
MlirAttribute attr =
mlirStringAttrGet(context.context, value.size(), &value[0]);
return PyStringAttribute(attr);
mlirStringAttrGet(context.get(), value.size(), &value[0]);
return PyStringAttribute(context.getRef(), attr);
},
py::keep_alive<0, 1>(), "Gets a uniqued string attribute");
"Gets a uniqued string attribute");
c.def_static(
"get_typed",
[](PyType &type, std::string value) {
MlirAttribute attr =
mlirStringAttrTypedGet(type.type, value.size(), &value[0]);
return PyStringAttribute(attr);
return PyStringAttribute(type.getContext(), attr);
},
py::keep_alive<0, 1>(),
"Gets a uniqued string attribute associated to a type");
c.def_property_readonly(
"value",
@ -315,8 +361,10 @@ public:
using IsAFunctionTy = int (*)(MlirType);
PyConcreteType() = default;
PyConcreteType(MlirType t) : BaseTy(t) {}
PyConcreteType(PyType &orig) : PyConcreteType(castFrom(orig)) {}
PyConcreteType(PyMlirContextRef contextRef, MlirType t)
: BaseTy(std::move(contextRef), t) {}
PyConcreteType(PyType &orig)
: PyConcreteType(orig.getContext(), castFrom(orig)) {}
static MlirType castFrom(PyType &orig) {
if (!DerivedTy::isaFunction(orig.type)) {
@ -348,24 +396,24 @@ public:
c.def_static(
"get_signless",
[](PyMlirContext &context, unsigned width) {
MlirType t = mlirIntegerTypeGet(context.context, width);
return PyIntegerType(t);
MlirType t = mlirIntegerTypeGet(context.get(), width);
return PyIntegerType(context.getRef(), t);
},
py::keep_alive<0, 1>(), "Create a signless integer type");
"Create a signless integer type");
c.def_static(
"get_signed",
[](PyMlirContext &context, unsigned width) {
MlirType t = mlirIntegerTypeSignedGet(context.context, width);
return PyIntegerType(t);
MlirType t = mlirIntegerTypeSignedGet(context.get(), width);
return PyIntegerType(context.getRef(), t);
},
py::keep_alive<0, 1>(), "Create a signed integer type");
"Create a signed integer type");
c.def_static(
"get_unsigned",
[](PyMlirContext &context, unsigned width) {
MlirType t = mlirIntegerTypeUnsignedGet(context.context, width);
return PyIntegerType(t);
MlirType t = mlirIntegerTypeUnsignedGet(context.get(), width);
return PyIntegerType(context.getRef(), t);
},
py::keep_alive<0, 1>(), "Create an unsigned integer type");
"Create an unsigned integer type");
c.def_property_readonly(
"width",
[](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self.type); },
@ -400,10 +448,10 @@ public:
static void bindDerived(ClassTy &c) {
c.def(py::init([](PyMlirContext &context) {
MlirType t = mlirIndexTypeGet(context.context);
return PyIndexType(t);
MlirType t = mlirIndexTypeGet(context.get());
return PyIndexType(context.getRef(), t);
}),
py::keep_alive<0, 1>(), "Create a index type.");
"Create a index type.");
}
};
@ -416,10 +464,10 @@ public:
static void bindDerived(ClassTy &c) {
c.def(py::init([](PyMlirContext &context) {
MlirType t = mlirBF16TypeGet(context.context);
return PyBF16Type(t);
MlirType t = mlirBF16TypeGet(context.get());
return PyBF16Type(context.getRef(), t);
}),
py::keep_alive<0, 1>(), "Create a bf16 type.");
"Create a bf16 type.");
}
};
@ -432,10 +480,10 @@ public:
static void bindDerived(ClassTy &c) {
c.def(py::init([](PyMlirContext &context) {
MlirType t = mlirF16TypeGet(context.context);
return PyF16Type(t);
MlirType t = mlirF16TypeGet(context.get());
return PyF16Type(context.getRef(), t);
}),
py::keep_alive<0, 1>(), "Create a f16 type.");
"Create a f16 type.");
}
};
@ -448,10 +496,10 @@ public:
static void bindDerived(ClassTy &c) {
c.def(py::init([](PyMlirContext &context) {
MlirType t = mlirF32TypeGet(context.context);
return PyF32Type(t);
MlirType t = mlirF32TypeGet(context.get());
return PyF32Type(context.getRef(), t);
}),
py::keep_alive<0, 1>(), "Create a f32 type.");
"Create a f32 type.");
}
};
@ -464,10 +512,10 @@ public:
static void bindDerived(ClassTy &c) {
c.def(py::init([](PyMlirContext &context) {
MlirType t = mlirF64TypeGet(context.context);
return PyF64Type(t);
MlirType t = mlirF64TypeGet(context.get());
return PyF64Type(context.getRef(), t);
}),
py::keep_alive<0, 1>(), "Create a f64 type.");
"Create a f64 type.");
}
};
@ -480,10 +528,10 @@ public:
static void bindDerived(ClassTy &c) {
c.def(py::init([](PyMlirContext &context) {
MlirType t = mlirNoneTypeGet(context.context);
return PyNoneType(t);
MlirType t = mlirNoneTypeGet(context.get());
return PyNoneType(context.getRef(), t);
}),
py::keep_alive<0, 1>(), "Create a none type.");
"Create a none type.");
}
};
@ -501,7 +549,7 @@ public:
// The element must be a floating point or integer scalar type.
if (mlirTypeIsAIntegerOrFloat(elementType.type)) {
MlirType t = mlirComplexTypeGet(elementType.type);
return PyComplexType(t);
return PyComplexType(elementType.getContext(), t);
}
throw SetPyError(
PyExc_ValueError,
@ -509,12 +557,12 @@ public:
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point or integer type.");
},
py::keep_alive<0, 1>(), "Create a complex type");
"Create a complex type");
c.def_property_readonly(
"element_type",
[](PyComplexType &self) -> PyType {
MlirType t = mlirComplexTypeGetElementType(self.type);
return PyType(t);
return PyType(self.getContext(), t);
},
"Returns element type.");
}
@ -531,9 +579,9 @@ public:
"element_type",
[](PyShapedType &self) {
MlirType t = mlirShapedTypeGetElementType(self.type);
return PyType(t);
return PyType(self.getContext(), t);
},
py::keep_alive<0, 1>(), "Returns the element type of the shaped type.");
"Returns the element type of the shaped type.");
c.def_property_readonly(
"has_rank",
[](PyShapedType &self) -> bool {
@ -616,9 +664,9 @@ public:
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point or integer type.");
}
return PyVectorType(t);
return PyVectorType(elementType.getContext(), t);
},
py::keep_alive<0, 2>(), "Create a vector type");
"Create a vector type");
}
};
@ -648,9 +696,9 @@ public:
"complex "
"type.");
}
return PyRankedTensorType(t);
return PyRankedTensorType(elementType.getContext(), t);
},
py::keep_alive<0, 2>(), "Create a ranked tensor type");
"Create a ranked tensor type");
}
};
@ -680,9 +728,9 @@ public:
"complex "
"type.");
}
return PyUnrankedTensorType(t);
return PyUnrankedTensorType(elementType.getContext(), t);
},
py::keep_alive<0, 1>(), "Create a unranked tensor type");
"Create a unranked tensor type");
}
};
@ -715,9 +763,9 @@ public:
"complex "
"type.");
}
return PyMemRefType(t);
return PyMemRefType(elementType.getContext(), t);
},
py::keep_alive<0, 1>(), "Create a memref type")
"Create a memref type")
.def_property_readonly(
"num_affine_maps",
[](PyMemRefType &self) -> intptr_t {
@ -760,9 +808,9 @@ public:
"complex "
"type.");
}
return PyUnrankedMemRefType(t);
return PyUnrankedMemRefType(elementType.getContext(), t);
},
py::keep_alive<0, 1>(), "Create a unranked memref type")
"Create a unranked memref type")
.def_property_readonly(
"memory_space",
[](PyUnrankedMemRefType &self) -> unsigned {
@ -788,17 +836,17 @@ public:
SmallVector<MlirType, 4> elements;
for (auto element : elementList)
elements.push_back(element.cast<PyType>().type);
MlirType t = mlirTupleTypeGet(context.context, num, elements.data());
return PyTupleType(t);
MlirType t = mlirTupleTypeGet(context.get(), num, elements.data());
return PyTupleType(context.getRef(), t);
},
py::keep_alive<0, 1>(), "Create a tuple type");
"Create a tuple type");
c.def(
"get_type",
[](PyTupleType &self, intptr_t pos) -> PyType {
MlirType t = mlirTupleTypeGetType(self.type, pos);
return PyType(t);
return PyType(self.getContext(), t);
},
py::keep_alive<0, 1>(), "Returns the pos-th type in the tuple type.");
"Returns the pos-th type in the tuple type.");
c.def_property_readonly(
"num_types",
[](PyTupleType &self) -> intptr_t {
@ -817,12 +865,21 @@ public:
void mlir::python::populateIRSubmodule(py::module &m) {
// Mapping of MlirContext
py::class_<PyMlirContext>(m, "Context")
.def(py::init<>())
.def(py::init<>([]() {
MlirContext context = mlirContextCreate();
auto contextRef = PyMlirContext::forContext(context);
return contextRef.release();
}))
.def_static("_get_live_count", &PyMlirContext::getLiveCount)
.def("_get_context_again",
[](PyMlirContext &self) {
auto ref = PyMlirContext::forContext(self.get());
return ref.release();
})
.def(
"parse_module",
[](PyMlirContext &self, const std::string module) {
auto moduleRef =
mlirModuleCreateParse(self.context, module.c_str());
auto moduleRef = mlirModuleCreateParse(self.get(), module.c_str());
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirModuleIsNull(moduleRef)) {
@ -830,14 +887,14 @@ void mlir::python::populateIRSubmodule(py::module &m) {
PyExc_ValueError,
"Unable to parse module assembly (see diagnostics)");
}
return PyModule(moduleRef);
return PyModule(self.getRef(), moduleRef);
},
py::keep_alive<0, 1>(), kContextParseDocstring)
kContextParseDocstring)
.def(
"parse_attr",
[](PyMlirContext &self, std::string attrSpec) {
MlirAttribute type =
mlirAttributeParseGet(self.context, attrSpec.c_str());
mlirAttributeParseGet(self.get(), attrSpec.c_str());
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirAttributeIsNull(type)) {
@ -845,13 +902,13 @@ void mlir::python::populateIRSubmodule(py::module &m) {
llvm::Twine("Unable to parse attribute: '") +
attrSpec + "'");
}
return PyAttribute(type);
return PyAttribute(self.getRef(), type);
},
py::keep_alive<0, 1>())
.def(
"parse_type",
[](PyMlirContext &self, std::string typeSpec) {
MlirType type = mlirTypeParseGet(self.context, typeSpec.c_str());
MlirType type = mlirTypeParseGet(self.get(), typeSpec.c_str());
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(type)) {
@ -859,30 +916,32 @@ void mlir::python::populateIRSubmodule(py::module &m) {
llvm::Twine("Unable to parse type: '") +
typeSpec + "'");
}
return PyType(type);
return PyType(self.getRef(), type);
},
py::keep_alive<0, 1>(), kContextParseTypeDocstring)
kContextParseTypeDocstring)
.def(
"get_unknown_location",
[](PyMlirContext &self) {
return PyLocation(mlirLocationUnknownGet(self.context));
return PyLocation(self.getRef(),
mlirLocationUnknownGet(self.get()));
},
py::keep_alive<0, 1>(), kContextGetUnknownLocationDocstring)
kContextGetUnknownLocationDocstring)
.def(
"get_file_location",
[](PyMlirContext &self, std::string filename, int line, int col) {
return PyLocation(mlirLocationFileLineColGet(
self.context, filename.c_str(), line, col));
return PyLocation(self.getRef(),
mlirLocationFileLineColGet(
self.get(), filename.c_str(), line, col));
},
py::keep_alive<0, 1>(), kContextGetFileLocationDocstring,
py::arg("filename"), py::arg("line"), py::arg("col"))
kContextGetFileLocationDocstring, py::arg("filename"),
py::arg("line"), py::arg("col"))
.def(
"create_region",
[](PyMlirContext &self) {
// The creating context is explicitly captured on regions to
// facilitate illegal assemblies of objects from multiple contexts
// that would invalidate the memory model.
return PyRegion(self.context, mlirRegionCreate(),
return PyRegion(self.get(), mlirRegionCreate(),
/*detached=*/true);
},
py::keep_alive<0, 1>(), kContextCreateRegionDocstring)
@ -893,7 +952,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
// types must be from the same context.
for (auto pyType : pyTypes) {
if (!mlirContextEqual(mlirTypeGetContext(pyType.type),
self.context)) {
self.get())) {
throw SetPyError(
PyExc_ValueError,
"All types used to construct a block must be from "
@ -902,8 +961,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
}
llvm::SmallVector<MlirType, 4> types(pyTypes.begin(),
pyTypes.end());
return PyBlock(self.context,
mlirBlockCreate(types.size(), &types[0]),
return PyBlock(self.get(), mlirBlockCreate(types.size(), &types[0]),
/*detached=*/true);
},
py::keep_alive<0, 1>(), kContextCreateBlockDocstring);
@ -1063,7 +1121,11 @@ void mlir::python::populateIRSubmodule(py::module &m) {
.def_property_readonly(
"attr",
[](PyNamedAttribute &self) {
return PyAttribute(self.namedAttr.attribute);
// TODO: When named attribute is removed/refactored, also remove
// this constructor (it does an inefficient table lookup).
auto contextRef = PyMlirContext::forContext(
mlirAttributeGetContext(self.namedAttr.attribute));
return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
},
py::keep_alive<0, 1>(),
"The underlying generic attribute of the NamedAttribute binding");

View File

@ -12,6 +12,7 @@
#include <pybind11/pybind11.h>
#include "mlir-c/IR.h"
#include "llvm/ADT/DenseMap.h"
namespace mlir {
namespace python {
@ -19,28 +20,105 @@ namespace python {
class PyMlirContext;
class PyModule;
/// Holds a C++ PyMlirContext and associated py::object, making it convenient
/// to have an auto-releasing C++-side keep-alive reference to the context.
/// The reference to the PyMlirContext is a simple C++ reference and the
/// py::object holds the reference count which keeps it alive.
class PyMlirContextRef {
public:
PyMlirContextRef(PyMlirContext &referrent, pybind11::object object)
: referrent(referrent), object(std::move(object)) {}
~PyMlirContextRef() {}
/// Releases the object held by this instance, causing its reference count
/// to remain artifically inflated by one. This must be used to return
/// the referenced PyMlirContext from a function. Otherwise, the destructor
/// of this reference would be called prior to the default take_ownership
/// policy assuming that the reference count has been transferred to it.
PyMlirContext *release();
PyMlirContext &operator->() { return referrent; }
pybind11::object getObject() { return object; }
private:
PyMlirContext &referrent;
pybind11::object object;
};
/// Wrapper around MlirContext.
class PyMlirContext {
public:
PyMlirContext() { context = mlirContextCreate(); }
~PyMlirContext() { mlirContextDestroy(context); }
PyMlirContext() = delete;
PyMlirContext(const PyMlirContext &) = delete;
PyMlirContext(PyMlirContext &&) = delete;
/// Returns a context reference for the singleton PyMlirContext wrapper for
/// the given context.
static PyMlirContextRef forContext(MlirContext context);
~PyMlirContext();
/// Accesses the underlying MlirContext.
MlirContext get() { return context; }
/// Gets a strong reference to this context, which will ensure it is kept
/// alive for the life of the reference.
PyMlirContextRef getRef() {
return PyMlirContextRef(
*this, pybind11::reinterpret_borrow<pybind11::object>(handle));
}
/// Gets the count of live context objects. Used for testing.
static size_t getLiveCount();
private:
PyMlirContext(MlirContext context);
// Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
// preserving the relationship that an MlirContext maps to a single
// PyMlirContext wrapper. This could be replaced in the future with an
// extension mechanism on the MlirContext for stashing user pointers.
// Note that this holds a handle, which does not imply ownership.
// Mappings will be removed when the context is destructed.
using LiveContextMap =
llvm::DenseMap<void *, std::pair<pybind11::handle, PyMlirContext *>>;
static LiveContextMap &getLiveContexts();
MlirContext context;
// The handle is set as part of lookup with forContext() (post construction).
pybind11::handle handle;
};
/// Base class for all objects that directly or indirectly depend on an
/// MlirContext. The lifetime of the context will extend at least to the
/// lifetime of these instances.
/// Immutable objects that depend on a context extend this directly.
class BaseContextObject {
public:
BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) {}
/// Accesses the context reference.
PyMlirContextRef &getContext() { return contextRef; }
private:
PyMlirContextRef contextRef;
};
/// Wrapper around an MlirLocation.
class PyLocation {
class PyLocation : public BaseContextObject {
public:
PyLocation(MlirLocation loc) : loc(loc) {}
PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
: BaseContextObject(std::move(contextRef)), loc(loc) {}
MlirLocation loc;
};
/// Wrapper around MlirModule.
class PyModule {
class PyModule : public BaseContextObject {
public:
PyModule(MlirModule module) : module(module) {}
PyModule(PyMlirContextRef contextRef, MlirModule module)
: BaseContextObject(std::move(contextRef)), module(module) {}
PyModule(PyModule &) = delete;
PyModule(PyModule &&other) {
PyModule(PyModule &&other)
: BaseContextObject(std::move(other.getContext())) {
module = other.module;
other.module.ptr = nullptr;
}
@ -120,9 +198,10 @@ private:
/// Wrapper around the generic MlirAttribute.
/// The lifetime of a type is bound by the PyContext that created it.
class PyAttribute {
class PyAttribute : public BaseContextObject {
public:
PyAttribute(MlirAttribute attr) : attr(attr) {}
PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
: BaseContextObject(std::move(contextRef)), attr(attr) {}
bool operator==(const PyAttribute &other);
MlirAttribute attr;
@ -153,9 +232,10 @@ private:
/// Wrapper around the generic MlirType.
/// The lifetime of a type is bound by the PyContext that created it.
class PyType {
class PyType : public BaseContextObject {
public:
PyType(MlirType type) : type(type) {}
PyType(PyMlirContextRef contextRef, MlirType type)
: BaseContextObject(std::move(contextRef)), type(type) {}
bool operator==(const PyType &other);
operator MlirType() const { return type; }

View File

@ -48,6 +48,10 @@ MlirLocation mlirLocationUnknownGet(MlirContext context) {
return wrap(UnknownLoc::get(unwrap(context)));
}
MlirContext mlirLocationGetContext(MlirLocation location) {
return wrap(unwrap(location).getContext());
}
void mlirLocationPrint(MlirLocation location, MlirStringCallback callback,
void *userData) {
detail::CallbackOstream stream(callback, userData);
@ -70,6 +74,10 @@ MlirModule mlirModuleCreateParse(MlirContext context, const char *module) {
return MlirModule{owning.release().getOperation()};
}
MlirContext mlirModuleGetContext(MlirModule module) {
return wrap(unwrap(module).getContext());
}
void mlirModuleDestroy(MlirModule module) {
// Transfer ownership to an OwningModuleRef so that its destructor is called.
OwningModuleRef(unwrap(module));
@ -349,6 +357,10 @@ MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) {
return wrap(mlir::parseAttribute(attr, unwrap(context)));
}
MlirContext mlirAttributeGetContext(MlirAttribute attribute) {
return wrap(unwrap(attribute).getContext());
}
int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
return unwrap(a1) == unwrap(a2);
}

View File

@ -0,0 +1,42 @@
# RUN: %PYTHON %s
# Standalone sanity check of context life-cycle.
import gc
import mlir
assert mlir.ir.Context._get_live_count() == 0
# Create first context.
print("CREATE C1")
c1 = mlir.ir.Context()
assert mlir.ir.Context._get_live_count() == 1
c1_repr = repr(c1)
print("C1 = ", c1_repr)
print("GETTING AGAIN...")
c2 = c1._get_context_again()
c2_repr = repr(c2)
assert mlir.ir.Context._get_live_count() == 1
assert c1_repr == c2_repr
print("C2 =", c2)
# Make sure new contexts on constructor.
print("CREATE C3")
c3 = mlir.ir.Context()
assert mlir.ir.Context._get_live_count() == 2
c3_repr = repr(c3)
print("C3 =", c3)
assert c3_repr != c1_repr
print("FREE C3")
c3 = None
gc.collect()
assert mlir.ir.Context._get_live_count() == 1
print("Free C1")
c1 = None
gc.collect()
assert mlir.ir.Context._get_live_count() == 1
print("Free C2")
c2 = None
gc.collect()
assert mlir.ir.Context._get_live_count() == 0