forked from OSchip/llvm-project
[mlir][Python] Add checking process before create an AffineMap from a permutation.
An invalid permutation will trigger a C++ assertion when attempting to create an AffineMap from the permutation. This patch adds an `isPermutation` function to check the given permutation before creating the AffineMap. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D94492
This commit is contained in:
parent
25b3921f2f
commit
c0f3ea8a08
|
@ -153,6 +153,21 @@ static MlirStringRef toMlirStringRef(const std::string &s) {
|
||||||
return mlirStringRefCreate(s.data(), s.size());
|
return mlirStringRefCreate(s.data(), s.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename PermutationTy>
|
||||||
|
static bool isPermutation(std::vector<PermutationTy> permutation) {
|
||||||
|
llvm::SmallVector<bool, 8> seen(permutation.size(), false);
|
||||||
|
for (auto val : permutation) {
|
||||||
|
if (val < permutation.size()) {
|
||||||
|
if (seen[val])
|
||||||
|
return false;
|
||||||
|
seen[val] = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
//------------------------------------------------------------------------------
|
//------------------------------------------------------------------------------
|
||||||
// Collections.
|
// Collections.
|
||||||
//------------------------------------------------------------------------------
|
//------------------------------------------------------------------------------
|
||||||
|
@ -3914,6 +3929,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
||||||
"get_permutation",
|
"get_permutation",
|
||||||
[](std::vector<unsigned> permutation,
|
[](std::vector<unsigned> permutation,
|
||||||
DefaultingPyMlirContext context) {
|
DefaultingPyMlirContext context) {
|
||||||
|
if (!isPermutation(permutation))
|
||||||
|
throw py::cast_error("Invalid permutation when attempting to "
|
||||||
|
"create an AffineMap");
|
||||||
MlirAffineMap affineMap = mlirAffineMapPermutationGet(
|
MlirAffineMap affineMap = mlirAffineMapPermutationGet(
|
||||||
context->get(), permutation.size(), permutation.data());
|
context->get(), permutation.size(), permutation.data());
|
||||||
return PyAffineMap(context->getRef(), affineMap);
|
return PyAffineMap(context->getRef(), affineMap);
|
||||||
|
|
|
@ -73,6 +73,12 @@ def testAffineMapGet():
|
||||||
# CHECK: Invalid expression (None?) when attempting to create an AffineMap
|
# CHECK: Invalid expression (None?) when attempting to create an AffineMap
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
|
try:
|
||||||
|
AffineMap.get_permutation([1, 0, 1])
|
||||||
|
except RuntimeError as e:
|
||||||
|
# CHECK: Invalid permutation when attempting to create an AffineMap
|
||||||
|
print(e)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
map3.get_submap([42])
|
map3.get_submap([42])
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
|
Loading…
Reference in New Issue