[Linalg] Add tiling for IndexedGenericOp with a region.

PiperOrigin-RevId: 284949355
This commit is contained in:
Alexander Belyaev 2019-12-11 02:56:06 -08:00 committed by A. Unique TensorFlower
parent 98fbf41044
commit bae8a7a724
2 changed files with 217 additions and 9 deletions

View File

@ -58,14 +58,18 @@ static bool isZero(Value *v) {
cast<ConstantIndexOp>(v->getDefiningOp()).getValue() == 0;
}
using LoopIndexToRangeIndexMap = DenseMap<int, int>;
// Creates a number of ranges equal to the number of non-zero in `tileSizes`.
// One for each loop of the LinalgOp that is tiled. The `tileSizes` argument has
// one entry per surrounding loop. It uses zero as the convention that a
// particular loop is not tiled. This convention simplifies implementations by
// avoiding affine map manipulations.
// The returned ranges correspond to the loop ranges, in the proper order, that
// are tiled and for which new loops will be created.
static SmallVector<SubViewOp::Range, 4>
// are tiled and for which new loops will be created. Also the function returns
// a map from loop indices of the LinalgOp to the corresponding non-empty range
// indices of newly created loops.
static std::tuple<SmallVector<SubViewOp::Range, 4>, LoopIndexToRangeIndexMap>
makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
ArrayRef<Value *> allViewSizes,
ArrayRef<Value *> allTileSizes, OperationFolder *folder) {
@ -75,11 +79,15 @@ makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
SmallVector<Value *, 4> tileSizes(allTileSizes.begin(), allTileSizes.end());
// Traverse the tile sizes, which are in loop order, erase zeros everywhere.
for (int idx = tileSizes.size() - 1; idx >= 0; --idx) {
if (isZero(tileSizes[idx])) {
viewSizes.erase(viewSizes.begin() + idx);
tileSizes.erase(tileSizes.begin() + idx);
LoopIndexToRangeIndexMap loopIndexToRangeIndex;
for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) {
if (isZero(tileSizes[idx - zerosCount])) {
viewSizes.erase(viewSizes.begin() + idx - zerosCount);
tileSizes.erase(tileSizes.begin() + idx - zerosCount);
++zerosCount;
continue;
}
loopIndexToRangeIndex[idx] = idx - zerosCount;
}
// Create a new range with the applied tile sizes.
@ -88,10 +96,11 @@ makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
res.push_back(SubViewOp::Range{constant_index(folder, 0), viewSizes[idx],
tileSizes[idx]});
}
return res;
return std::make_tuple(res, loopIndexToRangeIndex);
}
namespace {
// Helper visitor to determine whether an AffineExpr is tiled.
// This is achieved by traversing every AffineDimExpr with position `pos` and
// checking whether the corresponding `tileSizes[pos]` is non-zero.
@ -117,8 +126,99 @@ struct TileCheck : public AffineExprVisitor<TileCheck> {
bool isTiled;
ArrayRef<Value *> tileSizes;
};
} // namespace
// IndexedGenericOp explicitly uses induction variables in the loop body. The
// values of the indices that are used in the loop body for any given access of
// input/output memref before `subview` op was applied should be invariant with
// respect to tiling.
//
// Therefore, if the operation is tiled, we have to transform the indices
// accordingly, i.e. offset them by the values of the corresponding induction
// variables that are captured implicitly in the body of the op.
//
// Example. `linalg.indexed_generic` before tiling:
//
// #id_2d = (i, j) -> (i, j)
// #pointwise_2d_trait = {
// indexing_maps = [#id_2d, #id_2d],
// iterator_types = ["parallel", "parallel"],
// n_views = [1, 1]
// }
// linalg.indexed_generic #pointwise_2d_trait %operand, %result {
// ^bb0(%i: index, %j: index, %operand_in: f32, %result_in: f32):
// <some operations that use %i, %j>
// }: memref<50x100xf32>, memref<50x100xf32>
//
// After tiling pass with tiles sizes 10 and 25:
//
// #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2)
//
// %c1 = constant 1 : index
// %c0 = constant 0 : index
// %c25 = constant 25 : index
// %c10 = constant 10 : index
// operand_dim_0 = dim %operand, 0 : memref<50x100xf32>
// operand_dim_1 = dim %operand, 1 : memref<50x100xf32>
// loop.for %k = %c0 to operand_dim_0 step %c10 {
// loop.for %l = %c0 to operand_dim_1 step %c25 {
// %4 = std.subview %operand[%k, %l][%c10, %c25][%c1, %c1]
// : memref<50x100xf32> to memref<?x?xf32, #strided>
// %5 = std.subview %result[%k, %l][%c10, %c25][%c1, %c1]
// : memref<50x100xf32> to memref<?x?xf32, #strided>
// linalg.indexed_generic pointwise_2d_trait %4, %5 {
// ^bb0(%i: index, %j: index, %operand_in: f32, %result_in: f32):
// // Indices `k` and `l` are implicitly captured in the body.
// %transformed_i = addi %i, %k : index // index `i` is offset by %k
// %transformed_j = addi %j, %l : index // index `j` is offset by %l
// // Every use of %i, %j is replaced with %transformed_i, %transformed_j
// <some operations that use %transformed_i, %transformed_j>
// }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided>
// }
// }
//
// TODO(pifon, ntv): Investigate whether mixing implicit and explicit indices
// does not lead to losing information.
void transformIndexedGenericOpIndices(
OpBuilder &b, LinalgOp op, ArrayRef<ValueHandle *> pivs,
const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op.getOperation());
if (!indexedGenericOp)
return;
// `linalg.indexed_generic` comes in two flavours. One has a region with a
// single block that defines the loop body. The other has a `fun` attribute
// that refers to an existing function symbol. The `fun` function call will be
// inserted in the loop body in that case.
//
// TODO(pifon): Add support for `linalg.indexed_generic` with `fun` attribute.
auto &region = indexedGenericOp.region();
if (region.empty()) {
indexedGenericOp.emitError("op expected a region");
return;
}
auto &block = region.getBlocks().front();
OpBuilder::InsertionGuard g(b);
b.setInsertionPointToStart(&block);
for (unsigned i = 0; i < indexedGenericOp.getNumLoops(); ++i) {
auto rangeIndex = loopIndexToRangeIndex.find(i);
if (rangeIndex == loopIndexToRangeIndex.end())
continue;
Value *oldIndex = block.getArgument(i);
// Offset the index argument `i` by the value of the corresponding induction
// variable and replace all uses of the previous value.
Value *newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex,
pivs[rangeIndex->second]->getValue());
for (auto &use : oldIndex->getUses()) {
if (use.getOwner() == newIndex->getDefiningOp())
continue;
use.set(newIndex);
}
}
}
static bool isTiled(AffineExpr expr, ArrayRef<Value *> tileSizes) {
if (!expr)
return false;
@ -244,7 +344,10 @@ llvm::Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp(
auto viewSizesToLoopsMap =
inversePermutation(concatAffineMaps(loopToOperandRangesMaps(op)));
assert(viewSizesToLoopsMap && "expected invertible map");
auto loopRanges =
SmallVector<SubViewOp::Range, 4> loopRanges;
LoopIndexToRangeIndexMap loopIndexToRangeIndex;
std::tie(loopRanges, loopIndexToRangeIndex) =
makeTiledLoopRanges(b, scope.getLocation(), viewSizesToLoopsMap,
viewSizes, tileSizes, folder);
if (!permutation.empty())
@ -274,7 +377,10 @@ llvm::Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp(
res = op.clone(b, loc, views);
});
// 4. Gather the newly created loops and return them with the new op.
// 4. Transforms index arguments of `linalg.generic` w.r.t. to the tiling.
transformIndexedGenericOpIndices(b, res, pivs, loopIndexToRangeIndex);
// 5. Gather the newly created loops and return them with the new op.
SmallVector<ForOp, 8> loops;
loops.reserve(ivs.size());
for (auto iv : ivs)

View File

@ -0,0 +1,102 @@
// RUN: mlir-opt %s -linalg-tile -linalg-tile-sizes=10,25 | FileCheck %s -check-prefix=TILE-10n25
// RUN: mlir-opt %s -linalg-tile -linalg-tile-sizes=25,0 | FileCheck %s -check-prefix=TILE-25n0
// RUN: mlir-opt %s -linalg-tile -linalg-tile-sizes=0,25 | FileCheck %s -check-prefix=TILE-0n25
#id_1d = (i) -> (i)
#pointwise_1d_trait = {
indexing_maps = [#id_1d, #id_1d],
iterator_types = ["parallel"],
n_views = [1, 1]
}
func @indexed_generic_vector(%operand: memref<50xf32>, %result: memref<50xf32>) {
linalg.indexed_generic #pointwise_1d_trait %operand, %result {
^bb0(%i: index, %operand_in: f32, %result_in: f32):
%i_int = index_cast %i: index to i32
%i_float = sitofp %i_int : i32 to f32
%out = addf %operand_in, %i_float : f32
linalg.yield %out : f32
}: memref<50xf32>, memref<50xf32>
return
}
// TILE-10n25-LABEL: func @indexed_generic_vector
// TILE-10n25: %[[C10:.*]] = constant 10 : index
// TILE-10n25: loop.for %[[J:.*]] = {{.*}} step %[[C10]]
// TILE-10n25: linalg.indexed_generic
// TILE-10n25: ^bb0(%[[I:.*]]: index, %[[IN:.*]]: f32, %[[OUT:.*]]: f32)
// TILE-10n25: %[[NEW_I:.*]] = addi %[[I]], %[[J]] : index
// TILE-10n25: %[[NEW_I_INT:.*]] = index_cast %[[NEW_I]] : index to i32
// TILE-10n25: %[[NEW_I_FLOAT:.*]] = sitofp %[[NEW_I_INT]] : i32 to f32
// TILE-10n25: %[[OUT:.*]] = addf %[[IN]], %[[NEW_I_FLOAT]] : f32
// TILE-25n0-LABEL: func @indexed_generic_vector
// TILE-25n0: %[[C25:.*]] = constant 25 : index
// TILE-25n0: loop.for %[[J:.*]] = {{.*}} step %[[C25]]
// TILE-25n0: linalg.indexed_generic
// TILE-25n0: ^bb0(%[[I:.*]]: index, %[[IN:.*]]: f32, %[[OUT:.*]]: f32)
// TILE-25n0: %[[NEW_I:.*]] = addi %[[I]], %[[J]] : index
// TILE-25n0: %[[NEW_I_INT:.*]] = index_cast %[[NEW_I]] : index to i32
// TILE-25n0: %[[NEW_I_FLOAT:.*]] = sitofp %[[NEW_I_INT]] : i32 to f32
// TILE-25n0: %[[OUT:.*]] = addf %[[IN]], %[[NEW_I_FLOAT]] : f32
// TILE-0n25-LABEL: func @indexed_generic_vector
// TILE-0n25-NOT: loop.for %[[J:.*]] = {{.*}} step %[[C25]]
// TILE-0n25: linalg.indexed_generic
#combined_indices_trait = {
indexing_maps = [
(i, j) -> (j, i + j),
(i, j) -> (i, j)
],
iterator_types = ["parallel", "parallel"],
n_views = [1, 1]
}
func @indexed_generic_matrix(%operand: memref<50x100xf32>, %result: memref<50x100xf32>) {
linalg.indexed_generic #combined_indices_trait %operand, %result {
^bb0(%i: index, %j: index, %operand_in: f32, %result_in: f32):
%i_int = index_cast %i: index to i32
%i_float = sitofp %i_int : i32 to f32
%j_int = index_cast %j: index to i32
%j_float = sitofp %j_int : i32 to f32
%out = addf %i_float, %j_float : f32
linalg.yield %out : f32
}: memref<50x100xf32>, memref<50x100xf32>
return
}
// TILE-10n25-LABEL: func @indexed_generic_matrix
// TILE-10n25: %[[C25:.*]] = constant 25 : index
// TILE-10n25: %[[C10:.*]] = constant 10 : index
// TILE-10n25: loop.for %[[K:.*]] = {{.*}} step %[[C10]]
// TILE-10n25: loop.for %[[L:.*]] = {{.*}} step %[[C25]]
// TILE-10n25: linalg.indexed_generic
// TILE-10n25: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index, %[[IN:.*]]: f32, %[[OUT:.*]]: f32):
// TILE-10n25: %[[NEW_I:.*]] = addi %[[I]], %[[K]] : index
// TILE-10n25: %[[NEW_J:.*]] = addi %[[J]], %[[L]] : index
// TILE-10n25: %[[NEW_INT_I:.*]] = index_cast %[[NEW_I]] : index to i32
// TILE-10n25: %[[NEW_FLOAT_I:.*]] = sitofp %[[NEW_INT_I]] : i32 to f32
// TILE-10n25: %[[NEW_INT_J:.*]] = index_cast %[[NEW_J]] : index to i32
// TILE-10n25: %[[NEW_FLOAT_J:.*]] = sitofp %[[NEW_INT_J]] : i32 to f32
// TILE-10n25: %[[OUT:.*]] = addf %[[NEW_FLOAT_I]], %[[NEW_FLOAT_J]] : f32
// TILE-25n0-LABEL: func @indexed_generic_matrix
// TILE-25n0: %[[C25:.*]] = constant 25 : index
// TILE-25n0: loop.for %[[L:.*]] = {{.*}} step %[[C25]]
// TILE-25n0: linalg.indexed_generic
// TILE-25n0: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index, %[[IN:.*]]: f32, %[[OUT:.*]]: f32):
// TILE-25n0: %[[NEW_I:.*]] = addi %[[I]], %[[L]] : index
// TILE-25n0: %[[NEW_INT_I:.*]] = index_cast %[[NEW_I]] : index to i32
// TILE-25n0: %[[NEW_FLOAT_I:.*]] = sitofp %[[NEW_INT_I]] : i32 to f32
// TILE-25n0: %[[INT_J:.*]] = index_cast %[[J]] : index to i32
// TILE-25n0: %[[FLOAT_J:.*]] = sitofp %[[INT_J]] : i32 to f32
// TILE-25n0: %[[OUT:.*]] = addf %[[NEW_FLOAT_I]], %[[FLOAT_J]] : f32
// TILE-0n25-LABEL: func @indexed_generic_matrix
// TILE-0n25: %[[C25:.*]] = constant 25 : index
// TILE-0n25: loop.for %[[L:.*]] = {{.*}} step %[[C25]]
// TILE-0n25: linalg.indexed_generic
// TILE-0n25: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index, %[[IN:.*]]: f32, %[[OUT:.*]]: f32):
// TILE-0n25: %[[NEW_J:.*]] = addi %[[J]], %[[L]] : index
// TILE-0n25: %[[INT_I:.*]] = index_cast %[[I]] : index to i32
// TILE-0n25: %[[FLOAT_I:.*]] = sitofp %[[INT_I]] : i32 to f32
// TILE-0n25: %[[NEW_INT_J:.*]] = index_cast %[[NEW_J]] : index to i32
// TILE-0n25: %[[NEW_FLOAT_J:.*]] = sitofp %[[NEW_INT_J]] : i32 to f32
// TILE-0n25: %[[OUT:.*]] = addf %[[FLOAT_I]], %[[NEW_FLOAT_J]] : f32