diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index 32d7fa1213f4..142f157c3c84 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -752,38 +752,27 @@ static inline bool HasBranchWeights(const Instruction* I) { return false; } -/// Tries to get a branch weight for the given instruction, returns NULL if it -/// can't. Pos starts at 0. -static ConstantInt* GetWeight(Instruction* I, int Pos) { - MDNode* ProfMD = I->getMetadata(LLVMContext::MD_prof); - if (ProfMD && ProfMD->getOperand(0)) { - if (MDString* MDS = dyn_cast(ProfMD->getOperand(0))) { - if (MDS->getString().equals("branch_weights")) { - assert(ProfMD->getNumOperands() >= 3); - return dyn_cast(ProfMD->getOperand(1 + Pos)); - } - } - } - - return 0; -} - -/// Scale the given weights based on the successor TI's metadata. Scaling is -/// done by multiplying every weight by the sum of the successor's weights. -static void ScaleWeights(Instruction* STI, MutableArrayRef Weights) { - // Sum the successor's weights - assert(HasBranchWeights(STI)); - unsigned Scale = 0; - MDNode* ProfMD = STI->getMetadata(LLVMContext::MD_prof); - for (unsigned i = 1; i < ProfMD->getNumOperands(); ++i) { - ConstantInt* CI = dyn_cast(ProfMD->getOperand(i)); +/// Get Weights of a given TerminatorInst, the default weight is at the front +/// of the vector. If TI is a conditional eq, we need to swap the branch-weight +/// metadata. +static void GetBranchWeights(TerminatorInst *TI, + SmallVectorImpl &Weights) { + MDNode* MD = TI->getMetadata(LLVMContext::MD_prof); + assert(MD); + for (unsigned i = 1, e = MD->getNumOperands(); i < e; ++i) { + ConstantInt* CI = dyn_cast(MD->getOperand(i)); assert(CI); - Scale += CI->getValue().getZExtValue(); + Weights.push_back(CI->getValue().getZExtValue()); } - // Skip default, as it's replaced during the folding - for (unsigned i = 1; i < Weights.size(); ++i) { - Weights[i] *= Scale; + // If TI is a conditional eq, the default case is the false case, + // and the corresponding branch-weight data is at index 2. We swap the + // default weight to be the first entry. + if (BranchInst* BI = dyn_cast(TI)) { + assert(Weights.size() == 2); + ICmpInst *ICI = cast(BI->getCondition()); + if (ICI->getPredicate() == ICmpInst::ICMP_EQ) + std::swap(Weights.front(), Weights.back()); } } @@ -838,52 +827,22 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI, // Update the branch weight metadata along the way SmallVector Weights; - uint64_t PredDefaultWeight = 0; bool PredHasWeights = HasBranchWeights(PTI); bool SuccHasWeights = HasBranchWeights(TI); - if (PredHasWeights) { - MDNode* MD = PTI->getMetadata(LLVMContext::MD_prof); - assert(MD); - for (unsigned i = 1, e = MD->getNumOperands(); i < e; ++i) { - ConstantInt* CI = dyn_cast(MD->getOperand(i)); - assert(CI); - Weights.push_back(CI->getValue().getZExtValue()); - } - - // If the predecessor is a conditional eq, then swap the default weight - // to be the first entry. - if (BranchInst* BI = dyn_cast(PTI)) { - assert(Weights.size() == 2); - ICmpInst *ICI = cast(BI->getCondition()); - - if (ICI->getPredicate() == ICmpInst::ICMP_EQ) { - std::swap(Weights.front(), Weights.back()); - } - } - - PredDefaultWeight = Weights.front(); - } else if (SuccHasWeights) { + if (PredHasWeights) + GetBranchWeights(PTI, Weights); + else if (SuccHasWeights) // If there are no predecessor weights but there are successor weights, // populate Weights with 1, which will later be scaled to the sum of // successor's weights Weights.assign(1 + PredCases.size(), 1); - PredDefaultWeight = 1; - } - uint64_t SuccDefaultWeight = 0; - if (SuccHasWeights) { - int Index = 0; - if (BranchInst* BI = dyn_cast(TI)) { - ICmpInst* ICI = dyn_cast(BI->getCondition()); - assert(ICI); - - if (ICI->getPredicate() == ICmpInst::ICMP_EQ) - Index = 1; - } - - SuccDefaultWeight = GetWeight(TI, Index)->getValue().getZExtValue(); - } + SmallVector SuccWeights; + if (SuccHasWeights) + GetBranchWeights(TI, SuccWeights); + else if (PredHasWeights) + SuccWeights.assign(1 + BBCases.size(), 1); if (PredDefault == BB) { // If this is the default destination from PTI, only the edges in TI @@ -896,7 +855,9 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI, // The default destination is BB, we don't need explicit targets. std::swap(PredCases[i], PredCases.back()); - if (PredHasWeights) { + if (PredHasWeights || SuccHasWeights) { + // Increase weight for the default case. + Weights[0] += Weights[i+1]; std::swap(Weights[i+1], Weights.back()); Weights.pop_back(); } @@ -912,27 +873,30 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI, NewSuccessors.push_back(BBDefault); } - if (SuccHasWeights) { - ScaleWeights(TI, Weights); - Weights.front() *= SuccDefaultWeight; - } else if (PredHasWeights) { - Weights.front() /= (1 + BBCases.size()); - } - + unsigned CasesFromPred = Weights.size(); + uint64_t ValidTotalSuccWeight = 0; for (unsigned i = 0, e = BBCases.size(); i != e; ++i) if (!PTIHandled.count(BBCases[i].Value) && BBCases[i].Dest != BBDefault) { PredCases.push_back(BBCases[i]); NewSuccessors.push_back(BBCases[i].Dest); - if (SuccHasWeights) { - Weights.push_back(PredDefaultWeight * - GetWeight(TI, i)->getValue().getZExtValue()); - } else if (PredHasWeights) { - // Split the old default's weight amongst the children - Weights.push_back(PredDefaultWeight / (1 + BBCases.size())); + if (SuccHasWeights || PredHasWeights) { + // The default weight is at index 0, so weight for the ith case + // should be at index i+1. Scale the cases from successor by + // PredDefaultWeight (Weights[0]). + Weights.push_back(Weights[0] * SuccWeights[i+1]); + ValidTotalSuccWeight += SuccWeights[i+1]; } } + if (SuccHasWeights || PredHasWeights) { + ValidTotalSuccWeight += SuccWeights[0]; + // Scale the cases from predecessor by ValidTotalSuccWeight. + for (unsigned i = 1; i < CasesFromPred; ++i) + Weights[i] *= ValidTotalSuccWeight; + // Scale the default weight by SuccDefaultWeight (SuccWeights[0]). + Weights[0] *= SuccWeights[0]; + } } else { // FIXME: preserve branch weight metadata, similarly to the 'then' // above. For now, drop it. diff --git a/llvm/test/Transforms/SimplifyCFG/preserve-branchweights-switch-create.ll b/llvm/test/Transforms/SimplifyCFG/preserve-branchweights-switch-create.ll new file mode 100644 index 000000000000..75f5f06daa7d --- /dev/null +++ b/llvm/test/Transforms/SimplifyCFG/preserve-branchweights-switch-create.ll @@ -0,0 +1,92 @@ +; RUN: opt -simplifycfg -S -o - < %s | FileCheck %s + +declare void @func2(i32) +declare void @func4(i32) +declare void @func6(i32) +declare void @func8(i32) + +;; test1 - create a switch with case 2 and case 4 from two branches: N == 2 +;; and N == 4. +define void @test1(i32 %N) nounwind uwtable { +entry: + %cmp = icmp eq i32 %N, 2 + br i1 %cmp, label %if.then, label %if.else, !prof !0 +; CHECK: test1 +; CHECK: switch i32 %N +; CHECK: ], !prof !0 + +if.then: + call void @func2(i32 %N) nounwind + br label %if.end9 + +if.else: + %cmp2 = icmp eq i32 %N, 4 + br i1 %cmp2, label %if.then7, label %if.else8, !prof !1 + +if.then7: + call void @func4(i32 %N) nounwind + br label %if.end + +if.else8: + call void @func8(i32 %N) nounwind + br label %if.end + +if.end: + br label %if.end9 + +if.end9: + ret void +} + +;; test2 - Merge two switches where PredDefault == BB. +define void @test2(i32 %M, i32 %N) nounwind uwtable { +entry: + %cmp = icmp sgt i32 %M, 2 + br i1 %cmp, label %sw1, label %sw2 + +sw1: + switch i32 %N, label %sw2 [ + i32 2, label %sw.bb + i32 3, label %sw.bb1 + ], !prof !2 +; CHECK: test2 +; CHECK: switch i32 %N, label %sw.epilog +; CHECK: i32 2, label %sw.bb +; CHECK: i32 3, label %sw.bb1 +; CHECK: i32 4, label %sw.bb5 +; CHECK: ], !prof !1 + +sw.bb: + call void @func2(i32 %N) nounwind + br label %sw.epilog + +sw.bb1: + call void @func4(i32 %N) nounwind + br label %sw.epilog + +sw2: +;; Here "case 2" is invalidated if control is transferred through default case +;; of the first switch. + switch i32 %N, label %sw.epilog [ + i32 2, label %sw.bb4 + i32 4, label %sw.bb5 + ], !prof !3 + +sw.bb4: + call void @func6(i32 %N) nounwind + br label %sw.epilog + +sw.bb5: + call void @func8(i32 %N) nounwind + br label %sw.epilog + +sw.epilog: + ret void +} + +!0 = metadata !{metadata !"branch_weights", i32 64, i32 4} +!1 = metadata !{metadata !"branch_weights", i32 4, i32 64} +; CHECK: !0 = metadata !{metadata !"branch_weights", i32 256, i32 4352, i32 16} +!2 = metadata !{metadata !"branch_weights", i32 4, i32 4, i32 8} +!3 = metadata !{metadata !"branch_weights", i32 8, i32 8, i32 4} +; CHECK: !1 = metadata !{metadata !"branch_weights", i32 32, i32 48, i32 96, i32 16}