[Dominators][AMDGPU] Don't use virtual exit node in findNearestCommonDominator. Cleanup MachinePostDominators.

Summary:
This patch fixes a bug that originated from passing a virtual exit block (nullptr) to `MachinePostDominatorTee::findNearestCommonDominator` and resulted in assertion failures inside its callee. It also applies a small cleanup to the class.

The patch introduces a new function in PDT that given a list of `MachineBasicBlock`s finds their NCD. The new overload of `findNearestCommonDominator` handles virtual root correctly.

Note that similar handling of virtual root nodes is not necessary in (forward) `DominatorTree`s, as right now they don't use virtual roots.

Reviewers: tstellar, tpr, nhaehnle, arsenm, NutshellySima, grosser, hliao

Reviewed By: hliao

Subscribers: hliao, kzhuravl, jvesely, wdng, yaxunl, dstuttard, t-tye, hiraditya, llvm-commits

Tags: #amdgpu, #llvm

Differential Revision: https://reviews.llvm.org/D67974

llvm-svn: 372874
This commit is contained in:
Jakub Kuderski 2019-09-25 14:04:36 +00:00
parent c5d90e4b5c
commit 269bd15c68
4 changed files with 65 additions and 46 deletions

View File

@ -16,68 +16,75 @@
#include "llvm/CodeGen/MachineDominators.h" #include "llvm/CodeGen/MachineDominators.h"
#include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/CodeGen/MachineFunctionPass.h"
#include <memory>
namespace llvm { namespace llvm {
/// ///
/// PostDominatorTree Class - Concrete subclass of DominatorTree that is used /// MachinePostDominatorTree - an analysis pass wrapper for DominatorTree
/// to compute the post-dominator tree. /// used to compute the post-dominator tree for MachineFunctions.
/// ///
struct MachinePostDominatorTree : public MachineFunctionPass { class MachinePostDominatorTree : public MachineFunctionPass {
private: using PostDomTreeT = PostDomTreeBase<MachineBasicBlock>;
PostDomTreeBase<MachineBasicBlock> *DT; std::unique_ptr<PostDomTreeT> PDT;
public: public:
static char ID; static char ID;
MachinePostDominatorTree(); MachinePostDominatorTree();
~MachinePostDominatorTree() override;
FunctionPass *createMachinePostDominatorTreePass(); FunctionPass *createMachinePostDominatorTreePass();
const SmallVectorImpl<MachineBasicBlock *> &getRoots() const { const SmallVectorImpl<MachineBasicBlock *> &getRoots() const {
return DT->getRoots(); return PDT->getRoots();
} }
MachineDomTreeNode *getRootNode() const { MachineDomTreeNode *getRootNode() const { return PDT->getRootNode(); }
return DT->getRootNode();
}
MachineDomTreeNode *operator[](MachineBasicBlock *BB) const { MachineDomTreeNode *operator[](MachineBasicBlock *BB) const {
return DT->getNode(BB); return PDT->getNode(BB);
} }
MachineDomTreeNode *getNode(MachineBasicBlock *BB) const { MachineDomTreeNode *getNode(MachineBasicBlock *BB) const {
return DT->getNode(BB); return PDT->getNode(BB);
} }
bool dominates(const MachineDomTreeNode *A, bool dominates(const MachineDomTreeNode *A,
const MachineDomTreeNode *B) const { const MachineDomTreeNode *B) const {
return DT->dominates(A, B); return PDT->dominates(A, B);
} }
bool dominates(const MachineBasicBlock *A, const MachineBasicBlock *B) const { bool dominates(const MachineBasicBlock *A, const MachineBasicBlock *B) const {
return DT->dominates(A, B); return PDT->dominates(A, B);
} }
bool properlyDominates(const MachineDomTreeNode *A, bool properlyDominates(const MachineDomTreeNode *A,
const MachineDomTreeNode *B) const { const MachineDomTreeNode *B) const {
return DT->properlyDominates(A, B); return PDT->properlyDominates(A, B);
} }
bool properlyDominates(const MachineBasicBlock *A, bool properlyDominates(const MachineBasicBlock *A,
const MachineBasicBlock *B) const { const MachineBasicBlock *B) const {
return DT->properlyDominates(A, B); return PDT->properlyDominates(A, B);
}
bool isVirtualRoot(const MachineDomTreeNode *Node) const {
return PDT->isVirtualRoot(Node);
} }
MachineBasicBlock *findNearestCommonDominator(MachineBasicBlock *A, MachineBasicBlock *findNearestCommonDominator(MachineBasicBlock *A,
MachineBasicBlock *B) { MachineBasicBlock *B) const {
return DT->findNearestCommonDominator(A, B); return PDT->findNearestCommonDominator(A, B);
} }
/// Returns the nearest common dominator of the given blocks.
/// If that tree node is a virtual root, a nullptr will be returned.
MachineBasicBlock *
findNearestCommonDominator(ArrayRef<MachineBasicBlock *> Blocks) const;
bool runOnMachineFunction(MachineFunction &MF) override; bool runOnMachineFunction(MachineFunction &MF) override;
void getAnalysisUsage(AnalysisUsage &AU) const override; void getAnalysisUsage(AnalysisUsage &AU) const override;
void releaseMemory() override { PDT.reset(nullptr); }
void print(llvm::raw_ostream &OS, const Module *M = nullptr) const override; void print(llvm::raw_ostream &OS, const Module *M = nullptr) const override;
}; };
} //end of namespace llvm } //end of namespace llvm

View File

@ -22,7 +22,7 @@
namespace llvm { namespace llvm {
struct MachinePostDominatorTree; class MachinePostDominatorTree;
class MachineRegion; class MachineRegion;
class MachineRegionNode; class MachineRegionNode;
class MachineRegionInfo; class MachineRegionInfo;

View File

@ -13,6 +13,8 @@
#include "llvm/CodeGen/MachinePostDominators.h" #include "llvm/CodeGen/MachinePostDominators.h"
#include "llvm/ADT/STLExtras.h"
using namespace llvm; using namespace llvm;
namespace llvm { namespace llvm {
@ -25,33 +27,43 @@ char MachinePostDominatorTree::ID = 0;
INITIALIZE_PASS(MachinePostDominatorTree, "machinepostdomtree", INITIALIZE_PASS(MachinePostDominatorTree, "machinepostdomtree",
"MachinePostDominator Tree Construction", true, true) "MachinePostDominator Tree Construction", true, true)
MachinePostDominatorTree::MachinePostDominatorTree() : MachineFunctionPass(ID) { MachinePostDominatorTree::MachinePostDominatorTree()
: MachineFunctionPass(ID), PDT(nullptr) {
initializeMachinePostDominatorTreePass(*PassRegistry::getPassRegistry()); initializeMachinePostDominatorTreePass(*PassRegistry::getPassRegistry());
DT = new PostDomTreeBase<MachineBasicBlock>();
} }
FunctionPass * FunctionPass *MachinePostDominatorTree::createMachinePostDominatorTreePass() {
MachinePostDominatorTree::createMachinePostDominatorTreePass() {
return new MachinePostDominatorTree(); return new MachinePostDominatorTree();
} }
bool bool MachinePostDominatorTree::runOnMachineFunction(MachineFunction &F) {
MachinePostDominatorTree::runOnMachineFunction(MachineFunction &F) { PDT = std::make_unique<PostDomTreeT>();
DT->recalculate(F); PDT->recalculate(F);
return false; return false;
} }
MachinePostDominatorTree::~MachinePostDominatorTree() { void MachinePostDominatorTree::getAnalysisUsage(AnalysisUsage &AU) const {
delete DT;
}
void
MachinePostDominatorTree::getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesAll(); AU.setPreservesAll();
MachineFunctionPass::getAnalysisUsage(AU); MachineFunctionPass::getAnalysisUsage(AU);
} }
void MachineBasicBlock *MachinePostDominatorTree::findNearestCommonDominator(
MachinePostDominatorTree::print(llvm::raw_ostream &OS, const Module *M) const { ArrayRef<MachineBasicBlock *> Blocks) const {
DT->print(OS); assert(!Blocks.empty());
MachineBasicBlock *NCD = Blocks.front();
for (MachineBasicBlock *BB : Blocks.drop_front()) {
NCD = PDT->findNearestCommonDominator(NCD, BB);
// Stop when the root is reached.
if (PDT->isVirtualRoot(PDT->getNode(NCD)))
return nullptr;
}
return NCD;
}
void MachinePostDominatorTree::print(llvm::raw_ostream &OS,
const Module *M) const {
PDT->print(OS);
} }

View File

@ -589,12 +589,12 @@ void SILowerI1Copies::lowerPhis() {
// Phis in a loop that are observed outside the loop receive a simple but // Phis in a loop that are observed outside the loop receive a simple but
// conservatively correct treatment. // conservatively correct treatment.
MachineBasicBlock *PostDomBound = &MBB; std::vector<MachineBasicBlock *> DomBlocks = {&MBB};
for (MachineInstr &Use : MRI->use_instructions(DstReg)) { for (MachineInstr &Use : MRI->use_instructions(DstReg))
PostDomBound = DomBlocks.push_back(Use.getParent());
PDT->findNearestCommonDominator(PostDomBound, Use.getParent());
}
MachineBasicBlock *PostDomBound =
PDT->findNearestCommonDominator(DomBlocks);
unsigned FoundLoopLevel = LF.findLoop(PostDomBound); unsigned FoundLoopLevel = LF.findLoop(PostDomBound);
SSAUpdater.Initialize(DstReg); SSAUpdater.Initialize(DstReg);
@ -711,12 +711,12 @@ void SILowerI1Copies::lowerCopiesToI1() {
// Defs in a loop that are observed outside the loop must be transformed // Defs in a loop that are observed outside the loop must be transformed
// into appropriate bit manipulation. // into appropriate bit manipulation.
MachineBasicBlock *PostDomBound = &MBB; std::vector<MachineBasicBlock *> DomBlocks = {&MBB};
for (MachineInstr &Use : MRI->use_instructions(DstReg)) { for (MachineInstr &Use : MRI->use_instructions(DstReg))
PostDomBound = DomBlocks.push_back(Use.getParent());
PDT->findNearestCommonDominator(PostDomBound, Use.getParent());
}
MachineBasicBlock *PostDomBound =
PDT->findNearestCommonDominator(DomBlocks);
unsigned FoundLoopLevel = LF.findLoop(PostDomBound); unsigned FoundLoopLevel = LF.findLoop(PostDomBound);
if (FoundLoopLevel) { if (FoundLoopLevel) {
SSAUpdater.Initialize(DstReg); SSAUpdater.Initialize(DstReg);