[mlir][linalg][bufferize] Fix insertion point InitTensorElimination

There was a bug where some of the OpOperands needed in the replacement op were not in scope.

It does not matter where the replacement op is inserted. Any insertion point is OK as long as there are no dominance errors. In the worst case, the newly inserted op will bufferize out-of-place. This is no worse than not eliminating the InitTensorOp at all.

Differential Revision: https://reviews.llvm.org/D117685
This commit is contained in:
Matthias Springer 2022-01-30 22:19:06 +09:00
parent ab0554b2ec
commit 6700a26d5f
3 changed files with 162 additions and 6 deletions

View File

@ -20,7 +20,10 @@ namespace linalg_ext {
struct InitTensorEliminationStep : public bufferization::PostAnalysisStep {
/// A function that matches anchor OpOperands for InitTensorOp elimination.
using AnchorMatchFn = std::function<bool(OpOperand &)>;
/// If an OpOperand is matched, the function should populate the SmallVector
/// with all values that are needed during `RewriteFn` to produce the
/// replacement value.
using AnchorMatchFn = std::function<bool(OpOperand &, SmallVector<Value> &)>;
/// A function that rewrites matched anchors.
using RewriteFn = std::function<Value(OpBuilder &, Location, OpOperand &)>;

View File

@ -12,6 +12,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Operation.h"
using namespace mlir;
@ -444,6 +445,79 @@ struct LinalgOpInterfaceHelper<> {
} // namespace
/// Return true if all `neededValues` are in scope at the given
/// `insertionPoint`.
static bool
neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
Operation *insertionPoint,
const SmallVector<Value> &neededValues) {
for (Value val : neededValues) {
if (auto bbArg = val.dyn_cast<BlockArgument>()) {
Block *owner = bbArg.getOwner();
if (!owner->findAncestorOpInBlock(*insertionPoint))
return false;
} else {
auto opResult = val.cast<OpResult>();
if (!domInfo.dominates(opResult.getOwner(), insertionPoint))
return false;
}
}
return true;
}
/// Return true if the given `insertionPoint` dominates all uses of
/// `initTensorOp`.
static bool insertionPointDominatesUses(const DominanceInfo &domInfo,
Operation *insertionPoint,
Operation *initTensorOp) {
for (Operation *user : initTensorOp->getUsers())
if (!domInfo.dominates(insertionPoint, user))
return false;
return true;
}
/// Find a valid insertion point for a replacement of `initTensorOp`, assuming
/// that the replacement may use any value from `neededValues`.
static Operation *
findValidInsertionPoint(Operation *initTensorOp,
const SmallVector<Value> &neededValues) {
DominanceInfo domInfo;
// Gather all possible insertion points: the location of `initTensorOp` and
// right after the definition of each value in `neededValues`.
SmallVector<Operation *> insertionPointCandidates;
insertionPointCandidates.push_back(initTensorOp);
for (Value val : neededValues) {
// Note: The anchor op is using all of `neededValues`, so:
// * in case of a block argument: There must be at least one op in the block
// (the anchor op or one of its parents).
// * in case of an OpResult: There must be at least one op right after the
// defining op (the anchor op or one of its
// parents).
if (auto bbArg = val.dyn_cast<BlockArgument>()) {
insertionPointCandidates.push_back(
&bbArg.getOwner()->getOperations().front());
} else {
insertionPointCandidates.push_back(val.getDefiningOp()->getNextNode());
}
}
// Select first matching insertion point.
for (Operation *insertionPoint : insertionPointCandidates) {
// Check if all needed values are in scope.
if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint,
neededValues))
continue;
// Check if the insertion point is before all uses.
if (!insertionPointDominatesUses(domInfo, insertionPoint, initTensorOp))
continue;
return insertionPoint;
}
// No suitable insertion point was found.
return nullptr;
}
/// Try to eliminate InitTensorOps inside `op`. An InitTensorOp is replaced
/// with the the result of `rewriteFunc` if it is anchored on a matching
/// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def
@ -462,8 +536,10 @@ mlir::linalg::comprehensive_bufferize::linalg_ext::InitTensorEliminationStep::
// Skip operands that do not bufferize inplace.
if (!aliasInfo.isInPlace(operand))
continue;
// All values that are needed to create the replacement op.
SmallVector<Value> neededValues;
// Is this a matching OpOperand?
if (!anchorMatchFunc(operand))
if (!anchorMatchFunc(operand, neededValues))
continue;
SetVector<Value> maybeInitTensor =
state.findValueInReverseUseDefChain(operand.get(), [&](Value val) {
@ -492,8 +568,14 @@ mlir::linalg::comprehensive_bufferize::linalg_ext::InitTensorEliminationStep::
return WalkResult::skip();
Value initTensor = maybeInitTensor.front();
// Find a suitable insertion point.
Operation *insertionPoint =
findValidInsertionPoint(initTensor.getDefiningOp(), neededValues);
if (!insertionPoint)
continue;
// Create a replacement for the InitTensorOp.
b.setInsertionPoint(initTensor.getDefiningOp());
b.setInsertionPoint(insertionPoint);
Value replacement = rewriteFunc(b, initTensor.getLoc(), operand);
if (!replacement)
continue;
@ -552,7 +634,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
return eliminateInitTensors(
op, state, aliasInfo,
/*anchorMatchFunc=*/
[&](OpOperand &operand) {
[&](OpOperand &operand, SmallVector<Value> &neededValues) {
auto insertSliceOp =
dyn_cast<tensor::InsertSliceOp>(operand.getOwner());
if (!insertSliceOp)
@ -560,7 +642,19 @@ LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
// Only inplace bufferized InsertSliceOps are eligible.
if (!aliasInfo.isInPlace(insertSliceOp->getOpOperand(1) /*dest*/))
return false;
return &operand == &insertSliceOp->getOpOperand(0) /*source*/;
if (&operand != &insertSliceOp->getOpOperand(0) /*source*/)
return false;
// Collect all values that are needed to construct the replacement op.
neededValues.append(insertSliceOp.offsets().begin(),
insertSliceOp.offsets().end());
neededValues.append(insertSliceOp.sizes().begin(),
insertSliceOp.sizes().end());
neededValues.append(insertSliceOp.strides().begin(),
insertSliceOp.strides().end());
neededValues.push_back(insertSliceOp.dest());
return true;
},
/*rewriteFunc=*/
[](OpBuilder &b, Location loc, OpOperand &operand) {

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="allow-return-memref init-tensor-elimination" -split-input-file | FileCheck %s
// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="allow-return-memref init-tensor-elimination" -canonicalize -split-input-file | FileCheck %s
// -----
@ -62,3 +62,62 @@ func @buffer_forwarding_no_conflict(
return %r1: tensor<?xf32>
}
// -----
// CHECK: func @insertion_point_inside_loop(
// CHECK-SAME: %[[t:.*]]: memref<?xf32, #{{.*}}>, %[[sz:.*]]: index)
func @insertion_point_inside_loop(%t : tensor<?xf32>, %sz : index) -> (tensor<?xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c5 = arith.constant 5 : index
// CHECK-NOT: memref.alloc
%blank = linalg.init_tensor [5] : tensor<5xf32>
// CHECK: scf.for %[[iv:.*]] = %{{.*}} to %[[sz]] step %{{.*}} {
%r = scf.for %iv = %c0 to %sz step %c5 iter_args(%bb = %t) -> (tensor<?xf32>) {
// CHECK: %[[subview:.*]] = memref.subview %[[t]][%[[iv]]] [5] [1]
%iv_i32 = arith.index_cast %iv : index to i32
%f = arith.sitofp %iv_i32 : i32 to f32
// CHECK: linalg.fill(%{{.*}}, %[[subview]])
%filled = linalg.fill(%f, %blank) : f32, tensor<5xf32> -> tensor<5xf32>
// CHECK-NOT: memref.copy
%inserted = tensor.insert_slice %filled into %bb[%iv][5][1] : tensor<5xf32> into tensor<?xf32>
scf.yield %inserted : tensor<?xf32>
}
return %r : tensor<?xf32>
}
// -----
// CHECK: func @insertion_point_outside_loop(
// CHECK-SAME: %[[t:.*]]: memref<?xf32, #{{.*}}>, %[[sz:.*]]: index, %[[idx:.*]]: index)
func @insertion_point_outside_loop(%t : tensor<?xf32>, %sz : index,
%idx : index) -> (tensor<?xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c5 = arith.constant 5 : index
// CHECK-NOT: memref.alloc
// CHECK: %[[subview:.*]] = memref.subview %[[t]][%[[idx]]] [5] [1]
%blank = linalg.init_tensor [5] : tensor<5xf32>
// CHECK: scf.for %[[iv:.*]] = %{{.*}} to %[[sz]] step %{{.*}} {
%r = scf.for %iv = %c0 to %sz step %c5 iter_args(%bb = %t) -> (tensor<?xf32>) {
%iv_i32 = arith.index_cast %iv : index to i32
%f = arith.sitofp %iv_i32 : i32 to f32
// CHECK: linalg.fill(%{{.*}}, %[[subview]])
%filled = linalg.fill(%f, %blank) : f32, tensor<5xf32> -> tensor<5xf32>
// CHECK-NOT: memref.copy
%inserted = tensor.insert_slice %filled into %bb[%idx][5][1] : tensor<5xf32> into tensor<?xf32>
scf.yield %inserted : tensor<?xf32>
}
return %r : tensor<?xf32>
}