forked from mindspore-Ecosystem/mindspore
!7048 Fix some type casting errors in probabilistic programming code
Merge pull request !7048 from peixu_ren/custom_bijector
This commit is contained in:
commit
12094c97ca
|
@ -15,6 +15,7 @@
|
|||
"""Bijector"""
|
||||
from mindspore import context
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from ..distribution._utils.utils import CheckTensor
|
||||
from ..distribution import Distribution
|
||||
|
@ -62,6 +63,10 @@ class Bijector(Cell):
|
|||
self.context_mode = context.get_context('mode')
|
||||
self.checktensor = CheckTensor()
|
||||
|
||||
# ops needed for the base class
|
||||
self.cast_base = P.Cast()
|
||||
self.dtype_base = P.DType()
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
@ -91,6 +96,10 @@ class Bijector(Cell):
|
|||
return value
|
||||
return self.checktensor(value, name)
|
||||
|
||||
def cast_param_by_value(self, value, para):
|
||||
local = self.cast_base(para, self.dtype_base(value))
|
||||
return local
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
"""
|
||||
Forward transformation: transform the input value to another distribution.
|
||||
|
|
|
@ -69,6 +69,8 @@ class PowerTransform(Bijector):
|
|||
validator.check_number("power", power, 0, Rel.GE, self.name)
|
||||
self._power = power
|
||||
self.pow = P.Pow()
|
||||
self.dtypeop = P.DType()
|
||||
self.cast = P.Cast()
|
||||
self.exp = exp_generic
|
||||
self.expm1 = expm1_generic
|
||||
self.log = log_generic
|
||||
|
@ -87,15 +89,21 @@ class PowerTransform(Bijector):
|
|||
|
||||
def _forward(self, x):
|
||||
x = self._check_value(x, 'value')
|
||||
if self.power == 0:
|
||||
return self.exp(x)
|
||||
return self.exp(self.log1p(x * self.power) / self.power)
|
||||
power_local = self.cast_param_by_value(x, self.power)
|
||||
if power_local == 0:
|
||||
forward_v = self.exp(x)
|
||||
else:
|
||||
forward_v = self.exp(self.log1p(x * power_local) / power_local)
|
||||
return forward_v
|
||||
|
||||
def _inverse(self, y):
|
||||
y = self._check_value(y, 'value')
|
||||
if self.power == 0:
|
||||
return self.log(y)
|
||||
return self.expm1(self.log(y) * self.power) / self.power
|
||||
power_local = self.cast_param_by_value(y, self.power)
|
||||
if power_local == 0:
|
||||
inverse_v = self.log(y)
|
||||
else:
|
||||
inverse_v = self.expm1(self.log(y) * power_local) / power_local
|
||||
return inverse_v
|
||||
|
||||
def _forward_log_jacobian(self, x):
|
||||
r"""
|
||||
|
@ -110,9 +118,12 @@ class PowerTransform(Bijector):
|
|||
\log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1)
|
||||
"""
|
||||
x = self._check_value(x, 'value')
|
||||
if self.power == 0:
|
||||
return x
|
||||
return (1. / self.power - 1) * self.log1p(x * self.power)
|
||||
power_local = self.cast_param_by_value(x, self.power)
|
||||
if power_local == 0:
|
||||
forward_log_j = x
|
||||
else:
|
||||
forward_log_j = (1. / power_local - 1) * self.log1p(x * power_local)
|
||||
return forward_log_j
|
||||
|
||||
def _inverse_log_jacobian(self, y):
|
||||
r"""
|
||||
|
@ -127,4 +138,6 @@ class PowerTransform(Bijector):
|
|||
\log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y)
|
||||
"""
|
||||
y = self._check_value(y, 'value')
|
||||
return (self.power - 1) * self.log(y)
|
||||
power_local = self.cast_param_by_value(y, self.power)
|
||||
inverse_log_j = (power_local - 1) * self.log(y)
|
||||
return inverse_log_j
|
||||
|
|
|
@ -76,6 +76,8 @@ class ScalarAffine(Bijector):
|
|||
|
||||
self.abs = P.Abs()
|
||||
self.oneslike = P.OnesLike()
|
||||
self.dtypeop = P.DType()
|
||||
self.cast = P.Cast()
|
||||
self.log = log_generic
|
||||
|
||||
@property
|
||||
|
@ -99,7 +101,10 @@ class ScalarAffine(Bijector):
|
|||
f(x) = a * x + b
|
||||
"""
|
||||
x = self._check_value(x, 'value')
|
||||
return self.scale * x + self.shift * self.oneslike(x)
|
||||
scale_local = self.cast_param_by_value(x, self.scale)
|
||||
shift_local = self.cast_param_by_value(x, self.shift)
|
||||
forward_v = scale_local * x + shift_local * self.oneslike(x)
|
||||
return forward_v
|
||||
|
||||
def _inverse(self, y):
|
||||
r"""
|
||||
|
@ -107,7 +112,10 @@ class ScalarAffine(Bijector):
|
|||
f(y) = \frac{y - b}{a}
|
||||
"""
|
||||
y = self._check_value(y, 'value')
|
||||
return (y - self.shift) / self.scale
|
||||
scale_local = self.cast_param_by_value(y, self.scale)
|
||||
shift_local = self.cast_param_by_value(y, self.shift)
|
||||
inverse_v = (y - shift_local) / scale_local
|
||||
return inverse_v
|
||||
|
||||
def _forward_log_jacobian(self, x):
|
||||
r"""
|
||||
|
@ -117,7 +125,9 @@ class ScalarAffine(Bijector):
|
|||
\log(f'(x)) = \log(a)
|
||||
"""
|
||||
x = self._check_value(x, 'value')
|
||||
return self.log(self.abs(self.scale))
|
||||
scale_local = self.cast_param_by_value(x, self.scale)
|
||||
forward_log_j = self.log(self.abs(scale_local))
|
||||
return forward_log_j
|
||||
|
||||
def _inverse_log_jacobian(self, y):
|
||||
r"""
|
||||
|
@ -127,4 +137,6 @@ class ScalarAffine(Bijector):
|
|||
\log(f'(x)) = - \log(a)
|
||||
"""
|
||||
y = self._check_value(y, 'value')
|
||||
return -1. * self.log(self.abs(self.scale))
|
||||
scale_local = self.cast_param_by_value(y, self.scale)
|
||||
inverse_log_j = -1. * self.log(self.abs(scale_local))
|
||||
return inverse_log_j
|
||||
|
|
|
@ -71,6 +71,7 @@ class Softplus(Bijector):
|
|||
self.expm1 = expm1_generic
|
||||
self.abs = P.Abs()
|
||||
self.dtypeop = P.DType()
|
||||
self.cast = P.Cast()
|
||||
self.fill = P.Fill()
|
||||
self.greater = P.Greater()
|
||||
self.less = P.Less()
|
||||
|
@ -125,8 +126,10 @@ class Softplus(Bijector):
|
|||
|
||||
def _forward(self, x):
|
||||
x = self._check_value(x, 'value')
|
||||
scaled_value = self.sharpness * x
|
||||
return self.softplus(scaled_value) / self.sharpness
|
||||
sharpness_local = self.cast_param_by_value(x, self.sharpness)
|
||||
scaled_value = sharpness_local * x
|
||||
forward_v = self.softplus(scaled_value) / sharpness_local
|
||||
return forward_v
|
||||
|
||||
def _inverse(self, y):
|
||||
r"""
|
||||
|
@ -135,8 +138,10 @@ class Softplus(Bijector):
|
|||
f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k}
|
||||
"""
|
||||
y = self._check_value(y, 'value')
|
||||
scaled_value = self.sharpness * y
|
||||
return self.inverse_softplus(scaled_value) / self.sharpness
|
||||
sharpness_local = self.cast_param_by_value(y, self.sharpness)
|
||||
scaled_value = sharpness_local * y
|
||||
inverse_v = self.inverse_softplus(scaled_value) / sharpness_local
|
||||
return inverse_v
|
||||
|
||||
def _forward_log_jacobian(self, x):
|
||||
r"""
|
||||
|
@ -146,8 +151,10 @@ class Softplus(Bijector):
|
|||
\log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx)
|
||||
"""
|
||||
x = self._check_value(x, 'value')
|
||||
scaled_value = self.sharpness * x
|
||||
return self.log_sigmoid(scaled_value)
|
||||
sharpness_local = self.cast_param_by_value(x, self.sharpness)
|
||||
scaled_value = sharpness_local * x
|
||||
forward_log_j = self.log_sigmoid(scaled_value)
|
||||
return forward_log_j
|
||||
|
||||
def _inverse_log_jacobian(self, y):
|
||||
r"""
|
||||
|
@ -157,5 +164,7 @@ class Softplus(Bijector):
|
|||
\log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky)
|
||||
"""
|
||||
y = self._check_value(y, 'value')
|
||||
scaled_value = self.sharpness * y
|
||||
return scaled_value - self.inverse_softplus(scaled_value)
|
||||
sharpness_local = self.cast_param_by_value(y, self.sharpness)
|
||||
scaled_value = sharpness_local * y
|
||||
inverse_log_j = scaled_value - self.inverse_softplus(scaled_value)
|
||||
return inverse_log_j
|
||||
|
|
|
@ -66,10 +66,10 @@ def normal(shape, mean, stddev, seed=None):
|
|||
|
||||
Args:
|
||||
shape (tuple): The shape of random tensor to be generated.
|
||||
mean (Tensor): The mean μ distribution parameter, which specifies the location of the peak.
|
||||
with float32 data type.
|
||||
stddev (Tensor): The deviation σ distribution parameter. It should be greater than 0.
|
||||
with float32 data type.
|
||||
mean (Tensor): The mean μ distribution parameter, which specifies the location of the peak,
|
||||
with data type in [int8, int16, int32, int64, float16, float32].
|
||||
stddev (Tensor): The deviation σ distribution parameter. It should be greater than 0,
|
||||
with data type in [int8, int16, int32, int64, float16, float32].
|
||||
seed (int): Seed is used as entropy source for the Random number engines to generate pseudo-random numbers.
|
||||
must be non-negative. Default: None, which will be treated as 0.
|
||||
|
||||
|
@ -86,8 +86,8 @@ def normal(shape, mean, stddev, seed=None):
|
|||
"""
|
||||
mean_dtype = F.dtype(mean)
|
||||
stddev_dtype = F.dtype(stddev)
|
||||
const_utils.check_tensors_dtype_same(mean_dtype, mstype.float32, "normal")
|
||||
const_utils.check_tensors_dtype_same(stddev_dtype, mstype.float32, "normal")
|
||||
const_utils.check_valid_type(mean_dtype, mstype.int_type + (mstype.float16, mstype.float32), 'normal')
|
||||
const_utils.check_valid_type(stddev_dtype, mstype.int_type + (mstype.float16, mstype.float32), 'normal')
|
||||
seed1, seed2 = get_seed(seed, "normal")
|
||||
stdnormal = P.StandardNormal(seed1, seed2)
|
||||
random_normal = stdnormal(shape)
|
||||
|
|
Loading…
Reference in New Issue