[mlir][nfc] Expose linalg tiling helpers.

Differential Revision: https://reviews.llvm.org/D119330
This commit is contained in:
Alexander Belyaev 2022-02-09 15:20:04 +01:00
parent fd0417a3cf
commit c962038914
2 changed files with 76 additions and 67 deletions

View File

@ -481,6 +481,75 @@ private:
using TileSizeComputationFunction = using TileSizeComputationFunction =
std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>; std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>;
/// Creates a number of ranges equal to the number of non-zero in `tileSizes`.
/// One for each loop of the LinalgOp that is tiled. The `tileSizes` argument
/// has one entry per surrounding loop. It uses zero as the convention that a
/// particular loop is not tiled. This convention simplifies implementations by
/// avoiding affine map manipulations.
/// The returned ranges correspond to the loop ranges, in the proper order, that
/// are tiled and for which new loops will be created. Also the function returns
/// a map from loop indices of the LinalgOp to the corresponding non-empty range
/// indices of newly created loops.
using LoopIndexToRangeIndexMap = DenseMap<int, int>;
std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
ValueRange allShapeSizes, ValueRange allTileSizes);
/// All indices returned by IndexOp should be invariant with respect to tiling.
/// Therefore, if an operation is tiled, we have to transform the indices
/// accordingly, i.e. offset them by the values of the corresponding induction
/// variables that are captured implicitly in the body of the op.
///
/// Example. `linalg.generic` before tiling:
///
/// #id_2d = (i, j) -> (i, j)
/// #pointwise_2d_trait = {
/// indexing_maps = [#id_2d, #id_2d],
/// iterator_types = ["parallel", "parallel"]
/// }
/// linalg.generic #pointwise_2d_trait %operand, %result {
/// ^bb0(%operand_in: f32, %result_in: f32):
/// %i = linalg.index 0 : index
/// %j = linalg.index 1 : index
/// <some operations that use %i, %j>
/// }: memref<50x100xf32>, memref<50x100xf32>
///
/// After tiling pass with tiles sizes 10 and 25:
///
/// #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2)
///
/// %c1 = arith.constant 1 : index
/// %c0 = arith.constant 0 : index
/// %c25 = arith.constant 25 : index
/// %c10 = arith.constant 10 : index
/// operand_dim_0 = dim %operand, 0 : memref<50x100xf32>
/// operand_dim_1 = dim %operand, 1 : memref<50x100xf32>
/// scf.for %k = %c0 to operand_dim_0 step %c10 {
/// scf.for %l = %c0 to operand_dim_1 step %c25 {
/// %4 = std.subview %operand[%k, %l][%c10, %c25][%c1, %c1]
/// : memref<50x100xf32> to memref<?x?xf32, #strided>
/// %5 = std.subview %result[%k, %l][%c10, %c25][%c1, %c1]
/// : memref<50x100xf32> to memref<?x?xf32, #strided>
/// linalg.generic pointwise_2d_trait %4, %5 {
/// ^bb0(%operand_in: f32, %result_in: f32):
/// %i = linalg.index 0 : index
/// %j = linalg.index 1 : index
/// // Indices `k` and `l` are implicitly captured in the body.
/// %transformed_i = arith.addi %i, %k : index // index `i` is offset by
/// %k %transformed_j = arith.addi %j, %l : index // index `j` is offset
/// by %l
/// // Every use of %i, %j is replaced with %transformed_i, %transformed_j
/// <some operations that use %transformed_i, %transformed_j>
/// }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided>
/// }
/// }
///
/// TODO: Investigate whether mixing implicit and explicit indices
/// does not lead to losing information.
void transformIndexOps(RewriterBase &b, LinalgOp op,
SmallVectorImpl<Value> &ivs,
const LoopIndexToRangeIndexMap &loopIndexToRangeIndex);
/// Callback returning the padding value to use for a given OpOperand or failure /// Callback returning the padding value to use for a given OpOperand or failure
/// for no padding. This should be a function of both the operation and the /// for no padding. This should be a function of both the operation and the
/// operand type. /// operand type.

View File

@ -39,20 +39,10 @@ static bool isZero(Value v) {
return false; return false;
} }
using LoopIndexToRangeIndexMap = DenseMap<int, int>; std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
mlir::linalg::makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
// Creates a number of ranges equal to the number of non-zero in `tileSizes`. ValueRange allShapeSizes,
// One for each loop of the LinalgOp that is tiled. The `tileSizes` argument has ValueRange allTileSizes) {
// one entry per surrounding loop. It uses zero as the convention that a
// particular loop is not tiled. This convention simplifies implementations by
// avoiding affine map manipulations.
// The returned ranges correspond to the loop ranges, in the proper order, that
// are tiled and for which new loops will be created. Also the function returns
// a map from loop indices of the LinalgOp to the corresponding non-empty range
// indices of newly created loops.
static std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
ValueRange allShapeSizes, ValueRange allTileSizes) {
assert(allTileSizes.size() == map.getNumResults()); assert(allTileSizes.size() == map.getNumResults());
// Apply `map` to get shape sizes in loop order. // Apply `map` to get shape sizes in loop order.
auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes);
@ -78,59 +68,9 @@ makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
return std::make_tuple(res, loopIndexToRangeIndex); return std::make_tuple(res, loopIndexToRangeIndex);
} }
// All indices returned by IndexOp should be invariant with respect to tiling. void mlir::linalg::transformIndexOps(
// Therefore, if an operation is tiled, we have to transform the indices RewriterBase &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
// accordingly, i.e. offset them by the values of the corresponding induction const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
// variables that are captured implicitly in the body of the op.
//
// Example. `linalg.generic` before tiling:
//
// #id_2d = (i, j) -> (i, j)
// #pointwise_2d_trait = {
// indexing_maps = [#id_2d, #id_2d],
// iterator_types = ["parallel", "parallel"]
// }
// linalg.generic #pointwise_2d_trait %operand, %result {
// ^bb0(%operand_in: f32, %result_in: f32):
// %i = linalg.index 0 : index
// %j = linalg.index 1 : index
// <some operations that use %i, %j>
// }: memref<50x100xf32>, memref<50x100xf32>
//
// After tiling pass with tiles sizes 10 and 25:
//
// #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2)
//
// %c1 = arith.constant 1 : index
// %c0 = arith.constant 0 : index
// %c25 = arith.constant 25 : index
// %c10 = arith.constant 10 : index
// operand_dim_0 = dim %operand, 0 : memref<50x100xf32>
// operand_dim_1 = dim %operand, 1 : memref<50x100xf32>
// scf.for %k = %c0 to operand_dim_0 step %c10 {
// scf.for %l = %c0 to operand_dim_1 step %c25 {
// %4 = std.subview %operand[%k, %l][%c10, %c25][%c1, %c1]
// : memref<50x100xf32> to memref<?x?xf32, #strided>
// %5 = std.subview %result[%k, %l][%c10, %c25][%c1, %c1]
// : memref<50x100xf32> to memref<?x?xf32, #strided>
// linalg.generic pointwise_2d_trait %4, %5 {
// ^bb0(%operand_in: f32, %result_in: f32):
// %i = linalg.index 0 : index
// %j = linalg.index 1 : index
// // Indices `k` and `l` are implicitly captured in the body.
// %transformed_i = arith.addi %i, %k : index // index `i` is offset by %k
// %transformed_j = arith.addi %j, %l : index // index `j` is offset by %l
// // Every use of %i, %j is replaced with %transformed_i, %transformed_j
// <some operations that use %transformed_i, %transformed_j>
// }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided>
// }
// }
//
// TODO: Investigate whether mixing implicit and explicit indices
// does not lead to losing information.
static void
transformIndexOps(RewriterBase &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
SmallVector<Value> allIvs(op.getNumLoops(), nullptr); SmallVector<Value> allIvs(op.getNumLoops(), nullptr);
for (auto &en : enumerate(allIvs)) { for (auto &en : enumerate(allIvs)) {
auto rangeIndex = loopIndexToRangeIndex.find(en.index()); auto rangeIndex = loopIndexToRangeIndex.find(en.index());