forked from OSchip/llvm-project
[RFC][MLIR][SCF] Enable better bufferization for `TileConsumerAndFuseProducersUsingSCFForOp`
Replace iterators of the outermost loop with region arguments of the innermost one. The changes avoid later `bufferization` passes to insert allocation within the body of the innermost loop. Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D130083
This commit is contained in:
parent
cc72af4e13
commit
9e65850305
|
@ -355,6 +355,23 @@ static Optional<OpResult> getFusableProducer(Value v) {
|
|||
return v.cast<OpResult>();
|
||||
}
|
||||
|
||||
// Replace iter args of the outer most loop with region args of the inner most
|
||||
// one.
|
||||
static void replaceIterArgs(scf::ForOp outerFor, scf::ForOp innerFor,
|
||||
PatternRewriter &rewriter) {
|
||||
assert(outerFor.getNumIterOperands() == innerFor.getNumIterOperands() &&
|
||||
"expect same number of iter args");
|
||||
Block *block = &(*innerFor.getRegion().begin());
|
||||
for (auto it :
|
||||
llvm::zip(outerFor.getIterOperands(), innerFor.getRegionIterArgs())) {
|
||||
Value source = std::get<0>(it);
|
||||
Value target = std::get<1>(it);
|
||||
source.replaceUsesWithIf(target, [&](OpOperand &use) {
|
||||
return use.getOwner()->getBlock() == block;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
FailureOr<scf::SCFTileAndFuseResult>
|
||||
scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite(
|
||||
TilingInterface op, PatternRewriter &rewriter) const {
|
||||
|
@ -470,5 +487,7 @@ scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite(
|
|||
}
|
||||
}
|
||||
}
|
||||
replaceIterArgs(tileAndFuseResult.loops.front(),
|
||||
tileAndFuseResult.loops.back(), rewriter);
|
||||
return tileAndFuseResult;
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@ func.func @gemm_fill_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) ->
|
|||
// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
|
||||
// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
|
||||
// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
|
||||
// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]]
|
||||
// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV0]], %[[IV1]]]
|
||||
// CHECK: %[[FILL_TILE:.+]] = linalg.fill
|
||||
// CHECK-SAME: outs(%[[INIT_TILE]] :
|
||||
// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul
|
||||
|
@ -68,7 +68,7 @@ func.func @gemm_generic_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
|
|||
// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
|
||||
// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
|
||||
// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
|
||||
// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]]
|
||||
// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV0]], %[[IV1]]]
|
||||
// CHECK: %[[FILL_TILE:.+]] = linalg.fill
|
||||
// CHECK-SAME: outs(%[[INIT_TILE]] :
|
||||
// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul
|
||||
|
@ -123,7 +123,7 @@ func.func @gemm_gemm_fusion(%lhs0 : tensor<?x?xf32>, %rhs0 : tensor<?x?xf32>, %r
|
|||
// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
|
||||
// CHECK-SAME: outs(%[[FILL0_TILE]] :
|
||||
// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
|
||||
// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[INIT1]][%[[IV]], 0]
|
||||
// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG]][%[[IV]], 0]
|
||||
// CHECK: %[[FILL1_TILE:.+]] = linalg.fill
|
||||
// CHECK-SAME: outs(%[[INIT1_TILE]] :
|
||||
// CHECK: %[[GEMM1_TILE:.+]] = linalg.matmul
|
||||
|
|
Loading…
Reference in New Issue