[mlir] Expose MemRef layout in Python bindings

This wasn't possible before because there was no support for affine expressions
as maps. Now that this support is available, provide the mechanism for
constructing maps with a layout and inspecting it.

Rework the `get` method on MemRefType in Python to avoid needing an explicit
memory space or layout map. Remove the `get_num_maps`, it is too low-level,
using the length of the now-avaiable pseudo-list of layout maps is more
pythonic.

Depends On D94297

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D94302
This commit is contained in:
Alex Zinenko 2021-01-08 15:08:44 +01:00
parent e79bd0b4f2
commit 547e3eef14
4 changed files with 86 additions and 19 deletions

View File

@ -225,7 +225,13 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedMemRef(MlirType type);
* same context as element type. The type is owned by the context. */
MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGet(
MlirType elementType, intptr_t rank, const int64_t *shape, intptr_t numMaps,
MlirAttribute const *affineMaps, unsigned memorySpace);
MlirAffineMap const *affineMaps, unsigned memorySpace);
/** Same as "mlirMemRefTypeGet" but returns a nullptr-wrapping MlirType o
* illegal arguments, emitting appropriate diagnostics. */
MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGetChecked(
MlirType elementType, intptr_t rank, const int64_t *shape, intptr_t numMaps,
MlirAffineMap const *affineMaps, unsigned memorySpace, MlirLocation loc);
/** Creates a MemRef type with the given rank, shape, memory space and element
* type in the same context as the element type. The type has no affine maps,

View File

@ -2535,6 +2535,8 @@ public:
}
};
class PyMemRefLayoutMapList;
/// Ranked MemRef Type subclass - MemRefType.
class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
public:
@ -2542,16 +2544,22 @@ public:
static constexpr const char *pyClassName = "MemRefType";
using PyConcreteType::PyConcreteType;
PyMemRefLayoutMapList getLayout();
static void bindDerived(ClassTy &c) {
// TODO: Add mlirMemRefTypeGet and mlirMemRefTypeGetAffineMap binding
// once the affine map binding is completed.
c.def_static(
"get_contiguous_memref",
// TODO: Make the location optional and create a default location.
"get",
[](PyType &elementType, std::vector<int64_t> shape,
unsigned memorySpace, DefaultingPyLocation loc) {
MlirType t = mlirMemRefTypeContiguousGetChecked(
elementType, shape.size(), shape.data(), memorySpace, loc);
std::vector<PyAffineMap> layout, unsigned memorySpace,
DefaultingPyLocation loc) {
SmallVector<MlirAffineMap> maps;
maps.reserve(layout.size());
for (PyAffineMap &map : layout)
maps.push_back(map);
MlirType t = mlirMemRefTypeGetChecked(elementType, shape.size(),
shape.data(), maps.size(),
maps.data(), memorySpace, loc);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
@ -2565,15 +2573,11 @@ public:
}
return PyMemRefType(elementType.getContext(), t);
},
py::arg("element_type"), py::arg("shape"), py::arg("memory_space"),
py::arg("element_type"), py::arg("shape"),
py::arg("layout") = py::list(), py::arg("memory_space") = 0,
py::arg("loc") = py::none(), "Create a memref type")
.def_property_readonly(
"num_affine_maps",
[](PyMemRefType &self) -> intptr_t {
return mlirMemRefTypeGetNumAffineMaps(self);
},
"Returns the number of affine layout maps in the given MemRef "
"type.")
.def_property_readonly("layout", &PyMemRefType::getLayout,
"The list of layout maps of the MemRef type.")
.def_property_readonly(
"memory_space",
[](PyMemRefType &self) -> unsigned {
@ -2583,6 +2587,41 @@ public:
}
};
/// A list of affine layout maps in a memref type. Internally, these are stored
/// as consecutive elements, random access is cheap. Both the type and the maps
/// are owned by the context, no need to worry about lifetime extension.
class PyMemRefLayoutMapList
: public Sliceable<PyMemRefLayoutMapList, PyAffineMap> {
public:
static constexpr const char *pyClassName = "MemRefLayoutMapList";
PyMemRefLayoutMapList(PyMemRefType type, intptr_t startIndex = 0,
intptr_t length = -1, intptr_t step = 1)
: Sliceable(startIndex,
length == -1 ? mlirMemRefTypeGetNumAffineMaps(type) : length,
step),
memref(type) {}
intptr_t getNumElements() { return mlirMemRefTypeGetNumAffineMaps(memref); }
PyAffineMap getElement(intptr_t index) {
return PyAffineMap(memref.getContext(),
mlirMemRefTypeGetAffineMap(memref, index));
}
PyMemRefLayoutMapList slice(intptr_t startIndex, intptr_t length,
intptr_t step) {
return PyMemRefLayoutMapList(memref, startIndex, length, step);
}
private:
PyMemRefType memref;
};
PyMemRefLayoutMapList PyMemRefType::getLayout() {
return PyMemRefLayoutMapList(*this);
}
/// Unranked MemRef Type subclass - UnrankedMemRefType.
class PyUnrankedMemRefType
: public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
@ -3631,6 +3670,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
PyRankedTensorType::bind(m);
PyUnrankedTensorType::bind(m);
PyMemRefType::bind(m);
PyMemRefLayoutMapList::bind(m);
PyUnrankedMemRefType::bind(m);
PyTupleType::bind(m);
PyFunctionType::bind(m);

View File

@ -231,6 +231,17 @@ MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank,
unwrap(elementType), maps, memorySpace));
}
MlirType mlirMemRefTypeGetChecked(MlirType elementType, intptr_t rank,
const int64_t *shape, intptr_t numMaps,
MlirAffineMap const *affineMaps,
unsigned memorySpace, MlirLocation loc) {
SmallVector<AffineMap, 1> maps;
(void)unwrapList(numMaps, affineMaps, maps);
return wrap(MemRefType::getChecked(
unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
unwrap(elementType), maps, memorySpace));
}
MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank,
const int64_t *shape,
unsigned memorySpace) {

View File

@ -326,17 +326,27 @@ def testMemRefType():
f32 = F32Type.get()
shape = [2, 3]
loc = Location.unknown()
memref = MemRefType.get_contiguous_memref(f32, shape, 2)
memref = MemRefType.get(f32, shape, memory_space=2)
# CHECK: memref type: memref<2x3xf32, 2>
print("memref type:", memref)
# CHECK: number of affine layout maps: 0
print("number of affine layout maps:", memref.num_affine_maps)
print("number of affine layout maps:", len(memref.layout))
# CHECK: memory space: 2
print("memory space:", memref.memory_space)
layout = AffineMap.get_permutation([1, 0])
memref_layout = MemRefType.get(f32, shape, [layout])
# CHECK: memref type: memref<2x3xf32, affine_map<(d0, d1) -> (d1, d0)>>
print("memref type:", memref_layout)
assert len(memref_layout.layout) == 1
# CHECK: memref layout: (d0, d1) -> (d1, d0)
print("memref layout:", memref_layout.layout[0])
# CHECK: memory space: 0
print("memory space:", memref_layout.memory_space)
none = NoneType.get()
try:
memref_invalid = MemRefType.get_contiguous_memref(none, shape, 2)
memref_invalid = MemRefType.get(none, shape)
except ValueError as e:
# CHECK: invalid 'Type(none)' and expected floating point, integer, vector
# CHECK: or complex type.