forked from OSchip/llvm-project
410 lines
14 KiB
C++
410 lines
14 KiB
C++
//===---- DivergenceAnalysis.cpp --- Divergence Analysis Implementation ----==//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements a general divergence analysis for loop vectorization
|
|
// and GPU programs. It determines which branches and values in a loop or GPU
|
|
// program are divergent. It can help branch optimizations such as jump
|
|
// threading and loop unswitching to make better decisions.
|
|
//
|
|
// GPU programs typically use the SIMD execution model, where multiple threads
|
|
// in the same execution group have to execute in lock-step. Therefore, if the
|
|
// code contains divergent branches (i.e., threads in a group do not agree on
|
|
// which path of the branch to take), the group of threads has to execute all
|
|
// the paths from that branch with different subsets of threads enabled until
|
|
// they re-converge.
|
|
//
|
|
// Due to this execution model, some optimizations such as jump
|
|
// threading and loop unswitching can interfere with thread re-convergence.
|
|
// Therefore, an analysis that computes which branches in a GPU program are
|
|
// divergent can help the compiler to selectively run these optimizations.
|
|
//
|
|
// This implementation is derived from the Vectorization Analysis of the
|
|
// Region Vectorizer (RV). The analysis is based on the approach described in
|
|
//
|
|
// An abstract interpretation for SPMD divergence
|
|
// on reducible control flow graphs.
|
|
// Julian Rosemann, Simon Moll and Sebastian Hack
|
|
// POPL '21
|
|
//
|
|
// This implementation is generic in the sense that it does
|
|
// not itself identify original sources of divergence.
|
|
// Instead specialized adapter classes, (LoopDivergenceAnalysis) for loops and
|
|
// (DivergenceAnalysis) for functions, identify the sources of divergence
|
|
// (e.g., special variables that hold the thread ID or the iteration variable).
|
|
//
|
|
// The generic implementation propagates divergence to variables that are data
|
|
// or sync dependent on a source of divergence.
|
|
//
|
|
// While data dependency is a well-known concept, the notion of sync dependency
|
|
// is worth more explanation. Sync dependence characterizes the control flow
|
|
// aspect of the propagation of branch divergence. For example,
|
|
//
|
|
// %cond = icmp slt i32 %tid, 10
|
|
// br i1 %cond, label %then, label %else
|
|
// then:
|
|
// br label %merge
|
|
// else:
|
|
// br label %merge
|
|
// merge:
|
|
// %a = phi i32 [ 0, %then ], [ 1, %else ]
|
|
//
|
|
// Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
|
|
// because %tid is not on its use-def chains, %a is sync dependent on %tid
|
|
// because the branch "br i1 %cond" depends on %tid and affects which value %a
|
|
// is assigned to.
|
|
//
|
|
// The sync dependence detection (which branch induces divergence in which join
|
|
// points) is implemented in the SyncDependenceAnalysis.
|
|
//
|
|
// The current implementation has the following limitations:
|
|
// 1. intra-procedural. It conservatively considers the arguments of a
|
|
// non-kernel-entry function and the return value of a function call as
|
|
// divergent.
|
|
// 2. memory as black box. It conservatively considers values loaded from
|
|
// generic or local address as divergent. This can be improved by leveraging
|
|
// pointer analysis and/or by modelling non-escaping memory objects in SSA
|
|
// as done in RV.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "llvm/Analysis/DivergenceAnalysis.h"
|
|
#include "llvm/ADT/PostOrderIterator.h"
|
|
#include "llvm/Analysis/CFG.h"
|
|
#include "llvm/Analysis/LoopInfo.h"
|
|
#include "llvm/Analysis/PostDominators.h"
|
|
#include "llvm/Analysis/TargetTransformInfo.h"
|
|
#include "llvm/IR/Dominators.h"
|
|
#include "llvm/IR/InstIterator.h"
|
|
#include "llvm/IR/Instructions.h"
|
|
#include "llvm/IR/Value.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
|
|
using namespace llvm;
|
|
|
|
#define DEBUG_TYPE "divergence"
|
|
|
|
DivergenceAnalysisImpl::DivergenceAnalysisImpl(
|
|
const Function &F, const Loop *RegionLoop, const DominatorTree &DT,
|
|
const LoopInfo &LI, SyncDependenceAnalysis &SDA, bool IsLCSSAForm)
|
|
: F(F), RegionLoop(RegionLoop), DT(DT), LI(LI), SDA(SDA),
|
|
IsLCSSAForm(IsLCSSAForm) {}
|
|
|
|
bool DivergenceAnalysisImpl::markDivergent(const Value &DivVal) {
|
|
if (isAlwaysUniform(DivVal))
|
|
return false;
|
|
assert(isa<Instruction>(DivVal) || isa<Argument>(DivVal));
|
|
assert(!isAlwaysUniform(DivVal) && "cannot be a divergent");
|
|
return DivergentValues.insert(&DivVal).second;
|
|
}
|
|
|
|
void DivergenceAnalysisImpl::addUniformOverride(const Value &UniVal) {
|
|
UniformOverrides.insert(&UniVal);
|
|
}
|
|
|
|
bool DivergenceAnalysisImpl::isTemporalDivergent(
|
|
const BasicBlock &ObservingBlock, const Value &Val) const {
|
|
const auto *Inst = dyn_cast<const Instruction>(&Val);
|
|
if (!Inst)
|
|
return false;
|
|
// check whether any divergent loop carrying Val terminates before control
|
|
// proceeds to ObservingBlock
|
|
for (const auto *Loop = LI.getLoopFor(Inst->getParent());
|
|
Loop != RegionLoop && !Loop->contains(&ObservingBlock);
|
|
Loop = Loop->getParentLoop()) {
|
|
if (DivergentLoops.contains(Loop))
|
|
return true;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
bool DivergenceAnalysisImpl::inRegion(const Instruction &I) const {
|
|
return I.getParent() && inRegion(*I.getParent());
|
|
}
|
|
|
|
bool DivergenceAnalysisImpl::inRegion(const BasicBlock &BB) const {
|
|
return RegionLoop ? RegionLoop->contains(&BB) : (BB.getParent() == &F);
|
|
}
|
|
|
|
void DivergenceAnalysisImpl::pushUsers(const Value &V) {
|
|
const auto *I = dyn_cast<const Instruction>(&V);
|
|
|
|
if (I && I->isTerminator()) {
|
|
analyzeControlDivergence(*I);
|
|
return;
|
|
}
|
|
|
|
for (const auto *User : V.users()) {
|
|
const auto *UserInst = dyn_cast<const Instruction>(User);
|
|
if (!UserInst)
|
|
continue;
|
|
|
|
// only compute divergent inside loop
|
|
if (!inRegion(*UserInst))
|
|
continue;
|
|
|
|
// All users of divergent values are immediate divergent
|
|
if (markDivergent(*UserInst))
|
|
Worklist.push_back(UserInst);
|
|
}
|
|
}
|
|
|
|
static const Instruction *getIfCarriedInstruction(const Use &U,
|
|
const Loop &DivLoop) {
|
|
const auto *I = dyn_cast<const Instruction>(&U);
|
|
if (!I)
|
|
return nullptr;
|
|
if (!DivLoop.contains(I))
|
|
return nullptr;
|
|
return I;
|
|
}
|
|
|
|
void DivergenceAnalysisImpl::analyzeTemporalDivergence(
|
|
const Instruction &I, const Loop &OuterDivLoop) {
|
|
if (isAlwaysUniform(I))
|
|
return;
|
|
if (isDivergent(I))
|
|
return;
|
|
|
|
LLVM_DEBUG(dbgs() << "Analyze temporal divergence: " << I.getName() << "\n");
|
|
assert((isa<PHINode>(I) || !IsLCSSAForm) &&
|
|
"In LCSSA form all users of loop-exiting defs are Phi nodes.");
|
|
for (const Use &Op : I.operands()) {
|
|
const auto *OpInst = getIfCarriedInstruction(Op, OuterDivLoop);
|
|
if (!OpInst)
|
|
continue;
|
|
if (markDivergent(I))
|
|
pushUsers(I);
|
|
return;
|
|
}
|
|
}
|
|
|
|
// marks all users of loop-carried values of the loop headed by LoopHeader as
|
|
// divergent
|
|
void DivergenceAnalysisImpl::analyzeLoopExitDivergence(
|
|
const BasicBlock &DivExit, const Loop &OuterDivLoop) {
|
|
// All users are in immediate exit blocks
|
|
if (IsLCSSAForm) {
|
|
for (const auto &Phi : DivExit.phis()) {
|
|
analyzeTemporalDivergence(Phi, OuterDivLoop);
|
|
}
|
|
return;
|
|
}
|
|
|
|
// For non-LCSSA we have to follow all live out edges wherever they may lead.
|
|
const BasicBlock &LoopHeader = *OuterDivLoop.getHeader();
|
|
SmallVector<const BasicBlock *, 8> TaintStack;
|
|
TaintStack.push_back(&DivExit);
|
|
|
|
// Otherwise potential users of loop-carried values could be anywhere in the
|
|
// dominance region of DivLoop (including its fringes for phi nodes)
|
|
DenseSet<const BasicBlock *> Visited;
|
|
Visited.insert(&DivExit);
|
|
|
|
do {
|
|
auto *UserBlock = TaintStack.pop_back_val();
|
|
|
|
// don't spread divergence beyond the region
|
|
if (!inRegion(*UserBlock))
|
|
continue;
|
|
|
|
assert(!OuterDivLoop.contains(UserBlock) &&
|
|
"irreducible control flow detected");
|
|
|
|
// phi nodes at the fringes of the dominance region
|
|
if (!DT.dominates(&LoopHeader, UserBlock)) {
|
|
// all PHI nodes of UserBlock become divergent
|
|
for (const auto &Phi : UserBlock->phis()) {
|
|
analyzeTemporalDivergence(Phi, OuterDivLoop);
|
|
}
|
|
continue;
|
|
}
|
|
|
|
// Taint outside users of values carried by OuterDivLoop.
|
|
for (const auto &I : *UserBlock) {
|
|
analyzeTemporalDivergence(I, OuterDivLoop);
|
|
}
|
|
|
|
// visit all blocks in the dominance region
|
|
for (const auto *SuccBlock : successors(UserBlock)) {
|
|
if (!Visited.insert(SuccBlock).second) {
|
|
continue;
|
|
}
|
|
TaintStack.push_back(SuccBlock);
|
|
}
|
|
} while (!TaintStack.empty());
|
|
}
|
|
|
|
void DivergenceAnalysisImpl::propagateLoopExitDivergence(
|
|
const BasicBlock &DivExit, const Loop &InnerDivLoop) {
|
|
LLVM_DEBUG(dbgs() << "\tpropLoopExitDiv " << DivExit.getName() << "\n");
|
|
|
|
// Find outer-most loop that does not contain \p DivExit
|
|
const Loop *DivLoop = &InnerDivLoop;
|
|
const Loop *OuterDivLoop = DivLoop;
|
|
const Loop *ExitLevelLoop = LI.getLoopFor(&DivExit);
|
|
const unsigned LoopExitDepth =
|
|
ExitLevelLoop ? ExitLevelLoop->getLoopDepth() : 0;
|
|
while (DivLoop && DivLoop->getLoopDepth() > LoopExitDepth) {
|
|
DivergentLoops.insert(DivLoop); // all crossed loops are divergent
|
|
OuterDivLoop = DivLoop;
|
|
DivLoop = DivLoop->getParentLoop();
|
|
}
|
|
LLVM_DEBUG(dbgs() << "\tOuter-most left loop: " << OuterDivLoop->getName()
|
|
<< "\n");
|
|
|
|
analyzeLoopExitDivergence(DivExit, *OuterDivLoop);
|
|
}
|
|
|
|
// this is a divergent join point - mark all phi nodes as divergent and push
|
|
// them onto the stack.
|
|
void DivergenceAnalysisImpl::taintAndPushPhiNodes(const BasicBlock &JoinBlock) {
|
|
LLVM_DEBUG(dbgs() << "taintAndPushPhiNodes in " << JoinBlock.getName()
|
|
<< "\n");
|
|
|
|
// ignore divergence outside the region
|
|
if (!inRegion(JoinBlock)) {
|
|
return;
|
|
}
|
|
|
|
// push non-divergent phi nodes in JoinBlock to the worklist
|
|
for (const auto &Phi : JoinBlock.phis()) {
|
|
if (isDivergent(Phi))
|
|
continue;
|
|
// FIXME Theoretically ,the 'undef' value could be replaced by any other
|
|
// value causing spurious divergence.
|
|
if (Phi.hasConstantOrUndefValue())
|
|
continue;
|
|
if (markDivergent(Phi))
|
|
Worklist.push_back(&Phi);
|
|
}
|
|
}
|
|
|
|
void DivergenceAnalysisImpl::analyzeControlDivergence(const Instruction &Term) {
|
|
LLVM_DEBUG(dbgs() << "analyzeControlDiv " << Term.getParent()->getName()
|
|
<< "\n");
|
|
|
|
// Don't propagate divergence from unreachable blocks.
|
|
if (!DT.isReachableFromEntry(Term.getParent()))
|
|
return;
|
|
|
|
const auto *BranchLoop = LI.getLoopFor(Term.getParent());
|
|
|
|
const auto &DivDesc = SDA.getJoinBlocks(Term);
|
|
|
|
// Iterate over all blocks now reachable by a disjoint path join
|
|
for (const auto *JoinBlock : DivDesc.JoinDivBlocks) {
|
|
taintAndPushPhiNodes(*JoinBlock);
|
|
}
|
|
|
|
assert(DivDesc.LoopDivBlocks.empty() || BranchLoop);
|
|
for (const auto *DivExitBlock : DivDesc.LoopDivBlocks) {
|
|
propagateLoopExitDivergence(*DivExitBlock, *BranchLoop);
|
|
}
|
|
}
|
|
|
|
void DivergenceAnalysisImpl::compute() {
|
|
// Initialize worklist.
|
|
auto DivValuesCopy = DivergentValues;
|
|
for (const auto *DivVal : DivValuesCopy) {
|
|
assert(isDivergent(*DivVal) && "Worklist invariant violated!");
|
|
pushUsers(*DivVal);
|
|
}
|
|
|
|
// All values on the Worklist are divergent.
|
|
// Their users may not have been updated yed.
|
|
while (!Worklist.empty()) {
|
|
const Instruction &I = *Worklist.back();
|
|
Worklist.pop_back();
|
|
|
|
// propagate value divergence to users
|
|
assert(isDivergent(I) && "Worklist invariant violated!");
|
|
pushUsers(I);
|
|
}
|
|
}
|
|
|
|
bool DivergenceAnalysisImpl::isAlwaysUniform(const Value &V) const {
|
|
return UniformOverrides.contains(&V);
|
|
}
|
|
|
|
bool DivergenceAnalysisImpl::isDivergent(const Value &V) const {
|
|
return DivergentValues.contains(&V);
|
|
}
|
|
|
|
bool DivergenceAnalysisImpl::isDivergentUse(const Use &U) const {
|
|
Value &V = *U.get();
|
|
Instruction &I = *cast<Instruction>(U.getUser());
|
|
return isDivergent(V) || isTemporalDivergent(*I.getParent(), V);
|
|
}
|
|
|
|
DivergenceInfo::DivergenceInfo(Function &F, const DominatorTree &DT,
|
|
const PostDominatorTree &PDT, const LoopInfo &LI,
|
|
const TargetTransformInfo &TTI,
|
|
bool KnownReducible)
|
|
: F(F) {
|
|
if (!KnownReducible) {
|
|
using RPOTraversal = ReversePostOrderTraversal<const Function *>;
|
|
RPOTraversal FuncRPOT(&F);
|
|
if (containsIrreducibleCFG<const BasicBlock *, const RPOTraversal,
|
|
const LoopInfo>(FuncRPOT, LI)) {
|
|
ContainsIrreducible = true;
|
|
return;
|
|
}
|
|
}
|
|
SDA = std::make_unique<SyncDependenceAnalysis>(DT, PDT, LI);
|
|
DA = std::make_unique<DivergenceAnalysisImpl>(F, nullptr, DT, LI, *SDA,
|
|
/* LCSSA */ false);
|
|
for (auto &I : instructions(F)) {
|
|
if (TTI.isSourceOfDivergence(&I)) {
|
|
DA->markDivergent(I);
|
|
} else if (TTI.isAlwaysUniform(&I)) {
|
|
DA->addUniformOverride(I);
|
|
}
|
|
}
|
|
for (auto &Arg : F.args()) {
|
|
if (TTI.isSourceOfDivergence(&Arg)) {
|
|
DA->markDivergent(Arg);
|
|
}
|
|
}
|
|
|
|
DA->compute();
|
|
}
|
|
|
|
AnalysisKey DivergenceAnalysis::Key;
|
|
|
|
DivergenceAnalysis::Result
|
|
DivergenceAnalysis::run(Function &F, FunctionAnalysisManager &AM) {
|
|
auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
|
|
auto &PDT = AM.getResult<PostDominatorTreeAnalysis>(F);
|
|
auto &LI = AM.getResult<LoopAnalysis>(F);
|
|
auto &TTI = AM.getResult<TargetIRAnalysis>(F);
|
|
|
|
return DivergenceInfo(F, DT, PDT, LI, TTI, /* KnownReducible = */ false);
|
|
}
|
|
|
|
PreservedAnalyses
|
|
DivergenceAnalysisPrinterPass::run(Function &F, FunctionAnalysisManager &FAM) {
|
|
auto &DI = FAM.getResult<DivergenceAnalysis>(F);
|
|
OS << "'Divergence Analysis' for function '" << F.getName() << "':\n";
|
|
if (DI.hasDivergence()) {
|
|
for (auto &Arg : F.args()) {
|
|
OS << (DI.isDivergent(Arg) ? "DIVERGENT: " : " ");
|
|
OS << Arg << "\n";
|
|
}
|
|
for (const BasicBlock &BB : F) {
|
|
OS << "\n " << BB.getName() << ":\n";
|
|
for (const auto &I : BB.instructionsWithoutDebug()) {
|
|
OS << (DI.isDivergent(I) ? "DIVERGENT: " : " ");
|
|
OS << I << "\n";
|
|
}
|
|
}
|
|
}
|
|
return PreservedAnalyses::all();
|
|
}
|