diff --git a/clang/lib/CodeGen/CodeGenPGO.cpp b/clang/lib/CodeGen/CodeGenPGO.cpp index d9016774fa11..158a6ac17ba3 100644 --- a/clang/lib/CodeGen/CodeGenPGO.cpp +++ b/clang/lib/CodeGen/CodeGenPGO.cpp @@ -242,6 +242,9 @@ struct ComputeRegionCounts : public ConstStmtVisitor { /// next statement, such as at the exit of a loop. bool RecordNextStmtCount; + /// The count at the current location in the traversal. + uint64_t CurrentCount; + /// The map of statements to count values. llvm::DenseMap &CountMap; @@ -259,14 +262,14 @@ struct ComputeRegionCounts : public ConstStmtVisitor { void RecordStmtCount(const Stmt *S) { if (RecordNextStmtCount) { - CountMap[S] = PGO.getCurrentRegionCount(); + CountMap[S] = CurrentCount; RecordNextStmtCount = false; } } /// Set and return the current count. uint64_t setCount(uint64_t Count) { - PGO.setCurrentRegionCount(Count); + CurrentCount = Count; return Count; } @@ -315,7 +318,7 @@ struct ComputeRegionCounts : public ConstStmtVisitor { RecordStmtCount(S); if (S->getRetValue()) Visit(S->getRetValue()); - PGO.setCurrentRegionUnreachable(); + CurrentCount = 0; RecordNextStmtCount = true; } @@ -323,13 +326,13 @@ struct ComputeRegionCounts : public ConstStmtVisitor { RecordStmtCount(E); if (E->getSubExpr()) Visit(E->getSubExpr()); - PGO.setCurrentRegionUnreachable(); + CurrentCount = 0; RecordNextStmtCount = true; } void VisitGotoStmt(const GotoStmt *S) { RecordStmtCount(S); - PGO.setCurrentRegionUnreachable(); + CurrentCount = 0; RecordNextStmtCount = true; } @@ -344,30 +347,30 @@ struct ComputeRegionCounts : public ConstStmtVisitor { void VisitBreakStmt(const BreakStmt *S) { RecordStmtCount(S); assert(!BreakContinueStack.empty() && "break not in a loop or switch!"); - BreakContinueStack.back().BreakCount += PGO.getCurrentRegionCount(); - PGO.setCurrentRegionUnreachable(); + BreakContinueStack.back().BreakCount += CurrentCount; + CurrentCount = 0; RecordNextStmtCount = true; } void VisitContinueStmt(const ContinueStmt *S) { RecordStmtCount(S); assert(!BreakContinueStack.empty() && "continue stmt not in a loop!"); - BreakContinueStack.back().ContinueCount += PGO.getCurrentRegionCount(); - PGO.setCurrentRegionUnreachable(); + BreakContinueStack.back().ContinueCount += CurrentCount; + CurrentCount = 0; RecordNextStmtCount = true; } void VisitWhileStmt(const WhileStmt *S) { RecordStmtCount(S); - uint64_t ParentCount = PGO.getCurrentRegionCount(); + uint64_t ParentCount = CurrentCount; BreakContinueStack.push_back(BreakContinue()); // Visit the body region first so the break/continue adjustments can be // included when visiting the condition. uint64_t BodyCount = setCount(PGO.getRegionCount(S)); - CountMap[S->getBody()] = PGO.getCurrentRegionCount(); + CountMap[S->getBody()] = CurrentCount; Visit(S->getBody()); - uint64_t BackedgeCount = PGO.getCurrentRegionCount(); + uint64_t BackedgeCount = CurrentCount; // ...then go back and propagate counts through the condition. The count // at the start of the condition is the sum of the incoming edges, @@ -388,10 +391,10 @@ struct ComputeRegionCounts : public ConstStmtVisitor { BreakContinueStack.push_back(BreakContinue()); // The count doesn't include the fallthrough from the parent scope. Add it. - uint64_t BodyCount = setCount(LoopCount + PGO.getCurrentRegionCount()); + uint64_t BodyCount = setCount(LoopCount + CurrentCount); CountMap[S->getBody()] = BodyCount; Visit(S->getBody()); - uint64_t BackedgeCount = PGO.getCurrentRegionCount(); + uint64_t BackedgeCount = CurrentCount; BreakContinue BC = BreakContinueStack.pop_back_val(); // The count at the start of the condition is equal to the count at the @@ -408,7 +411,7 @@ struct ComputeRegionCounts : public ConstStmtVisitor { if (S->getInit()) Visit(S->getInit()); - uint64_t ParentCount = PGO.getCurrentRegionCount(); + uint64_t ParentCount = CurrentCount; BreakContinueStack.push_back(BreakContinue()); // Visit the body region first. (This is basically the same as a while @@ -416,7 +419,7 @@ struct ComputeRegionCounts : public ConstStmtVisitor { uint64_t BodyCount = setCount(PGO.getRegionCount(S)); CountMap[S->getBody()] = BodyCount; Visit(S->getBody()); - uint64_t BackedgeCount = PGO.getCurrentRegionCount(); + uint64_t BackedgeCount = CurrentCount; BreakContinue BC = BreakContinueStack.pop_back_val(); // The increment is essentially part of the body but it needs to include @@ -444,15 +447,14 @@ struct ComputeRegionCounts : public ConstStmtVisitor { Visit(S->getRangeStmt()); Visit(S->getBeginEndStmt()); - uint64_t ParentCount = PGO.getCurrentRegionCount(); - + uint64_t ParentCount = CurrentCount; BreakContinueStack.push_back(BreakContinue()); // Visit the body region first. (This is basically the same as a while // loop; see further comments in VisitWhileStmt.) uint64_t BodyCount = setCount(PGO.getRegionCount(S)); CountMap[S->getBody()] = BodyCount; Visit(S->getBody()); - uint64_t BackedgeCount = PGO.getCurrentRegionCount(); + uint64_t BackedgeCount = CurrentCount; BreakContinue BC = BreakContinueStack.pop_back_val(); // The increment is essentially part of the body but it needs to include @@ -473,13 +475,13 @@ struct ComputeRegionCounts : public ConstStmtVisitor { void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { RecordStmtCount(S); Visit(S->getElement()); - uint64_t ParentCount = PGO.getCurrentRegionCount(); + uint64_t ParentCount = CurrentCount; BreakContinueStack.push_back(BreakContinue()); // Counter tracks the body of the loop. uint64_t BodyCount = setCount(PGO.getRegionCount(S)); CountMap[S->getBody()] = BodyCount; Visit(S->getBody()); - uint64_t BackedgeCount = PGO.getCurrentRegionCount(); + uint64_t BackedgeCount = CurrentCount; BreakContinue BC = BreakContinueStack.pop_back_val(); setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount - @@ -490,7 +492,7 @@ struct ComputeRegionCounts : public ConstStmtVisitor { void VisitSwitchStmt(const SwitchStmt *S) { RecordStmtCount(S); Visit(S->getCond()); - PGO.setCurrentRegionUnreachable(); + CurrentCount = 0; BreakContinueStack.push_back(BreakContinue()); Visit(S->getBody()); // If the switch is inside a loop, add the continue counts. @@ -508,7 +510,7 @@ struct ComputeRegionCounts : public ConstStmtVisitor { // switch header and does not include fallthrough from the case before // this one. uint64_t CaseCount = PGO.getRegionCount(S); - setCount(PGO.getCurrentRegionCount() + CaseCount); + setCount(CurrentCount + CaseCount); // We need the count without fallthrough in the mapping, so it's more useful // for branch probabilities. CountMap[S] = CaseCount; @@ -518,7 +520,7 @@ struct ComputeRegionCounts : public ConstStmtVisitor { void VisitIfStmt(const IfStmt *S) { RecordStmtCount(S); - uint64_t ParentCount = PGO.getCurrentRegionCount(); + uint64_t ParentCount = CurrentCount; Visit(S->getCond()); // Counter tracks the "then" part of an if statement. The count for @@ -526,14 +528,14 @@ struct ComputeRegionCounts : public ConstStmtVisitor { uint64_t ThenCount = setCount(PGO.getRegionCount(S)); CountMap[S->getThen()] = ThenCount; Visit(S->getThen()); - uint64_t OutCount = PGO.getCurrentRegionCount(); + uint64_t OutCount = CurrentCount; uint64_t ElseCount = ParentCount - ThenCount; if (S->getElse()) { setCount(ElseCount); CountMap[S->getElse()] = ElseCount; Visit(S->getElse()); - OutCount += PGO.getCurrentRegionCount(); + OutCount += CurrentCount; } else OutCount += ElseCount; setCount(OutCount); @@ -560,7 +562,7 @@ struct ComputeRegionCounts : public ConstStmtVisitor { void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) { RecordStmtCount(E); - uint64_t ParentCount = PGO.getCurrentRegionCount(); + uint64_t ParentCount = CurrentCount; Visit(E->getCond()); // Counter tracks the "true" part of a conditional operator. The @@ -568,12 +570,12 @@ struct ComputeRegionCounts : public ConstStmtVisitor { uint64_t TrueCount = setCount(PGO.getRegionCount(E)); CountMap[E->getTrueExpr()] = TrueCount; Visit(E->getTrueExpr()); - uint64_t OutCount = PGO.getCurrentRegionCount(); + uint64_t OutCount = CurrentCount; uint64_t FalseCount = setCount(ParentCount - TrueCount); CountMap[E->getFalseExpr()] = FalseCount; Visit(E->getFalseExpr()); - OutCount += PGO.getCurrentRegionCount(); + OutCount += CurrentCount; setCount(OutCount); RecordNextStmtCount = true; @@ -581,25 +583,25 @@ struct ComputeRegionCounts : public ConstStmtVisitor { void VisitBinLAnd(const BinaryOperator *E) { RecordStmtCount(E); - uint64_t ParentCount = PGO.getCurrentRegionCount(); + uint64_t ParentCount = CurrentCount; Visit(E->getLHS()); // Counter tracks the right hand side of a logical and operator. uint64_t RHSCount = setCount(PGO.getRegionCount(E)); CountMap[E->getRHS()] = RHSCount; Visit(E->getRHS()); - setCount(ParentCount + RHSCount - PGO.getCurrentRegionCount()); + setCount(ParentCount + RHSCount - CurrentCount); RecordNextStmtCount = true; } void VisitBinLOr(const BinaryOperator *E) { RecordStmtCount(E); - uint64_t ParentCount = PGO.getCurrentRegionCount(); + uint64_t ParentCount = CurrentCount; Visit(E->getLHS()); // Counter tracks the right hand side of a logical or operator. uint64_t RHSCount = setCount(PGO.getRegionCount(E)); CountMap[E->getRHS()] = RHSCount; Visit(E->getRHS()); - setCount(ParentCount + RHSCount - PGO.getCurrentRegionCount()); + setCount(ParentCount + RHSCount - CurrentCount); RecordNextStmtCount = true; } }; diff --git a/clang/lib/CodeGen/CodeGenPGO.h b/clang/lib/CodeGen/CodeGenPGO.h index d73d030e7b0d..13bb5a2fc86f 100644 --- a/clang/lib/CodeGen/CodeGenPGO.h +++ b/clang/lib/CodeGen/CodeGenPGO.h @@ -60,11 +60,6 @@ public: /// exits. void setCurrentRegionCount(uint64_t Count) { CurrentRegionCount = Count; } - /// Indicate that the current region is never reached, and thus should have a - /// counter value of zero. This is important so that subsequent regions can - /// correctly track their parent counts. - void setCurrentRegionUnreachable() { setCurrentRegionCount(0); } - /// Check if an execution count is known for a given statement. If so, return /// true and put the value in Count; else return false. Optional getStmtCount(const Stmt *S) {