diff --git a/mlir/include/mlir-c/StandardTypes.h b/mlir/include/mlir-c/StandardTypes.h index ad28ea546717..eacfe0d39b6a 100644 --- a/mlir/include/mlir-c/StandardTypes.h +++ b/mlir/include/mlir-c/StandardTypes.h @@ -162,6 +162,11 @@ int mlirTypeIsAVector(MlirType type); * is owned by the context. */ MlirType mlirVectorTypeGet(intptr_t rank, int64_t *shape, MlirType elementType); +/** Same as "mlirVectorTypeGet" but returns a nullptr wrapping MlirType on + * illegal arguments, emitting appropriate diagnostics. */ +MlirType mlirVectorTypeGetChecked(intptr_t rank, int64_t *shape, + MlirType elementType, MlirLocation loc); + /*============================================================================*/ /* Ranked / Unranked Tensor type. */ /*============================================================================*/ @@ -180,10 +185,20 @@ int mlirTypeIsAUnrankedTensor(MlirType type); MlirType mlirRankedTensorTypeGet(intptr_t rank, int64_t *shape, MlirType elementType); +/** Same as "mlirRankedTensorTypeGet" but returns a nullptr wrapping MlirType on + * illegal arguments, emitting appropriate diagnostics. */ +MlirType mlirRankedTensorTypeGetChecked(intptr_t rank, int64_t *shape, + MlirType elementType, MlirLocation loc); + /** Creates an unranked tensor type with the given element type in the same * context as the element type. The type is owned by the context. */ MlirType mlirUnrankedTensorTypeGet(MlirType elementType); +/** Same as "mlirUnrankedTensorTypeGet" but returns a nullptr wrapping MlirType + * on illegal arguments, emitting appropriate diagnostics. */ +MlirType mlirUnrankedTensorTypeGetChecked(MlirType elementType, + MlirLocation loc); + /*============================================================================*/ /* Ranked / Unranked MemRef type. */ /*============================================================================*/ @@ -208,10 +223,23 @@ MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank, int64_t *shape, MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank, int64_t *shape, unsigned memorySpace); +/** Same as "mlirMemRefTypeContiguousGet" but returns a nullptr wrapping + * MlirType on illegal arguments, emitting appropriate diagnostics. */ +MlirType mlirMemRefTypeContiguousGetChecked(MlirType elementType, intptr_t rank, + int64_t *shape, + unsigned memorySpace, + MlirLocation loc); + /** Creates an Unranked MemRef type with the given element type and in the given * memory space. The type is owned by the context of element type. */ MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, unsigned memorySpace); +/** Same as "mlirUnrankedMemRefTypeGet" but returns a nullptr wrapping + * MlirType on illegal arguments, emitting appropriate diagnostics. */ +MlirType mlirUnrankedMemRefTypeGetChecked(MlirType elementType, + unsigned memorySpace, + MlirLocation loc); + /** Returns the number of affine layout maps in the given MemRef type. */ intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type); diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp index 70c1a28e92be..149e231aed0b 100644 --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -516,30 +516,269 @@ public: } }; -/// Vector Type subclass - VectorType. -class PyVectorType : public PyConcreteType { +class PyShapedType : public PyConcreteType { public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector; - static constexpr const char *pyClassName = "VectorType"; + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAShaped; + static constexpr const char *pyClassName = "ShapedType"; using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { - c.def_static( - "get_vector", - [](std::vector shape, PyType &elementType) { - // The element must be a floating point or integer scalar type. - if (mlirTypeIsAIntegerOrFloat(elementType.type)) { - MlirType t = - mlirVectorTypeGet(shape.size(), shape.data(), elementType.type); - return PyVectorType(t); - } - throw SetPyError( - PyExc_ValueError, - llvm::Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point or integer type."); + c.def_property_readonly( + "element_type", + [](PyShapedType &self) { + MlirType t = mlirShapedTypeGetElementType(self.type); + return PyType(t); }, - py::keep_alive<0, 2>(), "Create a vector type"); + py::keep_alive<0, 1>(), "Returns the element type of the shaped type."); + c.def_property_readonly( + "has_rank", + [](PyShapedType &self) -> bool { + return mlirShapedTypeHasRank(self.type); + }, + "Returns whether the given shaped type is ranked."); + c.def_property_readonly( + "rank", + [](PyShapedType &self) { + self.requireHasRank(); + return mlirShapedTypeGetRank(self.type); + }, + "Returns the rank of the given ranked shaped type."); + c.def_property_readonly( + "has_static_shape", + [](PyShapedType &self) -> bool { + return mlirShapedTypeHasStaticShape(self.type); + }, + "Returns whether the given shaped type has a static shape."); + c.def( + "is_dynamic_dim", + [](PyShapedType &self, intptr_t dim) -> bool { + self.requireHasRank(); + return mlirShapedTypeIsDynamicDim(self.type, dim); + }, + "Returns whether the dim-th dimension of the given shaped type is " + "dynamic."); + c.def( + "get_dim_size", + [](PyShapedType &self, intptr_t dim) { + self.requireHasRank(); + return mlirShapedTypeGetDimSize(self.type, dim); + }, + "Returns the dim-th dimension of the given ranked shaped type."); + c.def_static( + "is_dynamic_size", + [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); }, + "Returns whether the given dimension size indicates a dynamic " + "dimension."); + c.def( + "is_dynamic_stride_or_offset", + [](PyShapedType &self, int64_t val) -> bool { + self.requireHasRank(); + return mlirShapedTypeIsDynamicStrideOrOffset(val); + }, + "Returns whether the given value is used as a placeholder for dynamic " + "strides and offsets in shaped types."); + } + +private: + void requireHasRank() { + if (!mlirShapedTypeHasRank(type)) { + throw SetPyError( + PyExc_ValueError, + "calling this method requires that the type has a rank."); + } + } +}; + +/// Vector Type subclass - VectorType. +class PyVectorType : public PyShapedType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector; + static constexpr const char *pyClassName = "VectorType"; + using PyShapedType::PyShapedType; + // TODO: Switch back to bindDerived by making the ClassTy modifiable by + // subclasses, exposing the ShapedType hierarchy. + static void bind(py::module &m) { + py::class_(m, pyClassName) + .def(py::init(), py::keep_alive<0, 1>()) + .def_static( + "get_vector", + // TODO: Make the location optional and create a default location. + [](std::vector shape, PyType &elementType, + PyLocation &loc) { + MlirType t = mlirVectorTypeGetChecked(shape.size(), shape.data(), + elementType.type, loc.loc); + // TODO: Rework error reporting once diagnostic engine is exposed + // in C API. + if (mlirTypeIsNull(t)) { + throw SetPyError( + PyExc_ValueError, + llvm::Twine("invalid '") + + py::repr(py::cast(elementType)).cast() + + "' and expected floating point or integer type."); + } + return PyVectorType(t); + }, + py::keep_alive<0, 2>(), "Create a vector type"); + } +}; + +/// Ranked Tensor Type subclass - RankedTensorType. +class PyRankedTensorType : public PyShapedType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; + static constexpr const char *pyClassName = "RankedTensorType"; + using PyShapedType::PyShapedType; + // TODO: Switch back to bindDerived by making the ClassTy modifiable by + // subclasses, exposing the ShapedType hierarchy. + static void bind(py::module &m) { + py::class_(m, pyClassName) + .def(py::init(), py::keep_alive<0, 1>()) + .def_static( + "get_ranked_tensor", + // TODO: Make the location optional and create a default location. + [](std::vector shape, PyType &elementType, + PyLocation &loc) { + MlirType t = mlirRankedTensorTypeGetChecked( + shape.size(), shape.data(), elementType.type, loc.loc); + // TODO: Rework error reporting once diagnostic engine is exposed + // in C API. + if (mlirTypeIsNull(t)) { + throw SetPyError( + PyExc_ValueError, + llvm::Twine("invalid '") + + py::repr(py::cast(elementType)).cast() + + "' and expected floating point, integer, vector or " + "complex " + "type."); + } + return PyRankedTensorType(t); + }, + py::keep_alive<0, 2>(), "Create a ranked tensor type"); + } +}; + +/// Unranked Tensor Type subclass - UnrankedTensorType. +class PyUnrankedTensorType : public PyShapedType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor; + static constexpr const char *pyClassName = "UnrankedTensorType"; + using PyShapedType::PyShapedType; + // TODO: Switch back to bindDerived by making the ClassTy modifiable by + // subclasses, exposing the ShapedType hierarchy. + static void bind(py::module &m) { + py::class_(m, pyClassName) + .def(py::init(), py::keep_alive<0, 1>()) + .def_static( + "get_unranked_tensor", + // TODO: Make the location optional and create a default location. + [](PyType &elementType, PyLocation &loc) { + MlirType t = + mlirUnrankedTensorTypeGetChecked(elementType.type, loc.loc); + // TODO: Rework error reporting once diagnostic engine is exposed + // in C API. + if (mlirTypeIsNull(t)) { + throw SetPyError( + PyExc_ValueError, + llvm::Twine("invalid '") + + py::repr(py::cast(elementType)).cast() + + "' and expected floating point, integer, vector or " + "complex " + "type."); + } + return PyUnrankedTensorType(t); + }, + py::keep_alive<0, 1>(), "Create a unranked tensor type"); + } +}; + +/// Ranked MemRef Type subclass - MemRefType. +class PyMemRefType : public PyShapedType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; + static constexpr const char *pyClassName = "MemRefType"; + using PyShapedType::PyShapedType; + // TODO: Switch back to bindDerived by making the ClassTy modifiable by + // subclasses, exposing the ShapedType hierarchy. + static void bind(py::module &m) { + py::class_(m, pyClassName) + .def(py::init(), py::keep_alive<0, 1>()) + // TODO: Add mlirMemRefTypeGet and mlirMemRefTypeGetAffineMap binding + // once the affine map binding is completed. + .def_static( + "get_contiguous_memref", + // TODO: Make the location optional and create a default location. + [](PyType &elementType, std::vector shape, + unsigned memorySpace, PyLocation &loc) { + MlirType t = mlirMemRefTypeContiguousGetChecked( + elementType.type, shape.size(), shape.data(), memorySpace, + loc.loc); + // TODO: Rework error reporting once diagnostic engine is exposed + // in C API. + if (mlirTypeIsNull(t)) { + throw SetPyError( + PyExc_ValueError, + llvm::Twine("invalid '") + + py::repr(py::cast(elementType)).cast() + + "' and expected floating point, integer, vector or " + "complex " + "type."); + } + return PyMemRefType(t); + }, + py::keep_alive<0, 1>(), "Create a memref type") + .def_property_readonly( + "num_affine_maps", + [](PyMemRefType &self) -> intptr_t { + return mlirMemRefTypeGetNumAffineMaps(self.type); + }, + "Returns the number of affine layout maps in the given MemRef " + "type.") + .def_property_readonly( + "memory_space", + [](PyMemRefType &self) -> unsigned { + return mlirMemRefTypeGetMemorySpace(self.type); + }, + "Returns the memory space of the given MemRef type."); + } +}; + +/// Unranked MemRef Type subclass - UnrankedMemRefType. +class PyUnrankedMemRefType : public PyShapedType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef; + static constexpr const char *pyClassName = "UnrankedMemRefType"; + using PyShapedType::PyShapedType; + // TODO: Switch back to bindDerived by making the ClassTy modifiable by + // subclasses, exposing the ShapedType hierarchy. + static void bind(py::module &m) { + py::class_(m, pyClassName) + .def(py::init(), py::keep_alive<0, 1>()) + .def_static( + "get_unranked_memref", + // TODO: Make the location optional and create a default location. + [](PyType &elementType, unsigned memorySpace, PyLocation &loc) { + MlirType t = mlirUnrankedMemRefTypeGetChecked( + elementType.type, memorySpace, loc.loc); + // TODO: Rework error reporting once diagnostic engine is exposed + // in C API. + if (mlirTypeIsNull(t)) { + throw SetPyError( + PyExc_ValueError, + llvm::Twine("invalid '") + + py::repr(py::cast(elementType)).cast() + + "' and expected floating point, integer, vector or " + "complex " + "type."); + } + return PyUnrankedMemRefType(t); + }, + py::keep_alive<0, 1>(), "Create a unranked memref type") + .def_property_readonly( + "memory_space", + [](PyUnrankedMemRefType &self) -> unsigned { + return mlirUnrankedMemrefGetMemorySpace(self.type); + }, + "Returns the memory space of the given Unranked MemRef type."); } }; @@ -886,6 +1125,11 @@ void mlir::python::populateIRSubmodule(py::module &m) { PyF64Type::bind(m); PyNoneType::bind(m); PyComplexType::bind(m); + PyShapedType::bind(m); PyVectorType::bind(m); + PyRankedTensorType::bind(m); + PyUnrankedTensorType::bind(m); + PyMemRefType::bind(m); + PyUnrankedMemRefType::bind(m); PyTupleType::bind(m); } diff --git a/mlir/lib/CAPI/IR/StandardTypes.cpp b/mlir/lib/CAPI/IR/StandardTypes.cpp index eb006242e880..ddd3a5e93147 100644 --- a/mlir/lib/CAPI/IR/StandardTypes.cpp +++ b/mlir/lib/CAPI/IR/StandardTypes.cpp @@ -168,6 +168,13 @@ MlirType mlirVectorTypeGet(intptr_t rank, int64_t *shape, unwrap(elementType))); } +MlirType mlirVectorTypeGetChecked(intptr_t rank, int64_t *shape, + MlirType elementType, MlirLocation loc) { + return wrap(VectorType::getChecked( + llvm::makeArrayRef(shape, static_cast(rank)), unwrap(elementType), + unwrap(loc))); +} + /* ========================================================================== */ /* Ranked / Unranked tensor type. */ /* ========================================================================== */ @@ -189,10 +196,23 @@ MlirType mlirRankedTensorTypeGet(intptr_t rank, int64_t *shape, unwrap(elementType))); } +MlirType mlirRankedTensorTypeGetChecked(intptr_t rank, int64_t *shape, + MlirType elementType, + MlirLocation loc) { + return wrap(RankedTensorType::getChecked( + llvm::makeArrayRef(shape, static_cast(rank)), unwrap(elementType), + unwrap(loc))); +} + MlirType mlirUnrankedTensorTypeGet(MlirType elementType) { return wrap(UnrankedTensorType::get(unwrap(elementType))); } +MlirType mlirUnrankedTensorTypeGetChecked(MlirType elementType, + MlirLocation loc) { + return wrap(UnrankedTensorType::getChecked(unwrap(elementType), unwrap(loc))); +} + /* ========================================================================== */ /* Ranked / Unranked MemRef type. */ /* ========================================================================== */ @@ -216,6 +236,15 @@ MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank, unwrap(elementType), llvm::None, memorySpace)); } +MlirType mlirMemRefTypeContiguousGetChecked(MlirType elementType, intptr_t rank, + int64_t *shape, + unsigned memorySpace, + MlirLocation loc) { + return wrap(MemRefType::getChecked( + llvm::makeArrayRef(shape, static_cast(rank)), unwrap(elementType), + llvm::None, memorySpace, unwrap(loc))); +} + intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type) { return static_cast( unwrap(type).cast().getAffineMaps().size()); @@ -237,6 +266,13 @@ MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, unsigned memorySpace) { return wrap(UnrankedMemRefType::get(unwrap(elementType), memorySpace)); } +MlirType mlirUnrankedMemRefTypeGetChecked(MlirType elementType, + unsigned memorySpace, + MlirLocation loc) { + return wrap(UnrankedMemRefType::getChecked(unwrap(elementType), memorySpace, + unwrap(loc))); +} + unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type) { return unwrap(type).cast().getMemorySpace(); } diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py index a8f3a3840497..00cd595843aa 100644 --- a/mlir/test/Bindings/Python/ir_types.py +++ b/mlir/test/Bindings/Python/ir_types.py @@ -177,25 +177,187 @@ def testComplexType(): run(testComplexType) +# CHECK-LABEL: TEST: testShapedType +# Shaped type is not a kind of standard types, it is the base class for +# vectors, memrefs and tensors, so this test case uses an instance of vector +# to test the shaped type. +def testShapedType(): + ctx = mlir.ir.Context() + vector = mlir.ir.VectorType(ctx.parse_type("vector<2x3xf32>")) + # CHECK: element type: f32 + print("element type:", vector.element_type) + # CHECK: whether the given shaped type is ranked: True + print("whether the given shaped type is ranked:", vector.has_rank) + # CHECK: rank: 2 + print("rank:", vector.rank) + # CHECK: whether the shaped type has a static shape: True + print("whether the shaped type has a static shape:", vector.has_static_shape) + # CHECK: whether the dim-th dimension is dynamic: False + print("whether the dim-th dimension is dynamic:", vector.is_dynamic_dim(0)) + # CHECK: dim size: 3 + print("dim size:", vector.get_dim_size(1)) + # CHECK: False + print(vector.is_dynamic_size(3)) + # CHECK: False + print(vector.is_dynamic_stride_or_offset(1)) + +run(testShapedType) + # CHECK-LABEL: TEST: testVectorType def testVectorType(): ctx = mlir.ir.Context() f32 = mlir.ir.F32Type(ctx) shape = [2, 3] + loc = ctx.get_unknown_location() # CHECK: vector type: vector<2x3xf32> - print("vector type:", mlir.ir.VectorType.get_vector(shape, f32)) + print("vector type:", mlir.ir.VectorType.get_vector(shape, f32, loc)) - index = mlir.ir.IndexType(ctx) + none = mlir.ir.NoneType(ctx) try: - vector_invalid = mlir.ir.VectorType.get_vector(shape, index) + vector_invalid = mlir.ir.VectorType.get_vector(shape, none, loc) except ValueError as e: - # CHECK: invalid 'Type(index)' and expected floating point or integer type. + # CHECK: invalid 'Type(none)' and expected floating point or integer type. print(e) else: print("Exception not produced") run(testVectorType) +# CHECK-LABEL: TEST: testRankedTensorType +def testRankedTensorType(): + ctx = mlir.ir.Context() + f32 = mlir.ir.F32Type(ctx) + shape = [2, 3] + loc = ctx.get_unknown_location() + # CHECK: ranked tensor type: tensor<2x3xf32> + print("ranked tensor type:", + mlir.ir.RankedTensorType.get_ranked_tensor(shape, f32, loc)) + + none = mlir.ir.NoneType(ctx) + try: + tensor_invalid = mlir.ir.RankedTensorType.get_ranked_tensor(shape, none, + loc) + except ValueError as e: + # CHECK: invalid 'Type(none)' and expected floating point, integer, vector + # CHECK: or complex type. + print(e) + else: + print("Exception not produced") + +run(testRankedTensorType) + +# CHECK-LABEL: TEST: testUnrankedTensorType +def testUnrankedTensorType(): + ctx = mlir.ir.Context() + f32 = mlir.ir.F32Type(ctx) + loc = ctx.get_unknown_location() + unranked_tensor = mlir.ir.UnrankedTensorType.get_unranked_tensor(f32, loc) + # CHECK: unranked tensor type: tensor<*xf32> + print("unranked tensor type:", unranked_tensor) + try: + invalid_rank = unranked_tensor.rank + except ValueError as e: + # CHECK: calling this method requires that the type has a rank. + print(e) + else: + print("Exception not produced") + try: + invalid_is_dynamic_dim = unranked_tensor.is_dynamic_dim(0) + except ValueError as e: + # CHECK: calling this method requires that the type has a rank. + print(e) + else: + print("Exception not produced") + try: + invalid_get_dim_size = unranked_tensor.get_dim_size(1) + except ValueError as e: + # CHECK: calling this method requires that the type has a rank. + print(e) + else: + print("Exception not produced") + + none = mlir.ir.NoneType(ctx) + try: + tensor_invalid = mlir.ir.UnrankedTensorType.get_unranked_tensor(none, loc) + except ValueError as e: + # CHECK: invalid 'Type(none)' and expected floating point, integer, vector + # CHECK: or complex type. + print(e) + else: + print("Exception not produced") + +run(testUnrankedTensorType) + +# CHECK-LABEL: TEST: testMemRefType +def testMemRefType(): + ctx = mlir.ir.Context() + f32 = mlir.ir.F32Type(ctx) + shape = [2, 3] + loc = ctx.get_unknown_location() + memref = mlir.ir.MemRefType.get_contiguous_memref(f32, shape, 2, loc) + # CHECK: memref type: memref<2x3xf32, 2> + print("memref type:", memref) + # CHECK: number of affine layout maps: 0 + print("number of affine layout maps:", memref.num_affine_maps) + # CHECK: memory space: 2 + print("memory space:", memref.memory_space) + + none = mlir.ir.NoneType(ctx) + try: + memref_invalid = mlir.ir.MemRefType.get_contiguous_memref(none, shape, 2, + loc) + except ValueError as e: + # CHECK: invalid 'Type(none)' and expected floating point, integer, vector + # CHECK: or complex type. + print(e) + else: + print("Exception not produced") + +run(testMemRefType) + +# CHECK-LABEL: TEST: testUnrankedMemRefType +def testUnrankedMemRefType(): + ctx = mlir.ir.Context() + f32 = mlir.ir.F32Type(ctx) + loc = ctx.get_unknown_location() + unranked_memref = mlir.ir.UnrankedMemRefType.get_unranked_memref(f32, 2, loc) + # CHECK: unranked memref type: memref<*xf32, 2> + print("unranked memref type:", unranked_memref) + try: + invalid_rank = unranked_memref.rank + except ValueError as e: + # CHECK: calling this method requires that the type has a rank. + print(e) + else: + print("Exception not produced") + try: + invalid_is_dynamic_dim = unranked_memref.is_dynamic_dim(0) + except ValueError as e: + # CHECK: calling this method requires that the type has a rank. + print(e) + else: + print("Exception not produced") + try: + invalid_get_dim_size = unranked_memref.get_dim_size(1) + except ValueError as e: + # CHECK: calling this method requires that the type has a rank. + print(e) + else: + print("Exception not produced") + + none = mlir.ir.NoneType(ctx) + try: + memref_invalid = mlir.ir.UnrankedMemRefType.get_unranked_memref(none, 2, + loc) + except ValueError as e: + # CHECK: invalid 'Type(none)' and expected floating point, integer, vector + # CHECK: or complex type. + print(e) + else: + print("Exception not produced") + +run(testUnrankedMemRefType) + # CHECK-LABEL: TEST: testTupleType def testTupleType(): ctx = mlir.ir.Context()