From 7380756a0fc6202c72b1641238fc4f21dc628b97 Mon Sep 17 00:00:00 2001 From: fakeen <15355477746@163.com> Date: Mon, 18 Apr 2022 12:49:21 +0800 Subject: [PATCH] igamma --- .../device/cpu/kernel/igamma_cpu_kernel.cc | 385 ++++++++++++++++ .../device/cpu/kernel/igamma_cpu_kernel.h | 70 +++ .../device/cpu/kernel/igammac_cpu_kernel.cc | 379 ++++++++++++++++ .../device/cpu/kernel/igammac_cpu_kernel.h | 69 +++ .../cpu/kernel/igammagrada_cpu_kernel.cc | 413 ++++++++++++++++++ .../cpu/kernel/igammagrada_cpu_kernel.h | 69 +++ mindspore/core/ops/core_ops.h | 3 + mindspore/core/ops/grad/igammagrada.cc | 61 +++ mindspore/core/ops/grad/igammagrada.h | 42 ++ mindspore/core/ops/igamma.cc | 59 +++ mindspore/core/ops/igamma.h | 43 ++ mindspore/core/ops/igammac.cc | 60 +++ mindspore/core/ops/igammac.h | 43 ++ .../ops/_grad_experimental/grad_math_ops.py | 51 +++ .../mindspore/ops/_op_impl/aicpu/__init__.py | 3 + .../mindspore/ops/_op_impl/aicpu/igamma.py | 30 ++ .../mindspore/ops/_op_impl/aicpu/igammac.py | 30 ++ .../ops/_op_impl/aicpu/igammagrada.py | 30 ++ .../mindspore/ops/operations/_grad_ops.py | 36 ++ .../mindspore/ops/operations/math_ops.py | 100 +++++ tests/ut/python/ops/test_math_ops.py | 23 +- 21 files changed, 1997 insertions(+), 2 deletions(-) create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/igamma_cpu_kernel.cc create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/igamma_cpu_kernel.h create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/igammac_cpu_kernel.cc create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/igammac_cpu_kernel.h create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/igammagrada_cpu_kernel.cc create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/igammagrada_cpu_kernel.h create mode 100644 mindspore/core/ops/grad/igammagrada.cc create mode 100644 mindspore/core/ops/grad/igammagrada.h create mode 100644 mindspore/core/ops/igamma.cc create mode 100644 mindspore/core/ops/igamma.h create mode 100644 mindspore/core/ops/igammac.cc create mode 100644 mindspore/core/ops/igammac.h create mode 100644 mindspore/python/mindspore/ops/_op_impl/aicpu/igamma.py create mode 100644 mindspore/python/mindspore/ops/_op_impl/aicpu/igammac.py create mode 100644 mindspore/python/mindspore/ops/_op_impl/aicpu/igammagrada.py diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/igamma_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/igamma_cpu_kernel.cc new file mode 100644 index 00000000000..d95e5c87c48 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/igamma_cpu_kernel.cc @@ -0,0 +1,385 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "plugin/device/cpu/kernel/igamma_cpu_kernel.h" +#include +#include "plugin/device/cpu/hal/device/cpu_device_address.h" +namespace mindspore { +namespace kernel { +namespace { +/** + * Coefficients for the Lanczos approximation of the gamma function. The + * coefficients are uniquely determined by the choice of g and n (kLanczosGamma + * and kLanczosCoefficients.size() + 1). The coefficients below correspond to + * [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and [7, + * 9] seemed to be the least sensitive to the quality of the log function. In + * particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5 + * for a particularly inaccurate log function. + * */ +static constexpr double kLanczosGamma = 7; // aka g +constexpr size_t kInputIndex0 = 0; +constexpr size_t kInputIndex1 = 1; +constexpr size_t kInputIndex2 = 2; +constexpr size_t kInputIndex3 = 3; +constexpr size_t kInputIndex4 = 4; +constexpr size_t kInputIndex5 = 5; +constexpr size_t kInputIndex6 = 6; +constexpr size_t kInputIndex7 = 7; +constexpr size_t kInputIndex8 = 8; +constexpr size_t kInputIndex9 = 9; +constexpr size_t kInputIndex10 = 10; +constexpr size_t kInputIndex11 = 11; +constexpr size_t kInputIndex12 = 12; +constexpr size_t kInputIndex13 = 13; +constexpr size_t kInputIndex14 = 14; +static constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478; +static constexpr double M_pi = 3.141592653589793238462643383279; +static constexpr std::array kLanczosCoefficients = { + 676.520368121885098567009190444019, -1259.13921672240287047156078755283, + 771.3234287776530788486528258894, -176.61502916214059906584551354, + 12.507343278686904814458936853, -0.13857109526572011689554707, + 9.984369578019570859563e-6, 1.50563273514931155834e-7}; +double log_lanczos_gamma_plus_one_half = std::log(kLanczosGamma + 0.5); +constexpr int64_t kParallelDataNums = 256; +constexpr int64_t kSameShape = 0; +constexpr int64_t kXOneElement = 1; +constexpr int64_t kYOneElement = 2; +constexpr size_t kInputNum = 2; +constexpr size_t kOutputNum = 1; + +size_t get_element_num(const std::vector &shape) { + size_t size = 1; + for (size_t i = 0; i < shape.size(); i++) { + size *= shape[i]; + } + return size; +} +} // namespace +/** Compute the Lgamma function using Lanczos' approximation from "A Precision + * Approximation of the Gamma Function". SIAM Journal on Numerical Analysis + * series B. Vol. 1: + * lgamma(z + 1) = (log(2) + log(pi)) / 2 + (z + 1/2) * log(t(z)) - t(z) + A(z) + * t(z) = z + kLanczosGamma + 1/2 + * A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k)) + */ + +template +T Lgamma(const T &input) { + T log_pi = std::log(M_pi); + T log_sqrt_two_pi = (std::log(2) + std::log(M_pi)) / 2; + + /** If the input is less than 0.5 use Euler's reflection formula: + * gamma(x) = pi / (sin(pi * x) * gamma(1 - x)) + */ + bool need_to_reflect = (input < 0.5); + T input_after_reflect = need_to_reflect ? -input : input - 1; + T sum = kBaseLanczosCoeff; + for (size_t i = 0, end = kLanczosCoefficients.size(); i < end; ++i) { + T lanczos_coefficient = kLanczosCoefficients[i]; + + sum += lanczos_coefficient / (input_after_reflect + i + 1); + } + T gamma_plus_onehalf_plus_z = kLanczosGamma + 0.5 + input_after_reflect; + T log_t = log_lanczos_gamma_plus_one_half + std::log1pf(input_after_reflect / (kLanczosGamma + 0.5)); + T log_y = log_sqrt_two_pi + (input_after_reflect + 0.5 - gamma_plus_onehalf_plus_z / log_t) * log_t + std::log(sum); + T abs_input = std::abs(input); + T abs_frac_input = abs_input - std::floor(abs_input); + + T reduced_frac_input = (abs_frac_input > 0.5) ? 1 - abs_frac_input : abs_frac_input; + T reflection_denom = std::log(std::sin(M_pi * reduced_frac_input)); + + T reflection = std::isfinite(reflection_denom) ? log_pi - reflection_denom - log_y : -reflection_denom; + T result = need_to_reflect ? reflection : log_y; + + return std::isinf(input) ? std::numeric_limits::infinity() : result; +} + +template +T use_igammact(const T &ax, const T &a, const T &x, T enabled) { + T y = 1 - a; + T z = x + y + 1; + T c = 0; + T pkm2 = 1; + T qkm2 = x; + T pkm1 = x + 1; + T qkm1 = z * x; + T ans = pkm1 / qkm1; + T t = 1; + T dpkm2_da = 0; + T dqkm2_da = 0; + T dpkm1_da = 0; + T dqkm1_da = -x; + T dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1; + std::vector vals = {enabled, ans, t, y, z, c, pkm1, qkm1, + pkm2, qkm2, dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da}; + constexpr int k2000 = 2000; + while (vals[kInputIndex0] && vals[kInputIndex5] < k2000) { + enabled = vals[kInputIndex0]; + ans = vals[kInputIndex1]; + T tmp_var_t = vals[kInputIndex2]; + T tmp_var_y = vals[kInputIndex3]; + T tmp_var_z = vals[kInputIndex4]; + T tmp_var_c = vals[kInputIndex5]; + pkm1 = vals[kInputIndex6]; + qkm1 = vals[kInputIndex7]; + pkm2 = vals[kInputIndex8]; + qkm2 = vals[kInputIndex9]; + dpkm2_da = vals[kInputIndex10]; + dqkm2_da = vals[kInputIndex11]; + dpkm1_da = vals[kInputIndex12]; + dqkm1_da = vals[kInputIndex13]; + dans_da = vals[kInputIndex14]; + tmp_var_c += 1; + tmp_var_y += 1; + constexpr int TWO = 2; + tmp_var_z += TWO; + + T yc = tmp_var_y * tmp_var_c; + T pk = pkm1 * tmp_var_z - pkm2 * yc; + T qk = qkm1 * tmp_var_z - qkm2 * yc; + bool qk_is_nonzero = (qk != 0); + T r = pk / qk; + t = qk_is_nonzero ? std::abs((ans - r) / r) : 1; + ans = qk_is_nonzero ? r : ans; + + T dpk_da = dpkm1_da * tmp_var_z - pkm1 - dpkm2_da * yc + pkm2 * tmp_var_c; + T dqk_da = dqkm1_da * tmp_var_z - qkm1 - dqkm2_da * yc + qkm2 * tmp_var_c; + T dans_da_new = qk_is_nonzero ? (dpk_da - ans * dqk_da) / qk : dans_da; + pkm2 = pkm1; + pkm1 = pk; + qkm2 = qkm1; + qkm1 = qk; + + dpkm2_da = dpkm1_da; + dqkm2_da = dqkm1_da; + dpkm1_da = dpk_da; + dqkm1_da = dqk_da; + bool rescale = std::abs(pk) > (1 / std::numeric_limits::epsilon()); + + pkm2 = rescale ? pkm2 * std::numeric_limits::epsilon() : pkm2; + pkm1 = rescale ? pkm1 * std::numeric_limits::epsilon() : pkm1; + qkm2 = rescale ? qkm2 * std::numeric_limits::epsilon() : qkm2; + qkm1 = rescale ? qkm1 * std::numeric_limits::epsilon() : qkm1; + + dpkm2_da = rescale ? dpkm2_da * std::numeric_limits::epsilon() : dpkm2_da; + dqkm2_da = rescale ? dqkm2_da * std::numeric_limits::epsilon() : dqkm2_da; + dpkm1_da = rescale ? dpkm1_da * std::numeric_limits::epsilon() : dpkm1_da; + dqkm1_da = rescale ? dqkm1_da * std::numeric_limits::epsilon() : dqkm1_da; + + T conditional = enabled && (t > std::numeric_limits::epsilon()); + vals[kInputIndex0] = conditional; + vals[kInputIndex5] = tmp_var_c; + if (enabled) { + vals = {conditional, ans, tmp_var_t, tmp_var_y, tmp_var_z, tmp_var_c, pkm1, qkm1, + pkm2, qkm2, dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da_new}; + } + } + ans = vals[kInputIndex1]; + return 1 - ans * ax; +} + +template +T use_igammacf(T ax, T a, T x, T enabled) { + std::vector vals = {enabled, a, 1, 1, x, 0, 0}; + while (vals[kInputIndex0] != 0) { + enabled = vals[kInputIndex0]; + T r = vals[kInputIndex1]; + T c = vals[kInputIndex2]; + T ans = vals[kInputIndex3]; + x = vals[kInputIndex4]; + T dc_da = vals[kInputIndex5]; + T dans_da = vals[kInputIndex6]; + r += 1; + dc_da = dc_da * (x / r) + (-1 * c * x) / (r * r); + dans_da = dans_da + dc_da; + c = c * (x / r); + ans = ans + c; + T conditional = enabled && (c / ans > std::numeric_limits::epsilon()); + vals[kInputIndex0] = conditional; + if (enabled) { + vals = {conditional, r, c, ans, x, dc_da, dans_da}; + } + } + T ans = vals[kInputIndex3]; + if (a == 0) { + return NAN; + } + return (ans * ax) / a; +} + +template +T IgammaSingle(const T &a, const T &x) { + if (!std::isinf(a) && (a > 0) && std::isinf(x) && x > 0) { + return 1; + } + bool is_nan = std::isnan(a) || std::isnan(x); + bool x_is_zero = (x == 0); + bool domain_error = (x < 0) || (a <= 0); + bool use_igammac = (x > 1) && (x > a); + + T ax = a * std::log(x) - x - Lgamma(a); + + bool underflow = (ax < -std::log(std::numeric_limits::max())); + + ax = std::exp(ax); + T enabled = static_cast(!(x_is_zero || domain_error || underflow || is_nan)); + T output; + if (use_igammac != 0) { + enabled = static_cast(enabled && use_igammac); + output = use_igammact(ax, a, x, enabled); + } else { + enabled = static_cast(enabled && !(use_igammac)); + output = use_igammacf(ax, a, x, enabled); + } + output = (domain_error || is_nan || std::isnan(output)) ? std::numeric_limits::quiet_NaN() : output; + output = x_is_zero ? 0 : output; + return output; +} + +template +void IgammaCpuKernelMod::BcastCompute(const std::vector &inputs, + const std::vector &outputs) { + auto a_data_addr = reinterpret_cast(inputs[0]->addr); + auto x_data_addr = reinterpret_cast(inputs[1]->addr); + auto z_data_addr = reinterpret_cast(outputs[0]->addr); + size_t data_num = get_element_num(z_shape_); + auto output_shape = CPUKernelUtils::GetBroadcastShape(a_shape_, x_shape_); + BroadcastIterator iter(a_shape_, x_shape_, output_shape); + if (data_num < kParallelDataNums) { + iter.SetPos(0); + for (size_t i = 0; i < data_num; i++) { + T *a_index = a_data_addr + iter.GetInputPosA(); // i-th value of input0 + T *x_index = x_data_addr + iter.GetInputPosB(); // i-th value of input1 + *(z_data_addr + i) = IgammaSingle(*a_index, *x_index); + iter.GenNextPos(); + } + } else { + auto shard_igamma = [z_data_addr, a_data_addr, x_data_addr, &iter](size_t start, size_t end) { + iter.SetPos(start); + for (size_t i = start; i < end; i++) { + T *a_index = a_data_addr + iter.GetInputPosA(); // i-th value of input0 + T *x_index = x_data_addr + iter.GetInputPosB(); // i-th value of input1 + *(z_data_addr + i) = IgammaSingle(*a_index, *x_index); + iter.GenNextPos(); + } + }; + ParallelLaunchAutoSearch(shard_igamma, data_num, this, ¶llel_search_info_); + } +} + +/* special compute is used in the following situations. + * 1. the shapes of input1 and input2 are the same + * 2. input1 is a 1D tensor with only one element or input1 is scalar + * 3. input2 is a 1D tensor with only one element or input2 is scalar + * 4. the shapes of input1 and input2 are different + **/ +template +void IgammaCpuKernelMod::SpecialCompute(int64_t type, int64_t start, int64_t end, const T *input1, const T *input2, + T *output) { + switch (type) { + case kSameShape: { + auto cur_input1 = input1 + start; + auto cur_input2 = input2 + start; + for (int64_t i = start; i < end; ++i) { + *output = IgammaSingle(*cur_input1, *cur_input2); + output = output + 1; + cur_input1 = cur_input1 + 1; + cur_input2 = cur_input2 + 1; + } + break; + } + case kXOneElement: { + auto cur_input2 = input2 + start; + for (int64_t i = start; i < end; ++i) { + *output = IgammaSingle(*input1, *cur_input2); + output = output + 1; + cur_input2 = cur_input2 + 1; + } + break; + } + case kYOneElement: { + auto cur_input1 = input1 + start; + for (int64_t i = start; i < end; ++i) { + *output = IgammaSingle(*cur_input1, *input2); + output = output + 1; + cur_input1 = cur_input1 + 1; + } + break; + } + default: + break; + } +} + +template +void IgammaCpuKernelMod::NoBcastCompute(const std::vector &inputs, + const std::vector &outputs) { + auto in0 = reinterpret_cast(inputs[0]->addr); + auto in1 = reinterpret_cast(inputs[1]->addr); + auto out0 = reinterpret_cast(outputs[0]->addr); + size_t in0_elements_nums = get_element_num(a_shape_); + size_t in1_elements_nums = get_element_num(x_shape_); + size_t data_num = get_element_num(z_shape_); + int64_t type = + in0_elements_nums == in1_elements_nums ? kSameShape : (in0_elements_nums == 1 ? kXOneElement : kYOneElement); + if (data_num < kParallelDataNums) { + SpecialCompute(type, 0, data_num, in0, in1, out0); + } else { + auto shard_igamma = [type, in0, in1, out0, this](int64_t start, int64_t end) { + SpecialCompute(type, start, end, in0, in1, out0 + start); + }; + ParallelLaunchAutoSearch(shard_igamma, data_num, this, ¶llel_search_info_); + } +} + +void IgammaCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { + a_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + z_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); + dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); + kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); +} + +bool IgammaCpuKernelMod::Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs) { + if (dtype_ == kNumberTypeFloat32) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat64) { + LaunchKernel(inputs, outputs); + } else { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dtype of 'var' should be float32 or float64, but got " + << TypeIdToType(dtype_)->ToString(); + } + return true; +} + +template +void IgammaCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_); + size_t in0_elements_nums = get_element_num(a_shape_); + size_t in1_elements_nums = get_element_num(x_shape_); + bool isNeedBcast = (a_shape_ == x_shape_) || (in0_elements_nums == 1) || (in1_elements_nums == 1); + if (isNeedBcast) { + NoBcastCompute(inputs, outputs); + } else { + BcastCompute(inputs, outputs); + } +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Igamma, IgammaCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/igamma_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/igamma_cpu_kernel.h new file mode 100644 index 00000000000..50417054153 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/igamma_cpu_kernel.h @@ -0,0 +1,70 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_IGAMMA_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_IGAMMA_CPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class IgammaCpuKernelMod : public DeprecatedNativeCpuKernelMod { + public: + IgammaCpuKernelMod() = default; + ~IgammaCpuKernelMod() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + protected: + std::vector GetOpSupport() override { + static std::vector support_list = { + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64)}; + return support_list; + } + + private: + std::vector a_shape_; + std::vector x_shape_; + std::vector z_shape_; + TypeId dtype_{kTypeUnknown}; + template + void LaunchKernel(const std::vector &inputs, const std::vector &outputs); + + template + void BcastCompute(const std::vector &, const std::vector &); + + template + void SpecialCompute(int64_t, int64_t, int64_t, const T *, const T *, T *); + + template + void NoBcastCompute(const std::vector &, const std::vector &); +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_IGAMMA_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/igammac_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/igammac_cpu_kernel.cc new file mode 100644 index 00000000000..4512ed6608a --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/igammac_cpu_kernel.cc @@ -0,0 +1,379 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "plugin/device/cpu/kernel/igammac_cpu_kernel.h" +#include "plugin/device/cpu/hal/device/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +/** + * Coefficients for the Lanczos approximation of the gamma function. The + * coefficients are uniquely determined by the choice of g and n (kLanczosGamma + * and kLanczosCoefficients.size() + 1). The coefficients below correspond to + * [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and [7, + * 9] seemed to be the least sensitive to the quality of the log function. In + * particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5 + * for a particularly inaccurate log function. + * */ +static constexpr double kLanczosGamma = 7; // aka g +constexpr size_t kInputIndex0 = 0; +constexpr size_t kInputIndex1 = 1; +constexpr size_t kInputIndex2 = 2; +constexpr size_t kInputIndex3 = 3; +constexpr size_t kInputIndex4 = 4; +constexpr size_t kInputIndex5 = 5; +constexpr size_t kInputIndex6 = 6; +constexpr size_t kInputIndex7 = 7; +constexpr size_t kInputIndex8 = 8; +constexpr size_t kInputIndex9 = 9; +constexpr size_t kInputIndex10 = 10; +constexpr size_t kInputIndex11 = 11; +constexpr size_t kInputIndex12 = 12; +constexpr size_t kInputIndex13 = 13; +constexpr size_t kInputIndex14 = 14; +static constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478; +static constexpr double M_pi = 3.141592653589793238462643383279; +static constexpr std::array kLanczosCoefficients = { + 676.520368121885098567009190444019, -1259.13921672240287047156078755283, + 771.3234287776530788486528258894, -176.61502916214059906584551354, + 12.507343278686904814458936853, -0.13857109526572011689554707, + 9.984369578019570859563e-6, 1.50563273514931155834e-7}; +double log_lanczos_gamma_plus_one_half = std::log(kLanczosGamma + 0.5); +constexpr int64_t kParallelDataNums = 256; +constexpr int64_t kSameShape = 0; +constexpr int64_t kXOneElement = 1; +constexpr int64_t kYOneElement = 2; +constexpr size_t kInputNum = 2; +constexpr size_t kOutputNum = 1; + +size_t get_element_num(const std::vector &shape) { + size_t size = 1; + for (size_t i = 0; i < shape.size(); i++) { + size *= shape[i]; + } + return size; +} +} // namespace +/** Compute the Lgamma function using Lanczos' approximation from "A Precision + * Approximation of the Gamma Function". SIAM Journal on Numerical Analysis + * series B. Vol. 1: + * lgamma(z + 1) = (log(2) + log(pi)) / 2 + (z + 1/2) * log(t(z)) - t(z) + A(z) + * t(z) = z + kLanczosGamma + 1/2 + * A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k)) + */ +template +T Lgamma(const T &input) { + T log_pi = std::log(M_pi); + T log_sqrt_two_pi = (std::log(2) + std::log(M_pi)) / 2; + + /** If the input is less than 0.5 use Euler's reflection formula: + * gamma(x) = pi / (sin(pi * x) * gamma(1 - x)) + */ + bool need_to_reflect = (input < 0.5); + T input_after_reflect = need_to_reflect ? -input : input - 1; + T sum = kBaseLanczosCoeff; + for (size_t i = 0, end = kLanczosCoefficients.size(); i < end; ++i) { + T lanczos_coefficient = kLanczosCoefficients[i]; + + sum += lanczos_coefficient / (input_after_reflect + i + 1); + } + T gamma_plus_onehalf_plus_z = kLanczosGamma + 0.5 + input_after_reflect; + T log_t = log_lanczos_gamma_plus_one_half + std::log1pf(input_after_reflect / (kLanczosGamma + 0.5)); + T log_y = log_sqrt_two_pi + (input_after_reflect + 0.5 - gamma_plus_onehalf_plus_z / log_t) * log_t + std::log(sum); + T abs_input = std::abs(input); + T abs_frac_input = abs_input - std::floor(abs_input); + + T reduced_frac_input = (abs_frac_input > 0.5) ? 1 - abs_frac_input : abs_frac_input; + T reflection_denom = std::log(std::sin(M_pi * reduced_frac_input)); + + T reflection = std::isfinite(reflection_denom) ? log_pi - reflection_denom - log_y : -reflection_denom; + T result = need_to_reflect ? reflection : log_y; + + return std::isinf(input) ? std::numeric_limits::infinity() : result; +} + +template +T use_igammaf(const T &ax, const T &a, const T &x, T enabled) { + T y = 1 - a; + T z = x + y + 1; + T c = 0; + T pkm2 = 1; + T qkm2 = x; + T pkm1 = x + 1; + T qkm1 = z * x; + T ans = pkm1 / qkm1; + T t = 1; + T dpkm2_da = 0; + T dqkm2_da = 0; + T dpkm1_da = 0; + T dqkm1_da = -x; + T dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1; + std::vector vals = {enabled, ans, t, y, z, c, pkm1, qkm1, + pkm2, qkm2, dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da}; + constexpr int k2000 = 2000; + while (vals[kInputIndex0] && vals[kInputIndex5] < k2000) { + enabled = vals[kInputIndex0]; + ans = vals[kInputIndex1]; + T tmp_var_t = vals[kInputIndex2]; + T tmp_var_y = vals[kInputIndex3]; + T tmp_var_z = vals[kInputIndex4]; + T tmp_var_c = vals[kInputIndex5]; + pkm1 = vals[kInputIndex6]; + qkm1 = vals[kInputIndex7]; + pkm2 = vals[kInputIndex8]; + qkm2 = vals[kInputIndex9]; + dpkm2_da = vals[kInputIndex10]; + dqkm2_da = vals[kInputIndex11]; + dpkm1_da = vals[kInputIndex12]; + dqkm1_da = vals[kInputIndex13]; + dans_da = vals[kInputIndex14]; + tmp_var_c += 1; + tmp_var_y += 1; + constexpr int TWO = 2; + tmp_var_z += TWO; + + T yc = tmp_var_y * tmp_var_c; + T pk = pkm1 * tmp_var_z - pkm2 * yc; + T qk = qkm1 * tmp_var_z - qkm2 * yc; + bool qk_is_nonzero = (qk != 0); + T r = pk / qk; + t = qk_is_nonzero ? std::abs((ans - r) / r) : 1; + ans = qk_is_nonzero ? r : ans; + + T dpk_da = dpkm1_da * tmp_var_z - pkm1 - dpkm2_da * yc + pkm2 * tmp_var_c; + T dqk_da = dqkm1_da * tmp_var_z - qkm1 - dqkm2_da * yc + qkm2 * tmp_var_c; + T dans_da_new = qk_is_nonzero ? (dpk_da - ans * dqk_da) / qk : dans_da; + pkm2 = pkm1; + pkm1 = pk; + qkm2 = qkm1; + qkm1 = qk; + + dpkm2_da = dpkm1_da; + dqkm2_da = dqkm1_da; + dpkm1_da = dpk_da; + dqkm1_da = dqk_da; + bool rescale = std::abs(pk) > (1 / std::numeric_limits::epsilon()); + + pkm2 = rescale ? pkm2 * std::numeric_limits::epsilon() : pkm2; + pkm1 = rescale ? pkm1 * std::numeric_limits::epsilon() : pkm1; + qkm2 = rescale ? qkm2 * std::numeric_limits::epsilon() : qkm2; + qkm1 = rescale ? qkm1 * std::numeric_limits::epsilon() : qkm1; + + dpkm2_da = rescale ? dpkm2_da * std::numeric_limits::epsilon() : dpkm2_da; + dqkm2_da = rescale ? dqkm2_da * std::numeric_limits::epsilon() : dqkm2_da; + dpkm1_da = rescale ? dpkm1_da * std::numeric_limits::epsilon() : dpkm1_da; + dqkm1_da = rescale ? dqkm1_da * std::numeric_limits::epsilon() : dqkm1_da; + + T conditional = enabled && (t > std::numeric_limits::epsilon()); + vals[kInputIndex0] = conditional; + vals[kInputIndex5] = tmp_var_c; + if (enabled) { + vals = {conditional, ans, tmp_var_t, tmp_var_y, tmp_var_z, tmp_var_c, pkm1, qkm1, + pkm2, qkm2, dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da_new}; + } + } + ans = vals[kInputIndex1]; + return ans * ax; +} + +template +T use_igammat(T ax, T a, T x, T enabled) { + std::vector vals = {enabled, a, 1, 1, x, 0, 0}; + while (vals[kInputIndex0] != 0) { + enabled = vals[kInputIndex0]; + T r = vals[kInputIndex1]; + T c = vals[kInputIndex2]; + T ans = vals[kInputIndex3]; + x = vals[kInputIndex4]; + T dc_da = vals[kInputIndex5]; + T dans_da = vals[kInputIndex6]; + r += 1; + dc_da = dc_da * (x / r) + (-1 * c * x) / (r * r); + dans_da = dans_da + dc_da; + c = c * (x / r); + ans = ans + c; + T conditional = enabled && (c / ans > std::numeric_limits::epsilon()); + vals[kInputIndex0] = conditional; + if (enabled) { + vals = {conditional, r, c, ans, x, dc_da, dans_da}; + } + } + T ans = vals[kInputIndex3]; + if (a == 0) { + return NAN; + } + return 1 - (ans * ax) / a; +} + +template +T IgammacSingle(const T &a, const T &x) { + bool out_of_range = (x <= 0) || (a <= 0); + bool use_igamma = (x < 1) || (x < a); + T ax = a * std::log(x) - x - Lgamma(a); + bool underflow = (ax < -std::log(std::numeric_limits::max())); + T enabled = static_cast(!(out_of_range || underflow)); + + ax = std::exp(ax); + T output; + if (use_igamma != 0) { + enabled = static_cast(enabled && use_igamma); + output = use_igammat(ax, a, x, enabled); + } else { + enabled = static_cast(enabled && (!use_igamma)); + output = use_igammaf(ax, a, x, enabled); + } + output = out_of_range ? 1 : output; + output = x < 0 || a <= 0 || std::isnan(x) || (std::isinf(x) && (x > 0)) || std::isnan(a) + ? std::numeric_limits::quiet_NaN() + : output; + output = std::isinf(x) && x > 0 && a > 0 ? 0 : output; + return output; +} + +template +void IgammacCpuKernelMod::BcastCompute(const std::vector &inputs, + const std::vector &outputs) { + auto a_data_addr = reinterpret_cast(inputs[0]->addr); + auto x_data_addr = reinterpret_cast(inputs[1]->addr); + auto z_data_addr = reinterpret_cast(outputs[0]->addr); + size_t data_num = get_element_num(z_shape_); + auto output_shape = CPUKernelUtils::GetBroadcastShape(a_shape_, x_shape_); + BroadcastIterator iter(a_shape_, x_shape_, output_shape); + if (data_num < kParallelDataNums) { + iter.SetPos(0); + for (size_t i = 0; i < data_num; i++) { + T *a_index = a_data_addr + iter.GetInputPosA(); // i-th value of input0 + T *x_index = x_data_addr + iter.GetInputPosB(); // i-th value of input1 + *(z_data_addr + i) = IgammacSingle(*a_index, *x_index); + iter.GenNextPos(); + } + } else { + auto shard_igammac = [z_data_addr, a_data_addr, x_data_addr, &iter](size_t start, size_t end) { + iter.SetPos(start); + for (size_t i = start; i < end; i++) { + T *a_index = a_data_addr + iter.GetInputPosA(); // i-th value of input0 + T *x_index = x_data_addr + iter.GetInputPosB(); // i-th value of input1 + *(z_data_addr + i) = IgammacSingle(*a_index, *x_index); + iter.GenNextPos(); + } + }; + ParallelLaunchAutoSearch(shard_igammac, data_num, this, ¶llel_search_info_); + } +} + +/* special compute is used in the following situations. + * 1. the shapes of input1 and input2 are the same + * 2. input1 is a 1D tensor with only one element or input1 is scalar + * 3. input2 is a 1D tensor with only one element or input2 is scalar + * 4. the shapes of input1 and input2 are different + **/ +template +void IgammacCpuKernelMod::SpecialCompute(int64_t type, int64_t start, int64_t end, const T *input1, const T *input2, + T *output) { + switch (type) { + case kSameShape: { + auto cur_input1 = input1 + start; + auto cur_input2 = input2 + start; + for (int64_t i = start; i < end; ++i) { + *output = IgammacSingle(*cur_input1, *cur_input2); + output = output + 1; + cur_input1 = cur_input1 + 1; + cur_input2 = cur_input2 + 1; + } + break; + } + case kXOneElement: { + auto cur_input2 = input2 + start; + for (int64_t i = start; i < end; ++i) { + *output = IgammacSingle(*input1, *cur_input2); + output = output + 1; + cur_input2 = cur_input2 + 1; + } + break; + } + case kYOneElement: { + auto cur_input1 = input1 + start; + for (int64_t i = start; i < end; ++i) { + *output = IgammacSingle(*cur_input1, *input2); + output = output + 1; + cur_input1 = cur_input1 + 1; + } + break; + } + default: + break; + } +} + +template +void IgammacCpuKernelMod::NoBcastCompute(const std::vector &inputs, + const std::vector &outputs) { + auto in0 = reinterpret_cast(inputs[0]->addr); + auto in1 = reinterpret_cast(inputs[1]->addr); + auto out0 = reinterpret_cast(outputs[0]->addr); + size_t in0_elements_nums = get_element_num(a_shape_); + size_t in1_elements_nums = get_element_num(x_shape_); + size_t data_num = get_element_num(z_shape_); + int64_t type = + in0_elements_nums == in1_elements_nums ? kSameShape : (in0_elements_nums == 1 ? kXOneElement : kYOneElement); + if (data_num < kParallelDataNums) { + SpecialCompute(type, 0, data_num, in0, in1, out0); + } else { + auto shard_igammac = [type, in0, in1, out0, this](int64_t start, int64_t end) { + SpecialCompute(type, start, end, in0, in1, out0 + start); + }; + ParallelLaunchAutoSearch(shard_igammac, data_num, this, ¶llel_search_info_); + } +} + +void IgammacCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { + a_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + z_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); + dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); + kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); +} + +bool IgammacCpuKernelMod::Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs) { + if (dtype_ == kNumberTypeFloat32) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat64) { + LaunchKernel(inputs, outputs); + } else { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dtype of 'var' should be float32 or float64, but got " + << TypeIdToType(dtype_)->ToString(); + } + return true; +} + +template +void IgammacCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_); + size_t in0_elements_nums = get_element_num(a_shape_); + size_t in1_elements_nums = get_element_num(x_shape_); + bool isNeedBcast = (a_shape_ == x_shape_) || (in0_elements_nums == 1) || (in1_elements_nums == 1); + if (isNeedBcast) { + NoBcastCompute(inputs, outputs); + } else { + BcastCompute(inputs, outputs); + } +} +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Igammac, IgammacCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/igammac_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/igammac_cpu_kernel.h new file mode 100644 index 00000000000..c804d10a869 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/igammac_cpu_kernel.h @@ -0,0 +1,69 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_IGAMMAC_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_IGAMMAC_CPU_KERNEL_H_ +#include +#include +#include +#include +#include +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class IgammacCpuKernelMod : public DeprecatedNativeCpuKernelMod { + public: + IgammacCpuKernelMod() = default; + ~IgammacCpuKernelMod() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + protected: + std::vector GetOpSupport() override { + static std::vector support_list = { + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64)}; + return support_list; + } + + private: + std::vector a_shape_; + std::vector x_shape_; + std::vector z_shape_; + TypeId dtype_{kTypeUnknown}; + template + void LaunchKernel(const std::vector &inputs, const std::vector &outputs); + + template + void BcastCompute(const std::vector &, const std::vector &); + + template + void SpecialCompute(int64_t, int64_t, int64_t, const T *, const T *, T *); + + template + void NoBcastCompute(const std::vector &, const std::vector &); +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_IGAMMAC_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/igammagrada_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/igammagrada_cpu_kernel.cc new file mode 100644 index 00000000000..84ab840fc7a --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/igammagrada_cpu_kernel.cc @@ -0,0 +1,413 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "plugin/device/cpu/kernel/igammagrada_cpu_kernel.h" +#include +#include "plugin/device/cpu/hal/device/cpu_device_address.h" +namespace mindspore { +namespace kernel { +namespace { +/** + * Coefficients for the Lanczos approximation of the gamma function. The + * coefficients are uniquely determined by the choice of g and n (kLanczosGamma + * and kLanczosCoefficients.size() + 1). The coefficients below correspond to + * [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and [7, + * 9] seemed to be the least sensitive to the quality of the log function. In + * particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5 + * for a particularly inaccurate log function. + * */ +static constexpr double kLanczosGamma = 7; // aka g +constexpr size_t kInputIndex0 = 0; +constexpr size_t kInputIndex1 = 1; +constexpr size_t kInputIndex2 = 2; +constexpr size_t kInputIndex3 = 3; +constexpr size_t kInputIndex4 = 4; +constexpr size_t kInputIndex5 = 5; +constexpr size_t kInputIndex6 = 6; +constexpr size_t kInputIndex7 = 7; +constexpr size_t kInputIndex8 = 8; +constexpr size_t kInputIndex9 = 9; +constexpr size_t kInputIndex10 = 10; +constexpr size_t kInputIndex11 = 11; +constexpr size_t kInputIndex12 = 12; +constexpr size_t kInputIndex13 = 13; +constexpr size_t kInputIndex14 = 14; +static constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478; +static constexpr double M_pi = 3.141592653589793238462643383279; +static constexpr std::array kLanczosCoefficients = { + 676.520368121885098567009190444019, -1259.13921672240287047156078755283, + 771.3234287776530788486528258894, -176.61502916214059906584551354, + 12.507343278686904814458936853, -0.13857109526572011689554707, + 9.984369578019570859563e-6, 1.50563273514931155834e-7}; +double log_lanczos_gamma_plus_one_half = std::log(kLanczosGamma + 0.5); +constexpr int64_t kParallelDataNums = 256; +constexpr int64_t kSameShape = 0; +constexpr int64_t kXOneElement = 1; +constexpr int64_t kYOneElement = 2; +constexpr size_t kInputNum = 2; +constexpr size_t kOutputNum = 1; +constexpr int64_t VALUE = 1; +constexpr int64_t DERIVATIVE = 2; +size_t get_element_num(const std::vector &shape) { + size_t size = 1; + for (size_t i = 0; i < shape.size(); i++) { + size *= shape[i]; + } + return size; +} +} // namespace +/** Compute the Lgamma function using Lanczos' approximation from "A Precision + * Approximation of the Gamma Function". SIAM Journal on Numerical Analysis + * series B. Vol. 1: + * lgamma(z + 1) = (log(2) + log(pi)) / 2 + (z + 1/2) * log(t(z)) - t(z) + A(z) + * t(z) = z + kLanczosGamma + 1/2 + * A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k)) + */ +template +T Lgamma(const T &input) { + T log_pi = std::log(M_pi); + T log_sqrt_two_pi = (std::log(2) + std::log(M_pi)) / 2; + + /** If the input is less than 0.5 use Euler's reflection formula: + * gamma(x) = pi / (sin(pi * x) * gamma(1 - x)) + */ + bool need_to_reflect = (input < 0.5); + T input_after_reflect = need_to_reflect ? -input : input - 1; + T sum = kBaseLanczosCoeff; + for (size_t i = 0, end = kLanczosCoefficients.size(); i < end; ++i) { + T lanczos_coefficient = kLanczosCoefficients[i]; + + sum += lanczos_coefficient / (input_after_reflect + i + 1); + } + T gamma_plus_onehalf_plus_z = kLanczosGamma + 0.5 + input_after_reflect; + T log_t = log_lanczos_gamma_plus_one_half + std::log1pf(input_after_reflect / (kLanczosGamma + 0.5)); + T log_y = log_sqrt_two_pi + (input_after_reflect + 0.5 - gamma_plus_onehalf_plus_z / log_t) * log_t + std::log(sum); + T abs_input = std::abs(input); + T abs_frac_input = abs_input - std::floor(abs_input); + + T reduced_frac_input = (abs_frac_input > 0.5) ? 1 - abs_frac_input : abs_frac_input; + T reflection_denom = std::log(std::sin(M_pi * reduced_frac_input)); + + T reflection = std::isfinite(reflection_denom) ? log_pi - reflection_denom - log_y : -reflection_denom; + T result = need_to_reflect ? reflection : log_y; + + return std::isinf(input) ? std::numeric_limits::infinity() : result; +} + +template +T Digamma(const T &input) { + bool need_to_reflect = (input < 0.5); + T reflected_input = need_to_reflect ? -input : input - 1; + + T num = 0; + T denom = kBaseLanczosCoeff; + + for (size_t i = 0, end = kLanczosCoefficients.size(); i < end; ++i) { + T lanczos_coefficient = kLanczosCoefficients[i]; + num -= lanczos_coefficient / ((reflected_input + i + 1) * (reflected_input + i + 1)); + denom += lanczos_coefficient / (reflected_input + i + 1); + } + + T gamma_plus_onehalf_plus_z = kLanczosGamma + 0.5 + reflected_input; + T log_t = log_lanczos_gamma_plus_one_half + std::log1pf(reflected_input / (kLanczosGamma + 0.5)); + + T result = log_t + num / denom - kLanczosGamma / gamma_plus_onehalf_plus_z; + + T reduced_input = input + std::abs(std::floor(input + 0.5)); + T reflection = result - M_pi * std::cos(M_pi * reduced_input) / std::sin(M_pi * reduced_input); + T real_result = need_to_reflect ? reflection : result; + + // Digamma has poles at negative integers and zero; return nan for those. + return (input < 0 && input == std::floor(input)) ? std::numeric_limits::quiet_NaN() : real_result; +} + +template +T use_igammact(const T &ax, const T &a, const T &x, T enabled, int mode) { + T y = 1 - a; + T z = x + y + 1; + T c = 0; + T pkm2 = 1; + T qkm2 = x; + T pkm1 = x + 1; + T qkm1 = z * x; + T ans = pkm1 / qkm1; + T t = 1; + T dpkm2_da = 0; + T dqkm2_da = 0; + T dpkm1_da = 0; + T dqkm1_da = -x; + T dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1; + std::vector vals = {enabled, ans, t, y, z, c, pkm1, qkm1, + pkm2, qkm2, dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da}; + constexpr int k2000 = 2000; + while (vals[kInputIndex0] && vals[kInputIndex5] < k2000) { + enabled = vals[kInputIndex0]; + ans = vals[kInputIndex1]; + T tmp_var_t = vals[kInputIndex2]; + T tmp_var_y = vals[kInputIndex3]; + T tmp_var_z = vals[kInputIndex4]; + T tmp_var_c = vals[kInputIndex5]; + pkm1 = vals[kInputIndex6]; + qkm1 = vals[kInputIndex7]; + pkm2 = vals[kInputIndex8]; + qkm2 = vals[kInputIndex9]; + dpkm2_da = vals[kInputIndex10]; + dqkm2_da = vals[kInputIndex11]; + dpkm1_da = vals[kInputIndex12]; + dqkm1_da = vals[kInputIndex13]; + dans_da = vals[kInputIndex14]; + tmp_var_c += 1; + tmp_var_y += 1; + constexpr int TWO = 2; + tmp_var_z += TWO; + T yc = tmp_var_y * tmp_var_c; + T pk = pkm1 * tmp_var_z - pkm2 * yc; + T qk = qkm1 * tmp_var_z - qkm2 * yc; + bool qk_is_nonzero = (qk != 0); + T r = pk / qk; + ans = qk_is_nonzero ? r : ans; + T dpk_da = dpkm1_da * tmp_var_z - pkm1 - dpkm2_da * yc + pkm2 * tmp_var_c; + T dqk_da = dqkm1_da * tmp_var_z - qkm1 - dqkm2_da * yc + qkm2 * tmp_var_c; + T dans_da_new = qk_is_nonzero ? (dpk_da - ans * dqk_da) / qk : dans_da; + T grad_conditional = qk_is_nonzero ? std::abs(dans_da_new - dans_da) : 1; + pkm2 = pkm1; + pkm1 = pk; + qkm2 = qkm1; + qkm1 = qk; + dpkm2_da = dpkm1_da; + dqkm2_da = dqkm1_da; + dpkm1_da = dpk_da; + dqkm1_da = dqk_da; + bool rescale = std::abs(pk) > (1 / std::numeric_limits::epsilon()); + pkm2 = rescale ? pkm2 * std::numeric_limits::epsilon() : pkm2; + pkm1 = rescale ? pkm1 * std::numeric_limits::epsilon() : pkm1; + qkm2 = rescale ? qkm2 * std::numeric_limits::epsilon() : qkm2; + qkm1 = rescale ? qkm1 * std::numeric_limits::epsilon() : qkm1; + dpkm2_da = rescale ? dpkm2_da * std::numeric_limits::epsilon() : dpkm2_da; + dqkm2_da = rescale ? dqkm2_da * std::numeric_limits::epsilon() : dqkm2_da; + dpkm1_da = rescale ? dpkm1_da * std::numeric_limits::epsilon() : dpkm1_da; + dqkm1_da = rescale ? dqkm1_da * std::numeric_limits::epsilon() : dqkm1_da; + T conditional = enabled && (grad_conditional > std::numeric_limits::epsilon()); + vals[kInputIndex0] = conditional; + vals[kInputIndex5] = tmp_var_c; + if (enabled) { + vals = {conditional, ans, tmp_var_t, tmp_var_y, tmp_var_z, tmp_var_c, pkm1, qkm1, + pkm2, qkm2, dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da_new}; + } + } + ans = vals[kInputIndex1]; + if (mode == VALUE) { + return ans * ax; + } + dans_da = vals[kInputIndex14]; + T dlogax_da = std::log(x) - Digamma(a); + switch (mode) { + case DERIVATIVE: + return ax * (ans * dlogax_da + dans_da); + default: + return -(dans_da + ans * dlogax_da) * x; + } +} + +template +T use_igammacf(const T &ax, const T &a, T x, T enabled) { + std::vector vals = {enabled, a, 1, 1, x, 0, 0}; + while (vals[kInputIndex0] != 0) { + enabled = vals[kInputIndex0]; + T r = vals[kInputIndex1]; + T c = vals[kInputIndex2]; + T ans = vals[kInputIndex3]; + x = vals[kInputIndex4]; + T dc_da = vals[kInputIndex5]; + T dans_da = vals[kInputIndex6]; + r += 1; + dc_da = dc_da * (x / r) + (-1 * c * x) / (r * r); + dans_da = dans_da + dc_da; + c = c * (x / r); + ans = ans + c; + T conditional = enabled && (std::abs(dc_da / dans_da) > std::numeric_limits::epsilon()); + vals[kInputIndex0] = conditional; + if (enabled) { + vals = {conditional, r, c, ans, x, dc_da, dans_da}; + } + } + T ans = vals[kInputIndex3]; + T dans_da = vals[kInputIndex6]; + if (a == 0) { + return NAN; + } + T dlogax_da = std::log(x) - Digamma(a + 1); + return ax * (ans * dlogax_da + dans_da) / a; +} + +template +T IgammaGradASingle(const T &a, const T &x) { + bool is_nan = std::isnan(a) || std::isnan(x); + bool x_is_zero = (x == 0); + bool domain_error = (x < 0) || (a <= 0); + bool use_igammac = (x > 1) && (x > a); + T ax = a * std::log(x) - x - Lgamma(a); + bool underflow = (ax < -std::log(std::numeric_limits::max())); + ax = std::exp(ax); + T enabled = static_cast(!(x_is_zero || domain_error || underflow || is_nan)); + T output; + if (use_igammac != 0) { + enabled = static_cast(enabled && use_igammac); + output = -use_igammact(ax, a, x, enabled, DERIVATIVE); + } else { + enabled = static_cast(enabled && !(use_igammac)); + output = use_igammacf(ax, a, x, enabled); + } + output = (domain_error || is_nan || std::isnan(output)) ? std::numeric_limits::quiet_NaN() : output; + output = x_is_zero || (std::isinf(x) && !is_nan && !domain_error && !std::isinf(a)) ? 0 : output; + return output; +} + +template +void IgammaGradACpuKernelMod::BcastCompute(const std::vector &inputs, + const std::vector &outputs) { + auto a_data_addr = reinterpret_cast(inputs[0]->addr); + auto x_data_addr = reinterpret_cast(inputs[1]->addr); + auto z_data_addr = reinterpret_cast(outputs[0]->addr); + size_t data_num = get_element_num(z_shape_); + auto output_shape = CPUKernelUtils::GetBroadcastShape(a_shape_, x_shape_); + BroadcastIterator iter(a_shape_, x_shape_, output_shape); + if (data_num < kParallelDataNums) { + iter.SetPos(0); + for (size_t i = 0; i < data_num; i++) { + T *a_index = a_data_addr + iter.GetInputPosA(); // i-th value of input0 + T *x_index = x_data_addr + iter.GetInputPosB(); // i-th value of input1 + *(z_data_addr + i) = IgammaGradASingle(*a_index, *x_index); + iter.GenNextPos(); + } + } else { + auto shard_igammaGradA = [z_data_addr, a_data_addr, x_data_addr, &iter](size_t start, size_t end) { + iter.SetPos(start); + for (size_t i = start; i < end; i++) { + T *a_index = a_data_addr + iter.GetInputPosA(); // i-th value of input0 + T *x_index = x_data_addr + iter.GetInputPosB(); // i-th value of input1 + *(z_data_addr + i) = IgammaGradASingle(*a_index, *x_index); + iter.GenNextPos(); + } + }; + ParallelLaunchAutoSearch(shard_igammaGradA, data_num, this, ¶llel_search_info_); + } +} + +/* special compute is used in the following situations. + * 1. the shapes of input1 and input2 are the same + * 2. input1 is a 1D tensor with only one element or input1 is scalar + * 3. input2 is a 1D tensor with only one element or input2 is scalar + * 4. the shapes of input1 and input2 are different + **/ +template +void IgammaGradACpuKernelMod::SpecialCompute(int64_t type, int64_t start, int64_t end, const T *input1, const T *input2, + T *output) { + switch (type) { + case kSameShape: { + auto cur_input1 = input1 + start; + auto cur_input2 = input2 + start; + for (int64_t i = start; i < end; ++i) { + *output = IgammaGradASingle(*cur_input1, *cur_input2); + output = output + 1; + cur_input1 = cur_input1 + 1; + cur_input2 = cur_input2 + 1; + } + break; + } + case kXOneElement: { + auto cur_input2 = input2 + start; + for (int64_t i = start; i < end; ++i) { + *output = IgammaGradASingle(*input1, *cur_input2); + output = output + 1; + cur_input2 = cur_input2 + 1; + } + break; + } + case kYOneElement: { + auto cur_input1 = input1 + start; + for (int64_t i = start; i < end; ++i) { + *output = IgammaGradASingle(*cur_input1, *input2); + output = output + 1; + cur_input1 = cur_input1 + 1; + } + break; + } + default: + break; + } +} + +template +void IgammaGradACpuKernelMod::NoBcastCompute(const std::vector &inputs, + const std::vector &outputs) { + auto in0 = reinterpret_cast(inputs[0]->addr); + auto in1 = reinterpret_cast(inputs[1]->addr); + auto out0 = reinterpret_cast(outputs[0]->addr); + size_t in0_elements_nums = get_element_num(a_shape_); + size_t in1_elements_nums = get_element_num(x_shape_); + size_t data_num = get_element_num(z_shape_); + int64_t type = + in0_elements_nums == in1_elements_nums ? kSameShape : (in0_elements_nums == 1 ? kXOneElement : kYOneElement); + if (data_num < kParallelDataNums) { + SpecialCompute(type, 0, data_num, in0, in1, out0); + } else { + auto shard_igammaGradA = [type, in0, in1, out0, this](int64_t start, int64_t end) { + SpecialCompute(type, start, end, in0, in1, out0 + start); + }; + ParallelLaunchAutoSearch(shard_igammaGradA, data_num, this, ¶llel_search_info_); + } +} + +void IgammaGradACpuKernelMod::InitKernel(const CNodePtr &kernel_node) { + a_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + z_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); + dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); + kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); +} + +bool IgammaGradACpuKernelMod::Launch(const std::vector &inputs, + const std::vector &, + const std::vector &outputs) { + if (dtype_ == kNumberTypeFloat32) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat64) { + LaunchKernel(inputs, outputs); + } else { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dtype of 'var' should be float32 or float64, but got " + << TypeIdToType(dtype_)->ToString(); + } + return true; +} + +template +void IgammaGradACpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_); + size_t in0_elements_nums = get_element_num(a_shape_); + size_t in1_elements_nums = get_element_num(x_shape_); + bool isNeedBcast = (a_shape_ == x_shape_) || (in0_elements_nums == 1) || (in1_elements_nums == 1); + if (isNeedBcast) { + NoBcastCompute(inputs, outputs); + } else { + BcastCompute(inputs, outputs); + } +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, IgammaGradA, IgammaGradACpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/igammagrada_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/igammagrada_cpu_kernel.h new file mode 100644 index 00000000000..ead85215fb1 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/igammagrada_cpu_kernel.h @@ -0,0 +1,69 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_IGAMMAGRADA_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_IGAMMAGRADA_CPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class IgammaGradACpuKernelMod : public DeprecatedNativeCpuKernelMod { + public: + IgammaGradACpuKernelMod() = default; + ~IgammaGradACpuKernelMod() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + protected: + std::vector GetOpSupport() override { + static std::vector support_list = { + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64)}; + return support_list; + } + + private: + std::vector a_shape_; + std::vector x_shape_; + std::vector z_shape_; + TypeId dtype_{kTypeUnknown}; + template + void LaunchKernel(const std::vector &inputs, const std::vector &outputs); + + template + void BcastCompute(const std::vector &, const std::vector &); + + template + void SpecialCompute(int64_t, int64_t, int64_t, const T *, const T *, T *); + + template + void NoBcastCompute(const std::vector &, const std::vector &); +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_IGAMMAGRADA_CPU_KERNEL_H_ diff --git a/mindspore/core/ops/core_ops.h b/mindspore/core/ops/core_ops.h index 37788417477..2906a8fd348 100644 --- a/mindspore/core/ops/core_ops.h +++ b/mindspore/core/ops/core_ops.h @@ -812,6 +812,9 @@ GVAR_DEF(PrimitivePtr, kPrimEinsumGrad, std::make_shared("EinsumGrad" GVAR_DEF(PrimitivePtr, kPrimTrace, std::make_shared("Trace")); GVAR_DEF(PrimitivePtr, kPrimTraceGrad, std::make_shared("TraceGrad")); GVAR_DEF(PrimitivePtr, kPrimZeta, std::make_shared("Zeta")); +GVAR_DEF(PrimitivePtr, kPrimIgamma, std::make_shared("Igamma")); +GVAR_DEF(PrimitivePtr, kPrimIgammac, std::make_shared("Igammac")); +GVAR_DEF(PrimitivePtr, kPrimIgammaGradA, std::make_shared("IgammaGradA")); // Image GVAR_DEF(PrimitivePtr, kPrimNonMaxSuppressionV3, std::make_shared("NonMaxSuppressionV3")); diff --git a/mindspore/core/ops/grad/igammagrada.cc b/mindspore/core/ops/grad/igammagrada.cc new file mode 100644 index 00000000000..faf61932396 --- /dev/null +++ b/mindspore/core/ops/grad/igammagrada.cc @@ -0,0 +1,61 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/grad/igammagrada.h" +#include +#include +#include +#include "abstract/ops/primitive_infer_map.h" +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "utils/tensor_construct_utils.h" +#include "mindapi/src/helper.h" + +namespace mindspore { +namespace ops { +namespace { +abstract::ShapePtr IgammaGradAInferShape(const PrimitivePtr &primitive, + const std::vector &input_args) { + auto prim_name = primitive->name(); + return BroadCastInferShape(prim_name, input_args); +} + +TypePtr IgammaGradAInferType(const PrimitivePtr &primitive, const std::vector &input_args) { + auto prim_name = primitive->name(); + auto a_type = input_args[kInputIndex0]->BuildType(); + auto x_type = input_args[kInputIndex1]->BuildType(); + const std::set valid_types = {kFloat32, kFloat64}; + std::map args; + (void)args.insert({"a", a_type}); + (void)args.insert({"x", x_type}); + return CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); +} +} // namespace + +MIND_API_OPERATOR_IMPL(IgammaGradA, BaseOperator); +AbstractBasePtr IgammaGradAInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + auto prim_name = primitive->name(); + const int64_t kInputNum = 2; + (void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, prim_name); + auto infer_type = IgammaGradAInferType(primitive, input_args); + auto infer_shape = IgammaGradAInferShape(primitive, input_args); + return abstract::MakeAbstract(infer_shape, infer_type); +} + +REGISTER_PRIMITIVE_EVAL_IMPL(IgammaGradA, prim::kPrimIgammaGradA, IgammaGradAInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/grad/igammagrada.h b/mindspore/core/ops/grad/igammagrada.h new file mode 100644 index 00000000000..5615d98e1bd --- /dev/null +++ b/mindspore/core/ops/grad/igammagrada.h @@ -0,0 +1,42 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_IGAMMAGRADA_H +#define MINDSPORE_CORE_OPS_IGAMMAGRADA_H + +#include +#include + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameIgammaGradA = "IgammaGradA"; +/// \brief Computes the gradient of igamma(a, x) wrt a. +class MIND_API IgammaGradA : public BaseOperator { + public: + MIND_API_BASE_MEMBER(IgammaGradA); + /// \brief Constructor. + IgammaGradA() : BaseOperator(kNameIgammaGradA) { InitIOName({"a", "x"}, {"z"}); } +}; + +abstract::AbstractBasePtr IgammaGradAInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using kPrimIgammaGradAPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore +#endif // MINDSPORE_CORE_OPS_IGAMMAGRADA_H diff --git a/mindspore/core/ops/igamma.cc b/mindspore/core/ops/igamma.cc new file mode 100644 index 00000000000..57e560bba31 --- /dev/null +++ b/mindspore/core/ops/igamma.cc @@ -0,0 +1,59 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ops/igamma.h" +#include +#include +#include +#include "abstract/ops/primitive_infer_map.h" +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "utils/tensor_construct_utils.h" +#include "mindapi/src/helper.h" + +namespace mindspore { +namespace ops { +namespace { +abstract::ShapePtr IgammaInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { + auto prim_name = primitive->name(); + return BroadCastInferShape(prim_name, input_args); +} + +TypePtr IgammaInferType(const PrimitivePtr &primitive, const std::vector &input_args) { + auto prim_name = primitive->name(); + auto a_type = input_args[kInputIndex0]->BuildType(); + auto x_type = input_args[kInputIndex1]->BuildType(); + const std::set valid_types = {kFloat32, kFloat64}; + std::map args; + (void)args.insert({"a", a_type}); + (void)args.insert({"x", x_type}); + return CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); +} +} // namespace + +MIND_API_OPERATOR_IMPL(Igamma, BaseOperator); +AbstractBasePtr IgammaInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + auto prim_name = primitive->name(); + const int64_t kInputNum = 2; + (void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, prim_name); + auto infer_type = IgammaInferType(primitive, input_args); + auto infer_shape = IgammaInferShape(primitive, input_args); + return abstract::MakeAbstract(infer_shape, infer_type); +} + +REGISTER_PRIMITIVE_EVAL_IMPL(Igamma, prim::kPrimIgamma, IgammaInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/igamma.h b/mindspore/core/ops/igamma.h new file mode 100644 index 00000000000..c32f7c6d411 --- /dev/null +++ b/mindspore/core/ops/igamma.h @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_IGAMMA_H +#define MINDSPORE_CORE_OPS_IGAMMA_H + +#include +#include + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameIgamma = "Igamma"; +/// \brief Calculates lower regularized incomplete Gamma function. +/// Refer to Python API @ref mindspore.ops.Igamma for more details. +class MIND_API Igamma : public BaseOperator { + public: + MIND_API_BASE_MEMBER(Igamma); + /// \brief Constructor. + Igamma() : BaseOperator(kNameIgamma) { InitIOName({"a", "x"}, {"z"}); } +}; + +abstract::AbstractBasePtr IgammaInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using kPrimIgammaPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore +#endif // MINDSPORE_CORE_OPS_IGAMMA_H diff --git a/mindspore/core/ops/igammac.cc b/mindspore/core/ops/igammac.cc new file mode 100644 index 00000000000..dfea1d90610 --- /dev/null +++ b/mindspore/core/ops/igammac.cc @@ -0,0 +1,60 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/igammac.h" +#include +#include +#include +#include "abstract/ops/primitive_infer_map.h" +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "utils/tensor_construct_utils.h" +#include "mindapi/src/helper.h" + +namespace mindspore { +namespace ops { +namespace { +abstract::ShapePtr IgammacInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { + auto prim_name = primitive->name(); + return BroadCastInferShape(prim_name, input_args); +} + +TypePtr IgammacInferType(const PrimitivePtr &primitive, const std::vector &input_args) { + auto prim_name = primitive->name(); + auto a_type = input_args[kInputIndex0]->BuildType(); + auto x_type = input_args[kInputIndex1]->BuildType(); + const std::set valid_types = {kFloat32, kFloat64}; + std::map args; + (void)args.insert({"a", a_type}); + (void)args.insert({"x", x_type}); + return CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); +} +} // namespace + +MIND_API_OPERATOR_IMPL(Igammac, BaseOperator); +AbstractBasePtr IgammacInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + auto prim_name = primitive->name(); + const int64_t kInputNum = 2; + (void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, prim_name); + auto infer_type = IgammacInferType(primitive, input_args); + auto infer_shape = IgammacInferShape(primitive, input_args); + return abstract::MakeAbstract(infer_shape, infer_type); +} + +REGISTER_PRIMITIVE_EVAL_IMPL(Igammac, prim::kPrimIgammac, IgammacInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/igammac.h b/mindspore/core/ops/igammac.h new file mode 100644 index 00000000000..e08a2918181 --- /dev/null +++ b/mindspore/core/ops/igammac.h @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_IGAMMAC_H +#define MINDSPORE_CORE_OPS_IGAMMAC_H + +#include +#include + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameIgammac = "Igammac"; +/// \brief Compute the upper regularized incomplete Gamma function Q(a, x). +/// Refer to Python API @ref mindspore.ops.Igammac for more details. +class MIND_API Igammac : public BaseOperator { + public: + MIND_API_BASE_MEMBER(Igammac); + /// \brief Constructor. + Igammac() : BaseOperator(kNameIgammac) { InitIOName({"a", "x"}, {"z"}); } +}; + +abstract::AbstractBasePtr IgammacInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using kPrimIgammacPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore +#endif // MINDSPORE_CORE_OPS_IGAMMAC_H diff --git a/mindspore/python/mindspore/ops/_grad_experimental/grad_math_ops.py b/mindspore/python/mindspore/ops/_grad_experimental/grad_math_ops.py index 04603f4c8b8..ac9b8b888e1 100644 --- a/mindspore/python/mindspore/ops/_grad_experimental/grad_math_ops.py +++ b/mindspore/python/mindspore/ops/_grad_experimental/grad_math_ops.py @@ -19,14 +19,17 @@ from mindspore.common import dtype as mstype from mindspore import nn import mindspore.numpy as mnp import numpy as np +from ...nn.layer import math from .. import functional as F from .. import operations as P from ..operations.math_ops import Trace +from ..functional import broadcast_gradient_args from .._grad.grad_base import bprop_getters from .._grad.grad_math_ops import binop_grad_common from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..operations import _grad_ops as G from ..operations import math_ops as math +from ..operations.math_ops import Igamma, Igammac from ..primitive import constexpr from ..operations.math_ops import ReduceStd @@ -496,3 +499,51 @@ def get_bprop_trace(self): return (dx,) return bprop + + +@bprop_getters.register(Igamma) +def get_bprop_igamma(self): + """Grad definition for `Igamma` operation.""" + shape_ = P.Shape() + igammagrada = G.IgammaGradA() + lgamma = math.LGamma() + log_ = P.Log() + exp_ = P.Exp() + reshape_ = P.Reshape() + reduce_sum_ = P.ReduceSum() + def bprop(a, x, out, dout): + sa = shape_(a) + sx = shape_(x) + ra, rx = broadcast_gradient_args(sa, sx) + partial_a = igammagrada(a, x) + partial_x = exp_(-x + (a - 1) * log_(x) - lgamma(a)) + if ra != () or rx != (): + return reshape_(reduce_sum_(partial_a * dout, ra), sa), reshape_(reduce_sum_(partial_x * dout, rx), sx) + return reshape_(partial_a * dout, sa), reshape_(partial_x * dout, sx) + + return bprop + + +@bprop_getters.register(Igammac) +def get_bprop_igammac(self): + """Grad definition for `Igammac` operation.""" + shape_ = P.Shape() + igammagrada = G.IgammaGradA() + lgamma = math.LGamma() + log_ = P.Log() + exp_ = P.Exp() + reshape_ = P.Reshape() + reduce_sum_ = P.ReduceSum() + neg_ = P.Neg() + def bprop(a, x, out, dout): + sa = shape_(a) + sx = shape_(x) + ra, rx = broadcast_gradient_args(sa, sx) + partial_a = igammagrada(a, x) + partial_x = exp_(-x + (a - 1) * log_(x) - lgamma(a)) + if ra != () or rx != (): + return neg_(reshape_(reduce_sum_(partial_a * dout, ra), sa)), \ + neg_(reshape_(reduce_sum_(partial_x * dout, rx), sx)) + return neg_(reshape_(partial_a * dout, sa)), neg_(reshape_(partial_x * dout, sx)) + + return bprop diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py index 839b49f257a..b79cf3aaae4 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py @@ -49,6 +49,9 @@ from .asin_grad import _asin_grad_aicpu from .is_finite import _is_finite_aicpu from .is_inf import _is_inf_aicpu from .is_nan import _is_nan_aicpu +from .igamma import _igamma_aicpu +from .igammac import _igammac_aicpu +from .igammagrada import _igammagrada_aicpu from .reshape import _reshape_aicpu from .fill_v2 import _fill_v2_aicpu from .flatten import _flatten_aicpu diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/igamma.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/igamma.py new file mode 100644 index 00000000000..cf215649c5e --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/igamma.py @@ -0,0 +1,30 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Igamma op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType +igamma_op_info = AiCPURegOp("Igamma") \ + .fusion_type("OPAQUE") \ + .input(0, "a", "required") \ + .input(1, "x", "required") \ + .output(0, "z", "required") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \ + .get_op_info() + + +@op_info_register(igamma_op_info) +def _igamma_aicpu(): + """Igamma aicpu register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/igammac.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/igammac.py new file mode 100644 index 00000000000..eb3f762db37 --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/igammac.py @@ -0,0 +1,30 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Igammac op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType +igammac_op_info = AiCPURegOp("Igammac") \ + .fusion_type("OPAQUE") \ + .input(0, "a", "required") \ + .input(1, "x", "required") \ + .output(0, "z", "required") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \ + .get_op_info() + + +@op_info_register(igammac_op_info) +def _igammac_aicpu(): + """Igammac aicpu register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/igammagrada.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/igammagrada.py new file mode 100644 index 00000000000..7feb66f1220 --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/igammagrada.py @@ -0,0 +1,30 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Igamma op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType +igammagrada_op_info = AiCPURegOp("IgammaGradA") \ + .fusion_type("OPAQUE") \ + .input(0, "a", "required") \ + .input(1, "x", "required") \ + .output(0, "z", "required") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \ + .get_op_info() + + +@op_info_register(igammagrada_op_info) +def _igammagrada_aicpu(): + """IgammaGradA aicpu register""" + return diff --git a/mindspore/python/mindspore/ops/operations/_grad_ops.py b/mindspore/python/mindspore/ops/operations/_grad_ops.py index 0c74456e77d..6a805d26af2 100644 --- a/mindspore/python/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/python/mindspore/ops/operations/_grad_ops.py @@ -2474,3 +2474,39 @@ class TraceGrad(Primitive): @prim_attr_register def __init__(self): pass + + +class IgammaGradA(Primitive): + r""" + Computes the gradient of igamma(a, x) wrt a. + + Inputs: + - **a** (Tensor) - The input tensor. With float32 or float 64 data type. + - **x** (Tensor) - The input tensor. With float32 data or float64 type. `x` should have + the same dtype with `a`. + + Outputs: + Tensor, has the same dtype as `a` and `x`. + + Raises: + TypeError: If a or grad is not a Tensor. + TypeError: If dtype of input x and a is not float32 nor float64. + TypeError: If x has different dtype with a. + ValueError: If `a` could not be broadcast to a tensor with shape of `x`. + + Supported Platforms: + ``Ascend````CPU`` + + Examples: + >>> a = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32)) + >>> x = Tensor(np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32)) + >>> igammagrada = G.IgammaGradA() + >>> output = igammagrada(a, x) + >>> print (output) + [-0.2940046 -0.20153049 -0.13028376 -0.08352186] + """ + + @prim_attr_register + def __init__(self): + """Initialize IgammaGradA""" + self.init_prim_io_names(inputs=['a', 'x'], outputs=['z']) diff --git a/mindspore/python/mindspore/ops/operations/math_ops.py b/mindspore/python/mindspore/ops/operations/math_ops.py index e744a946423..45d8bb8adb6 100644 --- a/mindspore/python/mindspore/ops/operations/math_ops.py +++ b/mindspore/python/mindspore/ops/operations/math_ops.py @@ -5336,6 +5336,106 @@ class Trunc(Primitive): """Initialize Trunc""" +class Igamma(Primitive): + r""" + Calculates lower regularized incomplete Gamma function. + The lower regularized incomplete Gamma function is defined as: + + .. math:: + P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x) + + where + + .. math:: + gamma(a, x) = \int_0^x t^{a-1} \exp^{-t} dt + + is the lower incomplete Gamma function. + + Above :math:`Q(a, x)` is the upper regularized complete Gamma function. + + .. warning:: + This is an experimental prototype that is subject to change and/or deletion. + + Inputs: + - **a** (Tensor) - The input tensor. With type of float32 or float64. + - **x** (Tensor) - The input tensor. With float32 or float64 type. `x` should have + the same dtype with `a`. + + Outputs: + Tensor, has the same dtype as `a` and `x`. + + Raises: + TypeError: If a or x is not a Tensor. + TypeError: If dtype of input x and a is not float32 nor float64. + TypeError: If x has different dtype with a. + ValueError: If `a` could not be broadcast to a tensor with shape of `x`. + + Supported Platforms: + ``Ascend````CPU`` + + Examples: + >>> a = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32)) + >>> x = Tensor(np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32)) + >>> igamma = P.IGamma() + >>> output = igamma(a, x) + >>> print (output) + [0.593994 0.35276785 0.21486944 0.13337152] + """ + + @prim_attr_register + def __init__(self): + """Initialize Igamma""" + self.init_prim_io_names(inputs=['a', 'x'], outputs=['z']) + + +class Igammac(Primitive): + r""" + Compute the upper regularized incomplete Gamma function Q(a, x). + + The upper regularized incomplete Gamma function is defined as: + \(Q(a, x) = Gamma(a, x) / Gamma(a) = 1 - P(a, x)\) + where + \(Gamma(a, x) = int_{x}^{\infty} t^{a-1} exp(-t) dt\) + + is the upper incomplete Gama function. + + Note, above P(a, x) (Igamma) is the lower regularized complete Gamma function. + + .. warning:: + This is an experimental prototype that is subject to change and/or deletion. + + Inputs: + - **a** (Tensor) - The input tensor of igammac. With float32 or float64 data type. + - **x** (Tensor) - The input tensor of igammac. With float32 or float64 type. `x` should have + the same type with `a`. + + Outputs: + A Tensor, has the same dtype as `a` and `x`. + + Raises: + TypeError: If dtype of input x and a is not float32 nor float64. + TypeError: If a or x is not a Tensor. + TypeError: If x has different dtype with a. + ValueError: If `a` could not be broadcast to a tensor with shape of `x`. + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> a = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32)) + >>> x = Tensor(np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32)) + >>> igammac = P.IGammac() + >>> output = igammac(a, x) + >>> print (output) + [0.40600586 0.6472318 0.7851304 0.8666283 ] + """ + + @prim_attr_register + def __init__(self): + """Initialize Igammac""" + self.init_prim_io_names(inputs=['a', 'x'], outputs=['z']) + + class IsClose(Primitive): r""" Returns a boolean tensor where two tensors are element-wise equal within a tolerance. diff --git a/tests/ut/python/ops/test_math_ops.py b/tests/ut/python/ops/test_math_ops.py index fa4d31bd064..e0117b88dcc 100755 --- a/tests/ut/python/ops/test_math_ops.py +++ b/tests/ut/python/ops/test_math_ops.py @@ -25,8 +25,9 @@ from mindspore.common import dtype as mstype from mindspore.ops import composite as C from mindspore.ops import operations as P from mindspore.ops import functional as F +from mindspore.ops.operations._grad_ops import IgammaGradA from mindspore.ops import prim_attr_register, PrimitiveWithInfer -from mindspore.ops.operations.math_ops import Zeta +from mindspore.ops.operations.math_ops import Zeta, Igamma, Igammac from ..ut_filter import non_graph_engine from ....mindspore_test_framework.mindspore_test import mindspore_test from ....mindspore_test_framework.pipeline.forward.compile_forward \ @@ -478,7 +479,25 @@ raise_set = [ ('Zeta', { 'block': Zeta(), 'desc_inputs': [Tensor(np.array([1, 1, 1, 1], np.float32)), - Tensor([0.5, 0.5, 0.5, 0.5], mstype.float32)], + Tensor([0.5, 0.5, 0.5, 0.5], mstype.float32)]}), + ('Igamma', { + 'block': Igamma(), + 'desc_inputs': [Tensor(np.array([1.1, 2.2, -4.1], np.float32)), + Tensor(np.array([0.2, 1.2, 2.1], np.float32))], + 'desc_bprop': [Tensor(np.array([2, 3], np.float32)), + Tensor(np.array([2, 3], np.float32))], + 'skip': ['backward']}), + ('Igammac', { + 'block': Igammac(), + 'desc_inputs': [Tensor(np.array([1.1, 2.2, -4.1], np.float32)), + Tensor(np.array([0.2, 1.2, 2.1], np.float32))], + 'desc_bprop': [Tensor(np.array([2, 3], np.float32)), + Tensor(np.array([2, 3], np.float32))], + 'skip': ['backward']}), + ('IgammaGradA', { + 'block': IgammaGradA(), + 'desc_inputs': [Tensor(np.array([1.1, 2.2, 8.1, 2.1], np.float32)), + Tensor(np.array([0.2, 1.2, 2.1, 3.4], np.float32))], 'skip': ['backward']}), ]