diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp index b256852ced26..7fb1018fc588 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -58,31 +58,17 @@ static SmallVector permuteIvs(ArrayRef ivs, : SmallVector(ivs.begin(), ivs.end()); } -// Creates a number of ranges equal to the number of results in `map`. -// The returned ranges correspond to the loop ranges, in the proper order, for -// which new loops will be created. -static SmallVector -emitLoopRanges(OpBuilder &b, Location loc, AffineMap map, - ArrayRef allViewSizes) { - // Apply `map` to get view sizes in loop order. - auto sizes = applyMapToValues(b, loc, map, allViewSizes); - // Create a new range with the applied tile sizes. - ScopedContext scope(b, loc); - SmallVector res; - for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx) { - res.push_back(SubViewOp::Range{std_constant_index(0), sizes[idx], - std_constant_index(1)}); - } - return res; -} - /// Creates a number of ranges equal to the number of dimensions in the `map`. -/// The function supports for now only limited number of expressions inside -/// map results. It expects a non-inverted, concatenated map and last values in -/// viewSizes will be applied to the symbols in the map. -static SmallVector -emitLoopRangesWithSymbols(OpBuilder &b, Location loc, AffineMap map, - ValueRange viewSizes) { +/// The returned ranges correspond to the loop ranges, in the proper order, for +/// which new loops will be created. +/// The function supports only maps that are invertible and have results of type +/// DimExpr or (DimExpr + DimExpr - SymbolExpr floordiv ConstExpr). +/// It expects a non-inverted, concatenated map and last values in +/// allViewSizes will be applied to the symbols in the map if it contains any. +static SmallVector emitLoopRanges(OpBuilder &b, + Location loc, + AffineMap map, + ValueRange viewSizes) { unsigned numDims = map.getNumDims(), numRes = map.getNumResults(); unsigned numSym = map.getNumSymbols(); assert(viewSizes.size() == numRes + numSym && @@ -537,23 +523,8 @@ Optional linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) { llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); SmallVector sizes = getViewSizes(builder, linalgOp); AffineMap map = concatAffineMaps(maps); - SmallVector loopRanges; - - if (map.getNumSymbols()) { - loopRanges = emitLoopRangesWithSymbols(scope.getBuilderRef(), - scope.getLocation(), map, sizes); - } else { - AffineMap invertedMap = inversePermutation(map); - if (!invertedMap) - return {}; - if (invertedMap.isEmpty()) { - emitScalarImplementation({}, linalgOp); - return LinalgLoops(); - } - - loopRanges = emitLoopRanges(scope.getBuilderRef(), scope.getLocation(), - invertedMap, sizes); - } + auto loopRanges = emitLoopRanges(scope.getBuilderRef(), scope.getLocation(), + map, getViewSizes(builder, linalgOp)); SmallVector allIvs; GenerateLoopNest::doit( loopRanges, linalgOp.iterator_types().getValue(), [&](ValueRange ivs) {