!4369 Added distribution_specific_args checking to some functions and fixed minor issues in docs

Merge pull request !4369 from XunDeng/pp_poc_v3
This commit is contained in:
mindspore-ci-bot 2020-08-14 09:33:23 +08:00 committed by Gitee
commit 93f60269c3
12 changed files with 63 additions and 40 deletions

View File

@ -26,7 +26,7 @@ class Bijector(Cell):
is_constant_jacobian (bool): if the bijector has constant derivative. Default: False.
is_injective (bool): if the bijector is an one-to-one mapping. Default: True.
name (str): name of the bijector. Default: None.
dtype (mstype): type of the distribution the bijector can operate on. Default: None.
dtype (mindspore.dtype): type of the distribution the bijector can operate on. Default: None.
param (dict): parameters used to initialize the bijector. Default: None.
"""
def __init__(self,
@ -110,7 +110,7 @@ class Bijector(Cell):
*args: args[0] shall be either a distribution or the name of a bijector function.
"""
if isinstance(args[0], Distribution):
return TransformedDistribution(self, args[0])
return TransformedDistribution(self, args[0], self.distribution.dtype)
return super(Bijector, self).__call__(*args, **kwargs)
def construct(self, name, *args, **kwargs):

View File

@ -22,7 +22,10 @@ from .bijector import Bijector
class Softplus(Bijector):
r"""
Softplus Bijector.
This Bijector performs the operation: Y = \frac{\log(1 + e ^ {kX})}{k}, where k is the sharpness factor.
This Bijector performs the operation, where k is the sharpness factor.
.. math::
Y = \frac{\log(1 + e ^ {kX})}{k}
Args:
sharpness (float): scale factor. Default: 1.0.

View File

@ -184,7 +184,7 @@ def check_greater(a, b, name_a, name_b):
def check_prob(p):
"""
Check if p is a proper probability, i.e. 0 <= p <=1.
Check if p is a proper probability, i.e. 0 < p <1.
Args:
p (Tensor, Parameter): value to be checked.
@ -196,12 +196,12 @@ def check_prob(p):
if not isinstance(p.default_input, Tensor):
return
p = p.default_input
comp = np.less(p.asnumpy(), np.zeros(p.shape))
if comp.any():
raise ValueError('Probabilities should be greater than or equal to zero')
comp = np.greater(p.asnumpy(), np.ones(p.shape))
if comp.any():
raise ValueError('Probabilities should be less than or equal to one')
comp = np.less(np.zeros(p.shape), p.asnumpy())
if not comp.all():
raise ValueError('Probabilities should be greater than zero')
comp = np.greater(np.ones(p.shape), p.asnumpy())
if not comp.all():
raise ValueError('Probabilities should be less than one')
def logits_to_probs(logits, is_binary=False):

View File

@ -110,6 +110,7 @@ class Bernoulli(Distribution):
self.const = P.ScalarToArray()
self.dtypeop = P.DType()
self.erf = P.Erf()
self.exp = P.Exp()
self.fill = P.Fill()
self.log = P.Log()
self.less = P.Less()
@ -159,7 +160,7 @@ class Bernoulli(Distribution):
"""
probs1 = self.probs if probs1 is None else probs1
probs0 = 1.0 - probs1
return probs0 * probs1
return self.exp(self.log(probs0) + self.log(probs1))
def _entropy(self, probs=None):
r"""
@ -183,7 +184,7 @@ class Bernoulli(Distribution):
return self._entropy(probs=probs1_a) + self._kl_loss(dist, probs1_b, probs1_a)
return None
def _prob(self, value, probs=None):
def _log_prob(self, value, probs=None):
r"""
pmf of Bernoulli distribution.
@ -197,7 +198,7 @@ class Bernoulli(Distribution):
"""
probs1 = self.probs if probs is None else probs
probs0 = 1.0 - probs1
return (probs1 * value) + (probs0 * (1.0 - value))
return self.log(probs1) * value + self.log(probs0) * (1.0 - value)
def _cdf(self, value, probs=None):
r"""

View File

@ -15,6 +15,7 @@
"""basic"""
from mindspore.nn.cell import Cell
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param
class Distribution(Cell):
@ -28,12 +29,15 @@ class Distribution(Cell):
Note:
Derived class should override operations such as ,_mean, _prob,
and _log_prob. Arguments should be passed in through *args or **kwargs.
and _log_prob. Required arguments, such as value for _prob,
should be passed in through args or kwargs. dist_spec_args which specify
a new distribution are optional.
Dist_spec_args are unique for each type of distribution. For example, mean and sd
are the dist_spec_args for a Normal distribution.
dist_spec_args are unique for each type of distribution. For example, mean and sd
are the dist_spec_args for a Normal distribution, while rate is the dist_spec_args
for exponential distribution.
For all functions, passing in dist_spec_args, are optional.
For all functions, passing in dist_spec_args, is optional.
Passing in the additional dist_spec_args will make the result to be evaluated with
new distribution specified by the dist_spec_args. But it won't change the
original distribuion.
@ -49,7 +53,7 @@ class Distribution(Cell):
"""
super(Distribution, self).__init__()
validator.check_value_type('name', name, [str], 'distribution_name')
validator.check_value_type('seed', seed, [int], name)
validator.check_integer('seed', seed, 0, Rel.GE, name)
self._name = name
self._seed = seed
@ -191,7 +195,7 @@ class Distribution(Cell):
Note:
Args must include value.
Dist_spec_args are optional.
dist_spec_args are optional.
"""
return self._call_log_prob(*args, **kwargs)
@ -210,7 +214,7 @@ class Distribution(Cell):
Note:
Args must include value.
Dist_spec_args are optional.
dist_spec_args are optional.
"""
return self._call_prob(*args, **kwargs)
@ -229,7 +233,7 @@ class Distribution(Cell):
Note:
Args must include value.
Dist_spec_args are optional.
dist_spec_args are optional.
"""
return self._call_cdf(*args, **kwargs)
@ -266,7 +270,7 @@ class Distribution(Cell):
Note:
Args must include value.
Dist_spec_args are optional.
dist_spec_args are optional.
"""
return self._call_log_cdf(*args, **kwargs)
@ -285,7 +289,7 @@ class Distribution(Cell):
Note:
Args must include value.
Dist_spec_args are optional.
dist_spec_args are optional.
"""
return self._call_survival(*args, **kwargs)
@ -313,7 +317,7 @@ class Distribution(Cell):
Note:
Args must include value.
Dist_spec_args are optional.
dist_spec_args are optional.
"""
return self._call_log_survival(*args, **kwargs)
@ -341,7 +345,7 @@ class Distribution(Cell):
Evaluate the mean.
Note:
Dist_spec_args are optional.
dist_spec_args are optional.
"""
return self._mean(*args, **kwargs)
@ -350,7 +354,7 @@ class Distribution(Cell):
Evaluate the mode.
Note:
Dist_spec_args are optional.
dist_spec_args are optional.
"""
return self._mode(*args, **kwargs)
@ -359,7 +363,7 @@ class Distribution(Cell):
Evaluate the standard deviation.
Note:
Dist_spec_args are optional.
dist_spec_args are optional.
"""
return self._call_sd(*args, **kwargs)
@ -368,7 +372,7 @@ class Distribution(Cell):
Evaluate the variance.
Note:
Dist_spec_args are optional.
dist_spec_args are optional.
"""
return self._call_var(*args, **kwargs)
@ -395,7 +399,7 @@ class Distribution(Cell):
Evaluate the entropy.
Note:
Dist_spec_args are optional.
dist_spec_args are optional.
"""
return self._entropy(*args, **kwargs)
@ -424,7 +428,7 @@ class Distribution(Cell):
Note:
Shape of the sample is default to ().
Dist_spec_args are optional.
dist_spec_args are optional.
"""
return self._sample(*args, **kwargs)

View File

@ -199,7 +199,7 @@ class Exponential(Distribution):
pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0
"""
rate = self.rate if rate is None else rate
prob = rate * self.exp(-1. * rate * value)
prob = self.exp(self.log(rate) - rate * value)
zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0)
comp = self.less(value, zeros)
return self.select(comp, zeros, prob)

View File

@ -113,6 +113,7 @@ class Geometric(Distribution):
self.cast = P.Cast()
self.const = P.ScalarToArray()
self.dtypeop = P.DType()
self.exp = P.Exp()
self.fill = P.Fill()
self.floor = P.Floor()
self.issubclass = P.IsSubClass()
@ -205,7 +206,7 @@ class Geometric(Distribution):
value = self.floor(value)
else:
return None
pmf = self.pow((1.0 - probs1), value) * probs1
pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1))
zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0)
comp = self.less(value, zeros)
return self.select(comp, zeros, pmf)

View File

@ -18,7 +18,7 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import convert_to_batch, check_greater_equal_zero, check_type
from ._utils.utils import convert_to_batch, check_greater_zero, check_type
class Normal(Distribution):
@ -106,7 +106,7 @@ class Normal(Distribution):
if mean is not None and sd is not None:
self._mean_value = convert_to_batch(mean, self.broadcast_shape, dtype)
self._sd_value = convert_to_batch(sd, self.broadcast_shape, dtype)
check_greater_equal_zero(self._sd_value, "Standard deviation")
check_greater_zero(self._sd_value, "Standard deviation")
else:
self._mean_value = mean
self._sd_value = sd
@ -166,7 +166,7 @@ class Normal(Distribution):
H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma)))
"""
sd = self._sd_value if sd is None else sd
return self.log(self.sqrt(np.e * 2. * np.pi * self.sq(sd)))
return self.log(self.sqrt(self.const(np.e * 2. * np.pi))) + self.log(sd)
def _cross_entropy(self, dist, mean_b, sd_b, mean_a=None, sd_a=None):
r"""
@ -198,7 +198,7 @@ class Normal(Distribution):
mean = self._mean_value if mean is None else mean
sd = self._sd_value if sd is None else sd
unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd))
neg_normalization = -1. * self.log(self.sqrt(2. * np.pi * self.sq(sd)))
neg_normalization = -1. * self.log(self.sqrt(self.const(2. * np.pi))) - self.log(sd)
return unnormalized_log_prob + neg_normalization
def _cdf(self, value, mean=None, sd=None):

View File

@ -216,8 +216,8 @@ class Uniform(Distribution):
"""
low = self.low if low is None else low
high = self.high if high is None else high
ones = self.fill(self.dtype, self.shape(value), 1.0)
prob = ones / (high - low)
neg_ones = self.fill(self.dtype, self.shape(value), -1.0)
prob = self.exp(neg_ones * self.log(high - low))
broadcast_shape = self.shape(prob)
zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
comp_lo = self.less(value, low)

View File

@ -28,7 +28,7 @@ def test_arguments():
"""
b = msd.Bernoulli()
assert isinstance(b, msd.Distribution)
b = msd.Bernoulli([0.0, 0.3, 0.5, 1.0], dtype=dtype.int32)
b = msd.Bernoulli([0.1, 0.3, 0.5, 0.9], dtype=dtype.int32)
assert isinstance(b, msd.Distribution)
def test_type():
@ -51,6 +51,10 @@ def test_prob():
msd.Bernoulli([-0.1], dtype=dtype.int32)
with pytest.raises(ValueError):
msd.Bernoulli([1.1], dtype=dtype.int32)
with pytest.raises(ValueError):
msd.Bernoulli([0.0], dtype=dtype.int32)
with pytest.raises(ValueError):
msd.Bernoulli([1.0], dtype=dtype.int32)
class BernoulliProb(nn.Cell):
"""

View File

@ -29,7 +29,7 @@ def test_arguments():
"""
g = msd.Geometric()
assert isinstance(g, msd.Distribution)
g = msd.Geometric([0.0, 0.3, 0.5, 1.0], dtype=dtype.int32)
g = msd.Geometric([0.1, 0.3, 0.5, 0.9], dtype=dtype.int32)
assert isinstance(g, msd.Distribution)
def test_type():
@ -52,6 +52,10 @@ def test_prob():
msd.Geometric([-0.1], dtype=dtype.int32)
with pytest.raises(ValueError):
msd.Geometric([1.1], dtype=dtype.int32)
with pytest.raises(ValueError):
msd.Geometric([0.0], dtype=dtype.int32)
with pytest.raises(ValueError):
msd.Geometric([1.0], dtype=dtype.int32)
class GeometricProb(nn.Cell):
"""

View File

@ -42,6 +42,12 @@ def test_seed():
with pytest.raises(TypeError):
msd.Normal(0., 1., seed='seed')
def test_sd():
with pytest.raises(ValueError):
msd.Normal(0., 0.)
with pytest.raises(ValueError):
msd.Normal(0., -1.)
def test_arguments():
"""
args passing during initialization.