[mlir][linalg][bufferize] Support scf::IfOp

This commit adds support for scf::IfOp to comprehensive bufferization. Support is currently limited to cases where both branches yield tensors that bufferize to the same buffer.

To keep the analysis simple, scf::IfOp are treated as memory writes for analysis purposes, even if no op inside any branch is writing. (scf::ForOps are handled in the same way.)

Differential Revision: https://reviews.llvm.org/D111929
This commit is contained in:
Matthias Springer 2021-10-22 09:58:41 +09:00
parent 4976be1e95
commit 3bbc869e2e
4 changed files with 449 additions and 24 deletions

View File

@ -442,6 +442,7 @@ static bool hasKnownBufferizationAliasingBehavior(Operation *op) {
ConstantOp,
tensor::DimOp,
ExtractSliceOp,
scf::IfOp,
scf::ForOp,
InsertSliceOp,
InitTensorOp,
@ -550,6 +551,16 @@ static OpResult getInplaceableOpResult(OpOperand &opOperand) {
// clang-format on
}
/// Either one of the corresponding yield values from the then/else branches
/// may alias with the result.
static void populateAliasingOpOperands(scf::IfOp op, OpResult result,
SmallVector<OpOperand *> &operands) {
size_t resultNum = std::distance(op->getOpResults().begin(),
llvm::find(op->getOpResults(), result));
operands.push_back(&op.thenYield()->getOpOperand(resultNum));
operands.push_back(&op.elseYield()->getOpOperand(resultNum));
}
/// Determine which OpOperand* will alias with `result` if the op is bufferized
/// in place. Note that multiple OpOperands can may potentially alias with an
/// OpResult. E.g.: std.select in the future.
@ -561,6 +572,7 @@ static SmallVector<OpOperand *> getAliasingOpOperand(OpResult result) {
TypeSwitch<Operation *>(result.getDefiningOp())
.Case([&](tensor::CastOp op) { r.push_back(&op->getOpOperand(0)); })
.Case([&](ExtractSliceOp op) { r.push_back(&op->getOpOperand(0)); })
.Case([&](scf::IfOp op) { populateAliasingOpOperands(op, result, r); })
// In the case of scf::ForOp, this currently assumes the iter_args / yield
// are 1-1. This may fail and is verified at the end.
// TODO: update this.
@ -730,6 +742,19 @@ BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
if (bbArg.getType().isa<TensorType>())
createAliasInfoEntry(bbArg);
});
// The return value of an scf::IfOp aliases with both yield values.
rootOp->walk([&](scf::IfOp ifOp) {
if (ifOp->getNumResults() > 0) {
for (auto it : llvm::zip(ifOp.thenYield().results(),
ifOp.elseYield().results(), ifOp.results())) {
aliasInfo.unionSets(std::get<0>(it), std::get<1>(it));
aliasInfo.unionSets(std::get<0>(it), std::get<2>(it));
equivalentInfo.unionSets(std::get<0>(it), std::get<1>(it));
equivalentInfo.unionSets(std::get<0>(it), std::get<2>(it));
}
}
});
}
/// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
@ -834,13 +859,28 @@ void BufferizationAliasInfo::bufferizeOutOfPlace(OpResult result) {
}
/// Starting from `value`, follow the use-def chain in reverse, always selecting
/// the corresponding aliasing OpOperand. Try to find and return a Value for
/// which `condition` evaluates to true.
/// the aliasing OpOperands. Find and return Values for which `condition`
/// evaluates to true. OpOperands of such matching Values are not traversed any
/// further.
///
/// When reaching the end of the chain (BlockArgument or Value without aliasing
/// OpOperands), return the last Value of the chain.
/// When reaching the end of a chain (BlockArgument or Value without aliasing
/// OpOperands), also return the last Value of that chain.
///
/// Note: The returned SetVector contains exactly one element.
/// Example:
///
/// 8
/// |
/// 6* 7* +-----+----+
/// | | | |
/// 2* 3 4* 5
/// | | | |
/// +----------+----------+----------+
/// |
/// 1
///
/// In the above example, Values with a star satisfy the condition. When
/// starting the traversal from Value 1, the resulting SetVector is:
/// { 2, 7, 8, 5 }
static llvm::SetVector<Value>
findValueInReverseUseDefChain(Value value,
std::function<bool(Value)> condition) {
@ -861,18 +901,22 @@ findValueInReverseUseDefChain(Value value,
continue;
}
assert(opOperands.size() == 1 && "multiple OpOperands not supported yet");
workingSet.insert(opOperands.front()->get());
for (OpOperand *o : opOperands)
workingSet.insert(o->get());
}
return result;
}
/// Find the Value (result) of the last preceding write of a given Value.
/// Find the Value of the last preceding write of a given Value.
///
/// Note: Unknown ops are handled conservatively and assumed to be writes.
/// Furthermore, BlockArguments are also assumed to be writes. There is no
/// analysis across block boundaries.
///
/// Note: To simplify the analysis, scf.if ops are considered writes. Treating
/// a non-writing op as a writing op may introduce unnecessary out-of-place
/// bufferizations, but is always safe from a correctness point of view.
static Value findLastPrecedingWrite(Value value) {
SetVector<Value> result =
findValueInReverseUseDefChain(value, [](Value value) {
@ -881,6 +925,8 @@ static Value findLastPrecedingWrite(Value value) {
return true;
if (!hasKnownBufferizationAliasingBehavior(op))
return true;
if (isa<scf::IfOp>(op))
return true;
SmallVector<OpOperand *> opOperands =
getAliasingOpOperand(value.cast<OpResult>());
@ -911,6 +957,21 @@ bool BufferizationAliasInfo::hasMatchingExtractSliceOp(
condition);
}
/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
/// properly dominates `b` and `b` is not inside `a`.
static bool happensBefore(Operation *a, Operation *b,
const DominanceInfo &domInfo) {
do {
// TODO: Instead of isProperAncestor + properlyDominates, we should use
// properlyDominatesImpl(a, b, /*enclosingOpOk=*/false)
if (a->isProperAncestor(b))
return false;
if (domInfo.properlyDominates(a, b))
return true;
} while ((a = a->getParentOp()));
return false;
}
/// Given sets of uses and writes, return true if there is a RaW conflict under
/// the assumption that all given reads/writes alias the same buffer and that
/// all given writes bufferize inplace.
@ -935,7 +996,6 @@ bool BufferizationAliasInfo::hasReadAfterWriteInterference(
// In the above example, if uRead is the OpOperand of reading_op, lastWrite
// is %0. Note that operations that create an alias but do not write (such
// as ExtractSliceOp) are skipped.
// TODO: With branches this should probably be a list of Values.
Value lastWrite = findLastPrecedingWrite(uRead->get());
// Look for conflicting memory writes. Potential conflicts are writes to an
@ -949,21 +1009,35 @@ bool BufferizationAliasInfo::hasReadAfterWriteInterference(
LDBG("Found potential conflict:\n");
LDBG("READ = #" << uRead->getOperandNumber() << " of "
<< printOperationInfo(readingOp) << "\n");
LDBG("WRITE = #" << printValueInfo(lastWrite) << "\n");
LDBG("CONFLICTING WRITE = #"
<< uConflictingWrite->getOperandNumber() << " of "
<< printOperationInfo(conflictingWritingOp) << "\n");
// No conflict if the readingOp dominates conflictingWritingOp, i.e., the
// write is not visible when reading.
if (domInfo.properlyDominates(readingOp, conflictingWritingOp))
if (happensBefore(readingOp, conflictingWritingOp, domInfo))
continue;
// No conflict if the conflicting write happens before the last write.
// No conflict if the reading use equals the use of the conflicting write.
// A use cannot conflict with itself. Note: Just being the same op is not
// enough. It has to be the same use.
if (uConflictingWrite == uRead)
continue;
if (scf::insideMutuallyExclusiveBranches(readingOp, conflictingWritingOp))
continue;
LDBG("WRITE = #" << printValueInfo(lastWrite) << "\n");
// No conflict if the conflicting write happens before the last
// write.
if (Operation *writingOp = lastWrite.getDefiningOp()) {
if (domInfo.properlyDominates(conflictingWritingOp, writingOp))
if (happensBefore(conflictingWritingOp, writingOp, domInfo))
// conflictingWritingOp happens before writingOp. No conflict.
continue;
// No conflict if conflictingWritingOp is contained in writingOp.
if (writingOp->isProperAncestor(conflictingWritingOp))
continue;
} else {
auto bbArg = lastWrite.cast<BlockArgument>();
Block *block = bbArg.getOwner();
@ -978,11 +1052,6 @@ bool BufferizationAliasInfo::hasReadAfterWriteInterference(
if (getAliasingOpResult(*uConflictingWrite) == lastWrite)
continue;
// No conflict is the same use is the read and the conflicting write. A
// use cannot conflict with itself.
if (uConflictingWrite == uRead)
continue;
// Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
// uRead is an InsertSliceOp...
if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
@ -1423,15 +1492,27 @@ static Value getResultBuffer(OpBuilder &b, OpResult result,
OpBuilder::InsertionGuard guard(b);
Operation *op = result.getOwner();
SmallVector<OpOperand *> aliasingOperands = getAliasingOpOperand(result);
// TODO: Support multiple OpOperands.
assert(aliasingOperands.size() == 1 &&
"more than 1 OpOperand not supported yet");
assert(!aliasingOperands.empty() && "could not get aliasing OpOperand");
Value operand = aliasingOperands.front()->get();
Value operandBuffer = lookup(bvm, operand);
assert(operandBuffer && "operand buffer not found");
// Make sure that all OpOperands are the same buffer. If this is not the case,
// we would have to materialize a memref value.
if (!llvm::all_of(aliasingOperands, [&](OpOperand *o) {
return lookup(bvm, o->get()) == operandBuffer;
})) {
op->emitError("result buffer is ambiguous");
return Value();
}
// If bufferizing out-of-place, allocate a new buffer.
if (getInPlace(result) != InPlaceSpec::True) {
bool needCopy =
getInPlace(result) != InPlaceSpec::True && !isa<scf::IfOp>(op);
if (needCopy) {
// Ops such as scf::IfOp can currently not bufferize out-of-place.
assert(
aliasingOperands.size() == 1 &&
"ops with multiple aliasing OpOperands cannot bufferize out-of-place");
Location loc = op->getLoc();
// Allocate the result buffer.
Value resultBuffer =
@ -1771,6 +1852,31 @@ static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp,
return success();
}
static LogicalResult bufferize(OpBuilder &b, scf::IfOp ifOp,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
for (OpResult opResult : ifOp->getResults()) {
if (!opResult.getType().isa<TensorType>())
continue;
// TODO: Atm we bail on unranked TensorType because we don't know how to
// alloc an UnrankedMemRefType + its underlying ranked MemRefType.
assert(opResult.getType().isa<RankedTensorType>() &&
"unsupported unranked tensor");
Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo);
if (!resultBuffer)
return failure();
aliasInfo.createAliasInfoEntry(resultBuffer);
map(bvm, opResult, resultBuffer);
}
return success();
}
/// FuncOp always creates TensorToMemRef ops.
static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
BlockAndValueMapping &bvm,
@ -2038,7 +2144,6 @@ static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp,
getResultBuffer(b, insertSliceOp->getResult(0), bvm, aliasInfo);
if (!dstMemref)
return failure();
auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
Value srcMemref = lookup(bvm, insertSliceOp.source());
@ -2127,6 +2232,9 @@ static LogicalResult bufferize(OpBuilder &b, scf::YieldOp yieldOp,
return success();
}
if (auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp()))
return success();
scf::ForOp forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp());
if (!forOp)
return yieldOp->emitError("expected scf::ForOp parent for scf::YieldOp");
@ -2344,6 +2452,13 @@ LogicalResult mlir::linalg::bufferizeOp(
LDBG("Begin bufferize:\n" << op << '\n');
return bufferize(b, op, bvm, aliasInfo);
})
.Case<tensor::CastOp, tensor::DimOp, ExtractSliceOp, InitTensorOp,
InsertSliceOp, tensor::ExtractOp, LinalgOp, ReturnOp,
VectorTransferOpInterface, linalg::YieldOp, scf::YieldOp,
scf::IfOp>([&](auto op) {
LDBG("Begin bufferize:\n" << op << '\n');
return bufferize(b, op, bvm, aliasInfo);
})
.Case([&](CallOpInterface op) {
LDBG("Begin bufferize:\n" << op << '\n');
if (!bufferizedFunctionTypes)

View File

@ -1087,3 +1087,291 @@ func @buffer_forwarding_no_conflict(%arg0: tensor<?xf32> {linalg.inplaceable = t
%2 = tensor.insert_slice %1 into %arg0[42] [%arg1] [1] : tensor<?xf32> into tensor<?xf32>
return %2, %2 : tensor<?xf32>, tensor<?xf32>
}
// -----
//===----------------------------------------------------------------------===//
// scf.if cases
//===----------------------------------------------------------------------===//
// This example passes analysis, but it fails when bufferizing.
// CHECK-LABEL: func @scf_if_inplace1
func @scf_if_inplace1(%t1: tensor<?xf32> {linalg.inplaceable = true},
%t2: tensor<?xf32> {linalg.inplaceable = true},
%cond: i1) -> tensor<?xf32> {
%r = scf.if %cond -> (tensor<?xf32>) {
scf.yield %t1 : tensor<?xf32>
} else {
scf.yield %t2 : tensor<?xf32>
}
return %r : tensor<?xf32>
}
// CHECK-LABEL: func @scf_if_inplace2
func @scf_if_inplace2(%t1: tensor<?xf32> {linalg.inplaceable = true},
%v: vector<5xf32>, %idx: index,
%cond: i1) -> tensor<?xf32> {
%r = scf.if %cond -> (tensor<?xf32>) {
scf.yield %t1 : tensor<?xf32>
} else {
// CHECK: vector.transfer_write
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
%t2 = vector.transfer_write %v, %t1[%idx] : vector<5xf32>, tensor<?xf32>
scf.yield %t2 : tensor<?xf32>
}
return %r : tensor<?xf32>
}
// -----
// CHECK-LABEL: func @scf_if_inplace3
func @scf_if_inplace3(%t1: tensor<?xf32> {linalg.inplaceable = true},
%v1: vector<5xf32>, %v2: vector<5xf32>, %idx: index,
%cond: i1) -> tensor<?xf32> {
// CHECK: tensor.extract_slice
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
%e = tensor.extract_slice %t1[%idx][%idx][1] : tensor<?xf32> to tensor<?xf32>
%r = scf.if %cond -> (tensor<?xf32>) {
// CHECK: vector.transfer_write
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
%t2 = vector.transfer_write %v1, %e[%idx] : vector<5xf32>, tensor<?xf32>
scf.yield %t2 : tensor<?xf32>
} else {
// Writing the same tensor through an alias. This is OK.
// CHECK: vector.transfer_write
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
%t3 = vector.transfer_write %v2, %t1[%idx] : vector<5xf32>, tensor<?xf32>
scf.yield %t3 : tensor<?xf32>
}
return %r : tensor<?xf32>
}
// -----
// CHECK-LABEL: func @scf_if_in_place4
func @scf_if_in_place4(%t1: tensor<?xf32> {linalg.inplaceable = true},
%v: vector<5xf32>, %idx: index,
%cond: i1, %cond2: i1) -> (tensor<?xf32>, vector<10xf32>) {
%cst = arith.constant 0.0 : f32
%r = scf.if %cond -> (tensor<?xf32>) {
scf.yield %t1 : tensor<?xf32>
} else {
// CHECK: vector.transfer_write
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
%t2 = vector.transfer_write %v, %t1[%idx] : vector<5xf32>, tensor<?xf32>
scf.yield %t2 : tensor<?xf32>
}
%r_alias = scf.if %cond2 -> (tensor<?xf32>) {
// Reading %r is OK. No conflict.
scf.yield %r : tensor<?xf32>
} else {
scf.yield %r : tensor<?xf32>
}
%v2 = vector.transfer_read %r_alias[%idx], %cst : tensor<?xf32>, vector<10xf32>
return %r_alias, %v2 : tensor<?xf32>, vector<10xf32>
}
// -----
// CHECK-LABEL: func @scf_if_inplace5
func @scf_if_inplace5(%t1: tensor<?xf32> {linalg.inplaceable = true},
%idx: index, %cond: i1) -> tensor<?xf32> {
%r = scf.if %cond -> (tensor<?xf32>) {
// CHECK: tensor.extract_slice
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
%e = tensor.extract_slice %t1[%idx][%idx][1] : tensor<?xf32> to tensor<?xf32>
scf.yield %e : tensor<?xf32>
} else {
// CHECK: tensor.extract_slice
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
%f = tensor.extract_slice %t1[%idx][%idx][1] : tensor<?xf32> to tensor<?xf32>
scf.yield %f : tensor<?xf32>
}
// Inserting into an equivalent tensor at the same offset. This bufferizes
// inplace.
// CHECK: tensor.insert_slice
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
%r2 = tensor.insert_slice %r into %t1[%idx][%idx][1] : tensor<?xf32> into tensor<?xf32>
return %r2 : tensor<?xf32>
}
// -----
// CHECK-LABEL: func @scf_if_inplace6
func @scf_if_inplace6(%t1: tensor<?xf32> {linalg.inplaceable = true},
%v1: vector<5xf32>, %v2: vector<5xf32>,
%v3: vector<5xf32>, %idx: index,
%cond: i1, %cond2: i1) -> tensor<?xf32> {
// Test nested scf.if ops.
%r = scf.if %cond -> (tensor<?xf32>) {
%t2 = scf.if %cond2 -> (tensor<?xf32>) {
// CHECK: vector.transfer_write
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
%t3 = vector.transfer_write %v1, %t1[%idx] : vector<5xf32>, tensor<?xf32>
scf.yield %t3 : tensor<?xf32>
} else {
// CHECK: vector.transfer_write
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
%t4 = vector.transfer_write %v3, %t1[%idx] : vector<5xf32>, tensor<?xf32>
scf.yield %t4 : tensor<?xf32>
}
scf.yield %t2 : tensor<?xf32>
} else {
// CHECK: vector.transfer_write
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
%t3 = vector.transfer_write %v2, %t1[%idx] : vector<5xf32>, tensor<?xf32>
scf.yield %t3 : tensor<?xf32>
}
return %r : tensor<?xf32>
}
// -----
// CHECK-LABEL: func @scf_if_inplace7
func @scf_if_inplace7(%t1: tensor<?xf32> {linalg.inplaceable = true},
%v1: vector<5xf32>, %v2: vector<5xf32>, %idx: index,
%idx2: index, %cond: i1) -> (tensor<?xf32>, vector<5xf32>) {
%cst = arith.constant 0.0 : f32
%r, %v_r2 = scf.if %cond -> (tensor<?xf32>, vector<5xf32>) {
// CHECK: vector.transfer_write
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
%t2 = vector.transfer_write %v1, %t1[%idx] : vector<5xf32>, tensor<?xf32>
scf.yield %t2, %v1 : tensor<?xf32>, vector<5xf32>
} else {
// Writing the same tensor through an alias.
// CHECK: vector.transfer_write
// CHECK-SAME: {__inplace_results_attr__ = ["false"]
%t3 = vector.transfer_write %v2, %t1[%idx] : vector<5xf32>, tensor<?xf32>
// Read the original value of %t1. This requires the write in this branch
// to be out-of-place. But the write in the other branch can still be
// inplace.
%v_r = vector.transfer_read %t1[%idx2], %cst : tensor<?xf32>, vector<5xf32>
scf.yield %t3, %v_r : tensor<?xf32>, vector<5xf32>
}
return %r, %v_r2 : tensor<?xf32>, vector<5xf32>
}
// -----
// CHECK-LABEL: func @scf_if_out_of_place1a
func @scf_if_out_of_place1a(%t1: tensor<?xf32> {linalg.inplaceable = true},
%idx: index, %idx2: index,
%cond: i1) -> tensor<?xf32> {
%r = scf.if %cond -> (tensor<?xf32>) {
// CHECK: tensor.extract_slice
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
%e = tensor.extract_slice %t1[%idx][%idx][1] : tensor<?xf32> to tensor<?xf32>
scf.yield %e : tensor<?xf32>
} else {
scf.yield %t1 : tensor<?xf32>
}
// Reading from and writing to the same tensor via different args. This is a
// conflict.
// CHECK: tensor.insert_slice
// CHECK-SAME: {__inplace_results_attr__ = ["false"]
%r2 = tensor.insert_slice %r into %t1[%idx2][%idx2][1] : tensor<?xf32> into tensor<?xf32>
return %r2 : tensor<?xf32>
}
// -----
// CHECK-LABEL: func @scf_if_out_of_place1b
func @scf_if_out_of_place1b(%t1: tensor<?xf32> {linalg.inplaceable = true},
%idx: index, %idx2: index, %idx3: index,
%cond: i1) -> tensor<?xf32> {
%r = scf.if %cond -> (tensor<?xf32>) {
// CHECK: tensor.extract_slice
// CHECK-SAME: {__inplace_results_attr__ = ["false"]
%e = tensor.extract_slice %t1[%idx][%idx][1] : tensor<?xf32> to tensor<?xf32>
scf.yield %e : tensor<?xf32>
} else {
// CHECK: tensor.extract_slice
// CHECK-SAME: {__inplace_results_attr__ = ["false"]
%f = tensor.extract_slice %t1[%idx2][%idx2][1] : tensor<?xf32> to tensor<?xf32>
scf.yield %f : tensor<?xf32>
}
// Reading from and writing to the same tensor via different args. This is a
// conflict. In contrast to scf_if_out_of_place1a, the fact that %r aliases
// with %t1 is only detected when analyzing the tensor.extract_slices. That's
// why the tensor.insert_slice is inplace and the two extract_slices are
// out-of-place.
// CHECK: tensor.insert_slice
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
%r2 = tensor.insert_slice %r into %t1[%idx3][%idx3][1] : tensor<?xf32> into tensor<?xf32>
return %r2 : tensor<?xf32>
}
// -----
// CHECK-LABEL: func @scf_if_out_of_place1c
func @scf_if_out_of_place1c(%t1: tensor<?xf32> {linalg.inplaceable = true},
%idx: index, %idx2: index, %cond: i1) -> tensor<?xf32> {
%r = scf.if %cond -> (tensor<?xf32>) {
// CHECK: tensor.extract_slice
// CHECK-SAME: {__inplace_results_attr__ = ["false"]
%e = tensor.extract_slice %t1[%idx][%idx][1] : tensor<?xf32> to tensor<?xf32>
scf.yield %e : tensor<?xf32>
} else {
// TODO: This one could bufferize inplace, but the analysis is too restrictive.
// CHECK: tensor.extract_slice
// CHECK-SAME: {__inplace_results_attr__ = ["false"]
%f = tensor.extract_slice %t1[%idx2][%idx2][1] : tensor<?xf32> to tensor<?xf32>
scf.yield %f : tensor<?xf32>
}
// CHECK: tensor.insert_slice
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
%r2 = tensor.insert_slice %r into %t1[%idx2][%idx2][1] : tensor<?xf32> into tensor<?xf32>
return %r2 : tensor<?xf32>
}
// -----
// CHECK-LABEL: func @scf_if_out_of_place2
func @scf_if_out_of_place2(%t1: tensor<?xf32> {linalg.inplaceable = true},
%v: vector<5xf32>, %idx: index,
%cond: i1) -> (tensor<?xf32>, vector<10xf32>) {
%cst = arith.constant 0.0 : f32
%r = scf.if %cond -> (tensor<?xf32>) {
scf.yield %t1 : tensor<?xf32>
} else {
// CHECK: vector.transfer_write
// CHECK-SAME: {__inplace_results_attr__ = ["false"]
%t2 = vector.transfer_write %v, %t1[%idx] : vector<5xf32>, tensor<?xf32>
scf.yield %t2 : tensor<?xf32>
}
// Read the old value of %t1. Forces the transfer_write to bufferize
// out-of-place.
%v2 = vector.transfer_read %t1[%idx], %cst : tensor<?xf32>, vector<10xf32>
return %r, %v2 : tensor<?xf32>, vector<10xf32>
}
// -----
// CHECK-LABEL: func @scf_if_out_of_place3
func @scf_if_out_of_place3(%t1: tensor<?xf32> {linalg.inplaceable = true},
%v: vector<5xf32>, %idx: index,
%cond: i1, %cond2: i1) -> (tensor<?xf32>, vector<10xf32>) {
%cst = arith.constant 0.0 : f32
%r = scf.if %cond -> (tensor<?xf32>) {
scf.yield %t1 : tensor<?xf32>
} else {
// CHECK: vector.transfer_write
// CHECK-SAME: {__inplace_results_attr__ = ["false"]
%t2 = vector.transfer_write %v, %t1[%idx] : vector<5xf32>, tensor<?xf32>
scf.yield %t2 : tensor<?xf32>
}
%t1_alias = scf.if %cond2 -> (tensor<?xf32>) {
// scf.yield bufferizes to a read. That is a conflict in this example.
scf.yield %t1 : tensor<?xf32>
} else {
scf.yield %t1 : tensor<?xf32>
}
%v2 = vector.transfer_read %t1_alias[%idx], %cst : tensor<?xf32>, vector<10xf32>
return %r, %v2 : tensor<?xf32>, vector<10xf32>
}

View File

@ -113,8 +113,8 @@ func @extract_slice_fun(%A : tensor<?xf32> {linalg.inplaceable = true})
func @scf_yield(%b : i1, %A : tensor<4xf32>, %B : tensor<4xf32>) -> tensor<4xf32>
{
// expected-error @+1 {{result buffer is ambiguous}}
%r = scf.if %b -> (tensor<4xf32>) {
// expected-error @+1 {{expected scf::ForOp parent for scf::YieldOp}}
scf.yield %A : tensor<4xf32>
} else {
scf.yield %B : tensor<4xf32>

View File

@ -861,3 +861,25 @@ func @buffer_forwarding_no_conflict(
return %r1: tensor<?xf32>
}
// -----
// CHECK-LABEL: func @scf_if_inplace(
// CHECK-SAME: %[[cond:.*]]: i1, %[[t1:.*]]: memref<?xf32{{.*}}>, %[[v:.*]]: vector
func @scf_if_inplace(%cond: i1,
%t1: tensor<?xf32> {linalg.inplaceable = true},
%v: vector<5xf32>, %idx: index) -> tensor<?xf32> {
// CHECK: scf.if %[[cond]] {
// CHECK-NEXT: } else {
// CHECK-NEXT: vector.transfer_write %[[v]], %[[t1]]
// CHECK-NEXT: }
// CHECK-NEXT: return
%r = scf.if %cond -> (tensor<?xf32>) {
scf.yield %t1 : tensor<?xf32>
} else {
%t2 = vector.transfer_write %v, %t1[%idx] : vector<5xf32>, tensor<?xf32>
scf.yield %t2 : tensor<?xf32>
}
return %r : tensor<?xf32>
}