forked from OSchip/llvm-project
[libc] Add implementations for sqrt, sqrtf, and sqrtl.
Differential Revision: https://reviews.llvm.org/D84726
This commit is contained in:
parent
75d159f924
commit
5078825aa9
|
@ -75,6 +75,9 @@ set(TARGET_LIBM_ENTRYPOINTS
|
|||
libc.src.math.roundl
|
||||
libc.src.math.sincosf
|
||||
libc.src.math.sinf
|
||||
libc.src.math.sqrt
|
||||
libc.src.math.sqrtf
|
||||
libc.src.math.sqrtl
|
||||
libc.src.math.trunc
|
||||
libc.src.math.truncf
|
||||
libc.src.math.truncl
|
||||
|
|
|
@ -204,6 +204,9 @@ def MathAPI : PublicAPI<"math.h"> {
|
|||
"roundl",
|
||||
"sincosf",
|
||||
"sinf",
|
||||
"sqrt",
|
||||
"sqrtf",
|
||||
"sqrtl",
|
||||
"trunc",
|
||||
"truncf",
|
||||
"truncl",
|
||||
|
|
|
@ -108,6 +108,9 @@ set(TARGET_LIBM_ENTRYPOINTS
|
|||
libc.src.math.roundl
|
||||
libc.src.math.sincosf
|
||||
libc.src.math.sinf
|
||||
libc.src.math.sqrt
|
||||
libc.src.math.sqrtf
|
||||
libc.src.math.sqrtl
|
||||
libc.src.math.trunc
|
||||
libc.src.math.truncf
|
||||
libc.src.math.truncl
|
||||
|
|
|
@ -314,6 +314,10 @@ def StdC : StandardSpec<"stdc"> {
|
|||
FunctionSpec<"roundf", RetValSpec<FloatType>, [ArgSpec<FloatType>]>,
|
||||
FunctionSpec<"roundl", RetValSpec<LongDoubleType>, [ArgSpec<LongDoubleType>]>,
|
||||
|
||||
FunctionSpec<"sqrt", RetValSpec<DoubleType>, [ArgSpec<DoubleType>]>,
|
||||
FunctionSpec<"sqrtf", RetValSpec<FloatType>, [ArgSpec<FloatType>]>,
|
||||
FunctionSpec<"sqrtl", RetValSpec<LongDoubleType>, [ArgSpec<LongDoubleType>]>,
|
||||
|
||||
FunctionSpec<"trunc", RetValSpec<DoubleType>, [ArgSpec<DoubleType>]>,
|
||||
FunctionSpec<"truncf", RetValSpec<FloatType>, [ArgSpec<FloatType>]>,
|
||||
FunctionSpec<"truncl", RetValSpec<LongDoubleType>, [ArgSpec<LongDoubleType>]>,
|
||||
|
|
|
@ -485,3 +485,39 @@ add_entrypoint_object(
|
|||
COMPILE_OPTIONS
|
||||
-O2
|
||||
)
|
||||
|
||||
add_entrypoint_object(
|
||||
sqrt
|
||||
SRCS
|
||||
sqrt.cpp
|
||||
HDRS
|
||||
sqrt.h
|
||||
DEPENDS
|
||||
libc.utils.FPUtil.fputil
|
||||
COMPILE_OPTIONS
|
||||
-O2
|
||||
)
|
||||
|
||||
add_entrypoint_object(
|
||||
sqrtf
|
||||
SRCS
|
||||
sqrtf.cpp
|
||||
HDRS
|
||||
sqrtf.h
|
||||
DEPENDS
|
||||
libc.utils.FPUtil.fputil
|
||||
COMPILE_OPTIONS
|
||||
-O2
|
||||
)
|
||||
|
||||
add_entrypoint_object(
|
||||
sqrtl
|
||||
SRCS
|
||||
sqrtl.cpp
|
||||
HDRS
|
||||
sqrtl.h
|
||||
DEPENDS
|
||||
libc.utils.FPUtil.fputil
|
||||
COMPILE_OPTIONS
|
||||
-O2
|
||||
)
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
//===-- Implementation of sqrt function -----------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "utils/FPUtil/Sqrt.h"
|
||||
#include "src/__support/common.h"
|
||||
|
||||
namespace __llvm_libc {
|
||||
|
||||
double LLVM_LIBC_ENTRYPOINT(sqrt)(double x) { return fputil::sqrt(x); }
|
||||
|
||||
} // namespace __llvm_libc
|
|
@ -0,0 +1,18 @@
|
|||
//===-- Implementation header for sqrt --------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef LLVM_LIBC_SRC_MATH_SQRT_H
|
||||
#define LLVM_LIBC_SRC_MATH_SQRT_H
|
||||
|
||||
namespace __llvm_libc {
|
||||
|
||||
double sqrt(double x);
|
||||
|
||||
} // namespace __llvm_libc
|
||||
|
||||
#endif // LLVM_LIBC_SRC_MATH_SQRT_H
|
|
@ -0,0 +1,16 @@
|
|||
//===-- Implementation of sqrtf function ----------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "src/__support/common.h"
|
||||
#include "utils/FPUtil/Sqrt.h"
|
||||
|
||||
namespace __llvm_libc {
|
||||
|
||||
float LLVM_LIBC_ENTRYPOINT(sqrtf)(float x) { return fputil::sqrt(x); }
|
||||
|
||||
} // namespace __llvm_libc
|
|
@ -0,0 +1,18 @@
|
|||
//===-- Implementation header for sqrtf -------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef LLVM_LIBC_SRC_MATH_SQRTF_H
|
||||
#define LLVM_LIBC_SRC_MATH_SQRTF_H
|
||||
|
||||
namespace __llvm_libc {
|
||||
|
||||
float sqrtf(float x);
|
||||
|
||||
} // namespace __llvm_libc
|
||||
|
||||
#endif // LLVM_LIBC_SRC_MATH_SQRTF_H
|
|
@ -0,0 +1,18 @@
|
|||
//===-- Implementation of sqrtl function ----------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "src/__support/common.h"
|
||||
#include "utils/FPUtil/Sqrt.h"
|
||||
|
||||
namespace __llvm_libc {
|
||||
|
||||
long double LLVM_LIBC_ENTRYPOINT(sqrtl)(long double x) {
|
||||
return fputil::sqrt(x);
|
||||
}
|
||||
|
||||
} // namespace __llvm_libc
|
|
@ -0,0 +1,18 @@
|
|||
//===-- Implementation header for sqrtl -------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef LLVM_LIBC_SRC_MATH_SQRTL_H
|
||||
#define LLVM_LIBC_SRC_MATH_SQRTL_H
|
||||
|
||||
namespace __llvm_libc {
|
||||
|
||||
long double sqrtl(long double x);
|
||||
|
||||
} // namespace __llvm_libc
|
||||
|
||||
#endif // LLVM_LIBC_SRC_MATH_SQRTL_H
|
|
@ -513,3 +513,42 @@ add_fp_unittest(
|
|||
libc.src.math.fmaxl
|
||||
libc.utils.FPUtil.fputil
|
||||
)
|
||||
|
||||
add_fp_unittest(
|
||||
sqrtf_test
|
||||
NEED_MPFR
|
||||
SUITE
|
||||
libc_math_unittests
|
||||
SRCS
|
||||
sqrtf_test.cpp
|
||||
DEPENDS
|
||||
libc.include.math
|
||||
libc.src.math.sqrtf
|
||||
libc.utils.FPUtil.fputil
|
||||
)
|
||||
|
||||
add_fp_unittest(
|
||||
sqrt_test
|
||||
NEED_MPFR
|
||||
SUITE
|
||||
libc_math_unittests
|
||||
SRCS
|
||||
sqrt_test.cpp
|
||||
DEPENDS
|
||||
libc.include.math
|
||||
libc.src.math.sqrt
|
||||
libc.utils.FPUtil.fputil
|
||||
)
|
||||
|
||||
add_fp_unittest(
|
||||
sqrtl_test
|
||||
NEED_MPFR
|
||||
SUITE
|
||||
libc_math_unittests
|
||||
SRCS
|
||||
sqrtl_test.cpp
|
||||
DEPENDS
|
||||
libc.include.math
|
||||
libc.src.math.sqrtl
|
||||
libc.utils.FPUtil.fputil
|
||||
)
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
//===-- Unittests for sqrt -----------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
#include "include/math.h"
|
||||
#include "src/math/sqrt.h"
|
||||
#include "utils/FPUtil/FPBits.h"
|
||||
#include "utils/FPUtil/TestHelpers.h"
|
||||
#include "utils/MPFRWrapper/MPFRUtils.h"
|
||||
|
||||
using FPBits = __llvm_libc::fputil::FPBits<double>;
|
||||
using UIntType = typename FPBits::UIntType;
|
||||
|
||||
namespace mpfr = __llvm_libc::testing::mpfr;
|
||||
|
||||
constexpr UIntType HiddenBit =
|
||||
UIntType(1) << __llvm_libc::fputil::MantissaWidth<double>::value;
|
||||
|
||||
double nan = FPBits::buildNaN(1);
|
||||
double inf = FPBits::inf();
|
||||
double negInf = FPBits::negInf();
|
||||
|
||||
TEST(SqrtTest, SpecialValues) {
|
||||
ASSERT_FP_EQ(nan, __llvm_libc::sqrt(nan));
|
||||
ASSERT_FP_EQ(inf, __llvm_libc::sqrt(inf));
|
||||
ASSERT_FP_EQ(nan, __llvm_libc::sqrt(negInf));
|
||||
ASSERT_FP_EQ(0.0, __llvm_libc::sqrt(0.0));
|
||||
ASSERT_FP_EQ(-0.0, __llvm_libc::sqrt(-0.0));
|
||||
ASSERT_FP_EQ(nan, __llvm_libc::sqrt(-1.0));
|
||||
ASSERT_FP_EQ(1.0, __llvm_libc::sqrt(1.0));
|
||||
ASSERT_FP_EQ(2.0, __llvm_libc::sqrt(4.0));
|
||||
ASSERT_FP_EQ(3.0, __llvm_libc::sqrt(9.0));
|
||||
}
|
||||
|
||||
TEST(SqrtTest, DenormalValues) {
|
||||
for (UIntType mant = 1; mant < HiddenBit; mant <<= 1) {
|
||||
FPBits denormal(0.0);
|
||||
denormal.mantissa = mant;
|
||||
|
||||
ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, double(denormal),
|
||||
__llvm_libc::sqrt(denormal), 0.5);
|
||||
}
|
||||
|
||||
constexpr UIntType count = 1'000'001;
|
||||
constexpr UIntType step = HiddenBit / count;
|
||||
for (UIntType i = 0, v = 0; i <= count; ++i, v += step) {
|
||||
double x = *reinterpret_cast<double *>(&v);
|
||||
ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, x, __llvm_libc::sqrt(x), 0.5);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SqrtTest, InDoubleRange) {
|
||||
constexpr UIntType count = 10'000'001;
|
||||
constexpr UIntType step = UIntType(-1) / count;
|
||||
for (UIntType i = 0, v = 0; i <= count; ++i, v += step) {
|
||||
double x = *reinterpret_cast<double *>(&v);
|
||||
if (isnan(x) || (x < 0)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, x, __llvm_libc::sqrt(x), 0.5);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,67 @@
|
|||
//===-- Unittests for sqrtf -----------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
#include "include/math.h"
|
||||
#include "src/math/sqrtf.h"
|
||||
#include "utils/FPUtil/FPBits.h"
|
||||
#include "utils/FPUtil/TestHelpers.h"
|
||||
#include "utils/MPFRWrapper/MPFRUtils.h"
|
||||
|
||||
using FPBits = __llvm_libc::fputil::FPBits<float>;
|
||||
using UIntType = typename FPBits::UIntType;
|
||||
|
||||
namespace mpfr = __llvm_libc::testing::mpfr;
|
||||
|
||||
constexpr UIntType HiddenBit =
|
||||
UIntType(1) << __llvm_libc::fputil::MantissaWidth<float>::value;
|
||||
|
||||
float nan = FPBits::buildNaN(1);
|
||||
float inf = FPBits::inf();
|
||||
float negInf = FPBits::negInf();
|
||||
|
||||
TEST(SqrtfTest, SpecialValues) {
|
||||
ASSERT_FP_EQ(nan, __llvm_libc::sqrtf(nan));
|
||||
ASSERT_FP_EQ(inf, __llvm_libc::sqrtf(inf));
|
||||
ASSERT_FP_EQ(nan, __llvm_libc::sqrtf(negInf));
|
||||
ASSERT_FP_EQ(0.0f, __llvm_libc::sqrtf(0.0f));
|
||||
ASSERT_FP_EQ(-0.0f, __llvm_libc::sqrtf(-0.0f));
|
||||
ASSERT_FP_EQ(nan, __llvm_libc::sqrtf(-1.0f));
|
||||
ASSERT_FP_EQ(1.0f, __llvm_libc::sqrtf(1.0f));
|
||||
ASSERT_FP_EQ(2.0f, __llvm_libc::sqrtf(4.0f));
|
||||
ASSERT_FP_EQ(3.0f, __llvm_libc::sqrtf(9.0f));
|
||||
}
|
||||
|
||||
TEST(SqrtfTest, DenormalValues) {
|
||||
for (UIntType mant = 1; mant < HiddenBit; mant <<= 1) {
|
||||
FPBits denormal(0.0f);
|
||||
denormal.mantissa = mant;
|
||||
|
||||
ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, float(denormal),
|
||||
__llvm_libc::sqrtf(denormal), 0.5);
|
||||
}
|
||||
|
||||
constexpr UIntType count = 1'000'001;
|
||||
constexpr UIntType step = HiddenBit / count;
|
||||
for (UIntType i = 0, v = 0; i <= count; ++i, v += step) {
|
||||
float x = *reinterpret_cast<float *>(&v);
|
||||
ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, x, __llvm_libc::sqrtf(x), 0.5);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SqrtfTest, InFloatRange) {
|
||||
constexpr UIntType count = 10'000'001;
|
||||
constexpr UIntType step = UIntType(-1) / count;
|
||||
for (UIntType i = 0, v = 0; i <= count; ++i, v += step) {
|
||||
float x = *reinterpret_cast<float *>(&v);
|
||||
if (isnan(x) || (x < 0)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, x, __llvm_libc::sqrtf(x), 0.5);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,67 @@
|
|||
//===-- Unittests for sqrtl ----------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
#include "include/math.h"
|
||||
#include "src/math/sqrtl.h"
|
||||
#include "utils/FPUtil/FPBits.h"
|
||||
#include "utils/FPUtil/TestHelpers.h"
|
||||
#include "utils/MPFRWrapper/MPFRUtils.h"
|
||||
|
||||
using FPBits = __llvm_libc::fputil::FPBits<long double>;
|
||||
using UIntType = typename FPBits::UIntType;
|
||||
|
||||
namespace mpfr = __llvm_libc::testing::mpfr;
|
||||
|
||||
constexpr UIntType HiddenBit =
|
||||
UIntType(1) << __llvm_libc::fputil::MantissaWidth<long double>::value;
|
||||
|
||||
long double nan = FPBits::buildNaN(1);
|
||||
long double inf = FPBits::inf();
|
||||
long double negInf = FPBits::negInf();
|
||||
|
||||
TEST(SqrtlTest, SpecialValues) {
|
||||
ASSERT_FP_EQ(nan, __llvm_libc::sqrtl(nan));
|
||||
ASSERT_FP_EQ(inf, __llvm_libc::sqrtl(inf));
|
||||
ASSERT_FP_EQ(nan, __llvm_libc::sqrtl(negInf));
|
||||
ASSERT_FP_EQ(0.0L, __llvm_libc::sqrtl(0.0L));
|
||||
ASSERT_FP_EQ(-0.0L, __llvm_libc::sqrtl(-0.0L));
|
||||
ASSERT_FP_EQ(nan, __llvm_libc::sqrtl(-1.0L));
|
||||
ASSERT_FP_EQ(1.0L, __llvm_libc::sqrtl(1.0L));
|
||||
ASSERT_FP_EQ(2.0L, __llvm_libc::sqrtl(4.0L));
|
||||
ASSERT_FP_EQ(3.0L, __llvm_libc::sqrtl(9.0L));
|
||||
}
|
||||
|
||||
TEST(SqrtlTest, DenormalValues) {
|
||||
for (UIntType mant = 1; mant < HiddenBit; mant <<= 1) {
|
||||
FPBits denormal(0.0L);
|
||||
denormal.mantissa = mant;
|
||||
|
||||
ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, static_cast<long double>(denormal),
|
||||
__llvm_libc::sqrtl(denormal), 0.5);
|
||||
}
|
||||
|
||||
constexpr UIntType count = 1'000'001;
|
||||
constexpr UIntType step = HiddenBit / count;
|
||||
for (UIntType i = 0, v = 0; i <= count; ++i, v += step) {
|
||||
long double x = *reinterpret_cast<long double *>(&v);
|
||||
ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, x, __llvm_libc::sqrtl(x), 0.5);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SqrtlTest, InLongDoubleRange) {
|
||||
constexpr UIntType count = 10'000'001;
|
||||
constexpr UIntType step = UIntType(-1) / count;
|
||||
for (UIntType i = 0, v = 0; i <= count; ++i, v += step) {
|
||||
long double x = *reinterpret_cast<long double *>(&v);
|
||||
if (isnan(x) || (x < 0)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ASSERT_MPFR_MATCH(mpfr::Operation::Sqrt, x, __llvm_libc::sqrtl(x), 0.5);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,186 @@
|
|||
//===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef LLVM_LIBC_UTILS_FPUTIL_SQRT_H
|
||||
#define LLVM_LIBC_UTILS_FPUTIL_SQRT_H
|
||||
|
||||
#include "FPBits.h"
|
||||
|
||||
#include "utils/CPP/TypeTraits.h"
|
||||
|
||||
namespace __llvm_libc {
|
||||
namespace fputil {
|
||||
|
||||
namespace internal {
|
||||
|
||||
template <typename T>
|
||||
static inline void normalize(int &exponent,
|
||||
typename FPBits<T>::UIntType &mantissa);
|
||||
|
||||
template <> inline void normalize<float>(int &exponent, uint32_t &mantissa) {
|
||||
// Use binary search to shift the leading 1 bit.
|
||||
// With MantissaWidth<float> = 23, it will take
|
||||
// ceil(log2(23)) = 5 steps checking the mantissa bits as followed:
|
||||
// Step 1: 0000 0000 0000 XXXX XXXX XXXX
|
||||
// Step 2: 0000 00XX XXXX XXXX XXXX XXXX
|
||||
// Step 3: 000X XXXX XXXX XXXX XXXX XXXX
|
||||
// Step 4: 00XX XXXX XXXX XXXX XXXX XXXX
|
||||
// Step 5: 0XXX XXXX XXXX XXXX XXXX XXXX
|
||||
constexpr int nsteps = 5; // = ceil(log2(MantissaWidth))
|
||||
constexpr uint32_t bounds[nsteps] = {1 << 12, 1 << 18, 1 << 21, 1 << 22,
|
||||
1 << 23};
|
||||
constexpr int shifts[nsteps] = {12, 6, 3, 2, 1};
|
||||
|
||||
for (int i = 0; i < nsteps; ++i) {
|
||||
if (mantissa < bounds[i]) {
|
||||
exponent -= shifts[i];
|
||||
mantissa <<= shifts[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <> inline void normalize<double>(int &exponent, uint64_t &mantissa) {
|
||||
// Use binary search to shift the leading 1 bit similar to float.
|
||||
// With MantissaWidth<double> = 52, it will take
|
||||
// ceil(log2(52)) = 6 steps checking the mantissa bits.
|
||||
constexpr int nsteps = 6; // = ceil(log2(MantissaWidth))
|
||||
constexpr uint64_t bounds[nsteps] = {1ULL << 26, 1ULL << 39, 1ULL << 46,
|
||||
1ULL << 49, 1ULL << 51, 1ULL << 52};
|
||||
constexpr int shifts[nsteps] = {27, 14, 7, 4, 2, 1};
|
||||
|
||||
for (int i = 0; i < nsteps; ++i) {
|
||||
if (mantissa < bounds[i]) {
|
||||
exponent -= shifts[i];
|
||||
mantissa <<= shifts[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if !(defined(__x86_64__) || defined(__i386__))
|
||||
template <>
|
||||
inline void normalize<long double>(int &exponent, __uint128_t &mantissa) {
|
||||
// Use binary search to shift the leading 1 bit similar to float.
|
||||
// With MantissaWidth<long double> = 112, it will take
|
||||
// ceil(log2(112)) = 7 steps checking the mantissa bits.
|
||||
constexpr int nsteps = 7; // = ceil(log2(MantissaWidth))
|
||||
constexpr __uint128_t bounds[nsteps] = {
|
||||
__uint128_t(1) << 56, __uint128_t(1) << 84, __uint128_t(1) << 98,
|
||||
__uint128_t(1) << 105, __uint128_t(1) << 109, __uint128_t(1) << 111,
|
||||
__uint128_t(1) << 112};
|
||||
constexpr int shifts[nsteps] = {57, 29, 15, 8, 4, 2, 1};
|
||||
|
||||
for (int i = 0; i < nsteps; ++i) {
|
||||
if (mantissa < bounds[i]) {
|
||||
exponent -= shifts[i];
|
||||
mantissa <<= shifts[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace internal
|
||||
|
||||
// Correctly rounded IEEE 754 SQRT with round to nearest, ties to even.
|
||||
// Shift-and-add algorithm.
|
||||
template <typename T,
|
||||
cpp::EnableIfType<cpp::IsFloatingPointType<T>::Value, int> = 0>
|
||||
static inline T sqrt(T x) {
|
||||
using UIntType = typename FPBits<T>::UIntType;
|
||||
constexpr UIntType One = UIntType(1) << MantissaWidth<T>::value;
|
||||
|
||||
FPBits<T> bits(x);
|
||||
|
||||
if (bits.isInfOrNaN()) {
|
||||
if (bits.sign && (bits.mantissa == 0)) {
|
||||
// sqrt(-Inf) = NaN
|
||||
return FPBits<T>::buildNaN(One >> 1);
|
||||
} else {
|
||||
// sqrt(NaN) = NaN
|
||||
// sqrt(+Inf) = +Inf
|
||||
return x;
|
||||
}
|
||||
} else if (bits.isZero()) {
|
||||
// sqrt(+0) = +0
|
||||
// sqrt(-0) = -0
|
||||
return x;
|
||||
} else if (bits.sign) {
|
||||
// sqrt( negative numbers ) = NaN
|
||||
return FPBits<T>::buildNaN(One >> 1);
|
||||
} else {
|
||||
int xExp = bits.getExponent();
|
||||
UIntType xMant = bits.mantissa;
|
||||
|
||||
// Step 1a: Normalize denormal input and append hiddent bit to the mantissa
|
||||
if (bits.exponent == 0) {
|
||||
++xExp; // let xExp be the correct exponent of One bit.
|
||||
internal::normalize<T>(xExp, xMant);
|
||||
} else {
|
||||
xMant |= One;
|
||||
}
|
||||
|
||||
// Step 1b: Make sure the exponent is even.
|
||||
if (xExp & 1) {
|
||||
--xExp;
|
||||
xMant <<= 1;
|
||||
}
|
||||
|
||||
// After step 1b, x = 2^(xExp) * xMant, where xExp is even, and
|
||||
// 1 <= xMant < 4. So sqrt(x) = 2^(xExp / 2) * y, with 1 <= y < 2.
|
||||
// Notice that the output of sqrt is always in the normal range.
|
||||
// To perform shift-and-add algorithm to find y, let denote:
|
||||
// y(n) = 1.y_1 y_2 ... y_n, we can define the nth residue to be:
|
||||
// r(n) = 2^n ( xMant - y(n)^2 ).
|
||||
// That leads to the following recurrence formula:
|
||||
// r(n) = 2*r(n-1) - y_n*[ 2*y(n-1) + 2^(-n-1) ]
|
||||
// with the initial conditions: y(0) = 1, and r(0) = x - 1.
|
||||
// So the nth digit y_n of the mantissa of sqrt(x) can be found by:
|
||||
// y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1)
|
||||
// 0 otherwise.
|
||||
UIntType y = One;
|
||||
UIntType r = xMant - One;
|
||||
|
||||
for (UIntType current_bit = One >> 1; current_bit; current_bit >>= 1) {
|
||||
r <<= 1;
|
||||
UIntType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1)
|
||||
if (r >= tmp) {
|
||||
r -= tmp;
|
||||
y += current_bit;
|
||||
}
|
||||
}
|
||||
|
||||
// We compute one more iteration in order to round correctly.
|
||||
bool lsb = y & 1; // Least significant bit
|
||||
bool rb = false; // Round bit
|
||||
r <<= 2;
|
||||
UIntType tmp = (y << 2) + 1;
|
||||
if (r >= tmp) {
|
||||
r -= tmp;
|
||||
rb = true;
|
||||
}
|
||||
|
||||
// Remove hidden bit and append the exponent field.
|
||||
xExp = ((xExp >> 1) + FPBits<T>::exponentBias);
|
||||
|
||||
y = (y - One) | (static_cast<UIntType>(xExp) << MantissaWidth<T>::value);
|
||||
// Round to nearest, ties to even
|
||||
if (rb && (lsb || (r != 0))) {
|
||||
++y;
|
||||
}
|
||||
|
||||
return *reinterpret_cast<T *>(&y);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fputil
|
||||
} // namespace __llvm_libc
|
||||
|
||||
#if (defined(__x86_64__) || defined(__i386__))
|
||||
#include "SqrtLongDoubleX86.h"
|
||||
#endif // defined(__x86_64__) || defined(__i386__)
|
||||
|
||||
#endif // LLVM_LIBC_UTILS_FPUTIL_SQRT_H
|
|
@ -0,0 +1,142 @@
|
|||
//===-- Square root of x86 long double numbers ------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef LLVM_LIBC_UTILS_FPUTIL_SQRT_LONG_DOUBLE_X86_H
|
||||
#define LLVM_LIBC_UTILS_FPUTIL_SQRT_LONG_DOUBLE_X86_H
|
||||
|
||||
#include "FPBits.h"
|
||||
#include "utils/CPP/TypeTraits.h"
|
||||
|
||||
namespace __llvm_libc {
|
||||
namespace fputil {
|
||||
|
||||
#if (defined(__x86_64__) || defined(__i386__))
|
||||
namespace internal {
|
||||
|
||||
template <>
|
||||
inline void normalize<long double>(int &exponent, __uint128_t &mantissa) {
|
||||
// Use binary search to shift the leading 1 bit similar to float.
|
||||
// With MantissaWidth<long double> = 63, it will take
|
||||
// ceil(log2(63)) = 6 steps checking the mantissa bits.
|
||||
constexpr int nsteps = 6; // = ceil(log2(MantissaWidth))
|
||||
constexpr __uint128_t bounds[nsteps] = {
|
||||
__uint128_t(1) << 32, __uint128_t(1) << 48, __uint128_t(1) << 56,
|
||||
__uint128_t(1) << 60, __uint128_t(1) << 62, __uint128_t(1) << 63};
|
||||
constexpr int shifts[nsteps] = {32, 16, 8, 4, 2, 1};
|
||||
|
||||
for (int i = 0; i < nsteps; ++i) {
|
||||
if (mantissa < bounds[i]) {
|
||||
exponent -= shifts[i];
|
||||
mantissa <<= shifts[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
|
||||
// Correctly rounded SQRT with round to nearest, ties to even.
|
||||
// Shift-and-add algorithm.
|
||||
template <> inline long double sqrt<long double, 0>(long double x) {
|
||||
using UIntType = typename FPBits<long double>::UIntType;
|
||||
constexpr UIntType One = UIntType(1)
|
||||
<< int(MantissaWidth<long double>::value);
|
||||
|
||||
FPBits<long double> bits(x);
|
||||
|
||||
if (bits.isInfOrNaN()) {
|
||||
if (bits.sign && (bits.mantissa == 0)) {
|
||||
// sqrt(-Inf) = NaN
|
||||
return FPBits<long double>::buildNaN(One >> 1);
|
||||
} else {
|
||||
// sqrt(NaN) = NaN
|
||||
// sqrt(+Inf) = +Inf
|
||||
return x;
|
||||
}
|
||||
} else if (bits.isZero()) {
|
||||
// sqrt(+0) = +0
|
||||
// sqrt(-0) = -0
|
||||
return x;
|
||||
} else if (bits.sign) {
|
||||
// sqrt( negative numbers ) = NaN
|
||||
return FPBits<long double>::buildNaN(One >> 1);
|
||||
} else {
|
||||
int xExp = bits.getExponent();
|
||||
UIntType xMant = bits.mantissa;
|
||||
|
||||
// Step 1a: Normalize denormal input
|
||||
if (bits.implicitBit) {
|
||||
xMant |= One;
|
||||
} else if (bits.exponent == 0) {
|
||||
internal::normalize<long double>(xExp, xMant);
|
||||
}
|
||||
|
||||
// Step 1b: Make sure the exponent is even.
|
||||
if (xExp & 1) {
|
||||
--xExp;
|
||||
xMant <<= 1;
|
||||
}
|
||||
|
||||
// After step 1b, x = 2^(xExp) * xMant, where xExp is even, and
|
||||
// 1 <= xMant < 4. So sqrt(x) = 2^(xExp / 2) * y, with 1 <= y < 2.
|
||||
// Notice that the output of sqrt is always in the normal range.
|
||||
// To perform shift-and-add algorithm to find y, let denote:
|
||||
// y(n) = 1.y_1 y_2 ... y_n, we can define the nth residue to be:
|
||||
// r(n) = 2^n ( xMant - y(n)^2 ).
|
||||
// That leads to the following recurrence formula:
|
||||
// r(n) = 2*r(n-1) - y_n*[ 2*y(n-1) + 2^(-n-1) ]
|
||||
// with the initial conditions: y(0) = 1, and r(0) = x - 1.
|
||||
// So the nth digit y_n of the mantissa of sqrt(x) can be found by:
|
||||
// y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1)
|
||||
// 0 otherwise.
|
||||
UIntType y = One;
|
||||
UIntType r = xMant - One;
|
||||
|
||||
for (UIntType current_bit = One >> 1; current_bit; current_bit >>= 1) {
|
||||
r <<= 1;
|
||||
UIntType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1)
|
||||
if (r >= tmp) {
|
||||
r -= tmp;
|
||||
y += current_bit;
|
||||
}
|
||||
}
|
||||
|
||||
// We compute one more iteration in order to round correctly.
|
||||
bool lsb = y & 1; // Least significant bit
|
||||
bool rb = false; // Round bit
|
||||
r <<= 2;
|
||||
UIntType tmp = (y << 2) + 1;
|
||||
if (r >= tmp) {
|
||||
r -= tmp;
|
||||
rb = true;
|
||||
}
|
||||
|
||||
// Append the exponent field.
|
||||
xExp = ((xExp >> 1) + FPBits<long double>::exponentBias);
|
||||
y |= (static_cast<UIntType>(xExp)
|
||||
<< (MantissaWidth<long double>::value + 1));
|
||||
|
||||
// Round to nearest, ties to even
|
||||
if (rb && (lsb || (r != 0))) {
|
||||
++y;
|
||||
}
|
||||
|
||||
// Extract output
|
||||
FPBits<long double> out(0.0L);
|
||||
out.exponent = xExp;
|
||||
out.implicitBit = 1;
|
||||
out.mantissa = (y & (One - 1));
|
||||
|
||||
return out;
|
||||
}
|
||||
}
|
||||
#endif // defined(__x86_64__) || defined(__i386__)
|
||||
|
||||
} // namespace fputil
|
||||
} // namespace __llvm_libc
|
||||
|
||||
#endif // LLVM_LIBC_UTILS_FPUTIL_SQRT_LONG_DOUBLE_X86_H
|
Loading…
Reference in New Issue