forked from OSchip/llvm-project
[mlir][linalg][bufferize] Handle scf::ForOp correctly in bufferizesToMemoryRead
From the perspective of analysis, scf::ForOp is treated as a black box. Basic block arguments do not alias with their respective OpOperands on the ForOp, so they do not participate in conflict analysis with ops defined outside of the loop. However, bufferizesToMemoryRead and bufferizesToMemoryWrite on the scf::ForOp itself are used to determine how the scf::ForOp interacts with its surrounding ops. Differential Revision: https://reviews.llvm.org/D111775
This commit is contained in:
parent
d3cb6bf2d4
commit
7dd7078760
|
@ -612,6 +612,31 @@ static OpResult getAliasingOpResult(OpOperand &opOperand) {
|
|||
[&](Operation *op) { return getInplaceableOpResult(opOperand); });
|
||||
}
|
||||
|
||||
// Predeclaration of function.
|
||||
static bool bufferizesToMemoryRead(OpOperand &opOperand);
|
||||
|
||||
/// scf::ForOp alone doesn't bufferize to a memory read, one of the uses of its
|
||||
/// matching bbArg may.
|
||||
static bool bufferizesToMemoryRead(scf::ForOp forOp, OpOperand &opOperand) {
|
||||
SmallVector<OpOperand *> workingSet;
|
||||
for (OpOperand &use : forOp.getRegionIterArgForOpOperand(opOperand).getUses())
|
||||
workingSet.push_back(&use);
|
||||
|
||||
while (!workingSet.empty()) {
|
||||
OpOperand *uMaybeReading = workingSet.pop_back_val();
|
||||
// Skip over all ExtractSliceOps. These do not read by themselves but just
|
||||
// add a new alias.
|
||||
if (auto extractSliceOp =
|
||||
dyn_cast<ExtractSliceOp>(uMaybeReading->getOwner()))
|
||||
for (OpOperand &use : extractSliceOp.result().getUses())
|
||||
workingSet.push_back(&use);
|
||||
if (bufferizesToMemoryRead(*uMaybeReading))
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Return true if `opOperand` bufferizes to a memory read.
|
||||
static bool bufferizesToMemoryRead(OpOperand &opOperand) {
|
||||
// Unknown op that returns a tensor. The inplace analysis does not support
|
||||
|
@ -622,15 +647,8 @@ static bool bufferizesToMemoryRead(OpOperand &opOperand) {
|
|||
// may.
|
||||
if (isa<ExtractSliceOp>(opOperand.getOwner()))
|
||||
return false;
|
||||
// scf::ForOp alone doesn't bufferize to a memory read, one of the uses of its
|
||||
// matching bbArg may.
|
||||
if (auto forOp = dyn_cast<scf::ForOp>(opOperand.getOwner())) {
|
||||
for (OpOperand &use :
|
||||
forOp.getRegionIterArgForOpOperand(opOperand).getUses())
|
||||
if (bufferizesToMemoryRead(use))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
if (auto forOp = dyn_cast<scf::ForOp>(opOperand.getOwner()))
|
||||
return bufferizesToMemoryRead(forOp, opOperand);
|
||||
// TiledLoop alone doesn't bufferize to a memory read, one of the uses of its
|
||||
// matching bbArg may.
|
||||
if (auto tiledLoopOp = dyn_cast<TiledLoopOp>(opOperand.getOwner())) {
|
||||
|
|
|
@ -912,3 +912,104 @@ func @interleaved_extract_insert_slice_chain_2(
|
|||
|
||||
return %15 : tensor<62x90xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#accesses = [
|
||||
affine_map<(i) -> (i)>
|
||||
]
|
||||
#trait = {
|
||||
indexing_maps = #accesses,
|
||||
iterator_types = ["parallel"]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @reading_scf_for
|
||||
func @reading_scf_for(%t1: tensor<?xf32> {linalg.inplaceable = true},
|
||||
%s: index, %v: vector<5xf32>) -> (tensor<?xf32>, vector<5xf32>) {
|
||||
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%cst = arith.constant 0.0 : f32
|
||||
|
||||
// Write to %t1.
|
||||
// CHECK: vector.transfer_write
|
||||
// CHECK-SAME: __inplace_results_attr__ = ["false"]
|
||||
%t3 = vector.transfer_write %v, %t1[%s] : vector<5xf32>, tensor<?xf32>
|
||||
|
||||
// Read the old value of %t1 inside the loop via an alias.
|
||||
// CHECK: scf.for
|
||||
%r, %v3 = scf.for %i = %c0 to %s step %c1 iter_args(%t2 = %t1, %v0 = %v) -> (tensor<?xf32>, vector<5xf32>) {
|
||||
// CHECK: tensor.extract_slice
|
||||
// CHECK-SAME: __inplace_results_attr__ = ["true"]
|
||||
%e = tensor.extract_slice %t2[%s][%s][1] : tensor<?xf32> to tensor<?xf32>
|
||||
|
||||
// Read from %t1 via alias %e.
|
||||
%v2 = vector.transfer_read %e[%s], %cst : tensor<?xf32>, vector<5xf32>
|
||||
scf.yield %e, %v2 : tensor<?xf32>, vector<5xf32>
|
||||
}
|
||||
// CHECK: __inplace_results_attr__ = ["true", "none"]
|
||||
|
||||
// Use %t3 in some way without reading it, so that it does not get DCE'd.
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: __inplace_results_attr__ = ["true"]
|
||||
%o = linalg.generic #trait outs (%t3 : tensor<?xf32>) {
|
||||
^bb(%0: f32) :
|
||||
linalg.yield %cst : f32
|
||||
} -> (tensor<?xf32>)
|
||||
|
||||
return %o, %v3 : tensor<?xf32>, vector<5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#accesses = [
|
||||
affine_map<(i) -> (i)>
|
||||
]
|
||||
#trait = {
|
||||
indexing_maps = #accesses,
|
||||
iterator_types = ["parallel"]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @non_reading_scf_for
|
||||
func @non_reading_scf_for(%t1: tensor<?xf32> {linalg.inplaceable = true},
|
||||
%s: index, %v: vector<5xf32>) -> (tensor<?xf32>, vector<5xf32>) {
|
||||
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%cst = arith.constant 0.0 : f32
|
||||
|
||||
// Write to %t1.
|
||||
// CHECK: vector.transfer_write
|
||||
// CHECK-SAME: __inplace_results_attr__ = ["true"]
|
||||
%t3 = vector.transfer_write %v, %t1[%s] : vector<5xf32>, tensor<?xf32>
|
||||
|
||||
// This loop does not read from %t1. It only writes to it.
|
||||
// CHECK: scf.for
|
||||
%r, %v3 = scf.for %i = %c0 to %s step %c1 iter_args(%t2 = %t1, %v0 = %v) -> (tensor<?xf32>, vector<5xf32>) {
|
||||
// CHECK: tensor.extract_slice
|
||||
// CHECK-SAME: __inplace_results_attr__ = ["true"]
|
||||
%e = tensor.extract_slice %t2[%s][%s][1] : tensor<?xf32> to tensor<?xf32>
|
||||
|
||||
// Write to %t1 via alias. (Overwrite %t3.)
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: __inplace_results_attr__ = ["true"]
|
||||
%o2 = linalg.generic #trait outs (%e : tensor<?xf32>) {
|
||||
^bb(%0: f32) :
|
||||
linalg.yield %cst : f32
|
||||
} -> (tensor<?xf32>)
|
||||
|
||||
// Read overwritten value. This is not a read of %t1.
|
||||
%v2 = vector.transfer_read %o2[%s], %cst : tensor<?xf32>, vector<5xf32>
|
||||
scf.yield %o2, %v2 : tensor<?xf32>, vector<5xf32>
|
||||
}
|
||||
|
||||
// Use %t3 in some way without reading it, so that it does not get DCE'd.
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: __inplace_results_attr__ = ["true"]
|
||||
%o = linalg.generic #trait outs (%t3 : tensor<?xf32>) {
|
||||
^bb(%0: f32) :
|
||||
linalg.yield %cst : f32
|
||||
} -> (tensor<?xf32>)
|
||||
|
||||
return %o, %v3 : tensor<?xf32>, vector<5xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue