[mlir] Fix loop unrolling: properly replace the arguments of the epilogue loop.

Using "replaceUsesOfWith" is incorrect because the same initializer value may appear multiple times.

For example, if the epilogue is needed when this loop is unrolled
```
%x:2 = scf.for ... iter_args(%arg1 = %c1, %arg2 = %c1) {
  ...
}
```
then both epilogue's arguments will be incorrectly renamed to use the same result index (note #1 in both cases):
```
%x_unrolled:2 = scf.for ... iter_args(%arg1 = %c1, %arg2 = %c1) {
  ...
}
%x_epilogue:2 = scf.for ... iter_args(%arg1 = %x_unrolled#1, %arg2 = %x_unrolled#1) {
  ...
}
```
This commit is contained in:
grosul1 2022-05-12 01:44:13 +00:00 committed by Mogball
parent 24532d05f8
commit a4b227c28a
2 changed files with 41 additions and 3 deletions

View File

@ -474,12 +474,12 @@ LogicalResult mlir::loopUnrollByFactor(
// Update uses of loop results.
auto results = forOp.getResults();
auto epilogueResults = epilogueForOp.getResults();
auto epilogueIterOperands = epilogueForOp.getIterOperands();
for (auto e : llvm::zip(results, epilogueResults, epilogueIterOperands)) {
for (auto e : llvm::zip(results, epilogueResults)) {
std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
epilogueForOp->replaceUsesOfWith(std::get<2>(e), std::get<0>(e));
}
epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
epilogueForOp.getNumIterOperands(), results);
(void)promoteIfSingleIteration(epilogueForOp);
}

View File

@ -276,3 +276,41 @@ func.func @static_loop_unroll_up_to_factor(%arg0 : memref<?xf32>) {
// UNROLL-UP-TO-NEXT: affine.store %{{.*}}, %[[MEM]][%[[V1]]] : memref<?xf32>
// UNROLL-UP-TO-NEXT: return
// Test that epilogue's arguments are correctly renamed.
func.func @static_loop_unroll_by_3_rename_epilogue_arguments() -> (f32, f32) {
%0 = arith.constant 7.0 : f32
%lb = arith.constant 0 : index
%ub = arith.constant 20 : index
%step = arith.constant 1 : index
%result:2 = scf.for %i0 = %lb to %ub step %step iter_args(%arg0 = %0, %arg1 = %0) -> (f32, f32) {
%add = arith.addf %arg0, %arg1 : f32
%mul = arith.mulf %arg0, %arg1 : f32
scf.yield %add, %mul : f32, f32
}
return %result#0, %result#1 : f32, f32
}
// UNROLL-BY-3-LABEL: func @static_loop_unroll_by_3_rename_epilogue_arguments
//
// UNROLL-BY-3-DAG: %[[CST:.*]] = arith.constant {{.*}} : f32
// UNROLL-BY-3-DAG: %[[C0:.*]] = arith.constant 0 : index
// UNROLL-BY-3-DAG: %[[C1:.*]] = arith.constant 1 : index
// UNROLL-BY-3-DAG: %[[C20:.*]] = arith.constant 20 : index
// UNROLL-BY-3-DAG: %[[C18:.*]] = arith.constant 18 : index
// UNROLL-BY-3-DAG: %[[C3:.*]] = arith.constant 3 : index
// UNROLL-BY-3: %[[FOR:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]] to %[[C18]] step %[[C3]]
// UNROLL-BY-3-SAME: iter_args(%[[ARG0:.*]] = %[[CST]], %[[ARG1:.*]] = %[[CST]]) -> (f32, f32) {
// UNROLL-BY-3-NEXT: %[[ADD0:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : f32
// UNROLL-BY-3-NEXT: %[[MUL0:.*]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32
// UNROLL-BY-3-NEXT: %[[ADD1:.*]] = arith.addf %[[ADD0]], %[[MUL0]] : f32
// UNROLL-BY-3-NEXT: %[[MUL1:.*]] = arith.mulf %[[ADD0]], %[[MUL0]] : f32
// UNROLL-BY-3-NEXT: %[[ADD2:.*]] = arith.addf %[[ADD1]], %[[MUL1]] : f32
// UNROLL-BY-3-NEXT: %[[MUL2:.*]] = arith.mulf %[[ADD1]], %[[MUL1]] : f32
// UNROLL-BY-3-NEXT: scf.yield %[[ADD2]], %[[MUL2]] : f32, f32
// UNROLL-BY-3-NEXT: }
// UNROLL-BY-3: %[[EFOR:.*]]:2 = scf.for %[[EIV:.*]] = %[[C18]] to %[[C20]] step %[[C1]]
// UNROLL-BY-3-SAME: iter_args(%[[EARG0:.*]] = %[[FOR]]#0, %[[EARG1:.*]] = %[[FOR]]#1) -> (f32, f32) {
// UNROLL-BY-3-NEXT: %[[EADD:.*]] = arith.addf %[[EARG0]], %[[EARG1]] : f32
// UNROLL-BY-3-NEXT: %[[EMUL:.*]] = arith.mulf %[[EARG0]], %[[EARG1]] : f32
// UNROLL-BY-3-NEXT: scf.yield %[[EADD]], %[[EMUL]] : f32, f32
// UNROLL-BY-3-NEXT: }
// UNROLL-BY-3-NEXT: return %[[EFOR]]#0, %[[EFOR]]#1 : f32, f32