diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h index 2eda5dcd2af0..61f88c10470a 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -26,6 +26,9 @@ std::unique_ptr> createLinalgFusionPass(); std::unique_ptr> createLinalgTilingPass(ArrayRef tileSizes = {}); +std::unique_ptr> +createLinalgTilingToParallelLoopsPass(ArrayRef tileSizes = {}); + std::unique_ptr> createLinalgPromotionPass(bool dynamicBuffers); diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index a8ff5c297c10..bc316d814a04 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -118,6 +118,9 @@ Optional tileLinalgOp(OpBuilder &b, LinalgOp op, ArrayRef tileSizes, ArrayRef permutation = {}, OperationFolder *folder = nullptr); +Optional tileLinalgOpToParallelLoops( + OpBuilder &b, LinalgOp op, ArrayRef tileSizes, + ArrayRef permutation = {}, OperationFolder *folder = nullptr); /// Performs standalone tiling of a single LinalgOp by constant `tileSizes`. /// and permute the loop nest according to `permutation` @@ -138,6 +141,9 @@ Optional tileLinalgOp(OpBuilder &b, LinalgOp op, ArrayRef tileSizes, ArrayRef permutation = {}, OperationFolder *folder = nullptr); +Optional tileLinalgOpToParallelLoops( + OpBuilder &b, LinalgOp op, ArrayRef tileSizes, + ArrayRef permutation = {}, OperationFolder *folder = nullptr); template Optional tileLinalgOperation(OpBuilder &b, Operation *op, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index b69665b73bbf..51565be16572 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -451,12 +451,26 @@ mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, ArrayRef tileSizes, return tileLinalgOpImpl(b, op, tileSizes, permutation, folder); } +Optional mlir::linalg::tileLinalgOpToParallelLoops( + OpBuilder &b, LinalgOp op, ArrayRef tileSizes, + ArrayRef permutation, OperationFolder *folder) { + return tileLinalgOpImpl(b, op, tileSizes, permutation, + folder); +} + Optional mlir::linalg::tileLinalgOp( OpBuilder &b, LinalgOp op, ArrayRef tileSizes, ArrayRef permutation, OperationFolder *folder) { return tileLinalgOpImpl(b, op, tileSizes, permutation, folder); } +Optional mlir::linalg::tileLinalgOpToParallelLoops( + OpBuilder &b, LinalgOp op, ArrayRef tileSizes, + ArrayRef permutation, OperationFolder *folder) { + return tileLinalgOpImpl(b, op, tileSizes, permutation, + folder); +} + template static void tileLinalgOps(FuncOp f, ArrayRef tileSizes) { OpBuilder b(f); @@ -501,9 +515,23 @@ mlir::createLinalgTilingPass(ArrayRef tileSizes) { return std::make_unique>(tileSizes); } +std::unique_ptr> +mlir::createLinalgTilingToParallelLoopsPass(ArrayRef tileSizes) { + return std::make_unique>(tileSizes); +} + static PassRegistration> - pass("linalg-tile", "Tile operations in the linalg dialect", [] { + tiling_pass("linalg-tile", "Tile operations in the linalg dialect", [] { auto pass = std::make_unique>(); pass->tileSizes.assign(clTileSizes.begin(), clTileSizes.end()); return pass; }); + +static PassRegistration> + tiling_to_parallel_loops( + "linalg-tile-to-parallel-loops", + "Tile operations in the linalg dialect to parallel loops", [] { + auto pass = std::make_unique>(); + pass->tileSizes.assign(clTileSizes.begin(), clTileSizes.end()); + return pass; + }); diff --git a/mlir/test/Dialect/Linalg/tile_parallel.mlir b/mlir/test/Dialect/Linalg/tile_parallel.mlir new file mode 100644 index 000000000000..7dbcce9e3a8a --- /dev/null +++ b/mlir/test/Dialect/Linalg/tile_parallel.mlir @@ -0,0 +1,70 @@ +// RUN: mlir-opt %s -linalg-tile-to-parallel-loops -linalg-tile-sizes=2 | FileCheck %s -check-prefix=TILE-2 --dump-input-on-failure +// RUN: mlir-opt %s -linalg-tile-to-parallel-loops -linalg-tile-sizes=0,2 | FileCheck %s -check-prefix=TILE-02 --dump-input-on-failure +// RUN: mlir-opt %s -linalg-tile-to-parallel-loops -linalg-tile-sizes=0,0,2 | FileCheck %s -check-prefix=TILE-002 --dump-input-on-failure +// RUN: mlir-opt %s -linalg-tile-to-parallel-loops -linalg-tile-sizes=2,3,4 | FileCheck %s -check-prefix=TILE-234 --dump-input-on-failure + +#id_2d = affine_map<(i, j) -> (i, j)> +#pointwise_2d_trait = { + args_in = 2, + args_out = 1, + indexing_maps = [#id_2d, #id_2d, #id_2d], + iterator_types = ["parallel", "parallel"] +} + +func @sum(%lhs: memref, + %rhs: memref, + %sum: memref) { + linalg.generic #pointwise_2d_trait %lhs, %rhs, %sum { + ^bb0(%lhs_in: f32, %rhs_in: f32, %sum_out: f32): + %result = addf %lhs_in, %rhs_in : f32 + linalg.yield %result : f32 + }: memref, + memref, + memref + return +} +// TILE-2-LABEL: func @sum( +// TILE-2-SAME: [[LHS:%.*]]: {{.*}}, [[RHS:%.*]]: {{.*}}, [[SUM:%.*]]: {{.*}}) { +// TILE-2-DAG: [[C0:%.*]] = constant 0 : index +// TILE-2-DAG: [[C1:%.*]] = constant 1 : index +// TILE-2-DAG: [[C2:%.*]] = constant 2 : index +// TILE-2: [[LHS_ROWS:%.*]] = dim [[LHS]], 0 +// TILE-2: loop.parallel ([[I:%.*]]) = ([[C0]]) to ([[LHS_ROWS]]) step ([[C2]]) { +// TILE-2-NO: loop.parallel +// TILE-2: [[LHS_SUBVIEW:%.*]] = std.subview [[LHS]] +// TILE-2: [[RHS_SUBVIEW:%.*]] = std.subview [[RHS]] +// TILE-2: [[SUM_SUBVIEW:%.*]] = std.subview [[SUM]] +// TILE-2: linalg.generic {{.*}} [[LHS_SUBVIEW]], [[RHS_SUBVIEW]], [[SUM_SUBVIEW]] { + +// TILE-02-LABEL: func @sum( +// TILE-02-SAME: [[LHS:%.*]]: {{.*}}, [[RHS:%.*]]: {{.*}}, [[SUM:%.*]]: {{.*}}) { +// TILE-02-DAG: [[C0:%.*]] = constant 0 : index +// TILE-02-DAG: [[C1:%.*]] = constant 1 : index +// TILE-02-DAG: [[C2:%.*]] = constant 2 : index +// TILE-02: [[LHS_COLS:%.*]] = dim [[LHS]], 1 +// TILE-02: loop.parallel ([[I:%.*]]) = ([[C0]]) to ([[LHS_COLS]]) step ([[C2]]) { +// TILE-02-NO: loop.parallel +// TILE-02: [[LHS_SUBVIEW:%.*]] = std.subview [[LHS]] +// TILE-02: [[RHS_SUBVIEW:%.*]] = std.subview [[RHS]] +// TILE-02: [[SUM_SUBVIEW:%.*]] = std.subview [[SUM]] +// TILE-02: linalg.generic {{.*}} [[LHS_SUBVIEW]], [[RHS_SUBVIEW]], [[SUM_SUBVIEW]] { + +// TILE-002-LABEL: func @sum( +// TILE-002-SAME: [[LHS:%.*]]: {{.*}}, [[RHS:%.*]]: {{.*}}, [[SUM:%.*]]: {{.*}}) { +// TILE-002-NO: loop.parallel +// TILE-002: linalg.generic {{.*}} [[LHS]], [[RHS]], [[SUM]] { + +// TILE-234-LABEL: func @sum( +// TILE-234-SAME: [[LHS:%.*]]: {{.*}}, [[RHS:%.*]]: {{.*}}, [[SUM:%.*]]: {{.*}}) { +// TILE-234-DAG: [[C0:%.*]] = constant 0 : index +// TILE-234-DAG: [[C1:%.*]] = constant 1 : index +// TILE-234-DAG: [[C2:%.*]] = constant 2 : index +// TILE-234-DAG: [[C3:%.*]] = constant 3 : index +// TILE-234: [[LHS_ROWS:%.*]] = dim [[LHS]], 0 +// TILE-234: [[LHS_COLS:%.*]] = dim [[LHS]], 1 +// TILE-234: loop.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) to ([[LHS_ROWS]], [[LHS_COLS]]) step ([[C2]], [[C3]]) { +// TILE-234-NO: loop.parallel +// TILE-234: [[LHS_SUBVIEW:%.*]] = std.subview [[LHS]] +// TILE-234: [[RHS_SUBVIEW:%.*]] = std.subview [[RHS]] +// TILE-234: [[SUM_SUBVIEW:%.*]] = std.subview [[SUM]] +// TILE-234: linalg.generic {{.*}} [[LHS_SUBVIEW]], [[RHS_SUBVIEW]], [[SUM_SUBVIEW]] {