[mlir][bufferize] Fix bug in module equivalence analysis

CallOp result are not equivalent to an OpOperand if the OpOperand bufferizes out-of-place.

Differential Revision: https://reviews.llvm.org/D126813
This commit is contained in:
Matthias Springer 2022-06-09 18:30:42 +02:00
parent 1efe354088
commit bf58256967
2 changed files with 8 additions and 4 deletions

View File

@ -258,7 +258,8 @@ static func::FuncOp getCalledFunction(CallOpInterface callOp) {
// TODO: This does not handle cyclic function call graphs etc.
static void equivalenceAnalysis(func::FuncOp funcOp,
BufferizationAliasInfo &aliasInfo,
FuncAnalysisState &funcState) {
OneShotAnalysisState &state) {
FuncAnalysisState &funcState = getFuncAnalysisState(state);
funcOp->walk([&](func::CallOp callOp) {
func::FuncOp calledFunction = getCalledFunction(callOp);
assert(calledFunction && "could not retrieved called func::FuncOp");
@ -270,6 +271,8 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
for (auto it : funcState.equivalentFuncArgs[calledFunction]) {
int64_t returnIdx = it.first;
int64_t bbargIdx = it.second;
if (!state.isInPlace(callOp->getOpOperand(bbargIdx)))
continue;
Value returnVal = callOp.getResult(returnIdx);
Value argVal = callOp->getOperand(bbargIdx);
aliasInfo.unionEquivalenceClasses(returnVal, argVal);
@ -409,7 +412,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
funcState.startFunctionAnalysis(funcOp);
// Gather equivalence info for CallOps.
equivalenceAnalysis(funcOp, aliasInfo, funcState);
equivalenceAnalysis(funcOp, aliasInfo, state);
// Analyze funcOp.
if (failed(analyzeOp(funcOp, state)))

View File

@ -196,8 +196,9 @@ func.func @call_func_with_non_tensor_return(
// CHECK: %[[call:.*]] = call @inner_func(%[[casted]])
%0, %1 = call @inner_func(%t0) : (tensor<?xf32>) -> (tensor<?xf32>, f32)
// Note: The tensor return value has folded away.
// CHECK: return %[[call]] : f32
// Note: The tensor return value cannot fold away because the CallOp
// bufferized out-of-place.
// CHECK: return %[[call]], %[[alloc]] : f32, memref<?xf32>
return %1, %0 : f32, tensor<?xf32>
}