forked from OSchip/llvm-project
[mlir][Linalg] Proper handling of ForOp and TiledLoopOp
The `bufferizesToMemoryRead` condition was too optimistics in the case of operands that map to a block argument. This is the case for ForOp and TiledLoopOp. For such ops, forward the call to all uses of the matching BBArg. Differential Revision: https://reviews.llvm.org/D105540
This commit is contained in:
parent
a7da0296a6
commit
9a0af63d05
|
@ -542,12 +542,21 @@ static bool bufferizesToMemoryRead(OpOperand &opOperand) {
|
|||
return false;
|
||||
// scf::ForOp alone doesn't bufferize to a memory read, one of the uses of its
|
||||
// matching bbArg may.
|
||||
if (isa<scf::ForOp>(opOperand.getOwner()))
|
||||
if (auto forOp = dyn_cast<scf::ForOp>(opOperand.getOwner())) {
|
||||
for (OpOperand &use :
|
||||
forOp.getRegionIterArgForOpOperand(opOperand).getUses())
|
||||
if (bufferizesToMemoryRead(use))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
// TiledLoop alone doesn't bufferize to a memory read, one of the uses of its
|
||||
// matching bbArg may.
|
||||
if (isa<TiledLoopOp>(opOperand.getOwner()))
|
||||
if (auto tiledLoopOp = dyn_cast<TiledLoopOp>(opOperand.getOwner())) {
|
||||
for (OpOperand &use : tiledLoopOp.getTiedBlockArgument(opOperand).getUses())
|
||||
if (bufferizesToMemoryRead(use))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
// CallOpInterface alone doesn't bufferize to a memory read, one of the uses
|
||||
// of the matching bbArg may. It is the responsibility of the caller to
|
||||
// inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be
|
||||
|
@ -1685,6 +1694,8 @@ static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp,
|
|||
b.create<linalg::CopyOp>(forOp.getLoc(), operandBuffer, resultBuffer);
|
||||
}
|
||||
BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand);
|
||||
aliasInfo.createAliasInfoEntry(resultBuffer);
|
||||
aliasInfo.insertNewBufferEquivalence(bbArg, resultBuffer);
|
||||
map(bvm, bbArg, resultBuffer);
|
||||
map(bvm, opResult, resultBuffer);
|
||||
}
|
||||
|
|
|
@ -474,6 +474,61 @@ func @scf_for_with_tensor.insert_slice(%A : tensor<?xf32>,
|
|||
|
||||
// -----
|
||||
|
||||
func private @some_use(tensor<?xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: func @scf_for_deps
|
||||
func @scf_for_deps(%A : tensor<?xf32> {linalg.inplaceable = true},
|
||||
%B : tensor<?xf32> {linalg.inplaceable = true},
|
||||
%lb : index, %ub : index, %step : index)
|
||||
-> (tensor<?xf32>, tensor<?xf32>)
|
||||
{
|
||||
// %r0 must be out of place because one use of %t in the subsequent production
|
||||
// of %r1 is read.
|
||||
// CHECK: scf.for
|
||||
// CHECK-NEXT: scf.yield
|
||||
// CHECK-NEXT: {__inplace_results_attr__ = ["false"]}
|
||||
%r0 = scf.for %i = %lb to %ub step %step iter_args(%t = %A) -> (tensor<?xf32>) {
|
||||
scf.yield %t : tensor<?xf32>
|
||||
}
|
||||
|
||||
// %r1 bufferizes inplace fine.
|
||||
// CHECK: scf.for
|
||||
// CHECK-NEXT: call
|
||||
// CHECK-NEXT: scf.yield
|
||||
// CHECK-NEXT: {__inplace_results_attr__ = ["true"]}
|
||||
%r1 = scf.for %i = %lb to %ub step %step iter_args(%t = %A) -> (tensor<?xf32>) {
|
||||
call @some_use(%t) : (tensor<?xf32>) -> ()
|
||||
scf.yield %t : tensor<?xf32>
|
||||
}
|
||||
|
||||
// %r2 must be out of place because one use of %t in the subsequent production
|
||||
// of %r3 is read.
|
||||
// CHECK: linalg.tiled_loop
|
||||
// CHECK-NEXT: linalg.yield
|
||||
// CHECK-NEXT: {__inplace_results_attr__ = ["false"]}
|
||||
%r2 = linalg.tiled_loop (%i) = (%lb) to (%ub) step (%step)
|
||||
ins()
|
||||
outs(%t = %B: tensor<?xf32>) {
|
||||
linalg.yield %t : tensor<?xf32>
|
||||
}
|
||||
|
||||
// %r3 bufferizes inplace fine.
|
||||
// CHECK: linalg.tiled_loop
|
||||
// CHECK-NEXT: call
|
||||
// CHECK-NEXT: linalg.yield
|
||||
// CHECK-NEXT: {__inplace_results_attr__ = ["true"]}
|
||||
%r3 = linalg.tiled_loop (%i) = (%lb) to (%ub) step (%step)
|
||||
ins()
|
||||
outs(%t = %B: tensor<?xf32>) {
|
||||
call @some_use(%t) : (tensor<?xf32>) -> ()
|
||||
linalg.yield %t : tensor<?xf32>
|
||||
}
|
||||
|
||||
return %r1, %r3: tensor<?xf32>, tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Cross function boundary cases.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
Loading…
Reference in New Issue