forked from OSchip/llvm-project
[mlir][linalg] Add isPermutation helper (NFC).
Add a helper method to check if an index vector contains a permutation of its indices. Additionally, refactor applyPermutationToVector to take int64_t. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D110135
This commit is contained in:
parent
3500e7d2b0
commit
9072f1b5f8
|
@ -29,16 +29,20 @@ class LinalgDependenceGraph;
|
|||
// General utilities
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Check if `permutation` is a permutation of the range
|
||||
/// `[0, permutation.size())`.
|
||||
bool isPermutation(ArrayRef<int64_t> permutation);
|
||||
|
||||
/// Apply the permutation defined by `permutation` to `inVec`.
|
||||
/// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
|
||||
/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector
|
||||
/// `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', 'b']`.
|
||||
template <typename T, unsigned N>
|
||||
void applyPermutationToVector(SmallVector<T, N> &inVec,
|
||||
ArrayRef<unsigned> permutation) {
|
||||
ArrayRef<int64_t> permutation) {
|
||||
SmallVector<T, N> auxVec(inVec.size());
|
||||
for (unsigned i = 0; i < permutation.size(); ++i)
|
||||
auxVec[i] = inVec[permutation[i]];
|
||||
for (auto en : enumerate(permutation))
|
||||
auxVec[en.index()] = inVec[en.value()];
|
||||
inVec = auxVec;
|
||||
}
|
||||
|
||||
|
|
|
@ -367,6 +367,8 @@ mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
|
|||
ArrayRef<int64_t> tileInterchange) {
|
||||
assert(tileSizes.size() == tileInterchange.size() &&
|
||||
"expect the number of tile sizes and interchange dims to match");
|
||||
assert(isPermutation(tileInterchange) &&
|
||||
"expect tile interchange is a permutation");
|
||||
|
||||
// Create an empty tile loop nest.
|
||||
TileLoopNest tileLoopNest(consumerOp);
|
||||
|
@ -375,9 +377,7 @@ mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
|
|||
// inner reduction dimensions.
|
||||
SmallVector<StringAttr> iterTypes =
|
||||
llvm::to_vector<6>(consumerOp.iterator_types().getAsRange<StringAttr>());
|
||||
applyPermutationToVector(
|
||||
iterTypes,
|
||||
SmallVector<unsigned>(tileInterchange.begin(), tileInterchange.end()));
|
||||
applyPermutationToVector(iterTypes, tileInterchange);
|
||||
auto *it = find_if(iterTypes, [&](StringAttr iterType) {
|
||||
return !isParallelIterator(iterType);
|
||||
});
|
||||
|
@ -459,14 +459,10 @@ struct LinalgTileAndFuseTensorOps
|
|||
tileInterchange.begin() +
|
||||
rootOp.getNumLoops());
|
||||
|
||||
// As a tiling can only tile a loop dimension once, `rootInterchange` has to
|
||||
// be a permutation of the `rootOp` loop dimensions.
|
||||
SmallVector<AffineExpr> rootInterchangeExprs;
|
||||
transform(rootInterchange, std::back_inserter(rootInterchangeExprs),
|
||||
[&](int64_t dim) { return b.getAffineDimExpr(dim); });
|
||||
AffineMap rootInterchangeMap = AffineMap::get(
|
||||
rootOp.getNumLoops(), 0, rootInterchangeExprs, funcOp.getContext());
|
||||
if (!rootInterchangeMap.isPermutation())
|
||||
// Check `rootInterchange` is a permutation of the `rootOp` loop dimensions.
|
||||
// It has to be a permutation since the tiling cannot tile the same loop
|
||||
// dimension multiple times.
|
||||
if (!isPermutation(rootInterchange))
|
||||
return notifyFailure(
|
||||
"expect the tile interchange permutes the root loops");
|
||||
|
||||
|
|
|
@ -69,7 +69,9 @@ void mlir::linalg::interchangeGenericOp(PatternRewriter &rewriter,
|
|||
ArrayRef<Attribute> itTypes = genericOp.iterator_types().getValue();
|
||||
SmallVector<Attribute, 4> itTypesVector;
|
||||
llvm::append_range(itTypesVector, itTypes);
|
||||
applyPermutationToVector(itTypesVector, interchangeVector);
|
||||
SmallVector<int64_t> permutation(interchangeVector.begin(),
|
||||
interchangeVector.end());
|
||||
applyPermutationToVector(itTypesVector, permutation);
|
||||
genericOp->setAttr(getIteratorTypesAttrName(),
|
||||
ArrayAttr::get(context, itTypesVector));
|
||||
|
||||
|
|
|
@ -206,8 +206,10 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
|
|||
invPermutationMap = inversePermutation(
|
||||
AffineMap::getPermutationMap(interchangeVector, b.getContext()));
|
||||
assert(invPermutationMap);
|
||||
applyPermutationToVector(loopRanges, interchangeVector);
|
||||
applyPermutationToVector(iteratorTypes, interchangeVector);
|
||||
SmallVector<int64_t> permutation(interchangeVector.begin(),
|
||||
interchangeVector.end());
|
||||
applyPermutationToVector(loopRanges, permutation);
|
||||
applyPermutationToVector(iteratorTypes, permutation);
|
||||
}
|
||||
|
||||
// 2. Create the tiled loops.
|
||||
|
|
|
@ -138,6 +138,19 @@ static void unpackRanges(ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs,
|
|||
namespace mlir {
|
||||
namespace linalg {
|
||||
|
||||
bool isPermutation(ArrayRef<int64_t> permutation) {
|
||||
// Count the number of appearances for all indices.
|
||||
SmallVector<int64_t> indexCounts(permutation.size(), 0);
|
||||
for (auto index : permutation) {
|
||||
// Exit if the index is out-of-range.
|
||||
if (index < 0 || index >= static_cast<int64_t>(permutation.size()))
|
||||
return false;
|
||||
indexCounts[index]++;
|
||||
}
|
||||
// Return true if all indices appear once.
|
||||
return count(indexCounts, 1) == static_cast<int64_t>(permutation.size());
|
||||
}
|
||||
|
||||
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
|
||||
/// the type of `source`.
|
||||
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) {
|
||||
|
|
Loading…
Reference in New Issue