[mlir][bufferize][NFC] Clean up ModuleBufferizationState

* Store bbArg indices instead of BlockArguments, so that args can be changed during bufferizationn.
* Use type aliases for better readability.

Differential Revision: https://reviews.llvm.org/D123191
This commit is contained in:
Matthias Springer 2022-04-06 18:01:41 +09:00
parent 4d21497006
commit 7a50560354
1 changed files with 37 additions and 16 deletions

View File

@ -93,27 +93,48 @@ enum class FuncOpAnalysisState { NotAnalyzed, InProgress, Analyzed };
/// Extra analysis state that is required for bufferization of function
/// boundaries.
struct ModuleAnalysisState : public DialectAnalysisState {
/// A set of block argument indices.
using BbArgIndexSet = DenseSet<int64_t>;
/// A mapping of indices to indices.
using IndexMapping = DenseMap<int64_t, int64_t>;
/// A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg
/// indices.
DenseMap<FuncOp, DenseMap<int64_t, int64_t>> equivalentFuncArgs;
DenseMap<FuncOp, IndexMapping> equivalentFuncArgs;
/// A set of all read BlockArguments of FuncOps.
// Note: BlockArgument knows about its owner, so we do not need to store
// FuncOps here.
DenseSet<BlockArgument> readBbArgs;
DenseMap<FuncOp, BbArgIndexSet> readBbArgs;
/// A set of all written-to BlockArguments of FuncOps.
DenseSet<BlockArgument> writtenBbArgs;
DenseMap<FuncOp, BbArgIndexSet> writtenBbArgs;
/// Keep track of which FuncOps are fully analyzed or currently being
/// analyzed.
DenseMap<FuncOp, FuncOpAnalysisState> analyzedFuncOps;
// A list of functions in the order in which they are analyzed + bufferized.
/// A list of functions in the order in which they are analyzed + bufferized.
SmallVector<FuncOp> orderedFuncOps;
// A mapping of FuncOps to their callers.
/// A mapping of FuncOps to their callers.
DenseMap<FuncOp, DenseSet<Operation *>> callerMap;
/// This function is called right before analyzing the given FuncOp. It
/// initializes the data structures for the FuncOp in this state object.
void startFunctionAnalysis(FuncOp funcOp) {
analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress;
auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping());
auto createdRead = readBbArgs.try_emplace(funcOp, BbArgIndexSet());
auto createdWritten = writtenBbArgs.try_emplace(funcOp, BbArgIndexSet());
(void)createdEquiv;
(void)createdRead;
(void)createdWritten;
#ifndef NDEBUG
assert(createdEquiv.second && "equivalence info exists already");
assert(createdRead.second && "bbarg access info exists already");
assert(createdWritten.second && "bbarg access info exists already");
#endif // NDEBUG
}
};
} // namespace
@ -267,8 +288,8 @@ funcOpBbArgReadWriteAnalysis(Operation *op, AnalysisState &state,
// read + written.
if (funcOp.getBody().empty()) {
for (BlockArgument bbArg : funcOp.getArguments()) {
moduleState.readBbArgs.insert(bbArg);
moduleState.writtenBbArgs.insert(bbArg);
moduleState.readBbArgs[funcOp].insert(bbArg.getArgNumber());
moduleState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber());
}
return success();
@ -282,9 +303,9 @@ funcOpBbArgReadWriteAnalysis(Operation *op, AnalysisState &state,
if (state.getOptions().testAnalysisOnly)
annotateFuncArgAccess(funcOp, bbArg, isRead, isWritten);
if (isRead)
moduleState.readBbArgs.insert(bbArg);
moduleState.readBbArgs[funcOp].insert(bbArg.getArgNumber());
if (isWritten)
moduleState.writtenBbArgs.insert(bbArg);
moduleState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber());
}
return success();
@ -704,8 +725,8 @@ struct CallOpInterface
// FuncOp not analyzed yet. Assume that OpOperand is read.
return true;
return moduleState.readBbArgs.contains(
funcOp.getArgument(opOperand.getOperandNumber()));
return moduleState.readBbArgs.lookup(funcOp).contains(
opOperand.getOperandNumber());
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
@ -719,8 +740,8 @@ struct CallOpInterface
// FuncOp not analyzed yet. Assume that OpOperand is written.
return true;
return moduleState.writtenBbArgs.contains(
funcOp.getArgument(opOperand.getOperandNumber()));
return moduleState.writtenBbArgs.lookup(funcOp).contains(
opOperand.getOperandNumber());
}
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
@ -1010,7 +1031,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize(
continue;
// Now analyzing function.
moduleState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress;
moduleState.startFunctionAnalysis(funcOp);
// Analyze funcOp.
if (failed(analyzeOp(funcOp, analysisState)))