forked from OSchip/llvm-project
[mlir][Linalg] Improve comprehensive bufferization for scf.yield.
Previously, comprehensive bufferization of scf.yield did not have enough information to detect whether an enclosing scf::for bbargs would bufferize to a buffer equivalent to that of the matching scf::yield operand. As a consequence a separate sanity check step would be required to determine whether bufferization occured properly. This late check would miss the case of calling a function in an loop. Instead, we now pass and update aliasInfo during bufferization and it is possible to imrpove bufferization of scf::yield and drop that post-pass check. Add an example use case that was failing previously. This slightly modifies the error conditions, which are also updated as part of this revision. Differential Revision: https://reviews.llvm.org/D105803
This commit is contained in:
parent
326b0054fd
commit
6b1668397f
|
@ -2075,14 +2075,24 @@ static LogicalResult bufferize(OpBuilder &b, scf::YieldOp yieldOp,
|
|||
auto tensorType = operand.get().getType().dyn_cast<TensorType>();
|
||||
if (!tensorType)
|
||||
continue;
|
||||
|
||||
OpOperand &forOperand = forOp.getOpOperandForResult(
|
||||
forOp->getResult(operand.getOperandNumber()));
|
||||
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
|
||||
if (getInPlace(bbArg) == InPlaceSpec::True)
|
||||
operand.set(bbArg);
|
||||
else
|
||||
operand.set(
|
||||
b.create<memref::TensorLoadOp>(yieldOp.getLoc(), lookup(bvm, bbArg)));
|
||||
Value yieldedBuffer = lookup(bvm, operand.get());
|
||||
Value bbArgBuffer = lookup(bvm, bbArg);
|
||||
if (!aliasInfo.areEquivalentBufferizedValues(yieldedBuffer, bbArgBuffer)) {
|
||||
// TODO: this could get resolved with copies but it can also turn into
|
||||
// swaps so we need to be careful about order of copies.
|
||||
return yieldOp->emitError()
|
||||
<< "Yield operand #" << operand.getOperandNumber()
|
||||
<< " does not bufferize to an equivalent buffer to the matching"
|
||||
<< " enclosing scf::for operand";
|
||||
}
|
||||
|
||||
// Buffers are equivalent so the work is already done and we just yield the
|
||||
// bbArg so that it later canonicalizes away.
|
||||
operand.set(bbArg);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
@ -2205,38 +2215,6 @@ bufferizableInPlaceAnalysis(OpOperand &operand, OpResult result,
|
|||
return success();
|
||||
}
|
||||
|
||||
/// Return `failure()` if either
|
||||
/// scf::YieldOp are not explicitly bufferized and we need to perform a separate
|
||||
/// sanity check for now.
|
||||
static LogicalResult
|
||||
bufferizationSanityCheck(scf::YieldOp yieldOp,
|
||||
const BufferizationAliasInfo &aliasInfo) {
|
||||
auto parentForOp = yieldOp->getParentOfType<scf::ForOp>();
|
||||
if (!parentForOp)
|
||||
return yieldOp->emitError() << "not nested under ForOp";
|
||||
|
||||
for (OpOperand &operand : yieldOp->getOpOperands()) {
|
||||
OpResult matchingForOpResult =
|
||||
parentForOp->getResult(operand.getOperandNumber());
|
||||
// Nothing to do if operand bufferizes out of place.
|
||||
if (getInPlace(matchingForOpResult) != InPlaceSpec::True)
|
||||
continue;
|
||||
OpOperand &machingForOpOperand =
|
||||
parentForOp.getOpOperandForResult(matchingForOpResult);
|
||||
BlockArgument matchingForOpIterArg =
|
||||
parentForOp.getRegionIterArgForOpOperand(machingForOpOperand);
|
||||
if (!aliasInfo.areEquivalentBufferizedValues(matchingForOpIterArg,
|
||||
operand.get())) {
|
||||
return yieldOp->emitError()
|
||||
<< "Yield operand #" << operand.getOperandNumber()
|
||||
<< " does not bufferize to an equivalent buffer to the matching"
|
||||
<< " enclosing scf::for operand -> Fail the pass\n";
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Analyze the `funcOp` body to determine which OpResults are inplaceable.
|
||||
static LogicalResult
|
||||
inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
|
||||
|
@ -2275,13 +2253,14 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
|
|||
return failure();
|
||||
}
|
||||
|
||||
// Bufferize all ops except ExtractSliceOp and InsertSliceOp which are handled
|
||||
// separately.
|
||||
// Analyze all ops that return a tensors, except ExtractSliceOp and
|
||||
// InsertSliceOp which are handled separately.
|
||||
// Walk other ops in reverse for better interference behavior.
|
||||
for (Operation *op : reverse(nonSliceOps))
|
||||
for (OpOperand &opOperand : op->getOpOperands())
|
||||
if (OpResult result = getInplaceableOpResult(opOperand))
|
||||
if (failed(bufferizableInPlaceAnalysis(opOperand, result, aliasInfo,
|
||||
if (result.getType().isa<TensorType>() &&
|
||||
failed(bufferizableInPlaceAnalysis(opOperand, result, aliasInfo,
|
||||
domInfo)))
|
||||
return failure();
|
||||
|
||||
|
@ -2292,14 +2271,9 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
|
|||
if (failed(bufferizableInPlaceAnalysis(extractSliceOp, aliasInfo, domInfo)))
|
||||
return failure();
|
||||
|
||||
// Sanity checks.
|
||||
auto walkResult = funcOp.walk([&](scf::YieldOp yieldOp) -> WalkResult {
|
||||
return bufferizationSanityCheck(yieldOp, aliasInfo);
|
||||
});
|
||||
|
||||
LDBG("End InPlaceAnalysisFuncOpInternals:\n" << funcOp << '\n');
|
||||
|
||||
return success(!walkResult.wasInterrupted());
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -18,7 +18,7 @@ func private @foo() -> tensor<?xf32>
|
|||
|
||||
// expected-error @+1 {{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}}
|
||||
func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>)
|
||||
-> (tensor<f32>, tensor<f32>)
|
||||
-> (tensor<f32>, tensor<f32>)
|
||||
{
|
||||
cond_br %cond1, ^bb1, ^bb2
|
||||
|
||||
|
@ -64,7 +64,7 @@ func @scf_for(%A : tensor<?xf32>,
|
|||
// Throw a wrench in the system by swapping yielded values: this result in a
|
||||
// ping-pong of values at each iteration on which we currently want to fail.
|
||||
|
||||
// expected-error @+1 {{Yield operand #1 does not bufferize to an equivalent buffer}}
|
||||
// expected-error @+1 {{Yield operand #0 does not bufferize to an equivalent buffer}}
|
||||
scf.yield %ttB, %ttA : tensor<?xf32>, tensor<?xf32>
|
||||
}
|
||||
|
||||
|
@ -73,6 +73,27 @@ func @scf_for(%A : tensor<?xf32>,
|
|||
|
||||
// -----
|
||||
|
||||
func private @fun_with_side_effects(%A: tensor<?xf32> {linalg.inplaceable = true})
|
||||
|
||||
func @foo(%A: tensor<?xf32> {linalg.inplaceable = true}) -> (tensor<?xf32>) {
|
||||
call @fun_with_side_effects(%A) : (tensor<?xf32>) -> ()
|
||||
return %A: tensor<?xf32>
|
||||
}
|
||||
|
||||
func @scf_yield_needs_copy(%A : tensor<?xf32> {linalg.inplaceable = true}, %iters : index) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%res = scf.for %arg0 = %c0 to %iters step %c1 iter_args(%bbarg = %A) -> (tensor<?xf32>) {
|
||||
%r = call @foo(%A) : (tensor<?xf32>) -> (tensor<?xf32>)
|
||||
// expected-error @+1 {{Yield operand #0 does not bufferize to an equivalent buffer}}
|
||||
scf.yield %r : tensor<?xf32>
|
||||
}
|
||||
call @fun_with_side_effects(%res) : (tensor<?xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @extract_slice_fun(%A : tensor<?xf32> {linalg.inplaceable = true})
|
||||
-> tensor<4xf32>
|
||||
{
|
||||
|
@ -92,8 +113,8 @@ func @extract_slice_fun(%A : tensor<?xf32> {linalg.inplaceable = true})
|
|||
|
||||
func @scf_yield(%b : i1, %A : tensor<4xf32>, %B : tensor<4xf32>) -> tensor<4xf32>
|
||||
{
|
||||
%r = scf.if %b -> (tensor<4xf32>) {
|
||||
// expected-error @+1 {{not nested under ForOp}}
|
||||
// expected-error @+1 {{unsupported op with tensors}}
|
||||
%r = scf.if %b -> (tensor<4xf32>) {
|
||||
scf.yield %A : tensor<4xf32>
|
||||
} else {
|
||||
scf.yield %B : tensor<4xf32>
|
||||
|
|
Loading…
Reference in New Issue