[mlir][Linalg] Proper handling of ForOp and TiledLoopOp

The `bufferizesToMemoryRead` condition was too optimistics in the case
of operands that map to a block argument.
This is the case for ForOp and TiledLoopOp.
For such ops, forward the call to all uses of the matching BBArg.

Differential Revision: https://reviews.llvm.org/D105540
This commit is contained in:
Nicolas Vasilache 2021-07-07 08:02:02 +00:00
parent a7da0296a6
commit 9a0af63d05
2 changed files with 68 additions and 2 deletions

View File

@ -542,12 +542,21 @@ static bool bufferizesToMemoryRead(OpOperand &opOperand) {
return false; return false;
// scf::ForOp alone doesn't bufferize to a memory read, one of the uses of its // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of its
// matching bbArg may. // matching bbArg may.
if (isa<scf::ForOp>(opOperand.getOwner())) if (auto forOp = dyn_cast<scf::ForOp>(opOperand.getOwner())) {
for (OpOperand &use :
forOp.getRegionIterArgForOpOperand(opOperand).getUses())
if (bufferizesToMemoryRead(use))
return true;
return false; return false;
}
// TiledLoop alone doesn't bufferize to a memory read, one of the uses of its // TiledLoop alone doesn't bufferize to a memory read, one of the uses of its
// matching bbArg may. // matching bbArg may.
if (isa<TiledLoopOp>(opOperand.getOwner())) if (auto tiledLoopOp = dyn_cast<TiledLoopOp>(opOperand.getOwner())) {
for (OpOperand &use : tiledLoopOp.getTiedBlockArgument(opOperand).getUses())
if (bufferizesToMemoryRead(use))
return true;
return false; return false;
}
// CallOpInterface alone doesn't bufferize to a memory read, one of the uses // CallOpInterface alone doesn't bufferize to a memory read, one of the uses
// of the matching bbArg may. It is the responsibility of the caller to // of the matching bbArg may. It is the responsibility of the caller to
// inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be // inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be
@ -1685,6 +1694,8 @@ static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp,
b.create<linalg::CopyOp>(forOp.getLoc(), operandBuffer, resultBuffer); b.create<linalg::CopyOp>(forOp.getLoc(), operandBuffer, resultBuffer);
} }
BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand); BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand);
aliasInfo.createAliasInfoEntry(resultBuffer);
aliasInfo.insertNewBufferEquivalence(bbArg, resultBuffer);
map(bvm, bbArg, resultBuffer); map(bvm, bbArg, resultBuffer);
map(bvm, opResult, resultBuffer); map(bvm, opResult, resultBuffer);
} }

View File

@ -474,6 +474,61 @@ func @scf_for_with_tensor.insert_slice(%A : tensor<?xf32>,
// ----- // -----
func private @some_use(tensor<?xf32>) -> ()
// CHECK-LABEL: func @scf_for_deps
func @scf_for_deps(%A : tensor<?xf32> {linalg.inplaceable = true},
%B : tensor<?xf32> {linalg.inplaceable = true},
%lb : index, %ub : index, %step : index)
-> (tensor<?xf32>, tensor<?xf32>)
{
// %r0 must be out of place because one use of %t in the subsequent production
// of %r1 is read.
// CHECK: scf.for
// CHECK-NEXT: scf.yield
// CHECK-NEXT: {__inplace_results_attr__ = ["false"]}
%r0 = scf.for %i = %lb to %ub step %step iter_args(%t = %A) -> (tensor<?xf32>) {
scf.yield %t : tensor<?xf32>
}
// %r1 bufferizes inplace fine.
// CHECK: scf.for
// CHECK-NEXT: call
// CHECK-NEXT: scf.yield
// CHECK-NEXT: {__inplace_results_attr__ = ["true"]}
%r1 = scf.for %i = %lb to %ub step %step iter_args(%t = %A) -> (tensor<?xf32>) {
call @some_use(%t) : (tensor<?xf32>) -> ()
scf.yield %t : tensor<?xf32>
}
// %r2 must be out of place because one use of %t in the subsequent production
// of %r3 is read.
// CHECK: linalg.tiled_loop
// CHECK-NEXT: linalg.yield
// CHECK-NEXT: {__inplace_results_attr__ = ["false"]}
%r2 = linalg.tiled_loop (%i) = (%lb) to (%ub) step (%step)
ins()
outs(%t = %B: tensor<?xf32>) {
linalg.yield %t : tensor<?xf32>
}
// %r3 bufferizes inplace fine.
// CHECK: linalg.tiled_loop
// CHECK-NEXT: call
// CHECK-NEXT: linalg.yield
// CHECK-NEXT: {__inplace_results_attr__ = ["true"]}
%r3 = linalg.tiled_loop (%i) = (%lb) to (%ub) step (%step)
ins()
outs(%t = %B: tensor<?xf32>) {
call @some_use(%t) : (tensor<?xf32>) -> ()
linalg.yield %t : tensor<?xf32>
}
return %r1, %r3: tensor<?xf32>, tensor<?xf32>
}
// -----
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Cross function boundary cases. // Cross function boundary cases.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//