forked from OSchip/llvm-project
[mlir] Fix loop unrolling: properly replace the arguments of the epilogue loop.
Using "replaceUsesOfWith" is incorrect because the same initializer value may appear multiple times. For example, if the epilogue is needed when this loop is unrolled ``` %x:2 = scf.for ... iter_args(%arg1 = %c1, %arg2 = %c1) { ... } ``` then both epilogue's arguments will be incorrectly renamed to use the same result index (note #1 in both cases): ``` %x_unrolled:2 = scf.for ... iter_args(%arg1 = %c1, %arg2 = %c1) { ... } %x_epilogue:2 = scf.for ... iter_args(%arg1 = %x_unrolled#1, %arg2 = %x_unrolled#1) { ... } ```
This commit is contained in:
parent
24532d05f8
commit
a4b227c28a
|
@ -474,12 +474,12 @@ LogicalResult mlir::loopUnrollByFactor(
|
|||
// Update uses of loop results.
|
||||
auto results = forOp.getResults();
|
||||
auto epilogueResults = epilogueForOp.getResults();
|
||||
auto epilogueIterOperands = epilogueForOp.getIterOperands();
|
||||
|
||||
for (auto e : llvm::zip(results, epilogueResults, epilogueIterOperands)) {
|
||||
for (auto e : llvm::zip(results, epilogueResults)) {
|
||||
std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
|
||||
epilogueForOp->replaceUsesOfWith(std::get<2>(e), std::get<0>(e));
|
||||
}
|
||||
epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
|
||||
epilogueForOp.getNumIterOperands(), results);
|
||||
(void)promoteIfSingleIteration(epilogueForOp);
|
||||
}
|
||||
|
||||
|
|
|
@ -276,3 +276,41 @@ func.func @static_loop_unroll_up_to_factor(%arg0 : memref<?xf32>) {
|
|||
// UNROLL-UP-TO-NEXT: affine.store %{{.*}}, %[[MEM]][%[[V1]]] : memref<?xf32>
|
||||
// UNROLL-UP-TO-NEXT: return
|
||||
|
||||
// Test that epilogue's arguments are correctly renamed.
|
||||
func.func @static_loop_unroll_by_3_rename_epilogue_arguments() -> (f32, f32) {
|
||||
%0 = arith.constant 7.0 : f32
|
||||
%lb = arith.constant 0 : index
|
||||
%ub = arith.constant 20 : index
|
||||
%step = arith.constant 1 : index
|
||||
%result:2 = scf.for %i0 = %lb to %ub step %step iter_args(%arg0 = %0, %arg1 = %0) -> (f32, f32) {
|
||||
%add = arith.addf %arg0, %arg1 : f32
|
||||
%mul = arith.mulf %arg0, %arg1 : f32
|
||||
scf.yield %add, %mul : f32, f32
|
||||
}
|
||||
return %result#0, %result#1 : f32, f32
|
||||
}
|
||||
// UNROLL-BY-3-LABEL: func @static_loop_unroll_by_3_rename_epilogue_arguments
|
||||
//
|
||||
// UNROLL-BY-3-DAG: %[[CST:.*]] = arith.constant {{.*}} : f32
|
||||
// UNROLL-BY-3-DAG: %[[C0:.*]] = arith.constant 0 : index
|
||||
// UNROLL-BY-3-DAG: %[[C1:.*]] = arith.constant 1 : index
|
||||
// UNROLL-BY-3-DAG: %[[C20:.*]] = arith.constant 20 : index
|
||||
// UNROLL-BY-3-DAG: %[[C18:.*]] = arith.constant 18 : index
|
||||
// UNROLL-BY-3-DAG: %[[C3:.*]] = arith.constant 3 : index
|
||||
// UNROLL-BY-3: %[[FOR:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]] to %[[C18]] step %[[C3]]
|
||||
// UNROLL-BY-3-SAME: iter_args(%[[ARG0:.*]] = %[[CST]], %[[ARG1:.*]] = %[[CST]]) -> (f32, f32) {
|
||||
// UNROLL-BY-3-NEXT: %[[ADD0:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : f32
|
||||
// UNROLL-BY-3-NEXT: %[[MUL0:.*]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32
|
||||
// UNROLL-BY-3-NEXT: %[[ADD1:.*]] = arith.addf %[[ADD0]], %[[MUL0]] : f32
|
||||
// UNROLL-BY-3-NEXT: %[[MUL1:.*]] = arith.mulf %[[ADD0]], %[[MUL0]] : f32
|
||||
// UNROLL-BY-3-NEXT: %[[ADD2:.*]] = arith.addf %[[ADD1]], %[[MUL1]] : f32
|
||||
// UNROLL-BY-3-NEXT: %[[MUL2:.*]] = arith.mulf %[[ADD1]], %[[MUL1]] : f32
|
||||
// UNROLL-BY-3-NEXT: scf.yield %[[ADD2]], %[[MUL2]] : f32, f32
|
||||
// UNROLL-BY-3-NEXT: }
|
||||
// UNROLL-BY-3: %[[EFOR:.*]]:2 = scf.for %[[EIV:.*]] = %[[C18]] to %[[C20]] step %[[C1]]
|
||||
// UNROLL-BY-3-SAME: iter_args(%[[EARG0:.*]] = %[[FOR]]#0, %[[EARG1:.*]] = %[[FOR]]#1) -> (f32, f32) {
|
||||
// UNROLL-BY-3-NEXT: %[[EADD:.*]] = arith.addf %[[EARG0]], %[[EARG1]] : f32
|
||||
// UNROLL-BY-3-NEXT: %[[EMUL:.*]] = arith.mulf %[[EARG0]], %[[EARG1]] : f32
|
||||
// UNROLL-BY-3-NEXT: scf.yield %[[EADD]], %[[EMUL]] : f32, f32
|
||||
// UNROLL-BY-3-NEXT: }
|
||||
// UNROLL-BY-3-NEXT: return %[[EFOR]]#0, %[[EFOR]]#1 : f32, f32
|
||||
|
|
Loading…
Reference in New Issue