forked from OSchip/llvm-project
Extend PyConcreteType to support intermediate base classes.
* Resolves todos from D87091. * Also modifies PyConcreteAttribute to follow suite (should be useful for ElementsAttr and friends). * Adds a test to ensure that the ShapedType base class functions as expected. Differential Revision: https://reviews.llvm.org/D87208
This commit is contained in:
parent
7695332166
commit
7403e3ee32
|
@ -221,34 +221,37 @@ namespace {
|
|||
|
||||
/// CRTP base classes for Python attributes that subclass Attribute and should
|
||||
/// be castable from it (i.e. via something like StringAttr(attr)).
|
||||
template <typename T>
|
||||
class PyConcreteAttribute : public PyAttribute {
|
||||
/// By default, attribute class hierarchies are one level deep (i.e. a
|
||||
/// concrete attribute class extends PyAttribute); however, intermediate
|
||||
/// python-visible base classes can be modeled by specifying a BaseTy.
|
||||
template <typename DerivedTy, typename BaseTy = PyAttribute>
|
||||
class PyConcreteAttribute : public BaseTy {
|
||||
public:
|
||||
// Derived classes must define statics for:
|
||||
// IsAFunctionTy isaFunction
|
||||
// const char *pyClassName
|
||||
using ClassTy = py::class_<T, PyAttribute>;
|
||||
using ClassTy = py::class_<DerivedTy, PyAttribute>;
|
||||
using IsAFunctionTy = int (*)(MlirAttribute);
|
||||
|
||||
PyConcreteAttribute() = default;
|
||||
PyConcreteAttribute(MlirAttribute attr) : PyAttribute(attr) {}
|
||||
PyConcreteAttribute(MlirAttribute attr) : BaseTy(attr) {}
|
||||
PyConcreteAttribute(PyAttribute &orig)
|
||||
: PyConcreteAttribute(castFrom(orig)) {}
|
||||
|
||||
static MlirAttribute castFrom(PyAttribute &orig) {
|
||||
if (!T::isaFunction(orig.attr)) {
|
||||
if (!DerivedTy::isaFunction(orig.attr)) {
|
||||
auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
|
||||
throw SetPyError(PyExc_ValueError,
|
||||
llvm::Twine("Cannot cast attribute to ") +
|
||||
T::pyClassName + " (from " + origRepr + ")");
|
||||
DerivedTy::pyClassName + " (from " + origRepr + ")");
|
||||
}
|
||||
return orig.attr;
|
||||
}
|
||||
|
||||
static void bind(py::module &m) {
|
||||
auto cls = ClassTy(m, T::pyClassName);
|
||||
auto cls = ClassTy(m, DerivedTy::pyClassName);
|
||||
cls.def(py::init<PyAttribute &>(), py::keep_alive<0, 1>());
|
||||
T::bindDerived(cls);
|
||||
DerivedTy::bindDerived(cls);
|
||||
}
|
||||
|
||||
/// Implemented by derived classes to add methods to the Python subclass.
|
||||
|
@ -301,33 +304,36 @@ namespace {
|
|||
|
||||
/// CRTP base classes for Python types that subclass Type and should be
|
||||
/// castable from it (i.e. via something like IntegerType(t)).
|
||||
template <typename T>
|
||||
class PyConcreteType : public PyType {
|
||||
/// By default, type class hierarchies are one level deep (i.e. a
|
||||
/// concrete type class extends PyType); however, intermediate python-visible
|
||||
/// base classes can be modeled by specifying a BaseTy.
|
||||
template <typename DerivedTy, typename BaseTy = PyType>
|
||||
class PyConcreteType : public BaseTy {
|
||||
public:
|
||||
// Derived classes must define statics for:
|
||||
// IsAFunctionTy isaFunction
|
||||
// const char *pyClassName
|
||||
using ClassTy = py::class_<T, PyType>;
|
||||
using ClassTy = py::class_<DerivedTy, BaseTy>;
|
||||
using IsAFunctionTy = int (*)(MlirType);
|
||||
|
||||
PyConcreteType() = default;
|
||||
PyConcreteType(MlirType t) : PyType(t) {}
|
||||
PyConcreteType(PyType &orig) : PyType(castFrom(orig)) {}
|
||||
PyConcreteType(MlirType t) : BaseTy(t) {}
|
||||
PyConcreteType(PyType &orig) : PyConcreteType(castFrom(orig)) {}
|
||||
|
||||
static MlirType castFrom(PyType &orig) {
|
||||
if (!T::isaFunction(orig.type)) {
|
||||
if (!DerivedTy::isaFunction(orig.type)) {
|
||||
auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
|
||||
throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") +
|
||||
T::pyClassName + " (from " +
|
||||
origRepr + ")");
|
||||
DerivedTy::pyClassName +
|
||||
" (from " + origRepr + ")");
|
||||
}
|
||||
return orig.type;
|
||||
}
|
||||
|
||||
static void bind(py::module &m) {
|
||||
auto cls = ClassTy(m, T::pyClassName);
|
||||
auto cls = ClassTy(m, DerivedTy::pyClassName);
|
||||
cls.def(py::init<PyType &>(), py::keep_alive<0, 1>());
|
||||
T::bindDerived(cls);
|
||||
DerivedTy::bindDerived(cls);
|
||||
}
|
||||
|
||||
/// Implemented by derived classes to add methods to the Python subclass.
|
||||
|
@ -590,142 +596,130 @@ private:
|
|||
};
|
||||
|
||||
/// Vector Type subclass - VectorType.
|
||||
class PyVectorType : public PyShapedType {
|
||||
class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
|
||||
static constexpr const char *pyClassName = "VectorType";
|
||||
using PyShapedType::PyShapedType;
|
||||
// TODO: Switch back to bindDerived by making the ClassTy modifiable by
|
||||
// subclasses, exposing the ShapedType hierarchy.
|
||||
static void bind(py::module &m) {
|
||||
py::class_<PyVectorType, PyShapedType>(m, pyClassName)
|
||||
.def(py::init<PyType &>(), py::keep_alive<0, 1>())
|
||||
.def_static(
|
||||
"get_vector",
|
||||
// TODO: Make the location optional and create a default location.
|
||||
[](std::vector<int64_t> shape, PyType &elementType,
|
||||
PyLocation &loc) {
|
||||
MlirType t = mlirVectorTypeGetChecked(shape.size(), shape.data(),
|
||||
elementType.type, loc.loc);
|
||||
// TODO: Rework error reporting once diagnostic engine is exposed
|
||||
// in C API.
|
||||
if (mlirTypeIsNull(t)) {
|
||||
throw SetPyError(
|
||||
PyExc_ValueError,
|
||||
llvm::Twine("invalid '") +
|
||||
py::repr(py::cast(elementType)).cast<std::string>() +
|
||||
"' and expected floating point or integer type.");
|
||||
}
|
||||
return PyVectorType(t);
|
||||
},
|
||||
py::keep_alive<0, 2>(), "Create a vector type");
|
||||
using PyConcreteType::PyConcreteType;
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_static(
|
||||
"get_vector",
|
||||
// TODO: Make the location optional and create a default location.
|
||||
[](std::vector<int64_t> shape, PyType &elementType, PyLocation &loc) {
|
||||
MlirType t = mlirVectorTypeGetChecked(shape.size(), shape.data(),
|
||||
elementType.type, loc.loc);
|
||||
// TODO: Rework error reporting once diagnostic engine is exposed
|
||||
// in C API.
|
||||
if (mlirTypeIsNull(t)) {
|
||||
throw SetPyError(
|
||||
PyExc_ValueError,
|
||||
llvm::Twine("invalid '") +
|
||||
py::repr(py::cast(elementType)).cast<std::string>() +
|
||||
"' and expected floating point or integer type.");
|
||||
}
|
||||
return PyVectorType(t);
|
||||
},
|
||||
py::keep_alive<0, 2>(), "Create a vector type");
|
||||
}
|
||||
};
|
||||
|
||||
/// Ranked Tensor Type subclass - RankedTensorType.
|
||||
class PyRankedTensorType : public PyShapedType {
|
||||
class PyRankedTensorType
|
||||
: public PyConcreteType<PyRankedTensorType, PyShapedType> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
|
||||
static constexpr const char *pyClassName = "RankedTensorType";
|
||||
using PyShapedType::PyShapedType;
|
||||
// TODO: Switch back to bindDerived by making the ClassTy modifiable by
|
||||
// subclasses, exposing the ShapedType hierarchy.
|
||||
static void bind(py::module &m) {
|
||||
py::class_<PyRankedTensorType, PyShapedType>(m, pyClassName)
|
||||
.def(py::init<PyType &>(), py::keep_alive<0, 1>())
|
||||
.def_static(
|
||||
"get_ranked_tensor",
|
||||
// TODO: Make the location optional and create a default location.
|
||||
[](std::vector<int64_t> shape, PyType &elementType,
|
||||
PyLocation &loc) {
|
||||
MlirType t = mlirRankedTensorTypeGetChecked(
|
||||
shape.size(), shape.data(), elementType.type, loc.loc);
|
||||
// TODO: Rework error reporting once diagnostic engine is exposed
|
||||
// in C API.
|
||||
if (mlirTypeIsNull(t)) {
|
||||
throw SetPyError(
|
||||
PyExc_ValueError,
|
||||
llvm::Twine("invalid '") +
|
||||
py::repr(py::cast(elementType)).cast<std::string>() +
|
||||
"' and expected floating point, integer, vector or "
|
||||
"complex "
|
||||
"type.");
|
||||
}
|
||||
return PyRankedTensorType(t);
|
||||
},
|
||||
py::keep_alive<0, 2>(), "Create a ranked tensor type");
|
||||
using PyConcreteType::PyConcreteType;
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_static(
|
||||
"get_ranked_tensor",
|
||||
// TODO: Make the location optional and create a default location.
|
||||
[](std::vector<int64_t> shape, PyType &elementType, PyLocation &loc) {
|
||||
MlirType t = mlirRankedTensorTypeGetChecked(
|
||||
shape.size(), shape.data(), elementType.type, loc.loc);
|
||||
// TODO: Rework error reporting once diagnostic engine is exposed
|
||||
// in C API.
|
||||
if (mlirTypeIsNull(t)) {
|
||||
throw SetPyError(
|
||||
PyExc_ValueError,
|
||||
llvm::Twine("invalid '") +
|
||||
py::repr(py::cast(elementType)).cast<std::string>() +
|
||||
"' and expected floating point, integer, vector or "
|
||||
"complex "
|
||||
"type.");
|
||||
}
|
||||
return PyRankedTensorType(t);
|
||||
},
|
||||
py::keep_alive<0, 2>(), "Create a ranked tensor type");
|
||||
}
|
||||
};
|
||||
|
||||
/// Unranked Tensor Type subclass - UnrankedTensorType.
|
||||
class PyUnrankedTensorType : public PyShapedType {
|
||||
class PyUnrankedTensorType
|
||||
: public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
|
||||
static constexpr const char *pyClassName = "UnrankedTensorType";
|
||||
using PyShapedType::PyShapedType;
|
||||
// TODO: Switch back to bindDerived by making the ClassTy modifiable by
|
||||
// subclasses, exposing the ShapedType hierarchy.
|
||||
static void bind(py::module &m) {
|
||||
py::class_<PyUnrankedTensorType, PyShapedType>(m, pyClassName)
|
||||
.def(py::init<PyType &>(), py::keep_alive<0, 1>())
|
||||
.def_static(
|
||||
"get_unranked_tensor",
|
||||
// TODO: Make the location optional and create a default location.
|
||||
[](PyType &elementType, PyLocation &loc) {
|
||||
MlirType t =
|
||||
mlirUnrankedTensorTypeGetChecked(elementType.type, loc.loc);
|
||||
// TODO: Rework error reporting once diagnostic engine is exposed
|
||||
// in C API.
|
||||
if (mlirTypeIsNull(t)) {
|
||||
throw SetPyError(
|
||||
PyExc_ValueError,
|
||||
llvm::Twine("invalid '") +
|
||||
py::repr(py::cast(elementType)).cast<std::string>() +
|
||||
"' and expected floating point, integer, vector or "
|
||||
"complex "
|
||||
"type.");
|
||||
}
|
||||
return PyUnrankedTensorType(t);
|
||||
},
|
||||
py::keep_alive<0, 1>(), "Create a unranked tensor type");
|
||||
using PyConcreteType::PyConcreteType;
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_static(
|
||||
"get_unranked_tensor",
|
||||
// TODO: Make the location optional and create a default location.
|
||||
[](PyType &elementType, PyLocation &loc) {
|
||||
MlirType t =
|
||||
mlirUnrankedTensorTypeGetChecked(elementType.type, loc.loc);
|
||||
// TODO: Rework error reporting once diagnostic engine is exposed
|
||||
// in C API.
|
||||
if (mlirTypeIsNull(t)) {
|
||||
throw SetPyError(
|
||||
PyExc_ValueError,
|
||||
llvm::Twine("invalid '") +
|
||||
py::repr(py::cast(elementType)).cast<std::string>() +
|
||||
"' and expected floating point, integer, vector or "
|
||||
"complex "
|
||||
"type.");
|
||||
}
|
||||
return PyUnrankedTensorType(t);
|
||||
},
|
||||
py::keep_alive<0, 1>(), "Create a unranked tensor type");
|
||||
}
|
||||
};
|
||||
|
||||
/// Ranked MemRef Type subclass - MemRefType.
|
||||
class PyMemRefType : public PyShapedType {
|
||||
class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
|
||||
static constexpr const char *pyClassName = "MemRefType";
|
||||
using PyShapedType::PyShapedType;
|
||||
// TODO: Switch back to bindDerived by making the ClassTy modifiable by
|
||||
// subclasses, exposing the ShapedType hierarchy.
|
||||
static void bind(py::module &m) {
|
||||
py::class_<PyMemRefType, PyShapedType>(m, pyClassName)
|
||||
.def(py::init<PyType &>(), py::keep_alive<0, 1>())
|
||||
// TODO: Add mlirMemRefTypeGet and mlirMemRefTypeGetAffineMap binding
|
||||
// once the affine map binding is completed.
|
||||
.def_static(
|
||||
"get_contiguous_memref",
|
||||
// TODO: Make the location optional and create a default location.
|
||||
[](PyType &elementType, std::vector<int64_t> shape,
|
||||
unsigned memorySpace, PyLocation &loc) {
|
||||
MlirType t = mlirMemRefTypeContiguousGetChecked(
|
||||
elementType.type, shape.size(), shape.data(), memorySpace,
|
||||
loc.loc);
|
||||
// TODO: Rework error reporting once diagnostic engine is exposed
|
||||
// in C API.
|
||||
if (mlirTypeIsNull(t)) {
|
||||
throw SetPyError(
|
||||
PyExc_ValueError,
|
||||
llvm::Twine("invalid '") +
|
||||
py::repr(py::cast(elementType)).cast<std::string>() +
|
||||
"' and expected floating point, integer, vector or "
|
||||
"complex "
|
||||
"type.");
|
||||
}
|
||||
return PyMemRefType(t);
|
||||
},
|
||||
py::keep_alive<0, 1>(), "Create a memref type")
|
||||
using PyConcreteType::PyConcreteType;
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
// TODO: Add mlirMemRefTypeGet and mlirMemRefTypeGetAffineMap binding
|
||||
// once the affine map binding is completed.
|
||||
c.def_static(
|
||||
"get_contiguous_memref",
|
||||
// TODO: Make the location optional and create a default location.
|
||||
[](PyType &elementType, std::vector<int64_t> shape,
|
||||
unsigned memorySpace, PyLocation &loc) {
|
||||
MlirType t = mlirMemRefTypeContiguousGetChecked(
|
||||
elementType.type, shape.size(), shape.data(), memorySpace,
|
||||
loc.loc);
|
||||
// TODO: Rework error reporting once diagnostic engine is exposed
|
||||
// in C API.
|
||||
if (mlirTypeIsNull(t)) {
|
||||
throw SetPyError(
|
||||
PyExc_ValueError,
|
||||
llvm::Twine("invalid '") +
|
||||
py::repr(py::cast(elementType)).cast<std::string>() +
|
||||
"' and expected floating point, integer, vector or "
|
||||
"complex "
|
||||
"type.");
|
||||
}
|
||||
return PyMemRefType(t);
|
||||
},
|
||||
py::keep_alive<0, 1>(), "Create a memref type")
|
||||
.def_property_readonly(
|
||||
"num_affine_maps",
|
||||
[](PyMemRefType &self) -> intptr_t {
|
||||
|
@ -743,36 +737,34 @@ public:
|
|||
};
|
||||
|
||||
/// Unranked MemRef Type subclass - UnrankedMemRefType.
|
||||
class PyUnrankedMemRefType : public PyShapedType {
|
||||
class PyUnrankedMemRefType
|
||||
: public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
|
||||
static constexpr const char *pyClassName = "UnrankedMemRefType";
|
||||
using PyShapedType::PyShapedType;
|
||||
// TODO: Switch back to bindDerived by making the ClassTy modifiable by
|
||||
// subclasses, exposing the ShapedType hierarchy.
|
||||
static void bind(py::module &m) {
|
||||
py::class_<PyUnrankedMemRefType, PyShapedType>(m, pyClassName)
|
||||
.def(py::init<PyType &>(), py::keep_alive<0, 1>())
|
||||
.def_static(
|
||||
"get_unranked_memref",
|
||||
// TODO: Make the location optional and create a default location.
|
||||
[](PyType &elementType, unsigned memorySpace, PyLocation &loc) {
|
||||
MlirType t = mlirUnrankedMemRefTypeGetChecked(
|
||||
elementType.type, memorySpace, loc.loc);
|
||||
// TODO: Rework error reporting once diagnostic engine is exposed
|
||||
// in C API.
|
||||
if (mlirTypeIsNull(t)) {
|
||||
throw SetPyError(
|
||||
PyExc_ValueError,
|
||||
llvm::Twine("invalid '") +
|
||||
py::repr(py::cast(elementType)).cast<std::string>() +
|
||||
"' and expected floating point, integer, vector or "
|
||||
"complex "
|
||||
"type.");
|
||||
}
|
||||
return PyUnrankedMemRefType(t);
|
||||
},
|
||||
py::keep_alive<0, 1>(), "Create a unranked memref type")
|
||||
using PyConcreteType::PyConcreteType;
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_static(
|
||||
"get_unranked_memref",
|
||||
// TODO: Make the location optional and create a default location.
|
||||
[](PyType &elementType, unsigned memorySpace, PyLocation &loc) {
|
||||
MlirType t = mlirUnrankedMemRefTypeGetChecked(elementType.type,
|
||||
memorySpace, loc.loc);
|
||||
// TODO: Rework error reporting once diagnostic engine is exposed
|
||||
// in C API.
|
||||
if (mlirTypeIsNull(t)) {
|
||||
throw SetPyError(
|
||||
PyExc_ValueError,
|
||||
llvm::Twine("invalid '") +
|
||||
py::repr(py::cast(elementType)).cast<std::string>() +
|
||||
"' and expected floating point, integer, vector or "
|
||||
"complex "
|
||||
"type.");
|
||||
}
|
||||
return PyUnrankedMemRefType(t);
|
||||
},
|
||||
py::keep_alive<0, 1>(), "Create a unranked memref type")
|
||||
.def_property_readonly(
|
||||
"memory_space",
|
||||
[](PyUnrankedMemRefType &self) -> unsigned {
|
||||
|
|
|
@ -177,11 +177,11 @@ def testComplexType():
|
|||
|
||||
run(testComplexType)
|
||||
|
||||
# CHECK-LABEL: TEST: testShapedType
|
||||
# CHECK-LABEL: TEST: testConcreteShapedType
|
||||
# Shaped type is not a kind of standard types, it is the base class for
|
||||
# vectors, memrefs and tensors, so this test case uses an instance of vector
|
||||
# to test the shaped type.
|
||||
def testShapedType():
|
||||
# to test the shaped type. The class hierarchy is preserved on the python side.
|
||||
def testConcreteShapedType():
|
||||
ctx = mlir.ir.Context()
|
||||
vector = mlir.ir.VectorType(ctx.parse_type("vector<2x3xf32>"))
|
||||
# CHECK: element type: f32
|
||||
|
@ -196,12 +196,25 @@ def testShapedType():
|
|||
print("whether the dim-th dimension is dynamic:", vector.is_dynamic_dim(0))
|
||||
# CHECK: dim size: 3
|
||||
print("dim size:", vector.get_dim_size(1))
|
||||
# CHECK: False
|
||||
print(vector.is_dynamic_size(3))
|
||||
# CHECK: False
|
||||
print(vector.is_dynamic_stride_or_offset(1))
|
||||
# CHECK: is_dynamic_size: False
|
||||
print("is_dynamic_size:", vector.is_dynamic_size(3))
|
||||
# CHECK: is_dynamic_stride_or_offset: False
|
||||
print("is_dynamic_stride_or_offset:", vector.is_dynamic_stride_or_offset(1))
|
||||
# CHECK: isinstance(ShapedType): True
|
||||
print("isinstance(ShapedType):", isinstance(vector, mlir.ir.ShapedType))
|
||||
|
||||
run(testShapedType)
|
||||
run(testConcreteShapedType)
|
||||
|
||||
# CHECK-LABEL: TEST: testAbstractShapedType
|
||||
# Tests that ShapedType operates as an abstract base class of a concrete
|
||||
# shaped type (using vector as an example).
|
||||
def testAbstractShapedType():
|
||||
ctx = mlir.ir.Context()
|
||||
vector = mlir.ir.ShapedType(ctx.parse_type("vector<2x3xf32>"))
|
||||
# CHECK: element type: f32
|
||||
print("element type:", vector.element_type)
|
||||
|
||||
run(testAbstractShapedType)
|
||||
|
||||
# CHECK-LABEL: TEST: testVectorType
|
||||
def testVectorType():
|
||||
|
|
Loading…
Reference in New Issue