diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index ab988f69dc3d..8313a48c3467 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -7996,6 +7996,18 @@ void SelectionDAGBuilder::lowerWorkItem(SwitchWorkListItem W, Value *Cond, } } +unsigned SelectionDAGBuilder::caseClusterRank(const CaseCluster &CC, + CaseClusterIt First, + CaseClusterIt Last) { + return std::count_if(First, Last + 1, [&](const CaseCluster &X) { + if (X.Weight != CC.Weight) + return X.Weight > CC.Weight; + + // Ties are broken by comparing the case value. + return X.Low->getValue().slt(CC.Low->getValue()); + }); +} + void SelectionDAGBuilder::splitWorkItem(SwitchWorkList &WorkList, const SwitchWorkListItem &W, Value *Cond, @@ -8025,6 +8037,48 @@ void SelectionDAGBuilder::splitWorkItem(SwitchWorkList &WorkList, RightWeight += (--FirstRight)->Weight; I++; } + + for (;;) { + // Our binary search tree differs from a typical BST in that ours can have up + // to three values in each leaf. The pivot selection above doesn't take that + // into account, which means the tree might require more nodes and be less + // efficient. We compensate for this here. + + unsigned NumLeft = LastLeft - W.FirstCluster + 1; + unsigned NumRight = W.LastCluster - FirstRight + 1; + + if (std::min(NumLeft, NumRight) < 3 && std::max(NumLeft, NumRight) > 3) { + // If one side has less than 3 clusters, and the other has more than 3, + // consider taking a cluster from the other side. + + if (NumLeft < NumRight) { + // Consider moving the first cluster on the right to the left side. + CaseCluster &CC = *FirstRight; + unsigned RightSideRank = caseClusterRank(CC, FirstRight, W.LastCluster); + unsigned LeftSideRank = caseClusterRank(CC, W.FirstCluster, LastLeft); + if (LeftSideRank <= RightSideRank) { + // Moving the cluster to the left does not demote it. + ++LastLeft; + ++FirstRight; + continue; + } + } else { + assert(NumRight < NumLeft); + // Consider moving the last element on the left to the right side. + CaseCluster &CC = *LastLeft; + unsigned LeftSideRank = caseClusterRank(CC, W.FirstCluster, LastLeft); + unsigned RightSideRank = caseClusterRank(CC, FirstRight, W.LastCluster); + if (RightSideRank <= LeftSideRank) { + // Moving the cluster to the right does not demot it. + --LastLeft; + --FirstRight; + continue; + } + } + } + break; + } + assert(LastLeft + 1 == FirstRight); assert(LastLeft >= W.FirstCluster); assert(FirstRight <= W.LastCluster); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h index f0c03af3f64b..f225d54d189d 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h @@ -342,6 +342,11 @@ private: }; typedef SmallVector SwitchWorkList; + /// Determine the rank by weight of CC in [First,Last]. If CC has more weight + /// than each cluster in the range, its rank is 0. + static unsigned caseClusterRank(const CaseCluster &CC, CaseClusterIt First, + CaseClusterIt Last); + /// Emit comparison and split W into two subtrees. void splitWorkItem(SwitchWorkList &WorkList, const SwitchWorkListItem &W, Value *Cond, MachineBasicBlock *SwitchMBB); diff --git a/llvm/test/CodeGen/X86/switch.ll b/llvm/test/CodeGen/X86/switch.ll index fc217d5203d8..748fd6f238b1 100644 --- a/llvm/test/CodeGen/X86/switch.ll +++ b/llvm/test/CodeGen/X86/switch.ll @@ -499,6 +499,8 @@ entry: i32 30, label %bb3 i32 40, label %bb4 i32 50, label %bb5 + i32 60, label %bb6 + i32 70, label %bb6 ], !prof !4 bb0: tail call void @g(i32 0) br label %return bb1: tail call void @g(i32 1) br label %return @@ -506,16 +508,87 @@ bb2: tail call void @g(i32 2) br label %return bb3: tail call void @g(i32 3) br label %return bb4: tail call void @g(i32 4) br label %return bb5: tail call void @g(i32 5) br label %return +bb6: tail call void @g(i32 6) br label %return +bb7: tail call void @g(i32 7) br label %return return: ret void -; To balance the tree by weight, the pivot is shifted to the right, moving hot -; cases closer to the root. +; Without branch probabilities, the pivot would be 40, since that would yield +; equal-sized sub-trees. When taking weights into account, case 70 becomes the +; pivot. Since there is room for 3 cases in a leaf, cases 50 and 60 are also +; included in the right-hand side because that doesn't reduce their rank. + ; CHECK-LABEL: left_leaning_weight_balanced_tree ; CHECK-NOT: cmpl -; CHECK: cmpl $39 +; CHECK: cmpl $49 } -!4 = !{!"branch_weights", i32 1, i32 10, i32 1, i32 1, i32 1, i32 10, i32 10} +!4 = !{!"branch_weights", i32 1, i32 10, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1000} + + +define void @left_leaning_weight_balanced_tree2(i32 %x) { +entry: + switch i32 %x, label %return [ + i32 0, label %bb0 + i32 10, label %bb1 + i32 20, label %bb2 + i32 30, label %bb3 + i32 40, label %bb4 + i32 50, label %bb5 + i32 60, label %bb6 + i32 70, label %bb6 + ], !prof !5 +bb0: tail call void @g(i32 0) br label %return +bb1: tail call void @g(i32 1) br label %return +bb2: tail call void @g(i32 2) br label %return +bb3: tail call void @g(i32 3) br label %return +bb4: tail call void @g(i32 4) br label %return +bb5: tail call void @g(i32 5) br label %return +bb6: tail call void @g(i32 6) br label %return +bb7: tail call void @g(i32 7) br label %return +return: ret void + +; Same as the previous test, except case 50 has higher rank to the left than it +; would have on the right. Case 60 would have the same rank on both sides, so is +; moved into the leaf. + +; CHECK-LABEL: left_leaning_weight_balanced_tree2 +; CHECK-NOT: cmpl +; CHECK: cmpl $59 +} + +!5 = !{!"branch_weights", i32 1, i32 10, i32 1, i32 1, i32 1, i32 1, i32 90, i32 70, i32 1000} + + +define void @right_leaning_weight_balanced_tree(i32 %x) { +entry: + switch i32 %x, label %return [ + i32 0, label %bb0 + i32 10, label %bb1 + i32 20, label %bb2 + i32 30, label %bb3 + i32 40, label %bb4 + i32 50, label %bb5 + i32 60, label %bb6 + i32 70, label %bb6 + ], !prof !6 +bb0: tail call void @g(i32 0) br label %return +bb1: tail call void @g(i32 1) br label %return +bb2: tail call void @g(i32 2) br label %return +bb3: tail call void @g(i32 3) br label %return +bb4: tail call void @g(i32 4) br label %return +bb5: tail call void @g(i32 5) br label %return +bb6: tail call void @g(i32 6) br label %return +bb7: tail call void @g(i32 7) br label %return +return: ret void + +; Analogous to left_leaning_weight_balanced_tree. + +; CHECK-LABEL: right_leaning_weight_balanced_tree +; CHECK-NOT: cmpl +; CHECK: cmpl $19 +} + +!6 = !{!"branch_weights", i32 1, i32 1000, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 10} define void @jump_table_affects_balance(i32 %x) {