diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h index ed32cd2b576f..07fd94e29a1f 100644 --- a/llvm/include/llvm/Support/KnownBits.h +++ b/llvm/include/llvm/Support/KnownBits.h @@ -202,6 +202,10 @@ public: return getBitWidth() - Zero.countPopulation(); } + /// Compute known bits resulting from adding LHS, RHS and a 1-bit Carry. + static KnownBits computeForAddCarry( + const KnownBits &LHS, const KnownBits &RHS, const KnownBits &Carry); + /// Compute known bits resulting from adding LHS and RHS. static KnownBits computeForAddSub(bool Add, bool NSW, const KnownBits &LHS, KnownBits RHS); diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp index 9988314fabb9..a6c591fca312 100644 --- a/llvm/lib/Support/KnownBits.cpp +++ b/llvm/lib/Support/KnownBits.cpp @@ -15,18 +15,14 @@ using namespace llvm; -KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, - const KnownBits &LHS, KnownBits RHS) { - // Carry in a 1 for a subtract, rather than 0. - bool CarryIn = false; - if (!Add) { - // Sum = LHS + ~RHS + 1 - std::swap(RHS.Zero, RHS.One); - CarryIn = true; - } +static KnownBits computeForAddCarry( + const KnownBits &LHS, const KnownBits &RHS, + bool CarryZero, bool CarryOne) { + assert(!(CarryZero && CarryOne) && + "Carry can't be zero and one at the same time"); - APInt PossibleSumZero = ~LHS.Zero + ~RHS.Zero + CarryIn; - APInt PossibleSumOne = LHS.One + RHS.One + CarryIn; + APInt PossibleSumZero = ~LHS.Zero + ~RHS.Zero + !CarryZero; + APInt PossibleSumOne = LHS.One + RHS.One + CarryOne; // Compute known bits of the carry. APInt CarryKnownZero = ~(PossibleSumZero ^ LHS.Zero ^ RHS.Zero); @@ -45,9 +41,32 @@ KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, KnownBits KnownOut; KnownOut.Zero = ~std::move(PossibleSumZero) & Known; KnownOut.One = std::move(PossibleSumOne) & Known; + return KnownOut; +} + +KnownBits KnownBits::computeForAddCarry( + const KnownBits &LHS, const KnownBits &RHS, const KnownBits &Carry) { + assert(Carry.getBitWidth() == 1 && "Carry must be 1-bit"); + return ::computeForAddCarry( + LHS, RHS, Carry.Zero.getBoolValue(), Carry.One.getBoolValue()); +} + +KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, + const KnownBits &LHS, KnownBits RHS) { + KnownBits KnownOut; + if (Add) { + // Sum = LHS + RHS + 0 + KnownOut = ::computeForAddCarry( + LHS, RHS, /*CarryZero*/true, /*CarryOne*/false); + } else { + // Sum = LHS + ~RHS + 1 + std::swap(RHS.Zero, RHS.One); + KnownOut = ::computeForAddCarry( + LHS, RHS, /*CarryZero*/false, /*CarryOne*/true); + } // Are we still trying to solve for the sign bit? - if (!Known.isSignBitSet()) { + if (!KnownOut.isNegative() && !KnownOut.isNonNegative()) { if (NSW) { // Adding two non-negative numbers, or subtracting a negative number from // a non-negative one, can't wrap into negative. diff --git a/llvm/unittests/Support/CMakeLists.txt b/llvm/unittests/Support/CMakeLists.txt index 12c983df50cc..b1cefddf9dee 100644 --- a/llvm/unittests/Support/CMakeLists.txt +++ b/llvm/unittests/Support/CMakeLists.txt @@ -34,6 +34,7 @@ add_llvm_unittest(SupportTests Host.cpp ItaniumManglingCanonicalizerTest.cpp JSONTest.cpp + KnownBitsTest.cpp LEB128Test.cpp LineIteratorTest.cpp LockFileManagerTest.cpp diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp new file mode 100644 index 000000000000..c2b3b127cf17 --- /dev/null +++ b/llvm/unittests/Support/KnownBitsTest.cpp @@ -0,0 +1,130 @@ +//===- llvm/unittest/Support/KnownBitsTest.cpp - KnownBits tests ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements unit tests for KnownBits functions. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Support/KnownBits.h" +#include "gtest/gtest.h" + +using namespace llvm; + +namespace { + +template +void ForeachKnownBits(unsigned Bits, FnTy Fn) { + unsigned Max = 1 << Bits; + KnownBits Known(Bits); + for (unsigned Zero = 0; Zero < Max; ++Zero) { + for (unsigned One = 0; One < Max; ++One) { + Known.Zero = Zero; + Known.One = One; + if (Known.hasConflict()) + continue; + + Fn(Known); + } + } +} + +template +void ForeachNumInKnownBits(const KnownBits &Known, FnTy Fn) { + unsigned Bits = Known.getBitWidth(); + unsigned Max = 1 << Bits; + for (unsigned N = 0; N < Max; ++N) { + APInt Num(Bits, N); + if ((Num & Known.Zero) != 0 || (~Num & Known.One) != 0) + continue; + + Fn(Num); + } +} + +TEST(KnownBitsTest, AddCarryExhaustive) { + unsigned Bits = 4; + ForeachKnownBits(Bits, [&](const KnownBits &Known1) { + ForeachKnownBits(Bits, [&](const KnownBits &Known2) { + ForeachKnownBits(1, [&](const KnownBits &KnownCarry) { + // Explicitly compute known bits of the addition by trying all + // possibilities. + KnownBits Known(Bits); + Known.Zero.setAllBits(); + Known.One.setAllBits(); + ForeachNumInKnownBits(Known1, [&](const APInt &N1) { + ForeachNumInKnownBits(Known2, [&](const APInt &N2) { + ForeachNumInKnownBits(KnownCarry, [&](const APInt &Carry) { + APInt Add = N1 + N2; + if (Carry.getBoolValue()) + ++Add; + + Known.One &= Add; + Known.Zero &= ~Add; + }); + }); + }); + + KnownBits KnownComputed = KnownBits::computeForAddCarry( + Known1, Known2, KnownCarry); + EXPECT_EQ(Known.Zero, KnownComputed.Zero); + EXPECT_EQ(Known.One, KnownComputed.One); + }); + }); + }); +} + +static void TestAddSubExhaustive(bool IsAdd) { + unsigned Bits = 4; + ForeachKnownBits(Bits, [&](const KnownBits &Known1) { + ForeachKnownBits(Bits, [&](const KnownBits &Known2) { + KnownBits Known(Bits), KnownNSW(Bits); + Known.Zero.setAllBits(); + Known.One.setAllBits(); + KnownNSW.Zero.setAllBits(); + KnownNSW.One.setAllBits(); + + ForeachNumInKnownBits(Known1, [&](const APInt &N1) { + ForeachNumInKnownBits(Known2, [&](const APInt &N2) { + bool Overflow; + APInt Res; + if (IsAdd) + Res = N1.sadd_ov(N2, Overflow); + else + Res = N1.ssub_ov(N2, Overflow); + + Known.One &= Res; + Known.Zero &= ~Res; + + if (!Overflow) { + KnownNSW.One &= Res; + KnownNSW.Zero &= ~Res; + } + }); + }); + + KnownBits KnownComputed = KnownBits::computeForAddSub( + IsAdd, /*NSW*/false, Known1, Known2); + EXPECT_EQ(Known.Zero, KnownComputed.Zero); + EXPECT_EQ(Known.One, KnownComputed.One); + + // The NSW calculation is not precise, only check that it's + // conservatively correct. + KnownBits KnownNSWComputed = KnownBits::computeForAddSub( + IsAdd, /*NSW*/true, Known1, Known2); + EXPECT_TRUE(KnownNSWComputed.Zero.isSubsetOf(KnownNSW.Zero)); + EXPECT_TRUE(KnownNSWComputed.One.isSubsetOf(KnownNSW.One)); + }); + }); +} + +TEST(KnownBitsTest, AddSubExhaustive) { + TestAddSubExhaustive(true); + TestAddSubExhaustive(false); +} + +} // end anonymous namespace