[mlir] Canonicalize single-iteration ParallelOp

Differential Revision: https://reviews.llvm.org/D100248
This commit is contained in:
Butygin 2021-04-10 19:38:11 +03:00
parent eae2d4b852
commit eb31540066
2 changed files with 97 additions and 7 deletions

View File

@ -1239,10 +1239,36 @@ struct CollapseSingleIterationLoops : public OpRewritePattern<ParallelOp> {
newSteps.push_back(step);
}
}
// Exit if all or none of the loop dimensions perform a single iteration.
if (newLowerBounds.size() == 0 ||
newLowerBounds.size() == op.lowerBound().size())
// Exit if none of the loop dimensions perform a single iteration.
if (newLowerBounds.size() == op.lowerBound().size())
return failure();
if (newLowerBounds.empty()) {
// All of the loop dimensions perform a single iteration. Inline
// loop body and nested ReduceOp's
SmallVector<Value> results;
results.reserve(op.initVals().size());
for (auto &bodyOp : op.getLoopBody().front().without_terminator()) {
auto reduce = dyn_cast<ReduceOp>(bodyOp);
if (!reduce) {
rewriter.clone(bodyOp, mapping);
continue;
}
Block &reduceBlock = reduce.reductionOperator().front();
auto initValIndex = results.size();
mapping.map(reduceBlock.getArgument(0), op.initVals()[initValIndex]);
mapping.map(reduceBlock.getArgument(1),
mapping.lookupOrDefault(reduce.operand()));
for (auto &reduceBodyOp : reduceBlock.without_terminator())
rewriter.clone(reduceBodyOp, mapping);
auto result = mapping.lookupOrDefault(
cast<ReduceReturnOp>(reduceBlock.getTerminator()).result());
results.push_back(result);
}
rewriter.replaceOp(op, results);
return success();
}
// Replace the parallel loop by lower-dimensional parallel loop.
auto newOp =
rewriter.create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,

View File

@ -3,7 +3,7 @@
// -----
func @single_iteration(%A: memref<?x?x?xi32>) {
func @single_iteration_some(%A: memref<?x?x?xi32>) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%c2 = constant 2 : index
@ -19,7 +19,7 @@ func @single_iteration(%A: memref<?x?x?xi32>) {
return
}
// CHECK-LABEL: func @single_iteration(
// CHECK-LABEL: func @single_iteration_some(
// CHECK-SAME: [[ARG0:%.*]]: memref<?x?x?xi32>) {
// CHECK-DAG: [[C42:%.*]] = constant 42 : i32
// CHECK-DAG: [[C7:%.*]] = constant 7 : index
@ -35,6 +35,70 @@ func @single_iteration(%A: memref<?x?x?xi32>) {
// -----
func @single_iteration_all(%A: memref<?x?x?xi32>) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%c3 = constant 3 : index
%c6 = constant 6 : index
%c7 = constant 7 : index
%c10 = constant 10 : index
scf.parallel (%i0, %i1, %i2) = (%c0, %c3, %c7) to (%c1, %c6, %c10) step (%c1, %c3, %c3) {
%c42 = constant 42 : i32
memref.store %c42, %A[%i0, %i1, %i2] : memref<?x?x?xi32>
scf.yield
}
return
}
// CHECK-LABEL: func @single_iteration_all(
// CHECK-SAME: [[ARG0:%.*]]: memref<?x?x?xi32>) {
// CHECK-DAG: [[C42:%.*]] = constant 42 : i32
// CHECK-DAG: [[C7:%.*]] = constant 7 : index
// CHECK-DAG: [[C3:%.*]] = constant 3 : index
// CHECK-DAG: [[C0:%.*]] = constant 0 : index
// CHECK-NOT: scf.parallel
// CHECK: memref.store [[C42]], [[ARG0]]{{\[}}[[C0]], [[C3]], [[C7]]] : memref<?x?x?xi32>
// CHECK-NOT: scf.yield
// CHECK: return
// -----
func @single_iteration_reduce(%A: index, %B: index) -> (index, index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%c2 = constant 2 : index
%c3 = constant 3 : index
%c6 = constant 6 : index
%0:2 = scf.parallel (%i0, %i1) = (%c1, %c3) to (%c2, %c6) step (%c1, %c3) init(%A, %B) -> (index, index) {
scf.reduce(%i0) : index {
^bb0(%lhs: index, %rhs: index):
%1 = addi %lhs, %rhs : index
scf.reduce.return %1 : index
}
scf.reduce(%i1) : index {
^bb0(%lhs: index, %rhs: index):
%2 = muli %lhs, %rhs : index
scf.reduce.return %2 : index
}
scf.yield
}
return %0#0, %0#1 : index, index
}
// CHECK-LABEL: func @single_iteration_reduce(
// CHECK-SAME: [[ARG0:%.*]]: index, [[ARG1:%.*]]: index)
// CHECK-DAG: [[C3:%.*]] = constant 3 : index
// CHECK-DAG: [[C1:%.*]] = constant 1 : index
// CHECK-NOT: scf.parallel
// CHECK-NOT: scf.reduce
// CHECK-NOT: scf.reduce.return
// CHECK-NOT: scf.yield
// CHECK: [[V0:%.*]] = addi [[ARG0]], [[C1]]
// CHECK: [[V1:%.*]] = muli [[ARG1]], [[C3]]
// CHECK: return [[V0]], [[V1]]
// -----
func private @side_effect()
func @one_unused(%cond: i1) -> (index) {
%c0 = constant 0 : index
@ -488,7 +552,7 @@ func @fold_away_iter_with_no_use_and_yielded_input(%arg0 : i32,
%ub : index, %lb : index, %step : index) -> (i32, i32) {
// CHECK-NEXT: %[[C32:.*]] = constant 32 : i32
%cst = constant 32 : i32
// CHECK-NEXT: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args({{.*}} = %[[A0]]) -> (i32) {
// CHECK-NEXT: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args({{.*}} = %[[A0]]) -> (i32) {
%0:2 = scf.for %arg1 = %lb to %ub step %step iter_args(%arg2 = %arg0, %arg3 = %cst)
-> (i32, i32) {
%1 = addi %arg2, %cst : i32
@ -512,7 +576,7 @@ func @fold_away_iter_and_result_with_no_use(%arg0 : i32,
%1 = addi %arg2, %cst : i32
scf.yield %1, %1 : i32, i32
}
// CHECK: return %[[FOR_RES]] : i32
return %0#0 : i32
}