Extend PyConcreteType to support intermediate base classes.

* Resolves todos from D87091.
* Also modifies PyConcreteAttribute to follow suite (should be useful for ElementsAttr and friends).
* Adds a test to ensure that the ShapedType base class functions as expected.

Differential Revision: https://reviews.llvm.org/D87208
This commit is contained in:
Stella Laurenzo 2020-09-06 12:16:40 -07:00
parent 7695332166
commit 7403e3ee32
2 changed files with 170 additions and 165 deletions

View File

@ -221,34 +221,37 @@ namespace {
/// CRTP base classes for Python attributes that subclass Attribute and should
/// be castable from it (i.e. via something like StringAttr(attr)).
template <typename T>
class PyConcreteAttribute : public PyAttribute {
/// By default, attribute class hierarchies are one level deep (i.e. a
/// concrete attribute class extends PyAttribute); however, intermediate
/// python-visible base classes can be modeled by specifying a BaseTy.
template <typename DerivedTy, typename BaseTy = PyAttribute>
class PyConcreteAttribute : public BaseTy {
public:
// Derived classes must define statics for:
// IsAFunctionTy isaFunction
// const char *pyClassName
using ClassTy = py::class_<T, PyAttribute>;
using ClassTy = py::class_<DerivedTy, PyAttribute>;
using IsAFunctionTy = int (*)(MlirAttribute);
PyConcreteAttribute() = default;
PyConcreteAttribute(MlirAttribute attr) : PyAttribute(attr) {}
PyConcreteAttribute(MlirAttribute attr) : BaseTy(attr) {}
PyConcreteAttribute(PyAttribute &orig)
: PyConcreteAttribute(castFrom(orig)) {}
static MlirAttribute castFrom(PyAttribute &orig) {
if (!T::isaFunction(orig.attr)) {
if (!DerivedTy::isaFunction(orig.attr)) {
auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
throw SetPyError(PyExc_ValueError,
llvm::Twine("Cannot cast attribute to ") +
T::pyClassName + " (from " + origRepr + ")");
DerivedTy::pyClassName + " (from " + origRepr + ")");
}
return orig.attr;
}
static void bind(py::module &m) {
auto cls = ClassTy(m, T::pyClassName);
auto cls = ClassTy(m, DerivedTy::pyClassName);
cls.def(py::init<PyAttribute &>(), py::keep_alive<0, 1>());
T::bindDerived(cls);
DerivedTy::bindDerived(cls);
}
/// Implemented by derived classes to add methods to the Python subclass.
@ -301,33 +304,36 @@ namespace {
/// CRTP base classes for Python types that subclass Type and should be
/// castable from it (i.e. via something like IntegerType(t)).
template <typename T>
class PyConcreteType : public PyType {
/// By default, type class hierarchies are one level deep (i.e. a
/// concrete type class extends PyType); however, intermediate python-visible
/// base classes can be modeled by specifying a BaseTy.
template <typename DerivedTy, typename BaseTy = PyType>
class PyConcreteType : public BaseTy {
public:
// Derived classes must define statics for:
// IsAFunctionTy isaFunction
// const char *pyClassName
using ClassTy = py::class_<T, PyType>;
using ClassTy = py::class_<DerivedTy, BaseTy>;
using IsAFunctionTy = int (*)(MlirType);
PyConcreteType() = default;
PyConcreteType(MlirType t) : PyType(t) {}
PyConcreteType(PyType &orig) : PyType(castFrom(orig)) {}
PyConcreteType(MlirType t) : BaseTy(t) {}
PyConcreteType(PyType &orig) : PyConcreteType(castFrom(orig)) {}
static MlirType castFrom(PyType &orig) {
if (!T::isaFunction(orig.type)) {
if (!DerivedTy::isaFunction(orig.type)) {
auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") +
T::pyClassName + " (from " +
origRepr + ")");
DerivedTy::pyClassName +
" (from " + origRepr + ")");
}
return orig.type;
}
static void bind(py::module &m) {
auto cls = ClassTy(m, T::pyClassName);
auto cls = ClassTy(m, DerivedTy::pyClassName);
cls.def(py::init<PyType &>(), py::keep_alive<0, 1>());
T::bindDerived(cls);
DerivedTy::bindDerived(cls);
}
/// Implemented by derived classes to add methods to the Python subclass.
@ -590,142 +596,130 @@ private:
};
/// Vector Type subclass - VectorType.
class PyVectorType : public PyShapedType {
class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
static constexpr const char *pyClassName = "VectorType";
using PyShapedType::PyShapedType;
// TODO: Switch back to bindDerived by making the ClassTy modifiable by
// subclasses, exposing the ShapedType hierarchy.
static void bind(py::module &m) {
py::class_<PyVectorType, PyShapedType>(m, pyClassName)
.def(py::init<PyType &>(), py::keep_alive<0, 1>())
.def_static(
"get_vector",
// TODO: Make the location optional and create a default location.
[](std::vector<int64_t> shape, PyType &elementType,
PyLocation &loc) {
MlirType t = mlirVectorTypeGetChecked(shape.size(), shape.data(),
elementType.type, loc.loc);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
throw SetPyError(
PyExc_ValueError,
llvm::Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point or integer type.");
}
return PyVectorType(t);
},
py::keep_alive<0, 2>(), "Create a vector type");
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def_static(
"get_vector",
// TODO: Make the location optional and create a default location.
[](std::vector<int64_t> shape, PyType &elementType, PyLocation &loc) {
MlirType t = mlirVectorTypeGetChecked(shape.size(), shape.data(),
elementType.type, loc.loc);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
throw SetPyError(
PyExc_ValueError,
llvm::Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point or integer type.");
}
return PyVectorType(t);
},
py::keep_alive<0, 2>(), "Create a vector type");
}
};
/// Ranked Tensor Type subclass - RankedTensorType.
class PyRankedTensorType : public PyShapedType {
class PyRankedTensorType
: public PyConcreteType<PyRankedTensorType, PyShapedType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
static constexpr const char *pyClassName = "RankedTensorType";
using PyShapedType::PyShapedType;
// TODO: Switch back to bindDerived by making the ClassTy modifiable by
// subclasses, exposing the ShapedType hierarchy.
static void bind(py::module &m) {
py::class_<PyRankedTensorType, PyShapedType>(m, pyClassName)
.def(py::init<PyType &>(), py::keep_alive<0, 1>())
.def_static(
"get_ranked_tensor",
// TODO: Make the location optional and create a default location.
[](std::vector<int64_t> shape, PyType &elementType,
PyLocation &loc) {
MlirType t = mlirRankedTensorTypeGetChecked(
shape.size(), shape.data(), elementType.type, loc.loc);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
throw SetPyError(
PyExc_ValueError,
llvm::Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point, integer, vector or "
"complex "
"type.");
}
return PyRankedTensorType(t);
},
py::keep_alive<0, 2>(), "Create a ranked tensor type");
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def_static(
"get_ranked_tensor",
// TODO: Make the location optional and create a default location.
[](std::vector<int64_t> shape, PyType &elementType, PyLocation &loc) {
MlirType t = mlirRankedTensorTypeGetChecked(
shape.size(), shape.data(), elementType.type, loc.loc);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
throw SetPyError(
PyExc_ValueError,
llvm::Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point, integer, vector or "
"complex "
"type.");
}
return PyRankedTensorType(t);
},
py::keep_alive<0, 2>(), "Create a ranked tensor type");
}
};
/// Unranked Tensor Type subclass - UnrankedTensorType.
class PyUnrankedTensorType : public PyShapedType {
class PyUnrankedTensorType
: public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
static constexpr const char *pyClassName = "UnrankedTensorType";
using PyShapedType::PyShapedType;
// TODO: Switch back to bindDerived by making the ClassTy modifiable by
// subclasses, exposing the ShapedType hierarchy.
static void bind(py::module &m) {
py::class_<PyUnrankedTensorType, PyShapedType>(m, pyClassName)
.def(py::init<PyType &>(), py::keep_alive<0, 1>())
.def_static(
"get_unranked_tensor",
// TODO: Make the location optional and create a default location.
[](PyType &elementType, PyLocation &loc) {
MlirType t =
mlirUnrankedTensorTypeGetChecked(elementType.type, loc.loc);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
throw SetPyError(
PyExc_ValueError,
llvm::Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point, integer, vector or "
"complex "
"type.");
}
return PyUnrankedTensorType(t);
},
py::keep_alive<0, 1>(), "Create a unranked tensor type");
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def_static(
"get_unranked_tensor",
// TODO: Make the location optional and create a default location.
[](PyType &elementType, PyLocation &loc) {
MlirType t =
mlirUnrankedTensorTypeGetChecked(elementType.type, loc.loc);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
throw SetPyError(
PyExc_ValueError,
llvm::Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point, integer, vector or "
"complex "
"type.");
}
return PyUnrankedTensorType(t);
},
py::keep_alive<0, 1>(), "Create a unranked tensor type");
}
};
/// Ranked MemRef Type subclass - MemRefType.
class PyMemRefType : public PyShapedType {
class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
static constexpr const char *pyClassName = "MemRefType";
using PyShapedType::PyShapedType;
// TODO: Switch back to bindDerived by making the ClassTy modifiable by
// subclasses, exposing the ShapedType hierarchy.
static void bind(py::module &m) {
py::class_<PyMemRefType, PyShapedType>(m, pyClassName)
.def(py::init<PyType &>(), py::keep_alive<0, 1>())
// TODO: Add mlirMemRefTypeGet and mlirMemRefTypeGetAffineMap binding
// once the affine map binding is completed.
.def_static(
"get_contiguous_memref",
// TODO: Make the location optional and create a default location.
[](PyType &elementType, std::vector<int64_t> shape,
unsigned memorySpace, PyLocation &loc) {
MlirType t = mlirMemRefTypeContiguousGetChecked(
elementType.type, shape.size(), shape.data(), memorySpace,
loc.loc);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
throw SetPyError(
PyExc_ValueError,
llvm::Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point, integer, vector or "
"complex "
"type.");
}
return PyMemRefType(t);
},
py::keep_alive<0, 1>(), "Create a memref type")
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
// TODO: Add mlirMemRefTypeGet and mlirMemRefTypeGetAffineMap binding
// once the affine map binding is completed.
c.def_static(
"get_contiguous_memref",
// TODO: Make the location optional and create a default location.
[](PyType &elementType, std::vector<int64_t> shape,
unsigned memorySpace, PyLocation &loc) {
MlirType t = mlirMemRefTypeContiguousGetChecked(
elementType.type, shape.size(), shape.data(), memorySpace,
loc.loc);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
throw SetPyError(
PyExc_ValueError,
llvm::Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point, integer, vector or "
"complex "
"type.");
}
return PyMemRefType(t);
},
py::keep_alive<0, 1>(), "Create a memref type")
.def_property_readonly(
"num_affine_maps",
[](PyMemRefType &self) -> intptr_t {
@ -743,36 +737,34 @@ public:
};
/// Unranked MemRef Type subclass - UnrankedMemRefType.
class PyUnrankedMemRefType : public PyShapedType {
class PyUnrankedMemRefType
: public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
static constexpr const char *pyClassName = "UnrankedMemRefType";
using PyShapedType::PyShapedType;
// TODO: Switch back to bindDerived by making the ClassTy modifiable by
// subclasses, exposing the ShapedType hierarchy.
static void bind(py::module &m) {
py::class_<PyUnrankedMemRefType, PyShapedType>(m, pyClassName)
.def(py::init<PyType &>(), py::keep_alive<0, 1>())
.def_static(
"get_unranked_memref",
// TODO: Make the location optional and create a default location.
[](PyType &elementType, unsigned memorySpace, PyLocation &loc) {
MlirType t = mlirUnrankedMemRefTypeGetChecked(
elementType.type, memorySpace, loc.loc);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
throw SetPyError(
PyExc_ValueError,
llvm::Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point, integer, vector or "
"complex "
"type.");
}
return PyUnrankedMemRefType(t);
},
py::keep_alive<0, 1>(), "Create a unranked memref type")
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def_static(
"get_unranked_memref",
// TODO: Make the location optional and create a default location.
[](PyType &elementType, unsigned memorySpace, PyLocation &loc) {
MlirType t = mlirUnrankedMemRefTypeGetChecked(elementType.type,
memorySpace, loc.loc);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
throw SetPyError(
PyExc_ValueError,
llvm::Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point, integer, vector or "
"complex "
"type.");
}
return PyUnrankedMemRefType(t);
},
py::keep_alive<0, 1>(), "Create a unranked memref type")
.def_property_readonly(
"memory_space",
[](PyUnrankedMemRefType &self) -> unsigned {

View File

@ -177,11 +177,11 @@ def testComplexType():
run(testComplexType)
# CHECK-LABEL: TEST: testShapedType
# CHECK-LABEL: TEST: testConcreteShapedType
# Shaped type is not a kind of standard types, it is the base class for
# vectors, memrefs and tensors, so this test case uses an instance of vector
# to test the shaped type.
def testShapedType():
# to test the shaped type. The class hierarchy is preserved on the python side.
def testConcreteShapedType():
ctx = mlir.ir.Context()
vector = mlir.ir.VectorType(ctx.parse_type("vector<2x3xf32>"))
# CHECK: element type: f32
@ -196,12 +196,25 @@ def testShapedType():
print("whether the dim-th dimension is dynamic:", vector.is_dynamic_dim(0))
# CHECK: dim size: 3
print("dim size:", vector.get_dim_size(1))
# CHECK: False
print(vector.is_dynamic_size(3))
# CHECK: False
print(vector.is_dynamic_stride_or_offset(1))
# CHECK: is_dynamic_size: False
print("is_dynamic_size:", vector.is_dynamic_size(3))
# CHECK: is_dynamic_stride_or_offset: False
print("is_dynamic_stride_or_offset:", vector.is_dynamic_stride_or_offset(1))
# CHECK: isinstance(ShapedType): True
print("isinstance(ShapedType):", isinstance(vector, mlir.ir.ShapedType))
run(testShapedType)
run(testConcreteShapedType)
# CHECK-LABEL: TEST: testAbstractShapedType
# Tests that ShapedType operates as an abstract base class of a concrete
# shaped type (using vector as an example).
def testAbstractShapedType():
ctx = mlir.ir.Context()
vector = mlir.ir.ShapedType(ctx.parse_type("vector<2x3xf32>"))
# CHECK: element type: f32
print("element type:", vector.element_type)
run(testAbstractShapedType)
# CHECK-LABEL: TEST: testVectorType
def testVectorType():