forked from OSchip/llvm-project
[Linalg] Add tiling for IndexedGenericOp with a region.
PiperOrigin-RevId: 284949355
This commit is contained in:
parent
98fbf41044
commit
bae8a7a724
|
@ -58,14 +58,18 @@ static bool isZero(Value *v) {
|
||||||
cast<ConstantIndexOp>(v->getDefiningOp()).getValue() == 0;
|
cast<ConstantIndexOp>(v->getDefiningOp()).getValue() == 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
using LoopIndexToRangeIndexMap = DenseMap<int, int>;
|
||||||
|
|
||||||
// Creates a number of ranges equal to the number of non-zero in `tileSizes`.
|
// Creates a number of ranges equal to the number of non-zero in `tileSizes`.
|
||||||
// One for each loop of the LinalgOp that is tiled. The `tileSizes` argument has
|
// One for each loop of the LinalgOp that is tiled. The `tileSizes` argument has
|
||||||
// one entry per surrounding loop. It uses zero as the convention that a
|
// one entry per surrounding loop. It uses zero as the convention that a
|
||||||
// particular loop is not tiled. This convention simplifies implementations by
|
// particular loop is not tiled. This convention simplifies implementations by
|
||||||
// avoiding affine map manipulations.
|
// avoiding affine map manipulations.
|
||||||
// The returned ranges correspond to the loop ranges, in the proper order, that
|
// The returned ranges correspond to the loop ranges, in the proper order, that
|
||||||
// are tiled and for which new loops will be created.
|
// are tiled and for which new loops will be created. Also the function returns
|
||||||
static SmallVector<SubViewOp::Range, 4>
|
// a map from loop indices of the LinalgOp to the corresponding non-empty range
|
||||||
|
// indices of newly created loops.
|
||||||
|
static std::tuple<SmallVector<SubViewOp::Range, 4>, LoopIndexToRangeIndexMap>
|
||||||
makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
|
makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
|
||||||
ArrayRef<Value *> allViewSizes,
|
ArrayRef<Value *> allViewSizes,
|
||||||
ArrayRef<Value *> allTileSizes, OperationFolder *folder) {
|
ArrayRef<Value *> allTileSizes, OperationFolder *folder) {
|
||||||
|
@ -75,11 +79,15 @@ makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
|
||||||
SmallVector<Value *, 4> tileSizes(allTileSizes.begin(), allTileSizes.end());
|
SmallVector<Value *, 4> tileSizes(allTileSizes.begin(), allTileSizes.end());
|
||||||
|
|
||||||
// Traverse the tile sizes, which are in loop order, erase zeros everywhere.
|
// Traverse the tile sizes, which are in loop order, erase zeros everywhere.
|
||||||
for (int idx = tileSizes.size() - 1; idx >= 0; --idx) {
|
LoopIndexToRangeIndexMap loopIndexToRangeIndex;
|
||||||
if (isZero(tileSizes[idx])) {
|
for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) {
|
||||||
viewSizes.erase(viewSizes.begin() + idx);
|
if (isZero(tileSizes[idx - zerosCount])) {
|
||||||
tileSizes.erase(tileSizes.begin() + idx);
|
viewSizes.erase(viewSizes.begin() + idx - zerosCount);
|
||||||
|
tileSizes.erase(tileSizes.begin() + idx - zerosCount);
|
||||||
|
++zerosCount;
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
|
loopIndexToRangeIndex[idx] = idx - zerosCount;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a new range with the applied tile sizes.
|
// Create a new range with the applied tile sizes.
|
||||||
|
@ -88,10 +96,11 @@ makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
|
||||||
res.push_back(SubViewOp::Range{constant_index(folder, 0), viewSizes[idx],
|
res.push_back(SubViewOp::Range{constant_index(folder, 0), viewSizes[idx],
|
||||||
tileSizes[idx]});
|
tileSizes[idx]});
|
||||||
}
|
}
|
||||||
return res;
|
return std::make_tuple(res, loopIndexToRangeIndex);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Helper visitor to determine whether an AffineExpr is tiled.
|
// Helper visitor to determine whether an AffineExpr is tiled.
|
||||||
// This is achieved by traversing every AffineDimExpr with position `pos` and
|
// This is achieved by traversing every AffineDimExpr with position `pos` and
|
||||||
// checking whether the corresponding `tileSizes[pos]` is non-zero.
|
// checking whether the corresponding `tileSizes[pos]` is non-zero.
|
||||||
|
@ -117,8 +126,99 @@ struct TileCheck : public AffineExprVisitor<TileCheck> {
|
||||||
bool isTiled;
|
bool isTiled;
|
||||||
ArrayRef<Value *> tileSizes;
|
ArrayRef<Value *> tileSizes;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
// IndexedGenericOp explicitly uses induction variables in the loop body. The
|
||||||
|
// values of the indices that are used in the loop body for any given access of
|
||||||
|
// input/output memref before `subview` op was applied should be invariant with
|
||||||
|
// respect to tiling.
|
||||||
|
//
|
||||||
|
// Therefore, if the operation is tiled, we have to transform the indices
|
||||||
|
// accordingly, i.e. offset them by the values of the corresponding induction
|
||||||
|
// variables that are captured implicitly in the body of the op.
|
||||||
|
//
|
||||||
|
// Example. `linalg.indexed_generic` before tiling:
|
||||||
|
//
|
||||||
|
// #id_2d = (i, j) -> (i, j)
|
||||||
|
// #pointwise_2d_trait = {
|
||||||
|
// indexing_maps = [#id_2d, #id_2d],
|
||||||
|
// iterator_types = ["parallel", "parallel"],
|
||||||
|
// n_views = [1, 1]
|
||||||
|
// }
|
||||||
|
// linalg.indexed_generic #pointwise_2d_trait %operand, %result {
|
||||||
|
// ^bb0(%i: index, %j: index, %operand_in: f32, %result_in: f32):
|
||||||
|
// <some operations that use %i, %j>
|
||||||
|
// }: memref<50x100xf32>, memref<50x100xf32>
|
||||||
|
//
|
||||||
|
// After tiling pass with tiles sizes 10 and 25:
|
||||||
|
//
|
||||||
|
// #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2)
|
||||||
|
//
|
||||||
|
// %c1 = constant 1 : index
|
||||||
|
// %c0 = constant 0 : index
|
||||||
|
// %c25 = constant 25 : index
|
||||||
|
// %c10 = constant 10 : index
|
||||||
|
// operand_dim_0 = dim %operand, 0 : memref<50x100xf32>
|
||||||
|
// operand_dim_1 = dim %operand, 1 : memref<50x100xf32>
|
||||||
|
// loop.for %k = %c0 to operand_dim_0 step %c10 {
|
||||||
|
// loop.for %l = %c0 to operand_dim_1 step %c25 {
|
||||||
|
// %4 = std.subview %operand[%k, %l][%c10, %c25][%c1, %c1]
|
||||||
|
// : memref<50x100xf32> to memref<?x?xf32, #strided>
|
||||||
|
// %5 = std.subview %result[%k, %l][%c10, %c25][%c1, %c1]
|
||||||
|
// : memref<50x100xf32> to memref<?x?xf32, #strided>
|
||||||
|
// linalg.indexed_generic pointwise_2d_trait %4, %5 {
|
||||||
|
// ^bb0(%i: index, %j: index, %operand_in: f32, %result_in: f32):
|
||||||
|
// // Indices `k` and `l` are implicitly captured in the body.
|
||||||
|
// %transformed_i = addi %i, %k : index // index `i` is offset by %k
|
||||||
|
// %transformed_j = addi %j, %l : index // index `j` is offset by %l
|
||||||
|
// // Every use of %i, %j is replaced with %transformed_i, %transformed_j
|
||||||
|
// <some operations that use %transformed_i, %transformed_j>
|
||||||
|
// }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided>
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// TODO(pifon, ntv): Investigate whether mixing implicit and explicit indices
|
||||||
|
// does not lead to losing information.
|
||||||
|
void transformIndexedGenericOpIndices(
|
||||||
|
OpBuilder &b, LinalgOp op, ArrayRef<ValueHandle *> pivs,
|
||||||
|
const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
|
||||||
|
auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op.getOperation());
|
||||||
|
if (!indexedGenericOp)
|
||||||
|
return;
|
||||||
|
|
||||||
|
// `linalg.indexed_generic` comes in two flavours. One has a region with a
|
||||||
|
// single block that defines the loop body. The other has a `fun` attribute
|
||||||
|
// that refers to an existing function symbol. The `fun` function call will be
|
||||||
|
// inserted in the loop body in that case.
|
||||||
|
//
|
||||||
|
// TODO(pifon): Add support for `linalg.indexed_generic` with `fun` attribute.
|
||||||
|
auto ®ion = indexedGenericOp.region();
|
||||||
|
if (region.empty()) {
|
||||||
|
indexedGenericOp.emitError("op expected a region");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto &block = region.getBlocks().front();
|
||||||
|
|
||||||
|
OpBuilder::InsertionGuard g(b);
|
||||||
|
b.setInsertionPointToStart(&block);
|
||||||
|
for (unsigned i = 0; i < indexedGenericOp.getNumLoops(); ++i) {
|
||||||
|
auto rangeIndex = loopIndexToRangeIndex.find(i);
|
||||||
|
if (rangeIndex == loopIndexToRangeIndex.end())
|
||||||
|
continue;
|
||||||
|
Value *oldIndex = block.getArgument(i);
|
||||||
|
// Offset the index argument `i` by the value of the corresponding induction
|
||||||
|
// variable and replace all uses of the previous value.
|
||||||
|
Value *newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex,
|
||||||
|
pivs[rangeIndex->second]->getValue());
|
||||||
|
for (auto &use : oldIndex->getUses()) {
|
||||||
|
if (use.getOwner() == newIndex->getDefiningOp())
|
||||||
|
continue;
|
||||||
|
use.set(newIndex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static bool isTiled(AffineExpr expr, ArrayRef<Value *> tileSizes) {
|
static bool isTiled(AffineExpr expr, ArrayRef<Value *> tileSizes) {
|
||||||
if (!expr)
|
if (!expr)
|
||||||
return false;
|
return false;
|
||||||
|
@ -244,7 +344,10 @@ llvm::Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp(
|
||||||
auto viewSizesToLoopsMap =
|
auto viewSizesToLoopsMap =
|
||||||
inversePermutation(concatAffineMaps(loopToOperandRangesMaps(op)));
|
inversePermutation(concatAffineMaps(loopToOperandRangesMaps(op)));
|
||||||
assert(viewSizesToLoopsMap && "expected invertible map");
|
assert(viewSizesToLoopsMap && "expected invertible map");
|
||||||
auto loopRanges =
|
|
||||||
|
SmallVector<SubViewOp::Range, 4> loopRanges;
|
||||||
|
LoopIndexToRangeIndexMap loopIndexToRangeIndex;
|
||||||
|
std::tie(loopRanges, loopIndexToRangeIndex) =
|
||||||
makeTiledLoopRanges(b, scope.getLocation(), viewSizesToLoopsMap,
|
makeTiledLoopRanges(b, scope.getLocation(), viewSizesToLoopsMap,
|
||||||
viewSizes, tileSizes, folder);
|
viewSizes, tileSizes, folder);
|
||||||
if (!permutation.empty())
|
if (!permutation.empty())
|
||||||
|
@ -274,7 +377,10 @@ llvm::Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp(
|
||||||
res = op.clone(b, loc, views);
|
res = op.clone(b, loc, views);
|
||||||
});
|
});
|
||||||
|
|
||||||
// 4. Gather the newly created loops and return them with the new op.
|
// 4. Transforms index arguments of `linalg.generic` w.r.t. to the tiling.
|
||||||
|
transformIndexedGenericOpIndices(b, res, pivs, loopIndexToRangeIndex);
|
||||||
|
|
||||||
|
// 5. Gather the newly created loops and return them with the new op.
|
||||||
SmallVector<ForOp, 8> loops;
|
SmallVector<ForOp, 8> loops;
|
||||||
loops.reserve(ivs.size());
|
loops.reserve(ivs.size());
|
||||||
for (auto iv : ivs)
|
for (auto iv : ivs)
|
||||||
|
|
|
@ -0,0 +1,102 @@
|
||||||
|
// RUN: mlir-opt %s -linalg-tile -linalg-tile-sizes=10,25 | FileCheck %s -check-prefix=TILE-10n25
|
||||||
|
// RUN: mlir-opt %s -linalg-tile -linalg-tile-sizes=25,0 | FileCheck %s -check-prefix=TILE-25n0
|
||||||
|
// RUN: mlir-opt %s -linalg-tile -linalg-tile-sizes=0,25 | FileCheck %s -check-prefix=TILE-0n25
|
||||||
|
|
||||||
|
#id_1d = (i) -> (i)
|
||||||
|
#pointwise_1d_trait = {
|
||||||
|
indexing_maps = [#id_1d, #id_1d],
|
||||||
|
iterator_types = ["parallel"],
|
||||||
|
n_views = [1, 1]
|
||||||
|
}
|
||||||
|
func @indexed_generic_vector(%operand: memref<50xf32>, %result: memref<50xf32>) {
|
||||||
|
linalg.indexed_generic #pointwise_1d_trait %operand, %result {
|
||||||
|
^bb0(%i: index, %operand_in: f32, %result_in: f32):
|
||||||
|
%i_int = index_cast %i: index to i32
|
||||||
|
%i_float = sitofp %i_int : i32 to f32
|
||||||
|
%out = addf %operand_in, %i_float : f32
|
||||||
|
linalg.yield %out : f32
|
||||||
|
}: memref<50xf32>, memref<50xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// TILE-10n25-LABEL: func @indexed_generic_vector
|
||||||
|
// TILE-10n25: %[[C10:.*]] = constant 10 : index
|
||||||
|
// TILE-10n25: loop.for %[[J:.*]] = {{.*}} step %[[C10]]
|
||||||
|
// TILE-10n25: linalg.indexed_generic
|
||||||
|
// TILE-10n25: ^bb0(%[[I:.*]]: index, %[[IN:.*]]: f32, %[[OUT:.*]]: f32)
|
||||||
|
// TILE-10n25: %[[NEW_I:.*]] = addi %[[I]], %[[J]] : index
|
||||||
|
// TILE-10n25: %[[NEW_I_INT:.*]] = index_cast %[[NEW_I]] : index to i32
|
||||||
|
// TILE-10n25: %[[NEW_I_FLOAT:.*]] = sitofp %[[NEW_I_INT]] : i32 to f32
|
||||||
|
// TILE-10n25: %[[OUT:.*]] = addf %[[IN]], %[[NEW_I_FLOAT]] : f32
|
||||||
|
|
||||||
|
// TILE-25n0-LABEL: func @indexed_generic_vector
|
||||||
|
// TILE-25n0: %[[C25:.*]] = constant 25 : index
|
||||||
|
// TILE-25n0: loop.for %[[J:.*]] = {{.*}} step %[[C25]]
|
||||||
|
// TILE-25n0: linalg.indexed_generic
|
||||||
|
// TILE-25n0: ^bb0(%[[I:.*]]: index, %[[IN:.*]]: f32, %[[OUT:.*]]: f32)
|
||||||
|
// TILE-25n0: %[[NEW_I:.*]] = addi %[[I]], %[[J]] : index
|
||||||
|
// TILE-25n0: %[[NEW_I_INT:.*]] = index_cast %[[NEW_I]] : index to i32
|
||||||
|
// TILE-25n0: %[[NEW_I_FLOAT:.*]] = sitofp %[[NEW_I_INT]] : i32 to f32
|
||||||
|
// TILE-25n0: %[[OUT:.*]] = addf %[[IN]], %[[NEW_I_FLOAT]] : f32
|
||||||
|
|
||||||
|
// TILE-0n25-LABEL: func @indexed_generic_vector
|
||||||
|
// TILE-0n25-NOT: loop.for %[[J:.*]] = {{.*}} step %[[C25]]
|
||||||
|
// TILE-0n25: linalg.indexed_generic
|
||||||
|
|
||||||
|
#combined_indices_trait = {
|
||||||
|
indexing_maps = [
|
||||||
|
(i, j) -> (j, i + j),
|
||||||
|
(i, j) -> (i, j)
|
||||||
|
],
|
||||||
|
iterator_types = ["parallel", "parallel"],
|
||||||
|
n_views = [1, 1]
|
||||||
|
}
|
||||||
|
func @indexed_generic_matrix(%operand: memref<50x100xf32>, %result: memref<50x100xf32>) {
|
||||||
|
linalg.indexed_generic #combined_indices_trait %operand, %result {
|
||||||
|
^bb0(%i: index, %j: index, %operand_in: f32, %result_in: f32):
|
||||||
|
%i_int = index_cast %i: index to i32
|
||||||
|
%i_float = sitofp %i_int : i32 to f32
|
||||||
|
%j_int = index_cast %j: index to i32
|
||||||
|
%j_float = sitofp %j_int : i32 to f32
|
||||||
|
%out = addf %i_float, %j_float : f32
|
||||||
|
linalg.yield %out : f32
|
||||||
|
}: memref<50x100xf32>, memref<50x100xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// TILE-10n25-LABEL: func @indexed_generic_matrix
|
||||||
|
// TILE-10n25: %[[C25:.*]] = constant 25 : index
|
||||||
|
// TILE-10n25: %[[C10:.*]] = constant 10 : index
|
||||||
|
// TILE-10n25: loop.for %[[K:.*]] = {{.*}} step %[[C10]]
|
||||||
|
// TILE-10n25: loop.for %[[L:.*]] = {{.*}} step %[[C25]]
|
||||||
|
// TILE-10n25: linalg.indexed_generic
|
||||||
|
// TILE-10n25: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index, %[[IN:.*]]: f32, %[[OUT:.*]]: f32):
|
||||||
|
// TILE-10n25: %[[NEW_I:.*]] = addi %[[I]], %[[K]] : index
|
||||||
|
// TILE-10n25: %[[NEW_J:.*]] = addi %[[J]], %[[L]] : index
|
||||||
|
// TILE-10n25: %[[NEW_INT_I:.*]] = index_cast %[[NEW_I]] : index to i32
|
||||||
|
// TILE-10n25: %[[NEW_FLOAT_I:.*]] = sitofp %[[NEW_INT_I]] : i32 to f32
|
||||||
|
// TILE-10n25: %[[NEW_INT_J:.*]] = index_cast %[[NEW_J]] : index to i32
|
||||||
|
// TILE-10n25: %[[NEW_FLOAT_J:.*]] = sitofp %[[NEW_INT_J]] : i32 to f32
|
||||||
|
// TILE-10n25: %[[OUT:.*]] = addf %[[NEW_FLOAT_I]], %[[NEW_FLOAT_J]] : f32
|
||||||
|
|
||||||
|
// TILE-25n0-LABEL: func @indexed_generic_matrix
|
||||||
|
// TILE-25n0: %[[C25:.*]] = constant 25 : index
|
||||||
|
// TILE-25n0: loop.for %[[L:.*]] = {{.*}} step %[[C25]]
|
||||||
|
// TILE-25n0: linalg.indexed_generic
|
||||||
|
// TILE-25n0: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index, %[[IN:.*]]: f32, %[[OUT:.*]]: f32):
|
||||||
|
// TILE-25n0: %[[NEW_I:.*]] = addi %[[I]], %[[L]] : index
|
||||||
|
// TILE-25n0: %[[NEW_INT_I:.*]] = index_cast %[[NEW_I]] : index to i32
|
||||||
|
// TILE-25n0: %[[NEW_FLOAT_I:.*]] = sitofp %[[NEW_INT_I]] : i32 to f32
|
||||||
|
// TILE-25n0: %[[INT_J:.*]] = index_cast %[[J]] : index to i32
|
||||||
|
// TILE-25n0: %[[FLOAT_J:.*]] = sitofp %[[INT_J]] : i32 to f32
|
||||||
|
// TILE-25n0: %[[OUT:.*]] = addf %[[NEW_FLOAT_I]], %[[FLOAT_J]] : f32
|
||||||
|
|
||||||
|
// TILE-0n25-LABEL: func @indexed_generic_matrix
|
||||||
|
// TILE-0n25: %[[C25:.*]] = constant 25 : index
|
||||||
|
// TILE-0n25: loop.for %[[L:.*]] = {{.*}} step %[[C25]]
|
||||||
|
// TILE-0n25: linalg.indexed_generic
|
||||||
|
// TILE-0n25: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index, %[[IN:.*]]: f32, %[[OUT:.*]]: f32):
|
||||||
|
// TILE-0n25: %[[NEW_J:.*]] = addi %[[J]], %[[L]] : index
|
||||||
|
// TILE-0n25: %[[INT_I:.*]] = index_cast %[[I]] : index to i32
|
||||||
|
// TILE-0n25: %[[FLOAT_I:.*]] = sitofp %[[INT_I]] : i32 to f32
|
||||||
|
// TILE-0n25: %[[NEW_INT_J:.*]] = index_cast %[[NEW_J]] : index to i32
|
||||||
|
// TILE-0n25: %[[NEW_FLOAT_J:.*]] = sitofp %[[NEW_INT_J]] : i32 to f32
|
||||||
|
// TILE-0n25: %[[OUT:.*]] = addf %[[FLOAT_I]], %[[NEW_FLOAT_J]] : f32
|
Loading…
Reference in New Issue