!4875 Fix issues related to parameter checking, formulas in distributions and bijectors

Merge pull request !4875 from XunDeng/pp_issue_branch
This commit is contained in:
mindspore-ci-bot 2020-08-21 07:04:31 +08:00 committed by Gitee
commit 8021dc587d
12 changed files with 284 additions and 249 deletions

View File

@ -16,6 +16,7 @@
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from ..distribution._utils.utils import CheckTensor
from .bijector import Bijector from .bijector import Bijector
class PowerTransform(Bijector): class PowerTransform(Bijector):
@ -62,6 +63,8 @@ class PowerTransform(Bijector):
self.log1p = self._log1p_by_step self.log1p = self._log1p_by_step
self.expm1 = self._expm1_by_step self.expm1 = self._expm1_by_step
self.checktensor = CheckTensor()
def _log1p_by_step(self, x): def _log1p_by_step(self, x):
""" """
Log1p ops on GPU device or when device_target == GPU. Log1p ops on GPU device or when device_target == GPU.
@ -86,11 +89,13 @@ class PowerTransform(Bijector):
return shape return shape
def _forward(self, x): def _forward(self, x):
self.checktensor(x, 'x')
if self.power == 0: if self.power == 0:
return self.exp(x) return self.exp(x)
return self.exp(self.log1p(x * self.power) / self.power) return self.exp(self.log1p(x * self.power) / self.power)
def _inverse(self, y): def _inverse(self, y):
self.checktensor(y, 'y')
if self.power == 0: if self.power == 0:
return self.log(y) return self.log(y)
return self.expm1(self.log(y) * self.power) / self.power return self.expm1(self.log(y) * self.power) / self.power
@ -107,6 +112,7 @@ class PowerTransform(Bijector):
f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1} f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1}
\log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1) \log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1)
""" """
self.checktensor(x, 'x')
if self.power == 0: if self.power == 0:
return x return x
return (1. / self.power - 1) * self.log1p(x * self.power) return (1. / self.power - 1) * self.log1p(x * self.power)
@ -123,4 +129,5 @@ class PowerTransform(Bijector):
f'(x) = \frac{e^c\log(y)}{y} f'(x) = \frac{e^c\log(y)}{y}
\log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y) \log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y)
""" """
self.checktensor(y, 'y')
return (self.power - 1) * self.log(y) return (self.power - 1) * self.log(y)

View File

@ -15,7 +15,7 @@
"""Scalar Affine Bijector""" """Scalar Affine Bijector"""
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from ..distribution._utils.utils import cast_to_tensor from ..distribution._utils.utils import cast_to_tensor, CheckTensor
from .bijector import Bijector from .bijector import Bijector
class ScalarAffine(Bijector): class ScalarAffine(Bijector):
@ -54,8 +54,8 @@ class ScalarAffine(Bijector):
Constructor of scalar affine bijector. Constructor of scalar affine bijector.
""" """
param = dict(locals()) param = dict(locals())
validator.check_value_type('scale', scale, [float], name) validator.check_value_type('scale', scale, [int, float], name)
validator.check_value_type('shift', shift, [float], name) validator.check_value_type('shift', shift, [int, float], name)
self._scale = cast_to_tensor(scale) self._scale = cast_to_tensor(scale)
self._shift = cast_to_tensor(shift) self._shift = cast_to_tensor(shift)
super(ScalarAffine, self).__init__( super(ScalarAffine, self).__init__(
@ -65,8 +65,10 @@ class ScalarAffine(Bijector):
dtype=None, dtype=None,
param=param) param=param)
self.abs = P.Abs()
self.log = P.Log() self.log = P.Log()
self.oneslike = P.OnesLike()
self.checktensor = CheckTensor()
@property @property
def scale(self): def scale(self):
@ -88,6 +90,7 @@ class ScalarAffine(Bijector):
.. math:: .. math::
f(x) = a * x + b f(x) = a * x + b
""" """
self.checktensor(x, 'x')
return self.scale * x + self.shift return self.scale * x + self.shift
def _inverse(self, y): def _inverse(self, y):
@ -95,22 +98,25 @@ class ScalarAffine(Bijector):
.. math:: .. math::
f(y) = \frac{y - b}{a} f(y) = \frac{y - b}{a}
""" """
self.checktensor(y, 'y')
return (y - self.shift) / self.scale return (y - self.shift) / self.scale
def _forward_log_jacobian(self, value): def _forward_log_jacobian(self, x):
r""" r"""
.. math:: .. math::
f(x) = a * x + b f(x) = a * x + b
f'(x) = a f'(x) = a
\log(f'(x)) = \log(a) \log(f'(x)) = \log(a)
""" """
return self.log(self.scale) * self.oneslike(value) self.checktensor(x, 'x')
return self.log(self.abs(self.scale))
def _inverse_log_jacobian(self, value): def _inverse_log_jacobian(self, y):
r""" r"""
.. math:: .. math::
f(y) = \frac{(y - b)}{a} f(y) = \frac{(y - b)}{a}
f'(x) = \frac{1.0}{a} f'(x) = \frac{1.0}{a}
\log(f'(x)) = - \log(a) \log(f'(x)) = - \log(a)
""" """
return -1. * self.log(self.scale) * self.oneslike(value) self.checktensor(y, 'y')
return -1. * self.log(self.abs(self.scale))

View File

@ -13,10 +13,12 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Softplus Bijector""" """Softplus Bijector"""
import numpy as np
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from mindspore.nn.layer.activation import LogSigmoid from mindspore.nn.layer.activation import LogSigmoid
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from ..distribution._utils.utils import cast_to_tensor from ..distribution._utils.utils import cast_to_tensor, CheckTensor
from .bijector import Bijector from .bijector import Bijector
class Softplus(Bijector): class Softplus(Bijector):
@ -52,19 +54,28 @@ class Softplus(Bijector):
sharpness=1.0, sharpness=1.0,
name='Softplus'): name='Softplus'):
param = dict(locals()) param = dict(locals())
validator.check_value_type('sharpness', sharpness, [float], name) validator.check_value_type('sharpness', sharpness, [int, float], name)
super(Softplus, self).__init__(name=name, param=param) super(Softplus, self).__init__(name=name, param=param)
self._sharpness = cast_to_tensor(sharpness) self._sharpness = cast_to_tensor(sharpness)
self.abs = P.Abs()
self.exp = P.Exp() self.exp = P.Exp()
self.expm1 = self._expm1_by_step self.expm1 = self._expm1_by_step
self.fill = P.Fill()
self.greater = P.Greater()
self.less = P.Less()
self.log_sigmoid = LogSigmoid() self.log_sigmoid = LogSigmoid()
self.log = P.Log() self.log = P.Log()
self.logicalor = P.LogicalOr()
self.select = P.Select()
self.shape = P.Shape()
self.sigmoid = P.Sigmoid() self.sigmoid = P.Sigmoid()
self.softplus = self._softplus self.softplus = self._softplus
self.inverse_softplus = self._inverse_softplus self.inverse_softplus = self._inverse_softplus
self.checktensor = CheckTensor()
self.threshold = np.log(np.finfo(np.float32).eps) + 1
def _expm1_by_step(self, x): def _expm1_by_step(self, x):
""" """
Expm1 ops under GPU context. Expm1 ops under GPU context.
@ -72,7 +83,15 @@ class Softplus(Bijector):
return self.exp(x) - 1.0 return self.exp(x) - 1.0
def _softplus(self, x): def _softplus(self, x):
return self.log(self.exp(x) + 1.0) too_small = self.less(x, self.threshold)
too_large = self.greater(x, -self.threshold)
too_small_value = self.exp(x)
too_large_value = x
ones = self.fill(mstype.float32, self.shape(x), 1.0)
too_small_or_too_large = self.logicalor(too_small, too_large)
x = self.select(too_small_or_too_large, ones, x)
y = self.log(self.exp(x) + 1.0)
return self.select(too_small, too_small_value, self.select(too_large, too_large_value, y))
def _inverse_softplus(self, x): def _inverse_softplus(self, x):
r""" r"""
@ -80,7 +99,15 @@ class Softplus(Bijector):
f(x) = \frac{\log(1 + e^{x}))} f(x) = \frac{\log(1 + e^{x}))}
f^{-1}(y) = \frac{\log(e^{y} - 1)} f^{-1}(y) = \frac{\log(e^{y} - 1)}
""" """
return self.log(self.expm1(x)) too_small = self.less(x, self.threshold)
too_large = self.greater(x, -self.threshold)
too_small_value = self.log(x)
too_large_value = x
ones = self.fill(mstype.float32, self.shape(x), 1.0)
too_small_or_too_large = self.logicalor(too_small, too_large)
x = self.select(too_small_or_too_large, ones, x)
y = x + self.log(self.abs(self.expm1(-x)))
return self.select(too_small, too_small_value, self.select(too_large, too_large_value, y))
@property @property
def sharpness(self): def sharpness(self):
@ -94,6 +121,7 @@ class Softplus(Bijector):
return shape return shape
def _forward(self, x): def _forward(self, x):
self.checktensor(x, 'x')
scaled_value = self.sharpness * x scaled_value = self.sharpness * x
return self.softplus(scaled_value) / self.sharpness return self.softplus(scaled_value) / self.sharpness
@ -103,6 +131,7 @@ class Softplus(Bijector):
f(x) = \frac{\log(1 + e^{kx}))}{k} f(x) = \frac{\log(1 + e^{kx}))}{k}
f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k} f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k}
""" """
self.checktensor(y, 'y')
scaled_value = self.sharpness * y scaled_value = self.sharpness * y
return self.inverse_softplus(scaled_value) / self.sharpness return self.inverse_softplus(scaled_value) / self.sharpness
@ -113,6 +142,7 @@ class Softplus(Bijector):
f'(x) = \frac{e^{kx}}{ 1 + e^{kx}} f'(x) = \frac{e^{kx}}{ 1 + e^{kx}}
\log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx) \log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx)
""" """
self.checktensor(x, 'x')
scaled_value = self.sharpness * x scaled_value = self.sharpness * x
return self.log_sigmoid(scaled_value) return self.log_sigmoid(scaled_value)
@ -123,5 +153,6 @@ class Softplus(Bijector):
f'(y) = \frac{e^{ky}}{e^{ky} - 1} f'(y) = \frac{e^{ky}}{e^{ky} - 1}
\log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky) \log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky)
""" """
self.checktensor(y, 'y')
scaled_value = self.sharpness * y scaled_value = self.sharpness * y
return scaled_value - self.inverse_softplus(scaled_value) return scaled_value - self.inverse_softplus(scaled_value)

View File

@ -15,7 +15,8 @@
"""Utitly functions to help distribution class.""" """Utitly functions to help distribution class."""
import numpy as np import numpy as np
from mindspore.ops import _utils as utils from mindspore.ops import _utils as utils
from mindspore.ops.primitive import constexpr from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register
from mindspore._checkparam import Validator as validator
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
@ -53,7 +54,9 @@ def cast_to_tensor(t, hint_type=mstype.float32):
raise TypeError(f'Input cannot be Type Bool') raise TypeError(f'Input cannot be Type Bool')
if isinstance(t, (int, float)): if isinstance(t, (int, float)):
return Tensor(t, dtype=t_type) return Tensor(t, dtype=t_type)
raise TypeError("Input type is not supported.") invalid_type = type(t)
raise TypeError(f"Unable to convert input of type {invalid_type} to a Tensor of type {t_type}")
def convert_to_batch(t, batch_shape, required_type): def convert_to_batch(t, batch_shape, required_type):
""" """
@ -274,5 +277,51 @@ def raise_none_error(name):
@constexpr @constexpr
def check_distribution_name(name, expected_name): def check_distribution_name(name, expected_name):
if name is None:
raise ValueError(f"Distribution should be a constant which is not None.")
if name != expected_name: if name != expected_name:
raise ValueError(f"Distribution should be {expected_name}.") raise ValueError(f"Expected distribution name is {expected_name}, but got {name}.")
class CheckTuple(PrimitiveWithInfer):
"""
Check if input is a tuple.
"""
@prim_attr_register
def __init__(self):
"""init Cast"""
super(CheckTuple, self).__init__("CheckTuple")
self.init_prim_io_names(inputs=['x'], outputs=['dummy_output'])
def __infer__(self, x, name):
if not isinstance(x['dtype'], tuple):
raise TypeError("Input type should be a tuple: " + name["value"])
out = {'shape': None,
'dtype': None,
'value': None}
return out
def __call__(self, *args):
return
class CheckTensor(PrimitiveWithInfer):
"""
Check if input is a Tensor.
"""
@prim_attr_register
def __init__(self):
"""init Cast"""
super(CheckTensor, self).__init__("CheckTensor")
self.init_prim_io_names(inputs=['x'], outputs=['dummy_output'])
def __infer__(self, x, name):
src_type = x['dtype']
validator.check_subclass("input", src_type, [mstype.tensor], name["value"])
out = {'shape': None,
'dtype': None,
'value': None}
return out
def __call__(self, *args):
return

View File

@ -18,6 +18,7 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error
from ._utils.utils import CheckTensor, CheckTuple
class Bernoulli(Distribution): class Bernoulli(Distribution):
""" """
@ -123,6 +124,9 @@ class Bernoulli(Distribution):
self.sqrt = P.Sqrt() self.sqrt = P.Sqrt()
self.uniform = C.uniform self.uniform = C.uniform
self.checktensor = CheckTensor()
self.checktuple = CheckTuple()
def extend_repr(self): def extend_repr(self):
if self.is_scalar_batch: if self.is_scalar_batch:
str_info = f'probs = {self.probs}' str_info = f'probs = {self.probs}'
@ -137,14 +141,21 @@ class Bernoulli(Distribution):
""" """
return self._probs return self._probs
def _check_param(self, probs1):
"""
Check availablity of distribution specific args probs1.
"""
if probs1 is not None:
self.checktensor(probs1, 'probs1')
return self.cast(probs1, self.parameter_type)
return self.probs if self.probs is not None else raise_none_error('probs1')
def _mean(self, probs1=None): def _mean(self, probs1=None):
r""" r"""
.. math:: .. math::
MEAN(B) = probs1 MEAN(B) = probs1
""" """
probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs probs1 = self._check_param(probs1)
if probs1 is None:
raise_none_error("probs1")
return probs1 return probs1
def _mode(self, probs1=None): def _mode(self, probs1=None):
@ -152,9 +163,7 @@ class Bernoulli(Distribution):
.. math:: .. math::
MODE(B) = 1 if probs1 > 0.5 else = 0 MODE(B) = 1 if probs1 > 0.5 else = 0
""" """
probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs probs1 = self._check_param(probs1)
if probs1 is None:
raise_none_error("probs1")
prob_type = self.dtypeop(probs1) prob_type = self.dtypeop(probs1)
zeros = self.fill(prob_type, self.shape(probs1), 0.0) zeros = self.fill(prob_type, self.shape(probs1), 0.0)
ones = self.fill(prob_type, self.shape(probs1), 1.0) ones = self.fill(prob_type, self.shape(probs1), 1.0)
@ -166,24 +175,20 @@ class Bernoulli(Distribution):
.. math:: .. math::
VAR(B) = probs1 * probs0 VAR(B) = probs1 * probs0
""" """
probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs probs1 = self._check_param(probs1)
if probs1 is None:
raise_none_error("probs1")
probs0 = 1.0 - probs1 probs0 = 1.0 - probs1
return self.exp(self.log(probs0) + self.log(probs1)) return self.exp(self.log(probs0) + self.log(probs1))
def _entropy(self, probs=None): def _entropy(self, probs1=None):
r""" r"""
.. math:: .. math::
H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1) H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1)
""" """
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs probs1 = self._check_param(probs1)
if probs1 is None:
raise_none_error("probs")
probs0 = 1 - probs1 probs0 = 1 - probs1
return -1 * (probs0 * self.log(probs0)) - (probs1 * self.log(probs1)) return -1 * (probs0 * self.log(probs0)) - (probs1 * self.log(probs1))
def _cross_entropy(self, dist, probs1_b, probs1_a=None): def _cross_entropy(self, dist, probs1_b, probs1=None):
""" """
Evaluate cross_entropy between Bernoulli distributions. Evaluate cross_entropy between Bernoulli distributions.
@ -193,9 +198,9 @@ class Bernoulli(Distribution):
probs1_a (Tensor): probs1 of distribution a. Default: self.probs. probs1_a (Tensor): probs1 of distribution a. Default: self.probs.
""" """
check_distribution_name(dist, 'Bernoulli') check_distribution_name(dist, 'Bernoulli')
return self._entropy(probs=probs1_a) + self._kl_loss(dist, probs1_b, probs1_a) return self._entropy(probs1) + self._kl_loss(dist, probs1_b, probs1)
def _log_prob(self, value, probs=None): def _log_prob(self, value, probs1=None):
r""" r"""
pmf of Bernoulli distribution. pmf of Bernoulli distribution.
@ -207,17 +212,14 @@ class Bernoulli(Distribution):
pmf(k) = probs1 if k = 1; pmf(k) = probs1 if k = 1;
pmf(k) = probs0 if k = 0; pmf(k) = probs0 if k = 0;
""" """
if value is None: self.checktensor(value, 'value')
raise_none_error("value")
value = self.cast(value, mstype.float32) value = self.cast(value, mstype.float32)
value = self.floor(value) value = self.floor(value)
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs probs1 = self._check_param(probs1)
if probs1 is None:
raise_none_error("probs")
probs0 = 1.0 - probs1 probs0 = 1.0 - probs1
return self.log(probs1) * value + self.log(probs0) * (1.0 - value) return self.log(probs1) * value + self.log(probs0) * (1.0 - value)
def _cdf(self, value, probs=None): def _cdf(self, value, probs1=None):
r""" r"""
cdf of Bernoulli distribution. cdf of Bernoulli distribution.
@ -230,13 +232,10 @@ class Bernoulli(Distribution):
cdf(k) = probs0 if 0 <= k <1; cdf(k) = probs0 if 0 <= k <1;
cdf(k) = 1 if k >=1; cdf(k) = 1 if k >=1;
""" """
if value is None: self.checktensor(value, 'value')
raise_none_error("value")
value = self.cast(value, mstype.float32) value = self.cast(value, mstype.float32)
value = self.floor(value) value = self.floor(value)
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs probs1 = self._check_param(probs1)
if probs1 is None:
raise_none_error("probs")
prob_type = self.dtypeop(probs1) prob_type = self.dtypeop(probs1)
value = value * self.fill(prob_type, self.shape(probs1), 1.0) value = value * self.fill(prob_type, self.shape(probs1), 1.0)
probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0) probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0)
@ -247,7 +246,7 @@ class Bernoulli(Distribution):
less_than_zero = self.select(comp_zero, zeros, probs0) less_than_zero = self.select(comp_zero, zeros, probs0)
return self.select(comp_one, less_than_zero, ones) return self.select(comp_one, less_than_zero, ones)
def _kl_loss(self, dist, probs1_b, probs1_a=None): def _kl_loss(self, dist, probs1_b, probs1=None):
r""" r"""
Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b). Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b).
@ -261,17 +260,14 @@ class Bernoulli(Distribution):
probs0_a * \log(\frac{probs0_a}{probs0_b}) probs0_a * \log(\frac{probs0_a}{probs0_b})
""" """
check_distribution_name(dist, 'Bernoulli') check_distribution_name(dist, 'Bernoulli')
if probs1_b is None: self.checktensor(probs1_b, 'probs1_b')
raise_none_error("probs1_b")
probs1_b = self.cast(probs1_b, self.parameter_type) probs1_b = self.cast(probs1_b, self.parameter_type)
probs1_a = self.cast(probs1_a, self.parameter_type) if probs1_a is not None else self.probs probs1_a = self._check_param(probs1)
if probs1_a is None:
raise_none_error("probs1_a")
probs0_a = 1.0 - probs1_a probs0_a = 1.0 - probs1_a
probs0_b = 1.0 - probs1_b probs0_b = 1.0 - probs1_b
return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b) return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b)
def _sample(self, shape=(), probs=None): def _sample(self, shape=(), probs1=None):
""" """
Sampling. Sampling.
@ -282,9 +278,8 @@ class Bernoulli(Distribution):
Returns: Returns:
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.
""" """
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs self.checktuple(shape, 'shape')
if probs1 is None: probs1 = self._check_param(probs1)
raise_none_error("probs")
origin_shape = shape + self.shape(probs1) origin_shape = shape + self.shape(probs1)
if origin_shape == (): if origin_shape == ():
sample_shape = (1,) sample_shape = (1,)

View File

@ -20,6 +20,7 @@ from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\ from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\
raise_none_error raise_none_error
from ._utils.utils import CheckTensor, CheckTuple
class Exponential(Distribution): class Exponential(Distribution):
""" """
@ -125,6 +126,9 @@ class Exponential(Distribution):
self.sq = P.Square() self.sq = P.Square()
self.uniform = C.uniform self.uniform = C.uniform
self.checktensor = CheckTensor()
self.checktuple = CheckTuple()
def extend_repr(self): def extend_repr(self):
if self.is_scalar_batch: if self.is_scalar_batch:
str_info = f'rate = {self.rate}' str_info = f'rate = {self.rate}'
@ -139,14 +143,21 @@ class Exponential(Distribution):
""" """
return self._rate return self._rate
def _check_param(self, rate):
"""
Check availablity of distribution specific args rate.
"""
if rate is not None:
self.checktensor(rate, 'rate')
return self.cast(rate, self.parameter_type)
return self.rate if self.rate is not None else raise_none_error('rate')
def _mean(self, rate=None): def _mean(self, rate=None):
r""" r"""
.. math:: .. math::
MEAN(EXP) = \frac{1.0}{\lambda}. MEAN(EXP) = \frac{1.0}{\lambda}.
""" """
rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate rate = self._check_param(rate)
if rate is None:
raise_none_error("rate")
return 1.0 / rate return 1.0 / rate
def _mode(self, rate=None): def _mode(self, rate=None):
@ -154,9 +165,7 @@ class Exponential(Distribution):
.. math:: .. math::
MODE(EXP) = 0. MODE(EXP) = 0.
""" """
rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate rate = self._check_param(rate)
if rate is None:
raise_none_error("rate")
return self.fill(self.dtype, self.shape(rate), 0.) return self.fill(self.dtype, self.shape(rate), 0.)
def _sd(self, rate=None): def _sd(self, rate=None):
@ -164,9 +173,7 @@ class Exponential(Distribution):
.. math:: .. math::
sd(EXP) = \frac{1.0}{\lambda}. sd(EXP) = \frac{1.0}{\lambda}.
""" """
rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate rate = self._check_param(rate)
if rate is None:
raise_none_error("rate")
return 1.0 / rate return 1.0 / rate
def _entropy(self, rate=None): def _entropy(self, rate=None):
@ -174,13 +181,10 @@ class Exponential(Distribution):
.. math:: .. math::
H(Exp) = 1 - \log(\lambda). H(Exp) = 1 - \log(\lambda).
""" """
rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate rate = self._check_param(rate)
if rate is None:
raise_none_error("rate")
return 1.0 - self.log(rate) return 1.0 - self.log(rate)
def _cross_entropy(self, dist, rate_b, rate=None):
def _cross_entropy(self, dist, rate_b, rate_a=None):
""" """
Evaluate cross_entropy between Exponential distributions. Evaluate cross_entropy between Exponential distributions.
@ -190,7 +194,7 @@ class Exponential(Distribution):
rate_a (Tensor): rate of distribution a. Default: self.rate. rate_a (Tensor): rate of distribution a. Default: self.rate.
""" """
check_distribution_name(dist, 'Exponential') check_distribution_name(dist, 'Exponential')
return self._entropy(rate=rate_a) + self._kl_loss(dist, rate_b, rate_a) return self._entropy(rate) + self._kl_loss(dist, rate_b, rate)
def _prob(self, value, rate=None): def _prob(self, value, rate=None):
@ -208,12 +212,9 @@ class Exponential(Distribution):
.. math:: .. math::
pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0 pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0
""" """
if value is None: self.checktensor(value, "value")
raise_none_error("value")
value = self.cast(value, self.dtype) value = self.cast(value, self.dtype)
rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate rate = self._check_param(rate)
if rate is None:
raise_none_error("rate")
prob = self.exp(self.log(rate) - rate * value) prob = self.exp(self.log(rate) - rate * value)
zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0) zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0)
comp = self.less(value, zeros) comp = self.less(value, zeros)
@ -233,19 +234,16 @@ class Exponential(Distribution):
.. math:: .. math::
cdf(x) = 1.0 - \exp(-1 * \lambda * x) if x >= 0 else 0 cdf(x) = 1.0 - \exp(-1 * \lambda * x) if x >= 0 else 0
""" """
if value is None: self.checktensor(value, 'value')
raise_none_error("value")
value = self.cast(value, self.dtype) value = self.cast(value, self.dtype)
rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate rate = self._check_param(rate)
if rate is None:
raise_none_error("rate")
cdf = 1.0 - self.exp(-1. * rate * value) cdf = 1.0 - self.exp(-1. * rate * value)
zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0) zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
comp = self.less(value, zeros) comp = self.less(value, zeros)
return self.select(comp, zeros, cdf) return self.select(comp, zeros, cdf)
def _kl_loss(self, dist, rate_b, rate_a=None): def _kl_loss(self, dist, rate_b, rate=None):
""" """
Evaluate exp-exp kl divergence, i.e. KL(a||b). Evaluate exp-exp kl divergence, i.e. KL(a||b).
@ -255,12 +253,9 @@ class Exponential(Distribution):
rate_a (Tensor): rate of distribution a. Default: self.rate. rate_a (Tensor): rate of distribution a. Default: self.rate.
""" """
check_distribution_name(dist, 'Exponential') check_distribution_name(dist, 'Exponential')
if rate_b is None: self.checktensor(rate_b, 'rate_b')
raise_none_error("rate_b")
rate_b = self.cast(rate_b, self.parameter_type) rate_b = self.cast(rate_b, self.parameter_type)
rate_a = self.cast(rate_a, self.parameter_type) if rate_a is not None else self.rate rate_a = self._check_param(rate)
if rate_a is None:
raise_none_error("rate_a")
return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0 return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0
def _sample(self, shape=(), rate=None): def _sample(self, shape=(), rate=None):
@ -274,9 +269,8 @@ class Exponential(Distribution):
Returns: Returns:
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.
""" """
rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate self.checktuple(shape, 'shape')
if rate is None: rate = self._check_param(rate)
raise_none_error("rate")
origin_shape = shape + self.shape(rate) origin_shape = shape + self.shape(rate)
if origin_shape == (): if origin_shape == ():
sample_shape = (1,) sample_shape = (1,)

View File

@ -20,6 +20,7 @@ from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\ from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\
raise_none_error raise_none_error
from ._utils.utils import CheckTensor, CheckTuple
class Geometric(Distribution): class Geometric(Distribution):
""" """
@ -129,6 +130,9 @@ class Geometric(Distribution):
self.sqrt = P.Sqrt() self.sqrt = P.Sqrt()
self.uniform = C.uniform self.uniform = C.uniform
self.checktensor = CheckTensor()
self.checktuple = CheckTuple()
def extend_repr(self): def extend_repr(self):
if self.is_scalar_batch: if self.is_scalar_batch:
str_info = f'probs = {self.probs}' str_info = f'probs = {self.probs}'
@ -143,14 +147,21 @@ class Geometric(Distribution):
""" """
return self._probs return self._probs
def _check_param(self, probs1):
"""
Check availablity of distribution specific args probs1.
"""
if probs1 is not None:
self.checktensor(probs1, 'probs1')
return self.cast(probs1, self.parameter_type)
return self.probs if self.probs is not None else raise_none_error('probs1')
def _mean(self, probs1=None): def _mean(self, probs1=None):
r""" r"""
.. math:: .. math::
MEAN(Geo) = \fratc{1 - probs1}{probs1} MEAN(Geo) = \fratc{1 - probs1}{probs1}
""" """
probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs probs1 = self._check_param(probs1)
if probs1 is None:
raise_none_error("probs1")
return (1. - probs1) / probs1 return (1. - probs1) / probs1
def _mode(self, probs1=None): def _mode(self, probs1=None):
@ -158,9 +169,7 @@ class Geometric(Distribution):
.. math:: .. math::
MODE(Geo) = 0 MODE(Geo) = 0
""" """
probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs probs1 = self._check_param(probs1)
if probs1 is None:
raise_none_error("probs1")
return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.) return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.)
def _var(self, probs1=None): def _var(self, probs1=None):
@ -168,23 +177,19 @@ class Geometric(Distribution):
.. math:: .. math::
VAR(Geo) = \frac{1 - probs1}{probs1 ^ {2}} VAR(Geo) = \frac{1 - probs1}{probs1 ^ {2}}
""" """
probs1 = self.cast(probs1, self.parameter_type) if probs1 is not None else self.probs probs1 = self._check_param(probs1)
if probs1 is None:
raise_none_error("probs1")
return (1.0 - probs1) / self.sq(probs1) return (1.0 - probs1) / self.sq(probs1)
def _entropy(self, probs=None): def _entropy(self, probs1=None):
r""" r"""
.. math:: .. math::
H(Geo) = \frac{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1} H(Geo) = \frac{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1}
""" """
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs probs1 = self._check_param(probs1)
if probs1 is None:
raise_none_error("probs")
probs0 = 1.0 - probs1 probs0 = 1.0 - probs1
return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1 return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1
def _cross_entropy(self, dist, probs1_b, probs1_a=None): def _cross_entropy(self, dist, probs1_b, probs1=None):
r""" r"""
Evaluate cross_entropy between Geometric distributions. Evaluate cross_entropy between Geometric distributions.
@ -194,9 +199,9 @@ class Geometric(Distribution):
probs1_a (Tensor): probability of success of distribution a. Default: self.probs. probs1_a (Tensor): probability of success of distribution a. Default: self.probs.
""" """
check_distribution_name(dist, 'Geometric') check_distribution_name(dist, 'Geometric')
return self._entropy(probs=probs1_a) + self._kl_loss(dist, probs1_b, probs1_a) return self._entropy(probs1) + self._kl_loss(dist, probs1_b, probs1)
def _prob(self, value, probs=None): def _prob(self, value, probs1=None):
r""" r"""
pmf of Geometric distribution. pmf of Geometric distribution.
@ -208,19 +213,16 @@ class Geometric(Distribution):
pmf(k) = probs0 ^k * probs1 if k >= 0; pmf(k) = probs0 ^k * probs1 if k >= 0;
pmf(k) = 0 if k < 0. pmf(k) = 0 if k < 0.
""" """
if value is None: self.checktensor(value, 'value')
raise_none_error("value")
value = self.cast(value, mstype.float32) value = self.cast(value, mstype.float32)
value = self.floor(value) value = self.floor(value)
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs probs1 = self._check_param(probs1)
if probs1 is None:
raise_none_error("probs")
pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1)) pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1))
zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0) zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0)
comp = self.less(value, zeros) comp = self.less(value, zeros)
return self.select(comp, zeros, pmf) return self.select(comp, zeros, pmf)
def _cdf(self, value, probs=None): def _cdf(self, value, probs1=None):
r""" r"""
cdf of Geometric distribution. cdf of Geometric distribution.
@ -233,13 +235,10 @@ class Geometric(Distribution):
cdf(k) = 0 if k < 0. cdf(k) = 0 if k < 0.
""" """
if value is None: self.checktensor(value, 'value')
raise_none_error("value")
value = self.cast(value, mstype.float32) value = self.cast(value, mstype.float32)
value = self.floor(value) value = self.floor(value)
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs probs1 = self._check_param(probs1)
if probs1 is None:
raise_none_error("probs")
probs0 = 1.0 - probs1 probs0 = 1.0 - probs1
cdf = 1.0 - self.pow(probs0, value + 1.0) cdf = 1.0 - self.pow(probs0, value + 1.0)
zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0) zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0)
@ -247,7 +246,7 @@ class Geometric(Distribution):
return self.select(comp, zeros, cdf) return self.select(comp, zeros, cdf)
def _kl_loss(self, dist, probs1_b, probs1_a=None): def _kl_loss(self, dist, probs1_b, probs1=None):
r""" r"""
Evaluate Geometric-Geometric kl divergence, i.e. KL(a||b). Evaluate Geometric-Geometric kl divergence, i.e. KL(a||b).
@ -260,17 +259,14 @@ class Geometric(Distribution):
KL(a||b) = \log(\frac{probs1_a}{probs1_b}) + \frac{probs0_a}{probs1_a} * \log(\frac{probs0_a}{probs0_b}) KL(a||b) = \log(\frac{probs1_a}{probs1_b}) + \frac{probs0_a}{probs1_a} * \log(\frac{probs0_a}{probs0_b})
""" """
check_distribution_name(dist, 'Geometric') check_distribution_name(dist, 'Geometric')
if probs1_b is None: self.checktensor(probs1_b, 'probs1_b')
raise_none_error("probs1_b")
probs1_b = self.cast(probs1_b, self.parameter_type) probs1_b = self.cast(probs1_b, self.parameter_type)
probs1_a = self.cast(probs1_a, self.parameter_type) if probs1_a is not None else self.probs probs1_a = self._check_param(probs1)
if probs1_a is None:
raise_none_error("probs1_a")
probs0_a = 1.0 - probs1_a probs0_a = 1.0 - probs1_a
probs0_b = 1.0 - probs1_b probs0_b = 1.0 - probs1_b
return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b) return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b)
def _sample(self, shape=(), probs=None): def _sample(self, shape=(), probs1=None):
""" """
Sampling. Sampling.
@ -281,9 +277,8 @@ class Geometric(Distribution):
Returns: Returns:
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.
""" """
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs self.checktuple(shape, 'shape')
if probs1 is None: probs1 = self._check_param(probs1)
raise_none_error("probs")
origin_shape = shape + self.shape(probs1) origin_shape = shape + self.shape(probs1)
if origin_shape == (): if origin_shape == ():
sample_shape = (1,) sample_shape = (1,)

View File

@ -20,6 +20,7 @@ from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import convert_to_batch, check_greater_zero, check_type, check_distribution_name,\ from ._utils.utils import convert_to_batch, check_greater_zero, check_type, check_distribution_name,\
raise_none_error raise_none_error
from ._utils.utils import CheckTensor, CheckTuple
class Normal(Distribution): class Normal(Distribution):
""" """
@ -112,7 +113,6 @@ class Normal(Distribution):
self._mean_value = mean self._mean_value = mean
self._sd_value = sd self._sd_value = sd
#ops needed for the class #ops needed for the class
self.squeeze = P.Squeeze(0) self.squeeze = P.Squeeze(0)
self.cast = P.Cast() self.cast = P.Cast()
@ -127,6 +127,9 @@ class Normal(Distribution):
self.sqrt = P.Sqrt() self.sqrt = P.Sqrt()
self.zeroslike = P.ZerosLike() self.zeroslike = P.ZerosLike()
self.checktensor = CheckTensor()
self.checktuple = CheckTuple()
def extend_repr(self): def extend_repr(self):
if self.is_scalar_batch: if self.is_scalar_batch:
str_info = f'mean = {self._mean_value}, standard deviation = {self._sd_value}' str_info = f'mean = {self._mean_value}, standard deviation = {self._sd_value}'
@ -140,40 +143,44 @@ class Normal(Distribution):
""" """
return self.exp(x) - 1.0 return self.exp(x) - 1.0
def _check_param(self, mean, sd):
"""
Check availablity of distribution specific args mean and sd.
"""
if mean is not None:
self.checktensor(mean, 'mean')
mean = self.cast(mean, self.parameter_type)
else:
mean = self._mean_value if self._mean_value is not None else raise_none_error('mean')
if sd is not None:
self.checktensor(sd, 'sd')
sd = self.cast(sd, self.parameter_type)
else:
sd = self._sd_value if self._sd_value is not None else raise_none_error('sd')
batch_shape = self.shape(mean + sd)
mean = mean * self.fill(self.dtype, batch_shape, 1.0)
sd = sd * self.fill(self.dtype, batch_shape, 1.0)
return mean, sd
def _mean(self, mean=None, sd=None): def _mean(self, mean=None, sd=None):
""" """
Mean of the distribution. Mean of the distribution.
""" """
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value mean, sd = self._check_param(mean, sd)
if mean is None:
raise_none_error("mean")
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
if sd is None:
raise_none_error("sd")
return mean return mean
def _mode(self, mean=None, sd=None): def _mode(self, mean=None, sd=None):
""" """
Mode of the distribution. Mode of the distribution.
""" """
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value mean, sd = self._check_param(mean, sd)
if mean is None:
raise_none_error("mean")
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
if sd is None:
raise_none_error("sd")
return mean return mean
def _sd(self, mean=None, sd=None): def _sd(self, mean=None, sd=None):
""" """
Standard deviation of the distribution. Standard deviation of the distribution.
""" """
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value mean, sd = self._check_param(mean, sd)
if mean is None:
raise_none_error("mean")
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
if sd is None:
raise_none_error("sd")
return sd return sd
def _entropy(self, mean=None, sd=None): def _entropy(self, mean=None, sd=None):
@ -183,15 +190,10 @@ class Normal(Distribution):
.. math:: .. math::
H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma))) H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma)))
""" """
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value mean, sd = self._check_param(mean, sd)
if mean is None:
raise_none_error("mean")
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
if sd is None:
raise_none_error("sd")
return self.log(self.sqrt(self.const(np.e * 2. * np.pi))) + self.log(sd) return self.log(self.sqrt(self.const(np.e * 2. * np.pi))) + self.log(sd)
def _cross_entropy(self, dist, mean_b, sd_b, mean_a=None, sd_a=None): def _cross_entropy(self, dist, mean_b, sd_b, mean=None, sd=None):
r""" r"""
Evaluate cross_entropy between normal distributions. Evaluate cross_entropy between normal distributions.
@ -203,7 +205,7 @@ class Normal(Distribution):
sd_a (Tensor): standard deviation distribution a. Default: self._sd_value. sd_a (Tensor): standard deviation distribution a. Default: self._sd_value.
""" """
check_distribution_name(dist, 'Normal') check_distribution_name(dist, 'Normal')
return self._entropy(mean=mean_a, sd=sd_a) + self._kl_loss(dist, mean_b, sd_b, mean_a, sd_a) return self._entropy(mean, sd) + self._kl_loss(dist, mean_b, sd_b, mean, sd)
def _log_prob(self, value, mean=None, sd=None): def _log_prob(self, value, mean=None, sd=None):
r""" r"""
@ -217,15 +219,9 @@ class Normal(Distribution):
.. math:: .. math::
L(x) = -1* \frac{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2)) L(x) = -1* \frac{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2))
""" """
if value is None: self.checktensor(value, 'value')
raise_none_error("value")
value = self.cast(value, self.dtype) value = self.cast(value, self.dtype)
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value mean, sd = self._check_param(mean, sd)
if mean is None:
raise_none_error("mean")
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
if sd is None:
raise_none_error("sd")
unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd)) unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd))
neg_normalization = -1. * self.log(self.const(2. * np.pi)) / 2. - self.log(sd) neg_normalization = -1. * self.log(self.const(2. * np.pi)) / 2. - self.log(sd)
return unnormalized_log_prob + neg_normalization return unnormalized_log_prob + neg_normalization
@ -242,20 +238,14 @@ class Normal(Distribution):
.. math:: .. math::
cdf(x) = 0.5 * (1+ Erf((x - \mu) / ( \sigma * \sqrt(2)))) cdf(x) = 0.5 * (1+ Erf((x - \mu) / ( \sigma * \sqrt(2))))
""" """
if value is None: self.checktensor(value, 'value')
raise_none_error("value")
value = self.cast(value, self.dtype) value = self.cast(value, self.dtype)
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value mean, sd = self._check_param(mean, sd)
if mean is None:
raise_none_error("mean")
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
if sd is None:
raise_none_error("sd")
sqrt2 = self.sqrt(self.const(2.0)) sqrt2 = self.sqrt(self.const(2.0))
adjusted = (value - mean) / (sd * sqrt2) adjusted = (value - mean) / (sd * sqrt2)
return 0.5 * (1.0 + self.erf(adjusted)) return 0.5 * (1.0 + self.erf(adjusted))
def _kl_loss(self, dist, mean_b, sd_b, mean_a=None, sd_a=None): def _kl_loss(self, dist, mean_b, sd_b, mean=None, sd=None):
r""" r"""
Evaluate Normal-Normal kl divergence, i.e. KL(a||b). Evaluate Normal-Normal kl divergence, i.e. KL(a||b).
@ -271,23 +261,15 @@ class Normal(Distribution):
0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b))) 0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b)))
""" """
check_distribution_name(dist, 'Normal') check_distribution_name(dist, 'Normal')
if mean_b is None: self.checktensor(mean_b, 'mean_b')
raise_none_error("mean_b") self.checktensor(sd_b, 'sd_b')
if sd_b is None:
raise_none_error("sd_b")
mean_b = self.cast(mean_b, self.parameter_type) mean_b = self.cast(mean_b, self.parameter_type)
sd_b = self.cast(sd_b, self.parameter_type) sd_b = self.cast(sd_b, self.parameter_type)
mean_a = self.cast(mean_a, self.parameter_type) if mean_a is not None else self._mean_value mean_a, sd_a = self._check_param(mean, sd)
sd_a = self.cast(sd_a, self.parameter_type) if sd_a is not None else self._sd_value
if mean_a is None:
raise_none_error("mean_a")
if sd_a is None:
raise_none_error("sd_a")
diff_log_scale = self.log(sd_a) - self.log(sd_b) diff_log_scale = self.log(sd_a) - self.log(sd_b)
squared_diff = self.sq(mean_a / sd_b - mean_b / sd_b) squared_diff = self.sq(mean_a / sd_b - mean_b / sd_b)
return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale
def _sample(self, shape=(), mean=None, sd=None): def _sample(self, shape=(), mean=None, sd=None):
""" """
Sampling. Sampling.
@ -300,12 +282,8 @@ class Normal(Distribution):
Returns: Returns:
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.
""" """
mean = self.cast(mean, self.parameter_type) if mean is not None else self._mean_value self.checktuple(shape, 'shape')
if mean is None: mean, sd = self._check_param(mean, sd)
raise_none_error("mean")
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
if sd is None:
raise_none_error("sd")
batch_shape = self.shape(mean + sd) batch_shape = self.shape(mean + sd)
origin_shape = shape + batch_shape origin_shape = shape + batch_shape
if origin_shape == (): if origin_shape == ():

View File

@ -19,6 +19,7 @@ from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\ from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\
raise_none_error raise_none_error
from ._utils.utils import CheckTensor, CheckTuple
class Uniform(Distribution): class Uniform(Distribution):
""" """
@ -129,6 +130,9 @@ class Uniform(Distribution):
self.zeroslike = P.ZerosLike() self.zeroslike = P.ZerosLike()
self.uniform = C.uniform self.uniform = C.uniform
self.checktensor = CheckTensor()
self.checktuple = CheckTuple()
def extend_repr(self): def extend_repr(self):
if self.is_scalar_batch: if self.is_scalar_batch:
str_info = f'low = {self.low}, high = {self.high}' str_info = f'low = {self.low}, high = {self.high}'
@ -136,6 +140,25 @@ class Uniform(Distribution):
str_info = f'batch_shape = {self._broadcast_shape}' str_info = f'batch_shape = {self._broadcast_shape}'
return str_info return str_info
def _check_param(self, low, high):
"""
Check availablity of distribution specific args low and high.
"""
if low is not None:
self.checktensor(low, 'low')
low = self.cast(low, self.parameter_type)
else:
low = self.low if self.low is not None else raise_none_error('low')
if high is not None:
self.checktensor(high, 'high')
high = self.cast(high, self.parameter_type)
else:
high = self.high if self.high is not None else raise_none_error('high')
batch_shape = self.shape(high - low)
high = high * self.fill(self.dtype, batch_shape, 1.0)
low = low * self.fill(self.dtype, batch_shape, 1.0)
return low, high
@property @property
def low(self): def low(self):
""" """
@ -156,12 +179,7 @@ class Uniform(Distribution):
.. math:: .. math::
range(U) = high -low range(U) = high -low
""" """
low = self.cast(low, self.parameter_type) if low is not None else self.low low, high = self._check_param(low, high)
if low is None:
raise_none_error("low")
high = self.cast(high, self.parameter_type) if high is not None else self.high
if high is None:
raise_none_error("high")
return high - low return high - low
def _mean(self, low=None, high=None): def _mean(self, low=None, high=None):
@ -169,12 +187,7 @@ class Uniform(Distribution):
.. math:: .. math::
MEAN(U) = \frac{low + high}{2}. MEAN(U) = \frac{low + high}{2}.
""" """
low = self.cast(low, self.parameter_type) if low is not None else self.low low, high = self._check_param(low, high)
if low is None:
raise_none_error("low")
high = self.cast(high, self.parameter_type) if high is not None else self.high
if high is None:
raise_none_error("high")
return (low + high) / 2. return (low + high) / 2.
def _var(self, low=None, high=None): def _var(self, low=None, high=None):
@ -182,12 +195,7 @@ class Uniform(Distribution):
.. math:: .. math::
VAR(U) = \frac{(high -low) ^ 2}{12}. VAR(U) = \frac{(high -low) ^ 2}{12}.
""" """
low = self.cast(low, self.parameter_type) if low is not None else self.low low, high = self._check_param(low, high)
if low is None:
raise_none_error("low")
high = self.cast(high, self.parameter_type) if high is not None else self.high
if high is None:
raise_none_error("high")
return self.sq(high - low) / 12.0 return self.sq(high - low) / 12.0
def _entropy(self, low=None, high=None): def _entropy(self, low=None, high=None):
@ -195,15 +203,10 @@ class Uniform(Distribution):
.. math:: .. math::
H(U) = \log(high - low). H(U) = \log(high - low).
""" """
low = self.cast(low, self.parameter_type) if low is not None else self.low low, high = self._check_param(low, high)
if low is None:
raise_none_error("low")
high = self.cast(high, self.parameter_type) if high is not None else self.high
if high is None:
raise_none_error("high")
return self.log(high - low) return self.log(high - low)
def _cross_entropy(self, dist, low_b, high_b, low_a=None, high_a=None): def _cross_entropy(self, dist, low_b, high_b, low=None, high=None):
""" """
Evaluate cross_entropy between Uniform distributoins. Evaluate cross_entropy between Uniform distributoins.
@ -215,7 +218,7 @@ class Uniform(Distribution):
high_a (Tensor): upper bound of distribution a. Default: self.high. high_a (Tensor): upper bound of distribution a. Default: self.high.
""" """
check_distribution_name(dist, 'Uniform') check_distribution_name(dist, 'Uniform')
return self._entropy(low=low_a, high=high_a) + self._kl_loss(dist, low_b, high_b, low_a, high_a) return self._entropy(low, high) + self._kl_loss(dist, low_b, high_b, low, high)
def _prob(self, value, low=None, high=None): def _prob(self, value, low=None, high=None):
r""" r"""
@ -231,15 +234,9 @@ class Uniform(Distribution):
pdf(x) = \frac{1.0}{high -low} if low <= x <= high; pdf(x) = \frac{1.0}{high -low} if low <= x <= high;
pdf(x) = 0 if x > high; pdf(x) = 0 if x > high;
""" """
if value is None: self.checktensor(value, 'value')
raise_none_error("value")
value = self.cast(value, self.dtype) value = self.cast(value, self.dtype)
low = self.cast(low, self.parameter_type) if low is not None else self.low low, high = self._check_param(low, high)
if low is None:
raise_none_error("low")
high = self.cast(high, self.parameter_type) if high is not None else self.high
if high is None:
raise_none_error("high")
neg_ones = self.fill(self.dtype, self.shape(value), -1.0) neg_ones = self.fill(self.dtype, self.shape(value), -1.0)
prob = self.exp(neg_ones * self.log(high - low)) prob = self.exp(neg_ones * self.log(high - low))
broadcast_shape = self.shape(prob) broadcast_shape = self.shape(prob)
@ -249,7 +246,7 @@ class Uniform(Distribution):
less_than_low = self.select(comp_lo, zeros, prob) less_than_low = self.select(comp_lo, zeros, prob)
return self.select(comp_hi, less_than_low, zeros) return self.select(comp_hi, less_than_low, zeros)
def _kl_loss(self, dist, low_b, high_b, low_a=None, high_a=None): def _kl_loss(self, dist, low_b, high_b, low=None, high=None):
""" """
Evaluate uniform-uniform kl divergence, i.e. KL(a||b). Evaluate uniform-uniform kl divergence, i.e. KL(a||b).
@ -261,19 +258,12 @@ class Uniform(Distribution):
high_a (Tensor): upper bound of distribution a. Default: self.high. high_a (Tensor): upper bound of distribution a. Default: self.high.
""" """
check_distribution_name(dist, 'Uniform') check_distribution_name(dist, 'Uniform')
if low_b is None: self.checktensor(low_b, 'low_b')
raise_none_error("low_b")
if high_b is None:
raise_none_error("high_b")
low_b = self.cast(low_b, self.parameter_type) low_b = self.cast(low_b, self.parameter_type)
self.checktensor(high_b, 'high_b')
high_b = self.cast(high_b, self.parameter_type) high_b = self.cast(high_b, self.parameter_type)
low_a = self.cast(low_a, self.parameter_type) if low_a is not None else self.low low_a, high_a = self._check_param(low, high)
if low_a is None: kl = self.log(high_b - low_b) - self.log(high_a - low_a)
raise_none_error("low_a")
high_a = self.cast(high_a, self.parameter_type) if high_a is not None else self.high
if high_a is None:
raise_none_error("high_a")
kl = self.log(high_b - low_b) / self.log(high_a - low_a)
comp = self.logicaland(self.lessequal(low_b, low_a), self.lessequal(high_a, high_b)) comp = self.logicaland(self.lessequal(low_b, low_a), self.lessequal(high_a, high_b))
return self.select(comp, kl, self.log(self.zeroslike(kl))) return self.select(comp, kl, self.log(self.zeroslike(kl)))
@ -291,15 +281,9 @@ class Uniform(Distribution):
cdf(x) = \frac{x - low}{high -low} if low <= x <= high; cdf(x) = \frac{x - low}{high -low} if low <= x <= high;
cdf(x) = 1 if x > high; cdf(x) = 1 if x > high;
""" """
if value is None: self.checktensor(value, 'value')
raise_none_error("value")
value = self.cast(value, self.dtype) value = self.cast(value, self.dtype)
low = self.cast(low, self.parameter_type) if low is not None else self.low low, high = self._check_param(low, high)
if low is None:
raise_none_error("low")
high = self.cast(high, self.parameter_type) if high is not None else self.high
if high is None:
raise_none_error("high")
prob = (value - low) / (high - low) prob = (value - low) / (high - low)
broadcast_shape = self.shape(prob) broadcast_shape = self.shape(prob)
zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0) zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
@ -321,12 +305,8 @@ class Uniform(Distribution):
Returns: Returns:
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.
""" """
low = self.cast(low, self.parameter_type) if low is not None else self.low self.checktuple(shape, 'shape')
if low is None: low, high = self._check_param(low, high)
raise_none_error("low")
high = self.cast(high, self.parameter_type) if high is not None else self.high
if high is None:
raise_none_error("high")
broadcast_shape = self.shape(low + high) broadcast_shape = self.shape(low + high)
origin_shape = shape + broadcast_shape origin_shape = shape + broadcast_shape
if origin_shape == (): if origin_shape == ():

View File

@ -75,7 +75,7 @@ def test_forward_jacobian():
forward_jacobian = Net2() forward_jacobian = Net2()
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
ans = forward_jacobian(x) ans = forward_jacobian(x)
expected = np.log([2.0, 2.0, 2.0, 2.0]) expected = np.log([2.0])
tol = 1e-6 tol = 1e-6
assert (np.abs(ans.asnumpy() - expected) < tol).all() assert (np.abs(ans.asnumpy() - expected) < tol).all()
@ -94,6 +94,6 @@ def test_backward_jacobian():
backward_jacobian = Net3() backward_jacobian = Net3()
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
ans = backward_jacobian(x) ans = backward_jacobian(x)
expected = np.log([0.5, 0.5, 0.5, 0.5]) expected = np.log([0.5])
tol = 1e-6 tol = 1e-6
assert (np.abs(ans.asnumpy() - expected) < tol).all() assert (np.abs(ans.asnumpy() - expected) < tol).all()

View File

@ -20,7 +20,7 @@ import mindspore.nn.probability.bijector as msb
from mindspore import Tensor from mindspore import Tensor
from mindspore import dtype from mindspore import dtype
context.set_context(device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell): class Net(nn.Cell):
""" """

View File

@ -88,7 +88,7 @@ def test_kl_loss():
high_a = 1.5 high_a = 1.5
low_b = -1.0 low_b = -1.0
high_b = 2.0 high_b = 2.0
expect_kl_loss = np.log(high_b - low_b) / np.log(high_a - low_a) expect_kl_loss = np.log(high_b - low_b) - np.log(high_a - low_a)
kl = KL() kl = KL()
output = kl(Tensor(low_b, dtype=dtype.float32), Tensor(high_b, dtype=dtype.float32)) output = kl(Tensor(low_b, dtype=dtype.float32), Tensor(high_b, dtype=dtype.float32))
tol = 1e-6 tol = 1e-6