From 70dbd5fbd0a5f869126a944ae7e23058a106f8b7 Mon Sep 17 00:00:00 2001 From: Simon Dardis Date: Sat, 9 Dec 2017 23:25:57 +0000 Subject: [PATCH] Infer lowest bits of an integer Multiply when the low bits of the operands are known When the lowest bits of the operands to an integer multiply are known, the low bits of the result are deducible. Code to deduce known-zero bottom bits already existed, but this change improves on that by deducing known-ones. Patch by: Pedro Ferreira Reviewers: craig.topper, sanjoy, efriedma Differential Revision: https://reviews.llvm.org/D34029 llvm-svn: 320269 --- llvm/lib/Analysis/ValueTracking.cpp | 75 ++++++++++++++++--- llvm/unittests/Analysis/ValueTrackingTest.cpp | 55 ++++++++++++++ 2 files changed, 121 insertions(+), 9 deletions(-) diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index 4f7039c6aa7a..e086d27005cc 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -336,21 +336,78 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW, } } - // If low bits are zero in either operand, output low known-0 bits. - // Also compute a conservative estimate for high known-0 bits. - // More trickiness is possible, but this is sufficient for the - // interesting case of alignment computation. - unsigned TrailZ = Known.countMinTrailingZeros() + - Known2.countMinTrailingZeros(); + assert(!Known.hasConflict() && !Known2.hasConflict()); + // Compute a conservative estimate for high known-0 bits. unsigned LeadZ = std::max(Known.countMinLeadingZeros() + Known2.countMinLeadingZeros(), BitWidth) - BitWidth; - - TrailZ = std::min(TrailZ, BitWidth); LeadZ = std::min(LeadZ, BitWidth); + + // The result of the bottom bits of an integer multiply can be + // inferred by looking at the bottom bits of both operands and + // multiplying them together. + // We can infer at least the minimum number of known trailing bits + // of both operands. Depending on number of trailing zeros, we can + // infer more bits, because (a*b) <=> ((a/m) * (b/n)) * (m*n) assuming + // a and b are divisible by m and n respectively. + // We then calculate how many of those bits are inferrable and set + // the output. For example, the i8 mul: + // a = XXXX1100 (12) + // b = XXXX1110 (14) + // We know the bottom 3 bits are zero since the first can be divided by + // 4 and the second by 2, thus having ((12/4) * (14/2)) * (2*4). + // Applying the multiplication to the trimmed arguments gets: + // XX11 (3) + // X111 (7) + // ------- + // XX11 + // XX11 + // XX11 + // XX11 + // ------- + // XXXXX01 + // Which allows us to infer the 2 LSBs. Since we're multiplying the result + // by 8, the bottom 3 bits will be 0, so we can infer a total of 5 bits. + // The proof for this can be described as: + // Pre: (C1 >= 0) && (C1 < (1 << C5)) && (C2 >= 0) && (C2 < (1 << C6)) && + // (C7 == (1 << (umin(countTrailingZeros(C1), C5) + + // umin(countTrailingZeros(C2), C6) + + // umin(C5 - umin(countTrailingZeros(C1), C5), + // C6 - umin(countTrailingZeros(C2), C6)))) - 1) + // %aa = shl i8 %a, C5 + // %bb = shl i8 %b, C6 + // %aaa = or i8 %aa, C1 + // %bbb = or i8 %bb, C2 + // %mul = mul i8 %aaa, %bbb + // %mask = and i8 %mul, C7 + // => + // %mask = i8 ((C1*C2)&C7) + // Where C5, C6 describe the known bits of %a, %b + // C1, C2 describe the known bottom bits of %a, %b. + // C7 describes the mask of the known bits of the result. + APInt Bottom0 = Known.One; + APInt Bottom1 = Known2.One; + + // How many times we'd be able to divide each argument by 2 (shr by 1). + // This gives us the number of trailing zeros on the multiplication result. + unsigned TrailBitsKnown0 = (Known.Zero | Known.One).countTrailingOnes(); + unsigned TrailBitsKnown1 = (Known2.Zero | Known2.One).countTrailingOnes(); + unsigned TrailZero0 = Known.countMinTrailingZeros(); + unsigned TrailZero1 = Known2.countMinTrailingZeros(); + unsigned TrailZ = TrailZero0 + TrailZero1; + + // Figure out the fewest known-bits operand. + unsigned SmallestOperand = std::min(TrailBitsKnown0 - TrailZero0, + TrailBitsKnown1 - TrailZero1); + unsigned ResultBitsKnown = std::min(SmallestOperand + TrailZ, BitWidth); + + APInt BottomKnown = Bottom0.getLoBits(TrailBitsKnown0) * + Bottom1.getLoBits(TrailBitsKnown1); + Known.resetAll(); - Known.Zero.setLowBits(TrailZ); Known.Zero.setHighBits(LeadZ); + Known.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown); + Known.One |= BottomKnown.getLoBits(ResultBitsKnown); // Only make use of no-wrap flags if we failed to compute the sign bit // directly. This matters if the multiplication always overflows, in diff --git a/llvm/unittests/Analysis/ValueTrackingTest.cpp b/llvm/unittests/Analysis/ValueTrackingTest.cpp index 3c8ecfbe1ee2..cfdf264da310 100644 --- a/llvm/unittests/Analysis/ValueTrackingTest.cpp +++ b/llvm/unittests/Analysis/ValueTrackingTest.cpp @@ -15,6 +15,7 @@ #include "llvm/IR/Module.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/SourceMgr.h" +#include "llvm/Support/KnownBits.h" #include "gtest/gtest.h" using namespace llvm; @@ -258,3 +259,57 @@ TEST(ValueTracking, ComputeNumSignBits_PR32045) { cast(F->getEntryBlock().getTerminator())->getOperand(0); EXPECT_EQ(ComputeNumSignBits(RVal, M->getDataLayout()), 1u); } + +TEST(ValueTracking, ComputeKnownBits) { + StringRef Assembly = "define i32 @f(i32 %a, i32 %b) { " + " %ash = mul i32 %a, 8 " + " %aad = add i32 %ash, 7 " + " %aan = and i32 %aad, 4095 " + " %bsh = shl i32 %b, 4 " + " %bad = or i32 %bsh, 6 " + " %ban = and i32 %bad, 4095 " + " %mul = mul i32 %aan, %ban " + " ret i32 %mul " + "} "; + + LLVMContext Context; + SMDiagnostic Error; + auto M = parseAssemblyString(Assembly, Error, Context); + assert(M && "Bad assembly?"); + + auto *F = M->getFunction("f"); + assert(F && "Bad assembly?"); + + auto *RVal = + cast(F->getEntryBlock().getTerminator())->getOperand(0); + auto Known = computeKnownBits(RVal, M->getDataLayout()); + ASSERT_FALSE(Known.hasConflict()); + EXPECT_EQ(Known.One.getZExtValue(), 10u); + EXPECT_EQ(Known.Zero.getZExtValue(), 4278190085u); +} + +TEST(ValueTracking, ComputeKnownMulBits) { + StringRef Assembly = "define i32 @f(i32 %a, i32 %b) { " + " %aa = shl i32 %a, 5 " + " %bb = shl i32 %b, 5 " + " %aaa = or i32 %aa, 24 " + " %bbb = or i32 %bb, 28 " + " %mul = mul i32 %aaa, %bbb " + " ret i32 %mul " + "} "; + + LLVMContext Context; + SMDiagnostic Error; + auto M = parseAssemblyString(Assembly, Error, Context); + assert(M && "Bad assembly?"); + + auto *F = M->getFunction("f"); + assert(F && "Bad assembly?"); + + auto *RVal = + cast(F->getEntryBlock().getTerminator())->getOperand(0); + auto Known = computeKnownBits(RVal, M->getDataLayout()); + ASSERT_FALSE(Known.hasConflict()); + EXPECT_EQ(Known.One.getZExtValue(), 32u); + EXPECT_EQ(Known.Zero.getZExtValue(), 95u); +}