diff --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h index 82833658c418..2e35f5a7fff8 100644 --- a/llvm/include/llvm/IR/Instructions.h +++ b/llvm/include/llvm/IR/Instructions.h @@ -3439,15 +3439,24 @@ public: /// their prof branch_weights metadata. class SwitchInstProfUpdateWrapper { SwitchInst &SI; - Optional > Weights; - bool Changed = false; + Optional > Weights = None; + + // Sticky invalid state is needed to safely ignore operations with prof data + // in cases where SwitchInstProfUpdateWrapper is created from SwitchInst + // with inconsistent prof data. TODO: once we fix all prof data + // inconsistencies we can turn invalid state to assertions. + enum { + Invalid, + Initialized, + Changed + } State = Invalid; protected: static MDNode *getProfBranchWeightsMD(const SwitchInst &SI); MDNode *buildProfBranchWeightsMD(); - Optional > getProfBranchWeights(); + void init(); public: using CaseWeightOpt = Optional; @@ -3455,11 +3464,10 @@ public: SwitchInst &operator*() { return SI; } operator SwitchInst *() { return &SI; } - SwitchInstProfUpdateWrapper(SwitchInst &SI) - : SI(SI), Weights(getProfBranchWeights()) {} + SwitchInstProfUpdateWrapper(SwitchInst &SI) : SI(SI) { init(); } ~SwitchInstProfUpdateWrapper() { - if (Changed) + if (State == Changed) SI.setMetadata(LLVMContext::MD_prof, buildProfBranchWeightsMD()); } diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp index 8812df35e26b..ad082a9c24f3 100644 --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -45,6 +45,12 @@ using namespace llvm; +static cl::opt SwitchInstProfUpdateWrapperStrict( + "switch-inst-prof-update-wrapper-strict", cl::Hidden, + cl::desc("Assert that prof branch_weights metadata is valid when creating " + "an instance of SwitchInstProfUpdateWrapper"), + cl::init(false)); + //===----------------------------------------------------------------------===// // AllocaInst Class //===----------------------------------------------------------------------===// @@ -3880,7 +3886,7 @@ SwitchInstProfUpdateWrapper::getProfBranchWeightsMD(const SwitchInst &SI) { } MDNode *SwitchInstProfUpdateWrapper::buildProfBranchWeightsMD() { - assert(Changed && "called only if metadata has changed"); + assert(State == Changed && "called only if metadata has changed"); if (!Weights) return nullptr; @@ -3897,11 +3903,20 @@ MDNode *SwitchInstProfUpdateWrapper::buildProfBranchWeightsMD() { return MDBuilder(SI.getParent()->getContext()).createBranchWeights(*Weights); } -Optional > -SwitchInstProfUpdateWrapper::getProfBranchWeights() { +void SwitchInstProfUpdateWrapper::init() { MDNode *ProfileData = getProfBranchWeightsMD(SI); - if (!ProfileData) - return None; + if (!ProfileData) { + State = Initialized; + return; + } + + if (ProfileData->getNumOperands() != SI.getNumSuccessors() + 1) { + State = Invalid; + if (SwitchInstProfUpdateWrapperStrict) + assert(!"number of prof branch_weights metadata operands corresponds to" + " number of succesors"); + return; + } SmallVector Weights; for (unsigned CI = 1, CE = SI.getNumSuccessors(); CI <= CE; ++CI) { @@ -3909,7 +3924,8 @@ SwitchInstProfUpdateWrapper::getProfBranchWeights() { uint32_t CW = C->getValue().getZExtValue(); Weights.push_back(CW); } - return Weights; + State = Initialized; + this->Weights = std::move(Weights); } SwitchInst::CaseIt @@ -3917,7 +3933,7 @@ SwitchInstProfUpdateWrapper::removeCase(SwitchInst::CaseIt I) { if (Weights) { assert(SI.getNumSuccessors() == Weights->size() && "num of prof branch_weights must accord with num of successors"); - Changed = true; + State = Changed; // Copy the last case to the place of the removed one and shrink. // This is tightly coupled with the way SwitchInst::removeCase() removes // the cases in SwitchInst::removeCase(CaseIt). @@ -3932,12 +3948,15 @@ void SwitchInstProfUpdateWrapper::addCase( SwitchInstProfUpdateWrapper::CaseWeightOpt W) { SI.addCase(OnVal, Dest); + if (State == Invalid) + return; + if (!Weights && W && *W) { - Changed = true; + State = Changed; Weights = SmallVector(SI.getNumSuccessors(), 0); Weights.getValue()[SI.getNumSuccessors() - 1] = *W; } else if (Weights) { - Changed = true; + State = Changed; Weights.getValue().push_back(W ? *W : 0); } if (Weights) @@ -3948,10 +3967,11 @@ void SwitchInstProfUpdateWrapper::addCase( SymbolTableList::iterator SwitchInstProfUpdateWrapper::eraseFromParent() { // Instruction is erased. Mark as unchanged to not touch it in the destructor. - Changed = false; - - if (Weights) - Weights->resize(0); + if (State != Invalid) { + State = Initialized; + if (Weights) + Weights->resize(0); + } return SI.eraseFromParent(); } @@ -3964,7 +3984,7 @@ SwitchInstProfUpdateWrapper::getSuccessorWeight(unsigned idx) { void SwitchInstProfUpdateWrapper::setSuccessorWeight( unsigned idx, SwitchInstProfUpdateWrapper::CaseWeightOpt W) { - if (!W) + if (!W || State == Invalid) return; if (!Weights && *W) @@ -3973,7 +3993,7 @@ void SwitchInstProfUpdateWrapper::setSuccessorWeight( if (Weights) { auto &OldW = Weights.getValue()[idx]; if (*W != OldW) { - Changed = true; + State = Changed; OldW = *W; } } @@ -3983,9 +4003,10 @@ SwitchInstProfUpdateWrapper::CaseWeightOpt SwitchInstProfUpdateWrapper::getSuccessorWeight(const SwitchInst &SI, unsigned idx) { if (MDNode *ProfileData = getProfBranchWeightsMD(SI)) - return mdconst::extract(ProfileData->getOperand(idx + 1)) - ->getValue() - .getZExtValue(); + if (ProfileData->getNumOperands() == SI.getNumSuccessors() + 1) + return mdconst::extract(ProfileData->getOperand(idx + 1)) + ->getValue() + .getZExtValue(); return None; } diff --git a/llvm/unittests/IR/InstructionsTest.cpp b/llvm/unittests/IR/InstructionsTest.cpp index 3b2bd6fa81b0..70d51e5fc6d1 100644 --- a/llvm/unittests/IR/InstructionsTest.cpp +++ b/llvm/unittests/IR/InstructionsTest.cpp @@ -753,6 +753,85 @@ TEST(InstructionsTest, SwitchInst) { EXPECT_EQ(BB1.get(), Handle.getCaseSuccessor()); } +TEST(InstructionsTest, SwitchInstProfUpdateWrapper) { + LLVMContext C; + + std::unique_ptr BB1, BB2, BB3; + BB1.reset(BasicBlock::Create(C)); + BB2.reset(BasicBlock::Create(C)); + BB3.reset(BasicBlock::Create(C)); + + // We create block 0 after the others so that it gets destroyed first and + // clears the uses of the other basic blocks. + std::unique_ptr BB0(BasicBlock::Create(C)); + + auto *Int32Ty = Type::getInt32Ty(C); + + SwitchInst *SI = + SwitchInst::Create(UndefValue::get(Int32Ty), BB0.get(), 4, BB0.get()); + SI->addCase(ConstantInt::get(Int32Ty, 1), BB1.get()); + SI->addCase(ConstantInt::get(Int32Ty, 2), BB2.get()); + SI->setMetadata(LLVMContext::MD_prof, + MDBuilder(C).createBranchWeights({ 9, 1, 22 })); + + { + SwitchInstProfUpdateWrapper SIW(*SI); + EXPECT_EQ(*SIW.getSuccessorWeight(0), 9u); + EXPECT_EQ(*SIW.getSuccessorWeight(1), 1u); + EXPECT_EQ(*SIW.getSuccessorWeight(2), 22u); + SIW.setSuccessorWeight(0, 99u); + SIW.setSuccessorWeight(1, 11u); + EXPECT_EQ(*SIW.getSuccessorWeight(0), 99u); + EXPECT_EQ(*SIW.getSuccessorWeight(1), 11u); + EXPECT_EQ(*SIW.getSuccessorWeight(2), 22u); + } + + { // Create another wrapper and check that the data persist. + SwitchInstProfUpdateWrapper SIW(*SI); + EXPECT_EQ(*SIW.getSuccessorWeight(0), 99u); + EXPECT_EQ(*SIW.getSuccessorWeight(1), 11u); + EXPECT_EQ(*SIW.getSuccessorWeight(2), 22u); + } + + // Make prof data invalid by adding one extra weight. + SI->setMetadata(LLVMContext::MD_prof, MDBuilder(C).createBranchWeights( + { 99, 11, 22, 33 })); // extra + { // Invalid prof data makes wrapper act as if there were no prof data. + SwitchInstProfUpdateWrapper SIW(*SI); + ASSERT_FALSE(SIW.getSuccessorWeight(0).hasValue()); + ASSERT_FALSE(SIW.getSuccessorWeight(1).hasValue()); + ASSERT_FALSE(SIW.getSuccessorWeight(2).hasValue()); + SIW.addCase(ConstantInt::get(Int32Ty, 3), BB3.get(), 39); + ASSERT_FALSE(SIW.getSuccessorWeight(3).hasValue()); // did not add weight 39 + } + + { // With added 3rd case the prof data become consistent with num of cases. + SwitchInstProfUpdateWrapper SIW(*SI); + EXPECT_EQ(*SIW.getSuccessorWeight(0), 99u); + EXPECT_EQ(*SIW.getSuccessorWeight(1), 11u); + EXPECT_EQ(*SIW.getSuccessorWeight(2), 22u); + EXPECT_EQ(*SIW.getSuccessorWeight(3), 33u); + } + + // Make prof data invalid by removing one extra weight. + SI->setMetadata(LLVMContext::MD_prof, + MDBuilder(C).createBranchWeights({ 99, 11, 22 })); // shorter + { // Invalid prof data makes wrapper act as if there were no prof data. + SwitchInstProfUpdateWrapper SIW(*SI); + ASSERT_FALSE(SIW.getSuccessorWeight(0).hasValue()); + ASSERT_FALSE(SIW.getSuccessorWeight(1).hasValue()); + ASSERT_FALSE(SIW.getSuccessorWeight(2).hasValue()); + SIW.removeCase(SwitchInst::CaseIt(SI, 2)); + } + + { // With removed 3rd case the prof data become consistent with num of cases. + SwitchInstProfUpdateWrapper SIW(*SI); + EXPECT_EQ(*SIW.getSuccessorWeight(0), 99u); + EXPECT_EQ(*SIW.getSuccessorWeight(1), 11u); + EXPECT_EQ(*SIW.getSuccessorWeight(2), 22u); + } +} + TEST(InstructionsTest, CommuteShuffleMask) { SmallVector Indices({-1, 0, 7}); ShuffleVectorInst::commuteShuffleMask(Indices, 4);