!7048 Fix some type casting errors in probabilistic programming code

Merge pull request !7048 from peixu_ren/custom_bijector
This commit is contained in:
mindspore-ci-bot 2020-10-09 09:16:44 +08:00 committed by Gitee
commit 12094c97ca
5 changed files with 71 additions and 28 deletions

View File

@ -15,6 +15,7 @@
"""Bijector"""
from mindspore import context
from mindspore.nn.cell import Cell
from mindspore.ops import operations as P
from mindspore._checkparam import Validator as validator
from ..distribution._utils.utils import CheckTensor
from ..distribution import Distribution
@ -62,6 +63,10 @@ class Bijector(Cell):
self.context_mode = context.get_context('mode')
self.checktensor = CheckTensor()
# ops needed for the base class
self.cast_base = P.Cast()
self.dtype_base = P.DType()
@property
def name(self):
return self._name
@ -91,6 +96,10 @@ class Bijector(Cell):
return value
return self.checktensor(value, name)
def cast_param_by_value(self, value, para):
local = self.cast_base(para, self.dtype_base(value))
return local
def forward(self, *args, **kwargs):
"""
Forward transformation: transform the input value to another distribution.

View File

@ -69,6 +69,8 @@ class PowerTransform(Bijector):
validator.check_number("power", power, 0, Rel.GE, self.name)
self._power = power
self.pow = P.Pow()
self.dtypeop = P.DType()
self.cast = P.Cast()
self.exp = exp_generic
self.expm1 = expm1_generic
self.log = log_generic
@ -87,15 +89,21 @@ class PowerTransform(Bijector):
def _forward(self, x):
x = self._check_value(x, 'value')
if self.power == 0:
return self.exp(x)
return self.exp(self.log1p(x * self.power) / self.power)
power_local = self.cast_param_by_value(x, self.power)
if power_local == 0:
forward_v = self.exp(x)
else:
forward_v = self.exp(self.log1p(x * power_local) / power_local)
return forward_v
def _inverse(self, y):
y = self._check_value(y, 'value')
if self.power == 0:
return self.log(y)
return self.expm1(self.log(y) * self.power) / self.power
power_local = self.cast_param_by_value(y, self.power)
if power_local == 0:
inverse_v = self.log(y)
else:
inverse_v = self.expm1(self.log(y) * power_local) / power_local
return inverse_v
def _forward_log_jacobian(self, x):
r"""
@ -110,9 +118,12 @@ class PowerTransform(Bijector):
\log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1)
"""
x = self._check_value(x, 'value')
if self.power == 0:
return x
return (1. / self.power - 1) * self.log1p(x * self.power)
power_local = self.cast_param_by_value(x, self.power)
if power_local == 0:
forward_log_j = x
else:
forward_log_j = (1. / power_local - 1) * self.log1p(x * power_local)
return forward_log_j
def _inverse_log_jacobian(self, y):
r"""
@ -127,4 +138,6 @@ class PowerTransform(Bijector):
\log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y)
"""
y = self._check_value(y, 'value')
return (self.power - 1) * self.log(y)
power_local = self.cast_param_by_value(y, self.power)
inverse_log_j = (power_local - 1) * self.log(y)
return inverse_log_j

View File

@ -76,6 +76,8 @@ class ScalarAffine(Bijector):
self.abs = P.Abs()
self.oneslike = P.OnesLike()
self.dtypeop = P.DType()
self.cast = P.Cast()
self.log = log_generic
@property
@ -99,7 +101,10 @@ class ScalarAffine(Bijector):
f(x) = a * x + b
"""
x = self._check_value(x, 'value')
return self.scale * x + self.shift * self.oneslike(x)
scale_local = self.cast_param_by_value(x, self.scale)
shift_local = self.cast_param_by_value(x, self.shift)
forward_v = scale_local * x + shift_local * self.oneslike(x)
return forward_v
def _inverse(self, y):
r"""
@ -107,7 +112,10 @@ class ScalarAffine(Bijector):
f(y) = \frac{y - b}{a}
"""
y = self._check_value(y, 'value')
return (y - self.shift) / self.scale
scale_local = self.cast_param_by_value(y, self.scale)
shift_local = self.cast_param_by_value(y, self.shift)
inverse_v = (y - shift_local) / scale_local
return inverse_v
def _forward_log_jacobian(self, x):
r"""
@ -117,7 +125,9 @@ class ScalarAffine(Bijector):
\log(f'(x)) = \log(a)
"""
x = self._check_value(x, 'value')
return self.log(self.abs(self.scale))
scale_local = self.cast_param_by_value(x, self.scale)
forward_log_j = self.log(self.abs(scale_local))
return forward_log_j
def _inverse_log_jacobian(self, y):
r"""
@ -127,4 +137,6 @@ class ScalarAffine(Bijector):
\log(f'(x)) = - \log(a)
"""
y = self._check_value(y, 'value')
return -1. * self.log(self.abs(self.scale))
scale_local = self.cast_param_by_value(y, self.scale)
inverse_log_j = -1. * self.log(self.abs(scale_local))
return inverse_log_j

View File

@ -71,6 +71,7 @@ class Softplus(Bijector):
self.expm1 = expm1_generic
self.abs = P.Abs()
self.dtypeop = P.DType()
self.cast = P.Cast()
self.fill = P.Fill()
self.greater = P.Greater()
self.less = P.Less()
@ -125,8 +126,10 @@ class Softplus(Bijector):
def _forward(self, x):
x = self._check_value(x, 'value')
scaled_value = self.sharpness * x
return self.softplus(scaled_value) / self.sharpness
sharpness_local = self.cast_param_by_value(x, self.sharpness)
scaled_value = sharpness_local * x
forward_v = self.softplus(scaled_value) / sharpness_local
return forward_v
def _inverse(self, y):
r"""
@ -135,8 +138,10 @@ class Softplus(Bijector):
f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k}
"""
y = self._check_value(y, 'value')
scaled_value = self.sharpness * y
return self.inverse_softplus(scaled_value) / self.sharpness
sharpness_local = self.cast_param_by_value(y, self.sharpness)
scaled_value = sharpness_local * y
inverse_v = self.inverse_softplus(scaled_value) / sharpness_local
return inverse_v
def _forward_log_jacobian(self, x):
r"""
@ -146,8 +151,10 @@ class Softplus(Bijector):
\log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx)
"""
x = self._check_value(x, 'value')
scaled_value = self.sharpness * x
return self.log_sigmoid(scaled_value)
sharpness_local = self.cast_param_by_value(x, self.sharpness)
scaled_value = sharpness_local * x
forward_log_j = self.log_sigmoid(scaled_value)
return forward_log_j
def _inverse_log_jacobian(self, y):
r"""
@ -157,5 +164,7 @@ class Softplus(Bijector):
\log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky)
"""
y = self._check_value(y, 'value')
scaled_value = self.sharpness * y
return scaled_value - self.inverse_softplus(scaled_value)
sharpness_local = self.cast_param_by_value(y, self.sharpness)
scaled_value = sharpness_local * y
inverse_log_j = scaled_value - self.inverse_softplus(scaled_value)
return inverse_log_j

View File

@ -66,10 +66,10 @@ def normal(shape, mean, stddev, seed=None):
Args:
shape (tuple): The shape of random tensor to be generated.
mean (Tensor): The mean μ distribution parameter, which specifies the location of the peak.
with float32 data type.
stddev (Tensor): The deviation σ distribution parameter. It should be greater than 0.
with float32 data type.
mean (Tensor): The mean μ distribution parameter, which specifies the location of the peak,
with data type in [int8, int16, int32, int64, float16, float32].
stddev (Tensor): The deviation σ distribution parameter. It should be greater than 0,
with data type in [int8, int16, int32, int64, float16, float32].
seed (int): Seed is used as entropy source for the Random number engines to generate pseudo-random numbers.
must be non-negative. Default: None, which will be treated as 0.
@ -86,8 +86,8 @@ def normal(shape, mean, stddev, seed=None):
"""
mean_dtype = F.dtype(mean)
stddev_dtype = F.dtype(stddev)
const_utils.check_tensors_dtype_same(mean_dtype, mstype.float32, "normal")
const_utils.check_tensors_dtype_same(stddev_dtype, mstype.float32, "normal")
const_utils.check_valid_type(mean_dtype, mstype.int_type + (mstype.float16, mstype.float32), 'normal')
const_utils.check_valid_type(stddev_dtype, mstype.int_type + (mstype.float16, mstype.float32), 'normal')
seed1, seed2 = get_seed(seed, "normal")
stdnormal = P.StandardNormal(seed1, seed2)
random_normal = stdnormal(shape)