[mlir][Linalg] Improve comprehensive bufferization for scf.yield.

Previously, comprehensive bufferization of scf.yield did not have enough information
to detect whether an enclosing scf::for bbargs would bufferize to a buffer equivalent
to that of the matching scf::yield operand.
As a consequence a separate sanity check step would be required to determine whether
bufferization occured properly.
This late check would miss the case of calling a function in an loop.

Instead, we now pass and update aliasInfo during bufferization and it is possible to
imrpove bufferization of scf::yield and drop that post-pass check.

Add an example use case that was failing previously.
This slightly modifies the error conditions, which are also updated as part of this
revision.

Differential Revision: https://reviews.llvm.org/D105803
This commit is contained in:
Nicolas Vasilache 2021-07-12 10:10:26 +00:00
parent 326b0054fd
commit 6b1668397f
2 changed files with 45 additions and 50 deletions

View File

@ -2075,14 +2075,24 @@ static LogicalResult bufferize(OpBuilder &b, scf::YieldOp yieldOp,
auto tensorType = operand.get().getType().dyn_cast<TensorType>();
if (!tensorType)
continue;
OpOperand &forOperand = forOp.getOpOperandForResult(
forOp->getResult(operand.getOperandNumber()));
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
if (getInPlace(bbArg) == InPlaceSpec::True)
operand.set(bbArg);
else
operand.set(
b.create<memref::TensorLoadOp>(yieldOp.getLoc(), lookup(bvm, bbArg)));
Value yieldedBuffer = lookup(bvm, operand.get());
Value bbArgBuffer = lookup(bvm, bbArg);
if (!aliasInfo.areEquivalentBufferizedValues(yieldedBuffer, bbArgBuffer)) {
// TODO: this could get resolved with copies but it can also turn into
// swaps so we need to be careful about order of copies.
return yieldOp->emitError()
<< "Yield operand #" << operand.getOperandNumber()
<< " does not bufferize to an equivalent buffer to the matching"
<< " enclosing scf::for operand";
}
// Buffers are equivalent so the work is already done and we just yield the
// bbArg so that it later canonicalizes away.
operand.set(bbArg);
}
return success();
}
@ -2205,38 +2215,6 @@ bufferizableInPlaceAnalysis(OpOperand &operand, OpResult result,
return success();
}
/// Return `failure()` if either
/// scf::YieldOp are not explicitly bufferized and we need to perform a separate
/// sanity check for now.
static LogicalResult
bufferizationSanityCheck(scf::YieldOp yieldOp,
const BufferizationAliasInfo &aliasInfo) {
auto parentForOp = yieldOp->getParentOfType<scf::ForOp>();
if (!parentForOp)
return yieldOp->emitError() << "not nested under ForOp";
for (OpOperand &operand : yieldOp->getOpOperands()) {
OpResult matchingForOpResult =
parentForOp->getResult(operand.getOperandNumber());
// Nothing to do if operand bufferizes out of place.
if (getInPlace(matchingForOpResult) != InPlaceSpec::True)
continue;
OpOperand &machingForOpOperand =
parentForOp.getOpOperandForResult(matchingForOpResult);
BlockArgument matchingForOpIterArg =
parentForOp.getRegionIterArgForOpOperand(machingForOpOperand);
if (!aliasInfo.areEquivalentBufferizedValues(matchingForOpIterArg,
operand.get())) {
return yieldOp->emitError()
<< "Yield operand #" << operand.getOperandNumber()
<< " does not bufferize to an equivalent buffer to the matching"
<< " enclosing scf::for operand -> Fail the pass\n";
}
}
return success();
}
/// Analyze the `funcOp` body to determine which OpResults are inplaceable.
static LogicalResult
inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
@ -2275,13 +2253,14 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
return failure();
}
// Bufferize all ops except ExtractSliceOp and InsertSliceOp which are handled
// separately.
// Analyze all ops that return a tensors, except ExtractSliceOp and
// InsertSliceOp which are handled separately.
// Walk other ops in reverse for better interference behavior.
for (Operation *op : reverse(nonSliceOps))
for (OpOperand &opOperand : op->getOpOperands())
if (OpResult result = getInplaceableOpResult(opOperand))
if (failed(bufferizableInPlaceAnalysis(opOperand, result, aliasInfo,
if (result.getType().isa<TensorType>() &&
failed(bufferizableInPlaceAnalysis(opOperand, result, aliasInfo,
domInfo)))
return failure();
@ -2292,14 +2271,9 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
if (failed(bufferizableInPlaceAnalysis(extractSliceOp, aliasInfo, domInfo)))
return failure();
// Sanity checks.
auto walkResult = funcOp.walk([&](scf::YieldOp yieldOp) -> WalkResult {
return bufferizationSanityCheck(yieldOp, aliasInfo);
});
LDBG("End InPlaceAnalysisFuncOpInternals:\n" << funcOp << '\n');
return success(!walkResult.wasInterrupted());
return success();
}
//===----------------------------------------------------------------------===//

View File

@ -18,7 +18,7 @@ func private @foo() -> tensor<?xf32>
// expected-error @+1 {{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}}
func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>)
-> (tensor<f32>, tensor<f32>)
-> (tensor<f32>, tensor<f32>)
{
cond_br %cond1, ^bb1, ^bb2
@ -64,7 +64,7 @@ func @scf_for(%A : tensor<?xf32>,
// Throw a wrench in the system by swapping yielded values: this result in a
// ping-pong of values at each iteration on which we currently want to fail.
// expected-error @+1 {{Yield operand #1 does not bufferize to an equivalent buffer}}
// expected-error @+1 {{Yield operand #0 does not bufferize to an equivalent buffer}}
scf.yield %ttB, %ttA : tensor<?xf32>, tensor<?xf32>
}
@ -73,6 +73,27 @@ func @scf_for(%A : tensor<?xf32>,
// -----
func private @fun_with_side_effects(%A: tensor<?xf32> {linalg.inplaceable = true})
func @foo(%A: tensor<?xf32> {linalg.inplaceable = true}) -> (tensor<?xf32>) {
call @fun_with_side_effects(%A) : (tensor<?xf32>) -> ()
return %A: tensor<?xf32>
}
func @scf_yield_needs_copy(%A : tensor<?xf32> {linalg.inplaceable = true}, %iters : index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%res = scf.for %arg0 = %c0 to %iters step %c1 iter_args(%bbarg = %A) -> (tensor<?xf32>) {
%r = call @foo(%A) : (tensor<?xf32>) -> (tensor<?xf32>)
// expected-error @+1 {{Yield operand #0 does not bufferize to an equivalent buffer}}
scf.yield %r : tensor<?xf32>
}
call @fun_with_side_effects(%res) : (tensor<?xf32>) -> ()
return
}
// -----
func @extract_slice_fun(%A : tensor<?xf32> {linalg.inplaceable = true})
-> tensor<4xf32>
{
@ -92,8 +113,8 @@ func @extract_slice_fun(%A : tensor<?xf32> {linalg.inplaceable = true})
func @scf_yield(%b : i1, %A : tensor<4xf32>, %B : tensor<4xf32>) -> tensor<4xf32>
{
%r = scf.if %b -> (tensor<4xf32>) {
// expected-error @+1 {{not nested under ForOp}}
// expected-error @+1 {{unsupported op with tensors}}
%r = scf.if %b -> (tensor<4xf32>) {
scf.yield %A : tensor<4xf32>
} else {
scf.yield %B : tensor<4xf32>