From 41fc7fb0363f35c606ce0f5999fbad57df8a7ba5 Mon Sep 17 00:00:00 2001 From: zong-shuai Date: Tue, 10 Jan 2023 20:32:50 +0800 Subject: [PATCH] debug --- .../kernel/cuda_impl/cuda_ops/zeta_impl.cu | 86 ++----------------- 1 file changed, 6 insertions(+), 80 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/zeta_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/zeta_impl.cu index e343753fba0..421e9dccdab 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/zeta_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/zeta_impl.cu @@ -19,90 +19,16 @@ #include #include "include/cuda_runtime.h" #include "include/cuda_fp16.h" +#include "unsupported/Eigen/CXX11/Tensor" +template +__device__ __forceinline__ T zeta(T x, T q) { + return Eigen::internal::scalar_zeta_op()(x, q); +} template __global__ void ZetaKernel(const size_t size, const T *x, const T *dimension, T *output) { for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - double p = static_cast(x[pos]); - double q = static_cast(dimension[pos]); - const double MACHEP = static_cast(1.11022302462515654042E-16); - constexpr double zero = static_cast(0.0); - constexpr double half = static_cast(0.5); - constexpr double one = static_cast(1.0); - static const double A[] = { - 12.0, - -720.0, - 30240.0, - -1209600.0, - 47900160.0, - -1.8924375803183791606e9, /*1.307674368e12/691*/ - 7.47242496e10, - -2.950130727918164224e12, /*1.067062284288e16/3617*/ - 1.1646782814350067249e14, /*5.109094217170944e18/43867*/ - -4.5979787224074726105e15, /*8.028576626982912e20/174611*/ - 1.8152105401943546773e17, /*1.5511210043330985984e23/854513*/ - -7.1661652561756670113e18 /*1.6938241367317436694528e27/236364091*/ - }; - int i = 0; - double a, b, k, s, t, w; - bool flag = false; - if (p == one) { - output[pos] = std::numeric_limits::infinity(); - continue; - } - if (p < one) { - output[pos] = std::numeric_limits::quiet_NaN(); - continue; - } - if (q <= zero) { - if (q == std::floor(q)) { - output[pos] = std::numeric_limits::infinity(); - continue; - } - if (p != std::floor(p)) { - output[pos] = std::numeric_limits::quiet_NaN(); - continue; - } - } - s = pow(q, -p); - a = q; - i = 0; - b = zero; - while ((i < 9) || (a <= T(9.0))) { - i += 1; - a += one; - b = pow(a, -p); - s += b; - if ((-MACHEP * s < b) && (b < MACHEP * s)) { - output[pos] = static_cast(s); - flag = true; - break; - } - } - if (flag) { - continue; - } - w = a; - s += b * w / (p - one); - s -= half * b; - a = one; - k = zero; - for (int i = 0; i < 12; i++) { - a *= p + k; - b /= w; - t = a * b / A[i]; - s = s + t; - t = fabs(t / s); - if (t < MACHEP) { - output[pos] = static_cast(s); - break; - } - k += one; - a *= p + k; - b /= w; - k += one; - } - output[pos] = static_cast(s); + output[pos] = zeta(x[pos], dimension[pos]); } return; }