!33193 [assistant][ops]New operator implementation, include igamma igammac igammagrad

Merge pull request !33193 from ganqijun/ig
This commit is contained in:
i-robot 2022-05-12 03:44:53 +00:00 committed by Gitee
commit 5e5606e82f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
21 changed files with 1997 additions and 2 deletions

View File

@ -0,0 +1,385 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/igamma_cpu_kernel.h"
#include <limits>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
namespace kernel {
namespace {
/**
* Coefficients for the Lanczos approximation of the gamma function. The
* coefficients are uniquely determined by the choice of g and n (kLanczosGamma
* and kLanczosCoefficients.size() + 1). The coefficients below correspond to
* [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and [7,
* 9] seemed to be the least sensitive to the quality of the log function. In
* particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5
* for a particularly inaccurate log function.
* */
static constexpr double kLanczosGamma = 7; // aka g
constexpr size_t kInputIndex0 = 0;
constexpr size_t kInputIndex1 = 1;
constexpr size_t kInputIndex2 = 2;
constexpr size_t kInputIndex3 = 3;
constexpr size_t kInputIndex4 = 4;
constexpr size_t kInputIndex5 = 5;
constexpr size_t kInputIndex6 = 6;
constexpr size_t kInputIndex7 = 7;
constexpr size_t kInputIndex8 = 8;
constexpr size_t kInputIndex9 = 9;
constexpr size_t kInputIndex10 = 10;
constexpr size_t kInputIndex11 = 11;
constexpr size_t kInputIndex12 = 12;
constexpr size_t kInputIndex13 = 13;
constexpr size_t kInputIndex14 = 14;
static constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
static constexpr double M_pi = 3.141592653589793238462643383279;
static constexpr std::array<double, 8> kLanczosCoefficients = {
676.520368121885098567009190444019, -1259.13921672240287047156078755283,
771.3234287776530788486528258894, -176.61502916214059906584551354,
12.507343278686904814458936853, -0.13857109526572011689554707,
9.984369578019570859563e-6, 1.50563273514931155834e-7};
double log_lanczos_gamma_plus_one_half = std::log(kLanczosGamma + 0.5);
constexpr int64_t kParallelDataNums = 256;
constexpr int64_t kSameShape = 0;
constexpr int64_t kXOneElement = 1;
constexpr int64_t kYOneElement = 2;
constexpr size_t kInputNum = 2;
constexpr size_t kOutputNum = 1;
size_t get_element_num(const std::vector<size_t> &shape) {
size_t size = 1;
for (size_t i = 0; i < shape.size(); i++) {
size *= shape[i];
}
return size;
}
} // namespace
/** Compute the Lgamma function using Lanczos' approximation from "A Precision
* Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
* series B. Vol. 1:
* lgamma(z + 1) = (log(2) + log(pi)) / 2 + (z + 1/2) * log(t(z)) - t(z) + A(z)
* t(z) = z + kLanczosGamma + 1/2
* A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k))
*/
template <typename T>
T Lgamma(const T &input) {
T log_pi = std::log(M_pi);
T log_sqrt_two_pi = (std::log(2) + std::log(M_pi)) / 2;
/** If the input is less than 0.5 use Euler's reflection formula:
* gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
*/
bool need_to_reflect = (input < 0.5);
T input_after_reflect = need_to_reflect ? -input : input - 1;
T sum = kBaseLanczosCoeff;
for (size_t i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
T lanczos_coefficient = kLanczosCoefficients[i];
sum += lanczos_coefficient / (input_after_reflect + i + 1);
}
T gamma_plus_onehalf_plus_z = kLanczosGamma + 0.5 + input_after_reflect;
T log_t = log_lanczos_gamma_plus_one_half + std::log1pf(input_after_reflect / (kLanczosGamma + 0.5));
T log_y = log_sqrt_two_pi + (input_after_reflect + 0.5 - gamma_plus_onehalf_plus_z / log_t) * log_t + std::log(sum);
T abs_input = std::abs(input);
T abs_frac_input = abs_input - std::floor(abs_input);
T reduced_frac_input = (abs_frac_input > 0.5) ? 1 - abs_frac_input : abs_frac_input;
T reflection_denom = std::log(std::sin(M_pi * reduced_frac_input));
T reflection = std::isfinite(reflection_denom) ? log_pi - reflection_denom - log_y : -reflection_denom;
T result = need_to_reflect ? reflection : log_y;
return std::isinf(input) ? std::numeric_limits<T>::infinity() : result;
}
template <typename T>
T use_igammact(const T &ax, const T &a, const T &x, T enabled) {
T y = 1 - a;
T z = x + y + 1;
T c = 0;
T pkm2 = 1;
T qkm2 = x;
T pkm1 = x + 1;
T qkm1 = z * x;
T ans = pkm1 / qkm1;
T t = 1;
T dpkm2_da = 0;
T dqkm2_da = 0;
T dpkm1_da = 0;
T dqkm1_da = -x;
T dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1;
std::vector<T> vals = {enabled, ans, t, y, z, c, pkm1, qkm1,
pkm2, qkm2, dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da};
constexpr int k2000 = 2000;
while (vals[kInputIndex0] && vals[kInputIndex5] < k2000) {
enabled = vals[kInputIndex0];
ans = vals[kInputIndex1];
T tmp_var_t = vals[kInputIndex2];
T tmp_var_y = vals[kInputIndex3];
T tmp_var_z = vals[kInputIndex4];
T tmp_var_c = vals[kInputIndex5];
pkm1 = vals[kInputIndex6];
qkm1 = vals[kInputIndex7];
pkm2 = vals[kInputIndex8];
qkm2 = vals[kInputIndex9];
dpkm2_da = vals[kInputIndex10];
dqkm2_da = vals[kInputIndex11];
dpkm1_da = vals[kInputIndex12];
dqkm1_da = vals[kInputIndex13];
dans_da = vals[kInputIndex14];
tmp_var_c += 1;
tmp_var_y += 1;
constexpr int TWO = 2;
tmp_var_z += TWO;
T yc = tmp_var_y * tmp_var_c;
T pk = pkm1 * tmp_var_z - pkm2 * yc;
T qk = qkm1 * tmp_var_z - qkm2 * yc;
bool qk_is_nonzero = (qk != 0);
T r = pk / qk;
t = qk_is_nonzero ? std::abs((ans - r) / r) : 1;
ans = qk_is_nonzero ? r : ans;
T dpk_da = dpkm1_da * tmp_var_z - pkm1 - dpkm2_da * yc + pkm2 * tmp_var_c;
T dqk_da = dqkm1_da * tmp_var_z - qkm1 - dqkm2_da * yc + qkm2 * tmp_var_c;
T dans_da_new = qk_is_nonzero ? (dpk_da - ans * dqk_da) / qk : dans_da;
pkm2 = pkm1;
pkm1 = pk;
qkm2 = qkm1;
qkm1 = qk;
dpkm2_da = dpkm1_da;
dqkm2_da = dqkm1_da;
dpkm1_da = dpk_da;
dqkm1_da = dqk_da;
bool rescale = std::abs(pk) > (1 / std::numeric_limits<T>::epsilon());
pkm2 = rescale ? pkm2 * std::numeric_limits<T>::epsilon() : pkm2;
pkm1 = rescale ? pkm1 * std::numeric_limits<T>::epsilon() : pkm1;
qkm2 = rescale ? qkm2 * std::numeric_limits<T>::epsilon() : qkm2;
qkm1 = rescale ? qkm1 * std::numeric_limits<T>::epsilon() : qkm1;
dpkm2_da = rescale ? dpkm2_da * std::numeric_limits<T>::epsilon() : dpkm2_da;
dqkm2_da = rescale ? dqkm2_da * std::numeric_limits<T>::epsilon() : dqkm2_da;
dpkm1_da = rescale ? dpkm1_da * std::numeric_limits<T>::epsilon() : dpkm1_da;
dqkm1_da = rescale ? dqkm1_da * std::numeric_limits<T>::epsilon() : dqkm1_da;
T conditional = enabled && (t > std::numeric_limits<T>::epsilon());
vals[kInputIndex0] = conditional;
vals[kInputIndex5] = tmp_var_c;
if (enabled) {
vals = {conditional, ans, tmp_var_t, tmp_var_y, tmp_var_z, tmp_var_c, pkm1, qkm1,
pkm2, qkm2, dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da_new};
}
}
ans = vals[kInputIndex1];
return 1 - ans * ax;
}
template <typename T>
T use_igammacf(T ax, T a, T x, T enabled) {
std::vector<T> vals = {enabled, a, 1, 1, x, 0, 0};
while (vals[kInputIndex0] != 0) {
enabled = vals[kInputIndex0];
T r = vals[kInputIndex1];
T c = vals[kInputIndex2];
T ans = vals[kInputIndex3];
x = vals[kInputIndex4];
T dc_da = vals[kInputIndex5];
T dans_da = vals[kInputIndex6];
r += 1;
dc_da = dc_da * (x / r) + (-1 * c * x) / (r * r);
dans_da = dans_da + dc_da;
c = c * (x / r);
ans = ans + c;
T conditional = enabled && (c / ans > std::numeric_limits<T>::epsilon());
vals[kInputIndex0] = conditional;
if (enabled) {
vals = {conditional, r, c, ans, x, dc_da, dans_da};
}
}
T ans = vals[kInputIndex3];
if (a == 0) {
return NAN;
}
return (ans * ax) / a;
}
template <typename T>
T IgammaSingle(const T &a, const T &x) {
if (!std::isinf(a) && (a > 0) && std::isinf(x) && x > 0) {
return 1;
}
bool is_nan = std::isnan(a) || std::isnan(x);
bool x_is_zero = (x == 0);
bool domain_error = (x < 0) || (a <= 0);
bool use_igammac = (x > 1) && (x > a);
T ax = a * std::log(x) - x - Lgamma<T>(a);
bool underflow = (ax < -std::log(std::numeric_limits<T>::max()));
ax = std::exp(ax);
T enabled = static_cast<T>(!(x_is_zero || domain_error || underflow || is_nan));
T output;
if (use_igammac != 0) {
enabled = static_cast<T>(enabled && use_igammac);
output = use_igammact(ax, a, x, enabled);
} else {
enabled = static_cast<T>(enabled && !(use_igammac));
output = use_igammacf(ax, a, x, enabled);
}
output = (domain_error || is_nan || std::isnan(output)) ? std::numeric_limits<double>::quiet_NaN() : output;
output = x_is_zero ? 0 : output;
return output;
}
template <typename T>
void IgammaCpuKernelMod::BcastCompute(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto a_data_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto x_data_addr = reinterpret_cast<T *>(inputs[1]->addr);
auto z_data_addr = reinterpret_cast<T *>(outputs[0]->addr);
size_t data_num = get_element_num(z_shape_);
auto output_shape = CPUKernelUtils::GetBroadcastShape(a_shape_, x_shape_);
BroadcastIterator iter(a_shape_, x_shape_, output_shape);
if (data_num < kParallelDataNums) {
iter.SetPos(0);
for (size_t i = 0; i < data_num; i++) {
T *a_index = a_data_addr + iter.GetInputPosA(); // i-th value of input0
T *x_index = x_data_addr + iter.GetInputPosB(); // i-th value of input1
*(z_data_addr + i) = IgammaSingle<T>(*a_index, *x_index);
iter.GenNextPos();
}
} else {
auto shard_igamma = [z_data_addr, a_data_addr, x_data_addr, &iter](size_t start, size_t end) {
iter.SetPos(start);
for (size_t i = start; i < end; i++) {
T *a_index = a_data_addr + iter.GetInputPosA(); // i-th value of input0
T *x_index = x_data_addr + iter.GetInputPosB(); // i-th value of input1
*(z_data_addr + i) = IgammaSingle<T>(*a_index, *x_index);
iter.GenNextPos();
}
};
ParallelLaunchAutoSearch(shard_igamma, data_num, this, &parallel_search_info_);
}
}
/* special compute is used in the following situations.
* 1. the shapes of input1 and input2 are the same
* 2. input1 is a 1D tensor with only one element or input1 is scalar
* 3. input2 is a 1D tensor with only one element or input2 is scalar
* 4. the shapes of input1 and input2 are different
**/
template <typename T>
void IgammaCpuKernelMod::SpecialCompute(int64_t type, int64_t start, int64_t end, const T *input1, const T *input2,
T *output) {
switch (type) {
case kSameShape: {
auto cur_input1 = input1 + start;
auto cur_input2 = input2 + start;
for (int64_t i = start; i < end; ++i) {
*output = IgammaSingle<T>(*cur_input1, *cur_input2);
output = output + 1;
cur_input1 = cur_input1 + 1;
cur_input2 = cur_input2 + 1;
}
break;
}
case kXOneElement: {
auto cur_input2 = input2 + start;
for (int64_t i = start; i < end; ++i) {
*output = IgammaSingle<T>(*input1, *cur_input2);
output = output + 1;
cur_input2 = cur_input2 + 1;
}
break;
}
case kYOneElement: {
auto cur_input1 = input1 + start;
for (int64_t i = start; i < end; ++i) {
*output = IgammaSingle<T>(*cur_input1, *input2);
output = output + 1;
cur_input1 = cur_input1 + 1;
}
break;
}
default:
break;
}
}
template <typename T>
void IgammaCpuKernelMod::NoBcastCompute(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto in0 = reinterpret_cast<T *>(inputs[0]->addr);
auto in1 = reinterpret_cast<T *>(inputs[1]->addr);
auto out0 = reinterpret_cast<T *>(outputs[0]->addr);
size_t in0_elements_nums = get_element_num(a_shape_);
size_t in1_elements_nums = get_element_num(x_shape_);
size_t data_num = get_element_num(z_shape_);
int64_t type =
in0_elements_nums == in1_elements_nums ? kSameShape : (in0_elements_nums == 1 ? kXOneElement : kYOneElement);
if (data_num < kParallelDataNums) {
SpecialCompute<T>(type, 0, data_num, in0, in1, out0);
} else {
auto shard_igamma = [type, in0, in1, out0, this](int64_t start, int64_t end) {
SpecialCompute<T>(type, start, end, in0, in1, out0 + start);
};
ParallelLaunchAutoSearch(shard_igamma, data_num, this, &parallel_search_info_);
}
}
void IgammaCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
a_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
z_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
}
bool IgammaCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeFloat32) {
LaunchKernel<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat64) {
LaunchKernel<double>(inputs, outputs);
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dtype of 'var' should be float32 or float64, but got "
<< TypeIdToType(dtype_)->ToString();
}
return true;
}
template <typename T>
void IgammaCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_);
size_t in0_elements_nums = get_element_num(a_shape_);
size_t in1_elements_nums = get_element_num(x_shape_);
bool isNeedBcast = (a_shape_ == x_shape_) || (in0_elements_nums == 1) || (in1_elements_nums == 1);
if (isNeedBcast) {
NoBcastCompute<T>(inputs, outputs);
} else {
BcastCompute<T>(inputs, outputs);
}
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Igamma, IgammaCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,70 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_IGAMMA_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_IGAMMA_CPU_KERNEL_H_
#include <vector>
#include <array>
#include <cmath>
#include <iostream>
#include <limits>
#include <map>
#include <string>
#include <tuple>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class IgammaCpuKernelMod : public DeprecatedNativeCpuKernelMod {
public:
IgammaCpuKernelMod() = default;
~IgammaCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override {
static std::vector<KernelAttr> support_list = {
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64)};
return support_list;
}
private:
std::vector<size_t> a_shape_;
std::vector<size_t> x_shape_;
std::vector<size_t> z_shape_;
TypeId dtype_{kTypeUnknown};
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
template <typename T>
void BcastCompute(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &);
template <typename T>
void SpecialCompute(int64_t, int64_t, int64_t, const T *, const T *, T *);
template <typename T>
void NoBcastCompute(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &);
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_IGAMMA_CPU_KERNEL_H_

View File

@ -0,0 +1,379 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/igammac_cpu_kernel.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
namespace kernel {
namespace {
/**
* Coefficients for the Lanczos approximation of the gamma function. The
* coefficients are uniquely determined by the choice of g and n (kLanczosGamma
* and kLanczosCoefficients.size() + 1). The coefficients below correspond to
* [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and [7,
* 9] seemed to be the least sensitive to the quality of the log function. In
* particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5
* for a particularly inaccurate log function.
* */
static constexpr double kLanczosGamma = 7; // aka g
constexpr size_t kInputIndex0 = 0;
constexpr size_t kInputIndex1 = 1;
constexpr size_t kInputIndex2 = 2;
constexpr size_t kInputIndex3 = 3;
constexpr size_t kInputIndex4 = 4;
constexpr size_t kInputIndex5 = 5;
constexpr size_t kInputIndex6 = 6;
constexpr size_t kInputIndex7 = 7;
constexpr size_t kInputIndex8 = 8;
constexpr size_t kInputIndex9 = 9;
constexpr size_t kInputIndex10 = 10;
constexpr size_t kInputIndex11 = 11;
constexpr size_t kInputIndex12 = 12;
constexpr size_t kInputIndex13 = 13;
constexpr size_t kInputIndex14 = 14;
static constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
static constexpr double M_pi = 3.141592653589793238462643383279;
static constexpr std::array<double, 8> kLanczosCoefficients = {
676.520368121885098567009190444019, -1259.13921672240287047156078755283,
771.3234287776530788486528258894, -176.61502916214059906584551354,
12.507343278686904814458936853, -0.13857109526572011689554707,
9.984369578019570859563e-6, 1.50563273514931155834e-7};
double log_lanczos_gamma_plus_one_half = std::log(kLanczosGamma + 0.5);
constexpr int64_t kParallelDataNums = 256;
constexpr int64_t kSameShape = 0;
constexpr int64_t kXOneElement = 1;
constexpr int64_t kYOneElement = 2;
constexpr size_t kInputNum = 2;
constexpr size_t kOutputNum = 1;
size_t get_element_num(const std::vector<size_t> &shape) {
size_t size = 1;
for (size_t i = 0; i < shape.size(); i++) {
size *= shape[i];
}
return size;
}
} // namespace
/** Compute the Lgamma function using Lanczos' approximation from "A Precision
* Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
* series B. Vol. 1:
* lgamma(z + 1) = (log(2) + log(pi)) / 2 + (z + 1/2) * log(t(z)) - t(z) + A(z)
* t(z) = z + kLanczosGamma + 1/2
* A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k))
*/
template <typename T>
T Lgamma(const T &input) {
T log_pi = std::log(M_pi);
T log_sqrt_two_pi = (std::log(2) + std::log(M_pi)) / 2;
/** If the input is less than 0.5 use Euler's reflection formula:
* gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
*/
bool need_to_reflect = (input < 0.5);
T input_after_reflect = need_to_reflect ? -input : input - 1;
T sum = kBaseLanczosCoeff;
for (size_t i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
T lanczos_coefficient = kLanczosCoefficients[i];
sum += lanczos_coefficient / (input_after_reflect + i + 1);
}
T gamma_plus_onehalf_plus_z = kLanczosGamma + 0.5 + input_after_reflect;
T log_t = log_lanczos_gamma_plus_one_half + std::log1pf(input_after_reflect / (kLanczosGamma + 0.5));
T log_y = log_sqrt_two_pi + (input_after_reflect + 0.5 - gamma_plus_onehalf_plus_z / log_t) * log_t + std::log(sum);
T abs_input = std::abs(input);
T abs_frac_input = abs_input - std::floor(abs_input);
T reduced_frac_input = (abs_frac_input > 0.5) ? 1 - abs_frac_input : abs_frac_input;
T reflection_denom = std::log(std::sin(M_pi * reduced_frac_input));
T reflection = std::isfinite(reflection_denom) ? log_pi - reflection_denom - log_y : -reflection_denom;
T result = need_to_reflect ? reflection : log_y;
return std::isinf(input) ? std::numeric_limits<T>::infinity() : result;
}
template <typename T>
T use_igammaf(const T &ax, const T &a, const T &x, T enabled) {
T y = 1 - a;
T z = x + y + 1;
T c = 0;
T pkm2 = 1;
T qkm2 = x;
T pkm1 = x + 1;
T qkm1 = z * x;
T ans = pkm1 / qkm1;
T t = 1;
T dpkm2_da = 0;
T dqkm2_da = 0;
T dpkm1_da = 0;
T dqkm1_da = -x;
T dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1;
std::vector<T> vals = {enabled, ans, t, y, z, c, pkm1, qkm1,
pkm2, qkm2, dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da};
constexpr int k2000 = 2000;
while (vals[kInputIndex0] && vals[kInputIndex5] < k2000) {
enabled = vals[kInputIndex0];
ans = vals[kInputIndex1];
T tmp_var_t = vals[kInputIndex2];
T tmp_var_y = vals[kInputIndex3];
T tmp_var_z = vals[kInputIndex4];
T tmp_var_c = vals[kInputIndex5];
pkm1 = vals[kInputIndex6];
qkm1 = vals[kInputIndex7];
pkm2 = vals[kInputIndex8];
qkm2 = vals[kInputIndex9];
dpkm2_da = vals[kInputIndex10];
dqkm2_da = vals[kInputIndex11];
dpkm1_da = vals[kInputIndex12];
dqkm1_da = vals[kInputIndex13];
dans_da = vals[kInputIndex14];
tmp_var_c += 1;
tmp_var_y += 1;
constexpr int TWO = 2;
tmp_var_z += TWO;
T yc = tmp_var_y * tmp_var_c;
T pk = pkm1 * tmp_var_z - pkm2 * yc;
T qk = qkm1 * tmp_var_z - qkm2 * yc;
bool qk_is_nonzero = (qk != 0);
T r = pk / qk;
t = qk_is_nonzero ? std::abs((ans - r) / r) : 1;
ans = qk_is_nonzero ? r : ans;
T dpk_da = dpkm1_da * tmp_var_z - pkm1 - dpkm2_da * yc + pkm2 * tmp_var_c;
T dqk_da = dqkm1_da * tmp_var_z - qkm1 - dqkm2_da * yc + qkm2 * tmp_var_c;
T dans_da_new = qk_is_nonzero ? (dpk_da - ans * dqk_da) / qk : dans_da;
pkm2 = pkm1;
pkm1 = pk;
qkm2 = qkm1;
qkm1 = qk;
dpkm2_da = dpkm1_da;
dqkm2_da = dqkm1_da;
dpkm1_da = dpk_da;
dqkm1_da = dqk_da;
bool rescale = std::abs(pk) > (1 / std::numeric_limits<T>::epsilon());
pkm2 = rescale ? pkm2 * std::numeric_limits<T>::epsilon() : pkm2;
pkm1 = rescale ? pkm1 * std::numeric_limits<T>::epsilon() : pkm1;
qkm2 = rescale ? qkm2 * std::numeric_limits<T>::epsilon() : qkm2;
qkm1 = rescale ? qkm1 * std::numeric_limits<T>::epsilon() : qkm1;
dpkm2_da = rescale ? dpkm2_da * std::numeric_limits<T>::epsilon() : dpkm2_da;
dqkm2_da = rescale ? dqkm2_da * std::numeric_limits<T>::epsilon() : dqkm2_da;
dpkm1_da = rescale ? dpkm1_da * std::numeric_limits<T>::epsilon() : dpkm1_da;
dqkm1_da = rescale ? dqkm1_da * std::numeric_limits<T>::epsilon() : dqkm1_da;
T conditional = enabled && (t > std::numeric_limits<T>::epsilon());
vals[kInputIndex0] = conditional;
vals[kInputIndex5] = tmp_var_c;
if (enabled) {
vals = {conditional, ans, tmp_var_t, tmp_var_y, tmp_var_z, tmp_var_c, pkm1, qkm1,
pkm2, qkm2, dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da_new};
}
}
ans = vals[kInputIndex1];
return ans * ax;
}
template <typename T>
T use_igammat(T ax, T a, T x, T enabled) {
std::vector<T> vals = {enabled, a, 1, 1, x, 0, 0};
while (vals[kInputIndex0] != 0) {
enabled = vals[kInputIndex0];
T r = vals[kInputIndex1];
T c = vals[kInputIndex2];
T ans = vals[kInputIndex3];
x = vals[kInputIndex4];
T dc_da = vals[kInputIndex5];
T dans_da = vals[kInputIndex6];
r += 1;
dc_da = dc_da * (x / r) + (-1 * c * x) / (r * r);
dans_da = dans_da + dc_da;
c = c * (x / r);
ans = ans + c;
T conditional = enabled && (c / ans > std::numeric_limits<T>::epsilon());
vals[kInputIndex0] = conditional;
if (enabled) {
vals = {conditional, r, c, ans, x, dc_da, dans_da};
}
}
T ans = vals[kInputIndex3];
if (a == 0) {
return NAN;
}
return 1 - (ans * ax) / a;
}
template <typename T>
T IgammacSingle(const T &a, const T &x) {
bool out_of_range = (x <= 0) || (a <= 0);
bool use_igamma = (x < 1) || (x < a);
T ax = a * std::log(x) - x - Lgamma(a);
bool underflow = (ax < -std::log(std::numeric_limits<T>::max()));
T enabled = static_cast<T>(!(out_of_range || underflow));
ax = std::exp(ax);
T output;
if (use_igamma != 0) {
enabled = static_cast<T>(enabled && use_igamma);
output = use_igammat(ax, a, x, enabled);
} else {
enabled = static_cast<T>(enabled && (!use_igamma));
output = use_igammaf(ax, a, x, enabled);
}
output = out_of_range ? 1 : output;
output = x < 0 || a <= 0 || std::isnan(x) || (std::isinf(x) && (x > 0)) || std::isnan(a)
? std::numeric_limits<T>::quiet_NaN()
: output;
output = std::isinf(x) && x > 0 && a > 0 ? 0 : output;
return output;
}
template <typename T>
void IgammacCpuKernelMod::BcastCompute(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto a_data_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto x_data_addr = reinterpret_cast<T *>(inputs[1]->addr);
auto z_data_addr = reinterpret_cast<T *>(outputs[0]->addr);
size_t data_num = get_element_num(z_shape_);
auto output_shape = CPUKernelUtils::GetBroadcastShape(a_shape_, x_shape_);
BroadcastIterator iter(a_shape_, x_shape_, output_shape);
if (data_num < kParallelDataNums) {
iter.SetPos(0);
for (size_t i = 0; i < data_num; i++) {
T *a_index = a_data_addr + iter.GetInputPosA(); // i-th value of input0
T *x_index = x_data_addr + iter.GetInputPosB(); // i-th value of input1
*(z_data_addr + i) = IgammacSingle<T>(*a_index, *x_index);
iter.GenNextPos();
}
} else {
auto shard_igammac = [z_data_addr, a_data_addr, x_data_addr, &iter](size_t start, size_t end) {
iter.SetPos(start);
for (size_t i = start; i < end; i++) {
T *a_index = a_data_addr + iter.GetInputPosA(); // i-th value of input0
T *x_index = x_data_addr + iter.GetInputPosB(); // i-th value of input1
*(z_data_addr + i) = IgammacSingle<T>(*a_index, *x_index);
iter.GenNextPos();
}
};
ParallelLaunchAutoSearch(shard_igammac, data_num, this, &parallel_search_info_);
}
}
/* special compute is used in the following situations.
* 1. the shapes of input1 and input2 are the same
* 2. input1 is a 1D tensor with only one element or input1 is scalar
* 3. input2 is a 1D tensor with only one element or input2 is scalar
* 4. the shapes of input1 and input2 are different
**/
template <typename T>
void IgammacCpuKernelMod::SpecialCompute(int64_t type, int64_t start, int64_t end, const T *input1, const T *input2,
T *output) {
switch (type) {
case kSameShape: {
auto cur_input1 = input1 + start;
auto cur_input2 = input2 + start;
for (int64_t i = start; i < end; ++i) {
*output = IgammacSingle<T>(*cur_input1, *cur_input2);
output = output + 1;
cur_input1 = cur_input1 + 1;
cur_input2 = cur_input2 + 1;
}
break;
}
case kXOneElement: {
auto cur_input2 = input2 + start;
for (int64_t i = start; i < end; ++i) {
*output = IgammacSingle<T>(*input1, *cur_input2);
output = output + 1;
cur_input2 = cur_input2 + 1;
}
break;
}
case kYOneElement: {
auto cur_input1 = input1 + start;
for (int64_t i = start; i < end; ++i) {
*output = IgammacSingle<T>(*cur_input1, *input2);
output = output + 1;
cur_input1 = cur_input1 + 1;
}
break;
}
default:
break;
}
}
template <typename T>
void IgammacCpuKernelMod::NoBcastCompute(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto in0 = reinterpret_cast<T *>(inputs[0]->addr);
auto in1 = reinterpret_cast<T *>(inputs[1]->addr);
auto out0 = reinterpret_cast<T *>(outputs[0]->addr);
size_t in0_elements_nums = get_element_num(a_shape_);
size_t in1_elements_nums = get_element_num(x_shape_);
size_t data_num = get_element_num(z_shape_);
int64_t type =
in0_elements_nums == in1_elements_nums ? kSameShape : (in0_elements_nums == 1 ? kXOneElement : kYOneElement);
if (data_num < kParallelDataNums) {
SpecialCompute<T>(type, 0, data_num, in0, in1, out0);
} else {
auto shard_igammac = [type, in0, in1, out0, this](int64_t start, int64_t end) {
SpecialCompute<T>(type, start, end, in0, in1, out0 + start);
};
ParallelLaunchAutoSearch(shard_igammac, data_num, this, &parallel_search_info_);
}
}
void IgammacCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
a_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
z_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
}
bool IgammacCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeFloat32) {
LaunchKernel<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat64) {
LaunchKernel<double>(inputs, outputs);
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dtype of 'var' should be float32 or float64, but got "
<< TypeIdToType(dtype_)->ToString();
}
return true;
}
template <typename T>
void IgammacCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_);
size_t in0_elements_nums = get_element_num(a_shape_);
size_t in1_elements_nums = get_element_num(x_shape_);
bool isNeedBcast = (a_shape_ == x_shape_) || (in0_elements_nums == 1) || (in1_elements_nums == 1);
if (isNeedBcast) {
NoBcastCompute<T>(inputs, outputs);
} else {
BcastCompute<T>(inputs, outputs);
}
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Igammac, IgammacCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,69 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_IGAMMAC_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_IGAMMAC_CPU_KERNEL_H_
#include <cmath>
#include <vector>
#include <array>
#include <iostream>
#include <limits>
#include <map>
#include <string>
#include <tuple>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class IgammacCpuKernelMod : public DeprecatedNativeCpuKernelMod {
public:
IgammacCpuKernelMod() = default;
~IgammacCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override {
static std::vector<KernelAttr> support_list = {
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64)};
return support_list;
}
private:
std::vector<size_t> a_shape_;
std::vector<size_t> x_shape_;
std::vector<size_t> z_shape_;
TypeId dtype_{kTypeUnknown};
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
template <typename T>
void BcastCompute(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &);
template <typename T>
void SpecialCompute(int64_t, int64_t, int64_t, const T *, const T *, T *);
template <typename T>
void NoBcastCompute(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &);
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_IGAMMAC_CPU_KERNEL_H_

View File

@ -0,0 +1,413 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/igammagrada_cpu_kernel.h"
#include <limits>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
namespace kernel {
namespace {
/**
* Coefficients for the Lanczos approximation of the gamma function. The
* coefficients are uniquely determined by the choice of g and n (kLanczosGamma
* and kLanczosCoefficients.size() + 1). The coefficients below correspond to
* [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and [7,
* 9] seemed to be the least sensitive to the quality of the log function. In
* particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5
* for a particularly inaccurate log function.
* */
static constexpr double kLanczosGamma = 7; // aka g
constexpr size_t kInputIndex0 = 0;
constexpr size_t kInputIndex1 = 1;
constexpr size_t kInputIndex2 = 2;
constexpr size_t kInputIndex3 = 3;
constexpr size_t kInputIndex4 = 4;
constexpr size_t kInputIndex5 = 5;
constexpr size_t kInputIndex6 = 6;
constexpr size_t kInputIndex7 = 7;
constexpr size_t kInputIndex8 = 8;
constexpr size_t kInputIndex9 = 9;
constexpr size_t kInputIndex10 = 10;
constexpr size_t kInputIndex11 = 11;
constexpr size_t kInputIndex12 = 12;
constexpr size_t kInputIndex13 = 13;
constexpr size_t kInputIndex14 = 14;
static constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
static constexpr double M_pi = 3.141592653589793238462643383279;
static constexpr std::array<double, 8> kLanczosCoefficients = {
676.520368121885098567009190444019, -1259.13921672240287047156078755283,
771.3234287776530788486528258894, -176.61502916214059906584551354,
12.507343278686904814458936853, -0.13857109526572011689554707,
9.984369578019570859563e-6, 1.50563273514931155834e-7};
double log_lanczos_gamma_plus_one_half = std::log(kLanczosGamma + 0.5);
constexpr int64_t kParallelDataNums = 256;
constexpr int64_t kSameShape = 0;
constexpr int64_t kXOneElement = 1;
constexpr int64_t kYOneElement = 2;
constexpr size_t kInputNum = 2;
constexpr size_t kOutputNum = 1;
constexpr int64_t VALUE = 1;
constexpr int64_t DERIVATIVE = 2;
size_t get_element_num(const std::vector<size_t> &shape) {
size_t size = 1;
for (size_t i = 0; i < shape.size(); i++) {
size *= shape[i];
}
return size;
}
} // namespace
/** Compute the Lgamma function using Lanczos' approximation from "A Precision
* Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
* series B. Vol. 1:
* lgamma(z + 1) = (log(2) + log(pi)) / 2 + (z + 1/2) * log(t(z)) - t(z) + A(z)
* t(z) = z + kLanczosGamma + 1/2
* A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k))
*/
template <typename T>
T Lgamma(const T &input) {
T log_pi = std::log(M_pi);
T log_sqrt_two_pi = (std::log(2) + std::log(M_pi)) / 2;
/** If the input is less than 0.5 use Euler's reflection formula:
* gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
*/
bool need_to_reflect = (input < 0.5);
T input_after_reflect = need_to_reflect ? -input : input - 1;
T sum = kBaseLanczosCoeff;
for (size_t i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
T lanczos_coefficient = kLanczosCoefficients[i];
sum += lanczos_coefficient / (input_after_reflect + i + 1);
}
T gamma_plus_onehalf_plus_z = kLanczosGamma + 0.5 + input_after_reflect;
T log_t = log_lanczos_gamma_plus_one_half + std::log1pf(input_after_reflect / (kLanczosGamma + 0.5));
T log_y = log_sqrt_two_pi + (input_after_reflect + 0.5 - gamma_plus_onehalf_plus_z / log_t) * log_t + std::log(sum);
T abs_input = std::abs(input);
T abs_frac_input = abs_input - std::floor(abs_input);
T reduced_frac_input = (abs_frac_input > 0.5) ? 1 - abs_frac_input : abs_frac_input;
T reflection_denom = std::log(std::sin(M_pi * reduced_frac_input));
T reflection = std::isfinite(reflection_denom) ? log_pi - reflection_denom - log_y : -reflection_denom;
T result = need_to_reflect ? reflection : log_y;
return std::isinf(input) ? std::numeric_limits<T>::infinity() : result;
}
template <typename T>
T Digamma(const T &input) {
bool need_to_reflect = (input < 0.5);
T reflected_input = need_to_reflect ? -input : input - 1;
T num = 0;
T denom = kBaseLanczosCoeff;
for (size_t i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
T lanczos_coefficient = kLanczosCoefficients[i];
num -= lanczos_coefficient / ((reflected_input + i + 1) * (reflected_input + i + 1));
denom += lanczos_coefficient / (reflected_input + i + 1);
}
T gamma_plus_onehalf_plus_z = kLanczosGamma + 0.5 + reflected_input;
T log_t = log_lanczos_gamma_plus_one_half + std::log1pf(reflected_input / (kLanczosGamma + 0.5));
T result = log_t + num / denom - kLanczosGamma / gamma_plus_onehalf_plus_z;
T reduced_input = input + std::abs(std::floor(input + 0.5));
T reflection = result - M_pi * std::cos(M_pi * reduced_input) / std::sin(M_pi * reduced_input);
T real_result = need_to_reflect ? reflection : result;
// Digamma has poles at negative integers and zero; return nan for those.
return (input < 0 && input == std::floor(input)) ? std::numeric_limits<T>::quiet_NaN() : real_result;
}
template <typename T>
T use_igammact(const T &ax, const T &a, const T &x, T enabled, int mode) {
T y = 1 - a;
T z = x + y + 1;
T c = 0;
T pkm2 = 1;
T qkm2 = x;
T pkm1 = x + 1;
T qkm1 = z * x;
T ans = pkm1 / qkm1;
T t = 1;
T dpkm2_da = 0;
T dqkm2_da = 0;
T dpkm1_da = 0;
T dqkm1_da = -x;
T dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1;
std::vector<T> vals = {enabled, ans, t, y, z, c, pkm1, qkm1,
pkm2, qkm2, dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da};
constexpr int k2000 = 2000;
while (vals[kInputIndex0] && vals[kInputIndex5] < k2000) {
enabled = vals[kInputIndex0];
ans = vals[kInputIndex1];
T tmp_var_t = vals[kInputIndex2];
T tmp_var_y = vals[kInputIndex3];
T tmp_var_z = vals[kInputIndex4];
T tmp_var_c = vals[kInputIndex5];
pkm1 = vals[kInputIndex6];
qkm1 = vals[kInputIndex7];
pkm2 = vals[kInputIndex8];
qkm2 = vals[kInputIndex9];
dpkm2_da = vals[kInputIndex10];
dqkm2_da = vals[kInputIndex11];
dpkm1_da = vals[kInputIndex12];
dqkm1_da = vals[kInputIndex13];
dans_da = vals[kInputIndex14];
tmp_var_c += 1;
tmp_var_y += 1;
constexpr int TWO = 2;
tmp_var_z += TWO;
T yc = tmp_var_y * tmp_var_c;
T pk = pkm1 * tmp_var_z - pkm2 * yc;
T qk = qkm1 * tmp_var_z - qkm2 * yc;
bool qk_is_nonzero = (qk != 0);
T r = pk / qk;
ans = qk_is_nonzero ? r : ans;
T dpk_da = dpkm1_da * tmp_var_z - pkm1 - dpkm2_da * yc + pkm2 * tmp_var_c;
T dqk_da = dqkm1_da * tmp_var_z - qkm1 - dqkm2_da * yc + qkm2 * tmp_var_c;
T dans_da_new = qk_is_nonzero ? (dpk_da - ans * dqk_da) / qk : dans_da;
T grad_conditional = qk_is_nonzero ? std::abs(dans_da_new - dans_da) : 1;
pkm2 = pkm1;
pkm1 = pk;
qkm2 = qkm1;
qkm1 = qk;
dpkm2_da = dpkm1_da;
dqkm2_da = dqkm1_da;
dpkm1_da = dpk_da;
dqkm1_da = dqk_da;
bool rescale = std::abs(pk) > (1 / std::numeric_limits<T>::epsilon());
pkm2 = rescale ? pkm2 * std::numeric_limits<T>::epsilon() : pkm2;
pkm1 = rescale ? pkm1 * std::numeric_limits<T>::epsilon() : pkm1;
qkm2 = rescale ? qkm2 * std::numeric_limits<T>::epsilon() : qkm2;
qkm1 = rescale ? qkm1 * std::numeric_limits<T>::epsilon() : qkm1;
dpkm2_da = rescale ? dpkm2_da * std::numeric_limits<T>::epsilon() : dpkm2_da;
dqkm2_da = rescale ? dqkm2_da * std::numeric_limits<T>::epsilon() : dqkm2_da;
dpkm1_da = rescale ? dpkm1_da * std::numeric_limits<T>::epsilon() : dpkm1_da;
dqkm1_da = rescale ? dqkm1_da * std::numeric_limits<T>::epsilon() : dqkm1_da;
T conditional = enabled && (grad_conditional > std::numeric_limits<T>::epsilon());
vals[kInputIndex0] = conditional;
vals[kInputIndex5] = tmp_var_c;
if (enabled) {
vals = {conditional, ans, tmp_var_t, tmp_var_y, tmp_var_z, tmp_var_c, pkm1, qkm1,
pkm2, qkm2, dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da_new};
}
}
ans = vals[kInputIndex1];
if (mode == VALUE) {
return ans * ax;
}
dans_da = vals[kInputIndex14];
T dlogax_da = std::log(x) - Digamma<T>(a);
switch (mode) {
case DERIVATIVE:
return ax * (ans * dlogax_da + dans_da);
default:
return -(dans_da + ans * dlogax_da) * x;
}
}
template <typename T>
T use_igammacf(const T &ax, const T &a, T x, T enabled) {
std::vector<T> vals = {enabled, a, 1, 1, x, 0, 0};
while (vals[kInputIndex0] != 0) {
enabled = vals[kInputIndex0];
T r = vals[kInputIndex1];
T c = vals[kInputIndex2];
T ans = vals[kInputIndex3];
x = vals[kInputIndex4];
T dc_da = vals[kInputIndex5];
T dans_da = vals[kInputIndex6];
r += 1;
dc_da = dc_da * (x / r) + (-1 * c * x) / (r * r);
dans_da = dans_da + dc_da;
c = c * (x / r);
ans = ans + c;
T conditional = enabled && (std::abs(dc_da / dans_da) > std::numeric_limits<T>::epsilon());
vals[kInputIndex0] = conditional;
if (enabled) {
vals = {conditional, r, c, ans, x, dc_da, dans_da};
}
}
T ans = vals[kInputIndex3];
T dans_da = vals[kInputIndex6];
if (a == 0) {
return NAN;
}
T dlogax_da = std::log(x) - Digamma<T>(a + 1);
return ax * (ans * dlogax_da + dans_da) / a;
}
template <typename T>
T IgammaGradASingle(const T &a, const T &x) {
bool is_nan = std::isnan(a) || std::isnan(x);
bool x_is_zero = (x == 0);
bool domain_error = (x < 0) || (a <= 0);
bool use_igammac = (x > 1) && (x > a);
T ax = a * std::log(x) - x - Lgamma<T>(a);
bool underflow = (ax < -std::log(std::numeric_limits<T>::max()));
ax = std::exp(ax);
T enabled = static_cast<T>(!(x_is_zero || domain_error || underflow || is_nan));
T output;
if (use_igammac != 0) {
enabled = static_cast<T>(enabled && use_igammac);
output = -use_igammact(ax, a, x, enabled, DERIVATIVE);
} else {
enabled = static_cast<T>(enabled && !(use_igammac));
output = use_igammacf(ax, a, x, enabled);
}
output = (domain_error || is_nan || std::isnan(output)) ? std::numeric_limits<double>::quiet_NaN() : output;
output = x_is_zero || (std::isinf(x) && !is_nan && !domain_error && !std::isinf(a)) ? 0 : output;
return output;
}
template <typename T>
void IgammaGradACpuKernelMod::BcastCompute(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto a_data_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto x_data_addr = reinterpret_cast<T *>(inputs[1]->addr);
auto z_data_addr = reinterpret_cast<T *>(outputs[0]->addr);
size_t data_num = get_element_num(z_shape_);
auto output_shape = CPUKernelUtils::GetBroadcastShape(a_shape_, x_shape_);
BroadcastIterator iter(a_shape_, x_shape_, output_shape);
if (data_num < kParallelDataNums) {
iter.SetPos(0);
for (size_t i = 0; i < data_num; i++) {
T *a_index = a_data_addr + iter.GetInputPosA(); // i-th value of input0
T *x_index = x_data_addr + iter.GetInputPosB(); // i-th value of input1
*(z_data_addr + i) = IgammaGradASingle<T>(*a_index, *x_index);
iter.GenNextPos();
}
} else {
auto shard_igammaGradA = [z_data_addr, a_data_addr, x_data_addr, &iter](size_t start, size_t end) {
iter.SetPos(start);
for (size_t i = start; i < end; i++) {
T *a_index = a_data_addr + iter.GetInputPosA(); // i-th value of input0
T *x_index = x_data_addr + iter.GetInputPosB(); // i-th value of input1
*(z_data_addr + i) = IgammaGradASingle<T>(*a_index, *x_index);
iter.GenNextPos();
}
};
ParallelLaunchAutoSearch(shard_igammaGradA, data_num, this, &parallel_search_info_);
}
}
/* special compute is used in the following situations.
* 1. the shapes of input1 and input2 are the same
* 2. input1 is a 1D tensor with only one element or input1 is scalar
* 3. input2 is a 1D tensor with only one element or input2 is scalar
* 4. the shapes of input1 and input2 are different
**/
template <typename T>
void IgammaGradACpuKernelMod::SpecialCompute(int64_t type, int64_t start, int64_t end, const T *input1, const T *input2,
T *output) {
switch (type) {
case kSameShape: {
auto cur_input1 = input1 + start;
auto cur_input2 = input2 + start;
for (int64_t i = start; i < end; ++i) {
*output = IgammaGradASingle<T>(*cur_input1, *cur_input2);
output = output + 1;
cur_input1 = cur_input1 + 1;
cur_input2 = cur_input2 + 1;
}
break;
}
case kXOneElement: {
auto cur_input2 = input2 + start;
for (int64_t i = start; i < end; ++i) {
*output = IgammaGradASingle<T>(*input1, *cur_input2);
output = output + 1;
cur_input2 = cur_input2 + 1;
}
break;
}
case kYOneElement: {
auto cur_input1 = input1 + start;
for (int64_t i = start; i < end; ++i) {
*output = IgammaGradASingle<T>(*cur_input1, *input2);
output = output + 1;
cur_input1 = cur_input1 + 1;
}
break;
}
default:
break;
}
}
template <typename T>
void IgammaGradACpuKernelMod::NoBcastCompute(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto in0 = reinterpret_cast<T *>(inputs[0]->addr);
auto in1 = reinterpret_cast<T *>(inputs[1]->addr);
auto out0 = reinterpret_cast<T *>(outputs[0]->addr);
size_t in0_elements_nums = get_element_num(a_shape_);
size_t in1_elements_nums = get_element_num(x_shape_);
size_t data_num = get_element_num(z_shape_);
int64_t type =
in0_elements_nums == in1_elements_nums ? kSameShape : (in0_elements_nums == 1 ? kXOneElement : kYOneElement);
if (data_num < kParallelDataNums) {
SpecialCompute<T>(type, 0, data_num, in0, in1, out0);
} else {
auto shard_igammaGradA = [type, in0, in1, out0, this](int64_t start, int64_t end) {
SpecialCompute<T>(type, start, end, in0, in1, out0 + start);
};
ParallelLaunchAutoSearch(shard_igammaGradA, data_num, this, &parallel_search_info_);
}
}
void IgammaGradACpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
a_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
z_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
}
bool IgammaGradACpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeFloat32) {
LaunchKernel<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat64) {
LaunchKernel<double>(inputs, outputs);
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dtype of 'var' should be float32 or float64, but got "
<< TypeIdToType(dtype_)->ToString();
}
return true;
}
template <typename T>
void IgammaGradACpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_);
size_t in0_elements_nums = get_element_num(a_shape_);
size_t in1_elements_nums = get_element_num(x_shape_);
bool isNeedBcast = (a_shape_ == x_shape_) || (in0_elements_nums == 1) || (in1_elements_nums == 1);
if (isNeedBcast) {
NoBcastCompute<T>(inputs, outputs);
} else {
BcastCompute<T>(inputs, outputs);
}
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, IgammaGradA, IgammaGradACpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,69 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_IGAMMAGRADA_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_IGAMMAGRADA_CPU_KERNEL_H_
#include <vector>
#include <array>
#include <cmath>
#include <iostream>
#include <map>
#include <string>
#include <tuple>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class IgammaGradACpuKernelMod : public DeprecatedNativeCpuKernelMod {
public:
IgammaGradACpuKernelMod() = default;
~IgammaGradACpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override {
static std::vector<KernelAttr> support_list = {
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64)};
return support_list;
}
private:
std::vector<size_t> a_shape_;
std::vector<size_t> x_shape_;
std::vector<size_t> z_shape_;
TypeId dtype_{kTypeUnknown};
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
template <typename T>
void BcastCompute(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &);
template <typename T>
void SpecialCompute(int64_t, int64_t, int64_t, const T *, const T *, T *);
template <typename T>
void NoBcastCompute(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &);
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_IGAMMAGRADA_CPU_KERNEL_H_

View File

@ -818,6 +818,9 @@ GVAR_DEF(PrimitivePtr, kPrimEinsumGrad, std::make_shared<Primitive>("EinsumGrad"
GVAR_DEF(PrimitivePtr, kPrimTrace, std::make_shared<Primitive>("Trace"));
GVAR_DEF(PrimitivePtr, kPrimTraceGrad, std::make_shared<Primitive>("TraceGrad"));
GVAR_DEF(PrimitivePtr, kPrimZeta, std::make_shared<Primitive>("Zeta"));
GVAR_DEF(PrimitivePtr, kPrimIgamma, std::make_shared<Primitive>("Igamma"));
GVAR_DEF(PrimitivePtr, kPrimIgammac, std::make_shared<Primitive>("Igammac"));
GVAR_DEF(PrimitivePtr, kPrimIgammaGradA, std::make_shared<Primitive>("IgammaGradA"));
// Image
GVAR_DEF(PrimitivePtr, kPrimNonMaxSuppressionV3, std::make_shared<Primitive>("NonMaxSuppressionV3"));

View File

@ -0,0 +1,61 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/grad/igammagrada.h"
#include <string>
#include <set>
#include <map>
#include "abstract/ops/primitive_infer_map.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr IgammaGradAInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
return BroadCastInferShape(prim_name, input_args);
}
TypePtr IgammaGradAInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
auto a_type = input_args[kInputIndex0]->BuildType();
auto x_type = input_args[kInputIndex1]->BuildType();
const std::set<TypePtr> valid_types = {kFloat32, kFloat64};
std::map<std::string, TypePtr> args;
(void)args.insert({"a", a_type});
(void)args.insert({"x", x_type});
return CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
}
} // namespace
MIND_API_OPERATOR_IMPL(IgammaGradA, BaseOperator);
AbstractBasePtr IgammaGradAInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
const int64_t kInputNum = 2;
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, prim_name);
auto infer_type = IgammaGradAInferType(primitive, input_args);
auto infer_shape = IgammaGradAInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(IgammaGradA, prim::kPrimIgammaGradA, IgammaGradAInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,42 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_IGAMMAGRADA_H
#define MINDSPORE_CORE_OPS_IGAMMAGRADA_H
#include <vector>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameIgammaGradA = "IgammaGradA";
/// \brief Computes the gradient of igamma(a, x) wrt a.
class MIND_API IgammaGradA : public BaseOperator {
public:
MIND_API_BASE_MEMBER(IgammaGradA);
/// \brief Constructor.
IgammaGradA() : BaseOperator(kNameIgammaGradA) { InitIOName({"a", "x"}, {"z"}); }
};
abstract::AbstractBasePtr IgammaGradAInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using kPrimIgammaGradAPtr = std::shared_ptr<IgammaGradA>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_IGAMMAGRADA_H

View File

@ -0,0 +1,59 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/igamma.h"
#include <map>
#include <set>
#include <string>
#include "abstract/ops/primitive_infer_map.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr IgammaInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
return BroadCastInferShape(prim_name, input_args);
}
TypePtr IgammaInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
auto a_type = input_args[kInputIndex0]->BuildType();
auto x_type = input_args[kInputIndex1]->BuildType();
const std::set<TypePtr> valid_types = {kFloat32, kFloat64};
std::map<std::string, TypePtr> args;
(void)args.insert({"a", a_type});
(void)args.insert({"x", x_type});
return CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
}
} // namespace
MIND_API_OPERATOR_IMPL(Igamma, BaseOperator);
AbstractBasePtr IgammaInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
const int64_t kInputNum = 2;
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, prim_name);
auto infer_type = IgammaInferType(primitive, input_args);
auto infer_shape = IgammaInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(Igamma, prim::kPrimIgamma, IgammaInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,43 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_IGAMMA_H
#define MINDSPORE_CORE_OPS_IGAMMA_H
#include <vector>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameIgamma = "Igamma";
/// \brief Calculates lower regularized incomplete Gamma function.
/// Refer to Python API @ref mindspore.ops.Igamma for more details.
class MIND_API Igamma : public BaseOperator {
public:
MIND_API_BASE_MEMBER(Igamma);
/// \brief Constructor.
Igamma() : BaseOperator(kNameIgamma) { InitIOName({"a", "x"}, {"z"}); }
};
abstract::AbstractBasePtr IgammaInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using kPrimIgammaPtr = std::shared_ptr<Igamma>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_IGAMMA_H

View File

@ -0,0 +1,60 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/igammac.h"
#include <string>
#include <set>
#include <map>
#include "abstract/ops/primitive_infer_map.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr IgammacInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
return BroadCastInferShape(prim_name, input_args);
}
TypePtr IgammacInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
auto a_type = input_args[kInputIndex0]->BuildType();
auto x_type = input_args[kInputIndex1]->BuildType();
const std::set<TypePtr> valid_types = {kFloat32, kFloat64};
std::map<std::string, TypePtr> args;
(void)args.insert({"a", a_type});
(void)args.insert({"x", x_type});
return CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
}
} // namespace
MIND_API_OPERATOR_IMPL(Igammac, BaseOperator);
AbstractBasePtr IgammacInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
const int64_t kInputNum = 2;
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, prim_name);
auto infer_type = IgammacInferType(primitive, input_args);
auto infer_shape = IgammacInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(Igammac, prim::kPrimIgammac, IgammacInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,43 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_IGAMMAC_H
#define MINDSPORE_CORE_OPS_IGAMMAC_H
#include <vector>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameIgammac = "Igammac";
/// \brief Compute the upper regularized incomplete Gamma function Q(a, x).
/// Refer to Python API @ref mindspore.ops.Igammac for more details.
class MIND_API Igammac : public BaseOperator {
public:
MIND_API_BASE_MEMBER(Igammac);
/// \brief Constructor.
Igammac() : BaseOperator(kNameIgammac) { InitIOName({"a", "x"}, {"z"}); }
};
abstract::AbstractBasePtr IgammacInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using kPrimIgammacPtr = std::shared_ptr<Igammac>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_IGAMMAC_H

View File

@ -19,14 +19,17 @@ from mindspore.common import dtype as mstype
from mindspore import nn
import mindspore.numpy as mnp
import numpy as np
from ...nn.layer import math
from .. import functional as F
from .. import operations as P
from ..operations.math_ops import Trace
from ..functional import broadcast_gradient_args
from .._grad.grad_base import bprop_getters
from .._grad.grad_math_ops import binop_grad_common
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations import _grad_ops as G
from ..operations import math_ops as math
from ..operations.math_ops import Igamma, Igammac
from ..primitive import constexpr
from ..operations.math_ops import ReduceStd
@ -512,3 +515,51 @@ def get_bprop_trace(self):
return (dx,)
return bprop
@bprop_getters.register(Igamma)
def get_bprop_igamma(self):
"""Grad definition for `Igamma` operation."""
shape_ = P.Shape()
igammagrada = G.IgammaGradA()
lgamma = math.LGamma()
log_ = P.Log()
exp_ = P.Exp()
reshape_ = P.Reshape()
reduce_sum_ = P.ReduceSum()
def bprop(a, x, out, dout):
sa = shape_(a)
sx = shape_(x)
ra, rx = broadcast_gradient_args(sa, sx)
partial_a = igammagrada(a, x)
partial_x = exp_(-x + (a - 1) * log_(x) - lgamma(a))
if ra != () or rx != ():
return reshape_(reduce_sum_(partial_a * dout, ra), sa), reshape_(reduce_sum_(partial_x * dout, rx), sx)
return reshape_(partial_a * dout, sa), reshape_(partial_x * dout, sx)
return bprop
@bprop_getters.register(Igammac)
def get_bprop_igammac(self):
"""Grad definition for `Igammac` operation."""
shape_ = P.Shape()
igammagrada = G.IgammaGradA()
lgamma = math.LGamma()
log_ = P.Log()
exp_ = P.Exp()
reshape_ = P.Reshape()
reduce_sum_ = P.ReduceSum()
neg_ = P.Neg()
def bprop(a, x, out, dout):
sa = shape_(a)
sx = shape_(x)
ra, rx = broadcast_gradient_args(sa, sx)
partial_a = igammagrada(a, x)
partial_x = exp_(-x + (a - 1) * log_(x) - lgamma(a))
if ra != () or rx != ():
return neg_(reshape_(reduce_sum_(partial_a * dout, ra), sa)), \
neg_(reshape_(reduce_sum_(partial_x * dout, rx), sx))
return neg_(reshape_(partial_a * dout, sa)), neg_(reshape_(partial_x * dout, sx))
return bprop

View File

@ -49,6 +49,9 @@ from .asin_grad import _asin_grad_aicpu
from .is_finite import _is_finite_aicpu
from .is_inf import _is_inf_aicpu
from .is_nan import _is_nan_aicpu
from .igamma import _igamma_aicpu
from .igammac import _igammac_aicpu
from .igammagrada import _igammagrada_aicpu
from .reshape import _reshape_aicpu
from .fill_v2 import _fill_v2_aicpu
from .flatten import _flatten_aicpu

View File

@ -0,0 +1,30 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Igamma op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
igamma_op_info = AiCPURegOp("Igamma") \
.fusion_type("OPAQUE") \
.input(0, "a", "required") \
.input(1, "x", "required") \
.output(0, "z", "required") \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
.get_op_info()
@op_info_register(igamma_op_info)
def _igamma_aicpu():
"""Igamma aicpu register"""
return

View File

@ -0,0 +1,30 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Igammac op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
igammac_op_info = AiCPURegOp("Igammac") \
.fusion_type("OPAQUE") \
.input(0, "a", "required") \
.input(1, "x", "required") \
.output(0, "z", "required") \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
.get_op_info()
@op_info_register(igammac_op_info)
def _igammac_aicpu():
"""Igammac aicpu register"""
return

View File

@ -0,0 +1,30 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Igamma op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
igammagrada_op_info = AiCPURegOp("IgammaGradA") \
.fusion_type("OPAQUE") \
.input(0, "a", "required") \
.input(1, "x", "required") \
.output(0, "z", "required") \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
.get_op_info()
@op_info_register(igammagrada_op_info)
def _igammagrada_aicpu():
"""IgammaGradA aicpu register"""
return

View File

@ -2527,3 +2527,39 @@ class TraceGrad(Primitive):
@prim_attr_register
def __init__(self):
pass
class IgammaGradA(Primitive):
r"""
Computes the gradient of igamma(a, x) wrt a.
Inputs:
- **a** (Tensor) - The input tensor. With float32 or float 64 data type.
- **x** (Tensor) - The input tensor. With float32 data or float64 type. `x` should have
the same dtype with `a`.
Outputs:
Tensor, has the same dtype as `a` and `x`.
Raises:
TypeError: If a or grad is not a Tensor.
TypeError: If dtype of input x and a is not float32 nor float64.
TypeError: If x has different dtype with a.
ValueError: If `a` could not be broadcast to a tensor with shape of `x`.
Supported Platforms:
``Ascend````CPU``
Examples:
>>> a = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32))
>>> x = Tensor(np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32))
>>> igammagrada = G.IgammaGradA()
>>> output = igammagrada(a, x)
>>> print (output)
[-0.2940046 -0.20153049 -0.13028376 -0.08352186]
"""
@prim_attr_register
def __init__(self):
"""Initialize IgammaGradA"""
self.init_prim_io_names(inputs=['a', 'x'], outputs=['z'])

View File

@ -5380,6 +5380,106 @@ class Trunc(Primitive):
"""Initialize Trunc"""
class Igamma(Primitive):
r"""
Calculates lower regularized incomplete Gamma function.
The lower regularized incomplete Gamma function is defined as:
.. math::
P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)
where
.. math::
gamma(a, x) = \int_0^x t^{a-1} \exp^{-t} dt
is the lower incomplete Gamma function.
Above :math:`Q(a, x)` is the upper regularized complete Gamma function.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Inputs:
- **a** (Tensor) - The input tensor. With type of float32 or float64.
- **x** (Tensor) - The input tensor. With float32 or float64 type. `x` should have
the same dtype with `a`.
Outputs:
Tensor, has the same dtype as `a` and `x`.
Raises:
TypeError: If a or x is not a Tensor.
TypeError: If dtype of input x and a is not float32 nor float64.
TypeError: If x has different dtype with a.
ValueError: If `a` could not be broadcast to a tensor with shape of `x`.
Supported Platforms:
``Ascend````CPU``
Examples:
>>> a = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32))
>>> x = Tensor(np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32))
>>> igamma = P.IGamma()
>>> output = igamma(a, x)
>>> print (output)
[0.593994 0.35276785 0.21486944 0.13337152]
"""
@prim_attr_register
def __init__(self):
"""Initialize Igamma"""
self.init_prim_io_names(inputs=['a', 'x'], outputs=['z'])
class Igammac(Primitive):
r"""
Compute the upper regularized incomplete Gamma function Q(a, x).
The upper regularized incomplete Gamma function is defined as:
\(Q(a, x) = Gamma(a, x) / Gamma(a) = 1 - P(a, x)\)
where
\(Gamma(a, x) = int_{x}^{\infty} t^{a-1} exp(-t) dt\)
is the upper incomplete Gama function.
Note, above P(a, x) (Igamma) is the lower regularized complete Gamma function.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Inputs:
- **a** (Tensor) - The input tensor of igammac. With float32 or float64 data type.
- **x** (Tensor) - The input tensor of igammac. With float32 or float64 type. `x` should have
the same type with `a`.
Outputs:
A Tensor, has the same dtype as `a` and `x`.
Raises:
TypeError: If dtype of input x and a is not float32 nor float64.
TypeError: If a or x is not a Tensor.
TypeError: If x has different dtype with a.
ValueError: If `a` could not be broadcast to a tensor with shape of `x`.
Supported Platforms:
``Ascend`` ``CPU``
Examples:
>>> a = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32))
>>> x = Tensor(np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32))
>>> igammac = P.IGammac()
>>> output = igammac(a, x)
>>> print (output)
[0.40600586 0.6472318 0.7851304 0.8666283 ]
"""
@prim_attr_register
def __init__(self):
"""Initialize Igammac"""
self.init_prim_io_names(inputs=['a', 'x'], outputs=['z'])
class IsClose(Primitive):
r"""
Returns a boolean tensor where two tensors are element-wise equal within a tolerance.

View File

@ -25,8 +25,9 @@ from mindspore.common import dtype as mstype
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.operations._grad_ops import IgammaGradA
from mindspore.ops import prim_attr_register, PrimitiveWithInfer
from mindspore.ops.operations.math_ops import Zeta
from mindspore.ops.operations.math_ops import Zeta, Igamma, Igammac
from ..ut_filter import non_graph_engine
from ....mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward \
@ -478,7 +479,25 @@ raise_set = [
('Zeta', {
'block': Zeta(),
'desc_inputs': [Tensor(np.array([1, 1, 1, 1], np.float32)),
Tensor([0.5, 0.5, 0.5, 0.5], mstype.float32)],
Tensor([0.5, 0.5, 0.5, 0.5], mstype.float32)]}),
('Igamma', {
'block': Igamma(),
'desc_inputs': [Tensor(np.array([1.1, 2.2, -4.1], np.float32)),
Tensor(np.array([0.2, 1.2, 2.1], np.float32))],
'desc_bprop': [Tensor(np.array([2, 3], np.float32)),
Tensor(np.array([2, 3], np.float32))],
'skip': ['backward']}),
('Igammac', {
'block': Igammac(),
'desc_inputs': [Tensor(np.array([1.1, 2.2, -4.1], np.float32)),
Tensor(np.array([0.2, 1.2, 2.1], np.float32))],
'desc_bprop': [Tensor(np.array([2, 3], np.float32)),
Tensor(np.array([2, 3], np.float32))],
'skip': ['backward']}),
('IgammaGradA', {
'block': IgammaGradA(),
'desc_inputs': [Tensor(np.array([1.1, 2.2, 8.1, 2.1], np.float32)),
Tensor(np.array([0.2, 1.2, 2.1, 3.4], np.float32))],
'skip': ['backward']}),
]