llvm-project/llvm/lib/Analysis/SyncDependenceAnalysis.cpp

460 lines
15 KiB
C++

//===--- SyncDependenceAnalysis.cpp - Compute Control Divergence Effects --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements an algorithm that returns for a divergent branch
// the set of basic blocks whose phi nodes become divergent due to divergent
// control. These are the blocks that are reachable by two disjoint paths from
// the branch or loop exits that have a reaching path that is disjoint from a
// path to the loop latch.
//
// The SyncDependenceAnalysis is used in the DivergenceAnalysis to model
// control-induced divergence in phi nodes.
//
// -- Summary --
// The SyncDependenceAnalysis lazily computes sync dependences [3].
// The analysis evaluates the disjoint path criterion [2] by a reduction
// to SSA construction. The SSA construction algorithm is implemented as
// a simple data-flow analysis [1].
//
// [1] "A Simple, Fast Dominance Algorithm", SPI '01, Cooper, Harvey and Kennedy
// [2] "Efficiently Computing Static Single Assignment Form
// and the Control Dependence Graph", TOPLAS '91,
// Cytron, Ferrante, Rosen, Wegman and Zadeck
// [3] "Improving Performance of OpenCL on CPUs", CC '12, Karrenberg and Hack
// [4] "Divergence Analysis", TOPLAS '13, Sampaio, Souza, Collange and Pereira
//
// -- Sync dependence --
// Sync dependence [4] characterizes the control flow aspect of the
// propagation of branch divergence. For example,
//
// %cond = icmp slt i32 %tid, 10
// br i1 %cond, label %then, label %else
// then:
// br label %merge
// else:
// br label %merge
// merge:
// %a = phi i32 [ 0, %then ], [ 1, %else ]
//
// Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
// because %tid is not on its use-def chains, %a is sync dependent on %tid
// because the branch "br i1 %cond" depends on %tid and affects which value %a
// is assigned to.
//
// -- Reduction to SSA construction --
// There are two disjoint paths from A to X, if a certain variant of SSA
// construction places a phi node in X under the following set-up scheme [2].
//
// This variant of SSA construction ignores incoming undef values.
// That is paths from the entry without a definition do not result in
// phi nodes.
//
// entry
// / \
// A \
// / \ Y
// B C /
// \ / \ /
// D E
// \ /
// F
// Assume that A contains a divergent branch. We are interested
// in the set of all blocks where each block is reachable from A
// via two disjoint paths. This would be the set {D, F} in this
// case.
// To generally reduce this query to SSA construction we introduce
// a virtual variable x and assign to x different values in each
// successor block of A.
// entry
// / \
// A \
// / \ Y
// x = 0 x = 1 /
// \ / \ /
// D E
// \ /
// F
// Our flavor of SSA construction for x will construct the following
// entry
// / \
// A \
// / \ Y
// x0 = 0 x1 = 1 /
// \ / \ /
// x2=phi E
// \ /
// x3=phi
// The blocks D and F contain phi nodes and are thus each reachable
// by two disjoins paths from A.
//
// -- Remarks --
// In case of loop exits we need to check the disjoint path criterion for loops
// [2]. To this end, we check whether the definition of x differs between the
// loop exit and the loop header (_after_ SSA construction).
//
//===----------------------------------------------------------------------===//
#include "llvm/Analysis/SyncDependenceAnalysis.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Analysis/PostDominators.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"
#include <functional>
#include <stack>
#include <unordered_set>
#define DEBUG_TYPE "sync-dependence"
// The SDA algorithm operates on a modified CFG - we modify the edges leaving
// loop headers as follows:
//
// * We remove all edges leaving all loop headers.
// * We add additional edges from the loop headers to their exit blocks.
//
// The modification is virtual, that is whenever we visit a loop header we
// pretend it had different successors.
namespace {
using namespace llvm;
// Custom Post-Order Traveral
//
// We cannot use the vanilla (R)PO computation of LLVM because:
// * We (virtually) modify the CFG.
// * We want a loop-compact block enumeration, that is the numbers assigned by
// the traveral to the blocks of a loop are an interval.
using POCB = std::function<void(const BasicBlock &)>;
using VisitedSet = std::set<const BasicBlock *>;
using BlockStack = std::vector<const BasicBlock *>;
// forward
static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack,
VisitedSet &Finalized);
// for a nested region (top-level loop or nested loop)
static void computeStackPO(BlockStack &Stack, const LoopInfo &LI, Loop *Loop,
POCB CallBack, VisitedSet &Finalized) {
const auto *LoopHeader = Loop ? Loop->getHeader() : nullptr;
while (!Stack.empty()) {
const auto *NextBB = Stack.back();
auto *NestedLoop = LI.getLoopFor(NextBB);
bool IsNestedLoop = NestedLoop != Loop;
// Treat the loop as a node
if (IsNestedLoop) {
SmallVector<BasicBlock *, 3> NestedExits;
NestedLoop->getUniqueExitBlocks(NestedExits);
bool PushedNodes = false;
for (const auto *NestedExitBB : NestedExits) {
if (NestedExitBB == LoopHeader)
continue;
if (Loop && !Loop->contains(NestedExitBB))
continue;
if (Finalized.count(NestedExitBB))
continue;
PushedNodes = true;
Stack.push_back(NestedExitBB);
}
if (!PushedNodes) {
// All loop exits finalized -> finish this node
Stack.pop_back();
computeLoopPO(LI, *NestedLoop, CallBack, Finalized);
}
continue;
}
// DAG-style
bool PushedNodes = false;
for (const auto *SuccBB : successors(NextBB)) {
if (SuccBB == LoopHeader)
continue;
if (Loop && !Loop->contains(SuccBB))
continue;
if (Finalized.count(SuccBB))
continue;
PushedNodes = true;
Stack.push_back(SuccBB);
}
if (!PushedNodes) {
// Never push nodes twice
Stack.pop_back();
if (!Finalized.insert(NextBB).second)
continue;
CallBack(*NextBB);
}
}
}
static void computeTopLevelPO(Function &F, const LoopInfo &LI, POCB CallBack) {
VisitedSet Finalized;
BlockStack Stack;
Stack.reserve(24); // FIXME made-up number
Stack.push_back(&F.getEntryBlock());
computeStackPO(Stack, LI, nullptr, CallBack, Finalized);
}
static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack,
VisitedSet &Finalized) {
/// Call CallBack on all loop blocks.
std::vector<const BasicBlock *> Stack;
const auto *LoopHeader = Loop.getHeader();
// Visit the header last
Finalized.insert(LoopHeader);
CallBack(*LoopHeader);
// Initialize with immediate successors
for (const auto *BB : successors(LoopHeader)) {
if (!Loop.contains(BB))
continue;
if (BB == LoopHeader)
continue;
Stack.push_back(BB);
}
// Compute PO inside region
computeStackPO(Stack, LI, &Loop, CallBack, Finalized);
}
} // namespace
namespace llvm {
ControlDivergenceDesc SyncDependenceAnalysis::EmptyDivergenceDesc;
SyncDependenceAnalysis::SyncDependenceAnalysis(const DominatorTree &DT,
const PostDominatorTree &PDT,
const LoopInfo &LI)
: DT(DT), PDT(PDT), LI(LI) {
computeTopLevelPO(*DT.getRoot()->getParent(), LI,
[&](const BasicBlock &BB) { LoopPO.appendBlock(BB); });
}
SyncDependenceAnalysis::~SyncDependenceAnalysis() {}
// divergence propagator for reducible CFGs
struct DivergencePropagator {
const ModifiedPO &LoopPOT;
const DominatorTree &DT;
const PostDominatorTree &PDT;
const LoopInfo &LI;
const BasicBlock &DivTermBlock;
// * if BlockLabels[IndexOf(B)] == C then C is the dominating definition at
// block B
// * if BlockLabels[IndexOf(B)] ~ undef then we haven't seen B yet
// * if BlockLabels[IndexOf(B)] == B then B is a join point of disjoint paths
// from X or B is an immediate successor of X (initial value).
using BlockLabelVec = std::vector<const BasicBlock *>;
BlockLabelVec BlockLabels;
// divergent join and loop exit descriptor.
std::unique_ptr<ControlDivergenceDesc> DivDesc;
DivergencePropagator(const ModifiedPO &LoopPOT, const DominatorTree &DT,
const PostDominatorTree &PDT, const LoopInfo &LI,
const BasicBlock &DivTermBlock)
: LoopPOT(LoopPOT), DT(DT), PDT(PDT), LI(LI), DivTermBlock(DivTermBlock),
BlockLabels(LoopPOT.size(), nullptr),
DivDesc(new ControlDivergenceDesc) {}
void printDefs(raw_ostream &Out) {
Out << "Propagator::BlockLabels {\n";
for (int BlockIdx = (int)BlockLabels.size() - 1; BlockIdx > 0; --BlockIdx) {
const auto *Label = BlockLabels[BlockIdx];
Out << LoopPOT.getBlockAt(BlockIdx)->getName().str() << "(" << BlockIdx
<< ") : ";
if (!Label) {
Out << "<null>\n";
} else {
Out << Label->getName() << "\n";
}
}
Out << "}\n";
}
// Push a definition (\p PushedLabel) to \p SuccBlock and return whether this
// causes a divergent join.
bool computeJoin(const BasicBlock &SuccBlock, const BasicBlock &PushedLabel) {
auto SuccIdx = LoopPOT.getIndexOf(SuccBlock);
// unset or same reaching label
const auto *OldLabel = BlockLabels[SuccIdx];
if (!OldLabel || (OldLabel == &PushedLabel)) {
BlockLabels[SuccIdx] = &PushedLabel;
return false;
}
// Update the definition
BlockLabels[SuccIdx] = &SuccBlock;
return true;
}
// visiting a virtual loop exit edge from the loop header --> temporal
// divergence on join
bool visitLoopExitEdge(const BasicBlock &ExitBlock,
const BasicBlock &DefBlock, bool FromParentLoop) {
// Pushing from a non-parent loop cannot cause temporal divergence.
if (!FromParentLoop)
return visitEdge(ExitBlock, DefBlock);
if (!computeJoin(ExitBlock, DefBlock))
return false;
// Identified a divergent loop exit
DivDesc->LoopDivBlocks.insert(&ExitBlock);
LLVM_DEBUG(dbgs() << "\tDivergent loop exit: " << ExitBlock.getName()
<< "\n");
return true;
}
// process \p SuccBlock with reaching definition \p DefBlock
bool visitEdge(const BasicBlock &SuccBlock, const BasicBlock &DefBlock) {
if (!computeJoin(SuccBlock, DefBlock))
return false;
// Divergent, disjoint paths join.
DivDesc->JoinDivBlocks.insert(&SuccBlock);
LLVM_DEBUG(dbgs() << "\tDivergent join: " << SuccBlock.getName());
return true;
}
std::unique_ptr<ControlDivergenceDesc> computeJoinPoints() {
assert(DivDesc);
LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: " << DivTermBlock.getName()
<< "\n");
const auto *DivBlockLoop = LI.getLoopFor(&DivTermBlock);
// Early stopping criterion
int FloorIdx = LoopPOT.size() - 1;
const BasicBlock *FloorLabel = nullptr;
// bootstrap with branch targets
int BlockIdx = 0;
for (const auto *SuccBlock : successors(&DivTermBlock)) {
auto SuccIdx = LoopPOT.getIndexOf(*SuccBlock);
BlockLabels[SuccIdx] = SuccBlock;
// Find the successor with the highest index to start with
BlockIdx = std::max<int>(BlockIdx, SuccIdx);
FloorIdx = std::min<int>(FloorIdx, SuccIdx);
// Identify immediate divergent loop exits
if (!DivBlockLoop)
continue;
const auto *BlockLoop = LI.getLoopFor(SuccBlock);
if (BlockLoop && DivBlockLoop->contains(BlockLoop))
continue;
DivDesc->LoopDivBlocks.insert(SuccBlock);
LLVM_DEBUG(dbgs() << "\tImmediate divergent loop exit: "
<< SuccBlock->getName() << "\n");
}
// propagate definitions at the immediate successors of the node in RPO
for (; BlockIdx >= FloorIdx; --BlockIdx) {
LLVM_DEBUG(dbgs() << "Before next visit:\n"; printDefs(dbgs()));
// Any label available here
const auto *Label = BlockLabels[BlockIdx];
if (!Label)
continue;
// Ok. Get the block
const auto *Block = LoopPOT.getBlockAt(BlockIdx);
LLVM_DEBUG(dbgs() << "SDA::joins. visiting " << Block->getName() << "\n");
auto *BlockLoop = LI.getLoopFor(Block);
bool IsLoopHeader = BlockLoop && BlockLoop->getHeader() == Block;
bool CausedJoin = false;
int LoweredFloorIdx = FloorIdx;
if (IsLoopHeader) {
// Disconnect from immediate successors and propagate directly to loop
// exits.
SmallVector<BasicBlock *, 4> BlockLoopExits;
BlockLoop->getExitBlocks(BlockLoopExits);
bool IsParentLoop = BlockLoop->contains(&DivTermBlock);
for (const auto *BlockLoopExit : BlockLoopExits) {
CausedJoin |= visitLoopExitEdge(*BlockLoopExit, *Label, IsParentLoop);
LoweredFloorIdx = std::min<int>(LoweredFloorIdx,
LoopPOT.getIndexOf(*BlockLoopExit));
}
} else {
// Acyclic successor case
for (const auto *SuccBlock : successors(Block)) {
CausedJoin |= visitEdge(*SuccBlock, *Label);
LoweredFloorIdx =
std::min<int>(LoweredFloorIdx, LoopPOT.getIndexOf(*SuccBlock));
}
}
// Floor update
if (CausedJoin) {
// 1. Different labels pushed to successors
FloorIdx = LoweredFloorIdx;
} else if (FloorLabel != Label) {
// 2. No join caused BUT we pushed a label that is different than the
// last pushed label
FloorIdx = LoweredFloorIdx;
FloorLabel = Label;
}
}
LLVM_DEBUG(dbgs() << "SDA::joins. After propagation:\n"; printDefs(dbgs()));
return std::move(DivDesc);
}
};
#ifndef NDEBUG
static void printBlockSet(ConstBlockSet &Blocks, raw_ostream &Out) {
Out << "[";
ListSeparator LS;
for (const auto *BB : Blocks)
Out << LS << BB->getName();
Out << "]";
}
#endif
const ControlDivergenceDesc &
SyncDependenceAnalysis::getJoinBlocks(const Instruction &Term) {
// trivial case
if (Term.getNumSuccessors() <= 1) {
return EmptyDivergenceDesc;
}
// already available in cache?
auto ItCached = CachedControlDivDescs.find(&Term);
if (ItCached != CachedControlDivDescs.end())
return *ItCached->second;
// compute all join points
// Special handling of divergent loop exits is not needed for LCSSA
const auto &TermBlock = *Term.getParent();
DivergencePropagator Propagator(LoopPO, DT, PDT, LI, TermBlock);
auto DivDesc = Propagator.computeJoinPoints();
LLVM_DEBUG(dbgs() << "Result (" << Term.getParent()->getName() << "):\n";
dbgs() << "JoinDivBlocks: ";
printBlockSet(DivDesc->JoinDivBlocks, dbgs());
dbgs() << "\nLoopDivBlocks: ";
printBlockSet(DivDesc->LoopDivBlocks, dbgs()); dbgs() << "\n";);
auto ItInserted = CachedControlDivDescs.emplace(&Term, std::move(DivDesc));
assert(ItInserted.second);
return *ItInserted.first->second;
}
} // namespace llvm