[mlir][AsmPrinter] Remove recursion while SSA naming

Address the TODO of removing recursion while SSA naming.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D102226
This commit is contained in:
Chia-hung Duan 2021-05-12 11:21:25 +08:00
parent 7d101e0f6a
commit f653313d4a
1 changed files with 53 additions and 20 deletions

View File

@ -36,6 +36,9 @@
#include "llvm/Support/Endian.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/SaveAndRestore.h"
#include <tuple>
using namespace mlir;
using namespace mlir::detail;
@ -835,11 +838,58 @@ private:
SSANameState::SSANameState(
Operation *op,
DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
llvm::ScopedHashTable<StringRef, char>::ScopeTy usedNamesScope(usedNames);
llvm::SaveAndRestore<unsigned> valueIDSaver(nextValueID);
llvm::SaveAndRestore<unsigned> argumentIDSaver(nextArgumentID);
llvm::SaveAndRestore<unsigned> conflictIDSaver(nextConflictID);
// The context includes nextValueID, nextArgumentID, nextConflictID and scoped
// HashTable.
using hashTableScopeTy = llvm::ScopedHashTable<StringRef, char>::ScopeTy;
// A namingContext carries the information inherits from parent region.
using namingContext =
std::tuple<Region *, unsigned, unsigned, unsigned, hashTableScopeTy *>;
// Allocator for hashTableScopeTy
llvm::BumpPtrAllocator allocator;
SmallVector<namingContext, 8> nameContext;
for (Region &region : op->getRegions())
nameContext.push_back(std::make_tuple(&region, nextValueID, nextArgumentID,
nextConflictID, nullptr));
numberValuesInOp(*op, interfaces);
for (auto &region : op->getRegions())
numberValuesInRegion(region, interfaces);
while (!nameContext.empty()) {
Region *region;
hashTableScopeTy *parentScope;
std::tie(region, nextValueID, nextArgumentID, nextConflictID, parentScope) =
nameContext.pop_back_val();
// When we switch from one subtree to another, pop the scopes(needless)
// until the parent scope.
while (usedNames.getCurScope() != parentScope) {
usedNames.getCurScope()->~hashTableScopeTy();
assert((usedNames.getCurScope() != nullptr || parentScope == nullptr) &&
"top level parentScope must be a nullptr");
}
// Add a scope for the current region.
auto *curNamesScope = allocator.Allocate<hashTableScopeTy>();
new (curNamesScope) hashTableScopeTy(usedNames);
numberValuesInRegion(*region, interfaces);
for (Block &block : *region) {
for (Operation &op : block)
for (Region &region : op.getRegions())
nameContext.push_back(std::make_tuple(&region, nextValueID,
nextArgumentID, nextConflictID,
curNamesScope));
}
}
// Manually remove all the scopes.
while (usedNames.getCurScope() != nullptr)
usedNames.getCurScope()->~hashTableScopeTy();
}
void SSANameState::printValueID(Value value, bool printResultNo,
@ -918,15 +968,6 @@ void SSANameState::shadowRegionArgs(Region &region, ValueRange namesToUse) {
void SSANameState::numberValuesInRegion(
Region &region,
DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
// Save the current value ids to allow for numbering values in sibling regions
// the same.
llvm::SaveAndRestore<unsigned> valueIDSaver(nextValueID);
llvm::SaveAndRestore<unsigned> argumentIDSaver(nextArgumentID);
llvm::SaveAndRestore<unsigned> conflictIDSaver(nextConflictID);
// Push a new used names scope.
llvm::ScopedHashTable<StringRef, char>::ScopeTy usedNamesScope(usedNames);
// Number the values within this region in a breadth-first order.
unsigned nextBlockID = 0;
for (auto &block : region) {
@ -935,14 +976,6 @@ void SSANameState::numberValuesInRegion(
blockIDs[&block] = nextBlockID++;
numberValuesInBlock(block, interfaces);
}
// After that we traverse the nested regions.
// TODO: Rework this loop to not use recursion.
for (auto &block : region) {
for (auto &op : block)
for (auto &nestedRegion : op.getRegions())
numberValuesInRegion(nestedRegion, interfaces);
}
}
void SSANameState::numberValuesInBlock(