forked from OSchip/llvm-project
[mlir][nfc] Expose linalg tiling helpers.
Differential Revision: https://reviews.llvm.org/D119330
This commit is contained in:
parent
fd0417a3cf
commit
c962038914
|
@ -481,6 +481,75 @@ private:
|
|||
using TileSizeComputationFunction =
|
||||
std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>;
|
||||
|
||||
/// Creates a number of ranges equal to the number of non-zero in `tileSizes`.
|
||||
/// One for each loop of the LinalgOp that is tiled. The `tileSizes` argument
|
||||
/// has one entry per surrounding loop. It uses zero as the convention that a
|
||||
/// particular loop is not tiled. This convention simplifies implementations by
|
||||
/// avoiding affine map manipulations.
|
||||
/// The returned ranges correspond to the loop ranges, in the proper order, that
|
||||
/// are tiled and for which new loops will be created. Also the function returns
|
||||
/// a map from loop indices of the LinalgOp to the corresponding non-empty range
|
||||
/// indices of newly created loops.
|
||||
using LoopIndexToRangeIndexMap = DenseMap<int, int>;
|
||||
std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
|
||||
makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
|
||||
ValueRange allShapeSizes, ValueRange allTileSizes);
|
||||
|
||||
/// All indices returned by IndexOp should be invariant with respect to tiling.
|
||||
/// Therefore, if an operation is tiled, we have to transform the indices
|
||||
/// accordingly, i.e. offset them by the values of the corresponding induction
|
||||
/// variables that are captured implicitly in the body of the op.
|
||||
///
|
||||
/// Example. `linalg.generic` before tiling:
|
||||
///
|
||||
/// #id_2d = (i, j) -> (i, j)
|
||||
/// #pointwise_2d_trait = {
|
||||
/// indexing_maps = [#id_2d, #id_2d],
|
||||
/// iterator_types = ["parallel", "parallel"]
|
||||
/// }
|
||||
/// linalg.generic #pointwise_2d_trait %operand, %result {
|
||||
/// ^bb0(%operand_in: f32, %result_in: f32):
|
||||
/// %i = linalg.index 0 : index
|
||||
/// %j = linalg.index 1 : index
|
||||
/// <some operations that use %i, %j>
|
||||
/// }: memref<50x100xf32>, memref<50x100xf32>
|
||||
///
|
||||
/// After tiling pass with tiles sizes 10 and 25:
|
||||
///
|
||||
/// #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2)
|
||||
///
|
||||
/// %c1 = arith.constant 1 : index
|
||||
/// %c0 = arith.constant 0 : index
|
||||
/// %c25 = arith.constant 25 : index
|
||||
/// %c10 = arith.constant 10 : index
|
||||
/// operand_dim_0 = dim %operand, 0 : memref<50x100xf32>
|
||||
/// operand_dim_1 = dim %operand, 1 : memref<50x100xf32>
|
||||
/// scf.for %k = %c0 to operand_dim_0 step %c10 {
|
||||
/// scf.for %l = %c0 to operand_dim_1 step %c25 {
|
||||
/// %4 = std.subview %operand[%k, %l][%c10, %c25][%c1, %c1]
|
||||
/// : memref<50x100xf32> to memref<?x?xf32, #strided>
|
||||
/// %5 = std.subview %result[%k, %l][%c10, %c25][%c1, %c1]
|
||||
/// : memref<50x100xf32> to memref<?x?xf32, #strided>
|
||||
/// linalg.generic pointwise_2d_trait %4, %5 {
|
||||
/// ^bb0(%operand_in: f32, %result_in: f32):
|
||||
/// %i = linalg.index 0 : index
|
||||
/// %j = linalg.index 1 : index
|
||||
/// // Indices `k` and `l` are implicitly captured in the body.
|
||||
/// %transformed_i = arith.addi %i, %k : index // index `i` is offset by
|
||||
/// %k %transformed_j = arith.addi %j, %l : index // index `j` is offset
|
||||
/// by %l
|
||||
/// // Every use of %i, %j is replaced with %transformed_i, %transformed_j
|
||||
/// <some operations that use %transformed_i, %transformed_j>
|
||||
/// }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided>
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// TODO: Investigate whether mixing implicit and explicit indices
|
||||
/// does not lead to losing information.
|
||||
void transformIndexOps(RewriterBase &b, LinalgOp op,
|
||||
SmallVectorImpl<Value> &ivs,
|
||||
const LoopIndexToRangeIndexMap &loopIndexToRangeIndex);
|
||||
|
||||
/// Callback returning the padding value to use for a given OpOperand or failure
|
||||
/// for no padding. This should be a function of both the operation and the
|
||||
/// operand type.
|
||||
|
|
|
@ -39,20 +39,10 @@ static bool isZero(Value v) {
|
|||
return false;
|
||||
}
|
||||
|
||||
using LoopIndexToRangeIndexMap = DenseMap<int, int>;
|
||||
|
||||
// Creates a number of ranges equal to the number of non-zero in `tileSizes`.
|
||||
// One for each loop of the LinalgOp that is tiled. The `tileSizes` argument has
|
||||
// one entry per surrounding loop. It uses zero as the convention that a
|
||||
// particular loop is not tiled. This convention simplifies implementations by
|
||||
// avoiding affine map manipulations.
|
||||
// The returned ranges correspond to the loop ranges, in the proper order, that
|
||||
// are tiled and for which new loops will be created. Also the function returns
|
||||
// a map from loop indices of the LinalgOp to the corresponding non-empty range
|
||||
// indices of newly created loops.
|
||||
static std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
|
||||
makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
|
||||
ValueRange allShapeSizes, ValueRange allTileSizes) {
|
||||
std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
|
||||
mlir::linalg::makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
|
||||
ValueRange allShapeSizes,
|
||||
ValueRange allTileSizes) {
|
||||
assert(allTileSizes.size() == map.getNumResults());
|
||||
// Apply `map` to get shape sizes in loop order.
|
||||
auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes);
|
||||
|
@ -78,59 +68,9 @@ makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
|
|||
return std::make_tuple(res, loopIndexToRangeIndex);
|
||||
}
|
||||
|
||||
// All indices returned by IndexOp should be invariant with respect to tiling.
|
||||
// Therefore, if an operation is tiled, we have to transform the indices
|
||||
// accordingly, i.e. offset them by the values of the corresponding induction
|
||||
// variables that are captured implicitly in the body of the op.
|
||||
//
|
||||
// Example. `linalg.generic` before tiling:
|
||||
//
|
||||
// #id_2d = (i, j) -> (i, j)
|
||||
// #pointwise_2d_trait = {
|
||||
// indexing_maps = [#id_2d, #id_2d],
|
||||
// iterator_types = ["parallel", "parallel"]
|
||||
// }
|
||||
// linalg.generic #pointwise_2d_trait %operand, %result {
|
||||
// ^bb0(%operand_in: f32, %result_in: f32):
|
||||
// %i = linalg.index 0 : index
|
||||
// %j = linalg.index 1 : index
|
||||
// <some operations that use %i, %j>
|
||||
// }: memref<50x100xf32>, memref<50x100xf32>
|
||||
//
|
||||
// After tiling pass with tiles sizes 10 and 25:
|
||||
//
|
||||
// #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2)
|
||||
//
|
||||
// %c1 = arith.constant 1 : index
|
||||
// %c0 = arith.constant 0 : index
|
||||
// %c25 = arith.constant 25 : index
|
||||
// %c10 = arith.constant 10 : index
|
||||
// operand_dim_0 = dim %operand, 0 : memref<50x100xf32>
|
||||
// operand_dim_1 = dim %operand, 1 : memref<50x100xf32>
|
||||
// scf.for %k = %c0 to operand_dim_0 step %c10 {
|
||||
// scf.for %l = %c0 to operand_dim_1 step %c25 {
|
||||
// %4 = std.subview %operand[%k, %l][%c10, %c25][%c1, %c1]
|
||||
// : memref<50x100xf32> to memref<?x?xf32, #strided>
|
||||
// %5 = std.subview %result[%k, %l][%c10, %c25][%c1, %c1]
|
||||
// : memref<50x100xf32> to memref<?x?xf32, #strided>
|
||||
// linalg.generic pointwise_2d_trait %4, %5 {
|
||||
// ^bb0(%operand_in: f32, %result_in: f32):
|
||||
// %i = linalg.index 0 : index
|
||||
// %j = linalg.index 1 : index
|
||||
// // Indices `k` and `l` are implicitly captured in the body.
|
||||
// %transformed_i = arith.addi %i, %k : index // index `i` is offset by %k
|
||||
// %transformed_j = arith.addi %j, %l : index // index `j` is offset by %l
|
||||
// // Every use of %i, %j is replaced with %transformed_i, %transformed_j
|
||||
// <some operations that use %transformed_i, %transformed_j>
|
||||
// }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided>
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// TODO: Investigate whether mixing implicit and explicit indices
|
||||
// does not lead to losing information.
|
||||
static void
|
||||
transformIndexOps(RewriterBase &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
|
||||
const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
|
||||
void mlir::linalg::transformIndexOps(
|
||||
RewriterBase &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
|
||||
const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
|
||||
SmallVector<Value> allIvs(op.getNumLoops(), nullptr);
|
||||
for (auto &en : enumerate(allIvs)) {
|
||||
auto rangeIndex = loopIndexToRangeIndex.find(en.index());
|
||||
|
|
Loading…
Reference in New Issue