[mlir] Add Shaped Type, Tensor Type and MemRef Type to python bindings.

Based on the PyType and PyConcreteType classes, this patch implements the bindings of Shaped Type, Tensor Type and MemRef Type subclasses.
The Tensor Type and MemRef Type are bound as ranked and unranked separately.
This patch adds the ***GetChecked C API to make sure the python side can get a valid type or a nullptr.
Shaped type is not a kind of standard types, it is the base class for vectors, memrefs and tensors, this patch binds the PyShapedType class as the base class of Vector Type, Tensor Type and MemRef Type subclasses.

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D87091
This commit is contained in:
zhanghb97 2020-09-06 11:37:16 -07:00 committed by Stella Laurenzo
parent bbb3baf620
commit 54d432aa6b
4 changed files with 493 additions and 23 deletions

View File

@ -162,6 +162,11 @@ int mlirTypeIsAVector(MlirType type);
* is owned by the context. */
MlirType mlirVectorTypeGet(intptr_t rank, int64_t *shape, MlirType elementType);
/** Same as "mlirVectorTypeGet" but returns a nullptr wrapping MlirType on
* illegal arguments, emitting appropriate diagnostics. */
MlirType mlirVectorTypeGetChecked(intptr_t rank, int64_t *shape,
MlirType elementType, MlirLocation loc);
/*============================================================================*/
/* Ranked / Unranked Tensor type. */
/*============================================================================*/
@ -180,10 +185,20 @@ int mlirTypeIsAUnrankedTensor(MlirType type);
MlirType mlirRankedTensorTypeGet(intptr_t rank, int64_t *shape,
MlirType elementType);
/** Same as "mlirRankedTensorTypeGet" but returns a nullptr wrapping MlirType on
* illegal arguments, emitting appropriate diagnostics. */
MlirType mlirRankedTensorTypeGetChecked(intptr_t rank, int64_t *shape,
MlirType elementType, MlirLocation loc);
/** Creates an unranked tensor type with the given element type in the same
* context as the element type. The type is owned by the context. */
MlirType mlirUnrankedTensorTypeGet(MlirType elementType);
/** Same as "mlirUnrankedTensorTypeGet" but returns a nullptr wrapping MlirType
* on illegal arguments, emitting appropriate diagnostics. */
MlirType mlirUnrankedTensorTypeGetChecked(MlirType elementType,
MlirLocation loc);
/*============================================================================*/
/* Ranked / Unranked MemRef type. */
/*============================================================================*/
@ -208,10 +223,23 @@ MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank, int64_t *shape,
MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank,
int64_t *shape, unsigned memorySpace);
/** Same as "mlirMemRefTypeContiguousGet" but returns a nullptr wrapping
* MlirType on illegal arguments, emitting appropriate diagnostics. */
MlirType mlirMemRefTypeContiguousGetChecked(MlirType elementType, intptr_t rank,
int64_t *shape,
unsigned memorySpace,
MlirLocation loc);
/** Creates an Unranked MemRef type with the given element type and in the given
* memory space. The type is owned by the context of element type. */
MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, unsigned memorySpace);
/** Same as "mlirUnrankedMemRefTypeGet" but returns a nullptr wrapping
* MlirType on illegal arguments, emitting appropriate diagnostics. */
MlirType mlirUnrankedMemRefTypeGetChecked(MlirType elementType,
unsigned memorySpace,
MlirLocation loc);
/** Returns the number of affine layout maps in the given MemRef type. */
intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type);

View File

@ -516,30 +516,269 @@ public:
}
};
/// Vector Type subclass - VectorType.
class PyVectorType : public PyConcreteType<PyVectorType> {
class PyShapedType : public PyConcreteType<PyShapedType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
static constexpr const char *pyClassName = "VectorType";
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAShaped;
static constexpr const char *pyClassName = "ShapedType";
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def_static(
"get_vector",
[](std::vector<int64_t> shape, PyType &elementType) {
// The element must be a floating point or integer scalar type.
if (mlirTypeIsAIntegerOrFloat(elementType.type)) {
MlirType t =
mlirVectorTypeGet(shape.size(), shape.data(), elementType.type);
return PyVectorType(t);
}
throw SetPyError(
PyExc_ValueError,
llvm::Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point or integer type.");
c.def_property_readonly(
"element_type",
[](PyShapedType &self) {
MlirType t = mlirShapedTypeGetElementType(self.type);
return PyType(t);
},
py::keep_alive<0, 2>(), "Create a vector type");
py::keep_alive<0, 1>(), "Returns the element type of the shaped type.");
c.def_property_readonly(
"has_rank",
[](PyShapedType &self) -> bool {
return mlirShapedTypeHasRank(self.type);
},
"Returns whether the given shaped type is ranked.");
c.def_property_readonly(
"rank",
[](PyShapedType &self) {
self.requireHasRank();
return mlirShapedTypeGetRank(self.type);
},
"Returns the rank of the given ranked shaped type.");
c.def_property_readonly(
"has_static_shape",
[](PyShapedType &self) -> bool {
return mlirShapedTypeHasStaticShape(self.type);
},
"Returns whether the given shaped type has a static shape.");
c.def(
"is_dynamic_dim",
[](PyShapedType &self, intptr_t dim) -> bool {
self.requireHasRank();
return mlirShapedTypeIsDynamicDim(self.type, dim);
},
"Returns whether the dim-th dimension of the given shaped type is "
"dynamic.");
c.def(
"get_dim_size",
[](PyShapedType &self, intptr_t dim) {
self.requireHasRank();
return mlirShapedTypeGetDimSize(self.type, dim);
},
"Returns the dim-th dimension of the given ranked shaped type.");
c.def_static(
"is_dynamic_size",
[](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); },
"Returns whether the given dimension size indicates a dynamic "
"dimension.");
c.def(
"is_dynamic_stride_or_offset",
[](PyShapedType &self, int64_t val) -> bool {
self.requireHasRank();
return mlirShapedTypeIsDynamicStrideOrOffset(val);
},
"Returns whether the given value is used as a placeholder for dynamic "
"strides and offsets in shaped types.");
}
private:
void requireHasRank() {
if (!mlirShapedTypeHasRank(type)) {
throw SetPyError(
PyExc_ValueError,
"calling this method requires that the type has a rank.");
}
}
};
/// Vector Type subclass - VectorType.
class PyVectorType : public PyShapedType {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
static constexpr const char *pyClassName = "VectorType";
using PyShapedType::PyShapedType;
// TODO: Switch back to bindDerived by making the ClassTy modifiable by
// subclasses, exposing the ShapedType hierarchy.
static void bind(py::module &m) {
py::class_<PyVectorType, PyShapedType>(m, pyClassName)
.def(py::init<PyType &>(), py::keep_alive<0, 1>())
.def_static(
"get_vector",
// TODO: Make the location optional and create a default location.
[](std::vector<int64_t> shape, PyType &elementType,
PyLocation &loc) {
MlirType t = mlirVectorTypeGetChecked(shape.size(), shape.data(),
elementType.type, loc.loc);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
throw SetPyError(
PyExc_ValueError,
llvm::Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point or integer type.");
}
return PyVectorType(t);
},
py::keep_alive<0, 2>(), "Create a vector type");
}
};
/// Ranked Tensor Type subclass - RankedTensorType.
class PyRankedTensorType : public PyShapedType {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
static constexpr const char *pyClassName = "RankedTensorType";
using PyShapedType::PyShapedType;
// TODO: Switch back to bindDerived by making the ClassTy modifiable by
// subclasses, exposing the ShapedType hierarchy.
static void bind(py::module &m) {
py::class_<PyRankedTensorType, PyShapedType>(m, pyClassName)
.def(py::init<PyType &>(), py::keep_alive<0, 1>())
.def_static(
"get_ranked_tensor",
// TODO: Make the location optional and create a default location.
[](std::vector<int64_t> shape, PyType &elementType,
PyLocation &loc) {
MlirType t = mlirRankedTensorTypeGetChecked(
shape.size(), shape.data(), elementType.type, loc.loc);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
throw SetPyError(
PyExc_ValueError,
llvm::Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point, integer, vector or "
"complex "
"type.");
}
return PyRankedTensorType(t);
},
py::keep_alive<0, 2>(), "Create a ranked tensor type");
}
};
/// Unranked Tensor Type subclass - UnrankedTensorType.
class PyUnrankedTensorType : public PyShapedType {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
static constexpr const char *pyClassName = "UnrankedTensorType";
using PyShapedType::PyShapedType;
// TODO: Switch back to bindDerived by making the ClassTy modifiable by
// subclasses, exposing the ShapedType hierarchy.
static void bind(py::module &m) {
py::class_<PyUnrankedTensorType, PyShapedType>(m, pyClassName)
.def(py::init<PyType &>(), py::keep_alive<0, 1>())
.def_static(
"get_unranked_tensor",
// TODO: Make the location optional and create a default location.
[](PyType &elementType, PyLocation &loc) {
MlirType t =
mlirUnrankedTensorTypeGetChecked(elementType.type, loc.loc);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
throw SetPyError(
PyExc_ValueError,
llvm::Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point, integer, vector or "
"complex "
"type.");
}
return PyUnrankedTensorType(t);
},
py::keep_alive<0, 1>(), "Create a unranked tensor type");
}
};
/// Ranked MemRef Type subclass - MemRefType.
class PyMemRefType : public PyShapedType {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
static constexpr const char *pyClassName = "MemRefType";
using PyShapedType::PyShapedType;
// TODO: Switch back to bindDerived by making the ClassTy modifiable by
// subclasses, exposing the ShapedType hierarchy.
static void bind(py::module &m) {
py::class_<PyMemRefType, PyShapedType>(m, pyClassName)
.def(py::init<PyType &>(), py::keep_alive<0, 1>())
// TODO: Add mlirMemRefTypeGet and mlirMemRefTypeGetAffineMap binding
// once the affine map binding is completed.
.def_static(
"get_contiguous_memref",
// TODO: Make the location optional and create a default location.
[](PyType &elementType, std::vector<int64_t> shape,
unsigned memorySpace, PyLocation &loc) {
MlirType t = mlirMemRefTypeContiguousGetChecked(
elementType.type, shape.size(), shape.data(), memorySpace,
loc.loc);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
throw SetPyError(
PyExc_ValueError,
llvm::Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point, integer, vector or "
"complex "
"type.");
}
return PyMemRefType(t);
},
py::keep_alive<0, 1>(), "Create a memref type")
.def_property_readonly(
"num_affine_maps",
[](PyMemRefType &self) -> intptr_t {
return mlirMemRefTypeGetNumAffineMaps(self.type);
},
"Returns the number of affine layout maps in the given MemRef "
"type.")
.def_property_readonly(
"memory_space",
[](PyMemRefType &self) -> unsigned {
return mlirMemRefTypeGetMemorySpace(self.type);
},
"Returns the memory space of the given MemRef type.");
}
};
/// Unranked MemRef Type subclass - UnrankedMemRefType.
class PyUnrankedMemRefType : public PyShapedType {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
static constexpr const char *pyClassName = "UnrankedMemRefType";
using PyShapedType::PyShapedType;
// TODO: Switch back to bindDerived by making the ClassTy modifiable by
// subclasses, exposing the ShapedType hierarchy.
static void bind(py::module &m) {
py::class_<PyUnrankedMemRefType, PyShapedType>(m, pyClassName)
.def(py::init<PyType &>(), py::keep_alive<0, 1>())
.def_static(
"get_unranked_memref",
// TODO: Make the location optional and create a default location.
[](PyType &elementType, unsigned memorySpace, PyLocation &loc) {
MlirType t = mlirUnrankedMemRefTypeGetChecked(
elementType.type, memorySpace, loc.loc);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
throw SetPyError(
PyExc_ValueError,
llvm::Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point, integer, vector or "
"complex "
"type.");
}
return PyUnrankedMemRefType(t);
},
py::keep_alive<0, 1>(), "Create a unranked memref type")
.def_property_readonly(
"memory_space",
[](PyUnrankedMemRefType &self) -> unsigned {
return mlirUnrankedMemrefGetMemorySpace(self.type);
},
"Returns the memory space of the given Unranked MemRef type.");
}
};
@ -886,6 +1125,11 @@ void mlir::python::populateIRSubmodule(py::module &m) {
PyF64Type::bind(m);
PyNoneType::bind(m);
PyComplexType::bind(m);
PyShapedType::bind(m);
PyVectorType::bind(m);
PyRankedTensorType::bind(m);
PyUnrankedTensorType::bind(m);
PyMemRefType::bind(m);
PyUnrankedMemRefType::bind(m);
PyTupleType::bind(m);
}

View File

@ -168,6 +168,13 @@ MlirType mlirVectorTypeGet(intptr_t rank, int64_t *shape,
unwrap(elementType)));
}
MlirType mlirVectorTypeGetChecked(intptr_t rank, int64_t *shape,
MlirType elementType, MlirLocation loc) {
return wrap(VectorType::getChecked(
llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
unwrap(loc)));
}
/* ========================================================================== */
/* Ranked / Unranked tensor type. */
/* ========================================================================== */
@ -189,10 +196,23 @@ MlirType mlirRankedTensorTypeGet(intptr_t rank, int64_t *shape,
unwrap(elementType)));
}
MlirType mlirRankedTensorTypeGetChecked(intptr_t rank, int64_t *shape,
MlirType elementType,
MlirLocation loc) {
return wrap(RankedTensorType::getChecked(
llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
unwrap(loc)));
}
MlirType mlirUnrankedTensorTypeGet(MlirType elementType) {
return wrap(UnrankedTensorType::get(unwrap(elementType)));
}
MlirType mlirUnrankedTensorTypeGetChecked(MlirType elementType,
MlirLocation loc) {
return wrap(UnrankedTensorType::getChecked(unwrap(elementType), unwrap(loc)));
}
/* ========================================================================== */
/* Ranked / Unranked MemRef type. */
/* ========================================================================== */
@ -216,6 +236,15 @@ MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank,
unwrap(elementType), llvm::None, memorySpace));
}
MlirType mlirMemRefTypeContiguousGetChecked(MlirType elementType, intptr_t rank,
int64_t *shape,
unsigned memorySpace,
MlirLocation loc) {
return wrap(MemRefType::getChecked(
llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
llvm::None, memorySpace, unwrap(loc)));
}
intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type) {
return static_cast<intptr_t>(
unwrap(type).cast<MemRefType>().getAffineMaps().size());
@ -237,6 +266,13 @@ MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, unsigned memorySpace) {
return wrap(UnrankedMemRefType::get(unwrap(elementType), memorySpace));
}
MlirType mlirUnrankedMemRefTypeGetChecked(MlirType elementType,
unsigned memorySpace,
MlirLocation loc) {
return wrap(UnrankedMemRefType::getChecked(unwrap(elementType), memorySpace,
unwrap(loc)));
}
unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type) {
return unwrap(type).cast<UnrankedMemRefType>().getMemorySpace();
}

View File

@ -177,25 +177,187 @@ def testComplexType():
run(testComplexType)
# CHECK-LABEL: TEST: testShapedType
# Shaped type is not a kind of standard types, it is the base class for
# vectors, memrefs and tensors, so this test case uses an instance of vector
# to test the shaped type.
def testShapedType():
ctx = mlir.ir.Context()
vector = mlir.ir.VectorType(ctx.parse_type("vector<2x3xf32>"))
# CHECK: element type: f32
print("element type:", vector.element_type)
# CHECK: whether the given shaped type is ranked: True
print("whether the given shaped type is ranked:", vector.has_rank)
# CHECK: rank: 2
print("rank:", vector.rank)
# CHECK: whether the shaped type has a static shape: True
print("whether the shaped type has a static shape:", vector.has_static_shape)
# CHECK: whether the dim-th dimension is dynamic: False
print("whether the dim-th dimension is dynamic:", vector.is_dynamic_dim(0))
# CHECK: dim size: 3
print("dim size:", vector.get_dim_size(1))
# CHECK: False
print(vector.is_dynamic_size(3))
# CHECK: False
print(vector.is_dynamic_stride_or_offset(1))
run(testShapedType)
# CHECK-LABEL: TEST: testVectorType
def testVectorType():
ctx = mlir.ir.Context()
f32 = mlir.ir.F32Type(ctx)
shape = [2, 3]
loc = ctx.get_unknown_location()
# CHECK: vector type: vector<2x3xf32>
print("vector type:", mlir.ir.VectorType.get_vector(shape, f32))
print("vector type:", mlir.ir.VectorType.get_vector(shape, f32, loc))
index = mlir.ir.IndexType(ctx)
none = mlir.ir.NoneType(ctx)
try:
vector_invalid = mlir.ir.VectorType.get_vector(shape, index)
vector_invalid = mlir.ir.VectorType.get_vector(shape, none, loc)
except ValueError as e:
# CHECK: invalid 'Type(index)' and expected floating point or integer type.
# CHECK: invalid 'Type(none)' and expected floating point or integer type.
print(e)
else:
print("Exception not produced")
run(testVectorType)
# CHECK-LABEL: TEST: testRankedTensorType
def testRankedTensorType():
ctx = mlir.ir.Context()
f32 = mlir.ir.F32Type(ctx)
shape = [2, 3]
loc = ctx.get_unknown_location()
# CHECK: ranked tensor type: tensor<2x3xf32>
print("ranked tensor type:",
mlir.ir.RankedTensorType.get_ranked_tensor(shape, f32, loc))
none = mlir.ir.NoneType(ctx)
try:
tensor_invalid = mlir.ir.RankedTensorType.get_ranked_tensor(shape, none,
loc)
except ValueError as e:
# CHECK: invalid 'Type(none)' and expected floating point, integer, vector
# CHECK: or complex type.
print(e)
else:
print("Exception not produced")
run(testRankedTensorType)
# CHECK-LABEL: TEST: testUnrankedTensorType
def testUnrankedTensorType():
ctx = mlir.ir.Context()
f32 = mlir.ir.F32Type(ctx)
loc = ctx.get_unknown_location()
unranked_tensor = mlir.ir.UnrankedTensorType.get_unranked_tensor(f32, loc)
# CHECK: unranked tensor type: tensor<*xf32>
print("unranked tensor type:", unranked_tensor)
try:
invalid_rank = unranked_tensor.rank
except ValueError as e:
# CHECK: calling this method requires that the type has a rank.
print(e)
else:
print("Exception not produced")
try:
invalid_is_dynamic_dim = unranked_tensor.is_dynamic_dim(0)
except ValueError as e:
# CHECK: calling this method requires that the type has a rank.
print(e)
else:
print("Exception not produced")
try:
invalid_get_dim_size = unranked_tensor.get_dim_size(1)
except ValueError as e:
# CHECK: calling this method requires that the type has a rank.
print(e)
else:
print("Exception not produced")
none = mlir.ir.NoneType(ctx)
try:
tensor_invalid = mlir.ir.UnrankedTensorType.get_unranked_tensor(none, loc)
except ValueError as e:
# CHECK: invalid 'Type(none)' and expected floating point, integer, vector
# CHECK: or complex type.
print(e)
else:
print("Exception not produced")
run(testUnrankedTensorType)
# CHECK-LABEL: TEST: testMemRefType
def testMemRefType():
ctx = mlir.ir.Context()
f32 = mlir.ir.F32Type(ctx)
shape = [2, 3]
loc = ctx.get_unknown_location()
memref = mlir.ir.MemRefType.get_contiguous_memref(f32, shape, 2, loc)
# CHECK: memref type: memref<2x3xf32, 2>
print("memref type:", memref)
# CHECK: number of affine layout maps: 0
print("number of affine layout maps:", memref.num_affine_maps)
# CHECK: memory space: 2
print("memory space:", memref.memory_space)
none = mlir.ir.NoneType(ctx)
try:
memref_invalid = mlir.ir.MemRefType.get_contiguous_memref(none, shape, 2,
loc)
except ValueError as e:
# CHECK: invalid 'Type(none)' and expected floating point, integer, vector
# CHECK: or complex type.
print(e)
else:
print("Exception not produced")
run(testMemRefType)
# CHECK-LABEL: TEST: testUnrankedMemRefType
def testUnrankedMemRefType():
ctx = mlir.ir.Context()
f32 = mlir.ir.F32Type(ctx)
loc = ctx.get_unknown_location()
unranked_memref = mlir.ir.UnrankedMemRefType.get_unranked_memref(f32, 2, loc)
# CHECK: unranked memref type: memref<*xf32, 2>
print("unranked memref type:", unranked_memref)
try:
invalid_rank = unranked_memref.rank
except ValueError as e:
# CHECK: calling this method requires that the type has a rank.
print(e)
else:
print("Exception not produced")
try:
invalid_is_dynamic_dim = unranked_memref.is_dynamic_dim(0)
except ValueError as e:
# CHECK: calling this method requires that the type has a rank.
print(e)
else:
print("Exception not produced")
try:
invalid_get_dim_size = unranked_memref.get_dim_size(1)
except ValueError as e:
# CHECK: calling this method requires that the type has a rank.
print(e)
else:
print("Exception not produced")
none = mlir.ir.NoneType(ctx)
try:
memref_invalid = mlir.ir.UnrankedMemRefType.get_unranked_memref(none, 2,
loc)
except ValueError as e:
# CHECK: invalid 'Type(none)' and expected floating point, integer, vector
# CHECK: or complex type.
print(e)
else:
print("Exception not produced")
run(testUnrankedMemRefType)
# CHECK-LABEL: TEST: testTupleType
def testTupleType():
ctx = mlir.ir.Context()