forked from mindspore-Ecosystem/mindspore
!4369 Added distribution_specific_args checking to some functions and fixed minor issues in docs
Merge pull request !4369 from XunDeng/pp_poc_v3
This commit is contained in:
commit
93f60269c3
|
@ -26,7 +26,7 @@ class Bijector(Cell):
|
|||
is_constant_jacobian (bool): if the bijector has constant derivative. Default: False.
|
||||
is_injective (bool): if the bijector is an one-to-one mapping. Default: True.
|
||||
name (str): name of the bijector. Default: None.
|
||||
dtype (mstype): type of the distribution the bijector can operate on. Default: None.
|
||||
dtype (mindspore.dtype): type of the distribution the bijector can operate on. Default: None.
|
||||
param (dict): parameters used to initialize the bijector. Default: None.
|
||||
"""
|
||||
def __init__(self,
|
||||
|
@ -110,7 +110,7 @@ class Bijector(Cell):
|
|||
*args: args[0] shall be either a distribution or the name of a bijector function.
|
||||
"""
|
||||
if isinstance(args[0], Distribution):
|
||||
return TransformedDistribution(self, args[0])
|
||||
return TransformedDistribution(self, args[0], self.distribution.dtype)
|
||||
return super(Bijector, self).__call__(*args, **kwargs)
|
||||
|
||||
def construct(self, name, *args, **kwargs):
|
||||
|
|
|
@ -22,7 +22,10 @@ from .bijector import Bijector
|
|||
class Softplus(Bijector):
|
||||
r"""
|
||||
Softplus Bijector.
|
||||
This Bijector performs the operation: Y = \frac{\log(1 + e ^ {kX})}{k}, where k is the sharpness factor.
|
||||
This Bijector performs the operation, where k is the sharpness factor.
|
||||
|
||||
.. math::
|
||||
Y = \frac{\log(1 + e ^ {kX})}{k}
|
||||
|
||||
Args:
|
||||
sharpness (float): scale factor. Default: 1.0.
|
||||
|
|
|
@ -184,7 +184,7 @@ def check_greater(a, b, name_a, name_b):
|
|||
|
||||
def check_prob(p):
|
||||
"""
|
||||
Check if p is a proper probability, i.e. 0 <= p <=1.
|
||||
Check if p is a proper probability, i.e. 0 < p <1.
|
||||
|
||||
Args:
|
||||
p (Tensor, Parameter): value to be checked.
|
||||
|
@ -196,12 +196,12 @@ def check_prob(p):
|
|||
if not isinstance(p.default_input, Tensor):
|
||||
return
|
||||
p = p.default_input
|
||||
comp = np.less(p.asnumpy(), np.zeros(p.shape))
|
||||
if comp.any():
|
||||
raise ValueError('Probabilities should be greater than or equal to zero')
|
||||
comp = np.greater(p.asnumpy(), np.ones(p.shape))
|
||||
if comp.any():
|
||||
raise ValueError('Probabilities should be less than or equal to one')
|
||||
comp = np.less(np.zeros(p.shape), p.asnumpy())
|
||||
if not comp.all():
|
||||
raise ValueError('Probabilities should be greater than zero')
|
||||
comp = np.greater(np.ones(p.shape), p.asnumpy())
|
||||
if not comp.all():
|
||||
raise ValueError('Probabilities should be less than one')
|
||||
|
||||
|
||||
def logits_to_probs(logits, is_binary=False):
|
||||
|
|
|
@ -110,6 +110,7 @@ class Bernoulli(Distribution):
|
|||
self.const = P.ScalarToArray()
|
||||
self.dtypeop = P.DType()
|
||||
self.erf = P.Erf()
|
||||
self.exp = P.Exp()
|
||||
self.fill = P.Fill()
|
||||
self.log = P.Log()
|
||||
self.less = P.Less()
|
||||
|
@ -159,7 +160,7 @@ class Bernoulli(Distribution):
|
|||
"""
|
||||
probs1 = self.probs if probs1 is None else probs1
|
||||
probs0 = 1.0 - probs1
|
||||
return probs0 * probs1
|
||||
return self.exp(self.log(probs0) + self.log(probs1))
|
||||
|
||||
def _entropy(self, probs=None):
|
||||
r"""
|
||||
|
@ -183,7 +184,7 @@ class Bernoulli(Distribution):
|
|||
return self._entropy(probs=probs1_a) + self._kl_loss(dist, probs1_b, probs1_a)
|
||||
return None
|
||||
|
||||
def _prob(self, value, probs=None):
|
||||
def _log_prob(self, value, probs=None):
|
||||
r"""
|
||||
pmf of Bernoulli distribution.
|
||||
|
||||
|
@ -197,7 +198,7 @@ class Bernoulli(Distribution):
|
|||
"""
|
||||
probs1 = self.probs if probs is None else probs
|
||||
probs0 = 1.0 - probs1
|
||||
return (probs1 * value) + (probs0 * (1.0 - value))
|
||||
return self.log(probs1) * value + self.log(probs0) * (1.0 - value)
|
||||
|
||||
def _cdf(self, value, probs=None):
|
||||
r"""
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
"""basic"""
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param
|
||||
|
||||
class Distribution(Cell):
|
||||
|
@ -28,12 +29,15 @@ class Distribution(Cell):
|
|||
|
||||
Note:
|
||||
Derived class should override operations such as ,_mean, _prob,
|
||||
and _log_prob. Arguments should be passed in through *args or **kwargs.
|
||||
and _log_prob. Required arguments, such as value for _prob,
|
||||
should be passed in through args or kwargs. dist_spec_args which specify
|
||||
a new distribution are optional.
|
||||
|
||||
Dist_spec_args are unique for each type of distribution. For example, mean and sd
|
||||
are the dist_spec_args for a Normal distribution.
|
||||
dist_spec_args are unique for each type of distribution. For example, mean and sd
|
||||
are the dist_spec_args for a Normal distribution, while rate is the dist_spec_args
|
||||
for exponential distribution.
|
||||
|
||||
For all functions, passing in dist_spec_args, are optional.
|
||||
For all functions, passing in dist_spec_args, is optional.
|
||||
Passing in the additional dist_spec_args will make the result to be evaluated with
|
||||
new distribution specified by the dist_spec_args. But it won't change the
|
||||
original distribuion.
|
||||
|
@ -49,7 +53,7 @@ class Distribution(Cell):
|
|||
"""
|
||||
super(Distribution, self).__init__()
|
||||
validator.check_value_type('name', name, [str], 'distribution_name')
|
||||
validator.check_value_type('seed', seed, [int], name)
|
||||
validator.check_integer('seed', seed, 0, Rel.GE, name)
|
||||
|
||||
self._name = name
|
||||
self._seed = seed
|
||||
|
@ -191,7 +195,7 @@ class Distribution(Cell):
|
|||
|
||||
Note:
|
||||
Args must include value.
|
||||
Dist_spec_args are optional.
|
||||
dist_spec_args are optional.
|
||||
"""
|
||||
return self._call_log_prob(*args, **kwargs)
|
||||
|
||||
|
@ -210,7 +214,7 @@ class Distribution(Cell):
|
|||
|
||||
Note:
|
||||
Args must include value.
|
||||
Dist_spec_args are optional.
|
||||
dist_spec_args are optional.
|
||||
"""
|
||||
return self._call_prob(*args, **kwargs)
|
||||
|
||||
|
@ -229,7 +233,7 @@ class Distribution(Cell):
|
|||
|
||||
Note:
|
||||
Args must include value.
|
||||
Dist_spec_args are optional.
|
||||
dist_spec_args are optional.
|
||||
"""
|
||||
return self._call_cdf(*args, **kwargs)
|
||||
|
||||
|
@ -266,7 +270,7 @@ class Distribution(Cell):
|
|||
|
||||
Note:
|
||||
Args must include value.
|
||||
Dist_spec_args are optional.
|
||||
dist_spec_args are optional.
|
||||
"""
|
||||
return self._call_log_cdf(*args, **kwargs)
|
||||
|
||||
|
@ -285,7 +289,7 @@ class Distribution(Cell):
|
|||
|
||||
Note:
|
||||
Args must include value.
|
||||
Dist_spec_args are optional.
|
||||
dist_spec_args are optional.
|
||||
"""
|
||||
return self._call_survival(*args, **kwargs)
|
||||
|
||||
|
@ -313,7 +317,7 @@ class Distribution(Cell):
|
|||
|
||||
Note:
|
||||
Args must include value.
|
||||
Dist_spec_args are optional.
|
||||
dist_spec_args are optional.
|
||||
"""
|
||||
return self._call_log_survival(*args, **kwargs)
|
||||
|
||||
|
@ -341,7 +345,7 @@ class Distribution(Cell):
|
|||
Evaluate the mean.
|
||||
|
||||
Note:
|
||||
Dist_spec_args are optional.
|
||||
dist_spec_args are optional.
|
||||
"""
|
||||
return self._mean(*args, **kwargs)
|
||||
|
||||
|
@ -350,7 +354,7 @@ class Distribution(Cell):
|
|||
Evaluate the mode.
|
||||
|
||||
Note:
|
||||
Dist_spec_args are optional.
|
||||
dist_spec_args are optional.
|
||||
"""
|
||||
return self._mode(*args, **kwargs)
|
||||
|
||||
|
@ -359,7 +363,7 @@ class Distribution(Cell):
|
|||
Evaluate the standard deviation.
|
||||
|
||||
Note:
|
||||
Dist_spec_args are optional.
|
||||
dist_spec_args are optional.
|
||||
"""
|
||||
return self._call_sd(*args, **kwargs)
|
||||
|
||||
|
@ -368,7 +372,7 @@ class Distribution(Cell):
|
|||
Evaluate the variance.
|
||||
|
||||
Note:
|
||||
Dist_spec_args are optional.
|
||||
dist_spec_args are optional.
|
||||
"""
|
||||
return self._call_var(*args, **kwargs)
|
||||
|
||||
|
@ -395,7 +399,7 @@ class Distribution(Cell):
|
|||
Evaluate the entropy.
|
||||
|
||||
Note:
|
||||
Dist_spec_args are optional.
|
||||
dist_spec_args are optional.
|
||||
"""
|
||||
return self._entropy(*args, **kwargs)
|
||||
|
||||
|
@ -424,7 +428,7 @@ class Distribution(Cell):
|
|||
|
||||
Note:
|
||||
Shape of the sample is default to ().
|
||||
Dist_spec_args are optional.
|
||||
dist_spec_args are optional.
|
||||
"""
|
||||
return self._sample(*args, **kwargs)
|
||||
|
||||
|
|
|
@ -199,7 +199,7 @@ class Exponential(Distribution):
|
|||
pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0
|
||||
"""
|
||||
rate = self.rate if rate is None else rate
|
||||
prob = rate * self.exp(-1. * rate * value)
|
||||
prob = self.exp(self.log(rate) - rate * value)
|
||||
zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0)
|
||||
comp = self.less(value, zeros)
|
||||
return self.select(comp, zeros, prob)
|
||||
|
|
|
@ -113,6 +113,7 @@ class Geometric(Distribution):
|
|||
self.cast = P.Cast()
|
||||
self.const = P.ScalarToArray()
|
||||
self.dtypeop = P.DType()
|
||||
self.exp = P.Exp()
|
||||
self.fill = P.Fill()
|
||||
self.floor = P.Floor()
|
||||
self.issubclass = P.IsSubClass()
|
||||
|
@ -205,7 +206,7 @@ class Geometric(Distribution):
|
|||
value = self.floor(value)
|
||||
else:
|
||||
return None
|
||||
pmf = self.pow((1.0 - probs1), value) * probs1
|
||||
pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1))
|
||||
zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0)
|
||||
comp = self.less(value, zeros)
|
||||
return self.select(comp, zeros, pmf)
|
||||
|
|
|
@ -18,7 +18,7 @@ from mindspore.ops import operations as P
|
|||
from mindspore.ops import composite as C
|
||||
from mindspore.common import dtype as mstype
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import convert_to_batch, check_greater_equal_zero, check_type
|
||||
from ._utils.utils import convert_to_batch, check_greater_zero, check_type
|
||||
|
||||
|
||||
class Normal(Distribution):
|
||||
|
@ -106,7 +106,7 @@ class Normal(Distribution):
|
|||
if mean is not None and sd is not None:
|
||||
self._mean_value = convert_to_batch(mean, self.broadcast_shape, dtype)
|
||||
self._sd_value = convert_to_batch(sd, self.broadcast_shape, dtype)
|
||||
check_greater_equal_zero(self._sd_value, "Standard deviation")
|
||||
check_greater_zero(self._sd_value, "Standard deviation")
|
||||
else:
|
||||
self._mean_value = mean
|
||||
self._sd_value = sd
|
||||
|
@ -166,7 +166,7 @@ class Normal(Distribution):
|
|||
H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma)))
|
||||
"""
|
||||
sd = self._sd_value if sd is None else sd
|
||||
return self.log(self.sqrt(np.e * 2. * np.pi * self.sq(sd)))
|
||||
return self.log(self.sqrt(self.const(np.e * 2. * np.pi))) + self.log(sd)
|
||||
|
||||
def _cross_entropy(self, dist, mean_b, sd_b, mean_a=None, sd_a=None):
|
||||
r"""
|
||||
|
@ -198,7 +198,7 @@ class Normal(Distribution):
|
|||
mean = self._mean_value if mean is None else mean
|
||||
sd = self._sd_value if sd is None else sd
|
||||
unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd))
|
||||
neg_normalization = -1. * self.log(self.sqrt(2. * np.pi * self.sq(sd)))
|
||||
neg_normalization = -1. * self.log(self.sqrt(self.const(2. * np.pi))) - self.log(sd)
|
||||
return unnormalized_log_prob + neg_normalization
|
||||
|
||||
def _cdf(self, value, mean=None, sd=None):
|
||||
|
|
|
@ -216,8 +216,8 @@ class Uniform(Distribution):
|
|||
"""
|
||||
low = self.low if low is None else low
|
||||
high = self.high if high is None else high
|
||||
ones = self.fill(self.dtype, self.shape(value), 1.0)
|
||||
prob = ones / (high - low)
|
||||
neg_ones = self.fill(self.dtype, self.shape(value), -1.0)
|
||||
prob = self.exp(neg_ones * self.log(high - low))
|
||||
broadcast_shape = self.shape(prob)
|
||||
zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
|
||||
comp_lo = self.less(value, low)
|
||||
|
|
|
@ -28,7 +28,7 @@ def test_arguments():
|
|||
"""
|
||||
b = msd.Bernoulli()
|
||||
assert isinstance(b, msd.Distribution)
|
||||
b = msd.Bernoulli([0.0, 0.3, 0.5, 1.0], dtype=dtype.int32)
|
||||
b = msd.Bernoulli([0.1, 0.3, 0.5, 0.9], dtype=dtype.int32)
|
||||
assert isinstance(b, msd.Distribution)
|
||||
|
||||
def test_type():
|
||||
|
@ -51,6 +51,10 @@ def test_prob():
|
|||
msd.Bernoulli([-0.1], dtype=dtype.int32)
|
||||
with pytest.raises(ValueError):
|
||||
msd.Bernoulli([1.1], dtype=dtype.int32)
|
||||
with pytest.raises(ValueError):
|
||||
msd.Bernoulli([0.0], dtype=dtype.int32)
|
||||
with pytest.raises(ValueError):
|
||||
msd.Bernoulli([1.0], dtype=dtype.int32)
|
||||
|
||||
class BernoulliProb(nn.Cell):
|
||||
"""
|
||||
|
|
|
@ -29,7 +29,7 @@ def test_arguments():
|
|||
"""
|
||||
g = msd.Geometric()
|
||||
assert isinstance(g, msd.Distribution)
|
||||
g = msd.Geometric([0.0, 0.3, 0.5, 1.0], dtype=dtype.int32)
|
||||
g = msd.Geometric([0.1, 0.3, 0.5, 0.9], dtype=dtype.int32)
|
||||
assert isinstance(g, msd.Distribution)
|
||||
|
||||
def test_type():
|
||||
|
@ -52,6 +52,10 @@ def test_prob():
|
|||
msd.Geometric([-0.1], dtype=dtype.int32)
|
||||
with pytest.raises(ValueError):
|
||||
msd.Geometric([1.1], dtype=dtype.int32)
|
||||
with pytest.raises(ValueError):
|
||||
msd.Geometric([0.0], dtype=dtype.int32)
|
||||
with pytest.raises(ValueError):
|
||||
msd.Geometric([1.0], dtype=dtype.int32)
|
||||
|
||||
class GeometricProb(nn.Cell):
|
||||
"""
|
||||
|
|
|
@ -42,6 +42,12 @@ def test_seed():
|
|||
with pytest.raises(TypeError):
|
||||
msd.Normal(0., 1., seed='seed')
|
||||
|
||||
def test_sd():
|
||||
with pytest.raises(ValueError):
|
||||
msd.Normal(0., 0.)
|
||||
with pytest.raises(ValueError):
|
||||
msd.Normal(0., -1.)
|
||||
|
||||
def test_arguments():
|
||||
"""
|
||||
args passing during initialization.
|
||||
|
|
Loading…
Reference in New Issue