[mlir][linalg] Improve codegen when tiling PadTensor evenly

Produce simpler IR with more static type information and fewer affine expressions.

Differential Revision: https://reviews.llvm.org/D105530
This commit is contained in:
Matthias Springer 2021-07-15 11:27:52 +09:00
parent 318ce4ad92
commit a0e02018be
3 changed files with 54 additions and 20 deletions

View File

@ -494,6 +494,7 @@ void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx);
ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx);
tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx);
tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx);
memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx);
tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
memref::ViewOp::getCanonicalizationPatterns(patterns, ctx);
@ -513,7 +514,15 @@ static void insertTilingPatterns(RewritePatternSet &patterns,
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>::insert(patterns, options);
patterns.add<PadTensorOpTilingPattern>(patterns.getContext(), options);
}
static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {
MLIRContext *ctx = funcOp.getContext();
RewritePatternSet patterns(ctx);
patterns.add<ExtractSliceOfPadTensorSwapPattern>(patterns.getContext());
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
(void)applyPatternsAndFoldGreedily(
funcOp, getLinalgTilingCanonicalizationPatterns(ctx));
}
static void
@ -527,6 +536,7 @@ applyTilingToLoopPatterns(LinalgTilingLoopType loopType, FuncOp funcOp,
MLIRContext *ctx = funcOp.getContext();
RewritePatternSet patterns(ctx);
insertTilingPatterns(patterns, options);
patterns.add<AffineMinSCFCanonicalizationPattern>(patterns.getContext());
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
(void)applyPatternsAndFoldGreedily(
funcOp, getLinalgTilingCanonicalizationPatterns(ctx));
@ -534,6 +544,10 @@ applyTilingToLoopPatterns(LinalgTilingLoopType loopType, FuncOp funcOp,
funcOp.walk([](LinalgOp op) {
op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
});
// Apply swap pattern after generating loop nest and running
// canonicalizations.
applyExtractSliceOfPadTensorSwapPattern(funcOp);
}
namespace {

View File

@ -92,3 +92,33 @@ func @static_pad_tensor(%input_tensor: tensor<7x9xf32>,
} : tensor<7x9xf32> to tensor<15x16xf32>
return %0 : tensor<15x16xf32>
}
// -----
// TILE1-LABEL: func @static_pad_tile_evenly(
// TILE1-SAME: %[[IN:.*]]: tensor<7x9xf32>, %[[OUT:.*]]: tensor<14x15xf32>
// TILE1-DAG: %[[C0:.*]] = constant 0 : index
// TILE1-DAG: %[[C3:.*]] = constant 3 : index
// TILE1-DAG: %[[C15:.*]] = constant 15 : index
// TILE1: %[[RESULT:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C15]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] =
// TILE1: %[[R2:.*]] = scf.if
// TILE1: %[[GEN:.*]] = tensor.generate
// TILE1: scf.yield %[[GEN]] : tensor<14x3xf32>
// TILE1: else
// TILE1: %[[SLICE:.*]] = tensor.extract_slice %arg0[0, %{{.*}}] [7, %{{.*}}] [1, 1] : tensor<7x9xf32> to tensor<7x?xf32>
// TILE1: %[[PAD:.*]] = linalg.pad_tensor %8 low[0, 0] high[7, %{{.*}}]
// TILE1: %[[CAST:.*]] = tensor.cast %[[PAD]] : tensor<14x?xf32> to tensor<14x3xf32>
// TILE1: scf.yield %[[CAST]] : tensor<14x3xf32>
// TILE1: %[[R3:.*]] = tensor.insert_slice %[[R2]] into %[[INNER_OUT]][0, %[[IV]]] [14, 3] [1, 1] : tensor<14x3xf32> into tensor<14x15xf32>
// TILE1: scf.yield %[[R3]] : tensor<14x15xf32>
// TILE1: return %[[RESULT]] : tensor<14x15xf32>
func @static_pad_tile_evenly(%input_tensor: tensor<7x9xf32>,
%output_tensor: tensor<14x15xf32>,
%pad_value: f32) -> tensor<14x15xf32> {
%0 = linalg.pad_tensor %input_tensor
low[0, 0] high[7, 6] into %output_tensor {
^bb0(%arg1: index, %arg2: index):
linalg.yield %pad_value : f32
} : tensor<7x9xf32> to tensor<14x15xf32>
return %0 : tensor<14x15xf32>
}

View File

@ -20,10 +20,6 @@
// TILE-234-DAG: #[[$bound_map_3:.*]] = affine_map<(d0)[s0] -> (3, -d0 + s0)>
// TILE-234-DAG: #[[$bound_map_4:.*]] = affine_map<(d0)[s0] -> (4, -d0 + s0)>
// TILE-2-DAG: #[[$bound_map_static:.*]] = affine_map<(d0) -> (2, -d0 + 10)>
// TILE-02-DAG: #[[$bound_map_static:.*]] = affine_map<(d0) -> (2, -d0 + 12)>
// TILE-002-DAG: #[[$bound_map_static:.*]] = affine_map<(d0) -> (2, -d0 + 16)>
// TILE-2-DAG: #[[$stride_99_1_layout_map:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 99 + s0 + d1)>
// TILE-02-DAG: #[[$stride_99_1_layout_map:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 99 + s0 + d1)>
// TILE-234-DAG: #[[$stride_99_1_layout_map:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 99 + s0 + d1)>
@ -132,10 +128,8 @@ func @matmul_static(%arg0: memref<10x16xf32, offset: ?, strides: [?, 1]>,
// TILE-2-DAG: %[[C2:.*]] = constant 2 : index
// TILE-2-DAG: %[[M:.*]] = constant 10 : index
// TILE-2: scf.for %[[I:.*]] = %{{.*}} to %[[M]] step %{{.*}} {
// TILE-2: %[[MIN2:.*]] = affine.min #[[$bound_map_static]](%[[I]])
// TILE-2: %[[sAi:.*]] = memref.subview %{{.*}}[%[[I]], 0] [%[[MIN2]], 16] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<?x16xf32, #[[$strided2D]]>
// TILE-2: %[[MIN22:.*]] = affine.min #[[$bound_map_static]](%[[I]])
// TILE-2: %[[sCi:.*]] = memref.subview %{{.*}}[%[[I]], 0] [%[[MIN22]], 12] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<?x12xf32, #[[$strided2D]]>
// TILE-2: %[[sAi:.*]] = memref.subview %{{.*}}[%[[I]], 0] [2, 16] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<2x16xf32, #[[$strided2D]]>
// TILE-2: %[[sCi:.*]] = memref.subview %{{.*}}[%[[I]], 0] [2, 12] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<2x12xf32, #[[$strided2D]]>
// TILE-2: linalg.matmul ins(%[[sAi]], %{{.*}}{{.*}} outs(%[[sCi]]
// TILE-02-LABEL: func @matmul_static(
@ -143,10 +137,8 @@ func @matmul_static(%arg0: memref<10x16xf32, offset: ?, strides: [?, 1]>,
// TILE-02-DAG: %[[C2:.*]] = constant 2 : index
// TILE-02-DAG: %[[N:.*]] = constant 12 : index
// TILE-02: scf.for %[[J:.*]] = %{{.*}} to %[[N]] step %{{.*}} {
// TILE-02: %[[MIN2:.*]] = affine.min #[[$bound_map_static]](%[[J]])
// TILE-02: %[[sBj:.*]] = memref.subview %{{.*}}[0, %[[J]]] [16, %[[MIN2]]] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<16x?xf32, #[[$strided2D]]>
// TILE-02: %[[MIN22:.*]] = affine.min #[[$bound_map_static]](%[[J]])
// TILE-02: %[[sCj:.*]] = memref.subview %{{.*}}[0, %[[J]]] [10, %[[MIN22]]] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<10x?xf32, #[[$strided2D]]>
// TILE-02: %[[sBj:.*]] = memref.subview %{{.*}}[0, %[[J]]] [16, 2] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<16x2xf32, #[[$strided2D]]>
// TILE-02: %[[sCj:.*]] = memref.subview %{{.*}}[0, %[[J]]] [10, 2] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<10x2xf32, #[[$strided2D]]>
// TILE-02: linalg.matmul ins(%{{.*}}, %[[sBj]]{{.*}} outs(%[[sCj]]
// TILE-002-LABEL: func @matmul_static(
@ -154,10 +146,8 @@ func @matmul_static(%arg0: memref<10x16xf32, offset: ?, strides: [?, 1]>,
// TILE-002-DAG: %[[C2:.*]] = constant 2 : index
// TILE-002-DAG: %[[C16:.*]] = constant 16 : index
// TILE-002: scf.for %[[K:.*]] = %{{.*}}{{.*}} to %[[C16]] step %{{.*}} {
// TILE-002: %[[MIN2:.*]] = affine.min #[[$bound_map_static]](%[[K]])
// TILE-002: %[[sAj:.*]] = memref.subview %{{.*}}[0, %[[K]]] [10, %[[MIN2]]] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<10x?xf32, #[[$strided2D]]>
// TILE-002: %[[MIN22:.*]] = affine.min #[[$bound_map_static]](%[[K]])
// TILE-002: %[[sBj:.*]] = memref.subview %{{.*}}[%[[K]], 0] [%[[MIN22]], 12] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<?x12xf32, #[[$strided2D]]>
// TILE-002: %[[sAj:.*]] = memref.subview %{{.*}}[0, %[[K]]] [10, 2] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<10x2xf32, #[[$strided2D]]>
// TILE-002: %[[sBj:.*]] = memref.subview %{{.*}}[%[[K]], 0] [2, 12] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<2x12xf32, #[[$strided2D]]>
// TILE-002: linalg.matmul ins(%[[sAj]], %[[sBj]]{{.*}} outs(%{{.*}}
// TILE-234-LABEL: func @matmul_static(
@ -171,9 +161,9 @@ func @matmul_static(%arg0: memref<10x16xf32, offset: ?, strides: [?, 1]>,
// TILE-234: scf.for %[[I:.*]] = %{{.*}}{{.*}} to %[[C10]] step %{{.*}} {
// TILE-234: scf.for %[[J:.*]] = %{{.*}}{{.*}} to %[[C12]] step %{{.*}} {
// TILE-234: scf.for %[[K:.*]] = %{{.*}}{{.*}} to %[[C16]] step %{{.*}} {
// TILE-234: %[[sAik:.*]] = memref.subview %{{.*}}[%[[I]], %[[K]]] [%{{.*}}, %{{.*}}] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<?x?xf32, #[[$strided2D]]>
// TILE-234: %[[sBkj:.*]] = memref.subview %{{.*}}[%[[K]], %[[J]]] [%{{.*}}, %{{.*}}] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<?x?xf32, #[[$strided2D]]>
// TILE-234: %[[sCij:.*]] = memref.subview %{{.*}}[%[[I]], %[[J]]] [%{{.*}}, %{{.*}}] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<?x?xf32, #[[$strided2D]]>
// TILE-234: %[[sAik:.*]] = memref.subview %{{.*}}[%[[I]], %[[K]]] [2, 4] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<2x4xf32, #[[$strided2D]]>
// TILE-234: %[[sBkj:.*]] = memref.subview %{{.*}}[%[[K]], %[[J]]] [4, 3] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<4x3xf32, #[[$strided2D]]>
// TILE-234: %[[sCij:.*]] = memref.subview %{{.*}}[%[[I]], %[[J]]] [2, 3] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<2x3xf32, #[[$strided2D]]>
//
// TILE-234: linalg.matmul ins(%[[sAik]], %[[sBkj]]{{.*}} outs(%[[sCij]]
@ -312,7 +302,7 @@ func @fill_static(%arg0: memref<127x99xf32>, %arg1: f32) {
// TILE-234: for
// TILE-234-NOT: for
// TILE-234: memref.subview{{.*}} : memref<127x99xf32>
// TILE-234: linalg.fill{{.*}} : f32, memref<?x?xf32, #[[$stride_99_1_layout_map]]>
// TILE-234: linalg.fill{{.*}} : f32, memref<?x3xf32, #[[$stride_99_1_layout_map]]>
func @fill(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: f32) {