diff --git a/clang/lib/CodeGen/CodeGenPGO.cpp b/clang/lib/CodeGen/CodeGenPGO.cpp index c90b025e5511..d9016774fa11 100644 --- a/clang/lib/CodeGen/CodeGenPGO.cpp +++ b/clang/lib/CodeGen/CodeGenPGO.cpp @@ -264,6 +264,12 @@ struct ComputeRegionCounts : public ConstStmtVisitor { } } + /// Set and return the current count. + uint64_t setCount(uint64_t Count) { + PGO.setCurrentRegionCount(Count); + return Count; + } + void VisitStmt(const Stmt *S) { RecordStmtCount(S); for (Stmt::const_child_range I = S->children(); I; ++I) { @@ -274,9 +280,8 @@ struct ComputeRegionCounts : public ConstStmtVisitor { void VisitFunctionDecl(const FunctionDecl *D) { // Counter tracks entry to the function body. - RegionCounter Cnt(PGO, D->getBody()); - Cnt.beginRegion(); - CountMap[D->getBody()] = PGO.getCurrentRegionCount(); + uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); + CountMap[D->getBody()] = BodyCount; Visit(D->getBody()); } @@ -287,25 +292,22 @@ struct ComputeRegionCounts : public ConstStmtVisitor { void VisitCapturedDecl(const CapturedDecl *D) { // Counter tracks entry to the capture body. - RegionCounter Cnt(PGO, D->getBody()); - Cnt.beginRegion(); - CountMap[D->getBody()] = PGO.getCurrentRegionCount(); + uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); + CountMap[D->getBody()] = BodyCount; Visit(D->getBody()); } void VisitObjCMethodDecl(const ObjCMethodDecl *D) { // Counter tracks entry to the method body. - RegionCounter Cnt(PGO, D->getBody()); - Cnt.beginRegion(); - CountMap[D->getBody()] = PGO.getCurrentRegionCount(); + uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); + CountMap[D->getBody()] = BodyCount; Visit(D->getBody()); } void VisitBlockDecl(const BlockDecl *D) { // Counter tracks entry to the block body. - RegionCounter Cnt(PGO, D->getBody()); - Cnt.beginRegion(); - CountMap[D->getBody()] = PGO.getCurrentRegionCount(); + uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); + CountMap[D->getBody()] = BodyCount; Visit(D->getBody()); } @@ -334,9 +336,8 @@ struct ComputeRegionCounts : public ConstStmtVisitor { void VisitLabelStmt(const LabelStmt *S) { RecordNextStmtCount = false; // Counter tracks the block following the label. - RegionCounter Cnt(PGO, S); - Cnt.beginRegion(); - CountMap[S] = PGO.getCurrentRegionCount(); + uint64_t BlockCount = setCount(PGO.getRegionCount(S)); + CountMap[S] = BlockCount; Visit(S->getSubStmt()); } @@ -358,52 +359,47 @@ struct ComputeRegionCounts : public ConstStmtVisitor { void VisitWhileStmt(const WhileStmt *S) { RecordStmtCount(S); - // Counter tracks the body of the loop. - RegionCounter Cnt(PGO, S); + uint64_t ParentCount = PGO.getCurrentRegionCount(); + BreakContinueStack.push_back(BreakContinue()); // Visit the body region first so the break/continue adjustments can be // included when visiting the condition. - Cnt.beginRegion(); + uint64_t BodyCount = setCount(PGO.getRegionCount(S)); CountMap[S->getBody()] = PGO.getCurrentRegionCount(); Visit(S->getBody()); - Cnt.adjustForControlFlow(); + uint64_t BackedgeCount = PGO.getCurrentRegionCount(); // ...then go back and propagate counts through the condition. The count // at the start of the condition is the sum of the incoming edges, // the backedge from the end of the loop body, and the edges from // continue statements. BreakContinue BC = BreakContinueStack.pop_back_val(); - Cnt.setCurrentRegionCount(Cnt.getParentCount() + Cnt.getAdjustedCount() + - BC.ContinueCount); - CountMap[S->getCond()] = PGO.getCurrentRegionCount(); + uint64_t CondCount = + setCount(ParentCount + BackedgeCount + BC.ContinueCount); + CountMap[S->getCond()] = CondCount; Visit(S->getCond()); - Cnt.adjustForControlFlow(); - Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); + setCount(BC.BreakCount + CondCount - BodyCount); RecordNextStmtCount = true; } void VisitDoStmt(const DoStmt *S) { RecordStmtCount(S); - // Counter tracks the body of the loop. - RegionCounter Cnt(PGO, S); + uint64_t LoopCount = PGO.getRegionCount(S); + BreakContinueStack.push_back(BreakContinue()); - Cnt.beginRegion(/*AddIncomingFallThrough=*/true); - CountMap[S->getBody()] = PGO.getCurrentRegionCount(); + // The count doesn't include the fallthrough from the parent scope. Add it. + uint64_t BodyCount = setCount(LoopCount + PGO.getCurrentRegionCount()); + CountMap[S->getBody()] = BodyCount; Visit(S->getBody()); - Cnt.adjustForControlFlow(); + uint64_t BackedgeCount = PGO.getCurrentRegionCount(); BreakContinue BC = BreakContinueStack.pop_back_val(); // The count at the start of the condition is equal to the count at the - // end of the body. The adjusted count does not include either the - // fall-through count coming into the loop or the continue count, so add - // both of those separately. This is coincidentally the same equation as - // with while loops but for different reasons. - Cnt.setCurrentRegionCount(Cnt.getParentCount() + Cnt.getAdjustedCount() + - BC.ContinueCount); - CountMap[S->getCond()] = PGO.getCurrentRegionCount(); + // end of the body, plus any continues. + uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount); + CountMap[S->getCond()] = CondCount; Visit(S->getCond()); - Cnt.adjustForControlFlow(); - Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); + setCount(BC.BreakCount + CondCount - LoopCount); RecordNextStmtCount = true; } @@ -411,37 +407,34 @@ struct ComputeRegionCounts : public ConstStmtVisitor { RecordStmtCount(S); if (S->getInit()) Visit(S->getInit()); - // Counter tracks the body of the loop. - RegionCounter Cnt(PGO, S); + + uint64_t ParentCount = PGO.getCurrentRegionCount(); + BreakContinueStack.push_back(BreakContinue()); // Visit the body region first. (This is basically the same as a while // loop; see further comments in VisitWhileStmt.) - Cnt.beginRegion(); - CountMap[S->getBody()] = PGO.getCurrentRegionCount(); + uint64_t BodyCount = setCount(PGO.getRegionCount(S)); + CountMap[S->getBody()] = BodyCount; Visit(S->getBody()); - Cnt.adjustForControlFlow(); + uint64_t BackedgeCount = PGO.getCurrentRegionCount(); + BreakContinue BC = BreakContinueStack.pop_back_val(); // The increment is essentially part of the body but it needs to include // the count for all the continue statements. if (S->getInc()) { - Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() + - BreakContinueStack.back().ContinueCount); - CountMap[S->getInc()] = PGO.getCurrentRegionCount(); + uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); + CountMap[S->getInc()] = IncCount; Visit(S->getInc()); - Cnt.adjustForControlFlow(); } - BreakContinue BC = BreakContinueStack.pop_back_val(); - // ...then go back and propagate counts through the condition. + uint64_t CondCount = + setCount(ParentCount + BackedgeCount + BC.ContinueCount); if (S->getCond()) { - Cnt.setCurrentRegionCount(Cnt.getParentCount() + Cnt.getAdjustedCount() + - BC.ContinueCount); - CountMap[S->getCond()] = PGO.getCurrentRegionCount(); + CountMap[S->getCond()] = CondCount; Visit(S->getCond()); - Cnt.adjustForControlFlow(); } - Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); + setCount(BC.BreakCount + CondCount - BodyCount); RecordNextStmtCount = true; } @@ -450,47 +443,47 @@ struct ComputeRegionCounts : public ConstStmtVisitor { Visit(S->getLoopVarStmt()); Visit(S->getRangeStmt()); Visit(S->getBeginEndStmt()); - // Counter tracks the body of the loop. - RegionCounter Cnt(PGO, S); + + uint64_t ParentCount = PGO.getCurrentRegionCount(); + BreakContinueStack.push_back(BreakContinue()); // Visit the body region first. (This is basically the same as a while // loop; see further comments in VisitWhileStmt.) - Cnt.beginRegion(); - CountMap[S->getBody()] = PGO.getCurrentRegionCount(); + uint64_t BodyCount = setCount(PGO.getRegionCount(S)); + CountMap[S->getBody()] = BodyCount; Visit(S->getBody()); - Cnt.adjustForControlFlow(); + uint64_t BackedgeCount = PGO.getCurrentRegionCount(); + BreakContinue BC = BreakContinueStack.pop_back_val(); // The increment is essentially part of the body but it needs to include // the count for all the continue statements. - Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() + - BreakContinueStack.back().ContinueCount); - CountMap[S->getInc()] = PGO.getCurrentRegionCount(); + uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); + CountMap[S->getInc()] = IncCount; Visit(S->getInc()); - Cnt.adjustForControlFlow(); - - BreakContinue BC = BreakContinueStack.pop_back_val(); // ...then go back and propagate counts through the condition. - Cnt.setCurrentRegionCount(Cnt.getParentCount() + Cnt.getAdjustedCount() + - BC.ContinueCount); - CountMap[S->getCond()] = PGO.getCurrentRegionCount(); + uint64_t CondCount = + setCount(ParentCount + BackedgeCount + BC.ContinueCount); + CountMap[S->getCond()] = CondCount; Visit(S->getCond()); - Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); + setCount(BC.BreakCount + CondCount - BodyCount); RecordNextStmtCount = true; } void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { RecordStmtCount(S); Visit(S->getElement()); - // Counter tracks the body of the loop. - RegionCounter Cnt(PGO, S); + uint64_t ParentCount = PGO.getCurrentRegionCount(); BreakContinueStack.push_back(BreakContinue()); - Cnt.beginRegion(); - CountMap[S->getBody()] = PGO.getCurrentRegionCount(); + // Counter tracks the body of the loop. + uint64_t BodyCount = setCount(PGO.getRegionCount(S)); + CountMap[S->getBody()] = BodyCount; Visit(S->getBody()); + uint64_t BackedgeCount = PGO.getCurrentRegionCount(); BreakContinue BC = BreakContinueStack.pop_back_val(); - Cnt.adjustForControlFlow(); - Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); + + setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount - + BodyCount); RecordNextStmtCount = true; } @@ -505,53 +498,45 @@ struct ComputeRegionCounts : public ConstStmtVisitor { if (!BreakContinueStack.empty()) BreakContinueStack.back().ContinueCount += BC.ContinueCount; // Counter tracks the exit block of the switch. - RegionCounter ExitCnt(PGO, S); - ExitCnt.beginRegion(); + setCount(PGO.getRegionCount(S)); RecordNextStmtCount = true; } - void VisitCaseStmt(const CaseStmt *S) { + void VisitSwitchCase(const SwitchCase *S) { RecordNextStmtCount = false; // Counter for this particular case. This counts only jumps from the // switch header and does not include fallthrough from the case before // this one. - RegionCounter Cnt(PGO, S); - Cnt.beginRegion(/*AddIncomingFallThrough=*/true); - CountMap[S] = Cnt.getCount(); - RecordNextStmtCount = true; - Visit(S->getSubStmt()); - } - - void VisitDefaultStmt(const DefaultStmt *S) { - RecordNextStmtCount = false; - // Counter for this default case. This does not include fallthrough from - // the previous case. - RegionCounter Cnt(PGO, S); - Cnt.beginRegion(/*AddIncomingFallThrough=*/true); - CountMap[S] = Cnt.getCount(); + uint64_t CaseCount = PGO.getRegionCount(S); + setCount(PGO.getCurrentRegionCount() + CaseCount); + // We need the count without fallthrough in the mapping, so it's more useful + // for branch probabilities. + CountMap[S] = CaseCount; RecordNextStmtCount = true; Visit(S->getSubStmt()); } void VisitIfStmt(const IfStmt *S) { RecordStmtCount(S); - // Counter tracks the "then" part of an if statement. The count for - // the "else" part, if it exists, will be calculated from this counter. - RegionCounter Cnt(PGO, S); + uint64_t ParentCount = PGO.getCurrentRegionCount(); Visit(S->getCond()); - Cnt.beginRegion(); - CountMap[S->getThen()] = PGO.getCurrentRegionCount(); + // Counter tracks the "then" part of an if statement. The count for + // the "else" part, if it exists, will be calculated from this counter. + uint64_t ThenCount = setCount(PGO.getRegionCount(S)); + CountMap[S->getThen()] = ThenCount; Visit(S->getThen()); - Cnt.adjustForControlFlow(); + uint64_t OutCount = PGO.getCurrentRegionCount(); + uint64_t ElseCount = ParentCount - ThenCount; if (S->getElse()) { - Cnt.beginElseRegion(); - CountMap[S->getElse()] = PGO.getCurrentRegionCount(); + setCount(ElseCount); + CountMap[S->getElse()] = ElseCount; Visit(S->getElse()); - Cnt.adjustForControlFlow(); - } - Cnt.applyAdjustmentsToRegion(0); + OutCount += PGO.getCurrentRegionCount(); + } else + OutCount += ElseCount; + setCount(OutCount); RecordNextStmtCount = true; } @@ -561,64 +546,60 @@ struct ComputeRegionCounts : public ConstStmtVisitor { for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I) Visit(S->getHandler(I)); // Counter tracks the continuation block of the try statement. - RegionCounter Cnt(PGO, S); - Cnt.beginRegion(); + setCount(PGO.getRegionCount(S)); RecordNextStmtCount = true; } void VisitCXXCatchStmt(const CXXCatchStmt *S) { RecordNextStmtCount = false; // Counter tracks the catch statement's handler block. - RegionCounter Cnt(PGO, S); - Cnt.beginRegion(); - CountMap[S] = PGO.getCurrentRegionCount(); + uint64_t CatchCount = setCount(PGO.getRegionCount(S)); + CountMap[S] = CatchCount; Visit(S->getHandlerBlock()); } void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) { RecordStmtCount(E); - // Counter tracks the "true" part of a conditional operator. The - // count in the "false" part will be calculated from this counter. - RegionCounter Cnt(PGO, E); + uint64_t ParentCount = PGO.getCurrentRegionCount(); Visit(E->getCond()); - Cnt.beginRegion(); - CountMap[E->getTrueExpr()] = PGO.getCurrentRegionCount(); + // Counter tracks the "true" part of a conditional operator. The + // count in the "false" part will be calculated from this counter. + uint64_t TrueCount = setCount(PGO.getRegionCount(E)); + CountMap[E->getTrueExpr()] = TrueCount; Visit(E->getTrueExpr()); - Cnt.adjustForControlFlow(); + uint64_t OutCount = PGO.getCurrentRegionCount(); - Cnt.beginElseRegion(); - CountMap[E->getFalseExpr()] = PGO.getCurrentRegionCount(); + uint64_t FalseCount = setCount(ParentCount - TrueCount); + CountMap[E->getFalseExpr()] = FalseCount; Visit(E->getFalseExpr()); - Cnt.adjustForControlFlow(); + OutCount += PGO.getCurrentRegionCount(); - Cnt.applyAdjustmentsToRegion(0); + setCount(OutCount); RecordNextStmtCount = true; } void VisitBinLAnd(const BinaryOperator *E) { RecordStmtCount(E); - // Counter tracks the right hand side of a logical and operator. - RegionCounter Cnt(PGO, E); + uint64_t ParentCount = PGO.getCurrentRegionCount(); Visit(E->getLHS()); - Cnt.beginRegion(); - CountMap[E->getRHS()] = PGO.getCurrentRegionCount(); + // Counter tracks the right hand side of a logical and operator. + uint64_t RHSCount = setCount(PGO.getRegionCount(E)); + CountMap[E->getRHS()] = RHSCount; Visit(E->getRHS()); - Cnt.adjustForControlFlow(); - Cnt.applyAdjustmentsToRegion(0); + setCount(ParentCount + RHSCount - PGO.getCurrentRegionCount()); RecordNextStmtCount = true; } void VisitBinLOr(const BinaryOperator *E) { RecordStmtCount(E); - // Counter tracks the right hand side of a logical or operator. - RegionCounter Cnt(PGO, E); + uint64_t ParentCount = PGO.getCurrentRegionCount(); Visit(E->getLHS()); - Cnt.beginRegion(); - CountMap[E->getRHS()] = PGO.getCurrentRegionCount(); + // Counter tracks the right hand side of a logical or operator. + uint64_t RHSCount = setCount(PGO.getRegionCount(E)); + CountMap[E->getRHS()] = RHSCount; Visit(E->getRHS()); - Cnt.adjustForControlFlow(); - Cnt.applyAdjustmentsToRegion(0); + setCount(ParentCount + RHSCount - PGO.getCurrentRegionCount()); RecordNextStmtCount = true; } }; diff --git a/clang/lib/CodeGen/CodeGenPGO.h b/clang/lib/CodeGen/CodeGenPGO.h index 392b1b978bf8..d73d030e7b0d 100644 --- a/clang/lib/CodeGen/CodeGenPGO.h +++ b/clang/lib/CodeGen/CodeGenPGO.h @@ -125,79 +125,6 @@ public: } }; -/// A counter for a particular region. This is the primary interface through -/// which clients manage PGO counters and their values. -class RegionCounter { - CodeGenPGO *PGO; - uint64_t Count; - uint64_t ParentCount; - uint64_t RegionCount; - int64_t Adjust; - -public: - RegionCounter(CodeGenPGO &PGO, const Stmt *S) - : PGO(&PGO), Count(PGO.getRegionCount(S)), - ParentCount(PGO.getCurrentRegionCount()), Adjust(0) {} - - /// Get the value of the counter. In most cases this is the number of times - /// the region of the counter was entered, but for switch labels it's the - /// number of direct jumps to that label. - uint64_t getCount() const { return Count; } - - /// Get the value of the counter with adjustments applied. Adjustments occur - /// when control enters or leaves the region abnormally; i.e., if there is a - /// jump to a label within the region, or if the function can return from - /// within the region. The adjusted count, then, is the value of the counter - /// at the end of the region. - uint64_t getAdjustedCount() const { - return Count + Adjust; - } - - /// Get the value of the counter in this region's parent, i.e., the region - /// that was active when this region began. This is useful for deriving - /// counts in implicitly counted regions, like the false case of a condition - /// or the normal exits of a loop. - uint64_t getParentCount() const { return ParentCount; } - - void beginRegion(bool AddIncomingFallThrough=false) { - RegionCount = Count; - if (AddIncomingFallThrough) - RegionCount += PGO->getCurrentRegionCount(); - PGO->setCurrentRegionCount(RegionCount); - } - - /// For counters on boolean branches, begins tracking adjustments for the - /// uncounted path. - void beginElseRegion() { - RegionCount = ParentCount - Count; - PGO->setCurrentRegionCount(RegionCount); - } - - /// Reset the current region count. - void setCurrentRegionCount(uint64_t CurrentCount) { - RegionCount = CurrentCount; - PGO->setCurrentRegionCount(RegionCount); - } - - /// Adjust for non-local control flow after emitting a subexpression or - /// substatement. This must be called to account for constructs such as gotos, - /// labels, and returns, so that we can ensure that our region's count is - /// correct in the code that follows. - void adjustForControlFlow() { - Adjust += PGO->getCurrentRegionCount() - RegionCount; - // Reset the region count in case this is called again later. - RegionCount = PGO->getCurrentRegionCount(); - } - - /// Commit all adjustments to the current region. If the region is a loop, - /// the LoopAdjust value should be the count of all the breaks and continues - /// from the loop, to compensate for those counts being deducted from the - /// adjustments for the body of the loop. - void applyAdjustmentsToRegion(uint64_t LoopAdjust) { - PGO->setCurrentRegionCount(ParentCount + Adjust + LoopAdjust); - } -}; - } // end namespace CodeGen } // end namespace clang