llvm-project/llvm/lib/Analysis/SyntheticCountsUtils.cpp

123 lines
4.2 KiB
C++
Raw Normal View History

//===--- SyntheticCountsUtils.cpp - synthetic counts propagation utils ---===//
//
// The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
//
// This file defines utilities for propagating synthetic counts.
//
//===----------------------------------------------------------------------===//
#include "llvm/Analysis/SyntheticCountsUtils.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SCCIterator.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Analysis/CallGraph.h"
#include "llvm/IR/CallSite.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instructions.h"
using namespace llvm;
// Given a set of functions in an SCC, propagate entry counts to functions
// called by the SCC.
static void
propagateFromSCC(const SmallPtrSetImpl<Function *> &SCCFunctions,
function_ref<Scaled64(CallSite CS)> GetCallSiteRelFreq,
function_ref<uint64_t(Function *F)> GetCount,
function_ref<void(Function *F, uint64_t)> AddToCount) {
SmallVector<CallSite, 16> CallSites;
// Gather all callsites in the SCC.
auto GatherCallSites = [&]() {
for (auto *F : SCCFunctions) {
assert(F && !F->isDeclaration());
for (auto &I : instructions(F)) {
if (auto CS = CallSite(&I)) {
CallSites.push_back(CS);
}
}
}
};
GatherCallSites();
// Partition callsites so that the callsites that call functions in the same
// SCC come first.
auto Mid = partition(CallSites, [&](CallSite &CS) {
auto *Callee = CS.getCalledFunction();
if (Callee)
return SCCFunctions.count(Callee);
// FIXME: Use the !callees metadata to propagate counts through indirect
// calls.
return 0U;
});
// For functions in the same SCC, update the counts in two steps:
// 1. Compute the additional count for each function by propagating the counts
// along all incoming edges to the function that originate from the same SCC
// and summing them up.
// 2. Add the additional counts to the functions in the SCC.
// This ensures that the order of
// traversal of functions within the SCC doesn't change the final result.
DenseMap<Function *, uint64_t> AdditionalCounts;
for (auto It = CallSites.begin(); It != Mid; It++) {
auto &CS = *It;
auto RelFreq = GetCallSiteRelFreq(CS);
Function *Callee = CS.getCalledFunction();
Function *Caller = CS.getCaller();
RelFreq *= Scaled64(GetCount(Caller), 0);
uint64_t AdditionalCount = RelFreq.toInt<uint64_t>();
AdditionalCounts[Callee] += AdditionalCount;
}
// Update the counts for the functions in the SCC.
for (auto &Entry : AdditionalCounts)
AddToCount(Entry.first, Entry.second);
// Now update the counts for functions not in SCC.
for (auto It = Mid; It != CallSites.end(); It++) {
auto &CS = *It;
auto Weight = GetCallSiteRelFreq(CS);
Function *Callee = CS.getCalledFunction();
Function *Caller = CS.getCaller();
Weight *= Scaled64(GetCount(Caller), 0);
AddToCount(Callee, Weight.toInt<uint64_t>());
}
}
/// Propgate synthetic entry counts on a callgraph.
///
/// This performs a reverse post-order traversal of the callgraph SCC. For each
/// SCC, it first propagates the entry counts to the functions within the SCC
/// through call edges and updates them in one shot. Then the entry counts are
/// propagated to functions outside the SCC.
void llvm::propagateSyntheticCounts(
const CallGraph &CG, function_ref<Scaled64(CallSite CS)> GetCallSiteRelFreq,
function_ref<uint64_t(Function *F)> GetCount,
function_ref<void(Function *F, uint64_t)> AddToCount) {
SmallVector<SmallPtrSet<Function *, 8>, 16> SCCs;
for (auto I = scc_begin(&CG); !I.isAtEnd(); ++I) {
auto SCC = *I;
SmallPtrSet<Function *, 8> SCCFunctions;
for (auto *Node : SCC) {
Function *F = Node->getFunction();
if (F && !F->isDeclaration()) {
SCCFunctions.insert(F);
}
}
SCCs.push_back(SCCFunctions);
}
for (auto &SCCFunctions : reverse(SCCs))
propagateFromSCC(SCCFunctions, GetCallSiteRelFreq, GetCount, AddToCount);
}