forked from mindspore-Ecosystem/mindspore
Add IGamma operator
This commit is contained in:
parent
52eb2d3401
commit
5e9178c5b6
|
@ -25,7 +25,12 @@ from ...common import dtype as mstype
|
|||
from ..._checkparam import Validator as validator
|
||||
|
||||
|
||||
__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'MatMul', 'Moments']
|
||||
__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'IGamma', 'MatMul', 'Moments']
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_input_dtype(param_name, input_dtype, allow_dtypes, cls_name):
|
||||
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
|
||||
|
||||
|
||||
class ReduceLogSumExp(Cell):
|
||||
|
@ -43,7 +48,7 @@ class ReduceLogSumExp(Cell):
|
|||
Default : False.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor[Number]) - The input tensor. With float16 or float32 data type.
|
||||
- **input_x** (Tensor) - The input tensor. With float16 or float32 data type.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same dtype as the `input_x`.
|
||||
|
@ -213,7 +218,7 @@ class LGamma(Cell):
|
|||
when x = +/- inf, return +inf
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor[Number]) - The input tensor. Only float16, float32 are supported.
|
||||
- **input_x** (Tensor) - The input tensor. Only float16, float32 are supported.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape and dtype as the `input_x`.
|
||||
|
@ -267,7 +272,7 @@ class LGamma(Cell):
|
|||
|
||||
def construct(self, input_x):
|
||||
input_dtype = self.dtype(input_x)
|
||||
check_tensors_dtype_same(input_dtype, [mstype.float16, mstype.float32], "LGamma")
|
||||
_check_input_dtype("input", input_dtype, [mstype.float16, mstype.float32], self.cls_name)
|
||||
infinity = self.fill(input_dtype, self.shape(input_x), self.inf)
|
||||
|
||||
need_to_reflect = self.less(input_x, 0.5)
|
||||
|
@ -307,6 +312,260 @@ class LGamma(Cell):
|
|||
return self.select(self.isfinite(input_x), result, infinity)
|
||||
|
||||
|
||||
eps_fp16 = Tensor(np.finfo(np.float16).eps, mstype.float16)
|
||||
eps_fp32 = Tensor(np.finfo(np.float32).eps, mstype.float32)
|
||||
|
||||
def _while_helper_func(cond, body, vals):
|
||||
while cond(vals).any():
|
||||
vals = body(vals)
|
||||
return vals
|
||||
|
||||
|
||||
def _IgammaSeries(ax, x, a, enabled):
|
||||
"""Helper function for computing Igamma using a power series."""
|
||||
|
||||
logicaland = P.LogicalAnd()
|
||||
greater = P.Greater()
|
||||
fill = P.Fill()
|
||||
shape = P.Shape()
|
||||
dtype = P.DType()
|
||||
select = P.Select()
|
||||
|
||||
if dtype(ax) == mstype.float16:
|
||||
epsilon = eps_fp16
|
||||
else:
|
||||
epsilon = eps_fp32
|
||||
|
||||
def cond(vals):
|
||||
enabled = vals[0]
|
||||
return enabled
|
||||
|
||||
def body(vals):
|
||||
enabled = vals[0]
|
||||
r = vals[1]
|
||||
c = vals[2]
|
||||
ans = vals[3]
|
||||
x = vals[4]
|
||||
dc_da = vals[5]
|
||||
dans_da = vals[6]
|
||||
|
||||
r = r + 1
|
||||
dc_da = dc_da * (x / r) + (-1 * c * x) / (r * r)
|
||||
dans_da = dans_da + dc_da
|
||||
c = c * (x / r)
|
||||
ans = ans + c
|
||||
conditional = logicaland(enabled, greater(c / ans, epsilon))
|
||||
|
||||
return (conditional, select(enabled, r, vals[1]),
|
||||
select(enabled, c, vals[2]), select(enabled, ans, vals[3]),
|
||||
select(enabled, x, vals[4]), select(enabled, dc_da, vals[5]),
|
||||
select(enabled, dans_da, vals[6]))
|
||||
|
||||
ones = fill(dtype(a), shape(a), 1)
|
||||
zeros = fill(dtype(a), shape(a), 0)
|
||||
vals = (enabled, a, ones, ones, x, zeros, zeros)
|
||||
|
||||
vals = _while_helper_func(cond, body, vals)
|
||||
ans = vals[3]
|
||||
return (ans * ax) / a
|
||||
|
||||
|
||||
def _IgammacContinuedFraction(ax, x, a, enabled):
|
||||
"""Helper function for computing Igammac using a continued fraction."""
|
||||
|
||||
abs_x = P.Abs()
|
||||
logicaland = P.LogicalAnd()
|
||||
greater = P.Greater()
|
||||
less = P.Less()
|
||||
notequal = P.NotEqual()
|
||||
fill = P.Fill()
|
||||
shape = P.Shape()
|
||||
dtype = P.DType()
|
||||
select = P.Select()
|
||||
|
||||
if dtype(ax) == mstype.float16:
|
||||
epsilon = eps_fp16
|
||||
else:
|
||||
epsilon = eps_fp32
|
||||
|
||||
def cond(vals):
|
||||
enabled = vals[0]
|
||||
c = vals[5]
|
||||
return logicaland(less(c, 2000), enabled)
|
||||
|
||||
def body(vals):
|
||||
enabled = vals[0]
|
||||
ans = vals[1]
|
||||
t = vals[2]
|
||||
y = vals[3]
|
||||
z = vals[4]
|
||||
c = vals[5]
|
||||
pkm1 = vals[6]
|
||||
qkm1 = vals[7]
|
||||
pkm2 = vals[8]
|
||||
qkm2 = vals[9]
|
||||
|
||||
dpkm2_da = vals[10]
|
||||
dqkm2_da = vals[11]
|
||||
dpkm1_da = vals[12]
|
||||
dqkm1_da = vals[13]
|
||||
dans_da = vals[14]
|
||||
|
||||
c = c + 1
|
||||
y = y + 1
|
||||
z = z + 2
|
||||
|
||||
yc = y * c
|
||||
pk = pkm1 * z - pkm2 * yc
|
||||
qk = qkm1 * z - qkm2 * yc
|
||||
qk_is_nonzero = notequal(qk, 0)
|
||||
r = pk / qk
|
||||
|
||||
t = select(qk_is_nonzero, abs_x((ans - r) / r), fill(dtype(t), shape(t), 1))
|
||||
ans = select(qk_is_nonzero, r, ans)
|
||||
|
||||
dpk_da = dpkm1_da * z - pkm1 - dpkm2_da * yc + pkm2 * c
|
||||
dqk_da = dqkm1_da * z - qkm1 - dqkm2_da * yc + qkm2 * c
|
||||
dans_da_new = select(qk_is_nonzero, (dpk_da - ans * dqk_da) / qk, dans_da)
|
||||
grad_conditional = select(qk_is_nonzero,
|
||||
abs_x(dans_da_new - dans_da),
|
||||
fill(dtype(dans_da), shape(dans_da), 1))
|
||||
|
||||
pkm2 = pkm1
|
||||
pkm1 = pk
|
||||
qkm2 = qkm1
|
||||
qkm1 = qk
|
||||
|
||||
dpkm2_da = dpkm1_da
|
||||
dqkm2_da = dqkm1_da
|
||||
dpkm1_da = dpk_da
|
||||
dqkm1_da = dqk_da
|
||||
|
||||
rescale = greater(abs_x(pk), 1 / epsilon)
|
||||
pkm2 = select(rescale, pkm2 * epsilon, pkm2)
|
||||
pkm1 = select(rescale, pkm1 * epsilon, pkm1)
|
||||
qkm2 = select(rescale, qkm2 * epsilon, qkm2)
|
||||
qkm1 = select(rescale, qkm1 * epsilon, qkm1)
|
||||
|
||||
dpkm2_da = select(rescale, dpkm2_da * epsilon, dpkm2_da)
|
||||
dqkm2_da = select(rescale, dqkm2_da * epsilon, dqkm2_da)
|
||||
dpkm1_da = select(rescale, dpkm1_da * epsilon, dpkm1_da)
|
||||
dqkm1_da = select(rescale, dqkm1_da * epsilon, dqkm1_da)
|
||||
|
||||
conditional = logicaland(enabled, greater(grad_conditional, epsilon))
|
||||
|
||||
return (conditional, select(enabled, ans, vals[1]), select(enabled, t, vals[2]),
|
||||
select(enabled, y, vals[3]), select(enabled, z, vals[4]),
|
||||
c, select(enabled, pkm1, vals[6]),
|
||||
select(enabled, qkm1, vals[7]), select(enabled, pkm2, vals[8]),
|
||||
select(enabled, qkm2, vals[9]), select(enabled, dpkm2_da, vals[10]),
|
||||
select(enabled, dqkm2_da, vals[11]), select(enabled, dpkm1_da, vals[12]),
|
||||
select(enabled, dqkm1_da, vals[13]), select(enabled, dans_da_new, vals[14]))
|
||||
|
||||
y = 1 - a
|
||||
z = x + y + 1
|
||||
c = fill(dtype(x), shape(x), 0)
|
||||
pkm2 = fill(dtype(x), shape(x), 1)
|
||||
qkm2 = x
|
||||
pkm1 = x + 1
|
||||
qkm1 = z * x
|
||||
ans = pkm1 / qkm1
|
||||
t = fill(dtype(x), shape(x), 1)
|
||||
dpkm2_da = fill(dtype(x), shape(x), 0)
|
||||
dqkm2_da = fill(dtype(x), shape(x), 0)
|
||||
dpkm1_da = fill(dtype(x), shape(x), 0)
|
||||
dqkm1_da = -x
|
||||
dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1
|
||||
vals = (enabled, ans, t, y, z, c, pkm1, qkm1, pkm2, qkm2, dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da)
|
||||
vals = _while_helper_func(cond, body, vals)
|
||||
ans = vals[1]
|
||||
return ans * ax
|
||||
|
||||
|
||||
class IGamma(Cell):
|
||||
r"""
|
||||
Calculate lower regularized incomplete Gamma function.
|
||||
The lower regularized incomplete Gamma function is defined as:
|
||||
|
||||
.. math::
|
||||
P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)
|
||||
|
||||
where
|
||||
|
||||
.. math::
|
||||
gamma(a, x) = \int_0^x t^{a-1} \exp^{-t} dt
|
||||
|
||||
is the lower incomplete Gamma function.
|
||||
|
||||
Above :math:`Q(a, x)` is the upper regularized complete Gamma function.
|
||||
|
||||
Inputs:
|
||||
- **a** (Tensor) - The input tensor. With float16 or float32 data type. `a` should have
|
||||
the same dtype with `x`.
|
||||
- **x** (Tensor) - The input tensor. With float16 or float32 data type. `x` should have
|
||||
the same dtype with `a`.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same dtype as `a` and `x`.
|
||||
|
||||
Examples:
|
||||
>>> input_a = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32))
|
||||
>>> input_x = Tensor(np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32))
|
||||
>>> igamma = nn.IGamma()
|
||||
>>> output = igamma(input_a, input_x)
|
||||
>>> print (output)
|
||||
[0.593994 0.35276785 0.21486944 0.13337152]
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(IGamma, self).__init__()
|
||||
# const numbers
|
||||
self.log_maxfloat16 = Tensor(np.log(np.finfo(np.float16).max), mstype.float16)
|
||||
self.log_maxfloat32 = Tensor(np.log(np.finfo(np.float32).max), mstype.float32)
|
||||
|
||||
# operations
|
||||
self.logicaland = P.LogicalAnd()
|
||||
self.logicalor = P.LogicalOr()
|
||||
self.logicalnot = P.LogicalNot()
|
||||
self.equal = P.Equal()
|
||||
self.greater = P.Greater()
|
||||
self.less = P.Less()
|
||||
self.neg = P.Neg()
|
||||
self.log = P.Log()
|
||||
self.exp = P.Exp()
|
||||
self.select = P.Select()
|
||||
self.zeroslike = P.ZerosLike()
|
||||
self.fill = P.Fill()
|
||||
self.shape = P.Shape()
|
||||
self.dtype = P.DType()
|
||||
self.lgamma = LGamma()
|
||||
self.const = P.ScalarToArray()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, a, x):
|
||||
a_dtype = self.dtype(a)
|
||||
x_dtype = self.dtype(x)
|
||||
_check_input_dtype("input_a", a_dtype, [mstype.float16, mstype.float32], self.cls_name)
|
||||
_check_input_dtype("input_x", x_dtype, a_dtype, self.cls_name)
|
||||
x_is_zero = self.equal(x, 0)
|
||||
domain_error = self.logicalor(self.less(x, 0), self.less(a, 0))
|
||||
use_igammac = self.logicaland(self.greater(x, 1), self.greater(x, a))
|
||||
ax = a * self.log(x) - x - self.lgamma(a)
|
||||
if a_dtype == mstype.float16:
|
||||
log_maxfloat = self.log_maxfloat16
|
||||
else:
|
||||
log_maxfloat = self.log_maxfloat32
|
||||
underflow = self.less(ax, self.neg(log_maxfloat))
|
||||
ax = self.exp(ax)
|
||||
enabled = self.logicalnot(self.logicalor(self.logicalor(x_is_zero, domain_error), underflow))
|
||||
output = self.select(use_igammac,
|
||||
1 - _IgammacContinuedFraction(ax, x, a, self.logicaland(enabled, use_igammac)),
|
||||
_IgammaSeries(ax, x, a, self.logicaland(enabled, self.logicalnot(use_igammac))))
|
||||
output = self.select(x_is_zero, self.zeroslike(output), output)
|
||||
output = self.select(domain_error, self.fill(self.dtype(a), self.shape(a), np.nan), output)
|
||||
return output
|
||||
|
||||
|
||||
@constexpr
|
||||
def get_broadcast_matmul_shape(x_shape, y_shape):
|
||||
"""get broadcast_matmul shape"""
|
||||
|
@ -453,11 +712,6 @@ class MatMul(Cell):
|
|||
return matmul_broadcast
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_input_dtype(param_name, input_dtype, allow_dtypes, cls_name):
|
||||
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
|
||||
|
||||
|
||||
class Moments(Cell):
|
||||
"""
|
||||
Calculate the mean and variance of `x`.
|
||||
|
|
|
@ -593,6 +593,11 @@ test_cases = [
|
|||
'block': nn.LGamma(),
|
||||
'desc_inputs': [Tensor(np.array([3, 4, 5, 6]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
('IGamma', {
|
||||
'block': nn.IGamma(),
|
||||
'desc_inputs': [Tensor(np.array([3, 4, 5, 6]).astype(np.float32)),
|
||||
Tensor(np.array([3, 4, 5, 6]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
('FlattenNet', {
|
||||
'block': FlattenNet(),
|
||||
'desc_inputs': [Tensor(np.ones([1, 2, 3, 4], np.float32))],
|
||||
|
|
Loading…
Reference in New Issue