forked from OSchip/llvm-project
[mlir][linalg][bufferize] Support scf::IfOp
This commit adds support for scf::IfOp to comprehensive bufferization. Support is currently limited to cases where both branches yield tensors that bufferize to the same buffer. To keep the analysis simple, scf::IfOp are treated as memory writes for analysis purposes, even if no op inside any branch is writing. (scf::ForOps are handled in the same way.) Differential Revision: https://reviews.llvm.org/D111929
This commit is contained in:
parent
4976be1e95
commit
3bbc869e2e
|
@ -442,6 +442,7 @@ static bool hasKnownBufferizationAliasingBehavior(Operation *op) {
|
|||
ConstantOp,
|
||||
tensor::DimOp,
|
||||
ExtractSliceOp,
|
||||
scf::IfOp,
|
||||
scf::ForOp,
|
||||
InsertSliceOp,
|
||||
InitTensorOp,
|
||||
|
@ -550,6 +551,16 @@ static OpResult getInplaceableOpResult(OpOperand &opOperand) {
|
|||
// clang-format on
|
||||
}
|
||||
|
||||
/// Either one of the corresponding yield values from the then/else branches
|
||||
/// may alias with the result.
|
||||
static void populateAliasingOpOperands(scf::IfOp op, OpResult result,
|
||||
SmallVector<OpOperand *> &operands) {
|
||||
size_t resultNum = std::distance(op->getOpResults().begin(),
|
||||
llvm::find(op->getOpResults(), result));
|
||||
operands.push_back(&op.thenYield()->getOpOperand(resultNum));
|
||||
operands.push_back(&op.elseYield()->getOpOperand(resultNum));
|
||||
}
|
||||
|
||||
/// Determine which OpOperand* will alias with `result` if the op is bufferized
|
||||
/// in place. Note that multiple OpOperands can may potentially alias with an
|
||||
/// OpResult. E.g.: std.select in the future.
|
||||
|
@ -561,6 +572,7 @@ static SmallVector<OpOperand *> getAliasingOpOperand(OpResult result) {
|
|||
TypeSwitch<Operation *>(result.getDefiningOp())
|
||||
.Case([&](tensor::CastOp op) { r.push_back(&op->getOpOperand(0)); })
|
||||
.Case([&](ExtractSliceOp op) { r.push_back(&op->getOpOperand(0)); })
|
||||
.Case([&](scf::IfOp op) { populateAliasingOpOperands(op, result, r); })
|
||||
// In the case of scf::ForOp, this currently assumes the iter_args / yield
|
||||
// are 1-1. This may fail and is verified at the end.
|
||||
// TODO: update this.
|
||||
|
@ -730,6 +742,19 @@ BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
|
|||
if (bbArg.getType().isa<TensorType>())
|
||||
createAliasInfoEntry(bbArg);
|
||||
});
|
||||
|
||||
// The return value of an scf::IfOp aliases with both yield values.
|
||||
rootOp->walk([&](scf::IfOp ifOp) {
|
||||
if (ifOp->getNumResults() > 0) {
|
||||
for (auto it : llvm::zip(ifOp.thenYield().results(),
|
||||
ifOp.elseYield().results(), ifOp.results())) {
|
||||
aliasInfo.unionSets(std::get<0>(it), std::get<1>(it));
|
||||
aliasInfo.unionSets(std::get<0>(it), std::get<2>(it));
|
||||
equivalentInfo.unionSets(std::get<0>(it), std::get<1>(it));
|
||||
equivalentInfo.unionSets(std::get<0>(it), std::get<2>(it));
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
|
||||
|
@ -834,13 +859,28 @@ void BufferizationAliasInfo::bufferizeOutOfPlace(OpResult result) {
|
|||
}
|
||||
|
||||
/// Starting from `value`, follow the use-def chain in reverse, always selecting
|
||||
/// the corresponding aliasing OpOperand. Try to find and return a Value for
|
||||
/// which `condition` evaluates to true.
|
||||
/// the aliasing OpOperands. Find and return Values for which `condition`
|
||||
/// evaluates to true. OpOperands of such matching Values are not traversed any
|
||||
/// further.
|
||||
///
|
||||
/// When reaching the end of the chain (BlockArgument or Value without aliasing
|
||||
/// OpOperands), return the last Value of the chain.
|
||||
/// When reaching the end of a chain (BlockArgument or Value without aliasing
|
||||
/// OpOperands), also return the last Value of that chain.
|
||||
///
|
||||
/// Note: The returned SetVector contains exactly one element.
|
||||
/// Example:
|
||||
///
|
||||
/// 8
|
||||
/// |
|
||||
/// 6* 7* +-----+----+
|
||||
/// | | | |
|
||||
/// 2* 3 4* 5
|
||||
/// | | | |
|
||||
/// +----------+----------+----------+
|
||||
/// |
|
||||
/// 1
|
||||
///
|
||||
/// In the above example, Values with a star satisfy the condition. When
|
||||
/// starting the traversal from Value 1, the resulting SetVector is:
|
||||
/// { 2, 7, 8, 5 }
|
||||
static llvm::SetVector<Value>
|
||||
findValueInReverseUseDefChain(Value value,
|
||||
std::function<bool(Value)> condition) {
|
||||
|
@ -861,18 +901,22 @@ findValueInReverseUseDefChain(Value value,
|
|||
continue;
|
||||
}
|
||||
|
||||
assert(opOperands.size() == 1 && "multiple OpOperands not supported yet");
|
||||
workingSet.insert(opOperands.front()->get());
|
||||
for (OpOperand *o : opOperands)
|
||||
workingSet.insert(o->get());
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Find the Value (result) of the last preceding write of a given Value.
|
||||
/// Find the Value of the last preceding write of a given Value.
|
||||
///
|
||||
/// Note: Unknown ops are handled conservatively and assumed to be writes.
|
||||
/// Furthermore, BlockArguments are also assumed to be writes. There is no
|
||||
/// analysis across block boundaries.
|
||||
///
|
||||
/// Note: To simplify the analysis, scf.if ops are considered writes. Treating
|
||||
/// a non-writing op as a writing op may introduce unnecessary out-of-place
|
||||
/// bufferizations, but is always safe from a correctness point of view.
|
||||
static Value findLastPrecedingWrite(Value value) {
|
||||
SetVector<Value> result =
|
||||
findValueInReverseUseDefChain(value, [](Value value) {
|
||||
|
@ -881,6 +925,8 @@ static Value findLastPrecedingWrite(Value value) {
|
|||
return true;
|
||||
if (!hasKnownBufferizationAliasingBehavior(op))
|
||||
return true;
|
||||
if (isa<scf::IfOp>(op))
|
||||
return true;
|
||||
|
||||
SmallVector<OpOperand *> opOperands =
|
||||
getAliasingOpOperand(value.cast<OpResult>());
|
||||
|
@ -911,6 +957,21 @@ bool BufferizationAliasInfo::hasMatchingExtractSliceOp(
|
|||
condition);
|
||||
}
|
||||
|
||||
/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
|
||||
/// properly dominates `b` and `b` is not inside `a`.
|
||||
static bool happensBefore(Operation *a, Operation *b,
|
||||
const DominanceInfo &domInfo) {
|
||||
do {
|
||||
// TODO: Instead of isProperAncestor + properlyDominates, we should use
|
||||
// properlyDominatesImpl(a, b, /*enclosingOpOk=*/false)
|
||||
if (a->isProperAncestor(b))
|
||||
return false;
|
||||
if (domInfo.properlyDominates(a, b))
|
||||
return true;
|
||||
} while ((a = a->getParentOp()));
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Given sets of uses and writes, return true if there is a RaW conflict under
|
||||
/// the assumption that all given reads/writes alias the same buffer and that
|
||||
/// all given writes bufferize inplace.
|
||||
|
@ -935,7 +996,6 @@ bool BufferizationAliasInfo::hasReadAfterWriteInterference(
|
|||
// In the above example, if uRead is the OpOperand of reading_op, lastWrite
|
||||
// is %0. Note that operations that create an alias but do not write (such
|
||||
// as ExtractSliceOp) are skipped.
|
||||
// TODO: With branches this should probably be a list of Values.
|
||||
Value lastWrite = findLastPrecedingWrite(uRead->get());
|
||||
|
||||
// Look for conflicting memory writes. Potential conflicts are writes to an
|
||||
|
@ -949,21 +1009,35 @@ bool BufferizationAliasInfo::hasReadAfterWriteInterference(
|
|||
LDBG("Found potential conflict:\n");
|
||||
LDBG("READ = #" << uRead->getOperandNumber() << " of "
|
||||
<< printOperationInfo(readingOp) << "\n");
|
||||
LDBG("WRITE = #" << printValueInfo(lastWrite) << "\n");
|
||||
LDBG("CONFLICTING WRITE = #"
|
||||
<< uConflictingWrite->getOperandNumber() << " of "
|
||||
<< printOperationInfo(conflictingWritingOp) << "\n");
|
||||
|
||||
// No conflict if the readingOp dominates conflictingWritingOp, i.e., the
|
||||
// write is not visible when reading.
|
||||
if (domInfo.properlyDominates(readingOp, conflictingWritingOp))
|
||||
if (happensBefore(readingOp, conflictingWritingOp, domInfo))
|
||||
continue;
|
||||
|
||||
// No conflict if the conflicting write happens before the last write.
|
||||
// No conflict if the reading use equals the use of the conflicting write.
|
||||
// A use cannot conflict with itself. Note: Just being the same op is not
|
||||
// enough. It has to be the same use.
|
||||
if (uConflictingWrite == uRead)
|
||||
continue;
|
||||
|
||||
if (scf::insideMutuallyExclusiveBranches(readingOp, conflictingWritingOp))
|
||||
continue;
|
||||
|
||||
LDBG("WRITE = #" << printValueInfo(lastWrite) << "\n");
|
||||
|
||||
// No conflict if the conflicting write happens before the last
|
||||
// write.
|
||||
if (Operation *writingOp = lastWrite.getDefiningOp()) {
|
||||
if (domInfo.properlyDominates(conflictingWritingOp, writingOp))
|
||||
if (happensBefore(conflictingWritingOp, writingOp, domInfo))
|
||||
// conflictingWritingOp happens before writingOp. No conflict.
|
||||
continue;
|
||||
// No conflict if conflictingWritingOp is contained in writingOp.
|
||||
if (writingOp->isProperAncestor(conflictingWritingOp))
|
||||
continue;
|
||||
} else {
|
||||
auto bbArg = lastWrite.cast<BlockArgument>();
|
||||
Block *block = bbArg.getOwner();
|
||||
|
@ -978,11 +1052,6 @@ bool BufferizationAliasInfo::hasReadAfterWriteInterference(
|
|||
if (getAliasingOpResult(*uConflictingWrite) == lastWrite)
|
||||
continue;
|
||||
|
||||
// No conflict is the same use is the read and the conflicting write. A
|
||||
// use cannot conflict with itself.
|
||||
if (uConflictingWrite == uRead)
|
||||
continue;
|
||||
|
||||
// Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
|
||||
// uRead is an InsertSliceOp...
|
||||
if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
|
||||
|
@ -1423,15 +1492,27 @@ static Value getResultBuffer(OpBuilder &b, OpResult result,
|
|||
OpBuilder::InsertionGuard guard(b);
|
||||
Operation *op = result.getOwner();
|
||||
SmallVector<OpOperand *> aliasingOperands = getAliasingOpOperand(result);
|
||||
// TODO: Support multiple OpOperands.
|
||||
assert(aliasingOperands.size() == 1 &&
|
||||
"more than 1 OpOperand not supported yet");
|
||||
assert(!aliasingOperands.empty() && "could not get aliasing OpOperand");
|
||||
Value operand = aliasingOperands.front()->get();
|
||||
Value operandBuffer = lookup(bvm, operand);
|
||||
assert(operandBuffer && "operand buffer not found");
|
||||
// Make sure that all OpOperands are the same buffer. If this is not the case,
|
||||
// we would have to materialize a memref value.
|
||||
if (!llvm::all_of(aliasingOperands, [&](OpOperand *o) {
|
||||
return lookup(bvm, o->get()) == operandBuffer;
|
||||
})) {
|
||||
op->emitError("result buffer is ambiguous");
|
||||
return Value();
|
||||
}
|
||||
|
||||
// If bufferizing out-of-place, allocate a new buffer.
|
||||
if (getInPlace(result) != InPlaceSpec::True) {
|
||||
bool needCopy =
|
||||
getInPlace(result) != InPlaceSpec::True && !isa<scf::IfOp>(op);
|
||||
if (needCopy) {
|
||||
// Ops such as scf::IfOp can currently not bufferize out-of-place.
|
||||
assert(
|
||||
aliasingOperands.size() == 1 &&
|
||||
"ops with multiple aliasing OpOperands cannot bufferize out-of-place");
|
||||
Location loc = op->getLoc();
|
||||
// Allocate the result buffer.
|
||||
Value resultBuffer =
|
||||
|
@ -1771,6 +1852,31 @@ static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp,
|
|||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult bufferize(OpBuilder &b, scf::IfOp ifOp,
|
||||
BlockAndValueMapping &bvm,
|
||||
BufferizationAliasInfo &aliasInfo) {
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
|
||||
for (OpResult opResult : ifOp->getResults()) {
|
||||
if (!opResult.getType().isa<TensorType>())
|
||||
continue;
|
||||
// TODO: Atm we bail on unranked TensorType because we don't know how to
|
||||
// alloc an UnrankedMemRefType + its underlying ranked MemRefType.
|
||||
assert(opResult.getType().isa<RankedTensorType>() &&
|
||||
"unsupported unranked tensor");
|
||||
|
||||
Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo);
|
||||
if (!resultBuffer)
|
||||
return failure();
|
||||
|
||||
aliasInfo.createAliasInfoEntry(resultBuffer);
|
||||
map(bvm, opResult, resultBuffer);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
/// FuncOp always creates TensorToMemRef ops.
|
||||
static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
|
||||
BlockAndValueMapping &bvm,
|
||||
|
@ -2038,7 +2144,6 @@ static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp,
|
|||
getResultBuffer(b, insertSliceOp->getResult(0), bvm, aliasInfo);
|
||||
if (!dstMemref)
|
||||
return failure();
|
||||
|
||||
auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
|
||||
|
||||
Value srcMemref = lookup(bvm, insertSliceOp.source());
|
||||
|
@ -2127,6 +2232,9 @@ static LogicalResult bufferize(OpBuilder &b, scf::YieldOp yieldOp,
|
|||
return success();
|
||||
}
|
||||
|
||||
if (auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp()))
|
||||
return success();
|
||||
|
||||
scf::ForOp forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp());
|
||||
if (!forOp)
|
||||
return yieldOp->emitError("expected scf::ForOp parent for scf::YieldOp");
|
||||
|
@ -2344,6 +2452,13 @@ LogicalResult mlir::linalg::bufferizeOp(
|
|||
LDBG("Begin bufferize:\n" << op << '\n');
|
||||
return bufferize(b, op, bvm, aliasInfo);
|
||||
})
|
||||
.Case<tensor::CastOp, tensor::DimOp, ExtractSliceOp, InitTensorOp,
|
||||
InsertSliceOp, tensor::ExtractOp, LinalgOp, ReturnOp,
|
||||
VectorTransferOpInterface, linalg::YieldOp, scf::YieldOp,
|
||||
scf::IfOp>([&](auto op) {
|
||||
LDBG("Begin bufferize:\n" << op << '\n');
|
||||
return bufferize(b, op, bvm, aliasInfo);
|
||||
})
|
||||
.Case([&](CallOpInterface op) {
|
||||
LDBG("Begin bufferize:\n" << op << '\n');
|
||||
if (!bufferizedFunctionTypes)
|
||||
|
|
|
@ -1087,3 +1087,291 @@ func @buffer_forwarding_no_conflict(%arg0: tensor<?xf32> {linalg.inplaceable = t
|
|||
%2 = tensor.insert_slice %1 into %arg0[42] [%arg1] [1] : tensor<?xf32> into tensor<?xf32>
|
||||
return %2, %2 : tensor<?xf32>, tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// scf.if cases
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// This example passes analysis, but it fails when bufferizing.
|
||||
// CHECK-LABEL: func @scf_if_inplace1
|
||||
func @scf_if_inplace1(%t1: tensor<?xf32> {linalg.inplaceable = true},
|
||||
%t2: tensor<?xf32> {linalg.inplaceable = true},
|
||||
%cond: i1) -> tensor<?xf32> {
|
||||
%r = scf.if %cond -> (tensor<?xf32>) {
|
||||
scf.yield %t1 : tensor<?xf32>
|
||||
} else {
|
||||
scf.yield %t2 : tensor<?xf32>
|
||||
}
|
||||
return %r : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @scf_if_inplace2
|
||||
func @scf_if_inplace2(%t1: tensor<?xf32> {linalg.inplaceable = true},
|
||||
%v: vector<5xf32>, %idx: index,
|
||||
%cond: i1) -> tensor<?xf32> {
|
||||
%r = scf.if %cond -> (tensor<?xf32>) {
|
||||
scf.yield %t1 : tensor<?xf32>
|
||||
} else {
|
||||
// CHECK: vector.transfer_write
|
||||
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
|
||||
%t2 = vector.transfer_write %v, %t1[%idx] : vector<5xf32>, tensor<?xf32>
|
||||
scf.yield %t2 : tensor<?xf32>
|
||||
}
|
||||
return %r : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @scf_if_inplace3
|
||||
func @scf_if_inplace3(%t1: tensor<?xf32> {linalg.inplaceable = true},
|
||||
%v1: vector<5xf32>, %v2: vector<5xf32>, %idx: index,
|
||||
%cond: i1) -> tensor<?xf32> {
|
||||
// CHECK: tensor.extract_slice
|
||||
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
|
||||
%e = tensor.extract_slice %t1[%idx][%idx][1] : tensor<?xf32> to tensor<?xf32>
|
||||
%r = scf.if %cond -> (tensor<?xf32>) {
|
||||
// CHECK: vector.transfer_write
|
||||
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
|
||||
%t2 = vector.transfer_write %v1, %e[%idx] : vector<5xf32>, tensor<?xf32>
|
||||
scf.yield %t2 : tensor<?xf32>
|
||||
} else {
|
||||
// Writing the same tensor through an alias. This is OK.
|
||||
// CHECK: vector.transfer_write
|
||||
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
|
||||
%t3 = vector.transfer_write %v2, %t1[%idx] : vector<5xf32>, tensor<?xf32>
|
||||
scf.yield %t3 : tensor<?xf32>
|
||||
}
|
||||
return %r : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @scf_if_in_place4
|
||||
func @scf_if_in_place4(%t1: tensor<?xf32> {linalg.inplaceable = true},
|
||||
%v: vector<5xf32>, %idx: index,
|
||||
%cond: i1, %cond2: i1) -> (tensor<?xf32>, vector<10xf32>) {
|
||||
%cst = arith.constant 0.0 : f32
|
||||
%r = scf.if %cond -> (tensor<?xf32>) {
|
||||
scf.yield %t1 : tensor<?xf32>
|
||||
} else {
|
||||
// CHECK: vector.transfer_write
|
||||
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
|
||||
%t2 = vector.transfer_write %v, %t1[%idx] : vector<5xf32>, tensor<?xf32>
|
||||
scf.yield %t2 : tensor<?xf32>
|
||||
}
|
||||
%r_alias = scf.if %cond2 -> (tensor<?xf32>) {
|
||||
// Reading %r is OK. No conflict.
|
||||
scf.yield %r : tensor<?xf32>
|
||||
} else {
|
||||
scf.yield %r : tensor<?xf32>
|
||||
}
|
||||
%v2 = vector.transfer_read %r_alias[%idx], %cst : tensor<?xf32>, vector<10xf32>
|
||||
return %r_alias, %v2 : tensor<?xf32>, vector<10xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @scf_if_inplace5
|
||||
func @scf_if_inplace5(%t1: tensor<?xf32> {linalg.inplaceable = true},
|
||||
%idx: index, %cond: i1) -> tensor<?xf32> {
|
||||
%r = scf.if %cond -> (tensor<?xf32>) {
|
||||
// CHECK: tensor.extract_slice
|
||||
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
|
||||
%e = tensor.extract_slice %t1[%idx][%idx][1] : tensor<?xf32> to tensor<?xf32>
|
||||
scf.yield %e : tensor<?xf32>
|
||||
} else {
|
||||
// CHECK: tensor.extract_slice
|
||||
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
|
||||
%f = tensor.extract_slice %t1[%idx][%idx][1] : tensor<?xf32> to tensor<?xf32>
|
||||
scf.yield %f : tensor<?xf32>
|
||||
}
|
||||
|
||||
// Inserting into an equivalent tensor at the same offset. This bufferizes
|
||||
// inplace.
|
||||
// CHECK: tensor.insert_slice
|
||||
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
|
||||
%r2 = tensor.insert_slice %r into %t1[%idx][%idx][1] : tensor<?xf32> into tensor<?xf32>
|
||||
return %r2 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @scf_if_inplace6
|
||||
func @scf_if_inplace6(%t1: tensor<?xf32> {linalg.inplaceable = true},
|
||||
%v1: vector<5xf32>, %v2: vector<5xf32>,
|
||||
%v3: vector<5xf32>, %idx: index,
|
||||
%cond: i1, %cond2: i1) -> tensor<?xf32> {
|
||||
// Test nested scf.if ops.
|
||||
%r = scf.if %cond -> (tensor<?xf32>) {
|
||||
%t2 = scf.if %cond2 -> (tensor<?xf32>) {
|
||||
// CHECK: vector.transfer_write
|
||||
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
|
||||
%t3 = vector.transfer_write %v1, %t1[%idx] : vector<5xf32>, tensor<?xf32>
|
||||
scf.yield %t3 : tensor<?xf32>
|
||||
} else {
|
||||
// CHECK: vector.transfer_write
|
||||
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
|
||||
%t4 = vector.transfer_write %v3, %t1[%idx] : vector<5xf32>, tensor<?xf32>
|
||||
scf.yield %t4 : tensor<?xf32>
|
||||
}
|
||||
scf.yield %t2 : tensor<?xf32>
|
||||
} else {
|
||||
// CHECK: vector.transfer_write
|
||||
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
|
||||
%t3 = vector.transfer_write %v2, %t1[%idx] : vector<5xf32>, tensor<?xf32>
|
||||
scf.yield %t3 : tensor<?xf32>
|
||||
}
|
||||
return %r : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @scf_if_inplace7
|
||||
func @scf_if_inplace7(%t1: tensor<?xf32> {linalg.inplaceable = true},
|
||||
%v1: vector<5xf32>, %v2: vector<5xf32>, %idx: index,
|
||||
%idx2: index, %cond: i1) -> (tensor<?xf32>, vector<5xf32>) {
|
||||
%cst = arith.constant 0.0 : f32
|
||||
%r, %v_r2 = scf.if %cond -> (tensor<?xf32>, vector<5xf32>) {
|
||||
// CHECK: vector.transfer_write
|
||||
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
|
||||
%t2 = vector.transfer_write %v1, %t1[%idx] : vector<5xf32>, tensor<?xf32>
|
||||
scf.yield %t2, %v1 : tensor<?xf32>, vector<5xf32>
|
||||
} else {
|
||||
// Writing the same tensor through an alias.
|
||||
// CHECK: vector.transfer_write
|
||||
// CHECK-SAME: {__inplace_results_attr__ = ["false"]
|
||||
%t3 = vector.transfer_write %v2, %t1[%idx] : vector<5xf32>, tensor<?xf32>
|
||||
// Read the original value of %t1. This requires the write in this branch
|
||||
// to be out-of-place. But the write in the other branch can still be
|
||||
// inplace.
|
||||
%v_r = vector.transfer_read %t1[%idx2], %cst : tensor<?xf32>, vector<5xf32>
|
||||
scf.yield %t3, %v_r : tensor<?xf32>, vector<5xf32>
|
||||
}
|
||||
return %r, %v_r2 : tensor<?xf32>, vector<5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @scf_if_out_of_place1a
|
||||
func @scf_if_out_of_place1a(%t1: tensor<?xf32> {linalg.inplaceable = true},
|
||||
%idx: index, %idx2: index,
|
||||
%cond: i1) -> tensor<?xf32> {
|
||||
%r = scf.if %cond -> (tensor<?xf32>) {
|
||||
// CHECK: tensor.extract_slice
|
||||
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
|
||||
%e = tensor.extract_slice %t1[%idx][%idx][1] : tensor<?xf32> to tensor<?xf32>
|
||||
scf.yield %e : tensor<?xf32>
|
||||
} else {
|
||||
scf.yield %t1 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// Reading from and writing to the same tensor via different args. This is a
|
||||
// conflict.
|
||||
// CHECK: tensor.insert_slice
|
||||
// CHECK-SAME: {__inplace_results_attr__ = ["false"]
|
||||
%r2 = tensor.insert_slice %r into %t1[%idx2][%idx2][1] : tensor<?xf32> into tensor<?xf32>
|
||||
return %r2 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @scf_if_out_of_place1b
|
||||
func @scf_if_out_of_place1b(%t1: tensor<?xf32> {linalg.inplaceable = true},
|
||||
%idx: index, %idx2: index, %idx3: index,
|
||||
%cond: i1) -> tensor<?xf32> {
|
||||
%r = scf.if %cond -> (tensor<?xf32>) {
|
||||
// CHECK: tensor.extract_slice
|
||||
// CHECK-SAME: {__inplace_results_attr__ = ["false"]
|
||||
%e = tensor.extract_slice %t1[%idx][%idx][1] : tensor<?xf32> to tensor<?xf32>
|
||||
scf.yield %e : tensor<?xf32>
|
||||
} else {
|
||||
// CHECK: tensor.extract_slice
|
||||
// CHECK-SAME: {__inplace_results_attr__ = ["false"]
|
||||
%f = tensor.extract_slice %t1[%idx2][%idx2][1] : tensor<?xf32> to tensor<?xf32>
|
||||
scf.yield %f : tensor<?xf32>
|
||||
}
|
||||
|
||||
// Reading from and writing to the same tensor via different args. This is a
|
||||
// conflict. In contrast to scf_if_out_of_place1a, the fact that %r aliases
|
||||
// with %t1 is only detected when analyzing the tensor.extract_slices. That's
|
||||
// why the tensor.insert_slice is inplace and the two extract_slices are
|
||||
// out-of-place.
|
||||
// CHECK: tensor.insert_slice
|
||||
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
|
||||
%r2 = tensor.insert_slice %r into %t1[%idx3][%idx3][1] : tensor<?xf32> into tensor<?xf32>
|
||||
return %r2 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @scf_if_out_of_place1c
|
||||
func @scf_if_out_of_place1c(%t1: tensor<?xf32> {linalg.inplaceable = true},
|
||||
%idx: index, %idx2: index, %cond: i1) -> tensor<?xf32> {
|
||||
%r = scf.if %cond -> (tensor<?xf32>) {
|
||||
// CHECK: tensor.extract_slice
|
||||
// CHECK-SAME: {__inplace_results_attr__ = ["false"]
|
||||
%e = tensor.extract_slice %t1[%idx][%idx][1] : tensor<?xf32> to tensor<?xf32>
|
||||
scf.yield %e : tensor<?xf32>
|
||||
} else {
|
||||
// TODO: This one could bufferize inplace, but the analysis is too restrictive.
|
||||
// CHECK: tensor.extract_slice
|
||||
// CHECK-SAME: {__inplace_results_attr__ = ["false"]
|
||||
%f = tensor.extract_slice %t1[%idx2][%idx2][1] : tensor<?xf32> to tensor<?xf32>
|
||||
scf.yield %f : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK: tensor.insert_slice
|
||||
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
|
||||
%r2 = tensor.insert_slice %r into %t1[%idx2][%idx2][1] : tensor<?xf32> into tensor<?xf32>
|
||||
return %r2 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @scf_if_out_of_place2
|
||||
func @scf_if_out_of_place2(%t1: tensor<?xf32> {linalg.inplaceable = true},
|
||||
%v: vector<5xf32>, %idx: index,
|
||||
%cond: i1) -> (tensor<?xf32>, vector<10xf32>) {
|
||||
%cst = arith.constant 0.0 : f32
|
||||
%r = scf.if %cond -> (tensor<?xf32>) {
|
||||
scf.yield %t1 : tensor<?xf32>
|
||||
} else {
|
||||
// CHECK: vector.transfer_write
|
||||
// CHECK-SAME: {__inplace_results_attr__ = ["false"]
|
||||
%t2 = vector.transfer_write %v, %t1[%idx] : vector<5xf32>, tensor<?xf32>
|
||||
scf.yield %t2 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// Read the old value of %t1. Forces the transfer_write to bufferize
|
||||
// out-of-place.
|
||||
%v2 = vector.transfer_read %t1[%idx], %cst : tensor<?xf32>, vector<10xf32>
|
||||
return %r, %v2 : tensor<?xf32>, vector<10xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @scf_if_out_of_place3
|
||||
func @scf_if_out_of_place3(%t1: tensor<?xf32> {linalg.inplaceable = true},
|
||||
%v: vector<5xf32>, %idx: index,
|
||||
%cond: i1, %cond2: i1) -> (tensor<?xf32>, vector<10xf32>) {
|
||||
%cst = arith.constant 0.0 : f32
|
||||
%r = scf.if %cond -> (tensor<?xf32>) {
|
||||
scf.yield %t1 : tensor<?xf32>
|
||||
} else {
|
||||
// CHECK: vector.transfer_write
|
||||
// CHECK-SAME: {__inplace_results_attr__ = ["false"]
|
||||
%t2 = vector.transfer_write %v, %t1[%idx] : vector<5xf32>, tensor<?xf32>
|
||||
scf.yield %t2 : tensor<?xf32>
|
||||
}
|
||||
%t1_alias = scf.if %cond2 -> (tensor<?xf32>) {
|
||||
// scf.yield bufferizes to a read. That is a conflict in this example.
|
||||
scf.yield %t1 : tensor<?xf32>
|
||||
} else {
|
||||
scf.yield %t1 : tensor<?xf32>
|
||||
}
|
||||
%v2 = vector.transfer_read %t1_alias[%idx], %cst : tensor<?xf32>, vector<10xf32>
|
||||
return %r, %v2 : tensor<?xf32>, vector<10xf32>
|
||||
}
|
||||
|
||||
|
|
|
@ -113,8 +113,8 @@ func @extract_slice_fun(%A : tensor<?xf32> {linalg.inplaceable = true})
|
|||
|
||||
func @scf_yield(%b : i1, %A : tensor<4xf32>, %B : tensor<4xf32>) -> tensor<4xf32>
|
||||
{
|
||||
// expected-error @+1 {{result buffer is ambiguous}}
|
||||
%r = scf.if %b -> (tensor<4xf32>) {
|
||||
// expected-error @+1 {{expected scf::ForOp parent for scf::YieldOp}}
|
||||
scf.yield %A : tensor<4xf32>
|
||||
} else {
|
||||
scf.yield %B : tensor<4xf32>
|
||||
|
|
|
@ -861,3 +861,25 @@ func @buffer_forwarding_no_conflict(
|
|||
return %r1: tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @scf_if_inplace(
|
||||
// CHECK-SAME: %[[cond:.*]]: i1, %[[t1:.*]]: memref<?xf32{{.*}}>, %[[v:.*]]: vector
|
||||
func @scf_if_inplace(%cond: i1,
|
||||
%t1: tensor<?xf32> {linalg.inplaceable = true},
|
||||
%v: vector<5xf32>, %idx: index) -> tensor<?xf32> {
|
||||
|
||||
// CHECK: scf.if %[[cond]] {
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: vector.transfer_write %[[v]], %[[t1]]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return
|
||||
%r = scf.if %cond -> (tensor<?xf32>) {
|
||||
scf.yield %t1 : tensor<?xf32>
|
||||
} else {
|
||||
%t2 = vector.transfer_write %v, %t1[%idx] : vector<5xf32>, tensor<?xf32>
|
||||
scf.yield %t2 : tensor<?xf32>
|
||||
}
|
||||
return %r : tensor<?xf32>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue