diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp index 5af62dafe6d9..d1646e92b8d4 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -145,11 +145,12 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp, fusedBlock->getArguments().take_front(numFusedOpIndices)); mapper.map(std::get<0>(it), newIndex); } - // 2b. Replace the producer index operations by index operations placed in the - // fused block using the `consumerToProducerLoopsMap` to map the index spaces. - unsigned numFusedOpLoops = - std::max(producer.getNumLoops(), consumer.getNumLoops()); + // 2b. Add an index operation for every fused loop dimension and use the + // `consumerToProducerLoopsMap` to map the producer indices. if (producer.hasIndexSemantics()) { + // Add an index operation for every fused loop dimension. + unsigned numFusedOpLoops = + std::max(producer.getNumLoops(), consumer.getNumLoops()); SmallVector fusedIndices; fusedIndices.reserve(numFusedOpLoops); llvm::transform(llvm::seq(0, numFusedOpLoops), @@ -161,10 +162,7 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp, Value newIndex = rewriter.create( producer.getLoc(), consumerToProducerLoopsMap.getSubMap(indexOp.dim()), fusedIndices); - // Replace the producer index operation by the index value computed in the - // fused block. All remaining operations in the producer block are later - // on cloned to the fused block. - rewriter.replaceOp(indexOp, newIndex); + mapper.map(indexOp.getResult(), newIndex); } } // TODO: allow fusing the producer of an output operand. @@ -210,10 +208,12 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp, // TODO: allow fusion of multi-result producers. assert(producer->getNumResults() == 1 && "expected single result producer"); - // 8. Clone operations from producer (except the yield operation) to the fused - // op. - for (auto &op : producerBlock.without_terminator()) - rewriter.clone(op, mapper); + // 8. Clone all producer operations except for the yield and index operations + // to the fused operation. + for (auto &op : producerBlock.without_terminator()) { + if (!isa(op)) + rewriter.clone(op, mapper); + } // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just // forward the yield operand. auto yieldOp = cast(producerBlock.getTerminator()); diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir index 40c52657a853..1ba2d37fff3e 100644 --- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir @@ -462,8 +462,7 @@ func @indexed_generic_op_generic_op_fusion(%arg0: tensor, // ----- #map0 = affine_map<(d0, d1) -> (d0, d1)> -func @indexed_producer_consumer_fusion(%arg0: tensor, - %arg1: tensor) -> tensor { +func @indexed_producer_consumer_fusion(%arg0: tensor) -> tensor { %c0 = constant 0 : index %c1 = constant 1 : index %0 = memref.dim %arg0, %c0 : tensor @@ -486,7 +485,7 @@ func @indexed_producer_consumer_fusion(%arg0: tensor, %4 = linalg.generic { indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"] } - ins(%3, %arg1 : tensor, tensor) + ins(%3, %arg0 : tensor, tensor) outs(%2 : tensor) { ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): // no predecessors %10 = addi %arg2, %arg3 : i32 @@ -497,7 +496,7 @@ func @indexed_producer_consumer_fusion(%arg0: tensor, // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @indexed_producer_consumer_fusion // CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]] +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]] // CHECK: ^{{[a-zA-Z0-9_]*}} // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: i32 // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: i32 @@ -507,7 +506,7 @@ func @indexed_producer_consumer_fusion(%arg0: tensor, // CHECK: %[[SUB_OPERAND:.+]] = index_cast %[[IDX1]] : index to i32 // CHECK: %[[VAL1:.+]] = addi %[[ARG0]], %[[ADD_OPERAND]] : i32 // CHECK: %[[VAL2:.+]] = subi %[[VAL1]], %[[SUB_OPERAND]] : i32 -// CHECK: %[[VAL3:.+]] = addi %[[VAL2]], %[[ARG1]] : i32 +// CHECK: %[[VAL3:.+]] = addi %[[VAL2]], %[[ARG0]] : i32 // CHECK: linalg.yield %[[VAL3]] : i32 // CHECK-NOT: linalg.generic