diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 206e61c6ac92..d4c96e51d549 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -474,12 +474,12 @@ LogicalResult mlir::loopUnrollByFactor( // Update uses of loop results. auto results = forOp.getResults(); auto epilogueResults = epilogueForOp.getResults(); - auto epilogueIterOperands = epilogueForOp.getIterOperands(); - for (auto e : llvm::zip(results, epilogueResults, epilogueIterOperands)) { + for (auto e : llvm::zip(results, epilogueResults)) { std::get<0>(e).replaceAllUsesWith(std::get<1>(e)); - epilogueForOp->replaceUsesOfWith(std::get<2>(e), std::get<0>(e)); } + epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(), + epilogueForOp.getNumIterOperands(), results); (void)promoteIfSingleIteration(epilogueForOp); } diff --git a/mlir/test/Dialect/SCF/loop-unroll.mlir b/mlir/test/Dialect/SCF/loop-unroll.mlir index 6a832578d581..dc2c07f291e7 100644 --- a/mlir/test/Dialect/SCF/loop-unroll.mlir +++ b/mlir/test/Dialect/SCF/loop-unroll.mlir @@ -276,3 +276,41 @@ func.func @static_loop_unroll_up_to_factor(%arg0 : memref) { // UNROLL-UP-TO-NEXT: affine.store %{{.*}}, %[[MEM]][%[[V1]]] : memref // UNROLL-UP-TO-NEXT: return +// Test that epilogue's arguments are correctly renamed. +func.func @static_loop_unroll_by_3_rename_epilogue_arguments() -> (f32, f32) { + %0 = arith.constant 7.0 : f32 + %lb = arith.constant 0 : index + %ub = arith.constant 20 : index + %step = arith.constant 1 : index + %result:2 = scf.for %i0 = %lb to %ub step %step iter_args(%arg0 = %0, %arg1 = %0) -> (f32, f32) { + %add = arith.addf %arg0, %arg1 : f32 + %mul = arith.mulf %arg0, %arg1 : f32 + scf.yield %add, %mul : f32, f32 + } + return %result#0, %result#1 : f32, f32 +} +// UNROLL-BY-3-LABEL: func @static_loop_unroll_by_3_rename_epilogue_arguments +// +// UNROLL-BY-3-DAG: %[[CST:.*]] = arith.constant {{.*}} : f32 +// UNROLL-BY-3-DAG: %[[C0:.*]] = arith.constant 0 : index +// UNROLL-BY-3-DAG: %[[C1:.*]] = arith.constant 1 : index +// UNROLL-BY-3-DAG: %[[C20:.*]] = arith.constant 20 : index +// UNROLL-BY-3-DAG: %[[C18:.*]] = arith.constant 18 : index +// UNROLL-BY-3-DAG: %[[C3:.*]] = arith.constant 3 : index +// UNROLL-BY-3: %[[FOR:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]] to %[[C18]] step %[[C3]] +// UNROLL-BY-3-SAME: iter_args(%[[ARG0:.*]] = %[[CST]], %[[ARG1:.*]] = %[[CST]]) -> (f32, f32) { +// UNROLL-BY-3-NEXT: %[[ADD0:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : f32 +// UNROLL-BY-3-NEXT: %[[MUL0:.*]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32 +// UNROLL-BY-3-NEXT: %[[ADD1:.*]] = arith.addf %[[ADD0]], %[[MUL0]] : f32 +// UNROLL-BY-3-NEXT: %[[MUL1:.*]] = arith.mulf %[[ADD0]], %[[MUL0]] : f32 +// UNROLL-BY-3-NEXT: %[[ADD2:.*]] = arith.addf %[[ADD1]], %[[MUL1]] : f32 +// UNROLL-BY-3-NEXT: %[[MUL2:.*]] = arith.mulf %[[ADD1]], %[[MUL1]] : f32 +// UNROLL-BY-3-NEXT: scf.yield %[[ADD2]], %[[MUL2]] : f32, f32 +// UNROLL-BY-3-NEXT: } +// UNROLL-BY-3: %[[EFOR:.*]]:2 = scf.for %[[EIV:.*]] = %[[C18]] to %[[C20]] step %[[C1]] +// UNROLL-BY-3-SAME: iter_args(%[[EARG0:.*]] = %[[FOR]]#0, %[[EARG1:.*]] = %[[FOR]]#1) -> (f32, f32) { +// UNROLL-BY-3-NEXT: %[[EADD:.*]] = arith.addf %[[EARG0]], %[[EARG1]] : f32 +// UNROLL-BY-3-NEXT: %[[EMUL:.*]] = arith.mulf %[[EARG0]], %[[EARG1]] : f32 +// UNROLL-BY-3-NEXT: scf.yield %[[EADD]], %[[EMUL]] : f32, f32 +// UNROLL-BY-3-NEXT: } +// UNROLL-BY-3-NEXT: return %[[EFOR]]#0, %[[EFOR]]#1 : f32, f32