!3982 Add LGamma op
Merge pull request !3982 from peixu_ren/custom_pp_ops
This commit is contained in:
commit
e7df54166c
|
@ -14,16 +14,18 @@
|
|||
# ============================================================================
|
||||
"""math"""
|
||||
import math
|
||||
import numpy as np
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from ..cell import Cell
|
||||
from ...common import dtype as mstype
|
||||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import Rel
|
||||
|
||||
|
||||
__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace']
|
||||
__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma']
|
||||
|
||||
|
||||
class ReduceLogSumExp(Cell):
|
||||
|
@ -169,3 +171,134 @@ class LinSpace(Cell):
|
|||
|
||||
lin_space_out = self.lin_space(self.assist, self.start, self.stop, self.num)
|
||||
return lin_space_out
|
||||
|
||||
@constexpr
|
||||
def check_tensors_dtype_same(data_dtype, value_dtype, op_name):
|
||||
"""Check tensors data type same."""
|
||||
if data_dtype in value_dtype:
|
||||
return True
|
||||
raise TypeError(f"For '{op_name}', the value data type '{value_dtype}' "
|
||||
f"is not consistent with assigned tensor data type {data_dtype}.")
|
||||
|
||||
class LGamma(Cell):
|
||||
r"""
|
||||
Calculate LGamma using Lanczos' approximation refering to "A Precision Approximationof the Gamma Function".
|
||||
The algorithm is:
|
||||
|
||||
.. math::
|
||||
lgamma(z + 1) = \frac{(\log(2) + \log(pi))}{2} + (z + 1/2) * log(t(z)) - t(z) + A(z)
|
||||
|
||||
t(z) = z + kLanczosGamma + 1/2
|
||||
|
||||
A(z) = kBaseLanczosCoeff + \sum_{k=1}^n \frac{kLanczosCoefficients[i]}{z + k}
|
||||
|
||||
However, if the input is less than 0.5 use Euler's reflection formula:
|
||||
|
||||
.. math::
|
||||
|
||||
lgamma(x) = \log(pi) - lgamma(1-x) - \log(abs(sin(pi * x)))
|
||||
|
||||
And please note that
|
||||
|
||||
.. math::
|
||||
|
||||
lgamma(+/-inf) = +inf
|
||||
|
||||
Thus, the behaviour of LGamma follows:
|
||||
when x > 0.5, return log(Gamma(x))
|
||||
when x < 0.5 and is not an interger, return the real part of Log(Gamma(x)) where Log is the complex logarithm
|
||||
when x is an integer less or equal to 0, return +inf
|
||||
when x = +/- inf, return +inf
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor[Number]) - The input tensor. Only float16, float32 are supported.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape and dtype as the 'input_x'.
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array(2, 3, 4).astype(np.float32))
|
||||
>>> op = nn.LGamma()
|
||||
>>> output = op(input_x)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(LGamma, self).__init__()
|
||||
# const numbers
|
||||
self.k_lanczos_gamma = 7
|
||||
self.k_base_lanczos_coeff = 0.99999999999980993227684700473478
|
||||
self.k_lanczos_coefficients = [676.520368121885098567009190444019,
|
||||
-1259.13921672240287047156078755283,
|
||||
771.3234287776530788486528258894,
|
||||
-176.61502916214059906584551354,
|
||||
12.507343278686904814458936853,
|
||||
-0.13857109526572011689554707,
|
||||
9.984369578019570859563e-6,
|
||||
1.50563273514931155834e-7]
|
||||
self.one_half = 0.5
|
||||
self.one = 1
|
||||
self.two = 2
|
||||
self.inf = np.inf
|
||||
self.pi = np.pi
|
||||
self.log_2 = np.log(self.two)
|
||||
self.log_pi = np.log(np.pi)
|
||||
self.log_sqrt_two_pi = (self.log_2 + self.log_pi) / self.two
|
||||
self.lanczos_gamma_plus_one_half = self.k_lanczos_gamma + 0.5
|
||||
self.log_lanczos_gamma_plus_one_half = np.log(self.lanczos_gamma_plus_one_half)
|
||||
|
||||
# operations
|
||||
self.log = P.Log()
|
||||
self.log1p = P.Log1p()
|
||||
self.abs = P.Abs()
|
||||
self.shape = P.Shape()
|
||||
self.dtype = P.DType()
|
||||
self.fill = P.Fill()
|
||||
self.floor = P.Floor()
|
||||
self.equal = P.Equal()
|
||||
self.greater = P.Greater()
|
||||
self.less = P.Less()
|
||||
self.lessequal = P.LessEqual()
|
||||
self.select = P.Select()
|
||||
self.sin = P.Sin()
|
||||
self.isfinite = P.IsFinite()
|
||||
|
||||
def construct(self, input_x):
|
||||
input_dtype = self.dtype(input_x)
|
||||
check_tensors_dtype_same(input_dtype, [mstype.float16, mstype.float32], "LGamma")
|
||||
infinity = self.fill(input_dtype, self.shape(input_x), self.inf)
|
||||
|
||||
need_to_reflect = self.less(input_x, 0.5)
|
||||
neg_input = -input_x
|
||||
z = self.select(need_to_reflect, neg_input, input_x - 1)
|
||||
|
||||
@constexpr
|
||||
def _calculate_x(z, k_base_lanczos_coeff, k_lanczos_coefficients):
|
||||
x = k_base_lanczos_coeff
|
||||
for i in range(8):
|
||||
product_ = k_lanczos_coefficients[i] / (z + i + 1)
|
||||
x = product_ + x
|
||||
return x
|
||||
x = _calculate_x(z, self.k_base_lanczos_coeff, self.k_lanczos_coefficients)
|
||||
|
||||
t = z + self.lanczos_gamma_plus_one_half
|
||||
log_t = self.log1p(z / self.lanczos_gamma_plus_one_half) + self.log_lanczos_gamma_plus_one_half
|
||||
|
||||
log_y = self.log(x) + (z + self.one_half - t / log_t) * log_t + self.log_sqrt_two_pi
|
||||
|
||||
abs_input = self.abs(input_x)
|
||||
abs_frac_input = abs_input - self.floor(abs_input)
|
||||
input_x = self.select(self.lessequal(input_x, 0.0),
|
||||
self.select(self.equal(abs_frac_input, 0.0),
|
||||
infinity, input_x),
|
||||
input_x)
|
||||
reduced_frac_input = self.select(self.greater(abs_frac_input, 0.5),
|
||||
1 - abs_frac_input, abs_frac_input)
|
||||
reflection_denom = self.log(self.sin(self.pi * reduced_frac_input))
|
||||
|
||||
reflection = self.select(self.isfinite(reflection_denom),
|
||||
-reflection_denom - log_y + self.log_pi,
|
||||
-reflection_denom)
|
||||
|
||||
result = self.select(need_to_reflect, reflection, log_y)
|
||||
|
||||
return self.select(self.isfinite(input_x), result, infinity)
|
||||
|
|
|
@ -584,7 +584,10 @@ test_cases = [
|
|||
('ReduceLogSumExp', {
|
||||
'block': nn.ReduceLogSumExp((0,), False),
|
||||
'desc_inputs': [Tensor(np.array([3, 4, 5, 6]).astype(np.float32))],
|
||||
'desc_bprop': [Tensor(np.array([1, 2, 3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
('LGamma', {
|
||||
'block': nn.LGamma(),
|
||||
'desc_inputs': [Tensor(np.array([3, 4, 5, 6]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
('FlattenNet', {
|
||||
'block': FlattenNet(),
|
||||
|
|
Loading…
Reference in New Issue