[mlir][linalg] Add isPermutation helper (NFC).

Add a helper method to check if an index vector contains a permutation of its indices. Additionally, refactor applyPermutationToVector to take int64_t.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D110135
This commit is contained in:
Tobias Gysi 2021-09-21 14:53:33 +00:00
parent 3500e7d2b0
commit 9072f1b5f8
5 changed files with 34 additions and 17 deletions

View File

@ -29,16 +29,20 @@ class LinalgDependenceGraph;
// General utilities
//===----------------------------------------------------------------------===//
/// Check if `permutation` is a permutation of the range
/// `[0, permutation.size())`.
bool isPermutation(ArrayRef<int64_t> permutation);
/// Apply the permutation defined by `permutation` to `inVec`.
/// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector
/// `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', 'b']`.
template <typename T, unsigned N>
void applyPermutationToVector(SmallVector<T, N> &inVec,
ArrayRef<unsigned> permutation) {
ArrayRef<int64_t> permutation) {
SmallVector<T, N> auxVec(inVec.size());
for (unsigned i = 0; i < permutation.size(); ++i)
auxVec[i] = inVec[permutation[i]];
for (auto en : enumerate(permutation))
auxVec[en.index()] = inVec[en.value()];
inVec = auxVec;
}

View File

@ -367,6 +367,8 @@ mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
ArrayRef<int64_t> tileInterchange) {
assert(tileSizes.size() == tileInterchange.size() &&
"expect the number of tile sizes and interchange dims to match");
assert(isPermutation(tileInterchange) &&
"expect tile interchange is a permutation");
// Create an empty tile loop nest.
TileLoopNest tileLoopNest(consumerOp);
@ -375,9 +377,7 @@ mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
// inner reduction dimensions.
SmallVector<StringAttr> iterTypes =
llvm::to_vector<6>(consumerOp.iterator_types().getAsRange<StringAttr>());
applyPermutationToVector(
iterTypes,
SmallVector<unsigned>(tileInterchange.begin(), tileInterchange.end()));
applyPermutationToVector(iterTypes, tileInterchange);
auto *it = find_if(iterTypes, [&](StringAttr iterType) {
return !isParallelIterator(iterType);
});
@ -459,14 +459,10 @@ struct LinalgTileAndFuseTensorOps
tileInterchange.begin() +
rootOp.getNumLoops());
// As a tiling can only tile a loop dimension once, `rootInterchange` has to
// be a permutation of the `rootOp` loop dimensions.
SmallVector<AffineExpr> rootInterchangeExprs;
transform(rootInterchange, std::back_inserter(rootInterchangeExprs),
[&](int64_t dim) { return b.getAffineDimExpr(dim); });
AffineMap rootInterchangeMap = AffineMap::get(
rootOp.getNumLoops(), 0, rootInterchangeExprs, funcOp.getContext());
if (!rootInterchangeMap.isPermutation())
// Check `rootInterchange` is a permutation of the `rootOp` loop dimensions.
// It has to be a permutation since the tiling cannot tile the same loop
// dimension multiple times.
if (!isPermutation(rootInterchange))
return notifyFailure(
"expect the tile interchange permutes the root loops");

View File

@ -69,7 +69,9 @@ void mlir::linalg::interchangeGenericOp(PatternRewriter &rewriter,
ArrayRef<Attribute> itTypes = genericOp.iterator_types().getValue();
SmallVector<Attribute, 4> itTypesVector;
llvm::append_range(itTypesVector, itTypes);
applyPermutationToVector(itTypesVector, interchangeVector);
SmallVector<int64_t> permutation(interchangeVector.begin(),
interchangeVector.end());
applyPermutationToVector(itTypesVector, permutation);
genericOp->setAttr(getIteratorTypesAttrName(),
ArrayAttr::get(context, itTypesVector));

View File

@ -206,8 +206,10 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
invPermutationMap = inversePermutation(
AffineMap::getPermutationMap(interchangeVector, b.getContext()));
assert(invPermutationMap);
applyPermutationToVector(loopRanges, interchangeVector);
applyPermutationToVector(iteratorTypes, interchangeVector);
SmallVector<int64_t> permutation(interchangeVector.begin(),
interchangeVector.end());
applyPermutationToVector(loopRanges, permutation);
applyPermutationToVector(iteratorTypes, permutation);
}
// 2. Create the tiled loops.

View File

@ -138,6 +138,19 @@ static void unpackRanges(ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs,
namespace mlir {
namespace linalg {
bool isPermutation(ArrayRef<int64_t> permutation) {
// Count the number of appearances for all indices.
SmallVector<int64_t> indexCounts(permutation.size(), 0);
for (auto index : permutation) {
// Exit if the index is out-of-range.
if (index < 0 || index >= static_cast<int64_t>(permutation.size()))
return false;
indexCounts[index]++;
}
// Return true if all indices appear once.
return count(indexCounts, 1) == static_cast<int64_t>(permutation.size());
}
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
/// the type of `source`.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) {