[mlir][linalg] update tiling to support linalg index operations.

The patch updates the tiling pass to add the tile offsets to the indices returned by the linalg operations.

Differential Revision: https://reviews.llvm.org/D100379
This commit is contained in:
Tobias Gysi 2021-04-13 14:04:45 +00:00
parent d7ce89c769
commit 8ea5d190ec
4 changed files with 168 additions and 18 deletions

View File

@ -172,6 +172,85 @@ static void transformIndexedGenericOpIndices(
}
}
// All indices returned by IndexOp should be invariant with respect to tiling.
// Therefore, if an operation is tiled, we have to transform the indices
// accordingly, i.e. offset them by the values of the corresponding induction
// variables that are captured implicitly in the body of the op.
//
// Example. `linalg.generic` before tiling:
//
// #id_2d = (i, j) -> (i, j)
// #pointwise_2d_trait = {
// indexing_maps = [#id_2d, #id_2d],
// iterator_types = ["parallel", "parallel"]
// }
// linalg.generic #pointwise_2d_trait %operand, %result {
// ^bb0(%operand_in: f32, %result_in: f32):
// %i = linalg.index 0 : index
// %j = linalg.index 1 : index
// <some operations that use %i, %j>
// }: memref<50x100xf32>, memref<50x100xf32>
//
// After tiling pass with tiles sizes 10 and 25:
//
// #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2)
//
// %c1 = constant 1 : index
// %c0 = constant 0 : index
// %c25 = constant 25 : index
// %c10 = constant 10 : index
// operand_dim_0 = dim %operand, 0 : memref<50x100xf32>
// operand_dim_1 = dim %operand, 1 : memref<50x100xf32>
// scf.for %k = %c0 to operand_dim_0 step %c10 {
// scf.for %l = %c0 to operand_dim_1 step %c25 {
// %4 = std.subview %operand[%k, %l][%c10, %c25][%c1, %c1]
// : memref<50x100xf32> to memref<?x?xf32, #strided>
// %5 = std.subview %result[%k, %l][%c10, %c25][%c1, %c1]
// : memref<50x100xf32> to memref<?x?xf32, #strided>
// linalg.generic pointwise_2d_trait %4, %5 {
// ^bb0(%operand_in: f32, %result_in: f32):
// %i = linalg.index 0 : index
// %j = linalg.index 1 : index
// // Indices `k` and `l` are implicitly captured in the body.
// %transformed_i = addi %i, %k : index // index `i` is offset by %k
// %transformed_j = addi %j, %l : index // index `j` is offset by %l
// // Every use of %i, %j is replaced with %transformed_i, %transformed_j
// <some operations that use %transformed_i, %transformed_j>
// }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided>
// }
// }
//
// TODO: Investigate whether mixing implicit and explicit indices
// does not lead to losing information.
static void
transformIndexOps(OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
// Skip operations that have no region attached.
if (op->getNumRegions() == 0)
return;
assert(op->getNumRegions() == 1 && op->getRegion(0).getBlocks().size() == 1 &&
"expected linalg operation to have one block.");
Block &block = op->getRegion(0).front();
for (IndexOp indexOp :
llvm::make_early_inc_range(block.getOps<linalg::IndexOp>())) {
auto rangeIndex = loopIndexToRangeIndex.find(indexOp.dim());
if (rangeIndex == loopIndexToRangeIndex.end())
continue;
// Offset the index by the value of the corresponding induction variable and
// replace all uses of the previous value.
OpBuilder::InsertionGuard g(b);
b.setInsertionPointAfter(indexOp);
AffineExpr index, iv;
bindDims(b.getContext(), index, iv);
AffineApplyOp applyOp = b.create<AffineApplyOp>(
indexOp.getLoc(), index + iv,
ValueRange{indexOp.getResult(), ivs[rangeIndex->second]});
indexOp.getResult().replaceAllUsesExcept(
applyOp.getResult(), SmallPtrSet<Operation *, 1>{applyOp});
}
}
template <typename LoopTy>
static Optional<TiledLinalgOp>
tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
@ -299,8 +378,10 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
},
options.distribution);
// 3. Transforms index arguments of `linalg.generic` w.r.t. to the tiling.
// 3a. Transforms index arguments of `linalg.generic` w.r.t. to the tiling.
transformIndexedGenericOpIndices(b, res, ivs, loopIndexToRangeIndex);
// 3b. Transform IndexOp results w.r.t. the tiling.
transformIndexOps(b, res, ivs, loopIndexToRangeIndex);
// 4. Gather the newly created loops and return them with the new op.
SmallVector<Operation *, 8> loops;

View File

@ -246,8 +246,7 @@ mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const {
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
// TODO: remove hasIndexSemantics check once index ops are supported.
if (!linalgOp || linalgOp.hasIndexSemantics())
if (!linalgOp)
return failure();
if (failed(filter.checkAndNotify(rewriter, linalgOp)))
return failure();

View File

@ -0,0 +1,85 @@
// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=10,25" -split-input-file | FileCheck %s -check-prefix=TILE-10n25
// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=25,0" -split-input-file | FileCheck %s -check-prefix=TILE-25n0
// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,25" -split-input-file | FileCheck %s -check-prefix=TILE-0n25
func @indexed_vector(%arg0: memref<50xindex>) {
linalg.generic {indexing_maps = [affine_map<(i) -> (i)>],
iterator_types = ["parallel"]}
outs(%arg0 : memref<50xindex>) {
^bb0(%a: index):
%i = linalg.index 0 : index
linalg.yield %i : index
}
return
}
// TILE-10n25-DAG: [[$MAP:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0 + d1)>
// TILE-10n25-LABEL: func @indexed_vector
// TILE-10n25: %[[C10:.*]] = constant 10 : index
// TILE-10n25: scf.for %[[J:.*]] = {{.*}} step %[[C10]]
// TILE-10n25: linalg.generic
// TILE-10n25: %[[I:.*]] = linalg.index 0 : index
// TILE-10n25: %[[NEW_I:.*]] = affine.apply [[$MAP]](%[[I]], %[[J]])
// TILE-10n25: linalg.yield %[[NEW_I]] : index
// TILE-25n0-DAG: [[$MAP:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0 + d1)>
// TILE-25n0-LABEL: func @indexed_vector
// TILE-25n0: %[[C25:.*]] = constant 25 : index
// TILE-25n0: scf.for %[[J:.*]] = {{.*}} step %[[C25]]
// TILE-25n0: linalg.generic
// TILE-25n0: %[[I:.*]] = linalg.index 0 : index
// TILE-25n0: %[[NEW_I:.*]] = affine.apply [[$MAP]](%[[I]], %[[J]])
// TILE-25n0: linalg.yield %[[NEW_I]] : index
// TILE-0n25-LABEL: func @indexed_vector
// TILE-0n25-NOT: scf.for %[[J:.*]] = {{.*}} step %
// TILE-0n25: linalg.generic
// -----
func @indexed_matrix(%arg0: memref<50x50xindex>) {
linalg.generic {indexing_maps = [affine_map<(i, j) -> (i, j)>],
iterator_types = ["parallel", "parallel"]}
outs(%arg0 : memref<50x50xindex>) {
^bb0(%a: index):
%i = linalg.index 0 : index
%j = linalg.index 1 : index
%sum = addi %i, %j : index
linalg.yield %sum : index
}
return
}
// TILE-10n25-DAG: [[$MAP:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0 + d1)>
// TILE-10n25-LABEL: func @indexed_matrix
// TILE-10n25-DAG: %[[C25:.*]] = constant 25 : index
// TILE-10n25-DAG: %[[C10:.*]] = constant 10 : index
// TILE-10n25: scf.for %[[K:.*]] = {{.*}} step %[[C10]]
// TILE-10n25: scf.for %[[L:.*]] = {{.*}} step %[[C25]]
// TILE-10n25: linalg.generic
// TILE-10n25: %[[I:.*]] = linalg.index 0 : index
// TILE-10n25: %[[NEW_I:.*]] = affine.apply [[$MAP]](%[[I]], %[[K]])
// TILE-10n25: %[[J:.*]] = linalg.index 1 : index
// TILE-10n25: %[[NEW_J:.*]] = affine.apply [[$MAP]](%[[J]], %[[L]])
// TILE-10n25: %[[SUM:.*]] = addi %[[NEW_I]], %[[NEW_J]] : index
// TILE-10n25: linalg.yield %[[SUM]] : index
// TILE-25n0-DAG: [[$MAP:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0 + d1)>
// TILE-25n0-LABEL: func @indexed_matrix
// TILE-25n0: %[[C25:.*]] = constant 25 : index
// TILE-25n0: scf.for %[[L:.*]] = {{.*}} step %[[C25]]
// TILE-25n0: linalg.generic
// TILE-25n0: %[[I:.*]] = linalg.index 0 : index
// TILE-25n0: %[[NEW_I:.*]] = affine.apply [[$MAP]](%[[I]], %[[L]])
// TILE-25n0: %[[J:.*]] = linalg.index 1 : index
// TILE-25n0: %[[SUM:.*]] = addi %[[NEW_I]], %[[J]] : index
// TILE-25n0: linalg.yield %[[SUM]] : index
// TILE-0n25-DAG: [[$MAP:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0 + d1)>
// TILE-0n25-LABEL: func @indexed_matrix
// TILE-0n25: %[[C25:.*]] = constant 25 : index
// TILE-0n25: scf.for %[[L:.*]] = {{.*}} step %[[C25]]
// TILE-0n25: linalg.generic
// TILE-0n25: %[[I:.*]] = linalg.index 0 : index
// TILE-0n25: %[[J:.*]] = linalg.index 1 : index
// TILE-0n25: %[[NEW_J:.*]] = affine.apply [[$MAP]](%[[J]], %[[L]])
// TILE-0n25: %[[SUM:.*]] = addi %[[I]], %[[NEW_J]] : index
// TILE-0n25: linalg.yield %[[SUM]] : index

View File

@ -377,18 +377,3 @@ func @pointwise(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memre
// TILE-234: for
// TILE-234-NOT: for
// TILE-234: linalg.generic
// TILE-2-LABEL: func @index_op
// TILE-2-NOT: for
// TILE-2: linalg.generic
func @index_op(%arg0: memref<?x?xindex>) {
linalg.generic {
indexing_maps = [affine_map<(i, j) -> (i, j)>],
iterator_types = ["parallel", "parallel"]}
outs(%arg0 : memref<?x?xindex>) {
^bb0(%arg1: index): // no predecessors
%0 = linalg.index 1 : index
linalg.yield %0 : index
}
return
}