[mlir][Python] Python API cleanups and additions found during code audit.

* Add capsule get/create for Attribute and Type, which already had capsule interop defined.
* Add capsule interop and get/create for Location.
* Add Location __eq__.
* Use get() and implicit cast to go from PyAttribute, PyType, PyLocation to MlirAttribute, MlirType, MlirLocation (bundled with this change because I didn't want to continue the pattern one more time).

Differential Revision: https://reviews.llvm.org/D92283
This commit is contained in:
Stella Laurenzo 2020-11-29 13:30:23 -08:00
parent e6c1777685
commit bd2083c2fa
9 changed files with 237 additions and 85 deletions

View File

@ -28,6 +28,7 @@
#define MLIR_PYTHON_CAPSULE_ATTRIBUTE "mlir.ir.Attribute._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_CONTEXT "mlir.ir.Context._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_LOCATION "mlir.ir.Location._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_MODULE "mlir.ir.Module._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_OPERATION "mlir.ir.Operation._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_TYPE "mlir.ir.Type._CAPIPtr"
@ -106,6 +107,24 @@ static inline MlirContext mlirPythonCapsuleToContext(PyObject *capsule) {
return context;
}
/** Creates a capsule object encapsulating the raw C-API MlirLocation.
* The returned capsule does not extend or affect ownership of any Python
* objects that reference the location in any way. */
static inline PyObject *mlirPythonLocationToCapsule(MlirLocation loc) {
return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(loc),
MLIR_PYTHON_CAPSULE_LOCATION, NULL);
}
/** Extracts an MlirLocation from a capsule as produced from
* mlirPythonLocationToCapsule. If the capsule is not of the right type, then
* a null module is returned (as checked via mlirLocationIsNull). In such a
* case, the Python APIs will have already set an error. */
static inline MlirLocation mlirPythonCapsuleToLocation(PyObject *capsule) {
void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_LOCATION);
MlirLocation loc = {ptr};
return loc;
}
/** Creates a capsule object encapsulating the raw C-API MlirModule.
* The returned capsule does not extend or affect ownership of any Python
* objects that reference the module in any way. */

View File

@ -153,6 +153,14 @@ MLIR_CAPI_EXPORTED MlirLocation mlirLocationUnknownGet(MlirContext context);
/// Gets the context that a location was created with.
MLIR_CAPI_EXPORTED MlirContext mlirLocationGetContext(MlirLocation location);
/// Checks if the location is null.
static inline int mlirLocationIsNull(MlirLocation location) {
return !location.ptr;
}
/// Checks if two locations are equal.
MLIR_CAPI_EXPORTED int mlirLocationEqual(MlirLocation l1, MlirLocation l2);
/** Prints a location by sending chunks of the string representation and
* forwarding `userData to `callback`. Note that the callback may be called
* several times with consecutive chunks of the string. */

View File

@ -289,7 +289,7 @@ public:
llvm::SmallVector<MlirType, 4> argTypes;
argTypes.reserve(pyArgTypes.size());
for (auto &pyArg : pyArgTypes) {
argTypes.push_back(pyArg.cast<PyType &>().type);
argTypes.push_back(pyArg.cast<PyType &>());
}
MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
@ -640,6 +640,18 @@ MlirDialect PyDialects::getDialectForKey(const std::string &key,
// PyLocation
//------------------------------------------------------------------------------
py::object PyLocation::getCapsule() {
return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
}
PyLocation PyLocation::createFromCapsule(py::object capsule) {
MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
if (mlirLocationIsNull(rawLoc))
throw py::error_already_set();
return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
rawLoc);
}
py::object PyLocation::contextEnter() {
return PyThreadContextEntry::pushLocation(*this);
}
@ -879,7 +891,7 @@ py::object PyOperation::create(
// TODO: Verify result type originate from the same context.
if (!result)
throw SetPyError(PyExc_ValueError, "result type cannot be None");
mlirResults.push_back(result->type);
mlirResults.push_back(*result);
}
}
// Unpack/validate attributes.
@ -890,7 +902,7 @@ py::object PyOperation::create(
auto name = it.first.cast<std::string>();
auto &attribute = it.second.cast<PyAttribute &>();
// TODO: Verify attribute originates from the same context.
mlirAttributes.emplace_back(std::move(name), attribute.attr);
mlirAttributes.emplace_back(std::move(name), attribute);
}
}
// Unpack/validate successors.
@ -908,7 +920,7 @@ py::object PyOperation::create(
// Apply unpacked/validated to the operation state. Beyond this
// point, exceptions cannot be thrown or else the state will leak.
MlirOperationState state =
mlirOperationStateGet(toMlirStringRef(name), location->loc);
mlirOperationStateGet(toMlirStringRef(name), location);
if (!mlirOperands.empty())
mlirOperationStateAddOperands(&state, mlirOperands.size(),
mlirOperands.data());
@ -1076,6 +1088,18 @@ bool PyAttribute::operator==(const PyAttribute &other) {
return mlirAttributeEqual(attr, other.attr);
}
py::object PyAttribute::getCapsule() {
return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
}
PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
if (mlirAttributeIsNull(rawAttr))
throw py::error_already_set();
return PyAttribute(
PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
}
//------------------------------------------------------------------------------
// PyNamedAttribute.
//------------------------------------------------------------------------------
@ -1093,6 +1117,18 @@ bool PyType::operator==(const PyType &other) {
return mlirTypeEqual(type, other.type);
}
py::object PyType::getCapsule() {
return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
}
PyType PyType::createFromCapsule(py::object capsule) {
MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
if (mlirTypeIsNull(rawType))
throw py::error_already_set();
return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
rawType);
}
//------------------------------------------------------------------------------
// PyValue and subclases.
//------------------------------------------------------------------------------
@ -1315,7 +1351,7 @@ public:
void dunderSetItem(const std::string &name, PyAttribute attr) {
mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
attr.attr);
attr);
}
void dunderDelItem(const std::string &name) {
@ -1378,13 +1414,13 @@ public:
: PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
static MlirAttribute castFrom(PyAttribute &orig) {
if (!DerivedTy::isaFunction(orig.attr)) {
if (!DerivedTy::isaFunction(orig)) {
auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
throw SetPyError(PyExc_ValueError,
llvm::Twine("Cannot cast attribute to ") +
DerivedTy::pyClassName + " (from " + origRepr + ")");
}
return orig.attr;
return orig;
}
static void bind(py::module &m) {
@ -1408,8 +1444,7 @@ public:
c.def_static(
"get",
[](PyType &type, double value, DefaultingPyLocation loc) {
MlirAttribute attr =
mlirFloatAttrDoubleGetChecked(type.type, value, loc->loc);
MlirAttribute attr = mlirFloatAttrDoubleGetChecked(type, value, loc);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirAttributeIsNull(attr)) {
@ -1443,7 +1478,7 @@ public:
c.def_property_readonly(
"value",
[](PyFloatAttribute &self) {
return mlirFloatAttrGetValueDouble(self.attr);
return mlirFloatAttrGetValueDouble(self);
},
"Returns the value of the float point attribute");
}
@ -1460,7 +1495,7 @@ public:
c.def_static(
"get",
[](PyType &type, int64_t value) {
MlirAttribute attr = mlirIntegerAttrGet(type.type, value);
MlirAttribute attr = mlirIntegerAttrGet(type, value);
return PyIntegerAttribute(type.getContext(), attr);
},
py::arg("type"), py::arg("value"),
@ -1468,7 +1503,7 @@ public:
c.def_property_readonly(
"value",
[](PyIntegerAttribute &self) {
return mlirIntegerAttrGetValueInt(self.attr);
return mlirIntegerAttrGetValueInt(self);
},
"Returns the value of the integer attribute");
}
@ -1492,7 +1527,7 @@ public:
"Gets an uniqued bool attribute");
c.def_property_readonly(
"value",
[](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self.attr); },
[](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
"Returns the value of the bool attribute");
}
};
@ -1517,7 +1552,7 @@ public:
"get_typed",
[](PyType &type, std::string value) {
MlirAttribute attr =
mlirStringAttrTypedGet(type.type, value.size(), &value[0]);
mlirStringAttrTypedGet(type, value.size(), &value[0]);
return PyStringAttribute(type.getContext(), attr);
},
@ -1525,7 +1560,7 @@ public:
c.def_property_readonly(
"value",
[](PyStringAttribute &self) {
MlirStringRef stringRef = mlirStringAttrGetValue(self.attr);
MlirStringRef stringRef = mlirStringAttrGetValue(self);
return py::str(stringRef.data, stringRef.length);
},
"Returns the value of the string attribute");
@ -1621,8 +1656,8 @@ public:
PyAttribute &elementAttr) {
auto contextWrapper =
PyMlirContext::forContext(mlirTypeGetContext(shapedType));
if (!mlirAttributeIsAInteger(elementAttr.attr) &&
!mlirAttributeIsAFloat(elementAttr.attr)) {
if (!mlirAttributeIsAInteger(elementAttr) &&
!mlirAttributeIsAFloat(elementAttr)) {
std::string message = "Illegal element type for DenseElementsAttr: ";
message.append(py::repr(py::cast(elementAttr)));
throw SetPyError(PyExc_ValueError, message);
@ -1634,8 +1669,8 @@ public:
message.append(py::repr(py::cast(shapedType)));
throw SetPyError(PyExc_ValueError, message);
}
MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType.type);
MlirType attrType = mlirAttributeGetType(elementAttr.attr);
MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
MlirType attrType = mlirAttributeGetType(elementAttr);
if (!mlirTypeEqual(shapedElementType, attrType)) {
std::string message =
"Shaped element type and attribute type must be equal: shaped=";
@ -1646,14 +1681,14 @@ public:
}
MlirAttribute elements =
mlirDenseElementsAttrSplatGet(shapedType.type, elementAttr.attr);
mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
}
intptr_t dunderLen() { return mlirElementsAttrGetNumElements(attr); }
intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
py::buffer_info accessBuffer() {
MlirType shapedType = mlirAttributeGetType(this->attr);
MlirType shapedType = mlirAttributeGetType(*this);
MlirType elementType = mlirShapedTypeGetElementType(shapedType);
if (mlirTypeIsAF32(elementType)) {
@ -1699,7 +1734,7 @@ public:
"Gets a DenseElementsAttr where all values are the same")
.def_property_readonly("is_splat",
[](PyDenseElementsAttribute &self) -> bool {
return mlirDenseElementsAttrIsSplat(self.attr);
return mlirDenseElementsAttrIsSplat(self);
})
.def_buffer(&PyDenseElementsAttribute::accessBuffer);
}
@ -1742,7 +1777,7 @@ private:
// Prepare the data for the buffer_info.
// Buffer is configured for read-only access below.
Type *data = static_cast<Type *>(
const_cast<void *>(mlirDenseElementsAttrGetRawData(this->attr)));
const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
// Prepare the shape for the buffer_info.
SmallVector<intptr_t, 4> shape;
for (intptr_t i = 0; i < rank; ++i)
@ -1782,7 +1817,7 @@ public:
"attempt to access out of bounds element");
}
MlirType type = mlirAttributeGetType(attr);
MlirType type = mlirAttributeGetType(*this);
type = mlirShapedTypeGetElementType(type);
assert(mlirTypeIsAInteger(type) &&
"expected integer element type in dense int elements attribute");
@ -1795,23 +1830,23 @@ public:
bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
if (isUnsigned) {
if (width == 1) {
return mlirDenseElementsAttrGetBoolValue(attr, pos);
return mlirDenseElementsAttrGetBoolValue(*this, pos);
}
if (width == 32) {
return mlirDenseElementsAttrGetUInt32Value(attr, pos);
return mlirDenseElementsAttrGetUInt32Value(*this, pos);
}
if (width == 64) {
return mlirDenseElementsAttrGetUInt64Value(attr, pos);
return mlirDenseElementsAttrGetUInt64Value(*this, pos);
}
} else {
if (width == 1) {
return mlirDenseElementsAttrGetBoolValue(attr, pos);
return mlirDenseElementsAttrGetBoolValue(*this, pos);
}
if (width == 32) {
return mlirDenseElementsAttrGetInt32Value(attr, pos);
return mlirDenseElementsAttrGetInt32Value(*this, pos);
}
if (width == 64) {
return mlirDenseElementsAttrGetInt64Value(attr, pos);
return mlirDenseElementsAttrGetInt64Value(*this, pos);
}
}
throw SetPyError(PyExc_TypeError, "Unsupported integer type");
@ -1838,7 +1873,7 @@ public:
"attempt to access out of bounds element");
}
MlirType type = mlirAttributeGetType(attr);
MlirType type = mlirAttributeGetType(*this);
type = mlirShapedTypeGetElementType(type);
// Dispatch element extraction to an appropriate C function based on the
// elemental type of the attribute. py::float_ is implicitly constructible
@ -1846,10 +1881,10 @@ public:
// TODO: consider caching the type properties in the constructor to avoid
// querying them on each element access.
if (mlirTypeIsAF32(type)) {
return mlirDenseElementsAttrGetFloatValue(attr, pos);
return mlirDenseElementsAttrGetFloatValue(*this, pos);
}
if (mlirTypeIsAF64(type)) {
return mlirDenseElementsAttrGetDoubleValue(attr, pos);
return mlirDenseElementsAttrGetDoubleValue(*this, pos);
}
throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
}
@ -1906,13 +1941,13 @@ public:
: PyConcreteType(orig.getContext(), castFrom(orig)) {}
static MlirType castFrom(PyType &orig) {
if (!DerivedTy::isaFunction(orig.type)) {
if (!DerivedTy::isaFunction(orig)) {
auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") +
DerivedTy::pyClassName +
" (from " + origRepr + ")");
}
return orig.type;
return orig;
}
static void bind(py::module &m) {
@ -1958,24 +1993,24 @@ public:
"Create an unsigned integer type");
c.def_property_readonly(
"width",
[](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self.type); },
[](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
"Returns the width of the integer type");
c.def_property_readonly(
"is_signless",
[](PyIntegerType &self) -> bool {
return mlirIntegerTypeIsSignless(self.type);
return mlirIntegerTypeIsSignless(self);
},
"Returns whether this is a signless integer");
c.def_property_readonly(
"is_signed",
[](PyIntegerType &self) -> bool {
return mlirIntegerTypeIsSigned(self.type);
return mlirIntegerTypeIsSigned(self);
},
"Returns whether this is a signed integer");
c.def_property_readonly(
"is_unsigned",
[](PyIntegerType &self) -> bool {
return mlirIntegerTypeIsUnsigned(self.type);
return mlirIntegerTypeIsUnsigned(self);
},
"Returns whether this is an unsigned integer");
}
@ -2101,8 +2136,8 @@ public:
"get",
[](PyType &elementType) {
// The element must be a floating point or integer scalar type.
if (mlirTypeIsAIntegerOrFloat(elementType.type)) {
MlirType t = mlirComplexTypeGet(elementType.type);
if (mlirTypeIsAIntegerOrFloat(elementType)) {
MlirType t = mlirComplexTypeGet(elementType);
return PyComplexType(elementType.getContext(), t);
}
throw SetPyError(
@ -2115,7 +2150,7 @@ public:
c.def_property_readonly(
"element_type",
[](PyComplexType &self) -> PyType {
MlirType t = mlirComplexTypeGetElementType(self.type);
MlirType t = mlirComplexTypeGetElementType(self);
return PyType(self.getContext(), t);
},
"Returns element type.");
@ -2132,34 +2167,32 @@ public:
c.def_property_readonly(
"element_type",
[](PyShapedType &self) {
MlirType t = mlirShapedTypeGetElementType(self.type);
MlirType t = mlirShapedTypeGetElementType(self);
return PyType(self.getContext(), t);
},
"Returns the element type of the shaped type.");
c.def_property_readonly(
"has_rank",
[](PyShapedType &self) -> bool {
return mlirShapedTypeHasRank(self.type);
},
[](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); },
"Returns whether the given shaped type is ranked.");
c.def_property_readonly(
"rank",
[](PyShapedType &self) {
self.requireHasRank();
return mlirShapedTypeGetRank(self.type);
return mlirShapedTypeGetRank(self);
},
"Returns the rank of the given ranked shaped type.");
c.def_property_readonly(
"has_static_shape",
[](PyShapedType &self) -> bool {
return mlirShapedTypeHasStaticShape(self.type);
return mlirShapedTypeHasStaticShape(self);
},
"Returns whether the given shaped type has a static shape.");
c.def(
"is_dynamic_dim",
[](PyShapedType &self, intptr_t dim) -> bool {
self.requireHasRank();
return mlirShapedTypeIsDynamicDim(self.type, dim);
return mlirShapedTypeIsDynamicDim(self, dim);
},
"Returns whether the dim-th dimension of the given shaped type is "
"dynamic.");
@ -2167,7 +2200,7 @@ public:
"get_dim_size",
[](PyShapedType &self, intptr_t dim) {
self.requireHasRank();
return mlirShapedTypeGetDimSize(self.type, dim);
return mlirShapedTypeGetDimSize(self, dim);
},
"Returns the dim-th dimension of the given ranked shaped type.");
c.def_static(
@ -2187,7 +2220,7 @@ public:
private:
void requireHasRank() {
if (!mlirShapedTypeHasRank(type)) {
if (!mlirShapedTypeHasRank(*this)) {
throw SetPyError(
PyExc_ValueError,
"calling this method requires that the type has a rank.");
@ -2208,7 +2241,7 @@ public:
[](std::vector<int64_t> shape, PyType &elementType,
DefaultingPyLocation loc) {
MlirType t = mlirVectorTypeGetChecked(shape.size(), shape.data(),
elementType.type, loc->loc);
elementType, loc);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
@ -2239,7 +2272,7 @@ public:
[](std::vector<int64_t> shape, PyType &elementType,
DefaultingPyLocation loc) {
MlirType t = mlirRankedTensorTypeGetChecked(
shape.size(), shape.data(), elementType.type, loc->loc);
shape.size(), shape.data(), elementType, loc);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
@ -2270,8 +2303,7 @@ public:
c.def_static(
"get",
[](PyType &elementType, DefaultingPyLocation loc) {
MlirType t =
mlirUnrankedTensorTypeGetChecked(elementType.type, loc->loc);
MlirType t = mlirUnrankedTensorTypeGetChecked(elementType, loc);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
@ -2306,8 +2338,7 @@ public:
[](PyType &elementType, std::vector<int64_t> shape,
unsigned memorySpace, DefaultingPyLocation loc) {
MlirType t = mlirMemRefTypeContiguousGetChecked(
elementType.type, shape.size(), shape.data(), memorySpace,
loc->loc);
elementType, shape.size(), shape.data(), memorySpace, loc);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
@ -2326,14 +2357,14 @@ public:
.def_property_readonly(
"num_affine_maps",
[](PyMemRefType &self) -> intptr_t {
return mlirMemRefTypeGetNumAffineMaps(self.type);
return mlirMemRefTypeGetNumAffineMaps(self);
},
"Returns the number of affine layout maps in the given MemRef "
"type.")
.def_property_readonly(
"memory_space",
[](PyMemRefType &self) -> unsigned {
return mlirMemRefTypeGetMemorySpace(self.type);
return mlirMemRefTypeGetMemorySpace(self);
},
"Returns the memory space of the given MemRef type.");
}
@ -2352,8 +2383,8 @@ public:
"get",
[](PyType &elementType, unsigned memorySpace,
DefaultingPyLocation loc) {
MlirType t = mlirUnrankedMemRefTypeGetChecked(elementType.type,
memorySpace, loc->loc);
MlirType t =
mlirUnrankedMemRefTypeGetChecked(elementType, memorySpace, loc);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
@ -2372,7 +2403,7 @@ public:
.def_property_readonly(
"memory_space",
[](PyUnrankedMemRefType &self) -> unsigned {
return mlirUnrankedMemrefGetMemorySpace(self.type);
return mlirUnrankedMemrefGetMemorySpace(self);
},
"Returns the memory space of the given Unranked MemRef type.");
}
@ -2393,7 +2424,7 @@ public:
// Mapping py::list to SmallVector.
SmallVector<MlirType, 4> elements;
for (auto element : elementList)
elements.push_back(element.cast<PyType>().type);
elements.push_back(element.cast<PyType>());
MlirType t = mlirTupleTypeGet(context->get(), num, elements.data());
return PyTupleType(context->getRef(), t);
},
@ -2402,14 +2433,14 @@ public:
c.def(
"get_type",
[](PyTupleType &self, intptr_t pos) -> PyType {
MlirType t = mlirTupleTypeGetType(self.type, pos);
MlirType t = mlirTupleTypeGetType(self, pos);
return PyType(self.getContext(), t);
},
"Returns the pos-th type in the tuple type.");
c.def_property_readonly(
"num_types",
[](PyTupleType &self) -> intptr_t {
return mlirTupleTypeGetNumTypes(self.type);
return mlirTupleTypeGetNumTypes(self);
},
"Returns the number of types contained in a tuple.");
}
@ -2439,11 +2470,11 @@ public:
c.def_property_readonly(
"inputs",
[](PyFunctionType &self) {
MlirType t = self.type;
MlirType t = self;
auto contextRef = self.getContext();
py::list types;
for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self.type);
i < e; ++i) {
for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
++i) {
types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i)));
}
return types;
@ -2452,12 +2483,12 @@ public:
c.def_property_readonly(
"results",
[](PyFunctionType &self) {
MlirType t = self.type;
auto contextRef = self.getContext();
py::list types;
for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self.type);
i < e; ++i) {
types.append(PyType(contextRef, mlirFunctionTypeGetResult(t, i)));
for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
++i) {
types.append(
PyType(contextRef, mlirFunctionTypeGetResult(self, i)));
}
return types;
},
@ -2584,8 +2615,15 @@ void mlir::python::populateIRSubmodule(py::module &m) {
// Mapping of Location
//----------------------------------------------------------------------------
py::class_<PyLocation>(m, "Location")
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
.def("__enter__", &PyLocation::contextEnter)
.def("__exit__", &PyLocation::contextExit)
.def("__eq__",
[](PyLocation &self, PyLocation &other) -> bool {
return mlirLocationEqual(self, other);
})
.def("__eq__", [](PyLocation &self, py::object other) { return false; })
.def_property_readonly_static(
"current",
[](py::object & /*class*/) {
@ -2620,7 +2658,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
"Context that owns the Location")
.def("__repr__", [](PyLocation &self) {
PyPrintAccumulator printAccum;
mlirLocationPrint(self.loc, printAccum.getCallback(),
mlirLocationPrint(self, printAccum.getCallback(),
printAccum.getUserData());
return printAccum.join();
});
@ -2650,7 +2688,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
.def_static(
"create",
[](DefaultingPyLocation loc) {
MlirModule module = mlirModuleCreateEmpty(loc->loc);
MlirModule module = mlirModuleCreateEmpty(loc);
return PyModule::forModule(module).releaseObject();
},
py::arg("loc") = py::none(), "Creates an empty module")
@ -2881,6 +2919,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
// Mapping of PyAttribute.
//----------------------------------------------------------------------------
py::class_<PyAttribute>(m, "Attribute")
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyAttribute::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
.def_static(
"parse",
[](std::string attrSpec, DefaultingPyMlirContext context) {
@ -2904,25 +2945,25 @@ void mlir::python::populateIRSubmodule(py::module &m) {
.def_property_readonly("type",
[](PyAttribute &self) {
return PyType(self.getContext()->getRef(),
mlirAttributeGetType(self.attr));
mlirAttributeGetType(self));
})
.def(
"get_named",
[](PyAttribute &self, std::string name) {
return PyNamedAttribute(self.attr, std::move(name));
return PyNamedAttribute(self, std::move(name));
},
py::keep_alive<0, 1>(), "Binds a name to the attribute")
.def("__eq__",
[](PyAttribute &self, PyAttribute &other) { return self == other; })
.def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
.def(
"dump", [](PyAttribute &self) { mlirAttributeDump(self.attr); },
"dump", [](PyAttribute &self) { mlirAttributeDump(self); },
kDumpDocstring)
.def(
"__str__",
[](PyAttribute &self) {
PyPrintAccumulator printAccum;
mlirAttributePrint(self.attr, printAccum.getCallback(),
mlirAttributePrint(self, printAccum.getCallback(),
printAccum.getUserData());
return printAccum.join();
},
@ -2935,7 +2976,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
// being excessive.
PyPrintAccumulator printAccum;
printAccum.parts.append("Attribute(");
mlirAttributePrint(self.attr, printAccum.getCallback(),
mlirAttributePrint(self, printAccum.getCallback(),
printAccum.getUserData());
printAccum.parts.append(")");
return printAccum.join();
@ -2990,6 +3031,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
// Mapping of PyType.
//----------------------------------------------------------------------------
py::class_<PyType>(m, "Type")
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
.def_static(
"parse",
[](std::string typeSpec, DefaultingPyMlirContext context) {
@ -3012,12 +3055,12 @@ void mlir::python::populateIRSubmodule(py::module &m) {
.def("__eq__", [](PyType &self, PyType &other) { return self == other; })
.def("__eq__", [](PyType &self, py::object &other) { return false; })
.def(
"dump", [](PyType &self) { mlirTypeDump(self.type); }, kDumpDocstring)
"dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
.def(
"__str__",
[](PyType &self) {
PyPrintAccumulator printAccum;
mlirTypePrint(self.type, printAccum.getCallback(),
mlirTypePrint(self, printAccum.getCallback(),
printAccum.getUserData());
return printAccum.join();
},
@ -3029,8 +3072,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
// assembly forms and printing them is useful.
PyPrintAccumulator printAccum;
printAccum.parts.append("Type(");
mlirTypePrint(self.type, printAccum.getCallback(),
printAccum.getUserData());
mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
printAccum.parts.append(")");
return printAccum.join();
});

View File

@ -307,11 +307,24 @@ public:
PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
: BaseContextObject(std::move(contextRef)), loc(loc) {}
operator MlirLocation() const { return loc; }
MlirLocation get() const { return loc; }
/// Enter and exit the context manager.
pybind11::object contextEnter();
void contextExit(pybind11::object excType, pybind11::object excVal,
pybind11::object excTb);
/// Gets a capsule wrapping the void* within the MlirContext.
pybind11::object getCapsule();
/// Creates a PyMlirContext from the MlirContext wrapped by a capsule.
/// Note that PyMlirContext instances are uniqued, so the returned object
/// may be a pre-existing object. Ownership of the underlying MlirContext
/// is taken by calling this function.
static PyLocation createFromCapsule(pybind11::object capsule);
private:
MlirLocation loc;
};
@ -324,6 +337,8 @@ public:
static constexpr const char kTypeDescription[] =
"[ThreadContextAware] mlir.ir.Location";
static PyLocation &resolve();
operator MlirLocation() const { return *get(); }
};
/// Wrapper around MlirModule.
@ -568,7 +583,19 @@ public:
PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
: BaseContextObject(std::move(contextRef)), attr(attr) {}
bool operator==(const PyAttribute &other);
operator MlirAttribute() const { return attr; }
MlirAttribute get() const { return attr; }
/// Gets a capsule wrapping the void* within the MlirContext.
pybind11::object getCapsule();
/// Creates a PyMlirContext from the MlirContext wrapped by a capsule.
/// Note that PyMlirContext instances are uniqued, so the returned object
/// may be a pre-existing object. Ownership of the underlying MlirContext
/// is taken by calling this function.
static PyAttribute createFromCapsule(pybind11::object capsule);
private:
MlirAttribute attr;
};
@ -603,7 +630,18 @@ public:
: BaseContextObject(std::move(contextRef)), type(type) {}
bool operator==(const PyType &other);
operator MlirType() const { return type; }
MlirType get() const { return type; }
/// Gets a capsule wrapping the void* within the MlirContext.
pybind11::object getCapsule();
/// Creates a PyMlirContext from the MlirContext wrapped by a capsule.
/// Note that PyMlirContext instances are uniqued, so the returned object
/// may be a pre-existing object. Ownership of the underlying MlirContext
/// is taken by calling this function.
static PyType createFromCapsule(pybind11::object capsule);
private:
MlirType type;
};

View File

@ -50,7 +50,7 @@ public:
Defaulting() = default;
Defaulting(ReferrentTy &referrent) : referrent(&referrent) {}
ReferrentTy *get() { return referrent; }
ReferrentTy *get() const { return referrent; }
ReferrentTy *operator->() { return referrent; }
private:

View File

@ -119,6 +119,10 @@ MlirLocation mlirLocationUnknownGet(MlirContext context) {
return wrap(UnknownLoc::get(unwrap(context)));
}
int mlirLocationEqual(MlirLocation l1, MlirLocation l2) {
return unwrap(l1) == unwrap(l2);
}
MlirContext mlirLocationGetContext(MlirLocation location) {
return wrap(unwrap(location).getContext());
}

View File

@ -74,6 +74,20 @@ def testAttrEqDoesNotRaise():
run(testAttrEqDoesNotRaise)
# CHECK-LABEL: TEST: testAttrCapsule
def testAttrCapsule():
with Context() as ctx:
a1 = Attribute.parse('"attr1"')
# CHECK: mlir.ir.Attribute._CAPIPtr
attr_capsule = a1._CAPIPtr
print(attr_capsule)
a2 = Attribute._CAPICreate(attr_capsule)
assert a2 == a1
assert a2.context is ctx
run(testAttrCapsule)
# CHECK-LABEL: TEST: testStandardAttrCasts
def testStandardAttrCasts():
with Context():

View File

@ -38,3 +38,16 @@ def testFileLineCol():
run(testFileLineCol)
# CHECK-LABEL: TEST: testLocationCapsule
def testLocationCapsule():
with Context() as ctx:
loc1 = Location.file("foo.txt", 123, 56)
# CHECK: mlir.ir.Location._CAPIPtr
loc_capsule = loc1._CAPIPtr
print(loc_capsule)
loc2 = Location._CAPICreate(loc_capsule)
assert loc2 == loc1
assert loc2.context is ctx
run(testLocationCapsule)

View File

@ -74,6 +74,20 @@ def testTypeEqDoesNotRaise():
run(testTypeEqDoesNotRaise)
# CHECK-LABEL: TEST: testTypeCapsule
def testTypeCapsule():
with Context() as ctx:
t1 = Type.parse("i32", ctx)
# CHECK: mlir.ir.Type._CAPIPtr
type_capsule = t1._CAPIPtr
print(type_capsule)
t2 = Type._CAPICreate(type_capsule)
assert t2 == t1
assert t2.context is ctx
run(testTypeCapsule)
# CHECK-LABEL: TEST: testStandardTypeCasts
def testStandardTypeCasts():
ctx = Context()