[mlir][Python] Add checking process before create an AffineMap from a permutation.

An invalid permutation will trigger a C++ assertion when attempting to create an AffineMap from the permutation.
This patch adds an `isPermutation` function to check the given permutation before creating the AffineMap.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D94492
This commit is contained in:
zhanghb97 2021-01-12 21:40:27 +08:00
parent 25b3921f2f
commit c0f3ea8a08
2 changed files with 24 additions and 0 deletions

View File

@ -153,6 +153,21 @@ static MlirStringRef toMlirStringRef(const std::string &s) {
return mlirStringRefCreate(s.data(), s.size());
}
template <typename PermutationTy>
static bool isPermutation(std::vector<PermutationTy> permutation) {
llvm::SmallVector<bool, 8> seen(permutation.size(), false);
for (auto val : permutation) {
if (val < permutation.size()) {
if (seen[val])
return false;
seen[val] = true;
continue;
}
return false;
}
return true;
}
//------------------------------------------------------------------------------
// Collections.
//------------------------------------------------------------------------------
@ -3914,6 +3929,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
"get_permutation",
[](std::vector<unsigned> permutation,
DefaultingPyMlirContext context) {
if (!isPermutation(permutation))
throw py::cast_error("Invalid permutation when attempting to "
"create an AffineMap");
MlirAffineMap affineMap = mlirAffineMapPermutationGet(
context->get(), permutation.size(), permutation.data());
return PyAffineMap(context->getRef(), affineMap);

View File

@ -73,6 +73,12 @@ def testAffineMapGet():
# CHECK: Invalid expression (None?) when attempting to create an AffineMap
print(e)
try:
AffineMap.get_permutation([1, 0, 1])
except RuntimeError as e:
# CHECK: Invalid permutation when attempting to create an AffineMap
print(e)
try:
map3.get_submap([42])
except ValueError as e: