!33193 [assistant][ops]New operator implementation, include igamma igammac igammagrad
Merge pull request !33193 from ganqijun/ig
This commit is contained in:
commit
5e5606e82f
|
@ -0,0 +1,385 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "plugin/device/cpu/kernel/igamma_cpu_kernel.h"
|
||||
#include <limits>
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
/**
|
||||
* Coefficients for the Lanczos approximation of the gamma function. The
|
||||
* coefficients are uniquely determined by the choice of g and n (kLanczosGamma
|
||||
* and kLanczosCoefficients.size() + 1). The coefficients below correspond to
|
||||
* [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and [7,
|
||||
* 9] seemed to be the least sensitive to the quality of the log function. In
|
||||
* particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5
|
||||
* for a particularly inaccurate log function.
|
||||
* */
|
||||
static constexpr double kLanczosGamma = 7; // aka g
|
||||
constexpr size_t kInputIndex0 = 0;
|
||||
constexpr size_t kInputIndex1 = 1;
|
||||
constexpr size_t kInputIndex2 = 2;
|
||||
constexpr size_t kInputIndex3 = 3;
|
||||
constexpr size_t kInputIndex4 = 4;
|
||||
constexpr size_t kInputIndex5 = 5;
|
||||
constexpr size_t kInputIndex6 = 6;
|
||||
constexpr size_t kInputIndex7 = 7;
|
||||
constexpr size_t kInputIndex8 = 8;
|
||||
constexpr size_t kInputIndex9 = 9;
|
||||
constexpr size_t kInputIndex10 = 10;
|
||||
constexpr size_t kInputIndex11 = 11;
|
||||
constexpr size_t kInputIndex12 = 12;
|
||||
constexpr size_t kInputIndex13 = 13;
|
||||
constexpr size_t kInputIndex14 = 14;
|
||||
static constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
|
||||
static constexpr double M_pi = 3.141592653589793238462643383279;
|
||||
static constexpr std::array<double, 8> kLanczosCoefficients = {
|
||||
676.520368121885098567009190444019, -1259.13921672240287047156078755283,
|
||||
771.3234287776530788486528258894, -176.61502916214059906584551354,
|
||||
12.507343278686904814458936853, -0.13857109526572011689554707,
|
||||
9.984369578019570859563e-6, 1.50563273514931155834e-7};
|
||||
double log_lanczos_gamma_plus_one_half = std::log(kLanczosGamma + 0.5);
|
||||
constexpr int64_t kParallelDataNums = 256;
|
||||
constexpr int64_t kSameShape = 0;
|
||||
constexpr int64_t kXOneElement = 1;
|
||||
constexpr int64_t kYOneElement = 2;
|
||||
constexpr size_t kInputNum = 2;
|
||||
constexpr size_t kOutputNum = 1;
|
||||
|
||||
size_t get_element_num(const std::vector<size_t> &shape) {
|
||||
size_t size = 1;
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
size *= shape[i];
|
||||
}
|
||||
return size;
|
||||
}
|
||||
} // namespace
|
||||
/** Compute the Lgamma function using Lanczos' approximation from "A Precision
|
||||
* Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
|
||||
* series B. Vol. 1:
|
||||
* lgamma(z + 1) = (log(2) + log(pi)) / 2 + (z + 1/2) * log(t(z)) - t(z) + A(z)
|
||||
* t(z) = z + kLanczosGamma + 1/2
|
||||
* A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k))
|
||||
*/
|
||||
|
||||
template <typename T>
|
||||
T Lgamma(const T &input) {
|
||||
T log_pi = std::log(M_pi);
|
||||
T log_sqrt_two_pi = (std::log(2) + std::log(M_pi)) / 2;
|
||||
|
||||
/** If the input is less than 0.5 use Euler's reflection formula:
|
||||
* gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
|
||||
*/
|
||||
bool need_to_reflect = (input < 0.5);
|
||||
T input_after_reflect = need_to_reflect ? -input : input - 1;
|
||||
T sum = kBaseLanczosCoeff;
|
||||
for (size_t i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
|
||||
T lanczos_coefficient = kLanczosCoefficients[i];
|
||||
|
||||
sum += lanczos_coefficient / (input_after_reflect + i + 1);
|
||||
}
|
||||
T gamma_plus_onehalf_plus_z = kLanczosGamma + 0.5 + input_after_reflect;
|
||||
T log_t = log_lanczos_gamma_plus_one_half + std::log1pf(input_after_reflect / (kLanczosGamma + 0.5));
|
||||
T log_y = log_sqrt_two_pi + (input_after_reflect + 0.5 - gamma_plus_onehalf_plus_z / log_t) * log_t + std::log(sum);
|
||||
T abs_input = std::abs(input);
|
||||
T abs_frac_input = abs_input - std::floor(abs_input);
|
||||
|
||||
T reduced_frac_input = (abs_frac_input > 0.5) ? 1 - abs_frac_input : abs_frac_input;
|
||||
T reflection_denom = std::log(std::sin(M_pi * reduced_frac_input));
|
||||
|
||||
T reflection = std::isfinite(reflection_denom) ? log_pi - reflection_denom - log_y : -reflection_denom;
|
||||
T result = need_to_reflect ? reflection : log_y;
|
||||
|
||||
return std::isinf(input) ? std::numeric_limits<T>::infinity() : result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T use_igammact(const T &ax, const T &a, const T &x, T enabled) {
|
||||
T y = 1 - a;
|
||||
T z = x + y + 1;
|
||||
T c = 0;
|
||||
T pkm2 = 1;
|
||||
T qkm2 = x;
|
||||
T pkm1 = x + 1;
|
||||
T qkm1 = z * x;
|
||||
T ans = pkm1 / qkm1;
|
||||
T t = 1;
|
||||
T dpkm2_da = 0;
|
||||
T dqkm2_da = 0;
|
||||
T dpkm1_da = 0;
|
||||
T dqkm1_da = -x;
|
||||
T dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1;
|
||||
std::vector<T> vals = {enabled, ans, t, y, z, c, pkm1, qkm1,
|
||||
pkm2, qkm2, dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da};
|
||||
constexpr int k2000 = 2000;
|
||||
while (vals[kInputIndex0] && vals[kInputIndex5] < k2000) {
|
||||
enabled = vals[kInputIndex0];
|
||||
ans = vals[kInputIndex1];
|
||||
T tmp_var_t = vals[kInputIndex2];
|
||||
T tmp_var_y = vals[kInputIndex3];
|
||||
T tmp_var_z = vals[kInputIndex4];
|
||||
T tmp_var_c = vals[kInputIndex5];
|
||||
pkm1 = vals[kInputIndex6];
|
||||
qkm1 = vals[kInputIndex7];
|
||||
pkm2 = vals[kInputIndex8];
|
||||
qkm2 = vals[kInputIndex9];
|
||||
dpkm2_da = vals[kInputIndex10];
|
||||
dqkm2_da = vals[kInputIndex11];
|
||||
dpkm1_da = vals[kInputIndex12];
|
||||
dqkm1_da = vals[kInputIndex13];
|
||||
dans_da = vals[kInputIndex14];
|
||||
tmp_var_c += 1;
|
||||
tmp_var_y += 1;
|
||||
constexpr int TWO = 2;
|
||||
tmp_var_z += TWO;
|
||||
|
||||
T yc = tmp_var_y * tmp_var_c;
|
||||
T pk = pkm1 * tmp_var_z - pkm2 * yc;
|
||||
T qk = qkm1 * tmp_var_z - qkm2 * yc;
|
||||
bool qk_is_nonzero = (qk != 0);
|
||||
T r = pk / qk;
|
||||
t = qk_is_nonzero ? std::abs((ans - r) / r) : 1;
|
||||
ans = qk_is_nonzero ? r : ans;
|
||||
|
||||
T dpk_da = dpkm1_da * tmp_var_z - pkm1 - dpkm2_da * yc + pkm2 * tmp_var_c;
|
||||
T dqk_da = dqkm1_da * tmp_var_z - qkm1 - dqkm2_da * yc + qkm2 * tmp_var_c;
|
||||
T dans_da_new = qk_is_nonzero ? (dpk_da - ans * dqk_da) / qk : dans_da;
|
||||
pkm2 = pkm1;
|
||||
pkm1 = pk;
|
||||
qkm2 = qkm1;
|
||||
qkm1 = qk;
|
||||
|
||||
dpkm2_da = dpkm1_da;
|
||||
dqkm2_da = dqkm1_da;
|
||||
dpkm1_da = dpk_da;
|
||||
dqkm1_da = dqk_da;
|
||||
bool rescale = std::abs(pk) > (1 / std::numeric_limits<T>::epsilon());
|
||||
|
||||
pkm2 = rescale ? pkm2 * std::numeric_limits<T>::epsilon() : pkm2;
|
||||
pkm1 = rescale ? pkm1 * std::numeric_limits<T>::epsilon() : pkm1;
|
||||
qkm2 = rescale ? qkm2 * std::numeric_limits<T>::epsilon() : qkm2;
|
||||
qkm1 = rescale ? qkm1 * std::numeric_limits<T>::epsilon() : qkm1;
|
||||
|
||||
dpkm2_da = rescale ? dpkm2_da * std::numeric_limits<T>::epsilon() : dpkm2_da;
|
||||
dqkm2_da = rescale ? dqkm2_da * std::numeric_limits<T>::epsilon() : dqkm2_da;
|
||||
dpkm1_da = rescale ? dpkm1_da * std::numeric_limits<T>::epsilon() : dpkm1_da;
|
||||
dqkm1_da = rescale ? dqkm1_da * std::numeric_limits<T>::epsilon() : dqkm1_da;
|
||||
|
||||
T conditional = enabled && (t > std::numeric_limits<T>::epsilon());
|
||||
vals[kInputIndex0] = conditional;
|
||||
vals[kInputIndex5] = tmp_var_c;
|
||||
if (enabled) {
|
||||
vals = {conditional, ans, tmp_var_t, tmp_var_y, tmp_var_z, tmp_var_c, pkm1, qkm1,
|
||||
pkm2, qkm2, dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da_new};
|
||||
}
|
||||
}
|
||||
ans = vals[kInputIndex1];
|
||||
return 1 - ans * ax;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T use_igammacf(T ax, T a, T x, T enabled) {
|
||||
std::vector<T> vals = {enabled, a, 1, 1, x, 0, 0};
|
||||
while (vals[kInputIndex0] != 0) {
|
||||
enabled = vals[kInputIndex0];
|
||||
T r = vals[kInputIndex1];
|
||||
T c = vals[kInputIndex2];
|
||||
T ans = vals[kInputIndex3];
|
||||
x = vals[kInputIndex4];
|
||||
T dc_da = vals[kInputIndex5];
|
||||
T dans_da = vals[kInputIndex6];
|
||||
r += 1;
|
||||
dc_da = dc_da * (x / r) + (-1 * c * x) / (r * r);
|
||||
dans_da = dans_da + dc_da;
|
||||
c = c * (x / r);
|
||||
ans = ans + c;
|
||||
T conditional = enabled && (c / ans > std::numeric_limits<T>::epsilon());
|
||||
vals[kInputIndex0] = conditional;
|
||||
if (enabled) {
|
||||
vals = {conditional, r, c, ans, x, dc_da, dans_da};
|
||||
}
|
||||
}
|
||||
T ans = vals[kInputIndex3];
|
||||
if (a == 0) {
|
||||
return NAN;
|
||||
}
|
||||
return (ans * ax) / a;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T IgammaSingle(const T &a, const T &x) {
|
||||
if (!std::isinf(a) && (a > 0) && std::isinf(x) && x > 0) {
|
||||
return 1;
|
||||
}
|
||||
bool is_nan = std::isnan(a) || std::isnan(x);
|
||||
bool x_is_zero = (x == 0);
|
||||
bool domain_error = (x < 0) || (a <= 0);
|
||||
bool use_igammac = (x > 1) && (x > a);
|
||||
|
||||
T ax = a * std::log(x) - x - Lgamma<T>(a);
|
||||
|
||||
bool underflow = (ax < -std::log(std::numeric_limits<T>::max()));
|
||||
|
||||
ax = std::exp(ax);
|
||||
T enabled = static_cast<T>(!(x_is_zero || domain_error || underflow || is_nan));
|
||||
T output;
|
||||
if (use_igammac != 0) {
|
||||
enabled = static_cast<T>(enabled && use_igammac);
|
||||
output = use_igammact(ax, a, x, enabled);
|
||||
} else {
|
||||
enabled = static_cast<T>(enabled && !(use_igammac));
|
||||
output = use_igammacf(ax, a, x, enabled);
|
||||
}
|
||||
output = (domain_error || is_nan || std::isnan(output)) ? std::numeric_limits<double>::quiet_NaN() : output;
|
||||
output = x_is_zero ? 0 : output;
|
||||
return output;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void IgammaCpuKernelMod::BcastCompute(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto a_data_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto x_data_addr = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto z_data_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
size_t data_num = get_element_num(z_shape_);
|
||||
auto output_shape = CPUKernelUtils::GetBroadcastShape(a_shape_, x_shape_);
|
||||
BroadcastIterator iter(a_shape_, x_shape_, output_shape);
|
||||
if (data_num < kParallelDataNums) {
|
||||
iter.SetPos(0);
|
||||
for (size_t i = 0; i < data_num; i++) {
|
||||
T *a_index = a_data_addr + iter.GetInputPosA(); // i-th value of input0
|
||||
T *x_index = x_data_addr + iter.GetInputPosB(); // i-th value of input1
|
||||
*(z_data_addr + i) = IgammaSingle<T>(*a_index, *x_index);
|
||||
iter.GenNextPos();
|
||||
}
|
||||
} else {
|
||||
auto shard_igamma = [z_data_addr, a_data_addr, x_data_addr, &iter](size_t start, size_t end) {
|
||||
iter.SetPos(start);
|
||||
for (size_t i = start; i < end; i++) {
|
||||
T *a_index = a_data_addr + iter.GetInputPosA(); // i-th value of input0
|
||||
T *x_index = x_data_addr + iter.GetInputPosB(); // i-th value of input1
|
||||
*(z_data_addr + i) = IgammaSingle<T>(*a_index, *x_index);
|
||||
iter.GenNextPos();
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(shard_igamma, data_num, this, ¶llel_search_info_);
|
||||
}
|
||||
}
|
||||
|
||||
/* special compute is used in the following situations.
|
||||
* 1. the shapes of input1 and input2 are the same
|
||||
* 2. input1 is a 1D tensor with only one element or input1 is scalar
|
||||
* 3. input2 is a 1D tensor with only one element or input2 is scalar
|
||||
* 4. the shapes of input1 and input2 are different
|
||||
**/
|
||||
template <typename T>
|
||||
void IgammaCpuKernelMod::SpecialCompute(int64_t type, int64_t start, int64_t end, const T *input1, const T *input2,
|
||||
T *output) {
|
||||
switch (type) {
|
||||
case kSameShape: {
|
||||
auto cur_input1 = input1 + start;
|
||||
auto cur_input2 = input2 + start;
|
||||
for (int64_t i = start; i < end; ++i) {
|
||||
*output = IgammaSingle<T>(*cur_input1, *cur_input2);
|
||||
output = output + 1;
|
||||
cur_input1 = cur_input1 + 1;
|
||||
cur_input2 = cur_input2 + 1;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case kXOneElement: {
|
||||
auto cur_input2 = input2 + start;
|
||||
for (int64_t i = start; i < end; ++i) {
|
||||
*output = IgammaSingle<T>(*input1, *cur_input2);
|
||||
output = output + 1;
|
||||
cur_input2 = cur_input2 + 1;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case kYOneElement: {
|
||||
auto cur_input1 = input1 + start;
|
||||
for (int64_t i = start; i < end; ++i) {
|
||||
*output = IgammaSingle<T>(*cur_input1, *input2);
|
||||
output = output + 1;
|
||||
cur_input1 = cur_input1 + 1;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void IgammaCpuKernelMod::NoBcastCompute(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto in0 = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto in1 = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto out0 = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
size_t in0_elements_nums = get_element_num(a_shape_);
|
||||
size_t in1_elements_nums = get_element_num(x_shape_);
|
||||
size_t data_num = get_element_num(z_shape_);
|
||||
int64_t type =
|
||||
in0_elements_nums == in1_elements_nums ? kSameShape : (in0_elements_nums == 1 ? kXOneElement : kYOneElement);
|
||||
if (data_num < kParallelDataNums) {
|
||||
SpecialCompute<T>(type, 0, data_num, in0, in1, out0);
|
||||
} else {
|
||||
auto shard_igamma = [type, in0, in1, out0, this](int64_t start, int64_t end) {
|
||||
SpecialCompute<T>(type, start, end, in0, in1, out0 + start);
|
||||
};
|
||||
ParallelLaunchAutoSearch(shard_igamma, data_num, this, ¶llel_search_info_);
|
||||
}
|
||||
}
|
||||
|
||||
void IgammaCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
a_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
z_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
}
|
||||
|
||||
bool IgammaCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (dtype_ == kNumberTypeFloat32) {
|
||||
LaunchKernel<float>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat64) {
|
||||
LaunchKernel<double>(inputs, outputs);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dtype of 'var' should be float32 or float64, but got "
|
||||
<< TypeIdToType(dtype_)->ToString();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void IgammaCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_);
|
||||
size_t in0_elements_nums = get_element_num(a_shape_);
|
||||
size_t in1_elements_nums = get_element_num(x_shape_);
|
||||
bool isNeedBcast = (a_shape_ == x_shape_) || (in0_elements_nums == 1) || (in1_elements_nums == 1);
|
||||
if (isNeedBcast) {
|
||||
NoBcastCompute<T>(inputs, outputs);
|
||||
} else {
|
||||
BcastCompute<T>(inputs, outputs);
|
||||
}
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Igamma, IgammaCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,70 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_IGAMMA_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_IGAMMA_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class IgammaCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
public:
|
||||
IgammaCpuKernelMod() = default;
|
||||
~IgammaCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override {
|
||||
static std::vector<KernelAttr> support_list = {
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64)};
|
||||
return support_list;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<size_t> a_shape_;
|
||||
std::vector<size_t> x_shape_;
|
||||
std::vector<size_t> z_shape_;
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
template <typename T>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
template <typename T>
|
||||
void BcastCompute(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &);
|
||||
|
||||
template <typename T>
|
||||
void SpecialCompute(int64_t, int64_t, int64_t, const T *, const T *, T *);
|
||||
|
||||
template <typename T>
|
||||
void NoBcastCompute(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &);
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_IGAMMA_CPU_KERNEL_H_
|
|
@ -0,0 +1,379 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "plugin/device/cpu/kernel/igammac_cpu_kernel.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
/**
|
||||
* Coefficients for the Lanczos approximation of the gamma function. The
|
||||
* coefficients are uniquely determined by the choice of g and n (kLanczosGamma
|
||||
* and kLanczosCoefficients.size() + 1). The coefficients below correspond to
|
||||
* [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and [7,
|
||||
* 9] seemed to be the least sensitive to the quality of the log function. In
|
||||
* particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5
|
||||
* for a particularly inaccurate log function.
|
||||
* */
|
||||
static constexpr double kLanczosGamma = 7; // aka g
|
||||
constexpr size_t kInputIndex0 = 0;
|
||||
constexpr size_t kInputIndex1 = 1;
|
||||
constexpr size_t kInputIndex2 = 2;
|
||||
constexpr size_t kInputIndex3 = 3;
|
||||
constexpr size_t kInputIndex4 = 4;
|
||||
constexpr size_t kInputIndex5 = 5;
|
||||
constexpr size_t kInputIndex6 = 6;
|
||||
constexpr size_t kInputIndex7 = 7;
|
||||
constexpr size_t kInputIndex8 = 8;
|
||||
constexpr size_t kInputIndex9 = 9;
|
||||
constexpr size_t kInputIndex10 = 10;
|
||||
constexpr size_t kInputIndex11 = 11;
|
||||
constexpr size_t kInputIndex12 = 12;
|
||||
constexpr size_t kInputIndex13 = 13;
|
||||
constexpr size_t kInputIndex14 = 14;
|
||||
static constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
|
||||
static constexpr double M_pi = 3.141592653589793238462643383279;
|
||||
static constexpr std::array<double, 8> kLanczosCoefficients = {
|
||||
676.520368121885098567009190444019, -1259.13921672240287047156078755283,
|
||||
771.3234287776530788486528258894, -176.61502916214059906584551354,
|
||||
12.507343278686904814458936853, -0.13857109526572011689554707,
|
||||
9.984369578019570859563e-6, 1.50563273514931155834e-7};
|
||||
double log_lanczos_gamma_plus_one_half = std::log(kLanczosGamma + 0.5);
|
||||
constexpr int64_t kParallelDataNums = 256;
|
||||
constexpr int64_t kSameShape = 0;
|
||||
constexpr int64_t kXOneElement = 1;
|
||||
constexpr int64_t kYOneElement = 2;
|
||||
constexpr size_t kInputNum = 2;
|
||||
constexpr size_t kOutputNum = 1;
|
||||
|
||||
size_t get_element_num(const std::vector<size_t> &shape) {
|
||||
size_t size = 1;
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
size *= shape[i];
|
||||
}
|
||||
return size;
|
||||
}
|
||||
} // namespace
|
||||
/** Compute the Lgamma function using Lanczos' approximation from "A Precision
|
||||
* Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
|
||||
* series B. Vol. 1:
|
||||
* lgamma(z + 1) = (log(2) + log(pi)) / 2 + (z + 1/2) * log(t(z)) - t(z) + A(z)
|
||||
* t(z) = z + kLanczosGamma + 1/2
|
||||
* A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k))
|
||||
*/
|
||||
template <typename T>
|
||||
T Lgamma(const T &input) {
|
||||
T log_pi = std::log(M_pi);
|
||||
T log_sqrt_two_pi = (std::log(2) + std::log(M_pi)) / 2;
|
||||
|
||||
/** If the input is less than 0.5 use Euler's reflection formula:
|
||||
* gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
|
||||
*/
|
||||
bool need_to_reflect = (input < 0.5);
|
||||
T input_after_reflect = need_to_reflect ? -input : input - 1;
|
||||
T sum = kBaseLanczosCoeff;
|
||||
for (size_t i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
|
||||
T lanczos_coefficient = kLanczosCoefficients[i];
|
||||
|
||||
sum += lanczos_coefficient / (input_after_reflect + i + 1);
|
||||
}
|
||||
T gamma_plus_onehalf_plus_z = kLanczosGamma + 0.5 + input_after_reflect;
|
||||
T log_t = log_lanczos_gamma_plus_one_half + std::log1pf(input_after_reflect / (kLanczosGamma + 0.5));
|
||||
T log_y = log_sqrt_two_pi + (input_after_reflect + 0.5 - gamma_plus_onehalf_plus_z / log_t) * log_t + std::log(sum);
|
||||
T abs_input = std::abs(input);
|
||||
T abs_frac_input = abs_input - std::floor(abs_input);
|
||||
|
||||
T reduced_frac_input = (abs_frac_input > 0.5) ? 1 - abs_frac_input : abs_frac_input;
|
||||
T reflection_denom = std::log(std::sin(M_pi * reduced_frac_input));
|
||||
|
||||
T reflection = std::isfinite(reflection_denom) ? log_pi - reflection_denom - log_y : -reflection_denom;
|
||||
T result = need_to_reflect ? reflection : log_y;
|
||||
|
||||
return std::isinf(input) ? std::numeric_limits<T>::infinity() : result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T use_igammaf(const T &ax, const T &a, const T &x, T enabled) {
|
||||
T y = 1 - a;
|
||||
T z = x + y + 1;
|
||||
T c = 0;
|
||||
T pkm2 = 1;
|
||||
T qkm2 = x;
|
||||
T pkm1 = x + 1;
|
||||
T qkm1 = z * x;
|
||||
T ans = pkm1 / qkm1;
|
||||
T t = 1;
|
||||
T dpkm2_da = 0;
|
||||
T dqkm2_da = 0;
|
||||
T dpkm1_da = 0;
|
||||
T dqkm1_da = -x;
|
||||
T dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1;
|
||||
std::vector<T> vals = {enabled, ans, t, y, z, c, pkm1, qkm1,
|
||||
pkm2, qkm2, dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da};
|
||||
constexpr int k2000 = 2000;
|
||||
while (vals[kInputIndex0] && vals[kInputIndex5] < k2000) {
|
||||
enabled = vals[kInputIndex0];
|
||||
ans = vals[kInputIndex1];
|
||||
T tmp_var_t = vals[kInputIndex2];
|
||||
T tmp_var_y = vals[kInputIndex3];
|
||||
T tmp_var_z = vals[kInputIndex4];
|
||||
T tmp_var_c = vals[kInputIndex5];
|
||||
pkm1 = vals[kInputIndex6];
|
||||
qkm1 = vals[kInputIndex7];
|
||||
pkm2 = vals[kInputIndex8];
|
||||
qkm2 = vals[kInputIndex9];
|
||||
dpkm2_da = vals[kInputIndex10];
|
||||
dqkm2_da = vals[kInputIndex11];
|
||||
dpkm1_da = vals[kInputIndex12];
|
||||
dqkm1_da = vals[kInputIndex13];
|
||||
dans_da = vals[kInputIndex14];
|
||||
tmp_var_c += 1;
|
||||
tmp_var_y += 1;
|
||||
constexpr int TWO = 2;
|
||||
tmp_var_z += TWO;
|
||||
|
||||
T yc = tmp_var_y * tmp_var_c;
|
||||
T pk = pkm1 * tmp_var_z - pkm2 * yc;
|
||||
T qk = qkm1 * tmp_var_z - qkm2 * yc;
|
||||
bool qk_is_nonzero = (qk != 0);
|
||||
T r = pk / qk;
|
||||
t = qk_is_nonzero ? std::abs((ans - r) / r) : 1;
|
||||
ans = qk_is_nonzero ? r : ans;
|
||||
|
||||
T dpk_da = dpkm1_da * tmp_var_z - pkm1 - dpkm2_da * yc + pkm2 * tmp_var_c;
|
||||
T dqk_da = dqkm1_da * tmp_var_z - qkm1 - dqkm2_da * yc + qkm2 * tmp_var_c;
|
||||
T dans_da_new = qk_is_nonzero ? (dpk_da - ans * dqk_da) / qk : dans_da;
|
||||
pkm2 = pkm1;
|
||||
pkm1 = pk;
|
||||
qkm2 = qkm1;
|
||||
qkm1 = qk;
|
||||
|
||||
dpkm2_da = dpkm1_da;
|
||||
dqkm2_da = dqkm1_da;
|
||||
dpkm1_da = dpk_da;
|
||||
dqkm1_da = dqk_da;
|
||||
bool rescale = std::abs(pk) > (1 / std::numeric_limits<T>::epsilon());
|
||||
|
||||
pkm2 = rescale ? pkm2 * std::numeric_limits<T>::epsilon() : pkm2;
|
||||
pkm1 = rescale ? pkm1 * std::numeric_limits<T>::epsilon() : pkm1;
|
||||
qkm2 = rescale ? qkm2 * std::numeric_limits<T>::epsilon() : qkm2;
|
||||
qkm1 = rescale ? qkm1 * std::numeric_limits<T>::epsilon() : qkm1;
|
||||
|
||||
dpkm2_da = rescale ? dpkm2_da * std::numeric_limits<T>::epsilon() : dpkm2_da;
|
||||
dqkm2_da = rescale ? dqkm2_da * std::numeric_limits<T>::epsilon() : dqkm2_da;
|
||||
dpkm1_da = rescale ? dpkm1_da * std::numeric_limits<T>::epsilon() : dpkm1_da;
|
||||
dqkm1_da = rescale ? dqkm1_da * std::numeric_limits<T>::epsilon() : dqkm1_da;
|
||||
|
||||
T conditional = enabled && (t > std::numeric_limits<T>::epsilon());
|
||||
vals[kInputIndex0] = conditional;
|
||||
vals[kInputIndex5] = tmp_var_c;
|
||||
if (enabled) {
|
||||
vals = {conditional, ans, tmp_var_t, tmp_var_y, tmp_var_z, tmp_var_c, pkm1, qkm1,
|
||||
pkm2, qkm2, dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da_new};
|
||||
}
|
||||
}
|
||||
ans = vals[kInputIndex1];
|
||||
return ans * ax;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T use_igammat(T ax, T a, T x, T enabled) {
|
||||
std::vector<T> vals = {enabled, a, 1, 1, x, 0, 0};
|
||||
while (vals[kInputIndex0] != 0) {
|
||||
enabled = vals[kInputIndex0];
|
||||
T r = vals[kInputIndex1];
|
||||
T c = vals[kInputIndex2];
|
||||
T ans = vals[kInputIndex3];
|
||||
x = vals[kInputIndex4];
|
||||
T dc_da = vals[kInputIndex5];
|
||||
T dans_da = vals[kInputIndex6];
|
||||
r += 1;
|
||||
dc_da = dc_da * (x / r) + (-1 * c * x) / (r * r);
|
||||
dans_da = dans_da + dc_da;
|
||||
c = c * (x / r);
|
||||
ans = ans + c;
|
||||
T conditional = enabled && (c / ans > std::numeric_limits<T>::epsilon());
|
||||
vals[kInputIndex0] = conditional;
|
||||
if (enabled) {
|
||||
vals = {conditional, r, c, ans, x, dc_da, dans_da};
|
||||
}
|
||||
}
|
||||
T ans = vals[kInputIndex3];
|
||||
if (a == 0) {
|
||||
return NAN;
|
||||
}
|
||||
return 1 - (ans * ax) / a;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T IgammacSingle(const T &a, const T &x) {
|
||||
bool out_of_range = (x <= 0) || (a <= 0);
|
||||
bool use_igamma = (x < 1) || (x < a);
|
||||
T ax = a * std::log(x) - x - Lgamma(a);
|
||||
bool underflow = (ax < -std::log(std::numeric_limits<T>::max()));
|
||||
T enabled = static_cast<T>(!(out_of_range || underflow));
|
||||
|
||||
ax = std::exp(ax);
|
||||
T output;
|
||||
if (use_igamma != 0) {
|
||||
enabled = static_cast<T>(enabled && use_igamma);
|
||||
output = use_igammat(ax, a, x, enabled);
|
||||
} else {
|
||||
enabled = static_cast<T>(enabled && (!use_igamma));
|
||||
output = use_igammaf(ax, a, x, enabled);
|
||||
}
|
||||
output = out_of_range ? 1 : output;
|
||||
output = x < 0 || a <= 0 || std::isnan(x) || (std::isinf(x) && (x > 0)) || std::isnan(a)
|
||||
? std::numeric_limits<T>::quiet_NaN()
|
||||
: output;
|
||||
output = std::isinf(x) && x > 0 && a > 0 ? 0 : output;
|
||||
return output;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void IgammacCpuKernelMod::BcastCompute(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto a_data_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto x_data_addr = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto z_data_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
size_t data_num = get_element_num(z_shape_);
|
||||
auto output_shape = CPUKernelUtils::GetBroadcastShape(a_shape_, x_shape_);
|
||||
BroadcastIterator iter(a_shape_, x_shape_, output_shape);
|
||||
if (data_num < kParallelDataNums) {
|
||||
iter.SetPos(0);
|
||||
for (size_t i = 0; i < data_num; i++) {
|
||||
T *a_index = a_data_addr + iter.GetInputPosA(); // i-th value of input0
|
||||
T *x_index = x_data_addr + iter.GetInputPosB(); // i-th value of input1
|
||||
*(z_data_addr + i) = IgammacSingle<T>(*a_index, *x_index);
|
||||
iter.GenNextPos();
|
||||
}
|
||||
} else {
|
||||
auto shard_igammac = [z_data_addr, a_data_addr, x_data_addr, &iter](size_t start, size_t end) {
|
||||
iter.SetPos(start);
|
||||
for (size_t i = start; i < end; i++) {
|
||||
T *a_index = a_data_addr + iter.GetInputPosA(); // i-th value of input0
|
||||
T *x_index = x_data_addr + iter.GetInputPosB(); // i-th value of input1
|
||||
*(z_data_addr + i) = IgammacSingle<T>(*a_index, *x_index);
|
||||
iter.GenNextPos();
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(shard_igammac, data_num, this, ¶llel_search_info_);
|
||||
}
|
||||
}
|
||||
|
||||
/* special compute is used in the following situations.
|
||||
* 1. the shapes of input1 and input2 are the same
|
||||
* 2. input1 is a 1D tensor with only one element or input1 is scalar
|
||||
* 3. input2 is a 1D tensor with only one element or input2 is scalar
|
||||
* 4. the shapes of input1 and input2 are different
|
||||
**/
|
||||
template <typename T>
|
||||
void IgammacCpuKernelMod::SpecialCompute(int64_t type, int64_t start, int64_t end, const T *input1, const T *input2,
|
||||
T *output) {
|
||||
switch (type) {
|
||||
case kSameShape: {
|
||||
auto cur_input1 = input1 + start;
|
||||
auto cur_input2 = input2 + start;
|
||||
for (int64_t i = start; i < end; ++i) {
|
||||
*output = IgammacSingle<T>(*cur_input1, *cur_input2);
|
||||
output = output + 1;
|
||||
cur_input1 = cur_input1 + 1;
|
||||
cur_input2 = cur_input2 + 1;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case kXOneElement: {
|
||||
auto cur_input2 = input2 + start;
|
||||
for (int64_t i = start; i < end; ++i) {
|
||||
*output = IgammacSingle<T>(*input1, *cur_input2);
|
||||
output = output + 1;
|
||||
cur_input2 = cur_input2 + 1;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case kYOneElement: {
|
||||
auto cur_input1 = input1 + start;
|
||||
for (int64_t i = start; i < end; ++i) {
|
||||
*output = IgammacSingle<T>(*cur_input1, *input2);
|
||||
output = output + 1;
|
||||
cur_input1 = cur_input1 + 1;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void IgammacCpuKernelMod::NoBcastCompute(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto in0 = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto in1 = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto out0 = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
size_t in0_elements_nums = get_element_num(a_shape_);
|
||||
size_t in1_elements_nums = get_element_num(x_shape_);
|
||||
size_t data_num = get_element_num(z_shape_);
|
||||
int64_t type =
|
||||
in0_elements_nums == in1_elements_nums ? kSameShape : (in0_elements_nums == 1 ? kXOneElement : kYOneElement);
|
||||
if (data_num < kParallelDataNums) {
|
||||
SpecialCompute<T>(type, 0, data_num, in0, in1, out0);
|
||||
} else {
|
||||
auto shard_igammac = [type, in0, in1, out0, this](int64_t start, int64_t end) {
|
||||
SpecialCompute<T>(type, start, end, in0, in1, out0 + start);
|
||||
};
|
||||
ParallelLaunchAutoSearch(shard_igammac, data_num, this, ¶llel_search_info_);
|
||||
}
|
||||
}
|
||||
|
||||
void IgammacCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
a_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
z_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
}
|
||||
|
||||
bool IgammacCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (dtype_ == kNumberTypeFloat32) {
|
||||
LaunchKernel<float>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat64) {
|
||||
LaunchKernel<double>(inputs, outputs);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dtype of 'var' should be float32 or float64, but got "
|
||||
<< TypeIdToType(dtype_)->ToString();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void IgammacCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_);
|
||||
size_t in0_elements_nums = get_element_num(a_shape_);
|
||||
size_t in1_elements_nums = get_element_num(x_shape_);
|
||||
bool isNeedBcast = (a_shape_ == x_shape_) || (in0_elements_nums == 1) || (in1_elements_nums == 1);
|
||||
if (isNeedBcast) {
|
||||
NoBcastCompute<T>(inputs, outputs);
|
||||
} else {
|
||||
BcastCompute<T>(inputs, outputs);
|
||||
}
|
||||
}
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Igammac, IgammacCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,69 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_IGAMMAC_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_IGAMMAC_CPU_KERNEL_H_
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
#include <array>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class IgammacCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
public:
|
||||
IgammacCpuKernelMod() = default;
|
||||
~IgammacCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override {
|
||||
static std::vector<KernelAttr> support_list = {
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64)};
|
||||
return support_list;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<size_t> a_shape_;
|
||||
std::vector<size_t> x_shape_;
|
||||
std::vector<size_t> z_shape_;
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
template <typename T>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
template <typename T>
|
||||
void BcastCompute(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &);
|
||||
|
||||
template <typename T>
|
||||
void SpecialCompute(int64_t, int64_t, int64_t, const T *, const T *, T *);
|
||||
|
||||
template <typename T>
|
||||
void NoBcastCompute(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &);
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_IGAMMAC_CPU_KERNEL_H_
|
|
@ -0,0 +1,413 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "plugin/device/cpu/kernel/igammagrada_cpu_kernel.h"
|
||||
#include <limits>
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
/**
|
||||
* Coefficients for the Lanczos approximation of the gamma function. The
|
||||
* coefficients are uniquely determined by the choice of g and n (kLanczosGamma
|
||||
* and kLanczosCoefficients.size() + 1). The coefficients below correspond to
|
||||
* [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and [7,
|
||||
* 9] seemed to be the least sensitive to the quality of the log function. In
|
||||
* particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5
|
||||
* for a particularly inaccurate log function.
|
||||
* */
|
||||
static constexpr double kLanczosGamma = 7; // aka g
|
||||
constexpr size_t kInputIndex0 = 0;
|
||||
constexpr size_t kInputIndex1 = 1;
|
||||
constexpr size_t kInputIndex2 = 2;
|
||||
constexpr size_t kInputIndex3 = 3;
|
||||
constexpr size_t kInputIndex4 = 4;
|
||||
constexpr size_t kInputIndex5 = 5;
|
||||
constexpr size_t kInputIndex6 = 6;
|
||||
constexpr size_t kInputIndex7 = 7;
|
||||
constexpr size_t kInputIndex8 = 8;
|
||||
constexpr size_t kInputIndex9 = 9;
|
||||
constexpr size_t kInputIndex10 = 10;
|
||||
constexpr size_t kInputIndex11 = 11;
|
||||
constexpr size_t kInputIndex12 = 12;
|
||||
constexpr size_t kInputIndex13 = 13;
|
||||
constexpr size_t kInputIndex14 = 14;
|
||||
static constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
|
||||
static constexpr double M_pi = 3.141592653589793238462643383279;
|
||||
static constexpr std::array<double, 8> kLanczosCoefficients = {
|
||||
676.520368121885098567009190444019, -1259.13921672240287047156078755283,
|
||||
771.3234287776530788486528258894, -176.61502916214059906584551354,
|
||||
12.507343278686904814458936853, -0.13857109526572011689554707,
|
||||
9.984369578019570859563e-6, 1.50563273514931155834e-7};
|
||||
double log_lanczos_gamma_plus_one_half = std::log(kLanczosGamma + 0.5);
|
||||
constexpr int64_t kParallelDataNums = 256;
|
||||
constexpr int64_t kSameShape = 0;
|
||||
constexpr int64_t kXOneElement = 1;
|
||||
constexpr int64_t kYOneElement = 2;
|
||||
constexpr size_t kInputNum = 2;
|
||||
constexpr size_t kOutputNum = 1;
|
||||
constexpr int64_t VALUE = 1;
|
||||
constexpr int64_t DERIVATIVE = 2;
|
||||
size_t get_element_num(const std::vector<size_t> &shape) {
|
||||
size_t size = 1;
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
size *= shape[i];
|
||||
}
|
||||
return size;
|
||||
}
|
||||
} // namespace
|
||||
/** Compute the Lgamma function using Lanczos' approximation from "A Precision
|
||||
* Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
|
||||
* series B. Vol. 1:
|
||||
* lgamma(z + 1) = (log(2) + log(pi)) / 2 + (z + 1/2) * log(t(z)) - t(z) + A(z)
|
||||
* t(z) = z + kLanczosGamma + 1/2
|
||||
* A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k))
|
||||
*/
|
||||
template <typename T>
|
||||
T Lgamma(const T &input) {
|
||||
T log_pi = std::log(M_pi);
|
||||
T log_sqrt_two_pi = (std::log(2) + std::log(M_pi)) / 2;
|
||||
|
||||
/** If the input is less than 0.5 use Euler's reflection formula:
|
||||
* gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
|
||||
*/
|
||||
bool need_to_reflect = (input < 0.5);
|
||||
T input_after_reflect = need_to_reflect ? -input : input - 1;
|
||||
T sum = kBaseLanczosCoeff;
|
||||
for (size_t i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
|
||||
T lanczos_coefficient = kLanczosCoefficients[i];
|
||||
|
||||
sum += lanczos_coefficient / (input_after_reflect + i + 1);
|
||||
}
|
||||
T gamma_plus_onehalf_plus_z = kLanczosGamma + 0.5 + input_after_reflect;
|
||||
T log_t = log_lanczos_gamma_plus_one_half + std::log1pf(input_after_reflect / (kLanczosGamma + 0.5));
|
||||
T log_y = log_sqrt_two_pi + (input_after_reflect + 0.5 - gamma_plus_onehalf_plus_z / log_t) * log_t + std::log(sum);
|
||||
T abs_input = std::abs(input);
|
||||
T abs_frac_input = abs_input - std::floor(abs_input);
|
||||
|
||||
T reduced_frac_input = (abs_frac_input > 0.5) ? 1 - abs_frac_input : abs_frac_input;
|
||||
T reflection_denom = std::log(std::sin(M_pi * reduced_frac_input));
|
||||
|
||||
T reflection = std::isfinite(reflection_denom) ? log_pi - reflection_denom - log_y : -reflection_denom;
|
||||
T result = need_to_reflect ? reflection : log_y;
|
||||
|
||||
return std::isinf(input) ? std::numeric_limits<T>::infinity() : result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T Digamma(const T &input) {
|
||||
bool need_to_reflect = (input < 0.5);
|
||||
T reflected_input = need_to_reflect ? -input : input - 1;
|
||||
|
||||
T num = 0;
|
||||
T denom = kBaseLanczosCoeff;
|
||||
|
||||
for (size_t i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
|
||||
T lanczos_coefficient = kLanczosCoefficients[i];
|
||||
num -= lanczos_coefficient / ((reflected_input + i + 1) * (reflected_input + i + 1));
|
||||
denom += lanczos_coefficient / (reflected_input + i + 1);
|
||||
}
|
||||
|
||||
T gamma_plus_onehalf_plus_z = kLanczosGamma + 0.5 + reflected_input;
|
||||
T log_t = log_lanczos_gamma_plus_one_half + std::log1pf(reflected_input / (kLanczosGamma + 0.5));
|
||||
|
||||
T result = log_t + num / denom - kLanczosGamma / gamma_plus_onehalf_plus_z;
|
||||
|
||||
T reduced_input = input + std::abs(std::floor(input + 0.5));
|
||||
T reflection = result - M_pi * std::cos(M_pi * reduced_input) / std::sin(M_pi * reduced_input);
|
||||
T real_result = need_to_reflect ? reflection : result;
|
||||
|
||||
// Digamma has poles at negative integers and zero; return nan for those.
|
||||
return (input < 0 && input == std::floor(input)) ? std::numeric_limits<T>::quiet_NaN() : real_result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T use_igammact(const T &ax, const T &a, const T &x, T enabled, int mode) {
|
||||
T y = 1 - a;
|
||||
T z = x + y + 1;
|
||||
T c = 0;
|
||||
T pkm2 = 1;
|
||||
T qkm2 = x;
|
||||
T pkm1 = x + 1;
|
||||
T qkm1 = z * x;
|
||||
T ans = pkm1 / qkm1;
|
||||
T t = 1;
|
||||
T dpkm2_da = 0;
|
||||
T dqkm2_da = 0;
|
||||
T dpkm1_da = 0;
|
||||
T dqkm1_da = -x;
|
||||
T dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1;
|
||||
std::vector<T> vals = {enabled, ans, t, y, z, c, pkm1, qkm1,
|
||||
pkm2, qkm2, dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da};
|
||||
constexpr int k2000 = 2000;
|
||||
while (vals[kInputIndex0] && vals[kInputIndex5] < k2000) {
|
||||
enabled = vals[kInputIndex0];
|
||||
ans = vals[kInputIndex1];
|
||||
T tmp_var_t = vals[kInputIndex2];
|
||||
T tmp_var_y = vals[kInputIndex3];
|
||||
T tmp_var_z = vals[kInputIndex4];
|
||||
T tmp_var_c = vals[kInputIndex5];
|
||||
pkm1 = vals[kInputIndex6];
|
||||
qkm1 = vals[kInputIndex7];
|
||||
pkm2 = vals[kInputIndex8];
|
||||
qkm2 = vals[kInputIndex9];
|
||||
dpkm2_da = vals[kInputIndex10];
|
||||
dqkm2_da = vals[kInputIndex11];
|
||||
dpkm1_da = vals[kInputIndex12];
|
||||
dqkm1_da = vals[kInputIndex13];
|
||||
dans_da = vals[kInputIndex14];
|
||||
tmp_var_c += 1;
|
||||
tmp_var_y += 1;
|
||||
constexpr int TWO = 2;
|
||||
tmp_var_z += TWO;
|
||||
T yc = tmp_var_y * tmp_var_c;
|
||||
T pk = pkm1 * tmp_var_z - pkm2 * yc;
|
||||
T qk = qkm1 * tmp_var_z - qkm2 * yc;
|
||||
bool qk_is_nonzero = (qk != 0);
|
||||
T r = pk / qk;
|
||||
ans = qk_is_nonzero ? r : ans;
|
||||
T dpk_da = dpkm1_da * tmp_var_z - pkm1 - dpkm2_da * yc + pkm2 * tmp_var_c;
|
||||
T dqk_da = dqkm1_da * tmp_var_z - qkm1 - dqkm2_da * yc + qkm2 * tmp_var_c;
|
||||
T dans_da_new = qk_is_nonzero ? (dpk_da - ans * dqk_da) / qk : dans_da;
|
||||
T grad_conditional = qk_is_nonzero ? std::abs(dans_da_new - dans_da) : 1;
|
||||
pkm2 = pkm1;
|
||||
pkm1 = pk;
|
||||
qkm2 = qkm1;
|
||||
qkm1 = qk;
|
||||
dpkm2_da = dpkm1_da;
|
||||
dqkm2_da = dqkm1_da;
|
||||
dpkm1_da = dpk_da;
|
||||
dqkm1_da = dqk_da;
|
||||
bool rescale = std::abs(pk) > (1 / std::numeric_limits<T>::epsilon());
|
||||
pkm2 = rescale ? pkm2 * std::numeric_limits<T>::epsilon() : pkm2;
|
||||
pkm1 = rescale ? pkm1 * std::numeric_limits<T>::epsilon() : pkm1;
|
||||
qkm2 = rescale ? qkm2 * std::numeric_limits<T>::epsilon() : qkm2;
|
||||
qkm1 = rescale ? qkm1 * std::numeric_limits<T>::epsilon() : qkm1;
|
||||
dpkm2_da = rescale ? dpkm2_da * std::numeric_limits<T>::epsilon() : dpkm2_da;
|
||||
dqkm2_da = rescale ? dqkm2_da * std::numeric_limits<T>::epsilon() : dqkm2_da;
|
||||
dpkm1_da = rescale ? dpkm1_da * std::numeric_limits<T>::epsilon() : dpkm1_da;
|
||||
dqkm1_da = rescale ? dqkm1_da * std::numeric_limits<T>::epsilon() : dqkm1_da;
|
||||
T conditional = enabled && (grad_conditional > std::numeric_limits<T>::epsilon());
|
||||
vals[kInputIndex0] = conditional;
|
||||
vals[kInputIndex5] = tmp_var_c;
|
||||
if (enabled) {
|
||||
vals = {conditional, ans, tmp_var_t, tmp_var_y, tmp_var_z, tmp_var_c, pkm1, qkm1,
|
||||
pkm2, qkm2, dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da_new};
|
||||
}
|
||||
}
|
||||
ans = vals[kInputIndex1];
|
||||
if (mode == VALUE) {
|
||||
return ans * ax;
|
||||
}
|
||||
dans_da = vals[kInputIndex14];
|
||||
T dlogax_da = std::log(x) - Digamma<T>(a);
|
||||
switch (mode) {
|
||||
case DERIVATIVE:
|
||||
return ax * (ans * dlogax_da + dans_da);
|
||||
default:
|
||||
return -(dans_da + ans * dlogax_da) * x;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T use_igammacf(const T &ax, const T &a, T x, T enabled) {
|
||||
std::vector<T> vals = {enabled, a, 1, 1, x, 0, 0};
|
||||
while (vals[kInputIndex0] != 0) {
|
||||
enabled = vals[kInputIndex0];
|
||||
T r = vals[kInputIndex1];
|
||||
T c = vals[kInputIndex2];
|
||||
T ans = vals[kInputIndex3];
|
||||
x = vals[kInputIndex4];
|
||||
T dc_da = vals[kInputIndex5];
|
||||
T dans_da = vals[kInputIndex6];
|
||||
r += 1;
|
||||
dc_da = dc_da * (x / r) + (-1 * c * x) / (r * r);
|
||||
dans_da = dans_da + dc_da;
|
||||
c = c * (x / r);
|
||||
ans = ans + c;
|
||||
T conditional = enabled && (std::abs(dc_da / dans_da) > std::numeric_limits<T>::epsilon());
|
||||
vals[kInputIndex0] = conditional;
|
||||
if (enabled) {
|
||||
vals = {conditional, r, c, ans, x, dc_da, dans_da};
|
||||
}
|
||||
}
|
||||
T ans = vals[kInputIndex3];
|
||||
T dans_da = vals[kInputIndex6];
|
||||
if (a == 0) {
|
||||
return NAN;
|
||||
}
|
||||
T dlogax_da = std::log(x) - Digamma<T>(a + 1);
|
||||
return ax * (ans * dlogax_da + dans_da) / a;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T IgammaGradASingle(const T &a, const T &x) {
|
||||
bool is_nan = std::isnan(a) || std::isnan(x);
|
||||
bool x_is_zero = (x == 0);
|
||||
bool domain_error = (x < 0) || (a <= 0);
|
||||
bool use_igammac = (x > 1) && (x > a);
|
||||
T ax = a * std::log(x) - x - Lgamma<T>(a);
|
||||
bool underflow = (ax < -std::log(std::numeric_limits<T>::max()));
|
||||
ax = std::exp(ax);
|
||||
T enabled = static_cast<T>(!(x_is_zero || domain_error || underflow || is_nan));
|
||||
T output;
|
||||
if (use_igammac != 0) {
|
||||
enabled = static_cast<T>(enabled && use_igammac);
|
||||
output = -use_igammact(ax, a, x, enabled, DERIVATIVE);
|
||||
} else {
|
||||
enabled = static_cast<T>(enabled && !(use_igammac));
|
||||
output = use_igammacf(ax, a, x, enabled);
|
||||
}
|
||||
output = (domain_error || is_nan || std::isnan(output)) ? std::numeric_limits<double>::quiet_NaN() : output;
|
||||
output = x_is_zero || (std::isinf(x) && !is_nan && !domain_error && !std::isinf(a)) ? 0 : output;
|
||||
return output;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void IgammaGradACpuKernelMod::BcastCompute(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto a_data_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto x_data_addr = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto z_data_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
size_t data_num = get_element_num(z_shape_);
|
||||
auto output_shape = CPUKernelUtils::GetBroadcastShape(a_shape_, x_shape_);
|
||||
BroadcastIterator iter(a_shape_, x_shape_, output_shape);
|
||||
if (data_num < kParallelDataNums) {
|
||||
iter.SetPos(0);
|
||||
for (size_t i = 0; i < data_num; i++) {
|
||||
T *a_index = a_data_addr + iter.GetInputPosA(); // i-th value of input0
|
||||
T *x_index = x_data_addr + iter.GetInputPosB(); // i-th value of input1
|
||||
*(z_data_addr + i) = IgammaGradASingle<T>(*a_index, *x_index);
|
||||
iter.GenNextPos();
|
||||
}
|
||||
} else {
|
||||
auto shard_igammaGradA = [z_data_addr, a_data_addr, x_data_addr, &iter](size_t start, size_t end) {
|
||||
iter.SetPos(start);
|
||||
for (size_t i = start; i < end; i++) {
|
||||
T *a_index = a_data_addr + iter.GetInputPosA(); // i-th value of input0
|
||||
T *x_index = x_data_addr + iter.GetInputPosB(); // i-th value of input1
|
||||
*(z_data_addr + i) = IgammaGradASingle<T>(*a_index, *x_index);
|
||||
iter.GenNextPos();
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(shard_igammaGradA, data_num, this, ¶llel_search_info_);
|
||||
}
|
||||
}
|
||||
|
||||
/* special compute is used in the following situations.
|
||||
* 1. the shapes of input1 and input2 are the same
|
||||
* 2. input1 is a 1D tensor with only one element or input1 is scalar
|
||||
* 3. input2 is a 1D tensor with only one element or input2 is scalar
|
||||
* 4. the shapes of input1 and input2 are different
|
||||
**/
|
||||
template <typename T>
|
||||
void IgammaGradACpuKernelMod::SpecialCompute(int64_t type, int64_t start, int64_t end, const T *input1, const T *input2,
|
||||
T *output) {
|
||||
switch (type) {
|
||||
case kSameShape: {
|
||||
auto cur_input1 = input1 + start;
|
||||
auto cur_input2 = input2 + start;
|
||||
for (int64_t i = start; i < end; ++i) {
|
||||
*output = IgammaGradASingle<T>(*cur_input1, *cur_input2);
|
||||
output = output + 1;
|
||||
cur_input1 = cur_input1 + 1;
|
||||
cur_input2 = cur_input2 + 1;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case kXOneElement: {
|
||||
auto cur_input2 = input2 + start;
|
||||
for (int64_t i = start; i < end; ++i) {
|
||||
*output = IgammaGradASingle<T>(*input1, *cur_input2);
|
||||
output = output + 1;
|
||||
cur_input2 = cur_input2 + 1;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case kYOneElement: {
|
||||
auto cur_input1 = input1 + start;
|
||||
for (int64_t i = start; i < end; ++i) {
|
||||
*output = IgammaGradASingle<T>(*cur_input1, *input2);
|
||||
output = output + 1;
|
||||
cur_input1 = cur_input1 + 1;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void IgammaGradACpuKernelMod::NoBcastCompute(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto in0 = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto in1 = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto out0 = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
size_t in0_elements_nums = get_element_num(a_shape_);
|
||||
size_t in1_elements_nums = get_element_num(x_shape_);
|
||||
size_t data_num = get_element_num(z_shape_);
|
||||
int64_t type =
|
||||
in0_elements_nums == in1_elements_nums ? kSameShape : (in0_elements_nums == 1 ? kXOneElement : kYOneElement);
|
||||
if (data_num < kParallelDataNums) {
|
||||
SpecialCompute<T>(type, 0, data_num, in0, in1, out0);
|
||||
} else {
|
||||
auto shard_igammaGradA = [type, in0, in1, out0, this](int64_t start, int64_t end) {
|
||||
SpecialCompute<T>(type, start, end, in0, in1, out0 + start);
|
||||
};
|
||||
ParallelLaunchAutoSearch(shard_igammaGradA, data_num, this, ¶llel_search_info_);
|
||||
}
|
||||
}
|
||||
|
||||
void IgammaGradACpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
a_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
z_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
}
|
||||
|
||||
bool IgammaGradACpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (dtype_ == kNumberTypeFloat32) {
|
||||
LaunchKernel<float>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat64) {
|
||||
LaunchKernel<double>(inputs, outputs);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dtype of 'var' should be float32 or float64, but got "
|
||||
<< TypeIdToType(dtype_)->ToString();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void IgammaGradACpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_);
|
||||
size_t in0_elements_nums = get_element_num(a_shape_);
|
||||
size_t in1_elements_nums = get_element_num(x_shape_);
|
||||
bool isNeedBcast = (a_shape_ == x_shape_) || (in0_elements_nums == 1) || (in1_elements_nums == 1);
|
||||
if (isNeedBcast) {
|
||||
NoBcastCompute<T>(inputs, outputs);
|
||||
} else {
|
||||
BcastCompute<T>(inputs, outputs);
|
||||
}
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, IgammaGradA, IgammaGradACpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,69 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_IGAMMAGRADA_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_IGAMMAGRADA_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class IgammaGradACpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
public:
|
||||
IgammaGradACpuKernelMod() = default;
|
||||
~IgammaGradACpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override {
|
||||
static std::vector<KernelAttr> support_list = {
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64)};
|
||||
return support_list;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<size_t> a_shape_;
|
||||
std::vector<size_t> x_shape_;
|
||||
std::vector<size_t> z_shape_;
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
template <typename T>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
template <typename T>
|
||||
void BcastCompute(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &);
|
||||
|
||||
template <typename T>
|
||||
void SpecialCompute(int64_t, int64_t, int64_t, const T *, const T *, T *);
|
||||
|
||||
template <typename T>
|
||||
void NoBcastCompute(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &);
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_IGAMMAGRADA_CPU_KERNEL_H_
|
|
@ -818,6 +818,9 @@ GVAR_DEF(PrimitivePtr, kPrimEinsumGrad, std::make_shared<Primitive>("EinsumGrad"
|
|||
GVAR_DEF(PrimitivePtr, kPrimTrace, std::make_shared<Primitive>("Trace"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTraceGrad, std::make_shared<Primitive>("TraceGrad"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimZeta, std::make_shared<Primitive>("Zeta"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimIgamma, std::make_shared<Primitive>("Igamma"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimIgammac, std::make_shared<Primitive>("Igammac"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimIgammaGradA, std::make_shared<Primitive>("IgammaGradA"));
|
||||
|
||||
// Image
|
||||
GVAR_DEF(PrimitivePtr, kPrimNonMaxSuppressionV3, std::make_shared<Primitive>("NonMaxSuppressionV3"));
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/grad/igammagrada.h"
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr IgammaGradAInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
return BroadCastInferShape(prim_name, input_args);
|
||||
}
|
||||
|
||||
TypePtr IgammaGradAInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
auto a_type = input_args[kInputIndex0]->BuildType();
|
||||
auto x_type = input_args[kInputIndex1]->BuildType();
|
||||
const std::set<TypePtr> valid_types = {kFloat32, kFloat64};
|
||||
std::map<std::string, TypePtr> args;
|
||||
(void)args.insert({"a", a_type});
|
||||
(void)args.insert({"x", x_type});
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(IgammaGradA, BaseOperator);
|
||||
AbstractBasePtr IgammaGradAInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
const int64_t kInputNum = 2;
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, prim_name);
|
||||
auto infer_type = IgammaGradAInferType(primitive, input_args);
|
||||
auto infer_shape = IgammaGradAInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(IgammaGradA, prim::kPrimIgammaGradA, IgammaGradAInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_IGAMMAGRADA_H
|
||||
#define MINDSPORE_CORE_OPS_IGAMMAGRADA_H
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameIgammaGradA = "IgammaGradA";
|
||||
/// \brief Computes the gradient of igamma(a, x) wrt a.
|
||||
class MIND_API IgammaGradA : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(IgammaGradA);
|
||||
/// \brief Constructor.
|
||||
IgammaGradA() : BaseOperator(kNameIgammaGradA) { InitIOName({"a", "x"}, {"z"}); }
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr IgammaGradAInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using kPrimIgammaGradAPtr = std::shared_ptr<IgammaGradA>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_OPS_IGAMMAGRADA_H
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ops/igamma.h"
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr IgammaInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
return BroadCastInferShape(prim_name, input_args);
|
||||
}
|
||||
|
||||
TypePtr IgammaInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
auto a_type = input_args[kInputIndex0]->BuildType();
|
||||
auto x_type = input_args[kInputIndex1]->BuildType();
|
||||
const std::set<TypePtr> valid_types = {kFloat32, kFloat64};
|
||||
std::map<std::string, TypePtr> args;
|
||||
(void)args.insert({"a", a_type});
|
||||
(void)args.insert({"x", x_type});
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(Igamma, BaseOperator);
|
||||
AbstractBasePtr IgammaInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
const int64_t kInputNum = 2;
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, prim_name);
|
||||
auto infer_type = IgammaInferType(primitive, input_args);
|
||||
auto infer_shape = IgammaInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Igamma, prim::kPrimIgamma, IgammaInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_IGAMMA_H
|
||||
#define MINDSPORE_CORE_OPS_IGAMMA_H
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameIgamma = "Igamma";
|
||||
/// \brief Calculates lower regularized incomplete Gamma function.
|
||||
/// Refer to Python API @ref mindspore.ops.Igamma for more details.
|
||||
class MIND_API Igamma : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(Igamma);
|
||||
/// \brief Constructor.
|
||||
Igamma() : BaseOperator(kNameIgamma) { InitIOName({"a", "x"}, {"z"}); }
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr IgammaInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using kPrimIgammaPtr = std::shared_ptr<Igamma>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_OPS_IGAMMA_H
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/igammac.h"
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr IgammacInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
return BroadCastInferShape(prim_name, input_args);
|
||||
}
|
||||
|
||||
TypePtr IgammacInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
auto a_type = input_args[kInputIndex0]->BuildType();
|
||||
auto x_type = input_args[kInputIndex1]->BuildType();
|
||||
const std::set<TypePtr> valid_types = {kFloat32, kFloat64};
|
||||
std::map<std::string, TypePtr> args;
|
||||
(void)args.insert({"a", a_type});
|
||||
(void)args.insert({"x", x_type});
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(Igammac, BaseOperator);
|
||||
AbstractBasePtr IgammacInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
const int64_t kInputNum = 2;
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, prim_name);
|
||||
auto infer_type = IgammacInferType(primitive, input_args);
|
||||
auto infer_shape = IgammacInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Igammac, prim::kPrimIgammac, IgammacInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_IGAMMAC_H
|
||||
#define MINDSPORE_CORE_OPS_IGAMMAC_H
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameIgammac = "Igammac";
|
||||
/// \brief Compute the upper regularized incomplete Gamma function Q(a, x).
|
||||
/// Refer to Python API @ref mindspore.ops.Igammac for more details.
|
||||
class MIND_API Igammac : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(Igammac);
|
||||
/// \brief Constructor.
|
||||
Igammac() : BaseOperator(kNameIgammac) { InitIOName({"a", "x"}, {"z"}); }
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr IgammacInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using kPrimIgammacPtr = std::shared_ptr<Igammac>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_OPS_IGAMMAC_H
|
|
@ -19,14 +19,17 @@ from mindspore.common import dtype as mstype
|
|||
from mindspore import nn
|
||||
import mindspore.numpy as mnp
|
||||
import numpy as np
|
||||
from ...nn.layer import math
|
||||
from .. import functional as F
|
||||
from .. import operations as P
|
||||
from ..operations.math_ops import Trace
|
||||
from ..functional import broadcast_gradient_args
|
||||
from .._grad.grad_base import bprop_getters
|
||||
from .._grad.grad_math_ops import binop_grad_common
|
||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from ..operations import _grad_ops as G
|
||||
from ..operations import math_ops as math
|
||||
from ..operations.math_ops import Igamma, Igammac
|
||||
from ..primitive import constexpr
|
||||
from ..operations.math_ops import ReduceStd
|
||||
|
||||
|
@ -512,3 +515,51 @@ def get_bprop_trace(self):
|
|||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(Igamma)
|
||||
def get_bprop_igamma(self):
|
||||
"""Grad definition for `Igamma` operation."""
|
||||
shape_ = P.Shape()
|
||||
igammagrada = G.IgammaGradA()
|
||||
lgamma = math.LGamma()
|
||||
log_ = P.Log()
|
||||
exp_ = P.Exp()
|
||||
reshape_ = P.Reshape()
|
||||
reduce_sum_ = P.ReduceSum()
|
||||
def bprop(a, x, out, dout):
|
||||
sa = shape_(a)
|
||||
sx = shape_(x)
|
||||
ra, rx = broadcast_gradient_args(sa, sx)
|
||||
partial_a = igammagrada(a, x)
|
||||
partial_x = exp_(-x + (a - 1) * log_(x) - lgamma(a))
|
||||
if ra != () or rx != ():
|
||||
return reshape_(reduce_sum_(partial_a * dout, ra), sa), reshape_(reduce_sum_(partial_x * dout, rx), sx)
|
||||
return reshape_(partial_a * dout, sa), reshape_(partial_x * dout, sx)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(Igammac)
|
||||
def get_bprop_igammac(self):
|
||||
"""Grad definition for `Igammac` operation."""
|
||||
shape_ = P.Shape()
|
||||
igammagrada = G.IgammaGradA()
|
||||
lgamma = math.LGamma()
|
||||
log_ = P.Log()
|
||||
exp_ = P.Exp()
|
||||
reshape_ = P.Reshape()
|
||||
reduce_sum_ = P.ReduceSum()
|
||||
neg_ = P.Neg()
|
||||
def bprop(a, x, out, dout):
|
||||
sa = shape_(a)
|
||||
sx = shape_(x)
|
||||
ra, rx = broadcast_gradient_args(sa, sx)
|
||||
partial_a = igammagrada(a, x)
|
||||
partial_x = exp_(-x + (a - 1) * log_(x) - lgamma(a))
|
||||
if ra != () or rx != ():
|
||||
return neg_(reshape_(reduce_sum_(partial_a * dout, ra), sa)), \
|
||||
neg_(reshape_(reduce_sum_(partial_x * dout, rx), sx))
|
||||
return neg_(reshape_(partial_a * dout, sa)), neg_(reshape_(partial_x * dout, sx))
|
||||
|
||||
return bprop
|
||||
|
|
|
@ -49,6 +49,9 @@ from .asin_grad import _asin_grad_aicpu
|
|||
from .is_finite import _is_finite_aicpu
|
||||
from .is_inf import _is_inf_aicpu
|
||||
from .is_nan import _is_nan_aicpu
|
||||
from .igamma import _igamma_aicpu
|
||||
from .igammac import _igammac_aicpu
|
||||
from .igammagrada import _igammagrada_aicpu
|
||||
from .reshape import _reshape_aicpu
|
||||
from .fill_v2 import _fill_v2_aicpu
|
||||
from .flatten import _flatten_aicpu
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Igamma op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
igamma_op_info = AiCPURegOp("Igamma") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "a", "required") \
|
||||
.input(1, "x", "required") \
|
||||
.output(0, "z", "required") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(igamma_op_info)
|
||||
def _igamma_aicpu():
|
||||
"""Igamma aicpu register"""
|
||||
return
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Igammac op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
igammac_op_info = AiCPURegOp("Igammac") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "a", "required") \
|
||||
.input(1, "x", "required") \
|
||||
.output(0, "z", "required") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(igammac_op_info)
|
||||
def _igammac_aicpu():
|
||||
"""Igammac aicpu register"""
|
||||
return
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Igamma op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
igammagrada_op_info = AiCPURegOp("IgammaGradA") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "a", "required") \
|
||||
.input(1, "x", "required") \
|
||||
.output(0, "z", "required") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(igammagrada_op_info)
|
||||
def _igammagrada_aicpu():
|
||||
"""IgammaGradA aicpu register"""
|
||||
return
|
|
@ -2527,3 +2527,39 @@ class TraceGrad(Primitive):
|
|||
@prim_attr_register
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class IgammaGradA(Primitive):
|
||||
r"""
|
||||
Computes the gradient of igamma(a, x) wrt a.
|
||||
|
||||
Inputs:
|
||||
- **a** (Tensor) - The input tensor. With float32 or float 64 data type.
|
||||
- **x** (Tensor) - The input tensor. With float32 data or float64 type. `x` should have
|
||||
the same dtype with `a`.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same dtype as `a` and `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If a or grad is not a Tensor.
|
||||
TypeError: If dtype of input x and a is not float32 nor float64.
|
||||
TypeError: If x has different dtype with a.
|
||||
ValueError: If `a` could not be broadcast to a tensor with shape of `x`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend````CPU``
|
||||
|
||||
Examples:
|
||||
>>> a = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32))
|
||||
>>> x = Tensor(np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32))
|
||||
>>> igammagrada = G.IgammaGradA()
|
||||
>>> output = igammagrada(a, x)
|
||||
>>> print (output)
|
||||
[-0.2940046 -0.20153049 -0.13028376 -0.08352186]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize IgammaGradA"""
|
||||
self.init_prim_io_names(inputs=['a', 'x'], outputs=['z'])
|
||||
|
|
|
@ -5380,6 +5380,106 @@ class Trunc(Primitive):
|
|||
"""Initialize Trunc"""
|
||||
|
||||
|
||||
class Igamma(Primitive):
|
||||
r"""
|
||||
Calculates lower regularized incomplete Gamma function.
|
||||
The lower regularized incomplete Gamma function is defined as:
|
||||
|
||||
.. math::
|
||||
P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)
|
||||
|
||||
where
|
||||
|
||||
.. math::
|
||||
gamma(a, x) = \int_0^x t^{a-1} \exp^{-t} dt
|
||||
|
||||
is the lower incomplete Gamma function.
|
||||
|
||||
Above :math:`Q(a, x)` is the upper regularized complete Gamma function.
|
||||
|
||||
.. warning::
|
||||
This is an experimental prototype that is subject to change and/or deletion.
|
||||
|
||||
Inputs:
|
||||
- **a** (Tensor) - The input tensor. With type of float32 or float64.
|
||||
- **x** (Tensor) - The input tensor. With float32 or float64 type. `x` should have
|
||||
the same dtype with `a`.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same dtype as `a` and `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If a or x is not a Tensor.
|
||||
TypeError: If dtype of input x and a is not float32 nor float64.
|
||||
TypeError: If x has different dtype with a.
|
||||
ValueError: If `a` could not be broadcast to a tensor with shape of `x`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend````CPU``
|
||||
|
||||
Examples:
|
||||
>>> a = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32))
|
||||
>>> x = Tensor(np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32))
|
||||
>>> igamma = P.IGamma()
|
||||
>>> output = igamma(a, x)
|
||||
>>> print (output)
|
||||
[0.593994 0.35276785 0.21486944 0.13337152]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize Igamma"""
|
||||
self.init_prim_io_names(inputs=['a', 'x'], outputs=['z'])
|
||||
|
||||
|
||||
class Igammac(Primitive):
|
||||
r"""
|
||||
Compute the upper regularized incomplete Gamma function Q(a, x).
|
||||
|
||||
The upper regularized incomplete Gamma function is defined as:
|
||||
\(Q(a, x) = Gamma(a, x) / Gamma(a) = 1 - P(a, x)\)
|
||||
where
|
||||
\(Gamma(a, x) = int_{x}^{\infty} t^{a-1} exp(-t) dt\)
|
||||
|
||||
is the upper incomplete Gama function.
|
||||
|
||||
Note, above P(a, x) (Igamma) is the lower regularized complete Gamma function.
|
||||
|
||||
.. warning::
|
||||
This is an experimental prototype that is subject to change and/or deletion.
|
||||
|
||||
Inputs:
|
||||
- **a** (Tensor) - The input tensor of igammac. With float32 or float64 data type.
|
||||
- **x** (Tensor) - The input tensor of igammac. With float32 or float64 type. `x` should have
|
||||
the same type with `a`.
|
||||
|
||||
Outputs:
|
||||
A Tensor, has the same dtype as `a` and `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of input x and a is not float32 nor float64.
|
||||
TypeError: If a or x is not a Tensor.
|
||||
TypeError: If x has different dtype with a.
|
||||
ValueError: If `a` could not be broadcast to a tensor with shape of `x`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> a = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32))
|
||||
>>> x = Tensor(np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32))
|
||||
>>> igammac = P.IGammac()
|
||||
>>> output = igammac(a, x)
|
||||
>>> print (output)
|
||||
[0.40600586 0.6472318 0.7851304 0.8666283 ]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize Igammac"""
|
||||
self.init_prim_io_names(inputs=['a', 'x'], outputs=['z'])
|
||||
|
||||
|
||||
class IsClose(Primitive):
|
||||
r"""
|
||||
Returns a boolean tensor where two tensors are element-wise equal within a tolerance.
|
||||
|
|
|
@ -25,8 +25,9 @@ from mindspore.common import dtype as mstype
|
|||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops.operations._grad_ops import IgammaGradA
|
||||
from mindspore.ops import prim_attr_register, PrimitiveWithInfer
|
||||
from mindspore.ops.operations.math_ops import Zeta
|
||||
from mindspore.ops.operations.math_ops import Zeta, Igamma, Igammac
|
||||
from ..ut_filter import non_graph_engine
|
||||
from ....mindspore_test_framework.mindspore_test import mindspore_test
|
||||
from ....mindspore_test_framework.pipeline.forward.compile_forward \
|
||||
|
@ -478,7 +479,25 @@ raise_set = [
|
|||
('Zeta', {
|
||||
'block': Zeta(),
|
||||
'desc_inputs': [Tensor(np.array([1, 1, 1, 1], np.float32)),
|
||||
Tensor([0.5, 0.5, 0.5, 0.5], mstype.float32)],
|
||||
Tensor([0.5, 0.5, 0.5, 0.5], mstype.float32)]}),
|
||||
('Igamma', {
|
||||
'block': Igamma(),
|
||||
'desc_inputs': [Tensor(np.array([1.1, 2.2, -4.1], np.float32)),
|
||||
Tensor(np.array([0.2, 1.2, 2.1], np.float32))],
|
||||
'desc_bprop': [Tensor(np.array([2, 3], np.float32)),
|
||||
Tensor(np.array([2, 3], np.float32))],
|
||||
'skip': ['backward']}),
|
||||
('Igammac', {
|
||||
'block': Igammac(),
|
||||
'desc_inputs': [Tensor(np.array([1.1, 2.2, -4.1], np.float32)),
|
||||
Tensor(np.array([0.2, 1.2, 2.1], np.float32))],
|
||||
'desc_bprop': [Tensor(np.array([2, 3], np.float32)),
|
||||
Tensor(np.array([2, 3], np.float32))],
|
||||
'skip': ['backward']}),
|
||||
('IgammaGradA', {
|
||||
'block': IgammaGradA(),
|
||||
'desc_inputs': [Tensor(np.array([1.1, 2.2, 8.1, 2.1], np.float32)),
|
||||
Tensor(np.array([0.2, 1.2, 2.1, 3.4], np.float32))],
|
||||
'skip': ['backward']}),
|
||||
]
|
||||
|
||||
|
|
Loading…
Reference in New Issue