fix minor issues in distribution classes
This commit is contained in:
parent
a15356c779
commit
4ef9e9e57f
|
@ -36,6 +36,7 @@ class GumbelCDF(Bijector):
|
|||
``Ascend`` ``GPU``
|
||||
|
||||
Note:
|
||||
`scale` must be greater than zero.
|
||||
For `inverse` and `inverse_log_jacobian`, input should be in range of (0, 1).
|
||||
The dtype of `loc` and `scale` must be float.
|
||||
If `loc`, `scale` are passed in as numpy.ndarray or tensor, they have to have
|
||||
|
|
|
@ -29,16 +29,16 @@ class Beta(Distribution):
|
|||
Beta distribution.
|
||||
|
||||
Args:
|
||||
concentration1 (int, float, list, numpy.ndarray, Tensor, Parameter): The concentration1,
|
||||
concentration1 (list, numpy.ndarray, Tensor, Parameter): The concentration1,
|
||||
also know as alpha of the Beta distribution.
|
||||
concentration0 (int, float, list, numpy.ndarray, Tensor, Parameter): The concentration0, also know as
|
||||
concentration0 (list, numpy.ndarray, Tensor, Parameter): The concentration0, also know as
|
||||
beta of the Beta distribution.
|
||||
seed (int): The seed used in sampling. The global seed is used if it is None. Default: None.
|
||||
dtype (mindspore.dtype): The type of the event samples. Default: mstype.float32.
|
||||
name (str): The name of the distribution. Default: 'Beta'.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
``Ascend``
|
||||
|
||||
Note:
|
||||
`concentration1` and `concentration0` must be greater than zero.
|
||||
|
@ -148,8 +148,16 @@ class Beta(Distribution):
|
|||
"""
|
||||
param = dict(locals())
|
||||
param['param_dict'] = {'concentration1': concentration1, 'concentration0': concentration0}
|
||||
|
||||
valid_dtype = mstype.float_type
|
||||
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
|
||||
|
||||
# As some operators can't accept scalar input, check the type here
|
||||
if isinstance(concentration0, float):
|
||||
raise TypeError("Parameter concentration0 can't be scalar")
|
||||
if isinstance(concentration1, float):
|
||||
raise TypeError("Parameter concentration1 can't be scalar")
|
||||
|
||||
super(Beta, self).__init__(seed, dtype, name, param)
|
||||
|
||||
self._concentration1 = self._add_parameter(concentration1, 'concentration1')
|
||||
|
@ -251,7 +259,7 @@ class Beta(Distribution):
|
|||
- (concentration0 - 1.) * self.digamma(concentration0) \
|
||||
+ (total_concentration - 2.) * self.digamma(total_concentration)
|
||||
|
||||
def _cross_entropy(self, dist, concentration1_b, concentration0_b, concentration1=None, concentration0=None):
|
||||
def _cross_entropy(self, dist, concentration1_b, concentration0_b, concentration1_a=None, concentration0_a=None):
|
||||
r"""
|
||||
Evaluate cross entropy between Beta distributions.
|
||||
|
||||
|
@ -263,8 +271,8 @@ class Beta(Distribution):
|
|||
concentration0_a (Tensor): concentration0 of distribution a. Default: self._concentration0.
|
||||
"""
|
||||
check_distribution_name(dist, 'Beta')
|
||||
return self._entropy(concentration1, concentration0) \
|
||||
+ self._kl_loss(dist, concentration1_b, concentration0_b, concentration1, concentration0)
|
||||
return self._entropy(concentration1_a, concentration0_a) \
|
||||
+ self._kl_loss(dist, concentration1_b, concentration0_b, concentration1_a, concentration0_a)
|
||||
|
||||
def _log_prob(self, value, concentration1=None, concentration0=None):
|
||||
r"""
|
||||
|
@ -285,7 +293,7 @@ class Beta(Distribution):
|
|||
+ (concentration0 - 1.) * self.log1p(self.neg(value))
|
||||
return log_unnormalized_prob - self.lbeta(concentration1, concentration0)
|
||||
|
||||
def _kl_loss(self, dist, concentration1_b, concentration0_b, concentration1=None, concentration0=None):
|
||||
def _kl_loss(self, dist, concentration1_b, concentration0_b, concentration1_a=None, concentration0_a=None):
|
||||
r"""
|
||||
Evaluate Beta-Beta KL divergence, i.e. KL(a||b).
|
||||
|
||||
|
@ -307,7 +315,7 @@ class Beta(Distribution):
|
|||
concentration0_b = self._check_value(concentration0_b, 'concentration0_b')
|
||||
concentration1_b = self.cast(concentration1_b, self.parameter_type)
|
||||
concentration0_b = self.cast(concentration0_b, self.parameter_type)
|
||||
concentration1_a, concentration0_a = self._check_param_type(concentration1, concentration0)
|
||||
concentration1_a, concentration0_a = self._check_param_type(concentration1_a, concentration0_a)
|
||||
total_concentration_a = concentration1_a + concentration0_a
|
||||
total_concentration_b = concentration1_b + concentration0_b
|
||||
log_normalization_a = self.lbeta(concentration1_a, concentration0_a)
|
||||
|
|
|
@ -29,16 +29,16 @@ class Gamma(Distribution):
|
|||
Gamma distribution.
|
||||
|
||||
Args:
|
||||
concentration (int, float, list, numpy.ndarray, Tensor, Parameter): The concentration,
|
||||
concentration (list, numpy.ndarray, Tensor, Parameter): The concentration,
|
||||
also know as alpha of the Gamma distribution.
|
||||
rate (int, float, list, numpy.ndarray, Tensor, Parameter): The rate, also know as
|
||||
rate (list, numpy.ndarray, Tensor, Parameter): The rate, also know as
|
||||
beta of the Gamma distribution.
|
||||
seed (int): The seed used in sampling. The global seed is used if it is None. Default: None.
|
||||
dtype (mindspore.dtype): The type of the event samples. Default: mstype.float32.
|
||||
name (str): The name of the distribution. Default: 'Gamma'.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
``Ascend``
|
||||
|
||||
Note:
|
||||
`concentration` and `rate` must be greater than zero.
|
||||
|
@ -147,6 +147,13 @@ class Gamma(Distribution):
|
|||
param['param_dict'] = {'concentration': concentration, 'rate': rate}
|
||||
valid_dtype = mstype.float_type
|
||||
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
|
||||
|
||||
# As some operators can't accept scalar input, check the type here
|
||||
if isinstance(concentration, (int, float)):
|
||||
raise TypeError("Parameter concentration can't be scalar")
|
||||
if isinstance(rate, (int, float)):
|
||||
raise TypeError("Parameter rate can't be scalar")
|
||||
|
||||
super(Gamma, self).__init__(seed, dtype, name, param)
|
||||
|
||||
self._concentration = self._add_parameter(concentration, 'concentration')
|
||||
|
@ -248,7 +255,7 @@ class Gamma(Distribution):
|
|||
return concentration - self.log(rate) + self.lgamma(concentration) \
|
||||
+ (1. - concentration) * self.digamma(concentration)
|
||||
|
||||
def _cross_entropy(self, dist, concentration_b, rate_b, concentration=None, rate=None):
|
||||
def _cross_entropy(self, dist, concentration_b, rate_b, concentration_a=None, rate_a=None):
|
||||
r"""
|
||||
Evaluate cross entropy between Gamma distributions.
|
||||
|
||||
|
@ -260,7 +267,8 @@ class Gamma(Distribution):
|
|||
rate_a (Tensor): rate of distribution a. Default: self._rate.
|
||||
"""
|
||||
check_distribution_name(dist, 'Gamma')
|
||||
return self._entropy(concentration, rate) + self._kl_loss(dist, concentration_b, rate_b, concentration, rate)
|
||||
return self._entropy(concentration_a, rate_a) +\
|
||||
self._kl_loss(dist, concentration_b, rate_b, concentration_a, rate_a)
|
||||
|
||||
def _log_prob(self, value, concentration=None, rate=None):
|
||||
r"""
|
||||
|
@ -299,7 +307,7 @@ class Gamma(Distribution):
|
|||
concentration, rate = self._check_param_type(concentration, rate)
|
||||
return self.igamma(concentration, rate * value)
|
||||
|
||||
def _kl_loss(self, dist, concentration_b, rate_b, concentration=None, rate=None):
|
||||
def _kl_loss(self, dist, concentration_b, rate_b, concentration_a=None, rate_a=None):
|
||||
r"""
|
||||
Evaluate Gamma-Gamma KL divergence, i.e. KL(a||b).
|
||||
|
||||
|
@ -320,7 +328,7 @@ class Gamma(Distribution):
|
|||
rate_b = self._check_value(rate_b, 'rate_b')
|
||||
concentration_b = self.cast(concentration_b, self.parameter_type)
|
||||
rate_b = self.cast(rate_b, self.parameter_type)
|
||||
concentration_a, rate_a = self._check_param_type(concentration, rate)
|
||||
concentration_a, rate_a = self._check_param_type(concentration_a, rate_a)
|
||||
return (concentration_a - concentration_b) * self.digamma(concentration_a) \
|
||||
+ self.lgamma(concentration_b) - self.lgamma(concentration_a) \
|
||||
+ concentration_b * self.log(rate_a) - concentration_b * self.log(rate_b) \
|
||||
|
|
|
@ -29,13 +29,13 @@ class Poisson(Distribution):
|
|||
Poisson Distribution.
|
||||
|
||||
Args:
|
||||
rate (float, list, numpy.ndarray, Tensor, Parameter): The rate of the Poisson distribution..
|
||||
rate (list, numpy.ndarray, Tensor, Parameter): The rate of the Poisson distribution..
|
||||
seed (int): The seed used in sampling. The global seed is used if it is None. Default: None.
|
||||
dtype (mindspore.dtype): The type of the event samples. Default: mstype.float32.
|
||||
name (str): The name of the distribution. Default: 'Poisson'.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
``Ascend``
|
||||
|
||||
Note:
|
||||
`rate` must be strictly greater than 0.
|
||||
|
@ -47,7 +47,7 @@ class Poisson(Distribution):
|
|||
>>> import mindspore.nn.probability.distribution as msd
|
||||
>>> from mindspore import Tensor
|
||||
>>> # To initialize an Poisson distribution of the rate 0.5.
|
||||
>>> p1 = msd.Poisson(0.5, dtype=mindspore.float32)
|
||||
>>> p1 = msd.Poisson([0.5], dtype=mindspore.float32)
|
||||
>>> # An Poisson distribution can be initilized without arguments.
|
||||
>>> # In this case, `rate` must be passed in through `args` during function calls.
|
||||
>>> p2 = msd.Poisson(dtype=mindspore.float32)
|
||||
|
@ -79,7 +79,7 @@ class Poisson(Distribution):
|
|||
>>> # Functions `mean`, `mode`, `sd`, and 'var' have the same arguments as follows.
|
||||
>>> # Args:
|
||||
>>> # rate (Tensor): the rate of the distribution. Default: self.rate.
|
||||
>>> # Examples of `mean`, `sd`, `mode`, `var`, and `entropy` are similar.
|
||||
>>> # Examples of `mean`, `sd`, `mode`, and `var` are similar.
|
||||
>>> ans = p1.mean() # return 2
|
||||
>>> print(ans)
|
||||
0.5
|
||||
|
@ -96,10 +96,10 @@ class Poisson(Distribution):
|
|||
>>> # probs1 (Tensor): the rate of the distribution. Default: self.rate.
|
||||
>>> ans = p1.sample()
|
||||
>>> print(ans.shape)
|
||||
()
|
||||
(1, )
|
||||
>>> ans = p1.sample((2,3))
|
||||
>>> print(ans.shape)
|
||||
(2, 3)
|
||||
(2, 3, 1)
|
||||
>>> ans = p1.sample((2,3), rate_b)
|
||||
>>> print(ans.shape)
|
||||
(2, 3, 3)
|
||||
|
@ -120,6 +120,11 @@ class Poisson(Distribution):
|
|||
param['param_dict'] = {'rate': rate}
|
||||
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
|
||||
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
|
||||
|
||||
# As some operators can't accept scalar input, check the type here
|
||||
if isinstance(rate, (int, float)):
|
||||
raise TypeError("Parameter rate can't be scalar")
|
||||
|
||||
super(Poisson, self).__init__(seed, dtype, name, param)
|
||||
|
||||
self._rate = self._add_parameter(rate, 'rate')
|
||||
|
|
|
@ -52,7 +52,7 @@ class LogProb(nn.Cell):
|
|||
"""
|
||||
def __init__(self):
|
||||
super(LogProb, self).__init__()
|
||||
self.p = msd.Poisson(0.5, dtype=dtype.float32)
|
||||
self.p = msd.Poisson([0.5], dtype=dtype.float32)
|
||||
|
||||
def construct(self, x_):
|
||||
return self.p.log_prob(x_)
|
||||
|
@ -169,7 +169,7 @@ class SF(nn.Cell):
|
|||
"""
|
||||
def __init__(self):
|
||||
super(SF, self).__init__()
|
||||
self.p = msd.Poisson(0.5, dtype=dtype.float32)
|
||||
self.p = msd.Poisson([0.5], dtype=dtype.float32)
|
||||
|
||||
def construct(self, x_):
|
||||
return self.p.survival_function(x_)
|
||||
|
@ -192,7 +192,7 @@ class LogSF(nn.Cell):
|
|||
"""
|
||||
def __init__(self):
|
||||
super(LogSF, self).__init__()
|
||||
self.p = msd.Poisson(0.5, dtype=dtype.float32)
|
||||
self.p = msd.Poisson([0.5], dtype=dtype.float32)
|
||||
|
||||
def construct(self, x_):
|
||||
return self.p.log_survival(x_)
|
||||
|
|
|
@ -32,27 +32,33 @@ def test_gamma_shape_errpr():
|
|||
|
||||
def test_type():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Gamma(0., 1., dtype=dtype.int32)
|
||||
msd.Gamma([0.], [1.], dtype=dtype.int32)
|
||||
|
||||
def test_name():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Gamma(0., 1., name=1.0)
|
||||
msd.Gamma([0.], [1.], name=1.0)
|
||||
|
||||
def test_seed():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Gamma(0., 1., seed='seed')
|
||||
msd.Gamma([0.], [1.], seed='seed')
|
||||
|
||||
def test_concentration1():
|
||||
with pytest.raises(ValueError):
|
||||
msd.Gamma(0., 1.)
|
||||
msd.Gamma([0.], [1.])
|
||||
with pytest.raises(ValueError):
|
||||
msd.Gamma(-1., 1.)
|
||||
msd.Gamma([-1.], [1.])
|
||||
|
||||
def test_concentration0():
|
||||
with pytest.raises(ValueError):
|
||||
msd.Gamma(1., 0.)
|
||||
msd.Gamma([1.], [0.])
|
||||
with pytest.raises(ValueError):
|
||||
msd.Gamma(1., -1.)
|
||||
msd.Gamma([1.], [-1.])
|
||||
|
||||
def test_scalar():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Gamma(3., [4.])
|
||||
with pytest.raises(TypeError):
|
||||
msd.Gamma([3.], -4.)
|
||||
|
||||
def test_arguments():
|
||||
"""
|
||||
|
|
|
@ -32,21 +32,27 @@ def test_gamma_shape_errpr():
|
|||
|
||||
def test_type():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Gamma(0., 1., dtype=dtype.int32)
|
||||
msd.Gamma([0.], [1.], dtype=dtype.int32)
|
||||
|
||||
def test_name():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Gamma(0., 1., name=1.0)
|
||||
msd.Gamma([0.], [1.], name=1.0)
|
||||
|
||||
def test_seed():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Gamma(0., 1., seed='seed')
|
||||
msd.Gamma([0.], [1.], seed='seed')
|
||||
|
||||
def test_rate():
|
||||
with pytest.raises(ValueError):
|
||||
msd.Gamma(0., 0.)
|
||||
msd.Gamma([0.], [0.])
|
||||
with pytest.raises(ValueError):
|
||||
msd.Gamma(0., -1.)
|
||||
msd.Gamma([0.], [-1.])
|
||||
|
||||
def test_scalar():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Gamma(3., [4.])
|
||||
with pytest.raises(TypeError):
|
||||
msd.Gamma([3.], -4.)
|
||||
|
||||
def test_arguments():
|
||||
"""
|
||||
|
|
|
@ -53,6 +53,10 @@ def test_rate():
|
|||
with pytest.raises(ValueError):
|
||||
msd.Poisson([0.0], dtype=dtype.float32)
|
||||
|
||||
def test_scalar():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Poisson(0.1, seed='seed')
|
||||
|
||||
class PoissonProb(nn.Cell):
|
||||
"""
|
||||
Poisson distribution: initialize with rate.
|
||||
|
|
Loading…
Reference in New Issue