forked from OSchip/llvm-project
[mlir][scf] Canonicalize scf.for last tensor iteration result.
Canonicalize the iter_args of an scf::ForOp that involve a tensor_load and for which only the last loop iteration is actually visible outside of the loop. The canonicalization looks for a pattern such as: ``` %t0 = ... : tensor_type %0 = scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) { ... // %m is either tensor_to_memref(%bb00) or defined above the loop %m... : memref_type ... // uses of %m with potential inplace updates %new_tensor = tensor_load %m : memref_type ... scf.yield %new_tensor : tensor_type } ``` `%bb0` may have either 0 or 1 use. If it has 1 use it must be exactly a `%m = tensor_to_memref %bb0` op that feeds into the yielded `tensor_load` op. If no aliasing write of `%new_tensor` occurs between tensor_load and yield then the value %0 visible outside of the loop is the last `tensor_load` produced in the loop. For now, we approximate the absence of aliasing by only supporting the case when the tensor_load is the operation immediately preceding the yield. The canonicalization rewrites the pattern as: ``` // %m is either a tensor_to_memref or defined above %m... : memref_type scf.for ... { // no iter_args ... // uses of %m with potential inplace updates } %0 = tensor_load %m : memref_type ``` Differential revision: https://reviews.llvm.org/D97953
This commit is contained in:
parent
43e4214173
commit
35908406dc
|
@ -560,11 +560,137 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
|
|||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
/// Canonicalize the iter_args of an scf::ForOp that involve a tensor_load and
|
||||
/// for which only the last loop iteration is actually visible outside of the
|
||||
/// loop. The canonicalization looks for a pattern such as:
|
||||
/// ```
|
||||
/// %t0 = ... : tensor_type
|
||||
/// %0 = scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
|
||||
/// ...
|
||||
/// // %m is either tensor_to_memref(%bb00) or defined above the loop
|
||||
/// %m... : memref_type
|
||||
/// ... // uses of %m with potential inplace updates
|
||||
/// %new_tensor = tensor_load %m : memref_type
|
||||
/// ...
|
||||
/// scf.yield %new_tensor : tensor_type
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// `%bb0` may have either 0 or 1 use. If it has 1 use it must be exactly a
|
||||
/// `%m = tensor_to_memref %bb0` op that feeds into the yielded `tensor_load`
|
||||
/// op.
|
||||
///
|
||||
/// If no aliasing write to the memref `%m`, from which `%new_tensor`is loaded,
|
||||
/// occurs between tensor_load and yield then the value %0 visible outside of
|
||||
/// the loop is the last `tensor_load` produced in the loop.
|
||||
///
|
||||
/// For now, we approximate the absence of aliasing by only supporting the case
|
||||
/// when the tensor_load is the operation immediately preceding the yield.
|
||||
///
|
||||
/// The canonicalization rewrites the pattern as:
|
||||
/// ```
|
||||
/// // %m is either a tensor_to_memref or defined above
|
||||
/// %m... : memref_type
|
||||
/// scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
|
||||
/// ... // uses of %m with potential inplace updates
|
||||
/// scf.yield %bb0: tensor_type
|
||||
/// }
|
||||
/// %0 = tensor_load %m : memref_type
|
||||
/// ```
|
||||
///
|
||||
/// A later bbArg canonicalization will further rewrite as:
|
||||
/// ```
|
||||
/// // %m is either a tensor_to_memref or defined above
|
||||
/// %m... : memref_type
|
||||
/// scf.for ... { // no iter_args
|
||||
/// ... // uses of %m with potential inplace updates
|
||||
/// }
|
||||
/// %0 = tensor_load %m : memref_type
|
||||
/// ```
|
||||
struct LastTensorLoadCanonicalization : public OpRewritePattern<ForOp> {
|
||||
using OpRewritePattern<ForOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ForOp forOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
assert(std::next(forOp.region().begin()) == forOp.region().end() &&
|
||||
"unexpected multiple blocks");
|
||||
|
||||
Location loc = forOp.getLoc();
|
||||
DenseMap<Value, Value> replacements;
|
||||
for (BlockArgument bbArg : forOp.getRegionIterArgs()) {
|
||||
unsigned idx = bbArg.getArgNumber() - /*numIv=*/1;
|
||||
auto yieldOp = cast<scf::YieldOp>(forOp.region().front().getTerminator());
|
||||
Value yieldVal = yieldOp->getOperand(idx);
|
||||
auto tensorLoadOp = yieldVal.getDefiningOp<TensorLoadOp>();
|
||||
bool isTensor = bbArg.getType().isa<TensorType>();
|
||||
|
||||
TensorToMemrefOp tensorToMemRefOp;
|
||||
// Either bbArg has no use or it has a single tensor_to_memref use.
|
||||
if (bbArg.hasOneUse())
|
||||
tensorToMemRefOp =
|
||||
dyn_cast<TensorToMemrefOp>(*bbArg.getUsers().begin());
|
||||
if (!isTensor || !tensorLoadOp ||
|
||||
(!bbArg.use_empty() && !tensorToMemRefOp))
|
||||
continue;
|
||||
// If tensorToMemRefOp is present, it must feed into the `tensorLoadOp`.
|
||||
if (tensorToMemRefOp && tensorLoadOp.memref() != tensorToMemRefOp)
|
||||
continue;
|
||||
// TODO: Any aliasing write of tensorLoadOp.memref() nested under `forOp`
|
||||
// must be before `tensorLoadOp` in the block so that the lastWrite
|
||||
// property is not subject to additional side-effects.
|
||||
// For now, we only support the case when tensorLoadOp appears immediately
|
||||
// before the terminator.
|
||||
if (tensorLoadOp->getNextNode() != yieldOp)
|
||||
continue;
|
||||
|
||||
// Clone the optional tensorToMemRefOp before forOp.
|
||||
if (tensorToMemRefOp) {
|
||||
rewriter.setInsertionPoint(forOp);
|
||||
rewriter.replaceOpWithNewOp<TensorToMemrefOp>(
|
||||
tensorToMemRefOp, tensorToMemRefOp.memref().getType(),
|
||||
tensorToMemRefOp.tensor());
|
||||
}
|
||||
|
||||
// Clone the tensorLoad after forOp.
|
||||
rewriter.setInsertionPointAfter(forOp);
|
||||
Value newTensorLoad =
|
||||
rewriter.create<TensorLoadOp>(loc, tensorLoadOp.memref());
|
||||
Value forOpResult = forOp.getResult(bbArg.getArgNumber() - /*iv=*/1);
|
||||
replacements.insert(std::make_pair(forOpResult, newTensorLoad));
|
||||
|
||||
// Make the terminator just yield the bbArg, the old tensorLoadOp + the
|
||||
// old bbArg (that is now directly yielded) will canonicalize away.
|
||||
rewriter.startRootUpdate(yieldOp);
|
||||
yieldOp.setOperand(idx, bbArg);
|
||||
rewriter.finalizeRootUpdate(yieldOp);
|
||||
}
|
||||
if (replacements.empty())
|
||||
return failure();
|
||||
|
||||
// We want to replace a subset of the results of `forOp`. rewriter.replaceOp
|
||||
// replaces the whole op and erase it unconditionally. This is wrong for
|
||||
// `forOp` as it generally contains ops with side effects.
|
||||
// Instead, use `rewriter.replaceOpWithIf`.
|
||||
SmallVector<Value> newResults;
|
||||
newResults.reserve(forOp.getNumResults());
|
||||
for (Value v : forOp.getResults()) {
|
||||
auto it = replacements.find(v);
|
||||
newResults.push_back((it != replacements.end()) ? it->second : v);
|
||||
}
|
||||
unsigned idx = 0;
|
||||
rewriter.replaceOpWithIf(forOp, newResults, [&](OpOperand &op) {
|
||||
return op.get() != newResults[idx++];
|
||||
});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void ForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<ForOpIterArgsFolder, SimplifyTrivialLoops>(context);
|
||||
results.insert<ForOpIterArgsFolder, SimplifyTrivialLoops,
|
||||
LastTensorLoadCanonicalization>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
// RUN: mlir-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s
|
||||
// RUN: mlir-opt %s -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
func @single_iteration(%A: memref<?x?x?xi32>) {
|
||||
%c0 = constant 0 : index
|
||||
|
@ -143,6 +146,8 @@ func @for_yields_3(%lb : index, %ub : index, %step : index) -> (i32, i32, i32) {
|
|||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[a]], %[[r1]], %[[b]] : i32, i32, i32
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @replace_true_if
|
||||
func @replace_true_if() {
|
||||
%true = constant true
|
||||
|
@ -155,6 +160,8 @@ func @replace_true_if() {
|
|||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @remove_false_if
|
||||
func @remove_false_if() {
|
||||
%false = constant false
|
||||
|
@ -167,6 +174,8 @@ func @remove_false_if() {
|
|||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @replace_true_if_with_values
|
||||
func @replace_true_if_with_values() {
|
||||
%true = constant true
|
||||
|
@ -184,6 +193,8 @@ func @replace_true_if_with_values() {
|
|||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @replace_false_if_with_values
|
||||
func @replace_false_if_with_values() {
|
||||
%false = constant false
|
||||
|
@ -201,6 +212,8 @@ func @replace_false_if_with_values() {
|
|||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @remove_zero_iteration_loop
|
||||
func @remove_zero_iteration_loop() {
|
||||
%c42 = constant 42 : index
|
||||
|
@ -217,6 +230,8 @@ func @remove_zero_iteration_loop() {
|
|||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @remove_zero_iteration_loop_vals
|
||||
func @remove_zero_iteration_loop_vals(%arg0: index) {
|
||||
%c2 = constant 2 : index
|
||||
|
@ -233,6 +248,8 @@ func @remove_zero_iteration_loop_vals(%arg0: index) {
|
|||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @replace_single_iteration_loop_1
|
||||
func @replace_single_iteration_loop_1() {
|
||||
// CHECK: %[[LB:.*]] = constant 42
|
||||
|
@ -252,6 +269,8 @@ func @replace_single_iteration_loop_1() {
|
|||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @replace_single_iteration_loop_2
|
||||
func @replace_single_iteration_loop_2() {
|
||||
// CHECK: %[[LB:.*]] = constant 5
|
||||
|
@ -271,6 +290,7 @@ func @replace_single_iteration_loop_2() {
|
|||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @replace_single_iteration_loop_non_unit_step
|
||||
func @replace_single_iteration_loop_non_unit_step() {
|
||||
|
@ -291,6 +311,8 @@ func @replace_single_iteration_loop_non_unit_step() {
|
|||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @remove_empty_parallel_loop
|
||||
func @remove_empty_parallel_loop(%lb: index, %ub: index, %s: index) {
|
||||
// CHECK: %[[INIT:.*]] = "test.init"
|
||||
|
@ -311,3 +333,52 @@ func @remove_empty_parallel_loop(%lb: index, %ub: index, %s: index) {
|
|||
"test.consume"(%0) : (f32) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
func private @process(%0 : memref<128x128xf32>)
|
||||
func private @process_tensor(%0 : tensor<128x128xf32>) -> memref<128x128xf32>
|
||||
|
||||
// CHECK-LABEL: last_value
|
||||
// CHECK-SAME: %[[T0:[0-9a-z]*]]: tensor<128x128xf32>
|
||||
// CHECK-SAME: %[[T1:[0-9a-z]*]]: tensor<128x128xf32>
|
||||
// CHECK-SAME: %[[T2:[0-9a-z]*]]: tensor<128x128xf32>
|
||||
// CHECK-SAME: %[[M0:[0-9a-z]*]]: memref<128x128xf32>
|
||||
func @last_value(%t0: tensor<128x128xf32>, %t1: tensor<128x128xf32>,
|
||||
%t2: tensor<128x128xf32>, %m0: memref<128x128xf32>,
|
||||
%lb : index, %ub : index, %step : index)
|
||||
-> (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>)
|
||||
{
|
||||
// CHECK-NEXT: %[[M1:.*]] = tensor_to_memref %[[T1]] : memref<128x128xf32>
|
||||
// CHECK-NEXT: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args(%[[BBARG_T2:.*]] = %[[T2]]) -> (tensor<128x128xf32>) {
|
||||
%0:3 = scf.for %arg0 = %lb to %ub step %step iter_args(%arg1 = %t0, %arg2 = %t1, %arg3 = %t2)
|
||||
-> (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>)
|
||||
{
|
||||
%m1 = tensor_to_memref %arg2 : memref<128x128xf32>
|
||||
|
||||
// CHECK-NEXT: call @process(%[[M0]]) : (memref<128x128xf32>) -> ()
|
||||
call @process(%m0) : (memref<128x128xf32>) -> ()
|
||||
|
||||
// CHECK-NEXT: call @process(%[[M1]]) : (memref<128x128xf32>) -> ()
|
||||
call @process(%m1) : (memref<128x128xf32>) -> ()
|
||||
|
||||
// This does not hoist (fails the bbArg has at most a single check).
|
||||
// CHECK-NEXT: %[[T:.*]] = call @process_tensor(%[[BBARG_T2]]) : (tensor<128x128xf32>) -> memref<128x128xf32>
|
||||
// CHECK-NEXT: %[[YIELD_T:.*]] = tensor_load %[[T:.*]]
|
||||
%m2 = call @process_tensor(%arg3): (tensor<128x128xf32>) -> memref<128x128xf32>
|
||||
%3 = tensor_load %m2 : memref<128x128xf32>
|
||||
|
||||
// All this stuff goes away, incrementally
|
||||
%1 = tensor_load %m0 : memref<128x128xf32>
|
||||
%2 = tensor_load %m1 : memref<128x128xf32>
|
||||
|
||||
// CHECK-NEXT: scf.yield %[[YIELD_T]] : tensor<128x128xf32>
|
||||
scf.yield %1, %2, %3 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
|
||||
|
||||
// CHECK-NEXT: }
|
||||
}
|
||||
|
||||
// CHECK-NEXT: %[[R0:.*]] = tensor_load %[[M0]] : memref<128x128xf32>
|
||||
// CHECK-NEXT: %[[R1:.*]] = tensor_load %[[M1]] : memref<128x128xf32>
|
||||
// CHECK-NEXT: return %[[R0]], %[[R1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
|
||||
return %0#0, %0#1, %0#2 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue