!4875 Fix issues related to parameter checking, formulas in distributions and bijectors
Merge pull request !4875 from XunDeng/pp_issue_branch
This commit is contained in:
commit
8021dc587d
|
@ -16,6 +16,7 @@
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore._checkparam import Validator as validator
|
from mindspore._checkparam import Validator as validator
|
||||||
from mindspore._checkparam import Rel
|
from mindspore._checkparam import Rel
|
||||||
|
from ..distribution._utils.utils import CheckTensor
|
||||||
from .bijector import Bijector
|
from .bijector import Bijector
|
||||||
|
|
||||||
class PowerTransform(Bijector):
|
class PowerTransform(Bijector):
|
||||||
|
@ -62,6 +63,8 @@ class PowerTransform(Bijector):
|
||||||
self.log1p = self._log1p_by_step
|
self.log1p = self._log1p_by_step
|
||||||
self.expm1 = self._expm1_by_step
|
self.expm1 = self._expm1_by_step
|
||||||
|
|
||||||
|
self.checktensor = CheckTensor()
|
||||||
|
|
||||||
def _log1p_by_step(self, x):
|
def _log1p_by_step(self, x):
|
||||||
"""
|
"""
|
||||||
Log1p ops on GPU device or when device_target == GPU.
|
Log1p ops on GPU device or when device_target == GPU.
|
||||||
|
@ -86,11 +89,13 @@ class PowerTransform(Bijector):
|
||||||
return shape
|
return shape
|
||||||
|
|
||||||
def _forward(self, x):
|
def _forward(self, x):
|
||||||
|
self.checktensor(x, 'x')
|
||||||
if self.power == 0:
|
if self.power == 0:
|
||||||
return self.exp(x)
|
return self.exp(x)
|
||||||
return self.exp(self.log1p(x * self.power) / self.power)
|
return self.exp(self.log1p(x * self.power) / self.power)
|
||||||
|
|
||||||
def _inverse(self, y):
|
def _inverse(self, y):
|
||||||
|
self.checktensor(y, 'y')
|
||||||
if self.power == 0:
|
if self.power == 0:
|
||||||
return self.log(y)
|
return self.log(y)
|
||||||
return self.expm1(self.log(y) * self.power) / self.power
|
return self.expm1(self.log(y) * self.power) / self.power
|
||||||
|
@ -107,6 +112,7 @@ class PowerTransform(Bijector):
|
||||||
f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1}
|
f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1}
|
||||||
\log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1)
|
\log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1)
|
||||||
"""
|
"""
|
||||||
|
self.checktensor(x, 'x')
|
||||||
if self.power == 0:
|
if self.power == 0:
|
||||||
return x
|
return x
|
||||||
return (1. / self.power - 1) * self.log1p(x * self.power)
|
return (1. / self.power - 1) * self.log1p(x * self.power)
|
||||||
|
@ -123,4 +129,5 @@ class PowerTransform(Bijector):
|
||||||
f'(x) = \frac{e^c\log(y)}{y}
|
f'(x) = \frac{e^c\log(y)}{y}
|
||||||
\log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y)
|
\log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y)
|
||||||
"""
|
"""
|
||||||
|
self.checktensor(y, 'y')
|
||||||
return (self.power - 1) * self.log(y)
|
return (self.power - 1) * self.log(y)
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
"""Scalar Affine Bijector"""
|
"""Scalar Affine Bijector"""
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore._checkparam import Validator as validator
|
from mindspore._checkparam import Validator as validator
|
||||||
from ..distribution._utils.utils import cast_to_tensor
|
from ..distribution._utils.utils import cast_to_tensor, CheckTensor
|
||||||
from .bijector import Bijector
|
from .bijector import Bijector
|
||||||
|
|
||||||
class ScalarAffine(Bijector):
|
class ScalarAffine(Bijector):
|
||||||
|
@ -54,8 +54,8 @@ class ScalarAffine(Bijector):
|
||||||
Constructor of scalar affine bijector.
|
Constructor of scalar affine bijector.
|
||||||
"""
|
"""
|
||||||
param = dict(locals())
|
param = dict(locals())
|
||||||
validator.check_value_type('scale', scale, [float], name)
|
validator.check_value_type('scale', scale, [int, float], name)
|
||||||
validator.check_value_type('shift', shift, [float], name)
|
validator.check_value_type('shift', shift, [int, float], name)
|
||||||
self._scale = cast_to_tensor(scale)
|
self._scale = cast_to_tensor(scale)
|
||||||
self._shift = cast_to_tensor(shift)
|
self._shift = cast_to_tensor(shift)
|
||||||
super(ScalarAffine, self).__init__(
|
super(ScalarAffine, self).__init__(
|
||||||
|
@ -65,8 +65,10 @@ class ScalarAffine(Bijector):
|
||||||
dtype=None,
|
dtype=None,
|
||||||
param=param)
|
param=param)
|
||||||
|
|
||||||
|
self.abs = P.Abs()
|
||||||
self.log = P.Log()
|
self.log = P.Log()
|
||||||
self.oneslike = P.OnesLike()
|
|
||||||
|
self.checktensor = CheckTensor()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scale(self):
|
def scale(self):
|
||||||
|
@ -88,6 +90,7 @@ class ScalarAffine(Bijector):
|
||||||
.. math::
|
.. math::
|
||||||
f(x) = a * x + b
|
f(x) = a * x + b
|
||||||
"""
|
"""
|
||||||
|
self.checktensor(x, 'x')
|
||||||
return self.scale * x + self.shift
|
return self.scale * x + self.shift
|
||||||
|
|
||||||
def _inverse(self, y):
|
def _inverse(self, y):
|
||||||
|
@ -95,22 +98,25 @@ class ScalarAffine(Bijector):
|
||||||
.. math::
|
.. math::
|
||||||
f(y) = \frac{y - b}{a}
|
f(y) = \frac{y - b}{a}
|
||||||
"""
|
"""
|
||||||
|
self.checktensor(y, 'y')
|
||||||
return (y - self.shift) / self.scale
|
return (y - self.shift) / self.scale
|
||||||
|
|
||||||
def _forward_log_jacobian(self, value):
|
def _forward_log_jacobian(self, x):
|
||||||
r"""
|
r"""
|
||||||
.. math::
|
.. math::
|
||||||
f(x) = a * x + b
|
f(x) = a * x + b
|
||||||
f'(x) = a
|
f'(x) = a
|
||||||
\log(f'(x)) = \log(a)
|
\log(f'(x)) = \log(a)
|
||||||
"""
|
"""
|
||||||
return self.log(self.scale) * self.oneslike(value)
|
self.checktensor(x, 'x')
|
||||||
|
return self.log(self.abs(self.scale))
|
||||||
|
|
||||||
def _inverse_log_jacobian(self, value):
|
def _inverse_log_jacobian(self, y):
|
||||||
r"""
|
r"""
|
||||||
.. math::
|
.. math::
|
||||||
f(y) = \frac{(y - b)}{a}
|
f(y) = \frac{(y - b)}{a}
|
||||||
f'(x) = \frac{1.0}{a}
|
f'(x) = \frac{1.0}{a}
|
||||||
\log(f'(x)) = - \log(a)
|
\log(f'(x)) = - \log(a)
|
||||||
"""
|
"""
|
||||||
return -1. * self.log(self.scale) * self.oneslike(value)
|
self.checktensor(y, 'y')
|
||||||
|
return -1. * self.log(self.abs(self.scale))
|
||||||
|
|
|
@ -13,10 +13,12 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""Softplus Bijector"""
|
"""Softplus Bijector"""
|
||||||
|
import numpy as np
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
from mindspore.nn.layer.activation import LogSigmoid
|
from mindspore.nn.layer.activation import LogSigmoid
|
||||||
from mindspore._checkparam import Validator as validator
|
from mindspore._checkparam import Validator as validator
|
||||||
from ..distribution._utils.utils import cast_to_tensor
|
from ..distribution._utils.utils import cast_to_tensor, CheckTensor
|
||||||
from .bijector import Bijector
|
from .bijector import Bijector
|
||||||
|
|
||||||
class Softplus(Bijector):
|
class Softplus(Bijector):
|
||||||
|
@ -52,19 +54,28 @@ class Softplus(Bijector):
|
||||||
sharpness=1.0,
|
sharpness=1.0,
|
||||||
name='Softplus'):
|
name='Softplus'):
|
||||||
param = dict(locals())
|
param = dict(locals())
|
||||||
validator.check_value_type('sharpness', sharpness, [float], name)
|
validator.check_value_type('sharpness', sharpness, [int, float], name)
|
||||||
super(Softplus, self).__init__(name=name, param=param)
|
super(Softplus, self).__init__(name=name, param=param)
|
||||||
self._sharpness = cast_to_tensor(sharpness)
|
self._sharpness = cast_to_tensor(sharpness)
|
||||||
|
|
||||||
|
self.abs = P.Abs()
|
||||||
self.exp = P.Exp()
|
self.exp = P.Exp()
|
||||||
self.expm1 = self._expm1_by_step
|
self.expm1 = self._expm1_by_step
|
||||||
|
self.fill = P.Fill()
|
||||||
|
self.greater = P.Greater()
|
||||||
|
self.less = P.Less()
|
||||||
self.log_sigmoid = LogSigmoid()
|
self.log_sigmoid = LogSigmoid()
|
||||||
self.log = P.Log()
|
self.log = P.Log()
|
||||||
|
self.logicalor = P.LogicalOr()
|
||||||
|
self.select = P.Select()
|
||||||
|
self.shape = P.Shape()
|
||||||
self.sigmoid = P.Sigmoid()
|
self.sigmoid = P.Sigmoid()
|
||||||
|
|
||||||
self.softplus = self._softplus
|
self.softplus = self._softplus
|
||||||
self.inverse_softplus = self._inverse_softplus
|
self.inverse_softplus = self._inverse_softplus
|
||||||
|
|
||||||
|
self.checktensor = CheckTensor()
|
||||||
|
self.threshold = np.log(np.finfo(np.float32).eps) + 1
|
||||||
|
|
||||||
def _expm1_by_step(self, x):
|
def _expm1_by_step(self, x):
|
||||||
"""
|
"""
|
||||||
Expm1 ops under GPU context.
|
Expm1 ops under GPU context.
|
||||||
|
@ -72,7 +83,15 @@ class Softplus(Bijector):
|
||||||
return self.exp(x) - 1.0
|
return self.exp(x) - 1.0
|
||||||
|
|
||||||
def _softplus(self, x):
|
def _softplus(self, x):
|
||||||
return self.log(self.exp(x) + 1.0)
|
too_small = self.less(x, self.threshold)
|
||||||
|
too_large = self.greater(x, -self.threshold)
|
||||||
|
too_small_value = self.exp(x)
|
||||||
|
too_large_value = x
|
||||||
|
ones = self.fill(mstype.float32, self.shape(x), 1.0)
|
||||||
|
too_small_or_too_large = self.logicalor(too_small, too_large)
|
||||||
|
x = self.select(too_small_or_too_large, ones, x)
|
||||||
|
y = self.log(self.exp(x) + 1.0)
|
||||||
|
return self.select(too_small, too_small_value, self.select(too_large, too_large_value, y))
|
||||||
|
|
||||||
def _inverse_softplus(self, x):
|
def _inverse_softplus(self, x):
|
||||||
r"""
|
r"""
|
||||||
|
@ -80,7 +99,15 @@ class Softplus(Bijector):
|
||||||
f(x) = \frac{\log(1 + e^{x}))}
|
f(x) = \frac{\log(1 + e^{x}))}
|
||||||
f^{-1}(y) = \frac{\log(e^{y} - 1)}
|
f^{-1}(y) = \frac{\log(e^{y} - 1)}
|
||||||
"""
|
"""
|
||||||
return self.log(self.expm1(x))
|
too_small = self.less(x, self.threshold)
|
||||||
|
too_large = self.greater(x, -self.threshold)
|
||||||
|
too_small_value = self.log(x)
|
||||||
|
too_large_value = x
|
||||||
|
ones = self.fill(mstype.float32, self.shape(x), 1.0)
|
||||||
|
too_small_or_too_large = self.logicalor(too_small, too_large)
|
||||||
|
x = self.select(too_small_or_too_large, ones, x)
|
||||||
|
y = x + self.log(self.abs(self.expm1(-x)))
|
||||||
|
return self.select(too_small, too_small_value, self.select(too_large, too_large_value, y))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sharpness(self):
|
def sharpness(self):
|
||||||
|
@ -94,6 +121,7 @@ class Softplus(Bijector):
|
||||||
return shape
|
return shape
|
||||||
|
|
||||||
def _forward(self, x):
|
def _forward(self, x):
|
||||||
|
self.checktensor(x, 'x')
|
||||||
scaled_value = self.sharpness * x
|
scaled_value = self.sharpness * x
|
||||||
return self.softplus(scaled_value) / self.sharpness
|
return self.softplus(scaled_value) / self.sharpness
|
||||||
|
|
||||||
|
@ -103,6 +131,7 @@ class Softplus(Bijector):
|
||||||
f(x) = \frac{\log(1 + e^{kx}))}{k}
|
f(x) = \frac{\log(1 + e^{kx}))}{k}
|
||||||
f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k}
|
f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k}
|
||||||
"""
|
"""
|
||||||
|
self.checktensor(y, 'y')
|
||||||
scaled_value = self.sharpness * y
|
scaled_value = self.sharpness * y
|
||||||
return self.inverse_softplus(scaled_value) / self.sharpness
|
return self.inverse_softplus(scaled_value) / self.sharpness
|
||||||
|
|
||||||
|
@ -113,6 +142,7 @@ class Softplus(Bijector):
|
||||||
f'(x) = \frac{e^{kx}}{ 1 + e^{kx}}
|
f'(x) = \frac{e^{kx}}{ 1 + e^{kx}}
|
||||||
\log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx)
|
\log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx)
|
||||||
"""
|
"""
|
||||||
|
self.checktensor(x, 'x')
|
||||||
scaled_value = self.sharpness * x
|
scaled_value = self.sharpness * x
|
||||||
return self.log_sigmoid(scaled_value)
|
return self.log_sigmoid(scaled_value)
|
||||||
|
|
||||||
|
@ -123,5 +153,6 @@ class Softplus(Bijector):
|
||||||
f'(y) = \frac{e^{ky}}{e^{ky} - 1}
|
f'(y) = \frac{e^{ky}}{e^{ky} - 1}
|
||||||
\log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky)
|
\log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky)
|
||||||
"""
|
"""
|
||||||
|
self.checktensor(y, 'y')
|
||||||
scaled_value = self.sharpness * y
|
scaled_value = self.sharpness * y
|
||||||
return scaled_value - self.inverse_softplus(scaled_value)
|
return scaled_value - self.inverse_softplus(scaled_value)
|
||||||
|
|
|
@ -15,7 +15,8 @@
|
||||||
"""Utitly functions to help distribution class."""
|
"""Utitly functions to help distribution class."""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mindspore.ops import _utils as utils
|
from mindspore.ops import _utils as utils
|
||||||
from mindspore.ops.primitive import constexpr
|
from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register
|
||||||
|
from mindspore._checkparam import Validator as validator
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore.common.parameter import Parameter
|
from mindspore.common.parameter import Parameter
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
|
@ -53,7 +54,9 @@ def cast_to_tensor(t, hint_type=mstype.float32):
|
||||||
raise TypeError(f'Input cannot be Type Bool')
|
raise TypeError(f'Input cannot be Type Bool')
|
||||||
if isinstance(t, (int, float)):
|
if isinstance(t, (int, float)):
|
||||||
return Tensor(t, dtype=t_type)
|
return Tensor(t, dtype=t_type)
|
||||||
raise TypeError("Input type is not supported.")
|
invalid_type = type(t)
|
||||||
|
raise TypeError(f"Unable to convert input of type {invalid_type} to a Tensor of type {t_type}")
|
||||||
|
|
||||||
|
|
||||||
def convert_to_batch(t, batch_shape, required_type):
|
def convert_to_batch(t, batch_shape, required_type):
|
||||||
"""
|
"""
|
||||||
|
@ -274,5 +277,51 @@ def raise_none_error(name):
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def check_distribution_name(name, expected_name):
|
def check_distribution_name(name, expected_name):
|
||||||
|
if name is None:
|
||||||
|
raise ValueError(f"Distribution should be a constant which is not None.")
|
||||||
if name != expected_name:
|
if name != expected_name:
|
||||||
raise ValueError(f"Distribution should be {expected_name}.")
|
raise ValueError(f"Expected distribution name is {expected_name}, but got {name}.")
|
||||||
|
|
||||||
|
class CheckTuple(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Check if input is a tuple.
|
||||||
|
"""
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""init Cast"""
|
||||||
|
super(CheckTuple, self).__init__("CheckTuple")
|
||||||
|
self.init_prim_io_names(inputs=['x'], outputs=['dummy_output'])
|
||||||
|
|
||||||
|
def __infer__(self, x, name):
|
||||||
|
if not isinstance(x['dtype'], tuple):
|
||||||
|
raise TypeError("Input type should be a tuple: " + name["value"])
|
||||||
|
|
||||||
|
out = {'shape': None,
|
||||||
|
'dtype': None,
|
||||||
|
'value': None}
|
||||||
|
return out
|
||||||
|
|
||||||
|
def __call__(self, *args):
|
||||||
|
return
|
||||||
|
|
||||||
|
class CheckTensor(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Check if input is a Tensor.
|
||||||
|
"""
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""init Cast"""
|
||||||
|
super(CheckTensor, self).__init__("CheckTensor")
|
||||||
|
self.init_prim_io_names(inputs=['x'], outputs=['dummy_output'])
|
||||||
|
|
||||||
|
def __infer__(self, x, name):
|
||||||
|
src_type = x['dtype']
|
||||||
|
validator.check_subclass("input", src_type, [mstype.tensor], name["value"])
|
||||||
|
|
||||||
|
out = {'shape': None,
|
||||||
|
'dtype': None,
|
||||||
|
'value': None}
|
||||||
|
return out
|
||||||
|
|
||||||
|
def __call__(self, *args):
|
||||||
|
return
|
||||||
|
|
|
@ -18,6 +18,7 @@ from mindspore.ops import operations as P
|
||||||
from mindspore.ops import composite as C
|
from mindspore.ops import composite as C
|
||||||
from .distribution import Distribution
|
from .distribution import Distribution
|
||||||
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error
|
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error
|
||||||
|
from ._utils.utils import CheckTensor, CheckTuple
|
||||||
|
|
||||||
class Bernoulli(Distribution):
|
class Bernoulli(Distribution):
|
||||||
"""
|
"""
|
||||||
|
@ -123,6 +124,9 @@ class Bernoulli(Distribution):
|
||||||
self.sqrt = P.Sqrt()
|
self.sqrt = P.Sqrt()
|
||||||
self.uniform = C.uniform
|
self.uniform = C.uniform
|
||||||
|
|
||||||
|
self.checktensor = CheckTensor()
|
||||||
|
self.checktuple = CheckTuple()
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
if self.is_scalar_batch:
|
if self.is_scalar_batch:
|
||||||
str_info = f'probs = {self.probs}'
|
str_info = f'probs = {self.probs}'
|
||||||
|
@ -137,14 +141,21 @@ class Bernoulli(Distribution):
|
||||||
"""
|
"""
|
||||||
return self._probs
|
return self._probs
|
||||||
|
|
||||||
|
def _check_param(self, probs1):
|
||||||
|
"""
|
||||||
|
Check availablity of distribution specific args probs1.
|
||||||
|
"""
|
||||||
|
if probs1 is not None:
|
||||||
|
self.checktensor(probs1, 'probs1')
|
||||||
|
return self.cast(probs1, self.parameter_type)
|
||||||
|
return self.probs if self.probs is not None else raise_none_error('probs1')
|
||||||
|
|
||||||
def _mean(self, probs1=None):
|
def _mean(self, probs1=None):
|
||||||
r"""
|
r"""
|
||||||
.. math::
|
.. math::
|
||||||
MEAN(B) = probs1
|
MEAN(B) = probs1
|
||||||
"""
|
"""
|
||||||
probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs
|
probs1 = self._check_param(probs1)
|
||||||
if probs1 is None:
|
|
||||||
raise_none_error("probs1")
|
|
||||||
return probs1
|
return probs1
|
||||||
|
|
||||||
def _mode(self, probs1=None):
|
def _mode(self, probs1=None):
|
||||||
|
@ -152,9 +163,7 @@ class Bernoulli(Distribution):
|
||||||
.. math::
|
.. math::
|
||||||
MODE(B) = 1 if probs1 > 0.5 else = 0
|
MODE(B) = 1 if probs1 > 0.5 else = 0
|
||||||
"""
|
"""
|
||||||
probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs
|
probs1 = self._check_param(probs1)
|
||||||
if probs1 is None:
|
|
||||||
raise_none_error("probs1")
|
|
||||||
prob_type = self.dtypeop(probs1)
|
prob_type = self.dtypeop(probs1)
|
||||||
zeros = self.fill(prob_type, self.shape(probs1), 0.0)
|
zeros = self.fill(prob_type, self.shape(probs1), 0.0)
|
||||||
ones = self.fill(prob_type, self.shape(probs1), 1.0)
|
ones = self.fill(prob_type, self.shape(probs1), 1.0)
|
||||||
|
@ -166,24 +175,20 @@ class Bernoulli(Distribution):
|
||||||
.. math::
|
.. math::
|
||||||
VAR(B) = probs1 * probs0
|
VAR(B) = probs1 * probs0
|
||||||
"""
|
"""
|
||||||
probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs
|
probs1 = self._check_param(probs1)
|
||||||
if probs1 is None:
|
|
||||||
raise_none_error("probs1")
|
|
||||||
probs0 = 1.0 - probs1
|
probs0 = 1.0 - probs1
|
||||||
return self.exp(self.log(probs0) + self.log(probs1))
|
return self.exp(self.log(probs0) + self.log(probs1))
|
||||||
|
|
||||||
def _entropy(self, probs=None):
|
def _entropy(self, probs1=None):
|
||||||
r"""
|
r"""
|
||||||
.. math::
|
.. math::
|
||||||
H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1)
|
H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1)
|
||||||
"""
|
"""
|
||||||
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs
|
probs1 = self._check_param(probs1)
|
||||||
if probs1 is None:
|
|
||||||
raise_none_error("probs")
|
|
||||||
probs0 = 1 - probs1
|
probs0 = 1 - probs1
|
||||||
return -1 * (probs0 * self.log(probs0)) - (probs1 * self.log(probs1))
|
return -1 * (probs0 * self.log(probs0)) - (probs1 * self.log(probs1))
|
||||||
|
|
||||||
def _cross_entropy(self, dist, probs1_b, probs1_a=None):
|
def _cross_entropy(self, dist, probs1_b, probs1=None):
|
||||||
"""
|
"""
|
||||||
Evaluate cross_entropy between Bernoulli distributions.
|
Evaluate cross_entropy between Bernoulli distributions.
|
||||||
|
|
||||||
|
@ -193,9 +198,9 @@ class Bernoulli(Distribution):
|
||||||
probs1_a (Tensor): probs1 of distribution a. Default: self.probs.
|
probs1_a (Tensor): probs1 of distribution a. Default: self.probs.
|
||||||
"""
|
"""
|
||||||
check_distribution_name(dist, 'Bernoulli')
|
check_distribution_name(dist, 'Bernoulli')
|
||||||
return self._entropy(probs=probs1_a) + self._kl_loss(dist, probs1_b, probs1_a)
|
return self._entropy(probs1) + self._kl_loss(dist, probs1_b, probs1)
|
||||||
|
|
||||||
def _log_prob(self, value, probs=None):
|
def _log_prob(self, value, probs1=None):
|
||||||
r"""
|
r"""
|
||||||
pmf of Bernoulli distribution.
|
pmf of Bernoulli distribution.
|
||||||
|
|
||||||
|
@ -207,17 +212,14 @@ class Bernoulli(Distribution):
|
||||||
pmf(k) = probs1 if k = 1;
|
pmf(k) = probs1 if k = 1;
|
||||||
pmf(k) = probs0 if k = 0;
|
pmf(k) = probs0 if k = 0;
|
||||||
"""
|
"""
|
||||||
if value is None:
|
self.checktensor(value, 'value')
|
||||||
raise_none_error("value")
|
|
||||||
value = self.cast(value, mstype.float32)
|
value = self.cast(value, mstype.float32)
|
||||||
value = self.floor(value)
|
value = self.floor(value)
|
||||||
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs
|
probs1 = self._check_param(probs1)
|
||||||
if probs1 is None:
|
|
||||||
raise_none_error("probs")
|
|
||||||
probs0 = 1.0 - probs1
|
probs0 = 1.0 - probs1
|
||||||
return self.log(probs1) * value + self.log(probs0) * (1.0 - value)
|
return self.log(probs1) * value + self.log(probs0) * (1.0 - value)
|
||||||
|
|
||||||
def _cdf(self, value, probs=None):
|
def _cdf(self, value, probs1=None):
|
||||||
r"""
|
r"""
|
||||||
cdf of Bernoulli distribution.
|
cdf of Bernoulli distribution.
|
||||||
|
|
||||||
|
@ -230,13 +232,10 @@ class Bernoulli(Distribution):
|
||||||
cdf(k) = probs0 if 0 <= k <1;
|
cdf(k) = probs0 if 0 <= k <1;
|
||||||
cdf(k) = 1 if k >=1;
|
cdf(k) = 1 if k >=1;
|
||||||
"""
|
"""
|
||||||
if value is None:
|
self.checktensor(value, 'value')
|
||||||
raise_none_error("value")
|
|
||||||
value = self.cast(value, mstype.float32)
|
value = self.cast(value, mstype.float32)
|
||||||
value = self.floor(value)
|
value = self.floor(value)
|
||||||
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs
|
probs1 = self._check_param(probs1)
|
||||||
if probs1 is None:
|
|
||||||
raise_none_error("probs")
|
|
||||||
prob_type = self.dtypeop(probs1)
|
prob_type = self.dtypeop(probs1)
|
||||||
value = value * self.fill(prob_type, self.shape(probs1), 1.0)
|
value = value * self.fill(prob_type, self.shape(probs1), 1.0)
|
||||||
probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0)
|
probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0)
|
||||||
|
@ -247,7 +246,7 @@ class Bernoulli(Distribution):
|
||||||
less_than_zero = self.select(comp_zero, zeros, probs0)
|
less_than_zero = self.select(comp_zero, zeros, probs0)
|
||||||
return self.select(comp_one, less_than_zero, ones)
|
return self.select(comp_one, less_than_zero, ones)
|
||||||
|
|
||||||
def _kl_loss(self, dist, probs1_b, probs1_a=None):
|
def _kl_loss(self, dist, probs1_b, probs1=None):
|
||||||
r"""
|
r"""
|
||||||
Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b).
|
Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b).
|
||||||
|
|
||||||
|
@ -261,17 +260,14 @@ class Bernoulli(Distribution):
|
||||||
probs0_a * \log(\frac{probs0_a}{probs0_b})
|
probs0_a * \log(\frac{probs0_a}{probs0_b})
|
||||||
"""
|
"""
|
||||||
check_distribution_name(dist, 'Bernoulli')
|
check_distribution_name(dist, 'Bernoulli')
|
||||||
if probs1_b is None:
|
self.checktensor(probs1_b, 'probs1_b')
|
||||||
raise_none_error("probs1_b")
|
|
||||||
probs1_b = self.cast(probs1_b, self.parameter_type)
|
probs1_b = self.cast(probs1_b, self.parameter_type)
|
||||||
probs1_a = self.cast(probs1_a, self.parameter_type) if probs1_a is not None else self.probs
|
probs1_a = self._check_param(probs1)
|
||||||
if probs1_a is None:
|
|
||||||
raise_none_error("probs1_a")
|
|
||||||
probs0_a = 1.0 - probs1_a
|
probs0_a = 1.0 - probs1_a
|
||||||
probs0_b = 1.0 - probs1_b
|
probs0_b = 1.0 - probs1_b
|
||||||
return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b)
|
return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b)
|
||||||
|
|
||||||
def _sample(self, shape=(), probs=None):
|
def _sample(self, shape=(), probs1=None):
|
||||||
"""
|
"""
|
||||||
Sampling.
|
Sampling.
|
||||||
|
|
||||||
|
@ -282,9 +278,8 @@ class Bernoulli(Distribution):
|
||||||
Returns:
|
Returns:
|
||||||
Tensor, shape is shape + batch_shape.
|
Tensor, shape is shape + batch_shape.
|
||||||
"""
|
"""
|
||||||
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs
|
self.checktuple(shape, 'shape')
|
||||||
if probs1 is None:
|
probs1 = self._check_param(probs1)
|
||||||
raise_none_error("probs")
|
|
||||||
origin_shape = shape + self.shape(probs1)
|
origin_shape = shape + self.shape(probs1)
|
||||||
if origin_shape == ():
|
if origin_shape == ():
|
||||||
sample_shape = (1,)
|
sample_shape = (1,)
|
||||||
|
|
|
@ -20,6 +20,7 @@ from mindspore.common import dtype as mstype
|
||||||
from .distribution import Distribution
|
from .distribution import Distribution
|
||||||
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\
|
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\
|
||||||
raise_none_error
|
raise_none_error
|
||||||
|
from ._utils.utils import CheckTensor, CheckTuple
|
||||||
|
|
||||||
class Exponential(Distribution):
|
class Exponential(Distribution):
|
||||||
"""
|
"""
|
||||||
|
@ -125,6 +126,9 @@ class Exponential(Distribution):
|
||||||
self.sq = P.Square()
|
self.sq = P.Square()
|
||||||
self.uniform = C.uniform
|
self.uniform = C.uniform
|
||||||
|
|
||||||
|
self.checktensor = CheckTensor()
|
||||||
|
self.checktuple = CheckTuple()
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
if self.is_scalar_batch:
|
if self.is_scalar_batch:
|
||||||
str_info = f'rate = {self.rate}'
|
str_info = f'rate = {self.rate}'
|
||||||
|
@ -139,14 +143,21 @@ class Exponential(Distribution):
|
||||||
"""
|
"""
|
||||||
return self._rate
|
return self._rate
|
||||||
|
|
||||||
|
def _check_param(self, rate):
|
||||||
|
"""
|
||||||
|
Check availablity of distribution specific args rate.
|
||||||
|
"""
|
||||||
|
if rate is not None:
|
||||||
|
self.checktensor(rate, 'rate')
|
||||||
|
return self.cast(rate, self.parameter_type)
|
||||||
|
return self.rate if self.rate is not None else raise_none_error('rate')
|
||||||
|
|
||||||
def _mean(self, rate=None):
|
def _mean(self, rate=None):
|
||||||
r"""
|
r"""
|
||||||
.. math::
|
.. math::
|
||||||
MEAN(EXP) = \frac{1.0}{\lambda}.
|
MEAN(EXP) = \frac{1.0}{\lambda}.
|
||||||
"""
|
"""
|
||||||
rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate
|
rate = self._check_param(rate)
|
||||||
if rate is None:
|
|
||||||
raise_none_error("rate")
|
|
||||||
return 1.0 / rate
|
return 1.0 / rate
|
||||||
|
|
||||||
def _mode(self, rate=None):
|
def _mode(self, rate=None):
|
||||||
|
@ -154,9 +165,7 @@ class Exponential(Distribution):
|
||||||
.. math::
|
.. math::
|
||||||
MODE(EXP) = 0.
|
MODE(EXP) = 0.
|
||||||
"""
|
"""
|
||||||
rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate
|
rate = self._check_param(rate)
|
||||||
if rate is None:
|
|
||||||
raise_none_error("rate")
|
|
||||||
return self.fill(self.dtype, self.shape(rate), 0.)
|
return self.fill(self.dtype, self.shape(rate), 0.)
|
||||||
|
|
||||||
def _sd(self, rate=None):
|
def _sd(self, rate=None):
|
||||||
|
@ -164,9 +173,7 @@ class Exponential(Distribution):
|
||||||
.. math::
|
.. math::
|
||||||
sd(EXP) = \frac{1.0}{\lambda}.
|
sd(EXP) = \frac{1.0}{\lambda}.
|
||||||
"""
|
"""
|
||||||
rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate
|
rate = self._check_param(rate)
|
||||||
if rate is None:
|
|
||||||
raise_none_error("rate")
|
|
||||||
return 1.0 / rate
|
return 1.0 / rate
|
||||||
|
|
||||||
def _entropy(self, rate=None):
|
def _entropy(self, rate=None):
|
||||||
|
@ -174,13 +181,10 @@ class Exponential(Distribution):
|
||||||
.. math::
|
.. math::
|
||||||
H(Exp) = 1 - \log(\lambda).
|
H(Exp) = 1 - \log(\lambda).
|
||||||
"""
|
"""
|
||||||
rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate
|
rate = self._check_param(rate)
|
||||||
if rate is None:
|
|
||||||
raise_none_error("rate")
|
|
||||||
return 1.0 - self.log(rate)
|
return 1.0 - self.log(rate)
|
||||||
|
|
||||||
|
def _cross_entropy(self, dist, rate_b, rate=None):
|
||||||
def _cross_entropy(self, dist, rate_b, rate_a=None):
|
|
||||||
"""
|
"""
|
||||||
Evaluate cross_entropy between Exponential distributions.
|
Evaluate cross_entropy between Exponential distributions.
|
||||||
|
|
||||||
|
@ -190,7 +194,7 @@ class Exponential(Distribution):
|
||||||
rate_a (Tensor): rate of distribution a. Default: self.rate.
|
rate_a (Tensor): rate of distribution a. Default: self.rate.
|
||||||
"""
|
"""
|
||||||
check_distribution_name(dist, 'Exponential')
|
check_distribution_name(dist, 'Exponential')
|
||||||
return self._entropy(rate=rate_a) + self._kl_loss(dist, rate_b, rate_a)
|
return self._entropy(rate) + self._kl_loss(dist, rate_b, rate)
|
||||||
|
|
||||||
|
|
||||||
def _prob(self, value, rate=None):
|
def _prob(self, value, rate=None):
|
||||||
|
@ -208,12 +212,9 @@ class Exponential(Distribution):
|
||||||
.. math::
|
.. math::
|
||||||
pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0
|
pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0
|
||||||
"""
|
"""
|
||||||
if value is None:
|
self.checktensor(value, "value")
|
||||||
raise_none_error("value")
|
|
||||||
value = self.cast(value, self.dtype)
|
value = self.cast(value, self.dtype)
|
||||||
rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate
|
rate = self._check_param(rate)
|
||||||
if rate is None:
|
|
||||||
raise_none_error("rate")
|
|
||||||
prob = self.exp(self.log(rate) - rate * value)
|
prob = self.exp(self.log(rate) - rate * value)
|
||||||
zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0)
|
zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0)
|
||||||
comp = self.less(value, zeros)
|
comp = self.less(value, zeros)
|
||||||
|
@ -233,19 +234,16 @@ class Exponential(Distribution):
|
||||||
.. math::
|
.. math::
|
||||||
cdf(x) = 1.0 - \exp(-1 * \lambda * x) if x >= 0 else 0
|
cdf(x) = 1.0 - \exp(-1 * \lambda * x) if x >= 0 else 0
|
||||||
"""
|
"""
|
||||||
if value is None:
|
self.checktensor(value, 'value')
|
||||||
raise_none_error("value")
|
|
||||||
value = self.cast(value, self.dtype)
|
value = self.cast(value, self.dtype)
|
||||||
rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate
|
rate = self._check_param(rate)
|
||||||
if rate is None:
|
|
||||||
raise_none_error("rate")
|
|
||||||
cdf = 1.0 - self.exp(-1. * rate * value)
|
cdf = 1.0 - self.exp(-1. * rate * value)
|
||||||
zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
|
zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
|
||||||
comp = self.less(value, zeros)
|
comp = self.less(value, zeros)
|
||||||
return self.select(comp, zeros, cdf)
|
return self.select(comp, zeros, cdf)
|
||||||
|
|
||||||
|
|
||||||
def _kl_loss(self, dist, rate_b, rate_a=None):
|
def _kl_loss(self, dist, rate_b, rate=None):
|
||||||
"""
|
"""
|
||||||
Evaluate exp-exp kl divergence, i.e. KL(a||b).
|
Evaluate exp-exp kl divergence, i.e. KL(a||b).
|
||||||
|
|
||||||
|
@ -255,12 +253,9 @@ class Exponential(Distribution):
|
||||||
rate_a (Tensor): rate of distribution a. Default: self.rate.
|
rate_a (Tensor): rate of distribution a. Default: self.rate.
|
||||||
"""
|
"""
|
||||||
check_distribution_name(dist, 'Exponential')
|
check_distribution_name(dist, 'Exponential')
|
||||||
if rate_b is None:
|
self.checktensor(rate_b, 'rate_b')
|
||||||
raise_none_error("rate_b")
|
|
||||||
rate_b = self.cast(rate_b, self.parameter_type)
|
rate_b = self.cast(rate_b, self.parameter_type)
|
||||||
rate_a = self.cast(rate_a, self.parameter_type) if rate_a is not None else self.rate
|
rate_a = self._check_param(rate)
|
||||||
if rate_a is None:
|
|
||||||
raise_none_error("rate_a")
|
|
||||||
return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0
|
return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0
|
||||||
|
|
||||||
def _sample(self, shape=(), rate=None):
|
def _sample(self, shape=(), rate=None):
|
||||||
|
@ -274,9 +269,8 @@ class Exponential(Distribution):
|
||||||
Returns:
|
Returns:
|
||||||
Tensor, shape is shape + batch_shape.
|
Tensor, shape is shape + batch_shape.
|
||||||
"""
|
"""
|
||||||
rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate
|
self.checktuple(shape, 'shape')
|
||||||
if rate is None:
|
rate = self._check_param(rate)
|
||||||
raise_none_error("rate")
|
|
||||||
origin_shape = shape + self.shape(rate)
|
origin_shape = shape + self.shape(rate)
|
||||||
if origin_shape == ():
|
if origin_shape == ():
|
||||||
sample_shape = (1,)
|
sample_shape = (1,)
|
||||||
|
|
|
@ -20,6 +20,7 @@ from mindspore.common import dtype as mstype
|
||||||
from .distribution import Distribution
|
from .distribution import Distribution
|
||||||
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\
|
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\
|
||||||
raise_none_error
|
raise_none_error
|
||||||
|
from ._utils.utils import CheckTensor, CheckTuple
|
||||||
|
|
||||||
class Geometric(Distribution):
|
class Geometric(Distribution):
|
||||||
"""
|
"""
|
||||||
|
@ -129,6 +130,9 @@ class Geometric(Distribution):
|
||||||
self.sqrt = P.Sqrt()
|
self.sqrt = P.Sqrt()
|
||||||
self.uniform = C.uniform
|
self.uniform = C.uniform
|
||||||
|
|
||||||
|
self.checktensor = CheckTensor()
|
||||||
|
self.checktuple = CheckTuple()
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
if self.is_scalar_batch:
|
if self.is_scalar_batch:
|
||||||
str_info = f'probs = {self.probs}'
|
str_info = f'probs = {self.probs}'
|
||||||
|
@ -143,14 +147,21 @@ class Geometric(Distribution):
|
||||||
"""
|
"""
|
||||||
return self._probs
|
return self._probs
|
||||||
|
|
||||||
|
def _check_param(self, probs1):
|
||||||
|
"""
|
||||||
|
Check availablity of distribution specific args probs1.
|
||||||
|
"""
|
||||||
|
if probs1 is not None:
|
||||||
|
self.checktensor(probs1, 'probs1')
|
||||||
|
return self.cast(probs1, self.parameter_type)
|
||||||
|
return self.probs if self.probs is not None else raise_none_error('probs1')
|
||||||
|
|
||||||
def _mean(self, probs1=None):
|
def _mean(self, probs1=None):
|
||||||
r"""
|
r"""
|
||||||
.. math::
|
.. math::
|
||||||
MEAN(Geo) = \fratc{1 - probs1}{probs1}
|
MEAN(Geo) = \fratc{1 - probs1}{probs1}
|
||||||
"""
|
"""
|
||||||
probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs
|
probs1 = self._check_param(probs1)
|
||||||
if probs1 is None:
|
|
||||||
raise_none_error("probs1")
|
|
||||||
return (1. - probs1) / probs1
|
return (1. - probs1) / probs1
|
||||||
|
|
||||||
def _mode(self, probs1=None):
|
def _mode(self, probs1=None):
|
||||||
|
@ -158,9 +169,7 @@ class Geometric(Distribution):
|
||||||
.. math::
|
.. math::
|
||||||
MODE(Geo) = 0
|
MODE(Geo) = 0
|
||||||
"""
|
"""
|
||||||
probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs
|
probs1 = self._check_param(probs1)
|
||||||
if probs1 is None:
|
|
||||||
raise_none_error("probs1")
|
|
||||||
return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.)
|
return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.)
|
||||||
|
|
||||||
def _var(self, probs1=None):
|
def _var(self, probs1=None):
|
||||||
|
@ -168,23 +177,19 @@ class Geometric(Distribution):
|
||||||
.. math::
|
.. math::
|
||||||
VAR(Geo) = \frac{1 - probs1}{probs1 ^ {2}}
|
VAR(Geo) = \frac{1 - probs1}{probs1 ^ {2}}
|
||||||
"""
|
"""
|
||||||
probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs
|
probs1 = self._check_param(probs1)
|
||||||
if probs1 is None:
|
|
||||||
raise_none_error("probs1")
|
|
||||||
return (1.0 - probs1) / self.sq(probs1)
|
return (1.0 - probs1) / self.sq(probs1)
|
||||||
|
|
||||||
def _entropy(self, probs=None):
|
def _entropy(self, probs1=None):
|
||||||
r"""
|
r"""
|
||||||
.. math::
|
.. math::
|
||||||
H(Geo) = \frac{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1}
|
H(Geo) = \frac{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1}
|
||||||
"""
|
"""
|
||||||
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs
|
probs1 = self._check_param(probs1)
|
||||||
if probs1 is None:
|
|
||||||
raise_none_error("probs")
|
|
||||||
probs0 = 1.0 - probs1
|
probs0 = 1.0 - probs1
|
||||||
return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1
|
return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1
|
||||||
|
|
||||||
def _cross_entropy(self, dist, probs1_b, probs1_a=None):
|
def _cross_entropy(self, dist, probs1_b, probs1=None):
|
||||||
r"""
|
r"""
|
||||||
Evaluate cross_entropy between Geometric distributions.
|
Evaluate cross_entropy between Geometric distributions.
|
||||||
|
|
||||||
|
@ -194,9 +199,9 @@ class Geometric(Distribution):
|
||||||
probs1_a (Tensor): probability of success of distribution a. Default: self.probs.
|
probs1_a (Tensor): probability of success of distribution a. Default: self.probs.
|
||||||
"""
|
"""
|
||||||
check_distribution_name(dist, 'Geometric')
|
check_distribution_name(dist, 'Geometric')
|
||||||
return self._entropy(probs=probs1_a) + self._kl_loss(dist, probs1_b, probs1_a)
|
return self._entropy(probs1) + self._kl_loss(dist, probs1_b, probs1)
|
||||||
|
|
||||||
def _prob(self, value, probs=None):
|
def _prob(self, value, probs1=None):
|
||||||
r"""
|
r"""
|
||||||
pmf of Geometric distribution.
|
pmf of Geometric distribution.
|
||||||
|
|
||||||
|
@ -208,19 +213,16 @@ class Geometric(Distribution):
|
||||||
pmf(k) = probs0 ^k * probs1 if k >= 0;
|
pmf(k) = probs0 ^k * probs1 if k >= 0;
|
||||||
pmf(k) = 0 if k < 0.
|
pmf(k) = 0 if k < 0.
|
||||||
"""
|
"""
|
||||||
if value is None:
|
self.checktensor(value, 'value')
|
||||||
raise_none_error("value")
|
|
||||||
value = self.cast(value, mstype.float32)
|
value = self.cast(value, mstype.float32)
|
||||||
value = self.floor(value)
|
value = self.floor(value)
|
||||||
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs
|
probs1 = self._check_param(probs1)
|
||||||
if probs1 is None:
|
|
||||||
raise_none_error("probs")
|
|
||||||
pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1))
|
pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1))
|
||||||
zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0)
|
zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0)
|
||||||
comp = self.less(value, zeros)
|
comp = self.less(value, zeros)
|
||||||
return self.select(comp, zeros, pmf)
|
return self.select(comp, zeros, pmf)
|
||||||
|
|
||||||
def _cdf(self, value, probs=None):
|
def _cdf(self, value, probs1=None):
|
||||||
r"""
|
r"""
|
||||||
cdf of Geometric distribution.
|
cdf of Geometric distribution.
|
||||||
|
|
||||||
|
@ -233,13 +235,10 @@ class Geometric(Distribution):
|
||||||
cdf(k) = 0 if k < 0.
|
cdf(k) = 0 if k < 0.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if value is None:
|
self.checktensor(value, 'value')
|
||||||
raise_none_error("value")
|
|
||||||
value = self.cast(value, mstype.float32)
|
value = self.cast(value, mstype.float32)
|
||||||
value = self.floor(value)
|
value = self.floor(value)
|
||||||
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs
|
probs1 = self._check_param(probs1)
|
||||||
if probs1 is None:
|
|
||||||
raise_none_error("probs")
|
|
||||||
probs0 = 1.0 - probs1
|
probs0 = 1.0 - probs1
|
||||||
cdf = 1.0 - self.pow(probs0, value + 1.0)
|
cdf = 1.0 - self.pow(probs0, value + 1.0)
|
||||||
zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0)
|
zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0)
|
||||||
|
@ -247,7 +246,7 @@ class Geometric(Distribution):
|
||||||
return self.select(comp, zeros, cdf)
|
return self.select(comp, zeros, cdf)
|
||||||
|
|
||||||
|
|
||||||
def _kl_loss(self, dist, probs1_b, probs1_a=None):
|
def _kl_loss(self, dist, probs1_b, probs1=None):
|
||||||
r"""
|
r"""
|
||||||
Evaluate Geometric-Geometric kl divergence, i.e. KL(a||b).
|
Evaluate Geometric-Geometric kl divergence, i.e. KL(a||b).
|
||||||
|
|
||||||
|
@ -260,17 +259,14 @@ class Geometric(Distribution):
|
||||||
KL(a||b) = \log(\frac{probs1_a}{probs1_b}) + \frac{probs0_a}{probs1_a} * \log(\frac{probs0_a}{probs0_b})
|
KL(a||b) = \log(\frac{probs1_a}{probs1_b}) + \frac{probs0_a}{probs1_a} * \log(\frac{probs0_a}{probs0_b})
|
||||||
"""
|
"""
|
||||||
check_distribution_name(dist, 'Geometric')
|
check_distribution_name(dist, 'Geometric')
|
||||||
if probs1_b is None:
|
self.checktensor(probs1_b, 'probs1_b')
|
||||||
raise_none_error("probs1_b")
|
|
||||||
probs1_b = self.cast(probs1_b, self.parameter_type)
|
probs1_b = self.cast(probs1_b, self.parameter_type)
|
||||||
probs1_a = self.cast(probs1_a, self.parameter_type) if probs1_a is not None else self.probs
|
probs1_a = self._check_param(probs1)
|
||||||
if probs1_a is None:
|
|
||||||
raise_none_error("probs1_a")
|
|
||||||
probs0_a = 1.0 - probs1_a
|
probs0_a = 1.0 - probs1_a
|
||||||
probs0_b = 1.0 - probs1_b
|
probs0_b = 1.0 - probs1_b
|
||||||
return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b)
|
return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b)
|
||||||
|
|
||||||
def _sample(self, shape=(), probs=None):
|
def _sample(self, shape=(), probs1=None):
|
||||||
"""
|
"""
|
||||||
Sampling.
|
Sampling.
|
||||||
|
|
||||||
|
@ -281,9 +277,8 @@ class Geometric(Distribution):
|
||||||
Returns:
|
Returns:
|
||||||
Tensor, shape is shape + batch_shape.
|
Tensor, shape is shape + batch_shape.
|
||||||
"""
|
"""
|
||||||
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs
|
self.checktuple(shape, 'shape')
|
||||||
if probs1 is None:
|
probs1 = self._check_param(probs1)
|
||||||
raise_none_error("probs")
|
|
||||||
origin_shape = shape + self.shape(probs1)
|
origin_shape = shape + self.shape(probs1)
|
||||||
if origin_shape == ():
|
if origin_shape == ():
|
||||||
sample_shape = (1,)
|
sample_shape = (1,)
|
||||||
|
|
|
@ -20,6 +20,7 @@ from mindspore.common import dtype as mstype
|
||||||
from .distribution import Distribution
|
from .distribution import Distribution
|
||||||
from ._utils.utils import convert_to_batch, check_greater_zero, check_type, check_distribution_name,\
|
from ._utils.utils import convert_to_batch, check_greater_zero, check_type, check_distribution_name,\
|
||||||
raise_none_error
|
raise_none_error
|
||||||
|
from ._utils.utils import CheckTensor, CheckTuple
|
||||||
|
|
||||||
class Normal(Distribution):
|
class Normal(Distribution):
|
||||||
"""
|
"""
|
||||||
|
@ -112,7 +113,6 @@ class Normal(Distribution):
|
||||||
self._mean_value = mean
|
self._mean_value = mean
|
||||||
self._sd_value = sd
|
self._sd_value = sd
|
||||||
|
|
||||||
|
|
||||||
#ops needed for the class
|
#ops needed for the class
|
||||||
self.squeeze = P.Squeeze(0)
|
self.squeeze = P.Squeeze(0)
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
|
@ -127,6 +127,9 @@ class Normal(Distribution):
|
||||||
self.sqrt = P.Sqrt()
|
self.sqrt = P.Sqrt()
|
||||||
self.zeroslike = P.ZerosLike()
|
self.zeroslike = P.ZerosLike()
|
||||||
|
|
||||||
|
self.checktensor = CheckTensor()
|
||||||
|
self.checktuple = CheckTuple()
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
if self.is_scalar_batch:
|
if self.is_scalar_batch:
|
||||||
str_info = f'mean = {self._mean_value}, standard deviation = {self._sd_value}'
|
str_info = f'mean = {self._mean_value}, standard deviation = {self._sd_value}'
|
||||||
|
@ -140,40 +143,44 @@ class Normal(Distribution):
|
||||||
"""
|
"""
|
||||||
return self.exp(x) - 1.0
|
return self.exp(x) - 1.0
|
||||||
|
|
||||||
|
def _check_param(self, mean, sd):
|
||||||
|
"""
|
||||||
|
Check availablity of distribution specific args mean and sd.
|
||||||
|
"""
|
||||||
|
if mean is not None:
|
||||||
|
self.checktensor(mean, 'mean')
|
||||||
|
mean = self.cast(mean, self.parameter_type)
|
||||||
|
else:
|
||||||
|
mean = self._mean_value if self._mean_value is not None else raise_none_error('mean')
|
||||||
|
if sd is not None:
|
||||||
|
self.checktensor(sd, 'sd')
|
||||||
|
sd = self.cast(sd, self.parameter_type)
|
||||||
|
else:
|
||||||
|
sd = self._sd_value if self._sd_value is not None else raise_none_error('sd')
|
||||||
|
batch_shape = self.shape(mean + sd)
|
||||||
|
mean = mean * self.fill(self.dtype, batch_shape, 1.0)
|
||||||
|
sd = sd * self.fill(self.dtype, batch_shape, 1.0)
|
||||||
|
return mean, sd
|
||||||
|
|
||||||
def _mean(self, mean=None, sd=None):
|
def _mean(self, mean=None, sd=None):
|
||||||
"""
|
"""
|
||||||
Mean of the distribution.
|
Mean of the distribution.
|
||||||
"""
|
"""
|
||||||
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value
|
mean, sd = self._check_param(mean, sd)
|
||||||
if mean is None:
|
|
||||||
raise_none_error("mean")
|
|
||||||
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
|
|
||||||
if sd is None:
|
|
||||||
raise_none_error("sd")
|
|
||||||
return mean
|
return mean
|
||||||
|
|
||||||
def _mode(self, mean=None, sd=None):
|
def _mode(self, mean=None, sd=None):
|
||||||
"""
|
"""
|
||||||
Mode of the distribution.
|
Mode of the distribution.
|
||||||
"""
|
"""
|
||||||
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value
|
mean, sd = self._check_param(mean, sd)
|
||||||
if mean is None:
|
|
||||||
raise_none_error("mean")
|
|
||||||
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
|
|
||||||
if sd is None:
|
|
||||||
raise_none_error("sd")
|
|
||||||
return mean
|
return mean
|
||||||
|
|
||||||
def _sd(self, mean=None, sd=None):
|
def _sd(self, mean=None, sd=None):
|
||||||
"""
|
"""
|
||||||
Standard deviation of the distribution.
|
Standard deviation of the distribution.
|
||||||
"""
|
"""
|
||||||
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value
|
mean, sd = self._check_param(mean, sd)
|
||||||
if mean is None:
|
|
||||||
raise_none_error("mean")
|
|
||||||
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
|
|
||||||
if sd is None:
|
|
||||||
raise_none_error("sd")
|
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
def _entropy(self, mean=None, sd=None):
|
def _entropy(self, mean=None, sd=None):
|
||||||
|
@ -183,15 +190,10 @@ class Normal(Distribution):
|
||||||
.. math::
|
.. math::
|
||||||
H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma)))
|
H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma)))
|
||||||
"""
|
"""
|
||||||
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value
|
mean, sd = self._check_param(mean, sd)
|
||||||
if mean is None:
|
|
||||||
raise_none_error("mean")
|
|
||||||
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
|
|
||||||
if sd is None:
|
|
||||||
raise_none_error("sd")
|
|
||||||
return self.log(self.sqrt(self.const(np.e * 2. * np.pi))) + self.log(sd)
|
return self.log(self.sqrt(self.const(np.e * 2. * np.pi))) + self.log(sd)
|
||||||
|
|
||||||
def _cross_entropy(self, dist, mean_b, sd_b, mean_a=None, sd_a=None):
|
def _cross_entropy(self, dist, mean_b, sd_b, mean=None, sd=None):
|
||||||
r"""
|
r"""
|
||||||
Evaluate cross_entropy between normal distributions.
|
Evaluate cross_entropy between normal distributions.
|
||||||
|
|
||||||
|
@ -203,7 +205,7 @@ class Normal(Distribution):
|
||||||
sd_a (Tensor): standard deviation distribution a. Default: self._sd_value.
|
sd_a (Tensor): standard deviation distribution a. Default: self._sd_value.
|
||||||
"""
|
"""
|
||||||
check_distribution_name(dist, 'Normal')
|
check_distribution_name(dist, 'Normal')
|
||||||
return self._entropy(mean=mean_a, sd=sd_a) + self._kl_loss(dist, mean_b, sd_b, mean_a, sd_a)
|
return self._entropy(mean, sd) + self._kl_loss(dist, mean_b, sd_b, mean, sd)
|
||||||
|
|
||||||
def _log_prob(self, value, mean=None, sd=None):
|
def _log_prob(self, value, mean=None, sd=None):
|
||||||
r"""
|
r"""
|
||||||
|
@ -217,15 +219,9 @@ class Normal(Distribution):
|
||||||
.. math::
|
.. math::
|
||||||
L(x) = -1* \frac{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2))
|
L(x) = -1* \frac{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2))
|
||||||
"""
|
"""
|
||||||
if value is None:
|
self.checktensor(value, 'value')
|
||||||
raise_none_error("value")
|
|
||||||
value = self.cast(value, self.dtype)
|
value = self.cast(value, self.dtype)
|
||||||
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value
|
mean, sd = self._check_param(mean, sd)
|
||||||
if mean is None:
|
|
||||||
raise_none_error("mean")
|
|
||||||
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
|
|
||||||
if sd is None:
|
|
||||||
raise_none_error("sd")
|
|
||||||
unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd))
|
unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd))
|
||||||
neg_normalization = -1. * self.log(self.const(2. * np.pi)) / 2. - self.log(sd)
|
neg_normalization = -1. * self.log(self.const(2. * np.pi)) / 2. - self.log(sd)
|
||||||
return unnormalized_log_prob + neg_normalization
|
return unnormalized_log_prob + neg_normalization
|
||||||
|
@ -242,20 +238,14 @@ class Normal(Distribution):
|
||||||
.. math::
|
.. math::
|
||||||
cdf(x) = 0.5 * (1+ Erf((x - \mu) / ( \sigma * \sqrt(2))))
|
cdf(x) = 0.5 * (1+ Erf((x - \mu) / ( \sigma * \sqrt(2))))
|
||||||
"""
|
"""
|
||||||
if value is None:
|
self.checktensor(value, 'value')
|
||||||
raise_none_error("value")
|
|
||||||
value = self.cast(value, self.dtype)
|
value = self.cast(value, self.dtype)
|
||||||
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value
|
mean, sd = self._check_param(mean, sd)
|
||||||
if mean is None:
|
|
||||||
raise_none_error("mean")
|
|
||||||
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
|
|
||||||
if sd is None:
|
|
||||||
raise_none_error("sd")
|
|
||||||
sqrt2 = self.sqrt(self.const(2.0))
|
sqrt2 = self.sqrt(self.const(2.0))
|
||||||
adjusted = (value - mean) / (sd * sqrt2)
|
adjusted = (value - mean) / (sd * sqrt2)
|
||||||
return 0.5 * (1.0 + self.erf(adjusted))
|
return 0.5 * (1.0 + self.erf(adjusted))
|
||||||
|
|
||||||
def _kl_loss(self, dist, mean_b, sd_b, mean_a=None, sd_a=None):
|
def _kl_loss(self, dist, mean_b, sd_b, mean=None, sd=None):
|
||||||
r"""
|
r"""
|
||||||
Evaluate Normal-Normal kl divergence, i.e. KL(a||b).
|
Evaluate Normal-Normal kl divergence, i.e. KL(a||b).
|
||||||
|
|
||||||
|
@ -271,23 +261,15 @@ class Normal(Distribution):
|
||||||
0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b)))
|
0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b)))
|
||||||
"""
|
"""
|
||||||
check_distribution_name(dist, 'Normal')
|
check_distribution_name(dist, 'Normal')
|
||||||
if mean_b is None:
|
self.checktensor(mean_b, 'mean_b')
|
||||||
raise_none_error("mean_b")
|
self.checktensor(sd_b, 'sd_b')
|
||||||
if sd_b is None:
|
|
||||||
raise_none_error("sd_b")
|
|
||||||
mean_b = self.cast(mean_b, self.parameter_type)
|
mean_b = self.cast(mean_b, self.parameter_type)
|
||||||
sd_b = self.cast(sd_b, self.parameter_type)
|
sd_b = self.cast(sd_b, self.parameter_type)
|
||||||
mean_a = self.cast(mean_a, self.parameter_type) if mean_a is not None else self._mean_value
|
mean_a, sd_a = self._check_param(mean, sd)
|
||||||
sd_a = self.cast(sd_a, self.parameter_type) if sd_a is not None else self._sd_value
|
|
||||||
if mean_a is None:
|
|
||||||
raise_none_error("mean_a")
|
|
||||||
if sd_a is None:
|
|
||||||
raise_none_error("sd_a")
|
|
||||||
diff_log_scale = self.log(sd_a) - self.log(sd_b)
|
diff_log_scale = self.log(sd_a) - self.log(sd_b)
|
||||||
squared_diff = self.sq(mean_a / sd_b - mean_b / sd_b)
|
squared_diff = self.sq(mean_a / sd_b - mean_b / sd_b)
|
||||||
return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale
|
return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale
|
||||||
|
|
||||||
|
|
||||||
def _sample(self, shape=(), mean=None, sd=None):
|
def _sample(self, shape=(), mean=None, sd=None):
|
||||||
"""
|
"""
|
||||||
Sampling.
|
Sampling.
|
||||||
|
@ -300,12 +282,8 @@ class Normal(Distribution):
|
||||||
Returns:
|
Returns:
|
||||||
Tensor, shape is shape + batch_shape.
|
Tensor, shape is shape + batch_shape.
|
||||||
"""
|
"""
|
||||||
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value
|
self.checktuple(shape, 'shape')
|
||||||
if mean is None:
|
mean, sd = self._check_param(mean, sd)
|
||||||
raise_none_error("mean")
|
|
||||||
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
|
|
||||||
if sd is None:
|
|
||||||
raise_none_error("sd")
|
|
||||||
batch_shape = self.shape(mean + sd)
|
batch_shape = self.shape(mean + sd)
|
||||||
origin_shape = shape + batch_shape
|
origin_shape = shape + batch_shape
|
||||||
if origin_shape == ():
|
if origin_shape == ():
|
||||||
|
|
|
@ -19,6 +19,7 @@ from mindspore.common import dtype as mstype
|
||||||
from .distribution import Distribution
|
from .distribution import Distribution
|
||||||
from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\
|
from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\
|
||||||
raise_none_error
|
raise_none_error
|
||||||
|
from ._utils.utils import CheckTensor, CheckTuple
|
||||||
|
|
||||||
class Uniform(Distribution):
|
class Uniform(Distribution):
|
||||||
"""
|
"""
|
||||||
|
@ -129,6 +130,9 @@ class Uniform(Distribution):
|
||||||
self.zeroslike = P.ZerosLike()
|
self.zeroslike = P.ZerosLike()
|
||||||
self.uniform = C.uniform
|
self.uniform = C.uniform
|
||||||
|
|
||||||
|
self.checktensor = CheckTensor()
|
||||||
|
self.checktuple = CheckTuple()
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
if self.is_scalar_batch:
|
if self.is_scalar_batch:
|
||||||
str_info = f'low = {self.low}, high = {self.high}'
|
str_info = f'low = {self.low}, high = {self.high}'
|
||||||
|
@ -136,6 +140,25 @@ class Uniform(Distribution):
|
||||||
str_info = f'batch_shape = {self._broadcast_shape}'
|
str_info = f'batch_shape = {self._broadcast_shape}'
|
||||||
return str_info
|
return str_info
|
||||||
|
|
||||||
|
def _check_param(self, low, high):
|
||||||
|
"""
|
||||||
|
Check availablity of distribution specific args low and high.
|
||||||
|
"""
|
||||||
|
if low is not None:
|
||||||
|
self.checktensor(low, 'low')
|
||||||
|
low = self.cast(low, self.parameter_type)
|
||||||
|
else:
|
||||||
|
low = self.low if self.low is not None else raise_none_error('low')
|
||||||
|
if high is not None:
|
||||||
|
self.checktensor(high, 'high')
|
||||||
|
high = self.cast(high, self.parameter_type)
|
||||||
|
else:
|
||||||
|
high = self.high if self.high is not None else raise_none_error('high')
|
||||||
|
batch_shape = self.shape(high - low)
|
||||||
|
high = high * self.fill(self.dtype, batch_shape, 1.0)
|
||||||
|
low = low * self.fill(self.dtype, batch_shape, 1.0)
|
||||||
|
return low, high
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def low(self):
|
def low(self):
|
||||||
"""
|
"""
|
||||||
|
@ -156,12 +179,7 @@ class Uniform(Distribution):
|
||||||
.. math::
|
.. math::
|
||||||
range(U) = high -low
|
range(U) = high -low
|
||||||
"""
|
"""
|
||||||
low = self.cast(low, self.parameter_type) if low is not None else self.low
|
low, high = self._check_param(low, high)
|
||||||
if low is None:
|
|
||||||
raise_none_error("low")
|
|
||||||
high = self.cast(high, self.parameter_type) if high is not None else self.high
|
|
||||||
if high is None:
|
|
||||||
raise_none_error("high")
|
|
||||||
return high - low
|
return high - low
|
||||||
|
|
||||||
def _mean(self, low=None, high=None):
|
def _mean(self, low=None, high=None):
|
||||||
|
@ -169,12 +187,7 @@ class Uniform(Distribution):
|
||||||
.. math::
|
.. math::
|
||||||
MEAN(U) = \frac{low + high}{2}.
|
MEAN(U) = \frac{low + high}{2}.
|
||||||
"""
|
"""
|
||||||
low = self.cast(low, self.parameter_type) if low is not None else self.low
|
low, high = self._check_param(low, high)
|
||||||
if low is None:
|
|
||||||
raise_none_error("low")
|
|
||||||
high = self.cast(high, self.parameter_type) if high is not None else self.high
|
|
||||||
if high is None:
|
|
||||||
raise_none_error("high")
|
|
||||||
return (low + high) / 2.
|
return (low + high) / 2.
|
||||||
|
|
||||||
def _var(self, low=None, high=None):
|
def _var(self, low=None, high=None):
|
||||||
|
@ -182,12 +195,7 @@ class Uniform(Distribution):
|
||||||
.. math::
|
.. math::
|
||||||
VAR(U) = \frac{(high -low) ^ 2}{12}.
|
VAR(U) = \frac{(high -low) ^ 2}{12}.
|
||||||
"""
|
"""
|
||||||
low = self.cast(low, self.parameter_type) if low is not None else self.low
|
low, high = self._check_param(low, high)
|
||||||
if low is None:
|
|
||||||
raise_none_error("low")
|
|
||||||
high = self.cast(high, self.parameter_type) if high is not None else self.high
|
|
||||||
if high is None:
|
|
||||||
raise_none_error("high")
|
|
||||||
return self.sq(high - low) / 12.0
|
return self.sq(high - low) / 12.0
|
||||||
|
|
||||||
def _entropy(self, low=None, high=None):
|
def _entropy(self, low=None, high=None):
|
||||||
|
@ -195,15 +203,10 @@ class Uniform(Distribution):
|
||||||
.. math::
|
.. math::
|
||||||
H(U) = \log(high - low).
|
H(U) = \log(high - low).
|
||||||
"""
|
"""
|
||||||
low = self.cast(low, self.parameter_type) if low is not None else self.low
|
low, high = self._check_param(low, high)
|
||||||
if low is None:
|
|
||||||
raise_none_error("low")
|
|
||||||
high = self.cast(high, self.parameter_type) if high is not None else self.high
|
|
||||||
if high is None:
|
|
||||||
raise_none_error("high")
|
|
||||||
return self.log(high - low)
|
return self.log(high - low)
|
||||||
|
|
||||||
def _cross_entropy(self, dist, low_b, high_b, low_a=None, high_a=None):
|
def _cross_entropy(self, dist, low_b, high_b, low=None, high=None):
|
||||||
"""
|
"""
|
||||||
Evaluate cross_entropy between Uniform distributoins.
|
Evaluate cross_entropy between Uniform distributoins.
|
||||||
|
|
||||||
|
@ -215,7 +218,7 @@ class Uniform(Distribution):
|
||||||
high_a (Tensor): upper bound of distribution a. Default: self.high.
|
high_a (Tensor): upper bound of distribution a. Default: self.high.
|
||||||
"""
|
"""
|
||||||
check_distribution_name(dist, 'Uniform')
|
check_distribution_name(dist, 'Uniform')
|
||||||
return self._entropy(low=low_a, high=high_a) + self._kl_loss(dist, low_b, high_b, low_a, high_a)
|
return self._entropy(low, high) + self._kl_loss(dist, low_b, high_b, low, high)
|
||||||
|
|
||||||
def _prob(self, value, low=None, high=None):
|
def _prob(self, value, low=None, high=None):
|
||||||
r"""
|
r"""
|
||||||
|
@ -231,15 +234,9 @@ class Uniform(Distribution):
|
||||||
pdf(x) = \frac{1.0}{high -low} if low <= x <= high;
|
pdf(x) = \frac{1.0}{high -low} if low <= x <= high;
|
||||||
pdf(x) = 0 if x > high;
|
pdf(x) = 0 if x > high;
|
||||||
"""
|
"""
|
||||||
if value is None:
|
self.checktensor(value, 'value')
|
||||||
raise_none_error("value")
|
|
||||||
value = self.cast(value, self.dtype)
|
value = self.cast(value, self.dtype)
|
||||||
low = self.cast(low, self.parameter_type) if low is not None else self.low
|
low, high = self._check_param(low, high)
|
||||||
if low is None:
|
|
||||||
raise_none_error("low")
|
|
||||||
high = self.cast(high, self.parameter_type) if high is not None else self.high
|
|
||||||
if high is None:
|
|
||||||
raise_none_error("high")
|
|
||||||
neg_ones = self.fill(self.dtype, self.shape(value), -1.0)
|
neg_ones = self.fill(self.dtype, self.shape(value), -1.0)
|
||||||
prob = self.exp(neg_ones * self.log(high - low))
|
prob = self.exp(neg_ones * self.log(high - low))
|
||||||
broadcast_shape = self.shape(prob)
|
broadcast_shape = self.shape(prob)
|
||||||
|
@ -249,7 +246,7 @@ class Uniform(Distribution):
|
||||||
less_than_low = self.select(comp_lo, zeros, prob)
|
less_than_low = self.select(comp_lo, zeros, prob)
|
||||||
return self.select(comp_hi, less_than_low, zeros)
|
return self.select(comp_hi, less_than_low, zeros)
|
||||||
|
|
||||||
def _kl_loss(self, dist, low_b, high_b, low_a=None, high_a=None):
|
def _kl_loss(self, dist, low_b, high_b, low=None, high=None):
|
||||||
"""
|
"""
|
||||||
Evaluate uniform-uniform kl divergence, i.e. KL(a||b).
|
Evaluate uniform-uniform kl divergence, i.e. KL(a||b).
|
||||||
|
|
||||||
|
@ -261,19 +258,12 @@ class Uniform(Distribution):
|
||||||
high_a (Tensor): upper bound of distribution a. Default: self.high.
|
high_a (Tensor): upper bound of distribution a. Default: self.high.
|
||||||
"""
|
"""
|
||||||
check_distribution_name(dist, 'Uniform')
|
check_distribution_name(dist, 'Uniform')
|
||||||
if low_b is None:
|
self.checktensor(low_b, 'low_b')
|
||||||
raise_none_error("low_b")
|
|
||||||
if high_b is None:
|
|
||||||
raise_none_error("high_b")
|
|
||||||
low_b = self.cast(low_b, self.parameter_type)
|
low_b = self.cast(low_b, self.parameter_type)
|
||||||
|
self.checktensor(high_b, 'high_b')
|
||||||
high_b = self.cast(high_b, self.parameter_type)
|
high_b = self.cast(high_b, self.parameter_type)
|
||||||
low_a = self.cast(low_a, self.parameter_type) if low_a is not None else self.low
|
low_a, high_a = self._check_param(low, high)
|
||||||
if low_a is None:
|
kl = self.log(high_b - low_b) - self.log(high_a - low_a)
|
||||||
raise_none_error("low_a")
|
|
||||||
high_a = self.cast(high_a, self.parameter_type) if high_a is not None else self.high
|
|
||||||
if high_a is None:
|
|
||||||
raise_none_error("high_a")
|
|
||||||
kl = self.log(high_b - low_b) / self.log(high_a - low_a)
|
|
||||||
comp = self.logicaland(self.lessequal(low_b, low_a), self.lessequal(high_a, high_b))
|
comp = self.logicaland(self.lessequal(low_b, low_a), self.lessequal(high_a, high_b))
|
||||||
return self.select(comp, kl, self.log(self.zeroslike(kl)))
|
return self.select(comp, kl, self.log(self.zeroslike(kl)))
|
||||||
|
|
||||||
|
@ -291,15 +281,9 @@ class Uniform(Distribution):
|
||||||
cdf(x) = \frac{x - low}{high -low} if low <= x <= high;
|
cdf(x) = \frac{x - low}{high -low} if low <= x <= high;
|
||||||
cdf(x) = 1 if x > high;
|
cdf(x) = 1 if x > high;
|
||||||
"""
|
"""
|
||||||
if value is None:
|
self.checktensor(value, 'value')
|
||||||
raise_none_error("value")
|
|
||||||
value = self.cast(value, self.dtype)
|
value = self.cast(value, self.dtype)
|
||||||
low = self.cast(low, self.parameter_type) if low is not None else self.low
|
low, high = self._check_param(low, high)
|
||||||
if low is None:
|
|
||||||
raise_none_error("low")
|
|
||||||
high = self.cast(high, self.parameter_type) if high is not None else self.high
|
|
||||||
if high is None:
|
|
||||||
raise_none_error("high")
|
|
||||||
prob = (value - low) / (high - low)
|
prob = (value - low) / (high - low)
|
||||||
broadcast_shape = self.shape(prob)
|
broadcast_shape = self.shape(prob)
|
||||||
zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
|
zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
|
||||||
|
@ -321,12 +305,8 @@ class Uniform(Distribution):
|
||||||
Returns:
|
Returns:
|
||||||
Tensor, shape is shape + batch_shape.
|
Tensor, shape is shape + batch_shape.
|
||||||
"""
|
"""
|
||||||
low = self.cast(low, self.parameter_type) if low is not None else self.low
|
self.checktuple(shape, 'shape')
|
||||||
if low is None:
|
low, high = self._check_param(low, high)
|
||||||
raise_none_error("low")
|
|
||||||
high = self.cast(high, self.parameter_type) if high is not None else self.high
|
|
||||||
if high is None:
|
|
||||||
raise_none_error("high")
|
|
||||||
broadcast_shape = self.shape(low + high)
|
broadcast_shape = self.shape(low + high)
|
||||||
origin_shape = shape + broadcast_shape
|
origin_shape = shape + broadcast_shape
|
||||||
if origin_shape == ():
|
if origin_shape == ():
|
||||||
|
|
|
@ -75,7 +75,7 @@ def test_forward_jacobian():
|
||||||
forward_jacobian = Net2()
|
forward_jacobian = Net2()
|
||||||
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
|
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
|
||||||
ans = forward_jacobian(x)
|
ans = forward_jacobian(x)
|
||||||
expected = np.log([2.0, 2.0, 2.0, 2.0])
|
expected = np.log([2.0])
|
||||||
tol = 1e-6
|
tol = 1e-6
|
||||||
assert (np.abs(ans.asnumpy() - expected) < tol).all()
|
assert (np.abs(ans.asnumpy() - expected) < tol).all()
|
||||||
|
|
||||||
|
@ -94,6 +94,6 @@ def test_backward_jacobian():
|
||||||
backward_jacobian = Net3()
|
backward_jacobian = Net3()
|
||||||
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
|
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
|
||||||
ans = backward_jacobian(x)
|
ans = backward_jacobian(x)
|
||||||
expected = np.log([0.5, 0.5, 0.5, 0.5])
|
expected = np.log([0.5])
|
||||||
tol = 1e-6
|
tol = 1e-6
|
||||||
assert (np.abs(ans.asnumpy() - expected) < tol).all()
|
assert (np.abs(ans.asnumpy() - expected) < tol).all()
|
||||||
|
|
|
@ -20,7 +20,7 @@ import mindspore.nn.probability.bijector as msb
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore import dtype
|
from mindspore import dtype
|
||||||
|
|
||||||
context.set_context(device_target="Ascend")
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -88,7 +88,7 @@ def test_kl_loss():
|
||||||
high_a = 1.5
|
high_a = 1.5
|
||||||
low_b = -1.0
|
low_b = -1.0
|
||||||
high_b = 2.0
|
high_b = 2.0
|
||||||
expect_kl_loss = np.log(high_b - low_b) / np.log(high_a - low_a)
|
expect_kl_loss = np.log(high_b - low_b) - np.log(high_a - low_a)
|
||||||
kl = KL()
|
kl = KL()
|
||||||
output = kl(Tensor(low_b, dtype=dtype.float32), Tensor(high_b, dtype=dtype.float32))
|
output = kl(Tensor(low_b, dtype=dtype.float32), Tensor(high_b, dtype=dtype.float32))
|
||||||
tol = 1e-6
|
tol = 1e-6
|
||||||
|
|
Loading…
Reference in New Issue