diff --git a/mindspore/nn/probability/bijector/bijector.py b/mindspore/nn/probability/bijector/bijector.py index ac011fda33..1e36ae0906 100644 --- a/mindspore/nn/probability/bijector/bijector.py +++ b/mindspore/nn/probability/bijector/bijector.py @@ -26,7 +26,7 @@ class Bijector(Cell): is_constant_jacobian (bool): if the bijector has constant derivative. Default: False. is_injective (bool): if the bijector is an one-to-one mapping. Default: True. name (str): name of the bijector. Default: None. - dtype (mstype): type of the distribution the bijector can operate on. Default: None. + dtype (mindspore.dtype): type of the distribution the bijector can operate on. Default: None. param (dict): parameters used to initialize the bijector. Default: None. """ def __init__(self, @@ -110,7 +110,7 @@ class Bijector(Cell): *args: args[0] shall be either a distribution or the name of a bijector function. """ if isinstance(args[0], Distribution): - return TransformedDistribution(self, args[0]) + return TransformedDistribution(self, args[0], self.distribution.dtype) return super(Bijector, self).__call__(*args, **kwargs) def construct(self, name, *args, **kwargs): diff --git a/mindspore/nn/probability/bijector/softplus.py b/mindspore/nn/probability/bijector/softplus.py index 26f70c8fc7..105beb061b 100644 --- a/mindspore/nn/probability/bijector/softplus.py +++ b/mindspore/nn/probability/bijector/softplus.py @@ -22,7 +22,10 @@ from .bijector import Bijector class Softplus(Bijector): r""" Softplus Bijector. - This Bijector performs the operation: Y = \frac{\log(1 + e ^ {kX})}{k}, where k is the sharpness factor. + This Bijector performs the operation, where k is the sharpness factor. + + .. math:: + Y = \frac{\log(1 + e ^ {kX})}{k} Args: sharpness (float): scale factor. Default: 1.0. diff --git a/mindspore/nn/probability/distribution/_utils/utils.py b/mindspore/nn/probability/distribution/_utils/utils.py index 2c11106be5..2a5cf5e7bf 100644 --- a/mindspore/nn/probability/distribution/_utils/utils.py +++ b/mindspore/nn/probability/distribution/_utils/utils.py @@ -184,7 +184,7 @@ def check_greater(a, b, name_a, name_b): def check_prob(p): """ - Check if p is a proper probability, i.e. 0 <= p <=1. + Check if p is a proper probability, i.e. 0 < p <1. Args: p (Tensor, Parameter): value to be checked. @@ -196,12 +196,12 @@ def check_prob(p): if not isinstance(p.default_input, Tensor): return p = p.default_input - comp = np.less(p.asnumpy(), np.zeros(p.shape)) - if comp.any(): - raise ValueError('Probabilities should be greater than or equal to zero') - comp = np.greater(p.asnumpy(), np.ones(p.shape)) - if comp.any(): - raise ValueError('Probabilities should be less than or equal to one') + comp = np.less(np.zeros(p.shape), p.asnumpy()) + if not comp.all(): + raise ValueError('Probabilities should be greater than zero') + comp = np.greater(np.ones(p.shape), p.asnumpy()) + if not comp.all(): + raise ValueError('Probabilities should be less than one') def logits_to_probs(logits, is_binary=False): diff --git a/mindspore/nn/probability/distribution/bernoulli.py b/mindspore/nn/probability/distribution/bernoulli.py index 028ee175f1..9b48dd7a5e 100644 --- a/mindspore/nn/probability/distribution/bernoulli.py +++ b/mindspore/nn/probability/distribution/bernoulli.py @@ -110,6 +110,7 @@ class Bernoulli(Distribution): self.const = P.ScalarToArray() self.dtypeop = P.DType() self.erf = P.Erf() + self.exp = P.Exp() self.fill = P.Fill() self.log = P.Log() self.less = P.Less() @@ -159,7 +160,7 @@ class Bernoulli(Distribution): """ probs1 = self.probs if probs1 is None else probs1 probs0 = 1.0 - probs1 - return probs0 * probs1 + return self.exp(self.log(probs0) + self.log(probs1)) def _entropy(self, probs=None): r""" @@ -183,7 +184,7 @@ class Bernoulli(Distribution): return self._entropy(probs=probs1_a) + self._kl_loss(dist, probs1_b, probs1_a) return None - def _prob(self, value, probs=None): + def _log_prob(self, value, probs=None): r""" pmf of Bernoulli distribution. @@ -197,7 +198,7 @@ class Bernoulli(Distribution): """ probs1 = self.probs if probs is None else probs probs0 = 1.0 - probs1 - return (probs1 * value) + (probs0 * (1.0 - value)) + return self.log(probs1) * value + self.log(probs0) * (1.0 - value) def _cdf(self, value, probs=None): r""" diff --git a/mindspore/nn/probability/distribution/distribution.py b/mindspore/nn/probability/distribution/distribution.py index d3e7d3ccd9..fffc7ed69e 100644 --- a/mindspore/nn/probability/distribution/distribution.py +++ b/mindspore/nn/probability/distribution/distribution.py @@ -15,6 +15,7 @@ """basic""" from mindspore.nn.cell import Cell from mindspore._checkparam import Validator as validator +from mindspore._checkparam import Rel from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param class Distribution(Cell): @@ -28,12 +29,15 @@ class Distribution(Cell): Note: Derived class should override operations such as ,_mean, _prob, - and _log_prob. Arguments should be passed in through *args or **kwargs. + and _log_prob. Required arguments, such as value for _prob, + should be passed in through args or kwargs. dist_spec_args which specify + a new distribution are optional. - Dist_spec_args are unique for each type of distribution. For example, mean and sd - are the dist_spec_args for a Normal distribution. + dist_spec_args are unique for each type of distribution. For example, mean and sd + are the dist_spec_args for a Normal distribution, while rate is the dist_spec_args + for exponential distribution. - For all functions, passing in dist_spec_args, are optional. + For all functions, passing in dist_spec_args, is optional. Passing in the additional dist_spec_args will make the result to be evaluated with new distribution specified by the dist_spec_args. But it won't change the original distribuion. @@ -49,7 +53,7 @@ class Distribution(Cell): """ super(Distribution, self).__init__() validator.check_value_type('name', name, [str], 'distribution_name') - validator.check_value_type('seed', seed, [int], name) + validator.check_integer('seed', seed, 0, Rel.GE, name) self._name = name self._seed = seed @@ -191,7 +195,7 @@ class Distribution(Cell): Note: Args must include value. - Dist_spec_args are optional. + dist_spec_args are optional. """ return self._call_log_prob(*args, **kwargs) @@ -210,7 +214,7 @@ class Distribution(Cell): Note: Args must include value. - Dist_spec_args are optional. + dist_spec_args are optional. """ return self._call_prob(*args, **kwargs) @@ -229,7 +233,7 @@ class Distribution(Cell): Note: Args must include value. - Dist_spec_args are optional. + dist_spec_args are optional. """ return self._call_cdf(*args, **kwargs) @@ -266,7 +270,7 @@ class Distribution(Cell): Note: Args must include value. - Dist_spec_args are optional. + dist_spec_args are optional. """ return self._call_log_cdf(*args, **kwargs) @@ -285,7 +289,7 @@ class Distribution(Cell): Note: Args must include value. - Dist_spec_args are optional. + dist_spec_args are optional. """ return self._call_survival(*args, **kwargs) @@ -313,7 +317,7 @@ class Distribution(Cell): Note: Args must include value. - Dist_spec_args are optional. + dist_spec_args are optional. """ return self._call_log_survival(*args, **kwargs) @@ -341,7 +345,7 @@ class Distribution(Cell): Evaluate the mean. Note: - Dist_spec_args are optional. + dist_spec_args are optional. """ return self._mean(*args, **kwargs) @@ -350,7 +354,7 @@ class Distribution(Cell): Evaluate the mode. Note: - Dist_spec_args are optional. + dist_spec_args are optional. """ return self._mode(*args, **kwargs) @@ -359,7 +363,7 @@ class Distribution(Cell): Evaluate the standard deviation. Note: - Dist_spec_args are optional. + dist_spec_args are optional. """ return self._call_sd(*args, **kwargs) @@ -368,7 +372,7 @@ class Distribution(Cell): Evaluate the variance. Note: - Dist_spec_args are optional. + dist_spec_args are optional. """ return self._call_var(*args, **kwargs) @@ -395,7 +399,7 @@ class Distribution(Cell): Evaluate the entropy. Note: - Dist_spec_args are optional. + dist_spec_args are optional. """ return self._entropy(*args, **kwargs) @@ -424,7 +428,7 @@ class Distribution(Cell): Note: Shape of the sample is default to (). - Dist_spec_args are optional. + dist_spec_args are optional. """ return self._sample(*args, **kwargs) diff --git a/mindspore/nn/probability/distribution/exponential.py b/mindspore/nn/probability/distribution/exponential.py index 8564935e09..410f829f5d 100644 --- a/mindspore/nn/probability/distribution/exponential.py +++ b/mindspore/nn/probability/distribution/exponential.py @@ -199,7 +199,7 @@ class Exponential(Distribution): pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0 """ rate = self.rate if rate is None else rate - prob = rate * self.exp(-1. * rate * value) + prob = self.exp(self.log(rate) - rate * value) zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0) comp = self.less(value, zeros) return self.select(comp, zeros, prob) diff --git a/mindspore/nn/probability/distribution/geometric.py b/mindspore/nn/probability/distribution/geometric.py index 2c67bb5588..fbdfc8263b 100644 --- a/mindspore/nn/probability/distribution/geometric.py +++ b/mindspore/nn/probability/distribution/geometric.py @@ -113,6 +113,7 @@ class Geometric(Distribution): self.cast = P.Cast() self.const = P.ScalarToArray() self.dtypeop = P.DType() + self.exp = P.Exp() self.fill = P.Fill() self.floor = P.Floor() self.issubclass = P.IsSubClass() @@ -205,7 +206,7 @@ class Geometric(Distribution): value = self.floor(value) else: return None - pmf = self.pow((1.0 - probs1), value) * probs1 + pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1)) zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0) comp = self.less(value, zeros) return self.select(comp, zeros, pmf) diff --git a/mindspore/nn/probability/distribution/normal.py b/mindspore/nn/probability/distribution/normal.py index e1dfbee89d..1db22d9f73 100644 --- a/mindspore/nn/probability/distribution/normal.py +++ b/mindspore/nn/probability/distribution/normal.py @@ -18,7 +18,7 @@ from mindspore.ops import operations as P from mindspore.ops import composite as C from mindspore.common import dtype as mstype from .distribution import Distribution -from ._utils.utils import convert_to_batch, check_greater_equal_zero, check_type +from ._utils.utils import convert_to_batch, check_greater_zero, check_type class Normal(Distribution): @@ -106,7 +106,7 @@ class Normal(Distribution): if mean is not None and sd is not None: self._mean_value = convert_to_batch(mean, self.broadcast_shape, dtype) self._sd_value = convert_to_batch(sd, self.broadcast_shape, dtype) - check_greater_equal_zero(self._sd_value, "Standard deviation") + check_greater_zero(self._sd_value, "Standard deviation") else: self._mean_value = mean self._sd_value = sd @@ -166,7 +166,7 @@ class Normal(Distribution): H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma))) """ sd = self._sd_value if sd is None else sd - return self.log(self.sqrt(np.e * 2. * np.pi * self.sq(sd))) + return self.log(self.sqrt(self.const(np.e * 2. * np.pi))) + self.log(sd) def _cross_entropy(self, dist, mean_b, sd_b, mean_a=None, sd_a=None): r""" @@ -198,7 +198,7 @@ class Normal(Distribution): mean = self._mean_value if mean is None else mean sd = self._sd_value if sd is None else sd unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd)) - neg_normalization = -1. * self.log(self.sqrt(2. * np.pi * self.sq(sd))) + neg_normalization = -1. * self.log(self.sqrt(self.const(2. * np.pi))) - self.log(sd) return unnormalized_log_prob + neg_normalization def _cdf(self, value, mean=None, sd=None): diff --git a/mindspore/nn/probability/distribution/uniform.py b/mindspore/nn/probability/distribution/uniform.py index 6aff1ef775..db248c24d7 100644 --- a/mindspore/nn/probability/distribution/uniform.py +++ b/mindspore/nn/probability/distribution/uniform.py @@ -216,8 +216,8 @@ class Uniform(Distribution): """ low = self.low if low is None else low high = self.high if high is None else high - ones = self.fill(self.dtype, self.shape(value), 1.0) - prob = ones / (high - low) + neg_ones = self.fill(self.dtype, self.shape(value), -1.0) + prob = self.exp(neg_ones * self.log(high - low)) broadcast_shape = self.shape(prob) zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0) comp_lo = self.less(value, low) diff --git a/tests/ut/python/nn/distribution/test_bernoulli.py b/tests/ut/python/nn/distribution/test_bernoulli.py index cecf563219..e04438f0a9 100644 --- a/tests/ut/python/nn/distribution/test_bernoulli.py +++ b/tests/ut/python/nn/distribution/test_bernoulli.py @@ -28,7 +28,7 @@ def test_arguments(): """ b = msd.Bernoulli() assert isinstance(b, msd.Distribution) - b = msd.Bernoulli([0.0, 0.3, 0.5, 1.0], dtype=dtype.int32) + b = msd.Bernoulli([0.1, 0.3, 0.5, 0.9], dtype=dtype.int32) assert isinstance(b, msd.Distribution) def test_type(): @@ -51,6 +51,10 @@ def test_prob(): msd.Bernoulli([-0.1], dtype=dtype.int32) with pytest.raises(ValueError): msd.Bernoulli([1.1], dtype=dtype.int32) + with pytest.raises(ValueError): + msd.Bernoulli([0.0], dtype=dtype.int32) + with pytest.raises(ValueError): + msd.Bernoulli([1.0], dtype=dtype.int32) class BernoulliProb(nn.Cell): """ diff --git a/tests/ut/python/nn/distribution/test_geometric.py b/tests/ut/python/nn/distribution/test_geometric.py index 11c12f62dc..4685adaa42 100644 --- a/tests/ut/python/nn/distribution/test_geometric.py +++ b/tests/ut/python/nn/distribution/test_geometric.py @@ -29,7 +29,7 @@ def test_arguments(): """ g = msd.Geometric() assert isinstance(g, msd.Distribution) - g = msd.Geometric([0.0, 0.3, 0.5, 1.0], dtype=dtype.int32) + g = msd.Geometric([0.1, 0.3, 0.5, 0.9], dtype=dtype.int32) assert isinstance(g, msd.Distribution) def test_type(): @@ -52,6 +52,10 @@ def test_prob(): msd.Geometric([-0.1], dtype=dtype.int32) with pytest.raises(ValueError): msd.Geometric([1.1], dtype=dtype.int32) + with pytest.raises(ValueError): + msd.Geometric([0.0], dtype=dtype.int32) + with pytest.raises(ValueError): + msd.Geometric([1.0], dtype=dtype.int32) class GeometricProb(nn.Cell): """ diff --git a/tests/ut/python/nn/distribution/test_normal.py b/tests/ut/python/nn/distribution/test_normal.py index 76602bde80..d7ca2f4954 100644 --- a/tests/ut/python/nn/distribution/test_normal.py +++ b/tests/ut/python/nn/distribution/test_normal.py @@ -42,6 +42,12 @@ def test_seed(): with pytest.raises(TypeError): msd.Normal(0., 1., seed='seed') +def test_sd(): + with pytest.raises(ValueError): + msd.Normal(0., 0.) + with pytest.raises(ValueError): + msd.Normal(0., -1.) + def test_arguments(): """ args passing during initialization.