Add IGamma operator

This commit is contained in:
peixu_ren 2020-10-01 11:08:37 -04:00
parent 52eb2d3401
commit 5e9178c5b6
2 changed files with 268 additions and 9 deletions

View File

@ -25,7 +25,12 @@ from ...common import dtype as mstype
from ..._checkparam import Validator as validator
__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'MatMul', 'Moments']
__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'IGamma', 'MatMul', 'Moments']
@constexpr
def _check_input_dtype(param_name, input_dtype, allow_dtypes, cls_name):
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
class ReduceLogSumExp(Cell):
@ -43,7 +48,7 @@ class ReduceLogSumExp(Cell):
Default : False.
Inputs:
- **input_x** (Tensor[Number]) - The input tensor. With float16 or float32 data type.
- **input_x** (Tensor) - The input tensor. With float16 or float32 data type.
Outputs:
Tensor, has the same dtype as the `input_x`.
@ -213,7 +218,7 @@ class LGamma(Cell):
when x = +/- inf, return +inf
Inputs:
- **input_x** (Tensor[Number]) - The input tensor. Only float16, float32 are supported.
- **input_x** (Tensor) - The input tensor. Only float16, float32 are supported.
Outputs:
Tensor, has the same shape and dtype as the `input_x`.
@ -267,7 +272,7 @@ class LGamma(Cell):
def construct(self, input_x):
input_dtype = self.dtype(input_x)
check_tensors_dtype_same(input_dtype, [mstype.float16, mstype.float32], "LGamma")
_check_input_dtype("input", input_dtype, [mstype.float16, mstype.float32], self.cls_name)
infinity = self.fill(input_dtype, self.shape(input_x), self.inf)
need_to_reflect = self.less(input_x, 0.5)
@ -307,6 +312,260 @@ class LGamma(Cell):
return self.select(self.isfinite(input_x), result, infinity)
eps_fp16 = Tensor(np.finfo(np.float16).eps, mstype.float16)
eps_fp32 = Tensor(np.finfo(np.float32).eps, mstype.float32)
def _while_helper_func(cond, body, vals):
while cond(vals).any():
vals = body(vals)
return vals
def _IgammaSeries(ax, x, a, enabled):
"""Helper function for computing Igamma using a power series."""
logicaland = P.LogicalAnd()
greater = P.Greater()
fill = P.Fill()
shape = P.Shape()
dtype = P.DType()
select = P.Select()
if dtype(ax) == mstype.float16:
epsilon = eps_fp16
else:
epsilon = eps_fp32
def cond(vals):
enabled = vals[0]
return enabled
def body(vals):
enabled = vals[0]
r = vals[1]
c = vals[2]
ans = vals[3]
x = vals[4]
dc_da = vals[5]
dans_da = vals[6]
r = r + 1
dc_da = dc_da * (x / r) + (-1 * c * x) / (r * r)
dans_da = dans_da + dc_da
c = c * (x / r)
ans = ans + c
conditional = logicaland(enabled, greater(c / ans, epsilon))
return (conditional, select(enabled, r, vals[1]),
select(enabled, c, vals[2]), select(enabled, ans, vals[3]),
select(enabled, x, vals[4]), select(enabled, dc_da, vals[5]),
select(enabled, dans_da, vals[6]))
ones = fill(dtype(a), shape(a), 1)
zeros = fill(dtype(a), shape(a), 0)
vals = (enabled, a, ones, ones, x, zeros, zeros)
vals = _while_helper_func(cond, body, vals)
ans = vals[3]
return (ans * ax) / a
def _IgammacContinuedFraction(ax, x, a, enabled):
"""Helper function for computing Igammac using a continued fraction."""
abs_x = P.Abs()
logicaland = P.LogicalAnd()
greater = P.Greater()
less = P.Less()
notequal = P.NotEqual()
fill = P.Fill()
shape = P.Shape()
dtype = P.DType()
select = P.Select()
if dtype(ax) == mstype.float16:
epsilon = eps_fp16
else:
epsilon = eps_fp32
def cond(vals):
enabled = vals[0]
c = vals[5]
return logicaland(less(c, 2000), enabled)
def body(vals):
enabled = vals[0]
ans = vals[1]
t = vals[2]
y = vals[3]
z = vals[4]
c = vals[5]
pkm1 = vals[6]
qkm1 = vals[7]
pkm2 = vals[8]
qkm2 = vals[9]
dpkm2_da = vals[10]
dqkm2_da = vals[11]
dpkm1_da = vals[12]
dqkm1_da = vals[13]
dans_da = vals[14]
c = c + 1
y = y + 1
z = z + 2
yc = y * c
pk = pkm1 * z - pkm2 * yc
qk = qkm1 * z - qkm2 * yc
qk_is_nonzero = notequal(qk, 0)
r = pk / qk
t = select(qk_is_nonzero, abs_x((ans - r) / r), fill(dtype(t), shape(t), 1))
ans = select(qk_is_nonzero, r, ans)
dpk_da = dpkm1_da * z - pkm1 - dpkm2_da * yc + pkm2 * c
dqk_da = dqkm1_da * z - qkm1 - dqkm2_da * yc + qkm2 * c
dans_da_new = select(qk_is_nonzero, (dpk_da - ans * dqk_da) / qk, dans_da)
grad_conditional = select(qk_is_nonzero,
abs_x(dans_da_new - dans_da),
fill(dtype(dans_da), shape(dans_da), 1))
pkm2 = pkm1
pkm1 = pk
qkm2 = qkm1
qkm1 = qk
dpkm2_da = dpkm1_da
dqkm2_da = dqkm1_da
dpkm1_da = dpk_da
dqkm1_da = dqk_da
rescale = greater(abs_x(pk), 1 / epsilon)
pkm2 = select(rescale, pkm2 * epsilon, pkm2)
pkm1 = select(rescale, pkm1 * epsilon, pkm1)
qkm2 = select(rescale, qkm2 * epsilon, qkm2)
qkm1 = select(rescale, qkm1 * epsilon, qkm1)
dpkm2_da = select(rescale, dpkm2_da * epsilon, dpkm2_da)
dqkm2_da = select(rescale, dqkm2_da * epsilon, dqkm2_da)
dpkm1_da = select(rescale, dpkm1_da * epsilon, dpkm1_da)
dqkm1_da = select(rescale, dqkm1_da * epsilon, dqkm1_da)
conditional = logicaland(enabled, greater(grad_conditional, epsilon))
return (conditional, select(enabled, ans, vals[1]), select(enabled, t, vals[2]),
select(enabled, y, vals[3]), select(enabled, z, vals[4]),
c, select(enabled, pkm1, vals[6]),
select(enabled, qkm1, vals[7]), select(enabled, pkm2, vals[8]),
select(enabled, qkm2, vals[9]), select(enabled, dpkm2_da, vals[10]),
select(enabled, dqkm2_da, vals[11]), select(enabled, dpkm1_da, vals[12]),
select(enabled, dqkm1_da, vals[13]), select(enabled, dans_da_new, vals[14]))
y = 1 - a
z = x + y + 1
c = fill(dtype(x), shape(x), 0)
pkm2 = fill(dtype(x), shape(x), 1)
qkm2 = x
pkm1 = x + 1
qkm1 = z * x
ans = pkm1 / qkm1
t = fill(dtype(x), shape(x), 1)
dpkm2_da = fill(dtype(x), shape(x), 0)
dqkm2_da = fill(dtype(x), shape(x), 0)
dpkm1_da = fill(dtype(x), shape(x), 0)
dqkm1_da = -x
dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1
vals = (enabled, ans, t, y, z, c, pkm1, qkm1, pkm2, qkm2, dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da)
vals = _while_helper_func(cond, body, vals)
ans = vals[1]
return ans * ax
class IGamma(Cell):
r"""
Calculate lower regularized incomplete Gamma function.
The lower regularized incomplete Gamma function is defined as:
.. math::
P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)
where
.. math::
gamma(a, x) = \int_0^x t^{a-1} \exp^{-t} dt
is the lower incomplete Gamma function.
Above :math:`Q(a, x)` is the upper regularized complete Gamma function.
Inputs:
- **a** (Tensor) - The input tensor. With float16 or float32 data type. `a` should have
the same dtype with `x`.
- **x** (Tensor) - The input tensor. With float16 or float32 data type. `x` should have
the same dtype with `a`.
Outputs:
Tensor, has the same dtype as `a` and `x`.
Examples:
>>> input_a = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32))
>>> input_x = Tensor(np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32))
>>> igamma = nn.IGamma()
>>> output = igamma(input_a, input_x)
>>> print (output)
[0.593994 0.35276785 0.21486944 0.13337152]
"""
def __init__(self):
super(IGamma, self).__init__()
# const numbers
self.log_maxfloat16 = Tensor(np.log(np.finfo(np.float16).max), mstype.float16)
self.log_maxfloat32 = Tensor(np.log(np.finfo(np.float32).max), mstype.float32)
# operations
self.logicaland = P.LogicalAnd()
self.logicalor = P.LogicalOr()
self.logicalnot = P.LogicalNot()
self.equal = P.Equal()
self.greater = P.Greater()
self.less = P.Less()
self.neg = P.Neg()
self.log = P.Log()
self.exp = P.Exp()
self.select = P.Select()
self.zeroslike = P.ZerosLike()
self.fill = P.Fill()
self.shape = P.Shape()
self.dtype = P.DType()
self.lgamma = LGamma()
self.const = P.ScalarToArray()
self.cast = P.Cast()
def construct(self, a, x):
a_dtype = self.dtype(a)
x_dtype = self.dtype(x)
_check_input_dtype("input_a", a_dtype, [mstype.float16, mstype.float32], self.cls_name)
_check_input_dtype("input_x", x_dtype, a_dtype, self.cls_name)
x_is_zero = self.equal(x, 0)
domain_error = self.logicalor(self.less(x, 0), self.less(a, 0))
use_igammac = self.logicaland(self.greater(x, 1), self.greater(x, a))
ax = a * self.log(x) - x - self.lgamma(a)
if a_dtype == mstype.float16:
log_maxfloat = self.log_maxfloat16
else:
log_maxfloat = self.log_maxfloat32
underflow = self.less(ax, self.neg(log_maxfloat))
ax = self.exp(ax)
enabled = self.logicalnot(self.logicalor(self.logicalor(x_is_zero, domain_error), underflow))
output = self.select(use_igammac,
1 - _IgammacContinuedFraction(ax, x, a, self.logicaland(enabled, use_igammac)),
_IgammaSeries(ax, x, a, self.logicaland(enabled, self.logicalnot(use_igammac))))
output = self.select(x_is_zero, self.zeroslike(output), output)
output = self.select(domain_error, self.fill(self.dtype(a), self.shape(a), np.nan), output)
return output
@constexpr
def get_broadcast_matmul_shape(x_shape, y_shape):
"""get broadcast_matmul shape"""
@ -453,11 +712,6 @@ class MatMul(Cell):
return matmul_broadcast
@constexpr
def _check_input_dtype(param_name, input_dtype, allow_dtypes, cls_name):
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
class Moments(Cell):
"""
Calculate the mean and variance of `x`.

View File

@ -593,6 +593,11 @@ test_cases = [
'block': nn.LGamma(),
'desc_inputs': [Tensor(np.array([3, 4, 5, 6]).astype(np.float32))],
'skip': ['backward']}),
('IGamma', {
'block': nn.IGamma(),
'desc_inputs': [Tensor(np.array([3, 4, 5, 6]).astype(np.float32)),
Tensor(np.array([3, 4, 5, 6]).astype(np.float32))],
'skip': ['backward']}),
('FlattenNet', {
'block': FlattenNet(),
'desc_inputs': [Tensor(np.ones([1, 2, 3, 4], np.float32))],