From 507dd40a4abadbfdaa4f49a4823ddae6c7dfec4f Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Tue, 18 Oct 2016 17:45:16 +0000 Subject: [PATCH] [SCEV] Make CompareValueComplexity a little bit smarter This helps canonicalization in some cases. Thanks to Pankaj Chawla for the investigation and the test case! llvm-svn: 284501 --- llvm/lib/Analysis/ScalarEvolution.cpp | 14 ++- .../Analysis/ScalarEvolutionTest.cpp | 112 ++++++++++++++++++ 2 files changed, 124 insertions(+), 2 deletions(-) diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index fd3cd17ec19f..9fa0de1aff81 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -449,7 +449,10 @@ bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const { //===----------------------------------------------------------------------===// static int CompareValueComplexity(const LoopInfo *const LI, Value *LV, - Value *RV) { + Value *RV, unsigned DepthLeft = 2) { + if (DepthLeft == 0) + return 0; + // Order pointer values after integer values. This helps SCEVExpander form // GEPs. bool LIsPointer = LV->getType()->isPointerTy(), @@ -487,7 +490,14 @@ static int CompareValueComplexity(const LoopInfo *const LI, Value *LV, // Compare the number of operands. unsigned LNumOps = LInst->getNumOperands(), RNumOps = RInst->getNumOperands(); - return (int)LNumOps - (int)RNumOps; + if (LNumOps != RNumOps || LNumOps != 1) + return (int)LNumOps - (int)RNumOps; + + // We only bother "recursing" if we have one operand to look at (so we don't + // really recurse as much as we iterate). We can consider expanding this + // logic in the future. + return CompareValueComplexity(LI, LInst->getOperand(0), + RInst->getOperand(0), DepthLeft - 1); } return 0; diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp index 0bc99a37dd8e..91a9c73cc42c 100644 --- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp +++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp @@ -14,12 +14,16 @@ #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/AsmParser/Parser.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/InstIterator.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Verifier.h" +#include "llvm/Support/SourceMgr.h" #include "gtest/gtest.h" namespace llvm { @@ -329,5 +333,113 @@ TEST_F(ScalarEvolutionsTest, ExpandPtrTypeSCEV) { EXPECT_TRUE(isa(Gep->getPrevNode())); } +static Instruction *getInstructionByName(Module &M, StringRef Name) { + for (auto &F : M) + for (auto &I : instructions(F)) + if (I.getName() == Name) + return &I; + llvm_unreachable("Expected to find instruction!"); +} + +TEST_F(ScalarEvolutionsTest, CommutativeExprOperandOrder) { + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr M = parseAssemblyString( + "target datalayout = \"e-m:e-p:32:32-f64:32:64-f80:32-n8:16:32-S128\" " + "define void @foo(i8* nocapture %arr, i32 %n, i32* %A, i32* %B) " + " local_unnamed_addr { " + "entry: " + " %entrycond = icmp sgt i32 %n, 0 " + " br i1 %entrycond, label %loop.ph, label %for.end " + " " + "loop.ph: " + " %a = load i32, i32* %A, align 4 " + " %b = load i32, i32* %B, align 4 " + " %mul = mul nsw i32 %b, %a " + " %iv0.init = getelementptr inbounds i8, i8* %arr, i32 %mul " + " br label %loop " + " " + "loop: " + " %iv0 = phi i8* [ %iv0.inc, %loop ], [ %iv0.init, %loop.ph ] " + " %iv1 = phi i32 [ %iv1.inc, %loop ], [ 0, %loop.ph ] " + " %conv = trunc i32 %iv1 to i8 " + " store i8 %conv, i8* %iv0, align 1 " + " %iv0.inc = getelementptr inbounds i8, i8* %iv0, i32 %b " + " %iv1.inc = add nuw nsw i32 %iv1, 1 " + " %exitcond = icmp eq i32 %iv1.inc, %n " + " br i1 %exitcond, label %for.end.loopexit, label %loop " + " " + "for.end.loopexit: " + " br label %for.end " + " " + "for.end: " + " ret void " + "} " + " " + "define void @bar(i32* %X, i32* %Y, i32* %Z) { " + " %x = load i32, i32* %X " + " %y = load i32, i32* %Y " + " %z = load i32, i32* %Z " + " ret void " + "} ", + Err, C); + + assert(M && "Could not parse module?"); + assert(!verifyModule(*M) && "Must have been well formed!"); + + { + auto *IV0 = getInstructionByName(*M, "iv0"); + auto *IV0Inc = getInstructionByName(*M, "iv0.inc"); + + auto *F = M->getFunction("foo"); + assert(F && "Expected!"); + + ScalarEvolution SE = buildSE(*F); + auto *FirstExprForIV0 = SE.getSCEV(IV0); + auto *FirstExprForIV0Inc = SE.getSCEV(IV0Inc); + auto *SecondExprForIV0 = SE.getSCEV(IV0); + + EXPECT_TRUE(isa(FirstExprForIV0)); + EXPECT_TRUE(isa(FirstExprForIV0Inc)); + EXPECT_TRUE(isa(SecondExprForIV0)); + } + + { + auto *F = M->getFunction("bar"); + assert(F && "Expected!"); + + ScalarEvolution SE = buildSE(*F); + + auto *LoadArg0 = SE.getSCEV(getInstructionByName(*M, "x")); + auto *LoadArg1 = SE.getSCEV(getInstructionByName(*M, "y")); + auto *LoadArg2 = SE.getSCEV(getInstructionByName(*M, "z")); + + auto *MulA = SE.getMulExpr(LoadArg0, LoadArg1); + auto *MulB = SE.getMulExpr(LoadArg1, LoadArg0); + + EXPECT_EQ(MulA, MulB); + + SmallVector Ops0 = { LoadArg0, LoadArg1, LoadArg2 }; + SmallVector Ops1 = { LoadArg0, LoadArg2, LoadArg1 }; + SmallVector Ops2 = { LoadArg1, LoadArg0, LoadArg2 }; + SmallVector Ops3 = { LoadArg1, LoadArg2, LoadArg0 }; + SmallVector Ops4 = { LoadArg2, LoadArg1, LoadArg0 }; + SmallVector Ops5 = { LoadArg2, LoadArg0, LoadArg1 }; + + auto *Mul0 = SE.getMulExpr(Ops0); + auto *Mul1 = SE.getMulExpr(Ops1); + auto *Mul2 = SE.getMulExpr(Ops2); + auto *Mul3 = SE.getMulExpr(Ops3); + auto *Mul4 = SE.getMulExpr(Ops4); + auto *Mul5 = SE.getMulExpr(Ops5); + + EXPECT_EQ(Mul0, Mul1); + EXPECT_EQ(Mul1, Mul2); + EXPECT_EQ(Mul2, Mul3); + EXPECT_EQ(Mul3, Mul4); + EXPECT_EQ(Mul4, Mul5); + } +} + } // end anonymous namespace } // end namespace llvm