[mlir][Linalg] Add utility function that return static loop bounds of Linalg ops

Differential Revision: https://reviews.llvm.org/D91749
This commit is contained in:
MaheshRavishankar 2020-11-19 18:59:48 -08:00
parent b2f6630739
commit 8b525c9c19
2 changed files with 29 additions and 0 deletions

View File

@ -114,12 +114,23 @@ SmallVector<Value, 8> getShape(OpBuilder &builder, ConcreteOpTy linalgOp) {
return getShape(builder, cast<linalg::LinalgOp>(linalgOp.getOperation()));
}
/// Like `getShape`, but only returns statically-known information, without
/// generating any new IR. For each shape dimension, returns >=0 if that
/// dimension is statically known, or -1 otherwise.
SmallVector<int64_t, 8> getStaticShape(LinalgOp linalgOp);
/// Returns the loop ranges of the `linalgOp`. Applies the inverse of the
/// concatenated indexing maps to the result of `getShape`. Returns None if
/// the bounds computation fails.
Optional<SmallVector<Value, 4>> getLoopRanges(OpBuilder &builder,
LinalgOp linalgOp);
/// Returns the statically-known loop ranges of the `linalgOp`. Applies the
/// inverse of the concatenated indexing maps to the result of `getStaticShape`.
/// Returns None if inverting the concatenated indexing map fails. Returns -1
/// for non-statically-known loop ranges.
Optional<SmallVector<int64_t, 4>> getStaticLoopRanges(LinalgOp linalgOp);
/// Returns the values obtained by applying `map` to the list of values.
SmallVector<Value, 4> applyMapToValues(OpBuilder &b, Location loc,
AffineMap map, ValueRange values);

View File

@ -156,6 +156,15 @@ SmallVector<Value, 8> getShape(OpBuilder &builder, LinalgOp linalgOp) {
return res;
}
SmallVector<int64_t, 8> getStaticShape(LinalgOp linalgOp) {
SmallVector<int64_t, 8> res;
for (Value v : linalgOp.getShapedOperands()) {
auto shape = v.getType().cast<ShapedType>().getShape();
res.append(shape.begin(), shape.end());
}
return res;
}
Optional<SmallVector<Value, 4>> getLoopRanges(OpBuilder &builder,
LinalgOp linalgOp) {
SmallVector<Value, 8> viewSizes = getShape(builder, linalgOp);
@ -166,6 +175,15 @@ Optional<SmallVector<Value, 4>> getLoopRanges(OpBuilder &builder,
return applyMapToValues(builder, linalgOp.getLoc(), invertedMap, viewSizes);
}
Optional<SmallVector<int64_t, 4>> getStaticLoopRanges(LinalgOp linalgOp) {
SmallVector<int64_t, 8> viewSizes = getStaticShape(linalgOp);
AffineMap invertedMap =
inversePermutation(concatAffineMaps(linalgOp.getIndexingMaps()));
if (!invertedMap)
return {};
return invertedMap.compose(viewSizes);
}
/// Specialization to build an scf "for" nest.
template <>
void GenerateLoopNest<scf::ForOp>::doit(