forked from OSchip/llvm-project
[mlir][linalg] update tiling to support linalg index operations.
The patch updates the tiling pass to add the tile offsets to the indices returned by the linalg operations. Differential Revision: https://reviews.llvm.org/D100379
This commit is contained in:
parent
d7ce89c769
commit
8ea5d190ec
|
@ -172,6 +172,85 @@ static void transformIndexedGenericOpIndices(
|
|||
}
|
||||
}
|
||||
|
||||
// All indices returned by IndexOp should be invariant with respect to tiling.
|
||||
// Therefore, if an operation is tiled, we have to transform the indices
|
||||
// accordingly, i.e. offset them by the values of the corresponding induction
|
||||
// variables that are captured implicitly in the body of the op.
|
||||
//
|
||||
// Example. `linalg.generic` before tiling:
|
||||
//
|
||||
// #id_2d = (i, j) -> (i, j)
|
||||
// #pointwise_2d_trait = {
|
||||
// indexing_maps = [#id_2d, #id_2d],
|
||||
// iterator_types = ["parallel", "parallel"]
|
||||
// }
|
||||
// linalg.generic #pointwise_2d_trait %operand, %result {
|
||||
// ^bb0(%operand_in: f32, %result_in: f32):
|
||||
// %i = linalg.index 0 : index
|
||||
// %j = linalg.index 1 : index
|
||||
// <some operations that use %i, %j>
|
||||
// }: memref<50x100xf32>, memref<50x100xf32>
|
||||
//
|
||||
// After tiling pass with tiles sizes 10 and 25:
|
||||
//
|
||||
// #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2)
|
||||
//
|
||||
// %c1 = constant 1 : index
|
||||
// %c0 = constant 0 : index
|
||||
// %c25 = constant 25 : index
|
||||
// %c10 = constant 10 : index
|
||||
// operand_dim_0 = dim %operand, 0 : memref<50x100xf32>
|
||||
// operand_dim_1 = dim %operand, 1 : memref<50x100xf32>
|
||||
// scf.for %k = %c0 to operand_dim_0 step %c10 {
|
||||
// scf.for %l = %c0 to operand_dim_1 step %c25 {
|
||||
// %4 = std.subview %operand[%k, %l][%c10, %c25][%c1, %c1]
|
||||
// : memref<50x100xf32> to memref<?x?xf32, #strided>
|
||||
// %5 = std.subview %result[%k, %l][%c10, %c25][%c1, %c1]
|
||||
// : memref<50x100xf32> to memref<?x?xf32, #strided>
|
||||
// linalg.generic pointwise_2d_trait %4, %5 {
|
||||
// ^bb0(%operand_in: f32, %result_in: f32):
|
||||
// %i = linalg.index 0 : index
|
||||
// %j = linalg.index 1 : index
|
||||
// // Indices `k` and `l` are implicitly captured in the body.
|
||||
// %transformed_i = addi %i, %k : index // index `i` is offset by %k
|
||||
// %transformed_j = addi %j, %l : index // index `j` is offset by %l
|
||||
// // Every use of %i, %j is replaced with %transformed_i, %transformed_j
|
||||
// <some operations that use %transformed_i, %transformed_j>
|
||||
// }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided>
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// TODO: Investigate whether mixing implicit and explicit indices
|
||||
// does not lead to losing information.
|
||||
static void
|
||||
transformIndexOps(OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
|
||||
const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
|
||||
// Skip operations that have no region attached.
|
||||
if (op->getNumRegions() == 0)
|
||||
return;
|
||||
assert(op->getNumRegions() == 1 && op->getRegion(0).getBlocks().size() == 1 &&
|
||||
"expected linalg operation to have one block.");
|
||||
Block &block = op->getRegion(0).front();
|
||||
|
||||
for (IndexOp indexOp :
|
||||
llvm::make_early_inc_range(block.getOps<linalg::IndexOp>())) {
|
||||
auto rangeIndex = loopIndexToRangeIndex.find(indexOp.dim());
|
||||
if (rangeIndex == loopIndexToRangeIndex.end())
|
||||
continue;
|
||||
// Offset the index by the value of the corresponding induction variable and
|
||||
// replace all uses of the previous value.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
b.setInsertionPointAfter(indexOp);
|
||||
AffineExpr index, iv;
|
||||
bindDims(b.getContext(), index, iv);
|
||||
AffineApplyOp applyOp = b.create<AffineApplyOp>(
|
||||
indexOp.getLoc(), index + iv,
|
||||
ValueRange{indexOp.getResult(), ivs[rangeIndex->second]});
|
||||
indexOp.getResult().replaceAllUsesExcept(
|
||||
applyOp.getResult(), SmallPtrSet<Operation *, 1>{applyOp});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LoopTy>
|
||||
static Optional<TiledLinalgOp>
|
||||
tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
|
||||
|
@ -299,8 +378,10 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
|
|||
},
|
||||
options.distribution);
|
||||
|
||||
// 3. Transforms index arguments of `linalg.generic` w.r.t. to the tiling.
|
||||
// 3a. Transforms index arguments of `linalg.generic` w.r.t. to the tiling.
|
||||
transformIndexedGenericOpIndices(b, res, ivs, loopIndexToRangeIndex);
|
||||
// 3b. Transform IndexOp results w.r.t. the tiling.
|
||||
transformIndexOps(b, res, ivs, loopIndexToRangeIndex);
|
||||
|
||||
// 4. Gather the newly created loops and return them with the new op.
|
||||
SmallVector<Operation *, 8> loops;
|
||||
|
|
|
@ -246,8 +246,7 @@ mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
|
|||
LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
|
||||
Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const {
|
||||
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
|
||||
// TODO: remove hasIndexSemantics check once index ops are supported.
|
||||
if (!linalgOp || linalgOp.hasIndexSemantics())
|
||||
if (!linalgOp)
|
||||
return failure();
|
||||
if (failed(filter.checkAndNotify(rewriter, linalgOp)))
|
||||
return failure();
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=10,25" -split-input-file | FileCheck %s -check-prefix=TILE-10n25
|
||||
// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=25,0" -split-input-file | FileCheck %s -check-prefix=TILE-25n0
|
||||
// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,25" -split-input-file | FileCheck %s -check-prefix=TILE-0n25
|
||||
|
||||
func @indexed_vector(%arg0: memref<50xindex>) {
|
||||
linalg.generic {indexing_maps = [affine_map<(i) -> (i)>],
|
||||
iterator_types = ["parallel"]}
|
||||
outs(%arg0 : memref<50xindex>) {
|
||||
^bb0(%a: index):
|
||||
%i = linalg.index 0 : index
|
||||
linalg.yield %i : index
|
||||
}
|
||||
return
|
||||
}
|
||||
// TILE-10n25-DAG: [[$MAP:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0 + d1)>
|
||||
// TILE-10n25-LABEL: func @indexed_vector
|
||||
// TILE-10n25: %[[C10:.*]] = constant 10 : index
|
||||
// TILE-10n25: scf.for %[[J:.*]] = {{.*}} step %[[C10]]
|
||||
// TILE-10n25: linalg.generic
|
||||
// TILE-10n25: %[[I:.*]] = linalg.index 0 : index
|
||||
// TILE-10n25: %[[NEW_I:.*]] = affine.apply [[$MAP]](%[[I]], %[[J]])
|
||||
// TILE-10n25: linalg.yield %[[NEW_I]] : index
|
||||
|
||||
// TILE-25n0-DAG: [[$MAP:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0 + d1)>
|
||||
// TILE-25n0-LABEL: func @indexed_vector
|
||||
// TILE-25n0: %[[C25:.*]] = constant 25 : index
|
||||
// TILE-25n0: scf.for %[[J:.*]] = {{.*}} step %[[C25]]
|
||||
// TILE-25n0: linalg.generic
|
||||
// TILE-25n0: %[[I:.*]] = linalg.index 0 : index
|
||||
// TILE-25n0: %[[NEW_I:.*]] = affine.apply [[$MAP]](%[[I]], %[[J]])
|
||||
// TILE-25n0: linalg.yield %[[NEW_I]] : index
|
||||
|
||||
// TILE-0n25-LABEL: func @indexed_vector
|
||||
// TILE-0n25-NOT: scf.for %[[J:.*]] = {{.*}} step %
|
||||
// TILE-0n25: linalg.generic
|
||||
|
||||
// -----
|
||||
|
||||
func @indexed_matrix(%arg0: memref<50x50xindex>) {
|
||||
linalg.generic {indexing_maps = [affine_map<(i, j) -> (i, j)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
outs(%arg0 : memref<50x50xindex>) {
|
||||
^bb0(%a: index):
|
||||
%i = linalg.index 0 : index
|
||||
%j = linalg.index 1 : index
|
||||
%sum = addi %i, %j : index
|
||||
linalg.yield %sum : index
|
||||
}
|
||||
return
|
||||
}
|
||||
// TILE-10n25-DAG: [[$MAP:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0 + d1)>
|
||||
// TILE-10n25-LABEL: func @indexed_matrix
|
||||
// TILE-10n25-DAG: %[[C25:.*]] = constant 25 : index
|
||||
// TILE-10n25-DAG: %[[C10:.*]] = constant 10 : index
|
||||
// TILE-10n25: scf.for %[[K:.*]] = {{.*}} step %[[C10]]
|
||||
// TILE-10n25: scf.for %[[L:.*]] = {{.*}} step %[[C25]]
|
||||
// TILE-10n25: linalg.generic
|
||||
// TILE-10n25: %[[I:.*]] = linalg.index 0 : index
|
||||
// TILE-10n25: %[[NEW_I:.*]] = affine.apply [[$MAP]](%[[I]], %[[K]])
|
||||
// TILE-10n25: %[[J:.*]] = linalg.index 1 : index
|
||||
// TILE-10n25: %[[NEW_J:.*]] = affine.apply [[$MAP]](%[[J]], %[[L]])
|
||||
// TILE-10n25: %[[SUM:.*]] = addi %[[NEW_I]], %[[NEW_J]] : index
|
||||
// TILE-10n25: linalg.yield %[[SUM]] : index
|
||||
|
||||
// TILE-25n0-DAG: [[$MAP:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0 + d1)>
|
||||
// TILE-25n0-LABEL: func @indexed_matrix
|
||||
// TILE-25n0: %[[C25:.*]] = constant 25 : index
|
||||
// TILE-25n0: scf.for %[[L:.*]] = {{.*}} step %[[C25]]
|
||||
// TILE-25n0: linalg.generic
|
||||
// TILE-25n0: %[[I:.*]] = linalg.index 0 : index
|
||||
// TILE-25n0: %[[NEW_I:.*]] = affine.apply [[$MAP]](%[[I]], %[[L]])
|
||||
// TILE-25n0: %[[J:.*]] = linalg.index 1 : index
|
||||
// TILE-25n0: %[[SUM:.*]] = addi %[[NEW_I]], %[[J]] : index
|
||||
// TILE-25n0: linalg.yield %[[SUM]] : index
|
||||
|
||||
// TILE-0n25-DAG: [[$MAP:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0 + d1)>
|
||||
// TILE-0n25-LABEL: func @indexed_matrix
|
||||
// TILE-0n25: %[[C25:.*]] = constant 25 : index
|
||||
// TILE-0n25: scf.for %[[L:.*]] = {{.*}} step %[[C25]]
|
||||
// TILE-0n25: linalg.generic
|
||||
// TILE-0n25: %[[I:.*]] = linalg.index 0 : index
|
||||
// TILE-0n25: %[[J:.*]] = linalg.index 1 : index
|
||||
// TILE-0n25: %[[NEW_J:.*]] = affine.apply [[$MAP]](%[[J]], %[[L]])
|
||||
// TILE-0n25: %[[SUM:.*]] = addi %[[I]], %[[NEW_J]] : index
|
||||
// TILE-0n25: linalg.yield %[[SUM]] : index
|
|
@ -377,18 +377,3 @@ func @pointwise(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memre
|
|||
// TILE-234: for
|
||||
// TILE-234-NOT: for
|
||||
// TILE-234: linalg.generic
|
||||
|
||||
// TILE-2-LABEL: func @index_op
|
||||
// TILE-2-NOT: for
|
||||
// TILE-2: linalg.generic
|
||||
func @index_op(%arg0: memref<?x?xindex>) {
|
||||
linalg.generic {
|
||||
indexing_maps = [affine_map<(i, j) -> (i, j)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
outs(%arg0 : memref<?x?xindex>) {
|
||||
^bb0(%arg1: index): // no predecessors
|
||||
%0 = linalg.index 1 : index
|
||||
linalg.yield %0 : index
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue