diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index beef1a70096e..c0c59bda1894 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -94,42 +94,22 @@ Operation *fuseTensorOps(PatternRewriter &rewriter, Operation *consumer, unsigned consumerIdx, OperationFolder *folder = nullptr); -/// Returns the linearized list of all view dimensions in a linalgOp. Applying +/// Returns the linearized list of all view dimensions in a `linalgOp`. Applying /// the inverse, concatenated loopToOperandRangeMaps to this list allows the /// derivation of loop ranges for any linalgOp. -template -SmallVector getViewSizes(OpBuilder &builder, ConcreteOp linalgOp) { - auto loc = linalgOp.getLoc(); - SmallVector res; - SmallVector ranks; - for (auto v : linalgOp.getInputsAndOutputBuffers()) { - MemRefType t = v.getType().template cast(); - ranks.push_back(t.getRank()); - for (unsigned i = 0; i < t.getRank(); ++i) - res.push_back(builder.create(loc, v, i)); - } - - auto attr = linalgOp.template getAttrOfType("symbol_source"); - if (attr) { - // Find the correct position for inserting values for symbols. - unsigned numSymb = ranks[attr.getInt()], symbolsPos = 0; - for (unsigned idx = 0; idx < attr.getInt(); idx++) - symbolsPos += ranks[idx]; - - // Append the end of the value list that corresponds to the - // values mapping to symbols. Since inside concatinated map symbols are - // repeated we have to repeat the sizes as well. - - // Reserve is mandatory to avoid a potential undefined behavior with - // pushing back to smallvector from itself. - res.reserve(res.size() + ranks.size() * numSymb); - for (unsigned idx = 0, s = ranks.size(); idx < s; ++idx) - for (unsigned idx2 = 0; idx2 < numSymb; ++idx2) - res.push_back(res[symbolsPos + idx2]); - } - return res; +SmallVector getViewSizes(OpBuilder &builder, LinalgOp linalgOp); +template +SmallVector getViewSizes(OpBuilder &builder, ConcreteOpTy linalgOp) { + return getViewSizes(builder, cast(linalgOp.getOperation())); } +/// Returns the loop ranges of the `linalgOp`. Applies the inverse of the +/// concatenated indexing maps to the result of `getViewSizes`. Returns None if +/// the bounds computation fails. +Optional> +getLoopRanges(OpBuilder &builder, LinalgOp linalgOp, + OperationFolder *folder = nullptr); + /// Returns the values obtained by applying `map` to the list of values. /// When non-null, the optional pointer `folder` is used to call into the /// `createAndFold` builder method. If `folder` is null, the regular `create` diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index cf14555aa63f..585b00189964 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -147,6 +147,50 @@ static void unpackRanges(ArrayRef ranges, namespace mlir { namespace linalg { +/// Return the linearized list of all view dimensions in a linalgOp. +SmallVector getViewSizes(OpBuilder &builder, LinalgOp linalgOp) { + auto loc = linalgOp.getLoc(); + SmallVector res; + SmallVector ranks; + for (auto v : linalgOp.getInputsAndOutputBuffers()) { + MemRefType t = v.getType().template cast(); + ranks.push_back(t.getRank()); + for (unsigned i = 0; i < t.getRank(); ++i) + res.push_back(builder.create(loc, v, i)); + } + + auto attr = linalgOp.template getAttrOfType("symbol_source"); + if (attr) { + // Find the correct position for inserting values for symbols. + unsigned numSymb = ranks[attr.getInt()], symbolsPos = 0; + for (unsigned idx = 0; idx < attr.getInt(); idx++) + symbolsPos += ranks[idx]; + + // Append the end of the value list that corresponds to the + // values mapping to symbols. Since inside concatinated map symbols are + // repeated we have to repeat the sizes as well. + + // Reserve is mandatory to avoid a potential undefined behavior with + // pushing back to smallvector from itself. + res.reserve(res.size() + ranks.size() * numSymb); + for (unsigned idx = 0, s = ranks.size(); idx < s; ++idx) + for (unsigned idx2 = 0; idx2 < numSymb; ++idx2) + res.push_back(res[symbolsPos + idx2]); + } + return res; +} + +Optional> +getLoopRanges(OpBuilder &builder, LinalgOp linalgOp, OperationFolder *folder) { + SmallVector viewSizes = getViewSizes(builder, linalgOp); + AffineMap invertedMap = + inversePermutation(concatAffineMaps(linalgOp.getIndexingMaps())); + if (!invertedMap) + return {}; + return applyMapToValues(builder, linalgOp.getLoc(), invertedMap, viewSizes, + folder); +} + /// Specialization to build an scf "for" nest. template <> void GenerateLoopNest::doit(