forked from OSchip/llvm-project
[mlir][Linalg] Add utility function that return static loop bounds of Linalg ops
Differential Revision: https://reviews.llvm.org/D91749
This commit is contained in:
parent
b2f6630739
commit
8b525c9c19
|
@ -114,12 +114,23 @@ SmallVector<Value, 8> getShape(OpBuilder &builder, ConcreteOpTy linalgOp) {
|
|||
return getShape(builder, cast<linalg::LinalgOp>(linalgOp.getOperation()));
|
||||
}
|
||||
|
||||
/// Like `getShape`, but only returns statically-known information, without
|
||||
/// generating any new IR. For each shape dimension, returns >=0 if that
|
||||
/// dimension is statically known, or -1 otherwise.
|
||||
SmallVector<int64_t, 8> getStaticShape(LinalgOp linalgOp);
|
||||
|
||||
/// Returns the loop ranges of the `linalgOp`. Applies the inverse of the
|
||||
/// concatenated indexing maps to the result of `getShape`. Returns None if
|
||||
/// the bounds computation fails.
|
||||
Optional<SmallVector<Value, 4>> getLoopRanges(OpBuilder &builder,
|
||||
LinalgOp linalgOp);
|
||||
|
||||
/// Returns the statically-known loop ranges of the `linalgOp`. Applies the
|
||||
/// inverse of the concatenated indexing maps to the result of `getStaticShape`.
|
||||
/// Returns None if inverting the concatenated indexing map fails. Returns -1
|
||||
/// for non-statically-known loop ranges.
|
||||
Optional<SmallVector<int64_t, 4>> getStaticLoopRanges(LinalgOp linalgOp);
|
||||
|
||||
/// Returns the values obtained by applying `map` to the list of values.
|
||||
SmallVector<Value, 4> applyMapToValues(OpBuilder &b, Location loc,
|
||||
AffineMap map, ValueRange values);
|
||||
|
|
|
@ -156,6 +156,15 @@ SmallVector<Value, 8> getShape(OpBuilder &builder, LinalgOp linalgOp) {
|
|||
return res;
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 8> getStaticShape(LinalgOp linalgOp) {
|
||||
SmallVector<int64_t, 8> res;
|
||||
for (Value v : linalgOp.getShapedOperands()) {
|
||||
auto shape = v.getType().cast<ShapedType>().getShape();
|
||||
res.append(shape.begin(), shape.end());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
Optional<SmallVector<Value, 4>> getLoopRanges(OpBuilder &builder,
|
||||
LinalgOp linalgOp) {
|
||||
SmallVector<Value, 8> viewSizes = getShape(builder, linalgOp);
|
||||
|
@ -166,6 +175,15 @@ Optional<SmallVector<Value, 4>> getLoopRanges(OpBuilder &builder,
|
|||
return applyMapToValues(builder, linalgOp.getLoc(), invertedMap, viewSizes);
|
||||
}
|
||||
|
||||
Optional<SmallVector<int64_t, 4>> getStaticLoopRanges(LinalgOp linalgOp) {
|
||||
SmallVector<int64_t, 8> viewSizes = getStaticShape(linalgOp);
|
||||
AffineMap invertedMap =
|
||||
inversePermutation(concatAffineMaps(linalgOp.getIndexingMaps()));
|
||||
if (!invertedMap)
|
||||
return {};
|
||||
return invertedMap.compose(viewSizes);
|
||||
}
|
||||
|
||||
/// Specialization to build an scf "for" nest.
|
||||
template <>
|
||||
void GenerateLoopNest<scf::ForOp>::doit(
|
||||
|
|
Loading…
Reference in New Issue