diff --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h index ba20d77215f5..43cb9a2ffde5 100644 --- a/llvm/include/llvm/IR/Instructions.h +++ b/llvm/include/llvm/IR/Instructions.h @@ -2906,9 +2906,10 @@ public: /// continues to map correctly to each operand. void swapSuccessors(); - /// Retrieve the probabilities of a conditional branch. Returns true on - /// success, or returns false if no or invalid metadata was found. - bool extractProfMetadata(uint64_t &ProbTrue, uint64_t &ProbFalse); + /// Retrieve the raw weight values of a conditional branch. + /// Returns true on success with profile weights filled in. + /// Returns false if no metadata or invalid metadata was found. + bool extractProfMetadata(uint64_t &TrueVal, uint64_t &FalseVal); // Methods for support type inquiry through isa, cast, and dyn_cast: static inline bool classof(const Instruction *I) { diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp index 0d5bd9e9429b..d66ec8f8d455 100644 --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -1120,20 +1120,24 @@ void BranchInst::swapSuccessors() { MDNode::get(ProfileData->getContext(), Ops)); } -bool BranchInst::extractProfMetadata(uint64_t &ProbTrue, uint64_t &ProbFalse) { +bool BranchInst::extractProfMetadata(uint64_t &TrueVal, uint64_t &FalseVal) { assert(isConditional() && "Looking for probabilities on unconditional branch?"); auto *ProfileData = getMetadata(LLVMContext::MD_prof); if (!ProfileData || ProfileData->getNumOperands() != 3) return false; + auto *ProfDataName = dyn_cast(ProfileData->getOperand(0)); + if (!ProfDataName || !ProfDataName->getString().equals("branch_weights")) + return false; + auto *CITrue = mdconst::dyn_extract(ProfileData->getOperand(1)); auto *CIFalse = mdconst::dyn_extract(ProfileData->getOperand(2)); if (!CITrue || !CIFalse) return false; - ProbTrue = CITrue->getValue().getZExtValue(); - ProbFalse = CIFalse->getValue().getZExtValue(); + TrueVal = CITrue->getValue().getZExtValue(); + FalseVal = CIFalse->getValue().getZExtValue(); return true; } diff --git a/llvm/test/Transforms/SimplifyCFG/preserve-branchweights.ll b/llvm/test/Transforms/SimplifyCFG/preserve-branchweights.ll index ae1794b1c61a..118c4c116b58 100644 --- a/llvm/test/Transforms/SimplifyCFG/preserve-branchweights.ll +++ b/llvm/test/Transforms/SimplifyCFG/preserve-branchweights.ll @@ -21,6 +21,29 @@ Z: ret void } +; Make sure the metadata name string is "branch_weights" before propagating it. + +define void @fake_weights(i1 %a, i1 %b) { +; CHECK-LABEL: @fake_weights( +entry: + br i1 %a, label %Y, label %X, !prof !12 +; CHECK: %or.cond = and i1 %a.not, %c +; CHECK-NEXT: br i1 %or.cond, label %Z, label %Y +; CHECK-NOT: !prof !0 +; CHECK: Y: +X: + %c = or i1 %b, false + br i1 %c, label %Z, label %Y, !prof !1 + +Y: + call void @helper(i32 0) + ret void + +Z: + call void @helper(i32 1) + ret void +} + define void @test2(i1 %a, i1 %b) { ; CHECK-LABEL: @test2( entry: @@ -376,6 +399,7 @@ for.exit: !9 = !{!"branch_weights", i32 7, i32 6} !10 = !{!"branch_weights", i32 672646, i32 21604207} !11 = !{!"branch_weights", i32 6960, i32 21597248} +!12 = !{!"these_are_not_the_branch_weights_you_are_looking_for", i32 3, i32 5} ; CHECK: !0 = !{!"branch_weights", i32 5, i32 11} ; CHECK: !1 = !{!"branch_weights", i32 1, i32 5}