diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 3bad54327e07..c62f27d8a22f 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -355,6 +355,23 @@ static Optional getFusableProducer(Value v) { return v.cast(); } +// Replace iter args of the outer most loop with region args of the inner most +// one. +static void replaceIterArgs(scf::ForOp outerFor, scf::ForOp innerFor, + PatternRewriter &rewriter) { + assert(outerFor.getNumIterOperands() == innerFor.getNumIterOperands() && + "expect same number of iter args"); + Block *block = &(*innerFor.getRegion().begin()); + for (auto it : + llvm::zip(outerFor.getIterOperands(), innerFor.getRegionIterArgs())) { + Value source = std::get<0>(it); + Value target = std::get<1>(it); + source.replaceUsesWithIf(target, [&](OpOperand &use) { + return use.getOwner()->getBlock() == block; + }); + } +} + FailureOr scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite( TilingInterface op, PatternRewriter &rewriter) const { @@ -470,5 +487,7 @@ scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite( } } } + replaceIterArgs(tileAndFuseResult.loops.front(), + tileAndFuseResult.loops.back(), rewriter); return tileAndFuseResult; } diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir index d1ca2d2c4625..888726995a46 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir @@ -23,7 +23,7 @@ func.func @gemm_fill_fusion(%arg0 : tensor, %arg1 : tensor) -> // CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]]) // CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] // CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] -// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]] +// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV0]], %[[IV1]]] // CHECK: %[[FILL_TILE:.+]] = linalg.fill // CHECK-SAME: outs(%[[INIT_TILE]] : // CHECK: %[[GEMM_TILE:.+]] = linalg.matmul @@ -68,7 +68,7 @@ func.func @gemm_generic_fusion(%arg0 : tensor, %arg1 : tensor, // CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]]) // CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] // CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] -// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]] +// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV0]], %[[IV1]]] // CHECK: %[[FILL_TILE:.+]] = linalg.fill // CHECK-SAME: outs(%[[INIT_TILE]] : // CHECK: %[[GEMM_TILE:.+]] = linalg.matmul @@ -123,7 +123,7 @@ func.func @gemm_gemm_fusion(%lhs0 : tensor, %rhs0 : tensor, %r // CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] : // CHECK-SAME: outs(%[[FILL0_TILE]] : // CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0] -// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[INIT1]][%[[IV]], 0] +// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG]][%[[IV]], 0] // CHECK: %[[FILL1_TILE:.+]] = linalg.fill // CHECK-SAME: outs(%[[INIT1_TILE]] : // CHECK: %[[GEMM1_TILE:.+]] = linalg.matmul