forked from OSchip/llvm-project
[mlir][linalg] Fix tile and fuse for outermost reduction.
Tile and fuse failed if the outermost tile loop is a reduction dimension. Add the necessary check to handle outermost reductions and introduce a test case to verify the change. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D114012
This commit is contained in:
parent
a9e236bed8
commit
0ccc44cec0
|
@ -225,9 +225,9 @@ public:
|
|||
LogicalResult tileRootOp(OpBuilder &b, ArrayRef<int64_t> tileSizes,
|
||||
ArrayRef<int64_t> tileInterchange);
|
||||
|
||||
/// Fuse the producer of `rootOpOperand` into the tile loop nest. Returns the
|
||||
/// fused producer of fails if fusion is not possible.
|
||||
FailureOr<LinalgOp> fuseProducer(OpBuilder &b, OpOperand *rootOpOperand);
|
||||
/// Fuse the producer of `consumerOpOperand` into the tile loop nest. Returns
|
||||
/// the fused producer or fails if fusion is not possible.
|
||||
FailureOr<LinalgOp> fuseProducer(OpBuilder &b, OpOperand *consumerOpOperand);
|
||||
|
||||
/// Returns the replacement results for the original untiled root operation.
|
||||
ValueRange getRootOpReplacementResults();
|
||||
|
|
|
@ -317,8 +317,11 @@ LogicalResult TileLoopNest::tileRootOp(OpBuilder &b,
|
|||
|
||||
FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
|
||||
OpOperand *consumerOpOperand) {
|
||||
assert(tiledRootAndFusedOpsLoops.count(consumerOpOperand->getOwner()) != 0 &&
|
||||
"expect the operand owner is the root operation or a fused producer");
|
||||
// Check if the consumer has been tiled before. For example, it may not have
|
||||
// been tiled if the outermost tile loop is a reduction loop.
|
||||
if (tiledRootAndFusedOpsLoops.count(consumerOpOperand->getOwner()) == 0)
|
||||
return failure();
|
||||
|
||||
assert(this->isValid() &&
|
||||
"expect the tile loop nest to satisfy all invariants");
|
||||
|
||||
|
|
|
@ -232,6 +232,41 @@ builtin.func @fuse_indexed(%arg0: tensor<24x12xi32>,
|
|||
|
||||
// -----
|
||||
|
||||
#map0 = affine_map<(d0, d1) -> (d0, d1)>
|
||||
#map1 = affine_map<(d0, d1) -> (d0)>
|
||||
|
||||
// CHECK: fuse_outermost_reduction
|
||||
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<10x17xf32>
|
||||
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<10xf32>
|
||||
func @fuse_outermost_reduction(%arg0: tensor<10x17xf32>,
|
||||
%arg1: tensor<10xf32>) -> tensor<10xf32> {
|
||||
%cst = arith.constant 0.000000e+00 : f32
|
||||
%0 = linalg.fill(%cst, %arg0) : f32, tensor<10x17xf32> -> tensor<10x17xf32>
|
||||
|
||||
// Cannot fuse the output fill since the reduction loop is the outermost loop.
|
||||
// CHECK: %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG1]])
|
||||
%1 = linalg.fill(%cst, %arg1) : f32, tensor<10xf32> -> tensor<10xf32>
|
||||
|
||||
// CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG2:.*]] = %[[T0]]
|
||||
// CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]]
|
||||
|
||||
// Check the input fill has been fused.
|
||||
// CHECK: %[[T1:.*]] = tensor.extract_slice %[[ARG0]]
|
||||
// CHECK-SAME: %[[IV1]], %[[IV0]]
|
||||
// CHECK: %[[T2:.*]] = linalg.fill(%{{.*}}, %[[T1]])
|
||||
// CHECK: %[[T3:.*]] = tensor.extract_slice %[[ARG3]]
|
||||
// CHECK-SAME: %[[IV1]]
|
||||
// CHECK: linalg.generic {{.*}} ins(%[[T2]] {{.*}} outs(%[[T3]]
|
||||
%2 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "reduction"]} ins(%0 : tensor<10x17xf32>) outs(%1 : tensor<10xf32>) {
|
||||
^bb0(%arg2: f32, %arg3: f32): // no predecessors
|
||||
%3 = arith.addf %arg2, %arg3 : f32
|
||||
linalg.yield %3 : f32
|
||||
} -> tensor<10xf32>
|
||||
return %2 : tensor<10xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
|
||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (8, -d0 - d1 + 17)>
|
||||
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, -d1 - d2 + 17)>
|
||||
|
|
Loading…
Reference in New Issue