forked from OSchip/llvm-project
[Linalg] Add tiling of Linalg to parallel loops.
Differential Revision: https://reviews.llvm.org/D73955
This commit is contained in:
parent
399887c9e4
commit
baecae838d
|
@ -26,6 +26,9 @@ std::unique_ptr<OpPassBase<FuncOp>> createLinalgFusionPass();
|
|||
std::unique_ptr<OpPassBase<FuncOp>>
|
||||
createLinalgTilingPass(ArrayRef<int64_t> tileSizes = {});
|
||||
|
||||
std::unique_ptr<OpPassBase<FuncOp>>
|
||||
createLinalgTilingToParallelLoopsPass(ArrayRef<int64_t> tileSizes = {});
|
||||
|
||||
std::unique_ptr<OpPassBase<FuncOp>>
|
||||
createLinalgPromotionPass(bool dynamicBuffers);
|
||||
|
||||
|
|
|
@ -118,6 +118,9 @@ Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
|
|||
ArrayRef<Value> tileSizes,
|
||||
ArrayRef<unsigned> permutation = {},
|
||||
OperationFolder *folder = nullptr);
|
||||
Optional<TiledLinalgOp> tileLinalgOpToParallelLoops(
|
||||
OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes,
|
||||
ArrayRef<unsigned> permutation = {}, OperationFolder *folder = nullptr);
|
||||
|
||||
/// Performs standalone tiling of a single LinalgOp by constant `tileSizes`.
|
||||
/// and permute the loop nest according to `permutation`
|
||||
|
@ -138,6 +141,9 @@ Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
|
|||
ArrayRef<int64_t> tileSizes,
|
||||
ArrayRef<unsigned> permutation = {},
|
||||
OperationFolder *folder = nullptr);
|
||||
Optional<TiledLinalgOp> tileLinalgOpToParallelLoops(
|
||||
OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes,
|
||||
ArrayRef<unsigned> permutation = {}, OperationFolder *folder = nullptr);
|
||||
|
||||
template <typename... Args>
|
||||
Optional<TiledLinalgOp> tileLinalgOperation(OpBuilder &b, Operation *op,
|
||||
|
|
|
@ -451,12 +451,26 @@ mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes,
|
|||
return tileLinalgOpImpl<loop::ForOp>(b, op, tileSizes, permutation, folder);
|
||||
}
|
||||
|
||||
Optional<TiledLinalgOp> mlir::linalg::tileLinalgOpToParallelLoops(
|
||||
OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes,
|
||||
ArrayRef<unsigned> permutation, OperationFolder *folder) {
|
||||
return tileLinalgOpImpl<loop::ParallelOp>(b, op, tileSizes, permutation,
|
||||
folder);
|
||||
}
|
||||
|
||||
Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp(
|
||||
OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes,
|
||||
ArrayRef<unsigned> permutation, OperationFolder *folder) {
|
||||
return tileLinalgOpImpl<loop::ForOp>(b, op, tileSizes, permutation, folder);
|
||||
}
|
||||
|
||||
Optional<TiledLinalgOp> mlir::linalg::tileLinalgOpToParallelLoops(
|
||||
OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes,
|
||||
ArrayRef<unsigned> permutation, OperationFolder *folder) {
|
||||
return tileLinalgOpImpl<loop::ParallelOp>(b, op, tileSizes, permutation,
|
||||
folder);
|
||||
}
|
||||
|
||||
template <typename LoopTy>
|
||||
static void tileLinalgOps(FuncOp f, ArrayRef<int64_t> tileSizes) {
|
||||
OpBuilder b(f);
|
||||
|
@ -501,9 +515,23 @@ mlir::createLinalgTilingPass(ArrayRef<int64_t> tileSizes) {
|
|||
return std::make_unique<LinalgTilingPass<loop::ForOp>>(tileSizes);
|
||||
}
|
||||
|
||||
std::unique_ptr<OpPassBase<FuncOp>>
|
||||
mlir::createLinalgTilingToParallelLoopsPass(ArrayRef<int64_t> tileSizes) {
|
||||
return std::make_unique<LinalgTilingPass<loop::ParallelOp>>(tileSizes);
|
||||
}
|
||||
|
||||
static PassRegistration<LinalgTilingPass<loop::ForOp>>
|
||||
pass("linalg-tile", "Tile operations in the linalg dialect", [] {
|
||||
tiling_pass("linalg-tile", "Tile operations in the linalg dialect", [] {
|
||||
auto pass = std::make_unique<LinalgTilingPass<loop::ForOp>>();
|
||||
pass->tileSizes.assign(clTileSizes.begin(), clTileSizes.end());
|
||||
return pass;
|
||||
});
|
||||
|
||||
static PassRegistration<LinalgTilingPass<loop::ParallelOp>>
|
||||
tiling_to_parallel_loops(
|
||||
"linalg-tile-to-parallel-loops",
|
||||
"Tile operations in the linalg dialect to parallel loops", [] {
|
||||
auto pass = std::make_unique<LinalgTilingPass<loop::ParallelOp>>();
|
||||
pass->tileSizes.assign(clTileSizes.begin(), clTileSizes.end());
|
||||
return pass;
|
||||
});
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
// RUN: mlir-opt %s -linalg-tile-to-parallel-loops -linalg-tile-sizes=2 | FileCheck %s -check-prefix=TILE-2 --dump-input-on-failure
|
||||
// RUN: mlir-opt %s -linalg-tile-to-parallel-loops -linalg-tile-sizes=0,2 | FileCheck %s -check-prefix=TILE-02 --dump-input-on-failure
|
||||
// RUN: mlir-opt %s -linalg-tile-to-parallel-loops -linalg-tile-sizes=0,0,2 | FileCheck %s -check-prefix=TILE-002 --dump-input-on-failure
|
||||
// RUN: mlir-opt %s -linalg-tile-to-parallel-loops -linalg-tile-sizes=2,3,4 | FileCheck %s -check-prefix=TILE-234 --dump-input-on-failure
|
||||
|
||||
#id_2d = affine_map<(i, j) -> (i, j)>
|
||||
#pointwise_2d_trait = {
|
||||
args_in = 2,
|
||||
args_out = 1,
|
||||
indexing_maps = [#id_2d, #id_2d, #id_2d],
|
||||
iterator_types = ["parallel", "parallel"]
|
||||
}
|
||||
|
||||
func @sum(%lhs: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
%rhs: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
%sum: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
|
||||
linalg.generic #pointwise_2d_trait %lhs, %rhs, %sum {
|
||||
^bb0(%lhs_in: f32, %rhs_in: f32, %sum_out: f32):
|
||||
%result = addf %lhs_in, %rhs_in : f32
|
||||
linalg.yield %result : f32
|
||||
}: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
return
|
||||
}
|
||||
// TILE-2-LABEL: func @sum(
|
||||
// TILE-2-SAME: [[LHS:%.*]]: {{.*}}, [[RHS:%.*]]: {{.*}}, [[SUM:%.*]]: {{.*}}) {
|
||||
// TILE-2-DAG: [[C0:%.*]] = constant 0 : index
|
||||
// TILE-2-DAG: [[C1:%.*]] = constant 1 : index
|
||||
// TILE-2-DAG: [[C2:%.*]] = constant 2 : index
|
||||
// TILE-2: [[LHS_ROWS:%.*]] = dim [[LHS]], 0
|
||||
// TILE-2: loop.parallel ([[I:%.*]]) = ([[C0]]) to ([[LHS_ROWS]]) step ([[C2]]) {
|
||||
// TILE-2-NO: loop.parallel
|
||||
// TILE-2: [[LHS_SUBVIEW:%.*]] = std.subview [[LHS]]
|
||||
// TILE-2: [[RHS_SUBVIEW:%.*]] = std.subview [[RHS]]
|
||||
// TILE-2: [[SUM_SUBVIEW:%.*]] = std.subview [[SUM]]
|
||||
// TILE-2: linalg.generic {{.*}} [[LHS_SUBVIEW]], [[RHS_SUBVIEW]], [[SUM_SUBVIEW]] {
|
||||
|
||||
// TILE-02-LABEL: func @sum(
|
||||
// TILE-02-SAME: [[LHS:%.*]]: {{.*}}, [[RHS:%.*]]: {{.*}}, [[SUM:%.*]]: {{.*}}) {
|
||||
// TILE-02-DAG: [[C0:%.*]] = constant 0 : index
|
||||
// TILE-02-DAG: [[C1:%.*]] = constant 1 : index
|
||||
// TILE-02-DAG: [[C2:%.*]] = constant 2 : index
|
||||
// TILE-02: [[LHS_COLS:%.*]] = dim [[LHS]], 1
|
||||
// TILE-02: loop.parallel ([[I:%.*]]) = ([[C0]]) to ([[LHS_COLS]]) step ([[C2]]) {
|
||||
// TILE-02-NO: loop.parallel
|
||||
// TILE-02: [[LHS_SUBVIEW:%.*]] = std.subview [[LHS]]
|
||||
// TILE-02: [[RHS_SUBVIEW:%.*]] = std.subview [[RHS]]
|
||||
// TILE-02: [[SUM_SUBVIEW:%.*]] = std.subview [[SUM]]
|
||||
// TILE-02: linalg.generic {{.*}} [[LHS_SUBVIEW]], [[RHS_SUBVIEW]], [[SUM_SUBVIEW]] {
|
||||
|
||||
// TILE-002-LABEL: func @sum(
|
||||
// TILE-002-SAME: [[LHS:%.*]]: {{.*}}, [[RHS:%.*]]: {{.*}}, [[SUM:%.*]]: {{.*}}) {
|
||||
// TILE-002-NO: loop.parallel
|
||||
// TILE-002: linalg.generic {{.*}} [[LHS]], [[RHS]], [[SUM]] {
|
||||
|
||||
// TILE-234-LABEL: func @sum(
|
||||
// TILE-234-SAME: [[LHS:%.*]]: {{.*}}, [[RHS:%.*]]: {{.*}}, [[SUM:%.*]]: {{.*}}) {
|
||||
// TILE-234-DAG: [[C0:%.*]] = constant 0 : index
|
||||
// TILE-234-DAG: [[C1:%.*]] = constant 1 : index
|
||||
// TILE-234-DAG: [[C2:%.*]] = constant 2 : index
|
||||
// TILE-234-DAG: [[C3:%.*]] = constant 3 : index
|
||||
// TILE-234: [[LHS_ROWS:%.*]] = dim [[LHS]], 0
|
||||
// TILE-234: [[LHS_COLS:%.*]] = dim [[LHS]], 1
|
||||
// TILE-234: loop.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) to ([[LHS_ROWS]], [[LHS_COLS]]) step ([[C2]], [[C3]]) {
|
||||
// TILE-234-NO: loop.parallel
|
||||
// TILE-234: [[LHS_SUBVIEW:%.*]] = std.subview [[LHS]]
|
||||
// TILE-234: [[RHS_SUBVIEW:%.*]] = std.subview [[RHS]]
|
||||
// TILE-234: [[SUM_SUBVIEW:%.*]] = std.subview [[SUM]]
|
||||
// TILE-234: linalg.generic {{.*}} [[LHS_SUBVIEW]], [[RHS_SUBVIEW]], [[SUM_SUBVIEW]] {
|
Loading…
Reference in New Issue