forked from OSchip/llvm-project
[mlir][linalg] Fix bug in the fusion on tensors index op handling.
The old index op handling let the new index operations point back to the producer block. As a result, after fusion some index operations in the fused block had back references to the old producer block resulting in illegal IR. The patch now relies on a block and value mapping to avoid such back references. Differential Revision: https://reviews.llvm.org/D101887
This commit is contained in:
parent
1f5cacfcb8
commit
4a6ee23d83
|
@ -145,11 +145,12 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp,
|
|||
fusedBlock->getArguments().take_front(numFusedOpIndices));
|
||||
mapper.map(std::get<0>(it), newIndex);
|
||||
}
|
||||
// 2b. Replace the producer index operations by index operations placed in the
|
||||
// fused block using the `consumerToProducerLoopsMap` to map the index spaces.
|
||||
unsigned numFusedOpLoops =
|
||||
std::max(producer.getNumLoops(), consumer.getNumLoops());
|
||||
// 2b. Add an index operation for every fused loop dimension and use the
|
||||
// `consumerToProducerLoopsMap` to map the producer indices.
|
||||
if (producer.hasIndexSemantics()) {
|
||||
// Add an index operation for every fused loop dimension.
|
||||
unsigned numFusedOpLoops =
|
||||
std::max(producer.getNumLoops(), consumer.getNumLoops());
|
||||
SmallVector<Value> fusedIndices;
|
||||
fusedIndices.reserve(numFusedOpLoops);
|
||||
llvm::transform(llvm::seq<int64_t>(0, numFusedOpLoops),
|
||||
|
@ -161,10 +162,7 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp,
|
|||
Value newIndex = rewriter.create<mlir::AffineApplyOp>(
|
||||
producer.getLoc(),
|
||||
consumerToProducerLoopsMap.getSubMap(indexOp.dim()), fusedIndices);
|
||||
// Replace the producer index operation by the index value computed in the
|
||||
// fused block. All remaining operations in the producer block are later
|
||||
// on cloned to the fused block.
|
||||
rewriter.replaceOp(indexOp, newIndex);
|
||||
mapper.map(indexOp.getResult(), newIndex);
|
||||
}
|
||||
}
|
||||
// TODO: allow fusing the producer of an output operand.
|
||||
|
@ -210,10 +208,12 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp,
|
|||
// TODO: allow fusion of multi-result producers.
|
||||
assert(producer->getNumResults() == 1 && "expected single result producer");
|
||||
|
||||
// 8. Clone operations from producer (except the yield operation) to the fused
|
||||
// op.
|
||||
for (auto &op : producerBlock.without_terminator())
|
||||
rewriter.clone(op, mapper);
|
||||
// 8. Clone all producer operations except for the yield and index operations
|
||||
// to the fused operation.
|
||||
for (auto &op : producerBlock.without_terminator()) {
|
||||
if (!isa<IndexOp>(op))
|
||||
rewriter.clone(op, mapper);
|
||||
}
|
||||
// 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
|
||||
// forward the yield operand.
|
||||
auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
|
||||
|
|
|
@ -462,8 +462,7 @@ func @indexed_generic_op_generic_op_fusion(%arg0: tensor<?x?xi32>,
|
|||
// -----
|
||||
|
||||
#map0 = affine_map<(d0, d1) -> (d0, d1)>
|
||||
func @indexed_producer_consumer_fusion(%arg0: tensor<?x?xi32>,
|
||||
%arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
|
||||
func @indexed_producer_consumer_fusion(%arg0: tensor<?x?xi32>) -> tensor<?x?xi32> {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%0 = memref.dim %arg0, %c0 : tensor<?x?xi32>
|
||||
|
@ -486,7 +485,7 @@ func @indexed_producer_consumer_fusion(%arg0: tensor<?x?xi32>,
|
|||
%4 = linalg.generic {
|
||||
indexing_maps = [#map0, #map0, #map0],
|
||||
iterator_types = ["parallel", "parallel"] }
|
||||
ins(%3, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>)
|
||||
ins(%3, %arg0 : tensor<?x?xi32>, tensor<?x?xi32>)
|
||||
outs(%2 : tensor<?x?xi32>) {
|
||||
^bb0(%arg2: i32, %arg3: i32, %arg4: i32): // no predecessors
|
||||
%10 = addi %arg2, %arg3 : i32
|
||||
|
@ -497,7 +496,7 @@ func @indexed_producer_consumer_fusion(%arg0: tensor<?x?xi32>,
|
|||
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-LABEL: func @indexed_producer_consumer_fusion
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]]
|
||||
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]]
|
||||
// CHECK: ^{{[a-zA-Z0-9_]*}}
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: i32
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: i32
|
||||
|
@ -507,7 +506,7 @@ func @indexed_producer_consumer_fusion(%arg0: tensor<?x?xi32>,
|
|||
// CHECK: %[[SUB_OPERAND:.+]] = index_cast %[[IDX1]] : index to i32
|
||||
// CHECK: %[[VAL1:.+]] = addi %[[ARG0]], %[[ADD_OPERAND]] : i32
|
||||
// CHECK: %[[VAL2:.+]] = subi %[[VAL1]], %[[SUB_OPERAND]] : i32
|
||||
// CHECK: %[[VAL3:.+]] = addi %[[VAL2]], %[[ARG1]] : i32
|
||||
// CHECK: %[[VAL3:.+]] = addi %[[VAL2]], %[[ARG0]] : i32
|
||||
// CHECK: linalg.yield %[[VAL3]] : i32
|
||||
// CHECK-NOT: linalg.generic
|
||||
|
||||
|
|
Loading…
Reference in New Issue