[BOLT] Support duplicating jump tables

Summary:
If two indirect branches use the same jump table, we need to
detect this and duplicate dump tables so we can modify this CFG
correctly. This is necessary for instrumentation and shrink wrapping.
For the latter, we only detect this and bail, fixing this old known
issue with shrink wrapping.

Other minor changes to support better instrumentation: add an option
to instrument only hot functions, add LOCK prefix to instrumentation
increment instruction, speed up splitting critical edges by avoiding
calling recomputeLandingPads() unnecessarily.

(cherry picked from FBD16101312)
This commit is contained in:
Rafael Auler 2019-07-02 16:56:41 -07:00 committed by Maksim Panchenko
parent 8880969ced
commit 1169f1fdd8
10 changed files with 220 additions and 33 deletions

View File

@ -289,9 +289,7 @@ BinaryContext::handleAddressRef(uint64_t Address, BinaryFunction &BF,
// better heuristic. // better heuristic.
if (opts::StrictMode && if (opts::StrictMode &&
MemType == MemoryContentsType::POSSIBLE_PIC_JUMP_TABLE && IsPCRel) { MemType == MemoryContentsType::POSSIBLE_PIC_JUMP_TABLE && IsPCRel) {
JumpTable *JT; const MCSymbol *Symbol =
const MCSymbol *Symbol;
std::tie(JT, Symbol) =
getOrCreateJumpTable(BF, Address, JumpTable::JTT_PIC); getOrCreateJumpTable(BF, Address, JumpTable::JTT_PIC);
return std::make_pair(Symbol, Addend); return std::make_pair(Symbol, Addend);
@ -506,9 +504,8 @@ BinaryFunction *BinaryContext::createBinaryFunction(
return BF; return BF;
} }
std::pair<JumpTable *, const MCSymbol *> const MCSymbol *
BinaryContext::getOrCreateJumpTable(BinaryFunction &Function, BinaryContext::getOrCreateJumpTable(BinaryFunction &Function, uint64_t Address,
uint64_t Address,
JumpTable::JumpTableType Type) { JumpTable::JumpTableType Type) {
if (auto *JT = getJumpTableContainingAddress(Address)) { if (auto *JT = getJumpTableContainingAddress(Address)) {
assert(JT->Type == Type && "jump table types have to match"); assert(JT->Type == Type && "jump table types have to match");
@ -516,7 +513,7 @@ BinaryContext::getOrCreateJumpTable(BinaryFunction &Function,
"cannot re-use jump table of a different function"); "cannot re-use jump table of a different function");
assert(Address == JT->getAddress() && "unexpected non-empty jump table"); assert(Address == JT->getAddress() && "unexpected non-empty jump table");
return std::make_pair(JT, JT->getFirstLabel()); return JT->getFirstLabel();
} }
const auto EntrySize = const auto EntrySize =
@ -551,7 +548,40 @@ BinaryContext::getOrCreateJumpTable(BinaryFunction &Function,
// Duplicate the entry for the parent function for easy access. // Duplicate the entry for the parent function for easy access.
Function.JumpTables.emplace(Address, JT); Function.JumpTables.emplace(Address, JT);
return std::make_pair(JT, JTLabel); return JTLabel;
}
std::pair<uint64_t, const MCSymbol *>
BinaryContext::duplicateJumpTable(BinaryFunction &Function, JumpTable *JT,
const MCSymbol *OldLabel) {
unsigned Offset = 0;
bool Found = false;
for (auto Elmt : JT->Labels) {
if (Elmt.second != OldLabel)
continue;
Offset = Elmt.first;
Found = true;
break;
}
assert(Found && "Label not found");
auto *NewLabel = Ctx->createTempSymbol("duplicatedJT", true);
auto *NewJT = new JumpTable(NewLabel->getName(),
JT->getAddress(),
JT->EntrySize,
JT->Type,
{},
JumpTable::LabelMapType{{Offset, NewLabel}},
Function,
*getSectionForAddress(JT->getAddress()));
NewJT->Entries = JT->Entries;
NewJT->Counts = JT->Counts;
uint64_t JumpTableID = ++DuplicatedJumpTables;
// Invert it to differentiate from regular jump tables whose IDs are their
// addresses in the input binary memory space
JumpTableID = ~JumpTableID;
JumpTables.emplace(JumpTableID, NewJT);
Function.JumpTables.emplace(JumpTableID, NewJT);
return std::make_pair(JumpTableID, NewLabel);
} }
std::string BinaryContext::generateJumpTableName(const BinaryFunction &BF, std::string BinaryContext::generateJumpTableName(const BinaryFunction &BF,

View File

@ -165,6 +165,11 @@ class BinaryContext {
/// Jump tables for all functions mapped by address. /// Jump tables for all functions mapped by address.
std::map<uint64_t, JumpTable *> JumpTables; std::map<uint64_t, JumpTable *> JumpTables;
/// Used in duplicateJumpTable() to uniquely identify a JT clone
/// Start our IDs with a high number so getJumpTableContainingAddress checks
/// with size won't overflow
uint32_t DuplicatedJumpTables{0x10000000};
public: public:
/// [name] -> [BinaryData*] map used for global symbol resolution. /// [name] -> [BinaryData*] map used for global symbol resolution.
using SymbolMapType = std::map<std::string, BinaryData *>; using SymbolMapType = std::map<std::string, BinaryData *>;
@ -313,8 +318,7 @@ public:
/// ///
/// May create an embedded jump table and return its label as the second /// May create an embedded jump table and return its label as the second
/// element of the pair. /// element of the pair.
std::pair<JumpTable *, const MCSymbol *> const MCSymbol *getOrCreateJumpTable(BinaryFunction &Function,
getOrCreateJumpTable(BinaryFunction &Function,
uint64_t Address, uint64_t Address,
JumpTable::JumpTableType Type); JumpTable::JumpTableType Type);
@ -322,11 +326,18 @@ public:
/// their OffsetEntries based on memory contents. /// their OffsetEntries based on memory contents.
void populateJumpTables(); void populateJumpTables();
/// Returns a jump table ID and label pointing to the duplicated jump table.
/// Ordinarily, jump tables are identified by their address in the input
/// binary. We return an ID with the high bit set to differentiate it from
/// regular addresses, avoiding conflicts with standard jump tables.
std::pair<uint64_t, const MCSymbol *>
duplicateJumpTable(BinaryFunction &Function, JumpTable *JT,
const MCSymbol *OldLabel);
/// Generate a unique name for jump table at a given \p Address belonging /// Generate a unique name for jump table at a given \p Address belonging
/// to function \p BF. /// to function \p BF.
std::string generateJumpTableName(const BinaryFunction &BF, uint64_t Address); std::string generateJumpTableName(const BinaryFunction &BF, uint64_t Address);
public:
/// Regular page size. /// Regular page size.
static constexpr unsigned RegularPageSize = 0x1000; static constexpr unsigned RegularPageSize = 0x1000;

View File

@ -807,9 +807,8 @@ BinaryFunction::processIndirectBranch(MCInst &Instruction,
} }
auto useJumpTableForInstruction = [&](JumpTable::JumpTableType JTType) { auto useJumpTableForInstruction = [&](JumpTable::JumpTableType JTType) {
JumpTable *JT; const MCSymbol *JTLabel =
const MCSymbol *JTLabel; BC.getOrCreateJumpTable(*this, ArrayStart, JTType);
std::tie(JT, JTLabel) = BC.getOrCreateJumpTable(*this, ArrayStart, JTType);
BC.MIB->replaceMemOperandDisp(const_cast<MCInst &>(*MemLocInstr), BC.MIB->replaceMemOperandDisp(const_cast<MCInst &>(*MemLocInstr),
JTLabel, BC.Ctx.get()); JTLabel, BC.Ctx.get());
@ -3503,7 +3502,8 @@ void BinaryFunction::insertBasicBlocks(
BinaryBasicBlock *Start, BinaryBasicBlock *Start,
std::vector<std::unique_ptr<BinaryBasicBlock>> &&NewBBs, std::vector<std::unique_ptr<BinaryBasicBlock>> &&NewBBs,
const bool UpdateLayout, const bool UpdateLayout,
const bool UpdateCFIState) { const bool UpdateCFIState,
const bool RecomputeLandingPads) {
const auto StartIndex = Start ? getIndex(Start) : -1; const auto StartIndex = Start ? getIndex(Start) : -1;
const auto NumNewBlocks = NewBBs.size(); const auto NumNewBlocks = NewBBs.size();
@ -3517,7 +3517,11 @@ void BinaryFunction::insertBasicBlocks(
BasicBlocks[I++] = BB.release(); BasicBlocks[I++] = BB.release();
} }
if (RecomputeLandingPads) {
recomputeLandingPads(); recomputeLandingPads();
} else {
updateBBIndices(0);
}
if (UpdateLayout) { if (UpdateLayout) {
updateLayout(Start, NumNewBlocks); updateLayout(Start, NumNewBlocks);
@ -3532,7 +3536,8 @@ BinaryFunction::iterator BinaryFunction::insertBasicBlocks(
BinaryFunction::iterator StartBB, BinaryFunction::iterator StartBB,
std::vector<std::unique_ptr<BinaryBasicBlock>> &&NewBBs, std::vector<std::unique_ptr<BinaryBasicBlock>> &&NewBBs,
const bool UpdateLayout, const bool UpdateLayout,
const bool UpdateCFIState) { const bool UpdateCFIState,
const bool RecomputeLandingPads) {
const auto StartIndex = getIndex(&*StartBB); const auto StartIndex = getIndex(&*StartBB);
const auto NumNewBlocks = NewBBs.size(); const auto NumNewBlocks = NewBBs.size();
@ -3547,7 +3552,11 @@ BinaryFunction::iterator BinaryFunction::insertBasicBlocks(
BasicBlocks[I++] = BB.release(); BasicBlocks[I++] = BB.release();
} }
if (RecomputeLandingPads) {
recomputeLandingPads(); recomputeLandingPads();
} else {
updateBBIndices(0);
}
if (UpdateLayout) { if (UpdateLayout) {
updateLayout(*std::prev(RetIter), NumNewBlocks); updateLayout(*std::prev(RetIter), NumNewBlocks);
@ -3594,6 +3603,106 @@ void BinaryFunction::updateLayout(BinaryBasicBlock *Start,
updateLayoutIndices(); updateLayoutIndices();
} }
bool BinaryFunction::checkForAmbiguousJumpTables() {
SmallPtrSet<uint64_t, 4> JumpTables;
for (auto &BB : BasicBlocks) {
for (auto &Inst : *BB) {
if (!BC.MIB->isIndirectBranch(Inst))
continue;
auto JTAddress = BC.MIB->getJumpTable(Inst);
if (!JTAddress)
continue;
// This address can be inside another jump table, but we only consider
// it ambiguous when the same start address is used, not the same JT
// object.
auto Iter = JumpTables.find(JTAddress);
if (Iter == JumpTables.end()) {
JumpTables.insert(JTAddress);
continue;
}
return true;
}
}
return false;
}
void BinaryFunction::disambiguateJumpTables() {
assert((opts::JumpTables != JTS_BASIC && isSimple()) || BC.HasRelocations);
SmallPtrSet<JumpTable *, 4> JumpTables;
for (auto &BB : BasicBlocks) {
for (auto &Inst : *BB) {
if (!BC.MIB->isIndirectBranch(Inst))
continue;
auto *JT = getJumpTable(Inst);
if (!JT)
continue;
auto Iter = JumpTables.find(JT);
if (Iter == JumpTables.end()) {
JumpTables.insert(JT);
continue;
}
// This instruction is an indirect jump using a jump table, but it is
// using the same jump table of another jump. Try all our tricks to
// extract the jump table symbol and make it point to a new, duplicated JT
uint64_t Scale;
const MCSymbol *Target;
MCInst *JTLoadInst = &Inst;
// Try a standard indirect jump matcher, scale 8
auto IndJmpMatcher = BC.MIB->matchIndJmp(
BC.MIB->matchReg(), BC.MIB->matchImm(Scale), BC.MIB->matchReg(),
/*Offset=*/BC.MIB->matchSymbol(Target));
if (!BC.MIB->hasPCRelOperand(Inst) ||
!IndJmpMatcher->match(
*BC.MRI, *BC.MIB,
MutableArrayRef<MCInst>(&*BB->begin(), &Inst + 1), -1) ||
Scale != 8) {
// Standard JT matching failed. Trying now:
// PIC-style matcher, scale 4
// addq %rdx, %rsi
// addq %rdx, %rdi
// leaq DATAat0x402450(%rip), %r11
// movslq (%r11,%rdx,4), %rcx
// addq %r11, %rcx
// jmpq *%rcx # JUMPTABLE @0x402450
MCPhysReg BaseReg1;
MCPhysReg BaseReg2;
uint64_t Offset;
auto PICIndJmpMatcher = BC.MIB->matchIndJmp(BC.MIB->matchAdd(
BC.MIB->matchReg(BaseReg1),
BC.MIB->matchLoad(BC.MIB->matchReg(BaseReg2),
BC.MIB->matchImm(Scale), BC.MIB->matchReg(),
BC.MIB->matchImm(Offset))));
auto LEAMatcherOwner =
BC.MIB->matchLoadAddr(BC.MIB->matchSymbol(Target));
auto LEAMatcher = LEAMatcherOwner.get();
auto PICBaseAddrMatcher = BC.MIB->matchIndJmp(BC.MIB->matchAdd(
std::move(LEAMatcherOwner), BC.MIB->matchAnyOperand()));
if (!PICIndJmpMatcher->match(
*BC.MRI, *BC.MIB,
MutableArrayRef<MCInst>(&*BB->begin(), &Inst + 1), -1) ||
Scale != 4 || BaseReg1 != BaseReg2 || Offset != 0 ||
!PICBaseAddrMatcher->match(
*BC.MRI, *BC.MIB,
MutableArrayRef<MCInst>(&*BB->begin(), &Inst + 1), -1)) {
llvm_unreachable("Failed to extract jump table base");
continue;
}
// Matched PIC
JTLoadInst = &*LEAMatcher->CurInst;
}
uint64_t NewJumpTableID{0};
const MCSymbol *NewJTLabel;
std::tie(NewJumpTableID, NewJTLabel) =
BC.duplicateJumpTable(*this, JT, Target);
BC.MIB->replaceMemOperandDisp(*JTLoadInst, NewJTLabel, BC.Ctx.get());
// We use a unique ID with the high bit set as address for this "injected"
// jump table (not originally in the input binary).
BC.MIB->setJumpTable(Inst, NewJumpTableID, 0);
}
}
}
bool BinaryFunction::replaceJumpTableEntryIn(BinaryBasicBlock *BB, bool BinaryFunction::replaceJumpTableEntryIn(BinaryBasicBlock *BB,
BinaryBasicBlock *OldDest, BinaryBasicBlock *OldDest,
BinaryBasicBlock *NewDest) { BinaryBasicBlock *NewDest) {
@ -3641,7 +3750,8 @@ BinaryBasicBlock *BinaryFunction::splitEdge(BinaryBasicBlock *From,
// Update CFI and BB layout with new intermediate BB // Update CFI and BB layout with new intermediate BB
std::vector<std::unique_ptr<BinaryBasicBlock>> NewBBs; std::vector<std::unique_ptr<BinaryBasicBlock>> NewBBs;
NewBBs.emplace_back(std::move(NewBB)); NewBBs.emplace_back(std::move(NewBB));
insertBasicBlocks(From, std::move(NewBBs), true, true); insertBasicBlocks(From, std::move(NewBBs), true, true,
/*RecomputeLandingPads=*/false);
return NewBBPtr; return NewBBPtr;
} }

View File

@ -1441,13 +1441,15 @@ public:
BinaryBasicBlock *Start, BinaryBasicBlock *Start,
std::vector<std::unique_ptr<BinaryBasicBlock>> &&NewBBs, std::vector<std::unique_ptr<BinaryBasicBlock>> &&NewBBs,
const bool UpdateLayout = true, const bool UpdateLayout = true,
const bool UpdateCFIState = true); const bool UpdateCFIState = true,
const bool RecomputeLandingPads = true);
iterator insertBasicBlocks( iterator insertBasicBlocks(
iterator StartBB, iterator StartBB,
std::vector<std::unique_ptr<BinaryBasicBlock>> &&NewBBs, std::vector<std::unique_ptr<BinaryBasicBlock>> &&NewBBs,
const bool UpdateLayout = true, const bool UpdateLayout = true,
const bool UpdateCFIState = true); const bool UpdateCFIState = true,
const bool RecomputeLandingPads = true);
/// Update the basic block layout for this function. The BBs from /// Update the basic block layout for this function. The BBs from
/// [Start->Index, Start->Index + NumNewBlocks) are inserted into the /// [Start->Index, Start->Index + NumNewBlocks) are inserted into the
@ -1466,6 +1468,20 @@ public:
/// new blocks into the CFG. This must be called after updateLayout. /// new blocks into the CFG. This must be called after updateLayout.
void updateCFIState(BinaryBasicBlock *Start, const unsigned NumNewBlocks); void updateCFIState(BinaryBasicBlock *Start, const unsigned NumNewBlocks);
/// Return true if we detected ambiguous jump tables in this function, which
/// happen when one JT is used in more than one indirect jumps. This precludes
/// us from splitting edges for this JT unless we duplicate the JT (see
/// disambiguateJumpTables).
bool checkForAmbiguousJumpTables();
/// Detect when two distinct indirect jumps are using the same jump table and
/// duplicate it, allocating a separate JT for each indirect branch. This is
/// necessary for code transformations on the CFG that change an edge induced
/// by an indirect branch, e.g.: instrumentation or shrink wrapping. However,
/// this is only possible if we are not updating jump tables in place, but are
/// writing it to a new location (moving them).
void disambiguateJumpTables();
/// Change \p OrigDest to \p NewDest in the jump table used at the end of /// Change \p OrigDest to \p NewDest in the jump table used at the end of
/// \p BB. Returns false if \p OrigDest couldn't be find as a valid target /// \p BB. Returns false if \p OrigDest couldn't be find as a valid target
/// and no replacement took place. /// and no replacement took place.

View File

@ -47,6 +47,10 @@ JumpTable::JumpTable(StringRef Name,
std::pair<size_t, size_t> std::pair<size_t, size_t>
JumpTable::getEntriesForAddress(const uint64_t Addr) const { JumpTable::getEntriesForAddress(const uint64_t Addr) const {
// Check if this is not an address, but a cloned JT id
if ((int64_t)Addr < 0ll)
return std::make_pair(0, Entries.size());
const uint64_t InstOffset = Addr - getAddress(); const uint64_t InstOffset = Addr - getAddress();
size_t StartIndex = 0, EndIndex = 0; size_t StartIndex = 0, EndIndex = 0;
uint64_t Offset = 0; uint64_t Offset = 0;
@ -73,13 +77,12 @@ JumpTable::getEntriesForAddress(const uint64_t Addr) const {
return std::make_pair(StartIndex, EndIndex); return std::make_pair(StartIndex, EndIndex);
} }
bool JumpTable::replaceDestination(uint64_t JTAddress, bool JumpTable::replaceDestination(uint64_t JTAddress, const MCSymbol *OldDest,
const MCSymbol *OldDest,
MCSymbol *NewDest) { MCSymbol *NewDest) {
bool Patched{false}; bool Patched{false};
const auto Range = getEntriesForAddress(JTAddress); const auto Range = getEntriesForAddress(JTAddress);
for (auto I = &Entries[Range.first], E = &Entries[Range.second]; for (auto I = &Entries[Range.first], E = &Entries[Range.second]; I != E;
I != E; ++I) { ++I) {
auto &Entry = *I; auto &Entry = *I;
if (Entry == OldDest) { if (Entry == OldDest) {
Patched = true; Patched = true;

View File

@ -49,7 +49,6 @@ public:
JTT_PIC, JTT_PIC,
}; };
public:
/// Branch statistics for jump table entries. /// Branch statistics for jump table entries.
struct JumpInfo { struct JumpInfo {
uint64_t Mispreds{0}; uint64_t Mispreds{0};

View File

@ -171,9 +171,8 @@ bool MCPlusBuilder::setJumpTable(MCInst &Inst, uint64_t Value,
uint16_t IndexReg) { uint16_t IndexReg) {
if (!isIndirectBranch(Inst)) if (!isIndirectBranch(Inst))
return false; return false;
assert(getJumpTable(Inst) == 0 && "jump table already set");
setAnnotationOpValue(Inst, MCAnnotation::kJumpTable, Value); setAnnotationOpValue(Inst, MCAnnotation::kJumpTable, Value);
addAnnotation<>(Inst, "JTIndexReg", IndexReg); getOrCreateAnnotationAs<uint16_t>(Inst, "JTIndexReg") = IndexReg;
return true; return true;
} }

View File

@ -28,6 +28,13 @@ cl::opt<std::string> InstrumentationFilename(
cl::init("/tmp/prof.fdata"), cl::init("/tmp/prof.fdata"),
cl::Optional, cl::Optional,
cl::cat(BoltCategory)); cl::cat(BoltCategory));
cl::opt<bool> InstrumentHotOnly(
"instrument-hot-only",
cl::desc("only insert instrumentation on hot functions (need profile)"),
cl::init(false),
cl::Optional,
cl::cat(BoltCategory));
} }
namespace llvm { namespace llvm {
@ -138,8 +145,10 @@ void Instrumentation::runOnFunctions(BinaryContext &BC) {
uint64_t InstrumentationSitesSavingFlags{0ULL}; uint64_t InstrumentationSitesSavingFlags{0ULL};
for (auto &BFI : BC.getBinaryFunctions()) { for (auto &BFI : BC.getBinaryFunctions()) {
BinaryFunction &Function = BFI.second; BinaryFunction &Function = BFI.second;
if (!Function.isSimple() || !opts::shouldProcess(Function)) if (!Function.isSimple() || !opts::shouldProcess(Function)
|| (opts::InstrumentHotOnly && !Function.getKnownExecutionCount()))
continue; continue;
Function.disambiguateJumpTables();
SplitWorklist.clear(); SplitWorklist.clear();
SplitInstrs.clear(); SplitInstrs.clear();
@ -193,6 +202,7 @@ void Instrumentation::runOnFunctions(BinaryContext &BC) {
if (!HasUnconditionalBranch && !HasJumpTable && BB.succ_size() > 0 && if (!HasUnconditionalBranch && !HasJumpTable && BB.succ_size() > 0 &&
BB.size() > 0) { BB.size() > 0) {
auto *FTBB = BB.getFallthrough(); auto *FTBB = BB.getFallthrough();
assert(FTBB && "expected valid fall-through basic block");
auto I = BB.begin(); auto I = BB.begin();
auto LastInstr = BB.end(); auto LastInstr = BB.end();
--LastInstr; --LastInstr;

View File

@ -1910,6 +1910,15 @@ bool ShrinkWrapping::perform() {
PopOffsetByReg = std::vector<int64_t>(BC.MRI->getNumRegs(), 0LL); PopOffsetByReg = std::vector<int64_t>(BC.MRI->getNumRegs(), 0LL);
DomOrder = std::vector<MCPhysReg>(BC.MRI->getNumRegs(), 0); DomOrder = std::vector<MCPhysReg>(BC.MRI->getNumRegs(), 0);
if (BF.checkForAmbiguousJumpTables()) {
DEBUG(dbgs() << "BOLT-DEBUG: ambiguous JTs in " << BF.getPrintName()
<< ".\n");
// We could call disambiguateJumpTables here, but it is probably not worth
// the cost (of duplicating potentially large jump tables that could regress
// dcache misses). Moreover, ambiguous JTs are rare and coming from code
// written in assembly language. Just bail.
return false;
}
SLM.initialize(); SLM.initialize();
CSA.compute(); CSA.compute();
classifyCSRUses(); classifyCSRUses();

View File

@ -2736,7 +2736,7 @@ public:
bool createIncMemory(MCInst &Inst, const MCSymbol *Target, bool createIncMemory(MCInst &Inst, const MCSymbol *Target,
MCContext *Ctx) const override { MCContext *Ctx) const override {
Inst.setOpcode(X86::INC64m); Inst.setOpcode(X86::LOCK_INC64m);
Inst.clear(); Inst.clear();
Inst.addOperand(MCOperand::createReg(X86::RIP)); // BaseReg Inst.addOperand(MCOperand::createReg(X86::RIP)); // BaseReg
Inst.addOperand(MCOperand::createImm(1)); // ScaleAmt Inst.addOperand(MCOperand::createImm(1)); // ScaleAmt