diff --git a/llvm/lib/Analysis/LoopInfo.cpp b/llvm/lib/Analysis/LoopInfo.cpp index 0ee591313569..4664c78b8c0c 100644 --- a/llvm/lib/Analysis/LoopInfo.cpp +++ b/llvm/lib/Analysis/LoopInfo.cpp @@ -218,20 +218,13 @@ MDNode *Loop::getLoopID() const { } else { assert(!getLoopLatch() && "The loop should have no single latch at this point"); - // Go through each predecessor of the loop header and check the - // terminator for the metadata. - BasicBlock *H = getHeader(); - for (BasicBlock *BB : this->blocks()) { + // Go through the latch blocks and check the terminator for the metadata. + SmallVector LatchesBlocks; + getLoopLatches(LatchesBlocks); + for (BasicBlock *BB : LatchesBlocks) { TerminatorInst *TI = BB->getTerminator(); - MDNode *MD = nullptr; + MDNode *MD = TI->getMetadata(LLVMContext::MD_loop); - // Check if this terminator branches to the loop header. - for (BasicBlock *Successor : successors(TI)) { - if (Successor == H) { - MD = TI->getMetadata(LLVMContext::MD_loop); - break; - } - } if (!MD) return nullptr; diff --git a/llvm/unittests/Analysis/LoopInfoTest.cpp b/llvm/unittests/Analysis/LoopInfoTest.cpp index 647ce8a3c1ba..240785f9eb99 100644 --- a/llvm/unittests/Analysis/LoopInfoTest.cpp +++ b/llvm/unittests/Analysis/LoopInfoTest.cpp @@ -82,6 +82,62 @@ TEST(LoopInfoTest, LoopWithSingleLatch) { }); } +// Test loop id handling for a loop with multiple latches. +TEST(LoopInfoTest, LoopWithMultipleLatches) { + const char *ModuleStr = + "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n" + "define void @foo(i32 %n) {\n" + "entry:\n" + " br i1 undef, label %for.cond, label %for.end\n" + "for.cond:\n" + " %i.0 = phi i32 [ 0, %entry ], [ %inc, %latch.1 ], [ %inc, %latch.2 ]\n" + " %inc = add nsw i32 %i.0, 1\n" + " %cmp = icmp slt i32 %i.0, %n\n" + " br i1 %cmp, label %latch.1, label %for.end\n" + "latch.1:\n" + " br i1 undef, label %for.cond, label %latch.2, !llvm.loop !0\n" + "latch.2:\n" + " br label %for.cond, !llvm.loop !0\n" + "for.end:\n" + " ret void\n" + "}\n" + "!0 = distinct !{!0, !1}\n" + "!1 = !{!\"llvm.loop.distribute.enable\", i1 true}\n"; + + // Parse the module. + LLVMContext Context; + std::unique_ptr M = makeLLVMModule(Context, ModuleStr); + + runWithLoopInfo(*M, "foo", [&](Function &F, LoopInfo &LI) { + Function::iterator FI = F.begin(); + F.dump(); + // First basic block is entry - skip it. + BasicBlock *Header = &*(++FI); + assert(Header->getName() == "for.cond"); + Loop *L = LI.getLoopFor(Header); + EXPECT_NE(L, nullptr); + + // This loop is not in simplified form. + EXPECT_FALSE(L->isLoopSimplifyForm()); + + // Try to get and set the metadata id for the loop. + MDNode *OldLoopID = L->getLoopID(); + EXPECT_NE(OldLoopID, nullptr); + + MDNode *NewLoopID = MDNode::get(Context, {nullptr}); + // Set operand 0 to refer to the loop id itself. + NewLoopID->replaceOperandWith(0, NewLoopID); + + L->setLoopID(NewLoopID); + EXPECT_EQ(L->getLoopID(), NewLoopID); + EXPECT_NE(L->getLoopID(), OldLoopID); + + L->setLoopID(OldLoopID); + EXPECT_EQ(L->getLoopID(), OldLoopID); + EXPECT_NE(L->getLoopID(), NewLoopID); + }); +} + TEST(LoopInfoTest, PreorderTraversals) { const char *ModuleStr = "define void @f() {\n" "entry:\n"