forked from OSchip/llvm-project
[mlir] Move PyConcreteAttribute to header. NFC.
This allows out-of-tree users to derive PyConcreteAttribute to bind custom attributes. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D101063
This commit is contained in:
parent
4138e7bd76
commit
0b10fdedf9
|
@ -27,46 +27,6 @@ static MlirStringRef toMlirStringRef(const std::string &s) {
|
|||
return mlirStringRefCreate(s.data(), s.size());
|
||||
}
|
||||
|
||||
/// CRTP base classes for Python attributes that subclass Attribute and should
|
||||
/// be castable from it (i.e. via something like StringAttr(attr)).
|
||||
/// By default, attribute class hierarchies are one level deep (i.e. a
|
||||
/// concrete attribute class extends PyAttribute); however, intermediate
|
||||
/// python-visible base classes can be modeled by specifying a BaseTy.
|
||||
template <typename DerivedTy, typename BaseTy = PyAttribute>
|
||||
class PyConcreteAttribute : public BaseTy {
|
||||
public:
|
||||
// Derived classes must define statics for:
|
||||
// IsAFunctionTy isaFunction
|
||||
// const char *pyClassName
|
||||
using ClassTy = py::class_<DerivedTy, BaseTy>;
|
||||
using IsAFunctionTy = bool (*)(MlirAttribute);
|
||||
|
||||
PyConcreteAttribute() = default;
|
||||
PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
|
||||
: BaseTy(std::move(contextRef), attr) {}
|
||||
PyConcreteAttribute(PyAttribute &orig)
|
||||
: PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
|
||||
|
||||
static MlirAttribute castFrom(PyAttribute &orig) {
|
||||
if (!DerivedTy::isaFunction(orig)) {
|
||||
auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
|
||||
throw SetPyError(PyExc_ValueError, Twine("Cannot cast attribute to ") +
|
||||
DerivedTy::pyClassName +
|
||||
" (from " + origRepr + ")");
|
||||
}
|
||||
return orig;
|
||||
}
|
||||
|
||||
static void bind(py::module &m) {
|
||||
auto cls = ClassTy(m, DerivedTy::pyClassName, py::buffer_protocol());
|
||||
cls.def(py::init<PyAttribute &>(), py::keep_alive<0, 1>());
|
||||
DerivedTy::bindDerived(cls);
|
||||
}
|
||||
|
||||
/// Implemented by derived classes to add methods to the Python subclass.
|
||||
static void bindDerived(ClassTy &m) {}
|
||||
};
|
||||
|
||||
class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
|
||||
|
|
|
@ -642,6 +642,46 @@ private:
|
|||
std::unique_ptr<std::string> ownedName;
|
||||
};
|
||||
|
||||
/// CRTP base classes for Python attributes that subclass Attribute and should
|
||||
/// be castable from it (i.e. via something like StringAttr(attr)).
|
||||
/// By default, attribute class hierarchies are one level deep (i.e. a
|
||||
/// concrete attribute class extends PyAttribute); however, intermediate
|
||||
/// python-visible base classes can be modeled by specifying a BaseTy.
|
||||
template <typename DerivedTy, typename BaseTy = PyAttribute>
|
||||
class PyConcreteAttribute : public BaseTy {
|
||||
public:
|
||||
// Derived classes must define statics for:
|
||||
// IsAFunctionTy isaFunction
|
||||
// const char *pyClassName
|
||||
using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
|
||||
using IsAFunctionTy = bool (*)(MlirAttribute);
|
||||
|
||||
PyConcreteAttribute() = default;
|
||||
PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
|
||||
: BaseTy(std::move(contextRef), attr) {}
|
||||
PyConcreteAttribute(PyAttribute &orig)
|
||||
: PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
|
||||
|
||||
static MlirAttribute castFrom(PyAttribute &orig) {
|
||||
if (!DerivedTy::isaFunction(orig)) {
|
||||
auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>();
|
||||
throw SetPyError(PyExc_ValueError,
|
||||
llvm::Twine("Cannot cast attribute to ") +
|
||||
DerivedTy::pyClassName + " (from " + origRepr + ")");
|
||||
}
|
||||
return orig;
|
||||
}
|
||||
|
||||
static void bind(pybind11::module &m) {
|
||||
auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol());
|
||||
cls.def(pybind11::init<PyAttribute &>(), pybind11::keep_alive<0, 1>());
|
||||
DerivedTy::bindDerived(cls);
|
||||
}
|
||||
|
||||
/// Implemented by derived classes to add methods to the Python subclass.
|
||||
static void bindDerived(ClassTy &m) {}
|
||||
};
|
||||
|
||||
/// Wrapper around the generic MlirType.
|
||||
/// The lifetime of a type is bound by the PyContext that created it.
|
||||
class PyType : public BaseContextObject {
|
||||
|
|
Loading…
Reference in New Issue