diff --git a/mindspore/nn/layer/math.py b/mindspore/nn/layer/math.py index ddcaf2da6b9..d278437e136 100644 --- a/mindspore/nn/layer/math.py +++ b/mindspore/nn/layer/math.py @@ -14,16 +14,18 @@ # ============================================================================ """math""" import math +import numpy as np from mindspore.ops import operations as P from mindspore.ops.operations import _inner_ops as inner from mindspore.common.tensor import Tensor +from mindspore.ops.primitive import constexpr from ..cell import Cell from ...common import dtype as mstype from ..._checkparam import Validator as validator from ..._checkparam import Rel -__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace'] +__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma'] class ReduceLogSumExp(Cell): @@ -169,3 +171,134 @@ class LinSpace(Cell): lin_space_out = self.lin_space(self.assist, self.start, self.stop, self.num) return lin_space_out + +@constexpr +def check_tensors_dtype_same(data_dtype, value_dtype, op_name): + """Check tensors data type same.""" + if data_dtype in value_dtype: + return True + raise TypeError(f"For '{op_name}', the value data type '{value_dtype}' " + f"is not consistent with assigned tensor data type {data_dtype}.") + +class LGamma(Cell): + r""" + Calculate LGamma using Lanczos' approximation refering to "A Precision Approximationof the Gamma Function". + The algorithm is: + + .. math:: + lgamma(z + 1) = \frac{(\log(2) + \log(pi))}{2} + (z + 1/2) * log(t(z)) - t(z) + A(z) + + t(z) = z + kLanczosGamma + 1/2 + + A(z) = kBaseLanczosCoeff + \sum_{k=1}^n \frac{kLanczosCoefficients[i]}{z + k} + + However, if the input is less than 0.5 use Euler's reflection formula: + + .. math:: + + lgamma(x) = \log(pi) - lgamma(1-x) - \log(abs(sin(pi * x))) + + And please note that + + .. math:: + + lgamma(+/-inf) = +inf + + Thus, the behaviour of LGamma follows: + when x > 0.5, return log(Gamma(x)) + when x < 0.5 and is not an interger, return the real part of Log(Gamma(x)) where Log is the complex logarithm + when x is an integer less or equal to 0, return +inf + when x = +/- inf, return +inf + + Inputs: + - **input_x** (Tensor[Number]) - The input tensor. Only float16, float32 are supported. + + Outputs: + Tensor, has the same shape and dtype as the 'input_x'. + + Examples: + >>> input_x = Tensor(np.array(2, 3, 4).astype(np.float32)) + >>> op = nn.LGamma() + >>> output = op(input_x) + """ + + def __init__(self): + super(LGamma, self).__init__() + # const numbers + self.k_lanczos_gamma = 7 + self.k_base_lanczos_coeff = 0.99999999999980993227684700473478 + self.k_lanczos_coefficients = [676.520368121885098567009190444019, + -1259.13921672240287047156078755283, + 771.3234287776530788486528258894, + -176.61502916214059906584551354, + 12.507343278686904814458936853, + -0.13857109526572011689554707, + 9.984369578019570859563e-6, + 1.50563273514931155834e-7] + self.one_half = 0.5 + self.one = 1 + self.two = 2 + self.inf = np.inf + self.pi = np.pi + self.log_2 = np.log(self.two) + self.log_pi = np.log(np.pi) + self.log_sqrt_two_pi = (self.log_2 + self.log_pi) / self.two + self.lanczos_gamma_plus_one_half = self.k_lanczos_gamma + 0.5 + self.log_lanczos_gamma_plus_one_half = np.log(self.lanczos_gamma_plus_one_half) + + # operations + self.log = P.Log() + self.log1p = P.Log1p() + self.abs = P.Abs() + self.shape = P.Shape() + self.dtype = P.DType() + self.fill = P.Fill() + self.floor = P.Floor() + self.equal = P.Equal() + self.greater = P.Greater() + self.less = P.Less() + self.lessequal = P.LessEqual() + self.select = P.Select() + self.sin = P.Sin() + self.isfinite = P.IsFinite() + + def construct(self, input_x): + input_dtype = self.dtype(input_x) + check_tensors_dtype_same(input_dtype, [mstype.float16, mstype.float32], "LGamma") + infinity = self.fill(input_dtype, self.shape(input_x), self.inf) + + need_to_reflect = self.less(input_x, 0.5) + neg_input = -input_x + z = self.select(need_to_reflect, neg_input, input_x - 1) + + @constexpr + def _calculate_x(z, k_base_lanczos_coeff, k_lanczos_coefficients): + x = k_base_lanczos_coeff + for i in range(8): + product_ = k_lanczos_coefficients[i] / (z + i + 1) + x = product_ + x + return x + x = _calculate_x(z, self.k_base_lanczos_coeff, self.k_lanczos_coefficients) + + t = z + self.lanczos_gamma_plus_one_half + log_t = self.log1p(z / self.lanczos_gamma_plus_one_half) + self.log_lanczos_gamma_plus_one_half + + log_y = self.log(x) + (z + self.one_half - t / log_t) * log_t + self.log_sqrt_two_pi + + abs_input = self.abs(input_x) + abs_frac_input = abs_input - self.floor(abs_input) + input_x = self.select(self.lessequal(input_x, 0.0), + self.select(self.equal(abs_frac_input, 0.0), + infinity, input_x), + input_x) + reduced_frac_input = self.select(self.greater(abs_frac_input, 0.5), + 1 - abs_frac_input, abs_frac_input) + reflection_denom = self.log(self.sin(self.pi * reduced_frac_input)) + + reflection = self.select(self.isfinite(reflection_denom), + -reflection_denom - log_y + self.log_pi, + -reflection_denom) + + result = self.select(need_to_reflect, reflection, log_y) + + return self.select(self.isfinite(input_x), result, infinity) diff --git a/tests/ut/python/ops/test_nn_ops.py b/tests/ut/python/ops/test_nn_ops.py index ed7a8e695e2..9eb1653ed6e 100644 --- a/tests/ut/python/ops/test_nn_ops.py +++ b/tests/ut/python/ops/test_nn_ops.py @@ -584,7 +584,10 @@ test_cases = [ ('ReduceLogSumExp', { 'block': nn.ReduceLogSumExp((0,), False), 'desc_inputs': [Tensor(np.array([3, 4, 5, 6]).astype(np.float32))], - 'desc_bprop': [Tensor(np.array([1, 2, 3, 4]).astype(np.float32))], + 'skip': ['backward']}), + ('LGamma', { + 'block': nn.LGamma(), + 'desc_inputs': [Tensor(np.array([3, 4, 5, 6]).astype(np.float32))], 'skip': ['backward']}), ('FlattenNet', { 'block': FlattenNet(),