[mlir][Linalg] Add Utility method to get loop ranges for a LinalgOp.

Also refactor the getViewSizes method to work on LinalgOp instead of
being a templated version. Keeping the templated version for
compatibility.

Differential Revision: https://reviews.llvm.org/D87303
This commit is contained in:
MaheshRavishankar 2020-09-09 22:20:12 -07:00
parent 6afb279100
commit a7b2977aa6
2 changed files with 56 additions and 32 deletions

View File

@ -94,41 +94,21 @@ Operation *fuseTensorOps(PatternRewriter &rewriter, Operation *consumer,
unsigned consumerIdx,
OperationFolder *folder = nullptr);
/// Returns the linearized list of all view dimensions in a linalgOp. Applying
/// Returns the linearized list of all view dimensions in a `linalgOp`. Applying
/// the inverse, concatenated loopToOperandRangeMaps to this list allows the
/// derivation of loop ranges for any linalgOp.
template <typename ConcreteOp>
SmallVector<Value, 8> getViewSizes(OpBuilder &builder, ConcreteOp linalgOp) {
auto loc = linalgOp.getLoc();
SmallVector<Value, 8> res;
SmallVector<unsigned, 4> ranks;
for (auto v : linalgOp.getInputsAndOutputBuffers()) {
MemRefType t = v.getType().template cast<MemRefType>();
ranks.push_back(t.getRank());
for (unsigned i = 0; i < t.getRank(); ++i)
res.push_back(builder.create<DimOp>(loc, v, i));
SmallVector<Value, 8> getViewSizes(OpBuilder &builder, LinalgOp linalgOp);
template <typename ConcreteOpTy>
SmallVector<Value, 8> getViewSizes(OpBuilder &builder, ConcreteOpTy linalgOp) {
return getViewSizes(builder, cast<linalg::LinalgOp>(linalgOp.getOperation()));
}
auto attr = linalgOp.template getAttrOfType<IntegerAttr>("symbol_source");
if (attr) {
// Find the correct position for inserting values for symbols.
unsigned numSymb = ranks[attr.getInt()], symbolsPos = 0;
for (unsigned idx = 0; idx < attr.getInt(); idx++)
symbolsPos += ranks[idx];
// Append the end of the value list that corresponds to the
// values mapping to symbols. Since inside concatinated map symbols are
// repeated we have to repeat the sizes as well.
// Reserve is mandatory to avoid a potential undefined behavior with
// pushing back to smallvector from itself.
res.reserve(res.size() + ranks.size() * numSymb);
for (unsigned idx = 0, s = ranks.size(); idx < s; ++idx)
for (unsigned idx2 = 0; idx2 < numSymb; ++idx2)
res.push_back(res[symbolsPos + idx2]);
}
return res;
}
/// Returns the loop ranges of the `linalgOp`. Applies the inverse of the
/// concatenated indexing maps to the result of `getViewSizes`. Returns None if
/// the bounds computation fails.
Optional<SmallVector<Value, 4>>
getLoopRanges(OpBuilder &builder, LinalgOp linalgOp,
OperationFolder *folder = nullptr);
/// Returns the values obtained by applying `map` to the list of values.
/// When non-null, the optional pointer `folder` is used to call into the

View File

@ -147,6 +147,50 @@ static void unpackRanges(ArrayRef<SubViewOp::Range> ranges,
namespace mlir {
namespace linalg {
/// Return the linearized list of all view dimensions in a linalgOp.
SmallVector<Value, 8> getViewSizes(OpBuilder &builder, LinalgOp linalgOp) {
auto loc = linalgOp.getLoc();
SmallVector<Value, 8> res;
SmallVector<unsigned, 4> ranks;
for (auto v : linalgOp.getInputsAndOutputBuffers()) {
MemRefType t = v.getType().template cast<MemRefType>();
ranks.push_back(t.getRank());
for (unsigned i = 0; i < t.getRank(); ++i)
res.push_back(builder.create<DimOp>(loc, v, i));
}
auto attr = linalgOp.template getAttrOfType<IntegerAttr>("symbol_source");
if (attr) {
// Find the correct position for inserting values for symbols.
unsigned numSymb = ranks[attr.getInt()], symbolsPos = 0;
for (unsigned idx = 0; idx < attr.getInt(); idx++)
symbolsPos += ranks[idx];
// Append the end of the value list that corresponds to the
// values mapping to symbols. Since inside concatinated map symbols are
// repeated we have to repeat the sizes as well.
// Reserve is mandatory to avoid a potential undefined behavior with
// pushing back to smallvector from itself.
res.reserve(res.size() + ranks.size() * numSymb);
for (unsigned idx = 0, s = ranks.size(); idx < s; ++idx)
for (unsigned idx2 = 0; idx2 < numSymb; ++idx2)
res.push_back(res[symbolsPos + idx2]);
}
return res;
}
Optional<SmallVector<Value, 4>>
getLoopRanges(OpBuilder &builder, LinalgOp linalgOp, OperationFolder *folder) {
SmallVector<Value, 8> viewSizes = getViewSizes(builder, linalgOp);
AffineMap invertedMap =
inversePermutation(concatAffineMaps(linalgOp.getIndexingMaps()));
if (!invertedMap)
return {};
return applyMapToValues(builder, linalgOp.getLoc(), invertedMap, viewSizes,
folder);
}
/// Specialization to build an scf "for" nest.
template <>
void GenerateLoopNest<scf::ForOp>::doit(