changed distribution api
This commit is contained in:
parent
6945eb2821
commit
e87e1fc6bc
|
@ -34,55 +34,56 @@ class Bernoulli(Distribution):
|
|||
|
||||
Examples:
|
||||
>>> # To initialize a Bernoulli distribution of prob 0.5
|
||||
>>> n = nn.Bernoulli(0.5, dtype=mstype.int32)
|
||||
>>> import mindspore.nn.probability.distribution as msd
|
||||
>>> b = msd.Bernoulli(0.5, dtype=mstype.int32)
|
||||
>>>
|
||||
>>> # The following creates two independent Bernoulli distributions
|
||||
>>> n = nn.Bernoulli([0.5, 0.5], dtype=mstype.int32)
|
||||
>>> b = msd.Bernoulli([0.5, 0.5], dtype=mstype.int32)
|
||||
>>>
|
||||
>>> # A Bernoulli distribution can be initilized without arguments
|
||||
>>> # In this case, probs must be passed in through construct.
|
||||
>>> n = nn.Bernoulli(dtype=mstype.int32)
|
||||
>>> # In this case, probs must be passed in through args during function calls.
|
||||
>>> b = msd.Bernoulli(dtype=mstype.int32)
|
||||
>>>
|
||||
>>> # To use Bernoulli distribution in a network
|
||||
>>> # To use Bernoulli in a network
|
||||
>>> class net(Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(net, self).__init__():
|
||||
>>> self.b1 = nn.Bernoulli(0.5, dtype=mstype.int32)
|
||||
>>> self.b2 = nn.Bernoulli(dtype=mstype.int32)
|
||||
>>> self.b1 = msd.Bernoulli(0.5, dtype=mstype.int32)
|
||||
>>> self.b2 = msd.Bernoulli(dtype=mstype.int32)
|
||||
>>>
|
||||
>>> # All the following calls in construct are valid
|
||||
>>> def construct(self, value, probs_b, probs_a):
|
||||
>>>
|
||||
>>> # Similar calls can be made to other probability functions
|
||||
>>> # by replacing 'prob' with the name of the function
|
||||
>>> ans = self.b1('prob', value)
|
||||
>>> ans = self.b1.prob(value)
|
||||
>>> # Evaluate with the respect to distribution b
|
||||
>>> ans = self.b1('prob', value, probs_b)
|
||||
>>> ans = self.b1.prob(value, probs_b)
|
||||
>>>
|
||||
>>> # probs must be passed in through construct
|
||||
>>> ans = self.b2('prob', value, probs_a)
|
||||
>>> # probs must be passed in during function calls
|
||||
>>> ans = self.b2.prob(value, probs_a)
|
||||
>>>
|
||||
>>> # Functions 'sd', 'var', 'entropy' have the same usage like 'mean'
|
||||
>>> # Will return [0.0]
|
||||
>>> ans = self.b1('mean')
|
||||
>>> # Will return mean_b
|
||||
>>> ans = self.b1('mean', probs_b)
|
||||
>>> # Functions 'sd', 'var', 'entropy' have the same usage as 'mean'
|
||||
>>> # Will return 0.5
|
||||
>>> ans = self.b1.mean()
|
||||
>>> # Will return probs_b
|
||||
>>> ans = self.b1.mean(probs_b)
|
||||
>>>
|
||||
>>> # probs must be passed in through construct
|
||||
>>> ans = self.b2('mean', probs_a)
|
||||
>>> # probs must be passed in during function calls
|
||||
>>> ans = self.b2.mean(probs_a)
|
||||
>>>
|
||||
>>> # Usage of 'kl_loss' and 'cross_entropy' are similar
|
||||
>>> ans = self.b1('kl_loss', 'Bernoulli', probs_b)
|
||||
>>> ans = self.b1('kl_loss', 'Bernoulli', probs_b, probs_a)
|
||||
>>> ans = self.b1.kl_loss('Bernoulli', probs_b)
|
||||
>>> ans = self.b1.kl_loss('Bernoulli', probs_b, probs_a)
|
||||
>>>
|
||||
>>> # Additional probs_a must be passed in through construct
|
||||
>>> ans = self.b2('kl_loss', 'Bernoulli', probs_b, probs_a)
|
||||
>>> # Additional probs_a must be passed in through
|
||||
>>> ans = self.b2.kl_loss('Bernoulli', probs_b, probs_a)
|
||||
>>>
|
||||
>>> # Sample Usage
|
||||
>>> ans = self.b1('sample')
|
||||
>>> ans = self.b1('sample', (2,3))
|
||||
>>> ans = self.b1('sample', (2,3), probs_b)
|
||||
>>> ans = self.b2('sample', (2,3), probs_a)
|
||||
>>> # Sample
|
||||
>>> ans = self.b1.sample()
|
||||
>>> ans = self.b1.sample((2,3))
|
||||
>>> ans = self.b1.sample((2,3), probs_b)
|
||||
>>> ans = self.b2.sample((2,3), probs_a)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -130,71 +131,61 @@ class Bernoulli(Distribution):
|
|||
"""
|
||||
return self._probs
|
||||
|
||||
def _mean(self, name='mean', probs1=None):
|
||||
def _mean(self, probs1=None):
|
||||
r"""
|
||||
.. math::
|
||||
MEAN(B) = probs1
|
||||
"""
|
||||
if name == 'mean':
|
||||
return self.probs if probs1 is None else probs1
|
||||
return None
|
||||
return self.probs if probs1 is None else probs1
|
||||
|
||||
def _mode(self, name='mode', probs1=None):
|
||||
def _mode(self, probs1=None):
|
||||
r"""
|
||||
.. math::
|
||||
MODE(B) = 1 if probs1 > 0.5 else = 0
|
||||
"""
|
||||
if name == 'mode':
|
||||
probs1 = self.probs if probs1 is None else probs1
|
||||
prob_type = self.dtypeop(probs1)
|
||||
zeros = self.fill(prob_type, self.shape(probs1), 0.0)
|
||||
ones = self.fill(prob_type, self.shape(probs1), 1.0)
|
||||
comp = self.less(0.5, probs1)
|
||||
return self.select(comp, ones, zeros)
|
||||
return None
|
||||
probs1 = self.probs if probs1 is None else probs1
|
||||
prob_type = self.dtypeop(probs1)
|
||||
zeros = self.fill(prob_type, self.shape(probs1), 0.0)
|
||||
ones = self.fill(prob_type, self.shape(probs1), 1.0)
|
||||
comp = self.less(0.5, probs1)
|
||||
return self.select(comp, ones, zeros)
|
||||
|
||||
def _var(self, name='var', probs1=None):
|
||||
def _var(self, probs1=None):
|
||||
r"""
|
||||
.. math::
|
||||
VAR(B) = probs1 * probs0
|
||||
"""
|
||||
if name in self._variance_functions:
|
||||
probs1 = self.probs if probs1 is None else probs1
|
||||
probs0 = 1.0 - probs1
|
||||
return probs0 * probs1
|
||||
return None
|
||||
probs1 = self.probs if probs1 is None else probs1
|
||||
probs0 = 1.0 - probs1
|
||||
return probs0 * probs1
|
||||
|
||||
def _entropy(self, name='entropy', probs=None):
|
||||
def _entropy(self, probs=None):
|
||||
r"""
|
||||
.. math::
|
||||
H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1)
|
||||
"""
|
||||
if name == 'entropy':
|
||||
probs1 = self.probs if probs is None else probs
|
||||
probs0 = 1 - probs1
|
||||
return -1 * (probs0 * self.log(probs0)) - (probs1 * self.log(probs1))
|
||||
return None
|
||||
probs1 = self.probs if probs is None else probs
|
||||
probs0 = 1 - probs1
|
||||
return -1 * (probs0 * self.log(probs0)) - (probs1 * self.log(probs1))
|
||||
|
||||
def _cross_entropy(self, name, dist, probs1_b, probs1_a=None):
|
||||
def _cross_entropy(self, dist, probs1_b, probs1_a=None):
|
||||
"""
|
||||
Evaluate cross_entropy between Bernoulli distributions.
|
||||
|
||||
Args:
|
||||
name (str): name of the funtion.
|
||||
dist (str): type of the distributions. Should be "Bernoulli" in this case.
|
||||
probs1_b (Tensor): probs1 of distribution b.
|
||||
probs1_a (Tensor): probs1 of distribution a. Default: self.probs.
|
||||
"""
|
||||
if name == 'cross_entropy' and dist == 'Bernoulli':
|
||||
return self._entropy(probs=probs1_a) + self._kl_loss(name, dist, probs1_b, probs1_a)
|
||||
if dist == 'Bernoulli':
|
||||
return self._entropy(probs=probs1_a) + self._kl_loss(dist, probs1_b, probs1_a)
|
||||
return None
|
||||
|
||||
def _prob(self, name, value, probs=None):
|
||||
def _prob(self, value, probs=None):
|
||||
r"""
|
||||
pmf of Bernoulli distribution.
|
||||
|
||||
Args:
|
||||
name (str): name of the function. Should be "prob" when passed in from construct.
|
||||
value (Tensor): a Tensor composed of only zeros and ones.
|
||||
probs (Tensor): probability of outcome is 1. Default: self.probs.
|
||||
|
||||
|
@ -202,18 +193,15 @@ class Bernoulli(Distribution):
|
|||
pmf(k) = probs1 if k = 1;
|
||||
pmf(k) = probs0 if k = 0;
|
||||
"""
|
||||
if name in self._prob_functions:
|
||||
probs1 = self.probs if probs is None else probs
|
||||
probs0 = 1.0 - probs1
|
||||
return (probs1 * value) + (probs0 * (1.0 - value))
|
||||
return None
|
||||
probs1 = self.probs if probs is None else probs
|
||||
probs0 = 1.0 - probs1
|
||||
return (probs1 * value) + (probs0 * (1.0 - value))
|
||||
|
||||
def _cdf(self, name, value, probs=None):
|
||||
def _cdf(self, value, probs=None):
|
||||
r"""
|
||||
cdf of Bernoulli distribution.
|
||||
|
||||
Args:
|
||||
name (str): name of the function.
|
||||
value (Tensor): value to be evaluated.
|
||||
probs (Tensor): probability of outcome is 1. Default: self.probs.
|
||||
|
||||
|
@ -222,25 +210,22 @@ class Bernoulli(Distribution):
|
|||
cdf(k) = probs0 if 0 <= k <1;
|
||||
cdf(k) = 1 if k >=1;
|
||||
"""
|
||||
if name in self._cdf_survival_functions:
|
||||
probs1 = self.probs if probs is None else probs
|
||||
prob_type = self.dtypeop(probs1)
|
||||
value = value * self.fill(prob_type, self.shape(probs1), 1.0)
|
||||
probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0)
|
||||
comp_zero = self.less(value, 0.0)
|
||||
comp_one = self.less(value, 1.0)
|
||||
zeros = self.fill(prob_type, self.shape(value), 0.0)
|
||||
ones = self.fill(prob_type, self.shape(value), 1.0)
|
||||
less_than_zero = self.select(comp_zero, zeros, probs0)
|
||||
return self.select(comp_one, less_than_zero, ones)
|
||||
return None
|
||||
probs1 = self.probs if probs is None else probs
|
||||
prob_type = self.dtypeop(probs1)
|
||||
value = value * self.fill(prob_type, self.shape(probs1), 1.0)
|
||||
probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0)
|
||||
comp_zero = self.less(value, 0.0)
|
||||
comp_one = self.less(value, 1.0)
|
||||
zeros = self.fill(prob_type, self.shape(value), 0.0)
|
||||
ones = self.fill(prob_type, self.shape(value), 1.0)
|
||||
less_than_zero = self.select(comp_zero, zeros, probs0)
|
||||
return self.select(comp_one, less_than_zero, ones)
|
||||
|
||||
def _kl_loss(self, name, dist, probs1_b, probs1_a=None):
|
||||
def _kl_loss(self, dist, probs1_b, probs1_a=None):
|
||||
r"""
|
||||
Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b).
|
||||
|
||||
Args:
|
||||
name (str): name of the funtion.
|
||||
dist (str): type of the distributions. Should be "Bernoulli" in this case.
|
||||
probs1_b (Tensor): probs1 of distribution b.
|
||||
probs1_a (Tensor): probs1 of distribution a. Default: self.probs.
|
||||
|
@ -249,31 +234,28 @@ class Bernoulli(Distribution):
|
|||
KL(a||b) = probs1_a * \log(\fract{probs1_a}{probs1_b}) +
|
||||
probs0_a * \log(\fract{probs0_a}{probs0_b})
|
||||
"""
|
||||
if name in self._divergence_functions and dist == 'Bernoulli':
|
||||
if dist == 'Bernoulli':
|
||||
probs1_a = self.probs if probs1_a is None else probs1_a
|
||||
probs0_a = 1.0 - probs1_a
|
||||
probs0_b = 1.0 - probs1_b
|
||||
return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b)
|
||||
return None
|
||||
|
||||
def _sample(self, name, shape=(), probs=None):
|
||||
def _sample(self, shape=(), probs=None):
|
||||
"""
|
||||
Sampling.
|
||||
|
||||
Args:
|
||||
name (str): name of the function. Should always be 'sample' when passed in from construct.
|
||||
shape (tuple): shape of the sample. Default: ().
|
||||
probs (Tensor): probs1 of the samples. Default: self.probs.
|
||||
|
||||
Returns:
|
||||
Tensor, shape is shape + batch_shape.
|
||||
"""
|
||||
if name == 'sample':
|
||||
probs1 = self.probs if probs is None else probs
|
||||
l_zero = self.const(0.0)
|
||||
h_one = self.const(1.0)
|
||||
sample_uniform = self.uniform(shape + self.shape(probs1), l_zero, h_one)
|
||||
sample = self.less(sample_uniform, probs1)
|
||||
sample = self.cast(sample, self.dtype)
|
||||
return sample
|
||||
return None
|
||||
probs1 = self.probs if probs is None else probs
|
||||
l_zero = self.const(0.0)
|
||||
h_one = self.const(1.0)
|
||||
sample_uniform = self.uniform(shape + self.shape(probs1), l_zero, h_one)
|
||||
sample = self.less(sample_uniform, probs1)
|
||||
sample = self.cast(sample, self.dtype)
|
||||
return sample
|
||||
|
|
|
@ -27,11 +27,7 @@ class Distribution(Cell):
|
|||
|
||||
Note:
|
||||
Derived class should override operations such as ,_mean, _prob,
|
||||
and _log_prob. Functions should be called through construct when
|
||||
used inside a network. Arguments should be passed in through *args
|
||||
in the form of function name followed by additional arguments.
|
||||
Functions such as cdf and prob, require a value to be passed in while
|
||||
functions such as mean and sd do not require arguments other than name.
|
||||
and _log_prob. Arguments should be passed in through *args.
|
||||
|
||||
Dist_spec_args are unique for each type of distribution. For example, mean and sd
|
||||
are the dist_spec_args for a Normal distribution.
|
||||
|
@ -73,11 +69,6 @@ class Distribution(Cell):
|
|||
self._set_log_survival()
|
||||
self._set_cross_entropy()
|
||||
|
||||
self._prob_functions = ('prob', 'log_prob')
|
||||
self._cdf_survival_functions = ('cdf', 'log_cdf', 'survival_function', 'log_survival')
|
||||
self._variance_functions = ('var', 'sd')
|
||||
self._divergence_functions = ('kl_loss', 'cross_entropy')
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
@ -185,7 +176,7 @@ class Distribution(Cell):
|
|||
Evaluate the log probability(pdf or pmf) at the given value.
|
||||
|
||||
Note:
|
||||
Args must include name of the function and value.
|
||||
Args must include value.
|
||||
Dist_spec_args are optional.
|
||||
"""
|
||||
return self._call_log_prob(*args)
|
||||
|
@ -204,7 +195,7 @@ class Distribution(Cell):
|
|||
Evaluate the probability (pdf or pmf) at given value.
|
||||
|
||||
Note:
|
||||
Args must include name of the function and value.
|
||||
Args must include value.
|
||||
Dist_spec_args are optional.
|
||||
"""
|
||||
return self._call_prob(*args)
|
||||
|
@ -223,7 +214,7 @@ class Distribution(Cell):
|
|||
Evaluate the cdf at given value.
|
||||
|
||||
Note:
|
||||
Args must include name of the function and value.
|
||||
Args must include value.
|
||||
Dist_spec_args are optional.
|
||||
"""
|
||||
return self._call_cdf(*args)
|
||||
|
@ -260,7 +251,7 @@ class Distribution(Cell):
|
|||
Evaluate the log cdf at given value.
|
||||
|
||||
Note:
|
||||
Args must include name of the function and value.
|
||||
Args must include value.
|
||||
Dist_spec_args are optional.
|
||||
"""
|
||||
return self._call_log_cdf(*args)
|
||||
|
@ -279,7 +270,7 @@ class Distribution(Cell):
|
|||
Evaluate the survival function at given value.
|
||||
|
||||
Note:
|
||||
Args must include name of the function and value.
|
||||
Args must include value.
|
||||
Dist_spec_args are optional.
|
||||
"""
|
||||
return self._call_survival(*args)
|
||||
|
@ -307,7 +298,7 @@ class Distribution(Cell):
|
|||
Evaluate the log survival function at given value.
|
||||
|
||||
Note:
|
||||
Args must include name of the function and value.
|
||||
Args must include value.
|
||||
Dist_spec_args are optional.
|
||||
"""
|
||||
return self._call_log_survival(*args)
|
||||
|
@ -326,7 +317,7 @@ class Distribution(Cell):
|
|||
Evaluate the KL divergence, i.e. KL(a||b).
|
||||
|
||||
Note:
|
||||
Args must include name of the function, type of the distribution, parameters of distribution b.
|
||||
Args must include type of the distribution, parameters of distribution b.
|
||||
Parameters for distribution a are optional.
|
||||
"""
|
||||
return self._kl_loss(*args)
|
||||
|
@ -336,7 +327,7 @@ class Distribution(Cell):
|
|||
Evaluate the mean.
|
||||
|
||||
Note:
|
||||
Args must include the name of function. Dist_spec_args are optional.
|
||||
Dist_spec_args are optional.
|
||||
"""
|
||||
return self._mean(*args)
|
||||
|
||||
|
@ -345,7 +336,7 @@ class Distribution(Cell):
|
|||
Evaluate the mode.
|
||||
|
||||
Note:
|
||||
Args must include the name of function. Dist_spec_args are optional.
|
||||
Dist_spec_args are optional.
|
||||
"""
|
||||
return self._mode(*args)
|
||||
|
||||
|
@ -354,7 +345,7 @@ class Distribution(Cell):
|
|||
Evaluate the standard deviation.
|
||||
|
||||
Note:
|
||||
Args must include the name of function. Dist_spec_args are optional.
|
||||
Dist_spec_args are optional.
|
||||
"""
|
||||
return self._call_sd(*args)
|
||||
|
||||
|
@ -363,7 +354,7 @@ class Distribution(Cell):
|
|||
Evaluate the variance.
|
||||
|
||||
Note:
|
||||
Args must include the name of function. Dist_spec_args are optional.
|
||||
Dist_spec_args are optional.
|
||||
"""
|
||||
return self._call_var(*args)
|
||||
|
||||
|
@ -390,7 +381,7 @@ class Distribution(Cell):
|
|||
Evaluate the entropy.
|
||||
|
||||
Note:
|
||||
Args must include the name of function. Dist_spec_args are optional.
|
||||
Dist_spec_args are optional.
|
||||
"""
|
||||
return self._entropy(*args)
|
||||
|
||||
|
@ -399,7 +390,7 @@ class Distribution(Cell):
|
|||
Evaluate the cross_entropy between distribution a and b.
|
||||
|
||||
Note:
|
||||
Args must include name of the function, type of the distribution, parameters of distribution b.
|
||||
Args must include type of the distribution, parameters of distribution b.
|
||||
Parameters for distribution a are optional.
|
||||
"""
|
||||
return self._call_cross_entropy(*args)
|
||||
|
@ -421,13 +412,13 @@ class Distribution(Cell):
|
|||
*args (list): arguments passed in through construct.
|
||||
|
||||
Note:
|
||||
Args must include name of the function.
|
||||
Shape of the sample and dist_spec_args are optional.
|
||||
Shape of the sample is default to ().
|
||||
Dist_spec_args are optional.
|
||||
"""
|
||||
return self._sample(*args)
|
||||
|
||||
|
||||
def construct(self, *inputs):
|
||||
def construct(self, name, *args):
|
||||
"""
|
||||
Override construct in Cell.
|
||||
|
||||
|
@ -437,35 +428,36 @@ class Distribution(Cell):
|
|||
'var', 'sd', 'entropy', 'kl_loss', 'cross_entropy', 'sample'.
|
||||
|
||||
Args:
|
||||
*inputs (list): inputs[0] is always the name of the function.
|
||||
name (str): name of the function.
|
||||
*args (list): list of arguments needed for the function.
|
||||
"""
|
||||
|
||||
if inputs[0] == 'log_prob':
|
||||
return self._call_log_prob(*inputs)
|
||||
if inputs[0] == 'prob':
|
||||
return self._call_prob(*inputs)
|
||||
if inputs[0] == 'cdf':
|
||||
return self._call_cdf(*inputs)
|
||||
if inputs[0] == 'log_cdf':
|
||||
return self._call_log_cdf(*inputs)
|
||||
if inputs[0] == 'survival_function':
|
||||
return self._call_survival(*inputs)
|
||||
if inputs[0] == 'log_survival':
|
||||
return self._call_log_survival(*inputs)
|
||||
if inputs[0] == 'kl_loss':
|
||||
return self._kl_loss(*inputs)
|
||||
if inputs[0] == 'mean':
|
||||
return self._mean(*inputs)
|
||||
if inputs[0] == 'mode':
|
||||
return self._mode(*inputs)
|
||||
if inputs[0] == 'sd':
|
||||
return self._call_sd(*inputs)
|
||||
if inputs[0] == 'var':
|
||||
return self._call_var(*inputs)
|
||||
if inputs[0] == 'entropy':
|
||||
return self._entropy(*inputs)
|
||||
if inputs[0] == 'cross_entropy':
|
||||
return self._call_cross_entropy(*inputs)
|
||||
if inputs[0] == 'sample':
|
||||
return self._sample(*inputs)
|
||||
if name == 'log_prob':
|
||||
return self._call_log_prob(*args)
|
||||
if name == 'prob':
|
||||
return self._call_prob(*args)
|
||||
if name == 'cdf':
|
||||
return self._call_cdf(*args)
|
||||
if name == 'log_cdf':
|
||||
return self._call_log_cdf(*args)
|
||||
if name == 'survival_function':
|
||||
return self._call_survival(*args)
|
||||
if name == 'log_survival':
|
||||
return self._call_log_survival(*args)
|
||||
if name == 'kl_loss':
|
||||
return self._kl_loss(*args)
|
||||
if name == 'mean':
|
||||
return self._mean(*args)
|
||||
if name == 'mode':
|
||||
return self._mode(*args)
|
||||
if name == 'sd':
|
||||
return self._call_sd(*args)
|
||||
if name == 'var':
|
||||
return self._call_var(*args)
|
||||
if name == 'entropy':
|
||||
return self._entropy(*args)
|
||||
if name == 'cross_entropy':
|
||||
return self._call_cross_entropy(*args)
|
||||
if name == 'sample':
|
||||
return self._sample(*args)
|
||||
return None
|
||||
|
|
|
@ -35,55 +35,56 @@ class Exponential(Distribution):
|
|||
|
||||
Examples:
|
||||
>>> # To initialize an Exponential distribution of rate 0.5
|
||||
>>> n = nn.Exponential(0.5, dtype=mstype.float32)
|
||||
>>> import mindspore.nn.probability.distribution as msd
|
||||
>>> e = msd.Exponential(0.5, dtype=mstype.float32)
|
||||
>>>
|
||||
>>> # The following creates two independent Exponential distributions
|
||||
>>> n = nn.Exponential([0.5, 0.5], dtype=mstype.float32)
|
||||
>>> e = msd.Exponential([0.5, 0.5], dtype=mstype.float32)
|
||||
>>>
|
||||
>>> # A Exponential distribution can be initilized without arguments
|
||||
>>> # In this case, rate must be passed in through construct.
|
||||
>>> n = nn.Exponential(dtype=mstype.float32)
|
||||
>>> # An Exponential distribution can be initilized without arguments
|
||||
>>> # In this case, rate must be passed in through args during function calls
|
||||
>>> e = msd.Exponential(dtype=mstype.float32)
|
||||
>>>
|
||||
>>> # To use Exponential distribution in a network
|
||||
>>> # To use Exponential in a network
|
||||
>>> class net(Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(net, self).__init__():
|
||||
>>> self.e1 = nn.Exponential(0.5, dtype=mstype.float32)
|
||||
>>> self.e2 = nn.Exponential(dtype=mstype.float32)
|
||||
>>> self.e1 = msd.Exponential(0.5, dtype=mstype.float32)
|
||||
>>> self.e2 = msd.Exponential(dtype=mstype.float32)
|
||||
>>>
|
||||
>>> # All the following calls in construct are valid
|
||||
>>> def construct(self, value, rate_b, rate_a):
|
||||
>>>
|
||||
>>> # Similar calls can be made to other probability functions
|
||||
>>> # by replacing 'prob' with the name of the function
|
||||
>>> ans = self.e1('prob', value)
|
||||
>>> ans = self.e1.prob(value)
|
||||
>>> # Evaluate with the respect to distribution b
|
||||
>>> ans = self.e1('prob', value, rate_b)
|
||||
>>> ans = self.e1.prob(value, rate_b)
|
||||
>>>
|
||||
>>> # Rate must be passed in through construct
|
||||
>>> ans = self.e2('prob', value, rate_a)
|
||||
>>> # Rate must be passed in during function calls
|
||||
>>> ans = self.e2.prob(value, rate_a)
|
||||
>>>
|
||||
>>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean'
|
||||
>>> # Will return [0.0]
|
||||
>>> ans = self.e1('mean')
|
||||
>>> # Will return mean_b
|
||||
>>> ans = self.e1('mean', rate_b)
|
||||
>>> # Functions 'sd', 'var', 'entropy' have the same usage as'mean'
|
||||
>>> # Will return 2
|
||||
>>> ans = self.e1.mean()
|
||||
>>> # Will return 1 / rate_b
|
||||
>>> ans = self.e1.mean(rate_b)
|
||||
>>>
|
||||
>>> # Rate must be passed in through construct
|
||||
>>> ans = self.e2('mean', rate_a)
|
||||
>>> # Rate must be passed in during function calls
|
||||
>>> ans = self.e2.mean(rate_a)
|
||||
>>>
|
||||
>>> # Usage of 'kl_loss' and 'cross_entropy' are similar
|
||||
>>> ans = self.e1('kl_loss', 'Exponential', rate_b)
|
||||
>>> ans = self.e1('kl_loss', 'Exponential', rate_b, rate_a)
|
||||
>>> ans = self.e1.kl_loss('Exponential', rate_b)
|
||||
>>> ans = self.e1.kl_loss('Exponential', rate_b, rate_a)
|
||||
>>>
|
||||
>>> # Additional rate must be passed in through construct
|
||||
>>> ans = self.e2('kl_loss', 'Exponential', rate_b, rate_a)
|
||||
>>> # Additional rate must be passed in
|
||||
>>> ans = self.e2.kl_loss('Exponential', rate_b, rate_a)
|
||||
>>>
|
||||
>>> # Sample Usage
|
||||
>>> ans = self.e1('sample')
|
||||
>>> ans = self.e1('sample', (2,3))
|
||||
>>> ans = self.e1('sample', (2,3), rate_b)
|
||||
>>> ans = self.e2('sample', (2,3), rate_a)
|
||||
>>> # Sample
|
||||
>>> ans = self.e1.sample()
|
||||
>>> ans = self.e1.sample((2,3))
|
||||
>>> ans = self.e1.sample((2,3), rate_b)
|
||||
>>> ans = self.e2.sample((2,3), rate_a)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -131,67 +132,59 @@ class Exponential(Distribution):
|
|||
"""
|
||||
return self._rate
|
||||
|
||||
def _mean(self, name='mean', rate=None):
|
||||
def _mean(self, rate=None):
|
||||
r"""
|
||||
.. math::
|
||||
MEAN(EXP) = \fract{1.0}{\lambda}.
|
||||
"""
|
||||
if name == 'mean':
|
||||
rate = self.rate if rate is None else rate
|
||||
return 1.0 / rate
|
||||
return None
|
||||
rate = self.rate if rate is None else rate
|
||||
return 1.0 / rate
|
||||
|
||||
def _mode(self, name='mode', rate=None):
|
||||
|
||||
def _mode(self, rate=None):
|
||||
r"""
|
||||
.. math::
|
||||
MODE(EXP) = 0.
|
||||
"""
|
||||
if name == 'mode':
|
||||
rate = self.rate if rate is None else rate
|
||||
return self.fill(self.dtype, self.shape(rate), 0.)
|
||||
return None
|
||||
rate = self.rate if rate is None else rate
|
||||
return self.fill(self.dtype, self.shape(rate), 0.)
|
||||
|
||||
def _sd(self, name='sd', rate=None):
|
||||
def _sd(self, rate=None):
|
||||
r"""
|
||||
.. math::
|
||||
sd(EXP) = \fract{1.0}{\lambda}.
|
||||
"""
|
||||
if name in self._variance_functions:
|
||||
rate = self.rate if rate is None else rate
|
||||
return 1.0 / rate
|
||||
return None
|
||||
rate = self.rate if rate is None else rate
|
||||
return 1.0 / rate
|
||||
|
||||
def _entropy(self, name='entropy', rate=None):
|
||||
def _entropy(self, rate=None):
|
||||
r"""
|
||||
.. math::
|
||||
H(Exp) = 1 - \log(\lambda).
|
||||
"""
|
||||
rate = self.rate if rate is None else rate
|
||||
if name == 'entropy':
|
||||
return 1.0 - self.log(rate)
|
||||
return None
|
||||
return 1.0 - self.log(rate)
|
||||
|
||||
def _cross_entropy(self, name, dist, rate_b, rate_a=None):
|
||||
|
||||
def _cross_entropy(self, dist, rate_b, rate_a=None):
|
||||
"""
|
||||
Evaluate cross_entropy between Exponential distributions.
|
||||
|
||||
Args:
|
||||
name (str): name of the funtion. Should always be "cross_entropy" when passed in from construct.
|
||||
dist (str): type of the distributions. Should be "Exponential" in this case.
|
||||
rate_b (Tensor): rate of distribution b.
|
||||
rate_a (Tensor): rate of distribution a. Default: self.rate.
|
||||
"""
|
||||
if name == 'cross_entropy' and dist == 'Exponential':
|
||||
return self._entropy(rate=rate_a) + self._kl_loss(name, dist, rate_b, rate_a)
|
||||
if dist == 'Exponential':
|
||||
return self._entropy(rate=rate_a) + self._kl_loss(dist, rate_b, rate_a)
|
||||
return None
|
||||
|
||||
def _prob(self, name, value, rate=None):
|
||||
def _prob(self, value, rate=None):
|
||||
r"""
|
||||
pdf of Exponential distribution.
|
||||
|
||||
Args:
|
||||
Args:
|
||||
name (str): name of the function.
|
||||
value (Tensor): value to be evaluated.
|
||||
rate (Tensor): rate of the distribution. Default: self.rate.
|
||||
|
||||
|
@ -201,20 +194,17 @@ class Exponential(Distribution):
|
|||
.. math::
|
||||
pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0
|
||||
"""
|
||||
if name in self._prob_functions:
|
||||
rate = self.rate if rate is None else rate
|
||||
prob = rate * self.exp(-1. * rate * value)
|
||||
zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0)
|
||||
comp = self.less(value, zeros)
|
||||
return self.select(comp, zeros, prob)
|
||||
return None
|
||||
rate = self.rate if rate is None else rate
|
||||
prob = rate * self.exp(-1. * rate * value)
|
||||
zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0)
|
||||
comp = self.less(value, zeros)
|
||||
return self.select(comp, zeros, prob)
|
||||
|
||||
def _cdf(self, name, value, rate=None):
|
||||
def _cdf(self, value, rate=None):
|
||||
r"""
|
||||
cdf of Exponential distribution.
|
||||
|
||||
Args:
|
||||
name (str): name of the function.
|
||||
value (Tensor): value to be evaluated.
|
||||
rate (Tensor): rate of the distribution. Default: self.rate.
|
||||
|
||||
|
@ -224,45 +214,40 @@ class Exponential(Distribution):
|
|||
.. math::
|
||||
cdf(x) = 1.0 - \exp(-1 * \lambda * x) if x >= 0 else 0
|
||||
"""
|
||||
if name in self._cdf_survival_functions:
|
||||
rate = self.rate if rate is None else rate
|
||||
cdf = 1.0 - self.exp(-1. * rate * value)
|
||||
zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
|
||||
comp = self.less(value, zeros)
|
||||
return self.select(comp, zeros, cdf)
|
||||
return None
|
||||
rate = self.rate if rate is None else rate
|
||||
cdf = 1.0 - self.exp(-1. * rate * value)
|
||||
zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
|
||||
comp = self.less(value, zeros)
|
||||
return self.select(comp, zeros, cdf)
|
||||
|
||||
def _kl_loss(self, name, dist, rate_b, rate_a=None):
|
||||
|
||||
def _kl_loss(self, dist, rate_b, rate_a=None):
|
||||
"""
|
||||
Evaluate exp-exp kl divergence, i.e. KL(a||b).
|
||||
|
||||
Args:
|
||||
name (str): name of the funtion.
|
||||
dist (str): type of the distributions. Should be "Exponential" in this case.
|
||||
rate_b (Tensor): rate of distribution b.
|
||||
rate_a (Tensor): rate of distribution a. Default: self.rate.
|
||||
"""
|
||||
if name in self._divergence_functions and dist == 'Exponential':
|
||||
if dist == 'Exponential':
|
||||
rate_a = self.rate if rate_a is None else rate_a
|
||||
return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0
|
||||
return None
|
||||
|
||||
def _sample(self, name, shape=(), rate=None):
|
||||
def _sample(self, shape=(), rate=None):
|
||||
"""
|
||||
Sampling.
|
||||
|
||||
Args:
|
||||
name (str): name of the function.
|
||||
shape (tuple): shape of the sample. Default: ().
|
||||
rate (Tensor): rate of the distribution. Default: self.rate.
|
||||
|
||||
Returns:
|
||||
Tensor, shape is shape + batch_shape.
|
||||
"""
|
||||
if name == 'sample':
|
||||
rate = self.rate if rate is None else rate
|
||||
minval = self.const(self.minval)
|
||||
maxval = self.const(1.0)
|
||||
sample = self.uniform(shape + self.shape(rate), minval, maxval)
|
||||
return -self.log(sample) / rate
|
||||
return None
|
||||
rate = self.rate if rate is None else rate
|
||||
minval = self.const(self.minval)
|
||||
maxval = self.const(1.0)
|
||||
sample = self.uniform(shape + self.shape(rate), minval, maxval)
|
||||
return -self.log(sample) / rate
|
||||
|
|
|
@ -36,55 +36,56 @@ class Geometric(Distribution):
|
|||
|
||||
Examples:
|
||||
>>> # To initialize a Geometric distribution of prob 0.5
|
||||
>>> n = nn.Geometric(0.5, dtype=mstype.int32)
|
||||
>>> import mindspore.nn.probability.distribution as msd
|
||||
>>> n = msd.Geometric(0.5, dtype=mstype.int32)
|
||||
>>>
|
||||
>>> # The following creates two independent Geometric distributions
|
||||
>>> n = nn.Geometric([0.5, 0.5], dtype=mstype.int32)
|
||||
>>> n = msd.Geometric([0.5, 0.5], dtype=mstype.int32)
|
||||
>>>
|
||||
>>> # A Geometric distribution can be initilized without arguments
|
||||
>>> # In this case, probs must be passed in through construct.
|
||||
>>> n = nn.Geometric(dtype=mstype.int32)
|
||||
>>> # In this case, probs must be passed in through args during function calls.
|
||||
>>> n = msd.Geometric(dtype=mstype.int32)
|
||||
>>>
|
||||
>>> # To use Geometric distribution in a network
|
||||
>>> # To use Geometric in a network
|
||||
>>> class net(Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(net, self).__init__():
|
||||
>>> self.g1 = nn.Geometric(0.5, dtype=mstype.int32)
|
||||
>>> self.g2 = nn.Geometric(dtype=mstype.int32)
|
||||
>>> self.g1 = msd.Geometric(0.5, dtype=mstype.int32)
|
||||
>>> self.g2 = msd.Geometric(dtype=mstype.int32)
|
||||
>>>
|
||||
>>> # Tthe following calls are valid in construct
|
||||
>>> def construct(self, value, probs_b, probs_a):
|
||||
>>>
|
||||
>>> # Similar calls can be made to other probability functions
|
||||
>>> # by replacing 'prob' with the name of the function
|
||||
>>> ans = self.g1('prob', value)
|
||||
>>> ans = self.g1.prob(value)
|
||||
>>> # Evaluate with the respect to distribution b
|
||||
>>> ans = self.g1('prob', value, probs_b)
|
||||
>>> ans = self.g1.prob(value, probs_b)
|
||||
>>>
|
||||
>>> # Probs must be passed in through construct
|
||||
>>> ans = self.g2('prob', value, probs_a)
|
||||
>>> # Probs must be passed in during function calls
|
||||
>>> ans = self.g2.prob(value, probs_a)
|
||||
>>>
|
||||
>>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean'
|
||||
>>> # Will return [0.0]
|
||||
>>> ans = self.g1('mean')
|
||||
>>> # Will return mean_b
|
||||
>>> ans = self.g1('mean', probs_b)
|
||||
>>> # Functions 'sd', 'var', 'entropy' have the same usage as 'mean'
|
||||
>>> # Will return 1.0
|
||||
>>> ans = self.g1.mean()
|
||||
>>> # Another possible usage
|
||||
>>> ans = self.g1.mean(probs_b)
|
||||
>>>
|
||||
>>> # Probs must be passed in through construct
|
||||
>>> ans = self.g2('mean', probs_a)
|
||||
>>> # Probs must be passed in during function calls
|
||||
>>> ans = self.g2.mean(probs_a)
|
||||
>>>
|
||||
>>> # Usage of 'kl_loss' and 'cross_entropy' are similar
|
||||
>>> ans = self.g1('kl_loss', 'Geometric', probs_b)
|
||||
>>> ans = self.g1('kl_loss', 'Geometric', probs_b, probs_a)
|
||||
>>> ans = self.g1.kl_loss('Geometric', probs_b)
|
||||
>>> ans = self.g1.kl_loss('Geometric', probs_b, probs_a)
|
||||
>>>
|
||||
>>> # Additional probs must be passed in through construct
|
||||
>>> ans = self.g2('kl_loss', 'Geometric', probs_b, probs_a)
|
||||
>>> # Additional probs must be passed in
|
||||
>>> ans = self.g2.kl_loss('Geometric', probs_b, probs_a)
|
||||
>>>
|
||||
>>> # Sample Usage
|
||||
>>> ans = self.g1('sample')
|
||||
>>> ans = self.g1('sample', (2,3))
|
||||
>>> ans = self.g1('sample', (2,3), probs_b)
|
||||
>>> ans = self.g2('sample', (2,3), probs_a)
|
||||
>>> # Sample
|
||||
>>> ans = self.g1.sample()
|
||||
>>> ans = self.g1.sample((2,3))
|
||||
>>> ans = self.g1.sample((2,3), probs_b)
|
||||
>>> ans = self.g2.sample((2,3), probs_a)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -134,67 +135,57 @@ class Geometric(Distribution):
|
|||
"""
|
||||
return self._probs
|
||||
|
||||
def _mean(self, name='mean', probs1=None):
|
||||
def _mean(self, probs1=None):
|
||||
r"""
|
||||
.. math::
|
||||
MEAN(Geo) = \fratc{1 - probs1}{probs1}
|
||||
"""
|
||||
if name == 'mean':
|
||||
probs1 = self.probs if probs1 is None else probs1
|
||||
return (1. - probs1) / probs1
|
||||
return None
|
||||
probs1 = self.probs if probs1 is None else probs1
|
||||
return (1. - probs1) / probs1
|
||||
|
||||
def _mode(self, name='mode', probs1=None):
|
||||
def _mode(self, probs1=None):
|
||||
r"""
|
||||
.. math::
|
||||
MODE(Geo) = 0
|
||||
"""
|
||||
if name == 'mode':
|
||||
probs1 = self.probs if probs1 is None else probs1
|
||||
return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.)
|
||||
return None
|
||||
probs1 = self.probs if probs1 is None else probs1
|
||||
return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.)
|
||||
|
||||
def _var(self, name='var', probs1=None):
|
||||
def _var(self, probs1=None):
|
||||
r"""
|
||||
.. math::
|
||||
VAR(Geo) = \fract{1 - probs1}{probs1 ^ {2}}
|
||||
"""
|
||||
if name in self._variance_functions:
|
||||
probs1 = self.probs if probs1 is None else probs1
|
||||
return (1.0 - probs1) / self.sq(probs1)
|
||||
return None
|
||||
probs1 = self.probs if probs1 is None else probs1
|
||||
return (1.0 - probs1) / self.sq(probs1)
|
||||
|
||||
def _entropy(self, name='entropy', probs=None):
|
||||
def _entropy(self, probs=None):
|
||||
r"""
|
||||
.. math::
|
||||
H(Geo) = \fract{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1}
|
||||
"""
|
||||
if name == 'entropy':
|
||||
probs1 = self.probs if probs is None else probs
|
||||
probs0 = 1.0 - probs1
|
||||
return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1
|
||||
return None
|
||||
probs1 = self.probs if probs is None else probs
|
||||
probs0 = 1.0 - probs1
|
||||
return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1
|
||||
|
||||
def _cross_entropy(self, name, dist, probs1_b, probs1_a=None):
|
||||
def _cross_entropy(self, dist, probs1_b, probs1_a=None):
|
||||
r"""
|
||||
Evaluate cross_entropy between Geometric distributions.
|
||||
|
||||
Args:
|
||||
name (str): name of the funtion. Should always be "cross_entropy" when passed in from construct.
|
||||
dist (str): type of the distributions. Should be "Geometric" in this case.
|
||||
probs1_b (Tensor): probability of success of distribution b.
|
||||
probs1_a (Tensor): probability of success of distribution a. Default: self.probs.
|
||||
"""
|
||||
if name == 'cross_entropy' and dist == 'Geometric':
|
||||
return self._entropy(probs=probs1_a) + self._kl_loss(name, dist, probs1_b, probs1_a)
|
||||
if dist == 'Geometric':
|
||||
return self._entropy(probs=probs1_a) + self._kl_loss(dist, probs1_b, probs1_a)
|
||||
return None
|
||||
|
||||
def _prob(self, name, value, probs=None):
|
||||
def _prob(self, value, probs=None):
|
||||
r"""
|
||||
pmf of Geometric distribution.
|
||||
|
||||
Args:
|
||||
name (str): name of the function. Should be "prob" when passed in from construct.
|
||||
value (Tensor): a Tensor composed of only natural numbers.
|
||||
probs (Tensor): probability of success. Default: self.probs.
|
||||
|
||||
|
@ -202,27 +193,24 @@ class Geometric(Distribution):
|
|||
pmf(k) = probs0 ^k * probs1 if k >= 0;
|
||||
pmf(k) = 0 if k < 0.
|
||||
"""
|
||||
if name in self._prob_functions:
|
||||
probs1 = self.probs if probs is None else probs
|
||||
dtype = self.dtypeop(value)
|
||||
if self.issubclass(dtype, mstype.int_):
|
||||
pass
|
||||
elif self.issubclass(dtype, mstype.float_):
|
||||
value = self.floor(value)
|
||||
else:
|
||||
return None
|
||||
pmf = self.pow((1.0 - probs1), value) * probs1
|
||||
zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0)
|
||||
comp = self.less(value, zeros)
|
||||
return self.select(comp, zeros, pmf)
|
||||
return None
|
||||
probs1 = self.probs if probs is None else probs
|
||||
dtype = self.dtypeop(value)
|
||||
if self.issubclass(dtype, mstype.int_):
|
||||
pass
|
||||
elif self.issubclass(dtype, mstype.float_):
|
||||
value = self.floor(value)
|
||||
else:
|
||||
return None
|
||||
pmf = self.pow((1.0 - probs1), value) * probs1
|
||||
zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0)
|
||||
comp = self.less(value, zeros)
|
||||
return self.select(comp, zeros, pmf)
|
||||
|
||||
def _cdf(self, name, value, probs=None):
|
||||
def _cdf(self, value, probs=None):
|
||||
r"""
|
||||
cdf of Geometric distribution.
|
||||
|
||||
Args:
|
||||
name (str): name of the function.
|
||||
value (Tensor): a Tensor composed of only natural numbers.
|
||||
probs (Tensor): probability of success. Default: self.probs.
|
||||
|
||||
|
@ -231,28 +219,26 @@ class Geometric(Distribution):
|
|||
cdf(k) = 0 if k < 0.
|
||||
|
||||
"""
|
||||
if name in self._cdf_survival_functions:
|
||||
probs1 = self.probs if probs is None else probs
|
||||
probs0 = 1.0 - probs1
|
||||
dtype = self.dtypeop(value)
|
||||
if self.issubclass(dtype, mstype.int_):
|
||||
pass
|
||||
elif self.issubclass(dtype, mstype.float_):
|
||||
value = self.floor(value)
|
||||
else:
|
||||
return None
|
||||
cdf = 1.0 - self.pow(probs0, value + 1.0)
|
||||
zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0)
|
||||
comp = self.less(value, zeros)
|
||||
return self.select(comp, zeros, cdf)
|
||||
return None
|
||||
probs1 = self.probs if probs is None else probs
|
||||
probs0 = 1.0 - probs1
|
||||
dtype = self.dtypeop(value)
|
||||
if self.issubclass(dtype, mstype.int_):
|
||||
pass
|
||||
elif self.issubclass(dtype, mstype.float_):
|
||||
value = self.floor(value)
|
||||
else:
|
||||
return None
|
||||
cdf = 1.0 - self.pow(probs0, value + 1.0)
|
||||
zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0)
|
||||
comp = self.less(value, zeros)
|
||||
return self.select(comp, zeros, cdf)
|
||||
|
||||
def _kl_loss(self, name, dist, probs1_b, probs1_a=None):
|
||||
|
||||
def _kl_loss(self, dist, probs1_b, probs1_a=None):
|
||||
r"""
|
||||
Evaluate Geometric-Geometric kl divergence, i.e. KL(a||b).
|
||||
|
||||
Args:
|
||||
name (str): name of the funtion.
|
||||
dist (str): type of the distributions. Should be "Geometric" in this case.
|
||||
probs1_b (Tensor): probability of success of distribution b.
|
||||
probs1_a (Tensor): probability of success of distribution a. Default: self.probs.
|
||||
|
@ -260,29 +246,26 @@ class Geometric(Distribution):
|
|||
.. math::
|
||||
KL(a||b) = \log(\fract{probs1_a}{probs1_b}) + \fract{probs0_a}{probs1_a} * \log(\fract{probs0_a}{probs0_b})
|
||||
"""
|
||||
if name in self._divergence_functions and dist == 'Geometric':
|
||||
if dist == 'Geometric':
|
||||
probs1_a = self.probs if probs1_a is None else probs1_a
|
||||
probs0_a = 1.0 - probs1_a
|
||||
probs0_b = 1.0 - probs1_b
|
||||
return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b)
|
||||
return None
|
||||
|
||||
def _sample(self, name, shape=(), probs=None):
|
||||
def _sample(self, shape=(), probs=None):
|
||||
"""
|
||||
Sampling.
|
||||
|
||||
Args:
|
||||
name (str): name of the function. Should always be 'sample' when passed in from construct.
|
||||
shape (tuple): shape of the sample. Default: ().
|
||||
probs (Tensor): probability of success. Default: self.probs.
|
||||
|
||||
Returns:
|
||||
Tensor, shape is shape + batch_shape.
|
||||
"""
|
||||
if name == 'sample':
|
||||
probs = self.probs if probs is None else probs
|
||||
minval = self.const(self.minval)
|
||||
maxval = self.const(1.0)
|
||||
sample_uniform = self.uniform(shape + self.shape(probs), minval, maxval)
|
||||
return self.floor(self.log(sample_uniform) / self.log(1.0 - probs))
|
||||
return None
|
||||
probs = self.probs if probs is None else probs
|
||||
minval = self.const(self.minval)
|
||||
maxval = self.const(1.0)
|
||||
sample_uniform = self.uniform(shape + self.shape(probs), minval, maxval)
|
||||
return self.floor(self.log(sample_uniform) / self.log(1.0 - probs))
|
||||
|
|
|
@ -17,7 +17,6 @@ import numpy as np
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.context import get_context
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import convert_to_batch, check_greater_equal_zero
|
||||
|
||||
|
@ -39,55 +38,56 @@ class Normal(Distribution):
|
|||
|
||||
Examples:
|
||||
>>> # To initialize a Normal distribution of mean 3.0 and standard deviation 4.0
|
||||
>>> n = nn.Normal(3.0, 4.0, dtype=mstype.float32)
|
||||
>>> import mindspore.nn.probability.distribution as msd
|
||||
>>> n = msd.Normal(3.0, 4.0, dtype=mstype.float32)
|
||||
>>>
|
||||
>>> # The following creates two independent Normal distributions
|
||||
>>> n = nn.Normal([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32)
|
||||
>>> n = msd.Normal([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32)
|
||||
>>>
|
||||
>>> # A normal distribution can be initilize without arguments
|
||||
>>> # In this case, mean and sd must be passed in through construct.
|
||||
>>> n = nn.Normal(dtype=mstype.float32)
|
||||
>>> # A Normal distribution can be initilize without arguments
|
||||
>>> # In this case, mean and sd must be passed in through args.
|
||||
>>> n = msd.Normal(dtype=mstype.float32)
|
||||
>>>
|
||||
>>> # To use normal in a network
|
||||
>>> # To use Normal in a network
|
||||
>>> class net(Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(net, self).__init__():
|
||||
>>> self.n1 = nn.Normal(0.0, 1.0, dtype=mstype.float32)
|
||||
>>> self.n2 = nn.Normal(dtype=mstype.float32)
|
||||
>>> self.n1 = msd.Nomral(0.0, 1.0, dtype=mstype.float32)
|
||||
>>> self.n2 = msd.Normal(dtype=mstype.float32)
|
||||
>>>
|
||||
>>> # The following calls are valid in construct
|
||||
>>> def construct(self, value, mean_b, sd_b, mean_a, sd_a):
|
||||
>>>
|
||||
>>> # Similar calls can be made to other probability functions
|
||||
>>> # by replacing 'prob' with the name of the function
|
||||
>>> ans = self.n1('prob', value)
|
||||
>>> ans = self.n1.prob(value)
|
||||
>>> # Evaluate with the respect to distribution b
|
||||
>>> ans = self.n1('prob', value, mean_b, sd_b)
|
||||
>>> ans = self.n1.prob(value, mean_b, sd_b)
|
||||
>>>
|
||||
>>> # mean and sd must be passed in through construct
|
||||
>>> ans = self.n2('prob', value, mean_a, sd_a)
|
||||
>>> # mean and sd must be passed in during function calls
|
||||
>>> ans = self.n2.prob(value, mean_a, sd_a)
|
||||
>>>
|
||||
>>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean'
|
||||
>>> # Will return [0.0]
|
||||
>>> ans = self.n1('mean')
|
||||
>>> # Will return mean_b
|
||||
>>> ans = self.n1('mean', mean_b, sd_b)
|
||||
>>> # Functions 'sd', 'var', 'entropy' have the same usage as 'mean'
|
||||
>>> # will return [0.0]
|
||||
>>> ans = self.n1.mean()
|
||||
>>> # will return mean_b
|
||||
>>> ans = self.n1.mean(mean_b, sd_b)
|
||||
>>>
|
||||
>>> # mean and sd must be passed in through construct
|
||||
>>> ans = self.n2('mean', mean_a, sd_a)
|
||||
>>> # mean and sd must be passed during function calls
|
||||
>>> ans = self.n2.mean(mean_a, sd_a)
|
||||
>>>
|
||||
>>> # Usage of 'kl_loss' and 'cross_entropy' are similar
|
||||
>>> ans = self.n1('kl_loss', 'Normal', mean_b, sd_b)
|
||||
>>> ans = self.n1('kl_loss', 'Normal', mean_b, sd_b, mean_a, sd_a)
|
||||
>>> ans = self.n1.kl_loss('Normal', mean_b, sd_b)
|
||||
>>> ans = self.n1.kl_loss('Normal', mean_b, sd_b, mean_a, sd_a)
|
||||
>>>
|
||||
>>> # Additional mean and sd must be passed in through construct
|
||||
>>> ans = self.n2('kl_loss', 'Normal', mean_b, sd_b, mean_a, sd_a)
|
||||
>>> # Additional mean and sd must be passed
|
||||
>>> ans = self.n2.kl_loss('Normal', mean_b, sd_b, mean_a, sd_a)
|
||||
>>>
|
||||
>>> # Sample Usage
|
||||
>>> ans = self.n1('sample')
|
||||
>>> ans = self.n1('sample', (2,3))
|
||||
>>> ans = self.n1('sample', (2,3), mean_b, sd_b)
|
||||
>>> ans = self.n2('sample', (2,3), mean_a, sd_a)
|
||||
>>> # Sample
|
||||
>>> ans = self.n1.sample()
|
||||
>>> ans = self.n1.sample((2,3))
|
||||
>>> ans = self.n1.sample((2,3), mean_b, sd_b)
|
||||
>>> ans = self.n2.sample((2,3), mean_a, sd_a)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -114,7 +114,7 @@ class Normal(Distribution):
|
|||
self.const = P.ScalarToArray()
|
||||
self.erf = P.Erf()
|
||||
self.exp = P.Exp()
|
||||
self.expm1 = P.Expm1() if get_context('device_target') == 'Ascend' else self._expm1_by_step
|
||||
self.expm1 = self._expm1_by_step
|
||||
self.fill = P.Fill()
|
||||
self.log = P.Log()
|
||||
self.shape = P.Shape()
|
||||
|
@ -135,67 +135,57 @@ class Normal(Distribution):
|
|||
"""
|
||||
return self.exp(x) - 1.0
|
||||
|
||||
def _mean(self, name='mean', mean=None, sd=None):
|
||||
def _mean(self, mean=None, sd=None):
|
||||
"""
|
||||
Mean of the distribution.
|
||||
"""
|
||||
if name == 'mean':
|
||||
mean = self._mean_value if mean is None or sd is None else mean
|
||||
return mean
|
||||
return None
|
||||
mean = self._mean_value if mean is None or sd is None else mean
|
||||
return mean
|
||||
|
||||
def _mode(self, name='mode', mean=None, sd=None):
|
||||
def _mode(self, mean=None, sd=None):
|
||||
"""
|
||||
Mode of the distribution.
|
||||
"""
|
||||
if name == 'mode':
|
||||
mean = self._mean_value if mean is None or sd is None else mean
|
||||
return mean
|
||||
return None
|
||||
mean = self._mean_value if mean is None or sd is None else mean
|
||||
return mean
|
||||
|
||||
def _sd(self, name='sd', mean=None, sd=None):
|
||||
def _sd(self, mean=None, sd=None):
|
||||
"""
|
||||
Standard deviation of the distribution.
|
||||
"""
|
||||
if name in self._variance_functions:
|
||||
sd = self._sd_value if mean is None or sd is None else sd
|
||||
return sd
|
||||
return None
|
||||
sd = self._sd_value if mean is None or sd is None else sd
|
||||
return sd
|
||||
|
||||
def _entropy(self, name='entropy', sd=None):
|
||||
def _entropy(self, sd=None):
|
||||
r"""
|
||||
Evaluate entropy.
|
||||
|
||||
.. math::
|
||||
H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma)))
|
||||
"""
|
||||
if name == 'entropy':
|
||||
sd = self._sd_value if sd is None else sd
|
||||
return self.log(self.sqrt(np.e * 2. * np.pi * self.sq(sd)))
|
||||
return None
|
||||
sd = self._sd_value if sd is None else sd
|
||||
return self.log(self.sqrt(np.e * 2. * np.pi * self.sq(sd)))
|
||||
|
||||
def _cross_entropy(self, name, dist, mean_b, sd_b, mean_a=None, sd_a=None):
|
||||
def _cross_entropy(self, dist, mean_b, sd_b, mean_a=None, sd_a=None):
|
||||
r"""
|
||||
Evaluate cross_entropy between normal distributions.
|
||||
|
||||
Args:
|
||||
name (str): name of the funtion passed in from construct. Should always be "cross_entropy".
|
||||
dist (str): type of the distributions. Should be "Normal" in this case.
|
||||
mean_b (Tensor): mean of distribution b.
|
||||
sd_b (Tensor): standard deviation distribution b.
|
||||
mean_a (Tensor): mean of distribution a. Default: self._mean_value.
|
||||
sd_a (Tensor): standard deviation distribution a. Default: self._sd_value.
|
||||
"""
|
||||
if name == 'cross_entropy' and dist == 'Normal':
|
||||
return self._entropy(sd=sd_a) + self._kl_loss(name, dist, mean_b, sd_b, mean_a, sd_a)
|
||||
if dist == 'Normal':
|
||||
return self._entropy(sd=sd_a) + self._kl_loss(dist, mean_b, sd_b, mean_a, sd_a)
|
||||
return None
|
||||
|
||||
def _log_prob(self, name, value, mean=None, sd=None):
|
||||
def _log_prob(self, value, mean=None, sd=None):
|
||||
r"""
|
||||
Evaluate log probability.
|
||||
|
||||
Args:
|
||||
name (str): name of the funtion passed in from construct.
|
||||
value (Tensor): value to be evaluated.
|
||||
mean (Tensor): mean of the distribution. Default: self._mean_value.
|
||||
sd (Tensor): standard deviation the distribution. Default: self._sd_value.
|
||||
|
@ -203,20 +193,17 @@ class Normal(Distribution):
|
|||
.. math::
|
||||
L(x) = -1* \fract{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2))
|
||||
"""
|
||||
if name in self._prob_functions:
|
||||
mean = self._mean_value if mean is None else mean
|
||||
sd = self._sd_value if sd is None else sd
|
||||
unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd))
|
||||
neg_normalization = -1. * self.log(self.sqrt(2. * np.pi * self.sq(sd)))
|
||||
return unnormalized_log_prob + neg_normalization
|
||||
return None
|
||||
mean = self._mean_value if mean is None else mean
|
||||
sd = self._sd_value if sd is None else sd
|
||||
unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd))
|
||||
neg_normalization = -1. * self.log(self.sqrt(2. * np.pi * self.sq(sd)))
|
||||
return unnormalized_log_prob + neg_normalization
|
||||
|
||||
def _cdf(self, name, value, mean=None, sd=None):
|
||||
def _cdf(self, value, mean=None, sd=None):
|
||||
r"""
|
||||
Evaluate cdf of given value.
|
||||
|
||||
Args:
|
||||
name (str): name of the funtion passed in from construct. Should always be "cdf".
|
||||
value (Tensor): value to be evaluated.
|
||||
mean (Tensor): mean of the distribution. Default: self._mean_value.
|
||||
sd (Tensor): standard deviation the distribution. Default: self._sd_value.
|
||||
|
@ -224,20 +211,17 @@ class Normal(Distribution):
|
|||
.. math::
|
||||
cdf(x) = 0.5 * (1+ Erf((x - \mu) / ( \sigma * \sqrt(2))))
|
||||
"""
|
||||
if name in self._cdf_survival_functions:
|
||||
mean = self._mean_value if mean is None else mean
|
||||
sd = self._sd_value if sd is None else sd
|
||||
sqrt2 = self.sqrt(self.const(2.0))
|
||||
adjusted = (value - mean) / (sd * sqrt2)
|
||||
return 0.5 * (1.0 + self.erf(adjusted))
|
||||
return None
|
||||
mean = self._mean_value if mean is None else mean
|
||||
sd = self._sd_value if sd is None else sd
|
||||
sqrt2 = self.sqrt(self.const(2.0))
|
||||
adjusted = (value - mean) / (sd * sqrt2)
|
||||
return 0.5 * (1.0 + self.erf(adjusted))
|
||||
|
||||
def _kl_loss(self, name, dist, mean_b, sd_b, mean_a=None, sd_a=None):
|
||||
def _kl_loss(self, dist, mean_b, sd_b, mean_a=None, sd_a=None):
|
||||
r"""
|
||||
Evaluate Normal-Normal kl divergence, i.e. KL(a||b).
|
||||
|
||||
Args:
|
||||
name (str): name of the funtion passed in from construct.
|
||||
dist (str): type of the distributions. Should be "Normal" in this case.
|
||||
mean_b (Tensor): mean of distribution b.
|
||||
sd_b (Tensor): standard deviation distribution b.
|
||||
|
@ -248,7 +232,7 @@ class Normal(Distribution):
|
|||
KL(a||b) = 0.5 * (\fract{MEAN(a)}{STD(b)} - \fract{MEAN(b)}{STD(b)}) ^ 2 +
|
||||
0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b)))
|
||||
"""
|
||||
if name in self._divergence_functions and dist == 'Normal':
|
||||
if dist == 'Normal':
|
||||
mean_a = self._mean_value if mean_a is None else mean_a
|
||||
sd_a = self._sd_value if sd_a is None else sd_a
|
||||
diff_log_scale = self.log(sd_a) - self.log(sd_b)
|
||||
|
@ -256,12 +240,11 @@ class Normal(Distribution):
|
|||
return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale
|
||||
return None
|
||||
|
||||
def _sample(self, name, shape=(), mean=None, sd=None):
|
||||
def _sample(self, shape=(), mean=None, sd=None):
|
||||
"""
|
||||
Sampling.
|
||||
|
||||
Args:
|
||||
name (str): name of the function. Should always be 'sample' when passed in from construct.
|
||||
shape (tuple): shape of the sample. Default: ().
|
||||
mean (Tensor): mean of the samples. Default: self._mean_value.
|
||||
sd (Tensor): standard deviation of the samples. Default: self._sd_value.
|
||||
|
@ -269,14 +252,12 @@ class Normal(Distribution):
|
|||
Returns:
|
||||
Tensor, shape is shape + batch_shape.
|
||||
"""
|
||||
if name == 'sample':
|
||||
mean = self._mean_value if mean is None else mean
|
||||
sd = self._sd_value if sd is None else sd
|
||||
batch_shape = self.shape(self.zeroslike(mean) + self.zeroslike(sd))
|
||||
sample_shape = shape + batch_shape
|
||||
mean_zero = self.const(0.0)
|
||||
sd_one = self.const(1.0)
|
||||
sample_norm = C.normal(sample_shape, mean_zero, sd_one, self.seed)
|
||||
sample = mean + sample_norm * sd
|
||||
return sample
|
||||
return None
|
||||
mean = self._mean_value if mean is None else mean
|
||||
sd = self._sd_value if sd is None else sd
|
||||
batch_shape = self.shape(self.zeroslike(mean) + self.zeroslike(sd))
|
||||
sample_shape = shape + batch_shape
|
||||
mean_zero = self.const(0.0)
|
||||
sd_one = self.const(1.0)
|
||||
sample_norm = C.normal(sample_shape, mean_zero, sd_one, self.seed)
|
||||
sample = mean + sample_norm * sd
|
||||
return sample
|
||||
|
|
|
@ -35,55 +35,56 @@ class Uniform(Distribution):
|
|||
|
||||
Examples:
|
||||
>>> # To initialize a Uniform distribution of mean 3.0 and standard deviation 4.0
|
||||
>>> n = nn.Uniform(0.0, 1.0, dtype=mstype.float32)
|
||||
>>> import mindspore.nn.probability.distribution as msd
|
||||
>>> u = msd.Uniform(0.0, 1.0, dtype=mstype.float32)
|
||||
>>>
|
||||
>>> # The following creates two independent Uniform distributions
|
||||
>>> n = nn.Uniform([0.0, 0.0], [1.0, 2.0], dtype=mstype.float32)
|
||||
>>> u = msd.Uniform([0.0, 0.0], [1.0, 2.0], dtype=mstype.float32)
|
||||
>>>
|
||||
>>> # A Uniform distribution can be initilized without arguments
|
||||
>>> # In this case, high and low must be passed in through construct.
|
||||
>>> n = nn.Uniform(dtype=mstype.float32)
|
||||
>>> # In this case, high and low must be passed in through args during function calls.
|
||||
>>> u = msd.Uniform(dtype=mstype.float32)
|
||||
>>>
|
||||
>>> # To use Uniform in a network
|
||||
>>> class net(Cell):
|
||||
>>> def __init__(self)
|
||||
>>> super(net, self).__init__():
|
||||
>>> self.u1 = nn.Uniform(0.0, 1.0, dtype=mstype.float32)
|
||||
>>> self.u2 = nn.Uniform(dtype=mstype.float32)
|
||||
>>> self.u1 = msd.Uniform(0.0, 1.0, dtype=mstype.float32)
|
||||
>>> self.u2 = msd.Uniform(dtype=mstype.float32)
|
||||
>>>
|
||||
>>> # All the following calls in construct are valid
|
||||
>>> def construct(self, value, low_b, high_b, low_a, high_a):
|
||||
>>>
|
||||
>>> # Similar calls can be made to other probability functions
|
||||
>>> # by replacing 'prob' with the name of the function
|
||||
>>> ans = self.u1('prob', value)
|
||||
>>> ans = self.u1.prob(value)
|
||||
>>> # Evaluate with the respect to distribution b
|
||||
>>> ans = self.u1('prob', value, low_b, high_b)
|
||||
>>> ans = self.u1.prob(value, low_b, high_b)
|
||||
>>>
|
||||
>>> # High and low must be passed in through construct
|
||||
>>> ans = self.u2('prob', value, low_a, high_a)
|
||||
>>> # High and low must be passed in during function calls
|
||||
>>> ans = self.u2.prob(value, low_a, high_a)
|
||||
>>>
|
||||
>>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean'
|
||||
>>> # Will return [0.0]
|
||||
>>> ans = self.u1('mean')
|
||||
>>> # Will return low_b
|
||||
>>> ans = self.u1('mean', low_b, high_b)
|
||||
>>> # Functions 'sd', 'var', 'entropy' have the same usage as 'mean'
|
||||
>>> # Will return 0.5
|
||||
>>> ans = self.u1.mean()
|
||||
>>> # Will return (low_b + high_b) / 2
|
||||
>>> ans = self.u1.mean(low_b, high_b)
|
||||
>>>
|
||||
>>> # High and low must be passed in through construct
|
||||
>>> ans = self.u2('mean', low_a, high_a)
|
||||
>>> # High and low must be passed in during function calls
|
||||
>>> ans = self.u2.mean(low_a, high_a)
|
||||
>>>
|
||||
>>> # Usage of 'kl_loss' and 'cross_entropy' are similar
|
||||
>>> ans = self.u1('kl_loss', 'Uniform', low_b, high_b)
|
||||
>>> ans = self.u1('kl_loss', 'Uniform', low_b, high_b, low_a, high_a)
|
||||
>>> ans = self.u1.kl_loss('Uniform', low_b, high_b)
|
||||
>>> ans = self.u1.kl_loss('Uniform', low_b, high_b, low_a, high_a)
|
||||
>>>
|
||||
>>> # Additional high and low must be passed in through construct
|
||||
>>> ans = self.u2('kl_loss', 'Uniform', low_b, high_b, low_a, high_a)
|
||||
>>> # Additional high and low must be passed
|
||||
>>> ans = self.u2.kl_loss('Uniform', low_b, high_b, low_a, high_a)
|
||||
>>>
|
||||
>>> # Sample Usage
|
||||
>>> ans = self.u1('sample')
|
||||
>>> ans = self.u1('sample', (2,3))
|
||||
>>> ans = self.u1('sample', (2,3), low_b, high_b)
|
||||
>>> ans = self.u2('sample', (2,3), low_a, high_a)
|
||||
>>> # Sample
|
||||
>>> ans = self.u1.sample()
|
||||
>>> ans = self.u1.sample((2,3))
|
||||
>>> ans = self.u1.sample((2,3), low_b, high_b)
|
||||
>>> ans = self.u2.sample((2,3), low_a, high_a)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -142,73 +143,64 @@ class Uniform(Distribution):
|
|||
"""
|
||||
return self._high
|
||||
|
||||
def _range(self, name='range', low=None, high=None):
|
||||
def _range(self, low=None, high=None):
|
||||
r"""
|
||||
Return the range of the distribution.
|
||||
.. math::
|
||||
range(U) = high -low
|
||||
"""
|
||||
if name == 'range':
|
||||
low = self.low if low is None else low
|
||||
high = self.high if high is None else high
|
||||
return high - low
|
||||
return None
|
||||
low = self.low if low is None else low
|
||||
high = self.high if high is None else high
|
||||
return high - low
|
||||
|
||||
def _mean(self, name='mean', low=None, high=None):
|
||||
def _mean(self, low=None, high=None):
|
||||
r"""
|
||||
.. math::
|
||||
MEAN(U) = \fract{low + high}{2}.
|
||||
"""
|
||||
if name == 'mean':
|
||||
low = self.low if low is None else low
|
||||
high = self.high if high is None else high
|
||||
return (low + high) / 2.
|
||||
return None
|
||||
low = self.low if low is None else low
|
||||
high = self.high if high is None else high
|
||||
return (low + high) / 2.
|
||||
|
||||
def _var(self, name='var', low=None, high=None):
|
||||
|
||||
def _var(self, low=None, high=None):
|
||||
r"""
|
||||
.. math::
|
||||
VAR(U) = \fract{(high -low) ^ 2}{12}.
|
||||
"""
|
||||
if name in self._variance_functions:
|
||||
low = self.low if low is None else low
|
||||
high = self.high if high is None else high
|
||||
return self.sq(high - low) / 12.0
|
||||
return None
|
||||
low = self.low if low is None else low
|
||||
high = self.high if high is None else high
|
||||
return self.sq(high - low) / 12.0
|
||||
|
||||
def _entropy(self, name='entropy', low=None, high=None):
|
||||
def _entropy(self, low=None, high=None):
|
||||
r"""
|
||||
.. math::
|
||||
H(U) = \log(high - low).
|
||||
"""
|
||||
if name == 'entropy':
|
||||
low = self.low if low is None else low
|
||||
high = self.high if high is None else high
|
||||
return self.log(high - low)
|
||||
return None
|
||||
low = self.low if low is None else low
|
||||
high = self.high if high is None else high
|
||||
return self.log(high - low)
|
||||
|
||||
def _cross_entropy(self, name, dist, low_b, high_b, low_a=None, high_a=None):
|
||||
def _cross_entropy(self, dist, low_b, high_b, low_a=None, high_a=None):
|
||||
"""
|
||||
Evaluate cross_entropy between Uniform distributoins.
|
||||
|
||||
Args:
|
||||
name (str): name of the funtion.
|
||||
dist (str): type of the distributions. Should be "Uniform" in this case.
|
||||
low_b (Tensor): lower bound of distribution b.
|
||||
high_b (Tensor): upper bound of distribution b.
|
||||
low_a (Tensor): lower bound of distribution a. Default: self.low.
|
||||
high_a (Tensor): upper bound of distribution a. Default: self.high.
|
||||
"""
|
||||
if name == 'cross_entropy' and dist == 'Uniform':
|
||||
return self._entropy(low=low_a, high=high_a) + self._kl_loss(name, dist, low_b, high_b, low_a, high_a)
|
||||
if dist == 'Uniform':
|
||||
return self._entropy(low=low_a, high=high_a) + self._kl_loss(dist, low_b, high_b, low_a, high_a)
|
||||
return None
|
||||
|
||||
def _prob(self, name, value, low=None, high=None):
|
||||
def _prob(self, value, low=None, high=None):
|
||||
r"""
|
||||
pdf of Uniform distribution.
|
||||
|
||||
Args:
|
||||
name (str): name of the function.
|
||||
value (Tensor): value to be evaluated.
|
||||
low (Tensor): lower bound of the distribution. Default: self.low.
|
||||
high (Tensor): upper bound of the distribution. Default: self.high.
|
||||
|
@ -218,32 +210,29 @@ class Uniform(Distribution):
|
|||
pdf(x) = \fract{1.0}{high -low} if low <= x <= high;
|
||||
pdf(x) = 0 if x > high;
|
||||
"""
|
||||
if name in self._prob_functions:
|
||||
low = self.low if low is None else low
|
||||
high = self.high if high is None else high
|
||||
ones = self.fill(self.dtype, self.shape(value), 1.0)
|
||||
prob = ones / (high - low)
|
||||
broadcast_shape = self.shape(prob)
|
||||
zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
|
||||
comp_lo = self.less(value, low)
|
||||
comp_hi = self.lessequal(value, high)
|
||||
less_than_low = self.select(comp_lo, zeros, prob)
|
||||
return self.select(comp_hi, less_than_low, zeros)
|
||||
return None
|
||||
low = self.low if low is None else low
|
||||
high = self.high if high is None else high
|
||||
ones = self.fill(self.dtype, self.shape(value), 1.0)
|
||||
prob = ones / (high - low)
|
||||
broadcast_shape = self.shape(prob)
|
||||
zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
|
||||
comp_lo = self.less(value, low)
|
||||
comp_hi = self.lessequal(value, high)
|
||||
less_than_low = self.select(comp_lo, zeros, prob)
|
||||
return self.select(comp_hi, less_than_low, zeros)
|
||||
|
||||
def _kl_loss(self, name, dist, low_b, high_b, low_a=None, high_a=None):
|
||||
def _kl_loss(self, dist, low_b, high_b, low_a=None, high_a=None):
|
||||
"""
|
||||
Evaluate uniform-uniform kl divergence, i.e. KL(a||b).
|
||||
|
||||
Args:
|
||||
name (str): name of the funtion.
|
||||
dist (str): type of the distributions. Should be "Uniform" in this case.
|
||||
low_b (Tensor): lower bound of distribution b.
|
||||
high_b (Tensor): upper bound of distribution b.
|
||||
low_a (Tensor): lower bound of distribution a. Default: self.low.
|
||||
high_a (Tensor): upper bound of distribution a. Default: self.high.
|
||||
"""
|
||||
if name in self._divergence_functions and dist == 'Uniform':
|
||||
if dist == 'Uniform':
|
||||
low_a = self.low if low_a is None else low_a
|
||||
high_a = self.high if high_a is None else high_a
|
||||
kl = self.log(high_b - low_b) / self.log(high_a - low_a)
|
||||
|
@ -251,12 +240,11 @@ class Uniform(Distribution):
|
|||
return self.select(comp, kl, self.log(self.zeroslike(kl)))
|
||||
return None
|
||||
|
||||
def _cdf(self, name, value, low=None, high=None):
|
||||
def _cdf(self, value, low=None, high=None):
|
||||
r"""
|
||||
cdf of Uniform distribution.
|
||||
|
||||
Args:
|
||||
name (str): name of the function.
|
||||
value (Tensor): value to be evaluated.
|
||||
low (Tensor): lower bound of the distribution. Default: self.low.
|
||||
high (Tensor): upper bound of the distribution. Default: self.high.
|
||||
|
@ -266,25 +254,22 @@ class Uniform(Distribution):
|
|||
cdf(x) = \fract{x - low}{high -low} if low <= x <= high;
|
||||
cdf(x) = 1 if x > high;
|
||||
"""
|
||||
if name in self._cdf_survival_functions:
|
||||
low = self.low if low is None else low
|
||||
high = self.high if high is None else high
|
||||
prob = (value - low) / (high - low)
|
||||
broadcast_shape = self.shape(prob)
|
||||
zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
|
||||
ones = self.fill(self.dtypeop(prob), broadcast_shape, 1.0)
|
||||
comp_lo = self.less(value, low)
|
||||
comp_hi = self.less(value, high)
|
||||
less_than_low = self.select(comp_lo, zeros, prob)
|
||||
return self.select(comp_hi, less_than_low, ones)
|
||||
return None
|
||||
low = self.low if low is None else low
|
||||
high = self.high if high is None else high
|
||||
prob = (value - low) / (high - low)
|
||||
broadcast_shape = self.shape(prob)
|
||||
zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
|
||||
ones = self.fill(self.dtypeop(prob), broadcast_shape, 1.0)
|
||||
comp_lo = self.less(value, low)
|
||||
comp_hi = self.less(value, high)
|
||||
less_than_low = self.select(comp_lo, zeros, prob)
|
||||
return self.select(comp_hi, less_than_low, ones)
|
||||
|
||||
def _sample(self, name, shape=(), low=None, high=None):
|
||||
def _sample(self, shape=(), low=None, high=None):
|
||||
"""
|
||||
Sampling.
|
||||
|
||||
Args:
|
||||
name (str): name of the function. Should always be 'sample' when passed in from construct.
|
||||
shape (tuple): shape of the sample. Default: ().
|
||||
low (Tensor): lower bound of the distribution. Default: self.low.
|
||||
high (Tensor): upper bound of the distribution. Default: self.high.
|
||||
|
@ -292,13 +277,11 @@ class Uniform(Distribution):
|
|||
Returns:
|
||||
Tensor, shape is shape + batch_shape.
|
||||
"""
|
||||
if name == 'sample':
|
||||
low = self.low if low is None else low
|
||||
high = self.high if high is None else high
|
||||
broadcast_shape = self.shape(low + high)
|
||||
l_zero = self.const(0.0)
|
||||
h_one = self.const(1.0)
|
||||
sample_uniform = self.uniform(shape + broadcast_shape, l_zero, h_one)
|
||||
sample = (high - low) * sample_uniform + low
|
||||
return sample
|
||||
return None
|
||||
low = self.low if low is None else low
|
||||
high = self.high if high is None else high
|
||||
broadcast_shape = self.shape(low + high)
|
||||
l_zero = self.const(0.0)
|
||||
h_one = self.const(1.0)
|
||||
sample_uniform = self.uniform(shape + broadcast_shape, l_zero, h_one)
|
||||
sample = (high - low) * sample_uniform + low
|
||||
return sample
|
||||
|
|
|
@ -19,7 +19,6 @@ import mindspore.context as context
|
|||
import mindspore.nn as nn
|
||||
import mindspore.nn.probability.distribution as msd
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore import dtype
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
@ -32,9 +31,8 @@ class Prob(nn.Cell):
|
|||
super(Prob, self).__init__()
|
||||
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
return self.b('prob', x_)
|
||||
return self.b.prob(x_)
|
||||
|
||||
def test_pmf():
|
||||
"""
|
||||
|
@ -57,9 +55,8 @@ class LogProb(nn.Cell):
|
|||
super(LogProb, self).__init__()
|
||||
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
return self.b('log_prob', x_)
|
||||
return self.b.log_prob(x_)
|
||||
|
||||
def test_log_likelihood():
|
||||
"""
|
||||
|
@ -81,9 +78,8 @@ class KL(nn.Cell):
|
|||
super(KL, self).__init__()
|
||||
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
return self.b('kl_loss', 'Bernoulli', x_)
|
||||
return self.b.kl_loss('Bernoulli', x_)
|
||||
|
||||
def test_kl_loss():
|
||||
"""
|
||||
|
@ -107,9 +103,8 @@ class Basics(nn.Cell):
|
|||
super(Basics, self).__init__()
|
||||
self.b = msd.Bernoulli([0.3, 0.5, 0.7], dtype=dtype.int32)
|
||||
|
||||
@ms_function
|
||||
def construct(self):
|
||||
return self.b('mean'), self.b('sd'), self.b('mode')
|
||||
return self.b.mean(), self.b.sd(), self.b.mode()
|
||||
|
||||
def test_basics():
|
||||
"""
|
||||
|
@ -134,9 +129,8 @@ class Sampling(nn.Cell):
|
|||
self.b = msd.Bernoulli([0.7, 0.5], seed=seed, dtype=dtype.int32)
|
||||
self.shape = shape
|
||||
|
||||
@ms_function
|
||||
def construct(self, probs=None):
|
||||
return self.b('sample', self.shape, probs)
|
||||
return self.b.sample(self.shape, probs)
|
||||
|
||||
def test_sample():
|
||||
"""
|
||||
|
@ -155,9 +149,8 @@ class CDF(nn.Cell):
|
|||
super(CDF, self).__init__()
|
||||
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
return self.b('cdf', x_)
|
||||
return self.b.cdf(x_)
|
||||
|
||||
def test_cdf():
|
||||
"""
|
||||
|
@ -171,7 +164,6 @@ def test_cdf():
|
|||
tol = 1e-6
|
||||
assert (np.abs(output.asnumpy() - expect_cdf) < tol).all()
|
||||
|
||||
|
||||
class LogCDF(nn.Cell):
|
||||
"""
|
||||
Test class: log cdf of bernoulli distributions.
|
||||
|
@ -180,9 +172,8 @@ class LogCDF(nn.Cell):
|
|||
super(LogCDF, self).__init__()
|
||||
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
return self.b('log_cdf', x_)
|
||||
return self.b.log_cdf(x_)
|
||||
|
||||
def test_logcdf():
|
||||
"""
|
||||
|
@ -205,9 +196,8 @@ class SF(nn.Cell):
|
|||
super(SF, self).__init__()
|
||||
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
return self.b('survival_function', x_)
|
||||
return self.b.survival_function(x_)
|
||||
|
||||
def test_survival():
|
||||
"""
|
||||
|
@ -230,9 +220,8 @@ class LogSF(nn.Cell):
|
|||
super(LogSF, self).__init__()
|
||||
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
return self.b('log_survival', x_)
|
||||
return self.b.log_survival(x_)
|
||||
|
||||
def test_log_survival():
|
||||
"""
|
||||
|
@ -254,9 +243,8 @@ class EntropyH(nn.Cell):
|
|||
super(EntropyH, self).__init__()
|
||||
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
|
||||
|
||||
@ms_function
|
||||
def construct(self):
|
||||
return self.b('entropy')
|
||||
return self.b.entropy()
|
||||
|
||||
def test_entropy():
|
||||
"""
|
||||
|
@ -277,12 +265,11 @@ class CrossEntropy(nn.Cell):
|
|||
super(CrossEntropy, self).__init__()
|
||||
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
entropy = self.b('entropy')
|
||||
kl_loss = self.b('kl_loss', 'Bernoulli', x_)
|
||||
entropy = self.b.entropy()
|
||||
kl_loss = self.b.kl_loss('Bernoulli', x_)
|
||||
h_sum_kl = entropy + kl_loss
|
||||
cross_entropy = self.b('cross_entropy', 'Bernoulli', x_)
|
||||
cross_entropy = self.b.cross_entropy('Bernoulli', x_)
|
||||
return h_sum_kl - cross_entropy
|
||||
|
||||
def test_cross_entropy():
|
||||
|
|
|
@ -19,7 +19,6 @@ import mindspore.context as context
|
|||
import mindspore.nn as nn
|
||||
import mindspore.nn.probability.distribution as msd
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore import dtype
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
@ -32,9 +31,8 @@ class Prob(nn.Cell):
|
|||
super(Prob, self).__init__()
|
||||
self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
return self.e('prob', x_)
|
||||
return self.e.prob(x_)
|
||||
|
||||
def test_pdf():
|
||||
"""
|
||||
|
@ -56,9 +54,8 @@ class LogProb(nn.Cell):
|
|||
super(LogProb, self).__init__()
|
||||
self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
return self.e('log_prob', x_)
|
||||
return self.e.log_prob(x_)
|
||||
|
||||
def test_log_likelihood():
|
||||
"""
|
||||
|
@ -80,9 +77,8 @@ class KL(nn.Cell):
|
|||
super(KL, self).__init__()
|
||||
self.e = msd.Exponential([1.5], dtype=dtype.float32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
return self.e('kl_loss', 'Exponential', x_)
|
||||
return self.e.kl_loss('Exponential', x_)
|
||||
|
||||
def test_kl_loss():
|
||||
"""
|
||||
|
@ -104,9 +100,8 @@ class Basics(nn.Cell):
|
|||
super(Basics, self).__init__()
|
||||
self.e = msd.Exponential([0.5], dtype=dtype.float32)
|
||||
|
||||
@ms_function
|
||||
def construct(self):
|
||||
return self.e('mean'), self.e('sd'), self.e('mode')
|
||||
return self.e.mean(), self.e.sd(), self.e.mode()
|
||||
|
||||
def test_basics():
|
||||
"""
|
||||
|
@ -131,9 +126,8 @@ class Sampling(nn.Cell):
|
|||
self.e = msd.Exponential([[1.0], [0.5]], seed=seed, dtype=dtype.float32)
|
||||
self.shape = shape
|
||||
|
||||
@ms_function
|
||||
def construct(self, rate=None):
|
||||
return self.e('sample', self.shape, rate)
|
||||
return self.e.sample(self.shape, rate)
|
||||
|
||||
def test_sample():
|
||||
"""
|
||||
|
@ -154,9 +148,8 @@ class CDF(nn.Cell):
|
|||
super(CDF, self).__init__()
|
||||
self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
return self.e('cdf', x_)
|
||||
return self.e.cdf(x_)
|
||||
|
||||
def test_cdf():
|
||||
"""
|
||||
|
@ -178,9 +171,8 @@ class LogCDF(nn.Cell):
|
|||
super(LogCDF, self).__init__()
|
||||
self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
return self.e('log_cdf', x_)
|
||||
return self.e.log_cdf(x_)
|
||||
|
||||
def test_log_cdf():
|
||||
"""
|
||||
|
@ -202,9 +194,8 @@ class SF(nn.Cell):
|
|||
super(SF, self).__init__()
|
||||
self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
return self.e('survival_function', x_)
|
||||
return self.e.survival_function(x_)
|
||||
|
||||
def test_survival():
|
||||
"""
|
||||
|
@ -226,9 +217,8 @@ class LogSF(nn.Cell):
|
|||
super(LogSF, self).__init__()
|
||||
self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
return self.e('log_survival', x_)
|
||||
return self.e.log_survival(x_)
|
||||
|
||||
def test_log_survival():
|
||||
"""
|
||||
|
@ -250,9 +240,8 @@ class EntropyH(nn.Cell):
|
|||
super(EntropyH, self).__init__()
|
||||
self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32)
|
||||
|
||||
@ms_function
|
||||
def construct(self):
|
||||
return self.e('entropy')
|
||||
return self.e.entropy()
|
||||
|
||||
def test_entropy():
|
||||
"""
|
||||
|
@ -273,12 +262,11 @@ class CrossEntropy(nn.Cell):
|
|||
super(CrossEntropy, self).__init__()
|
||||
self.e = msd.Exponential([1.0], dtype=dtype.float32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
entropy = self.e('entropy')
|
||||
kl_loss = self.e('kl_loss', 'Exponential', x_)
|
||||
entropy = self.e.entropy()
|
||||
kl_loss = self.e.kl_loss('Exponential', x_)
|
||||
h_sum_kl = entropy + kl_loss
|
||||
cross_entropy = self.e('cross_entropy', 'Exponential', x_)
|
||||
cross_entropy = self.e.cross_entropy('Exponential', x_)
|
||||
return h_sum_kl - cross_entropy
|
||||
|
||||
def test_cross_entropy():
|
||||
|
|
|
@ -19,7 +19,6 @@ import mindspore.context as context
|
|||
import mindspore.nn as nn
|
||||
import mindspore.nn.probability.distribution as msd
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore import dtype
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
@ -32,9 +31,8 @@ class Prob(nn.Cell):
|
|||
super(Prob, self).__init__()
|
||||
self.g = msd.Geometric(0.7, dtype=dtype.int32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
return self.g('prob', x_)
|
||||
return self.g.prob(x_)
|
||||
|
||||
def test_pmf():
|
||||
"""
|
||||
|
@ -56,9 +54,8 @@ class LogProb(nn.Cell):
|
|||
super(LogProb, self).__init__()
|
||||
self.g = msd.Geometric(0.7, dtype=dtype.int32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
return self.g('log_prob', x_)
|
||||
return self.g.log_prob(x_)
|
||||
|
||||
def test_log_likelihood():
|
||||
"""
|
||||
|
@ -80,9 +77,8 @@ class KL(nn.Cell):
|
|||
super(KL, self).__init__()
|
||||
self.g = msd.Geometric(0.7, dtype=dtype.int32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
return self.g('kl_loss', 'Geometric', x_)
|
||||
return self.g.kl_loss('Geometric', x_)
|
||||
|
||||
def test_kl_loss():
|
||||
"""
|
||||
|
@ -106,9 +102,8 @@ class Basics(nn.Cell):
|
|||
super(Basics, self).__init__()
|
||||
self.g = msd.Geometric([0.5, 0.5], dtype=dtype.int32)
|
||||
|
||||
@ms_function
|
||||
def construct(self):
|
||||
return self.g('mean'), self.g('sd'), self.g('mode')
|
||||
return self.g.mean(), self.g.sd(), self.g.mode()
|
||||
|
||||
def test_basics():
|
||||
"""
|
||||
|
@ -133,9 +128,8 @@ class Sampling(nn.Cell):
|
|||
self.g = msd.Geometric([0.7, 0.5], seed=seed, dtype=dtype.int32)
|
||||
self.shape = shape
|
||||
|
||||
@ms_function
|
||||
def construct(self, probs=None):
|
||||
return self.g('sample', self.shape, probs)
|
||||
return self.g.sample(self.shape, probs)
|
||||
|
||||
def test_sample():
|
||||
"""
|
||||
|
@ -154,9 +148,8 @@ class CDF(nn.Cell):
|
|||
super(CDF, self).__init__()
|
||||
self.g = msd.Geometric(0.7, dtype=dtype.int32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
return self.g('cdf', x_)
|
||||
return self.g.cdf(x_)
|
||||
|
||||
def test_cdf():
|
||||
"""
|
||||
|
@ -178,9 +171,8 @@ class LogCDF(nn.Cell):
|
|||
super(LogCDF, self).__init__()
|
||||
self.g = msd.Geometric(0.7, dtype=dtype.int32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
return self.g('log_cdf', x_)
|
||||
return self.g.log_cdf(x_)
|
||||
|
||||
def test_logcdf():
|
||||
"""
|
||||
|
@ -202,9 +194,8 @@ class SF(nn.Cell):
|
|||
super(SF, self).__init__()
|
||||
self.g = msd.Geometric(0.7, dtype=dtype.int32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
return self.g('survival_function', x_)
|
||||
return self.g.survival_function(x_)
|
||||
|
||||
def test_survival():
|
||||
"""
|
||||
|
@ -226,9 +217,8 @@ class LogSF(nn.Cell):
|
|||
super(LogSF, self).__init__()
|
||||
self.g = msd.Geometric(0.7, dtype=dtype.int32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
return self.g('log_survival', x_)
|
||||
return self.g.log_survival(x_)
|
||||
|
||||
def test_log_survival():
|
||||
"""
|
||||
|
@ -250,9 +240,8 @@ class EntropyH(nn.Cell):
|
|||
super(EntropyH, self).__init__()
|
||||
self.g = msd.Geometric(0.7, dtype=dtype.int32)
|
||||
|
||||
@ms_function
|
||||
def construct(self):
|
||||
return self.g('entropy')
|
||||
return self.g.entropy()
|
||||
|
||||
def test_entropy():
|
||||
"""
|
||||
|
@ -273,12 +262,11 @@ class CrossEntropy(nn.Cell):
|
|||
super(CrossEntropy, self).__init__()
|
||||
self.g = msd.Geometric(0.7, dtype=dtype.int32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
entropy = self.g('entropy')
|
||||
kl_loss = self.g('kl_loss', 'Geometric', x_)
|
||||
entropy = self.g.entropy()
|
||||
kl_loss = self.g.kl_loss('Geometric', x_)
|
||||
h_sum_kl = entropy + kl_loss
|
||||
ans = self.g('cross_entropy', 'Geometric', x_)
|
||||
ans = self.g.cross_entropy('Geometric', x_)
|
||||
return h_sum_kl - ans
|
||||
|
||||
def test_cross_entropy():
|
||||
|
|
|
@ -19,7 +19,6 @@ import mindspore.context as context
|
|||
import mindspore.nn as nn
|
||||
import mindspore.nn.probability.distribution as msd
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore import dtype
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
@ -32,9 +31,8 @@ class Prob(nn.Cell):
|
|||
super(Prob, self).__init__()
|
||||
self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
return self.n('prob', x_)
|
||||
return self.n.prob(x_)
|
||||
|
||||
def test_pdf():
|
||||
"""
|
||||
|
@ -55,9 +53,8 @@ class LogProb(nn.Cell):
|
|||
super(LogProb, self).__init__()
|
||||
self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
return self.n('log_prob', x_)
|
||||
return self.n.log_prob(x_)
|
||||
|
||||
def test_log_likelihood():
|
||||
"""
|
||||
|
@ -79,9 +76,8 @@ class KL(nn.Cell):
|
|||
super(KL, self).__init__()
|
||||
self.n = msd.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_, y_):
|
||||
return self.n('kl_loss', 'Normal', x_, y_)
|
||||
return self.n.kl_loss('Normal', x_, y_)
|
||||
|
||||
|
||||
def test_kl_loss():
|
||||
|
@ -113,9 +109,8 @@ class Basics(nn.Cell):
|
|||
super(Basics, self).__init__()
|
||||
self.n = msd.Normal(np.array([3.0]), np.array([2.0, 4.0]), dtype=dtype.float32)
|
||||
|
||||
@ms_function
|
||||
def construct(self):
|
||||
return self.n('mean'), self.n('sd'), self.n('mode')
|
||||
return self.n.mean(), self.n.sd(), self.n.mode()
|
||||
|
||||
def test_basics():
|
||||
"""
|
||||
|
@ -139,9 +134,8 @@ class Sampling(nn.Cell):
|
|||
self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), seed=seed, dtype=dtype.float32)
|
||||
self.shape = shape
|
||||
|
||||
@ms_function
|
||||
def construct(self, mean=None, sd=None):
|
||||
return self.n('sample', self.shape, mean, sd)
|
||||
return self.n.sample(self.shape, mean, sd)
|
||||
|
||||
def test_sample():
|
||||
"""
|
||||
|
@ -163,9 +157,8 @@ class CDF(nn.Cell):
|
|||
super(CDF, self).__init__()
|
||||
self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
return self.n('cdf', x_)
|
||||
return self.n.cdf(x_)
|
||||
|
||||
|
||||
def test_cdf():
|
||||
|
@ -187,9 +180,8 @@ class LogCDF(nn.Cell):
|
|||
super(LogCDF, self).__init__()
|
||||
self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
return self.n('log_cdf', x_)
|
||||
return self.n.log_cdf(x_)
|
||||
|
||||
def test_log_cdf():
|
||||
"""
|
||||
|
@ -210,9 +202,8 @@ class SF(nn.Cell):
|
|||
super(SF, self).__init__()
|
||||
self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
return self.n('survival_function', x_)
|
||||
return self.n.survival_function(x_)
|
||||
|
||||
def test_survival():
|
||||
"""
|
||||
|
@ -233,9 +224,8 @@ class LogSF(nn.Cell):
|
|||
super(LogSF, self).__init__()
|
||||
self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
return self.n('log_survival', x_)
|
||||
return self.n.log_survival(x_)
|
||||
|
||||
def test_log_survival():
|
||||
"""
|
||||
|
@ -256,9 +246,8 @@ class EntropyH(nn.Cell):
|
|||
super(EntropyH, self).__init__()
|
||||
self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
|
||||
|
||||
@ms_function
|
||||
def construct(self):
|
||||
return self.n('entropy')
|
||||
return self.n.entropy()
|
||||
|
||||
def test_entropy():
|
||||
"""
|
||||
|
@ -279,12 +268,11 @@ class CrossEntropy(nn.Cell):
|
|||
super(CrossEntropy, self).__init__()
|
||||
self.n = msd.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_, y_):
|
||||
entropy = self.n('entropy')
|
||||
kl_loss = self.n('kl_loss', 'Normal', x_, y_)
|
||||
entropy = self.n.entropy()
|
||||
kl_loss = self.n.kl_loss('Normal', x_, y_)
|
||||
h_sum_kl = entropy + kl_loss
|
||||
cross_entropy = self.n('cross_entropy', 'Normal', x_, y_)
|
||||
cross_entropy = self.n.cross_entropy('Normal', x_, y_)
|
||||
return h_sum_kl - cross_entropy
|
||||
|
||||
def test_cross_entropy():
|
||||
|
@ -297,3 +285,40 @@ def test_cross_entropy():
|
|||
diff = cross_entropy(mean, sd)
|
||||
tol = 1e-6
|
||||
assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all()
|
||||
|
||||
class Net(nn.Cell):
|
||||
"""
|
||||
Test class: expand single distribution instance to multiple graphs
|
||||
by specifying the attributes.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.normal = msd.Normal(0., 1., dtype=dtype.float32)
|
||||
|
||||
def construct(self, x_, y_):
|
||||
kl = self.normal.kl_loss('Normal', x_, y_)
|
||||
prob = self.normal.prob(kl)
|
||||
return prob
|
||||
|
||||
def test_multiple_graphs():
|
||||
"""
|
||||
Test multiple graphs case.
|
||||
"""
|
||||
prob = Net()
|
||||
mean_a = np.array([0.0]).astype(np.float32)
|
||||
sd_a = np.array([1.0]).astype(np.float32)
|
||||
mean_b = np.array([1.0]).astype(np.float32)
|
||||
sd_b = np.array([1.0]).astype(np.float32)
|
||||
ans = prob(Tensor(mean_b), Tensor(sd_b))
|
||||
|
||||
diff_log_scale = np.log(sd_a) - np.log(sd_b)
|
||||
squared_diff = np.square(mean_a / sd_b - mean_b / sd_b)
|
||||
expect_kl_loss = 0.5 * squared_diff + 0.5 * \
|
||||
np.expm1(2 * diff_log_scale) - diff_log_scale
|
||||
|
||||
norm_benchmark = stats.norm(np.array([0.0]), np.array([1.0]))
|
||||
expect_prob = norm_benchmark.pdf(expect_kl_loss).astype(np.float32)
|
||||
|
||||
tol = 1e-6
|
||||
assert (np.abs(ans.asnumpy() - expect_prob) < tol).all()
|
||||
|
|
|
@ -1,62 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test cases for new api of normal distribution"""
|
||||
import numpy as np
|
||||
from scipy import stats
|
||||
import mindspore.nn as nn
|
||||
import mindspore.nn.probability.distribution as msd
|
||||
from mindspore import dtype
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
"""
|
||||
Test class: new api of normal distribution.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.normal = msd.Normal(0., 1., dtype=dtype.float32)
|
||||
|
||||
def construct(self, x_, y_):
|
||||
kl = self.normal.kl_loss('kl_loss', 'Normal', x_, y_)
|
||||
prob = self.normal.prob('prob', kl)
|
||||
return prob
|
||||
|
||||
|
||||
def test_new_api():
|
||||
"""
|
||||
Test new api of normal distribution.
|
||||
"""
|
||||
prob = Net()
|
||||
mean_a = np.array([0.0]).astype(np.float32)
|
||||
sd_a = np.array([1.0]).astype(np.float32)
|
||||
mean_b = np.array([1.0]).astype(np.float32)
|
||||
sd_b = np.array([1.0]).astype(np.float32)
|
||||
ans = prob(Tensor(mean_b), Tensor(sd_b))
|
||||
|
||||
diff_log_scale = np.log(sd_a) - np.log(sd_b)
|
||||
squared_diff = np.square(mean_a / sd_b - mean_b / sd_b)
|
||||
expect_kl_loss = 0.5 * squared_diff + 0.5 * \
|
||||
np.expm1(2 * diff_log_scale) - diff_log_scale
|
||||
|
||||
norm_benchmark = stats.norm(np.array([0.0]), np.array([1.0]))
|
||||
expect_prob = norm_benchmark.pdf(expect_kl_loss).astype(np.float32)
|
||||
|
||||
tol = 1e-6
|
||||
assert (np.abs(ans.asnumpy() - expect_prob) < tol).all()
|
|
@ -19,7 +19,6 @@ import mindspore.context as context
|
|||
import mindspore.nn as nn
|
||||
import mindspore.nn.probability.distribution as msd
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore import dtype
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
@ -32,9 +31,8 @@ class Prob(nn.Cell):
|
|||
super(Prob, self).__init__()
|
||||
self.u = msd.Uniform([0.0], [[1.0], [2.0]], dtype=dtype.float32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
return self.u('prob', x_)
|
||||
return self.u.prob(x_)
|
||||
|
||||
def test_pdf():
|
||||
"""
|
||||
|
@ -56,9 +54,8 @@ class LogProb(nn.Cell):
|
|||
super(LogProb, self).__init__()
|
||||
self.u = msd.Uniform([0.0], [[1.0], [2.0]], dtype=dtype.float32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
return self.u('log_prob', x_)
|
||||
return self.u.log_prob(x_)
|
||||
|
||||
def test_log_likelihood():
|
||||
"""
|
||||
|
@ -80,9 +77,8 @@ class KL(nn.Cell):
|
|||
super(KL, self).__init__()
|
||||
self.u = msd.Uniform([0.0], [1.5], dtype=dtype.float32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_, y_):
|
||||
return self.u('kl_loss', 'Uniform', x_, y_)
|
||||
return self.u.kl_loss('Uniform', x_, y_)
|
||||
|
||||
def test_kl_loss():
|
||||
"""
|
||||
|
@ -106,9 +102,8 @@ class Basics(nn.Cell):
|
|||
super(Basics, self).__init__()
|
||||
self.u = msd.Uniform([0.0], [3.0], dtype=dtype.float32)
|
||||
|
||||
@ms_function
|
||||
def construct(self):
|
||||
return self.u('mean'), self.u('sd')
|
||||
return self.u.mean(), self.u.sd()
|
||||
|
||||
def test_basics():
|
||||
"""
|
||||
|
@ -131,9 +126,8 @@ class Sampling(nn.Cell):
|
|||
self.u = msd.Uniform([0.0], [[1.0], [2.0]], seed=seed, dtype=dtype.float32)
|
||||
self.shape = shape
|
||||
|
||||
@ms_function
|
||||
def construct(self, low=None, high=None):
|
||||
return self.u('sample', self.shape, low, high)
|
||||
return self.u.sample(self.shape, low, high)
|
||||
|
||||
def test_sample():
|
||||
"""
|
||||
|
@ -155,9 +149,8 @@ class CDF(nn.Cell):
|
|||
super(CDF, self).__init__()
|
||||
self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
return self.u('cdf', x_)
|
||||
return self.u.cdf(x_)
|
||||
|
||||
def test_cdf():
|
||||
"""
|
||||
|
@ -179,9 +172,8 @@ class LogCDF(nn.Cell):
|
|||
super(LogCDF, self).__init__()
|
||||
self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
return self.u('log_cdf', x_)
|
||||
return self.u.log_cdf(x_)
|
||||
|
||||
class SF(nn.Cell):
|
||||
"""
|
||||
|
@ -191,9 +183,8 @@ class SF(nn.Cell):
|
|||
super(SF, self).__init__()
|
||||
self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
return self.u('survival_function', x_)
|
||||
return self.u.survival_function(x_)
|
||||
|
||||
class LogSF(nn.Cell):
|
||||
"""
|
||||
|
@ -203,9 +194,8 @@ class LogSF(nn.Cell):
|
|||
super(LogSF, self).__init__()
|
||||
self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_):
|
||||
return self.u('log_survival', x_)
|
||||
return self.u.log_survival(x_)
|
||||
|
||||
class EntropyH(nn.Cell):
|
||||
"""
|
||||
|
@ -215,9 +205,8 @@ class EntropyH(nn.Cell):
|
|||
super(EntropyH, self).__init__()
|
||||
self.u = msd.Uniform([0.0], [1.0, 2.0], dtype=dtype.float32)
|
||||
|
||||
@ms_function
|
||||
def construct(self):
|
||||
return self.u('entropy')
|
||||
return self.u.entropy()
|
||||
|
||||
def test_entropy():
|
||||
"""
|
||||
|
@ -238,12 +227,11 @@ class CrossEntropy(nn.Cell):
|
|||
super(CrossEntropy, self).__init__()
|
||||
self.u = msd.Uniform([0.0], [1.5], dtype=dtype.float32)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x_, y_):
|
||||
entropy = self.u('entropy')
|
||||
kl_loss = self.u('kl_loss', 'Uniform', x_, y_)
|
||||
entropy = self.u.entropy()
|
||||
kl_loss = self.u.kl_loss('Uniform', x_, y_)
|
||||
h_sum_kl = entropy + kl_loss
|
||||
cross_entropy = self.u('cross_entropy', 'Uniform', x_, y_)
|
||||
cross_entropy = self.u.cross_entropy('Uniform', x_, y_)
|
||||
return h_sum_kl - cross_entropy
|
||||
|
||||
def test_log_cdf():
|
||||
|
|
|
@ -49,12 +49,12 @@ class BernoulliProb(nn.Cell):
|
|||
self.b = msd.Bernoulli(0.5, dtype=dtype.int32)
|
||||
|
||||
def construct(self, value):
|
||||
prob = self.b('prob', value)
|
||||
log_prob = self.b('log_prob', value)
|
||||
cdf = self.b('cdf', value)
|
||||
log_cdf = self.b('log_cdf', value)
|
||||
sf = self.b('survival_function', value)
|
||||
log_sf = self.b('log_survival', value)
|
||||
prob = self.b.prob(value)
|
||||
log_prob = self.b.log_prob(value)
|
||||
cdf = self.b.cdf(value)
|
||||
log_cdf = self.b.log_cdf(value)
|
||||
sf = self.b.survival_function(value)
|
||||
log_sf = self.b.log_survival(value)
|
||||
return prob + log_prob + cdf + log_cdf + sf + log_sf
|
||||
|
||||
def test_bernoulli_prob():
|
||||
|
@ -75,12 +75,12 @@ class BernoulliProb1(nn.Cell):
|
|||
self.b = msd.Bernoulli(dtype=dtype.int32)
|
||||
|
||||
def construct(self, value, probs):
|
||||
prob = self.b('prob', value, probs)
|
||||
log_prob = self.b('log_prob', value, probs)
|
||||
cdf = self.b('cdf', value, probs)
|
||||
log_cdf = self.b('log_cdf', value, probs)
|
||||
sf = self.b('survival_function', value, probs)
|
||||
log_sf = self.b('log_survival', value, probs)
|
||||
prob = self.b.prob(value, probs)
|
||||
log_prob = self.b.log_prob(value, probs)
|
||||
cdf = self.b.cdf(value, probs)
|
||||
log_cdf = self.b.log_cdf(value, probs)
|
||||
sf = self.b.survival_function(value, probs)
|
||||
log_sf = self.b.log_survival(value, probs)
|
||||
return prob + log_prob + cdf + log_cdf + sf + log_sf
|
||||
|
||||
def test_bernoulli_prob1():
|
||||
|
@ -103,8 +103,8 @@ class BernoulliKl(nn.Cell):
|
|||
self.b2 = msd.Bernoulli(dtype=dtype.int32)
|
||||
|
||||
def construct(self, probs_b, probs_a):
|
||||
kl1 = self.b1('kl_loss', 'Bernoulli', probs_b)
|
||||
kl2 = self.b2('kl_loss', 'Bernoulli', probs_b, probs_a)
|
||||
kl1 = self.b1.kl_loss('Bernoulli', probs_b)
|
||||
kl2 = self.b2.kl_loss('Bernoulli', probs_b, probs_a)
|
||||
return kl1 + kl2
|
||||
|
||||
def test_kl():
|
||||
|
@ -127,8 +127,8 @@ class BernoulliCrossEntropy(nn.Cell):
|
|||
self.b2 = msd.Bernoulli(dtype=dtype.int32)
|
||||
|
||||
def construct(self, probs_b, probs_a):
|
||||
h1 = self.b1('cross_entropy', 'Bernoulli', probs_b)
|
||||
h2 = self.b2('cross_entropy', 'Bernoulli', probs_b, probs_a)
|
||||
h1 = self.b1.cross_entropy('Bernoulli', probs_b)
|
||||
h2 = self.b2.cross_entropy('Bernoulli', probs_b, probs_a)
|
||||
return h1 + h2
|
||||
|
||||
def test_cross_entropy():
|
||||
|
@ -150,11 +150,11 @@ class BernoulliBasics(nn.Cell):
|
|||
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
|
||||
|
||||
def construct(self):
|
||||
mean = self.b('mean')
|
||||
sd = self.b('sd')
|
||||
var = self.b('var')
|
||||
mode = self.b('mode')
|
||||
entropy = self.b('entropy')
|
||||
mean = self.b.mean()
|
||||
sd = self.b.sd()
|
||||
var = self.b.var()
|
||||
mode = self.b.mode()
|
||||
entropy = self.b.entropy()
|
||||
return mean + sd + var + mode + entropy
|
||||
|
||||
def test_bascis():
|
||||
|
@ -164,3 +164,28 @@ def test_bascis():
|
|||
net = BernoulliBasics()
|
||||
ans = net()
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
class BernoulliConstruct(nn.Cell):
|
||||
"""
|
||||
Bernoulli distribution: going through construct.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(BernoulliConstruct, self).__init__()
|
||||
self.b = msd.Bernoulli(0.5, dtype=dtype.int32)
|
||||
self.b1 = msd.Bernoulli(dtype=dtype.int32)
|
||||
|
||||
def construct(self, value, probs):
|
||||
prob = self.b('prob', value)
|
||||
prob1 = self.b('prob', value, probs)
|
||||
prob2 = self.b1('prob', value, probs)
|
||||
return prob + prob1 + prob2
|
||||
|
||||
def test_bernoulli_construct():
|
||||
"""
|
||||
Test probability function going through construct.
|
||||
"""
|
||||
net = BernoulliConstruct()
|
||||
value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32)
|
||||
probs = Tensor([0.5], dtype=dtype.float32)
|
||||
ans = net(value, probs)
|
||||
assert isinstance(ans, Tensor)
|
||||
|
|
|
@ -50,12 +50,12 @@ class ExponentialProb(nn.Cell):
|
|||
self.e = msd.Exponential(0.5, dtype=dtype.float32)
|
||||
|
||||
def construct(self, value):
|
||||
prob = self.e('prob', value)
|
||||
log_prob = self.e('log_prob', value)
|
||||
cdf = self.e('cdf', value)
|
||||
log_cdf = self.e('log_cdf', value)
|
||||
sf = self.e('survival_function', value)
|
||||
log_sf = self.e('log_survival', value)
|
||||
prob = self.e.prob(value)
|
||||
log_prob = self.e.log_prob(value)
|
||||
cdf = self.e.cdf(value)
|
||||
log_cdf = self.e.log_cdf(value)
|
||||
sf = self.e.survival_function(value)
|
||||
log_sf = self.e.log_survival(value)
|
||||
return prob + log_prob + cdf + log_cdf + sf + log_sf
|
||||
|
||||
def test_exponential_prob():
|
||||
|
@ -76,12 +76,12 @@ class ExponentialProb1(nn.Cell):
|
|||
self.e = msd.Exponential(dtype=dtype.float32)
|
||||
|
||||
def construct(self, value, rate):
|
||||
prob = self.e('prob', value, rate)
|
||||
log_prob = self.e('log_prob', value, rate)
|
||||
cdf = self.e('cdf', value, rate)
|
||||
log_cdf = self.e('log_cdf', value, rate)
|
||||
sf = self.e('survival_function', value, rate)
|
||||
log_sf = self.e('log_survival', value, rate)
|
||||
prob = self.e.prob(value, rate)
|
||||
log_prob = self.e.log_prob(value, rate)
|
||||
cdf = self.e.cdf(value, rate)
|
||||
log_cdf = self.e.log_cdf(value, rate)
|
||||
sf = self.e.survival_function(value, rate)
|
||||
log_sf = self.e.log_survival(value, rate)
|
||||
return prob + log_prob + cdf + log_cdf + sf + log_sf
|
||||
|
||||
def test_exponential_prob1():
|
||||
|
@ -104,8 +104,8 @@ class ExponentialKl(nn.Cell):
|
|||
self.e2 = msd.Exponential(dtype=dtype.float32)
|
||||
|
||||
def construct(self, rate_b, rate_a):
|
||||
kl1 = self.e1('kl_loss', 'Exponential', rate_b)
|
||||
kl2 = self.e2('kl_loss', 'Exponential', rate_b, rate_a)
|
||||
kl1 = self.e1.kl_loss('Exponential', rate_b)
|
||||
kl2 = self.e2.kl_loss('Exponential', rate_b, rate_a)
|
||||
return kl1 + kl2
|
||||
|
||||
def test_kl():
|
||||
|
@ -128,8 +128,8 @@ class ExponentialCrossEntropy(nn.Cell):
|
|||
self.e2 = msd.Exponential(dtype=dtype.float32)
|
||||
|
||||
def construct(self, rate_b, rate_a):
|
||||
h1 = self.e1('cross_entropy', 'Exponential', rate_b)
|
||||
h2 = self.e2('cross_entropy', 'Exponential', rate_b, rate_a)
|
||||
h1 = self.e1.cross_entropy('Exponential', rate_b)
|
||||
h2 = self.e2.cross_entropy('Exponential', rate_b, rate_a)
|
||||
return h1 + h2
|
||||
|
||||
def test_cross_entropy():
|
||||
|
@ -151,11 +151,11 @@ class ExponentialBasics(nn.Cell):
|
|||
self.e = msd.Exponential([0.3, 0.5], dtype=dtype.float32)
|
||||
|
||||
def construct(self):
|
||||
mean = self.e('mean')
|
||||
sd = self.e('sd')
|
||||
var = self.e('var')
|
||||
mode = self.e('mode')
|
||||
entropy = self.e('entropy')
|
||||
mean = self.e.mean()
|
||||
sd = self.e.sd()
|
||||
var = self.e.var()
|
||||
mode = self.e.mode()
|
||||
entropy = self.e.entropy()
|
||||
return mean + sd + var + mode + entropy
|
||||
|
||||
def test_bascis():
|
||||
|
@ -165,3 +165,29 @@ def test_bascis():
|
|||
net = ExponentialBasics()
|
||||
ans = net()
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
|
||||
class ExpConstruct(nn.Cell):
|
||||
"""
|
||||
Exponential distribution: going through construct.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(ExpConstruct, self).__init__()
|
||||
self.e = msd.Exponential(0.5, dtype=dtype.float32)
|
||||
self.e1 = msd.Exponential(dtype=dtype.float32)
|
||||
|
||||
def construct(self, value, rate):
|
||||
prob = self.e('prob', value)
|
||||
prob1 = self.e('prob', value, rate)
|
||||
prob2 = self.e1('prob', value, rate)
|
||||
return prob + prob1 + prob2
|
||||
|
||||
def test_exp_construct():
|
||||
"""
|
||||
Test probability function going through construct.
|
||||
"""
|
||||
net = ExpConstruct()
|
||||
value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32)
|
||||
probs = Tensor([0.5], dtype=dtype.float32)
|
||||
ans = net(value, probs)
|
||||
assert isinstance(ans, Tensor)
|
||||
|
|
|
@ -50,12 +50,12 @@ class GeometricProb(nn.Cell):
|
|||
self.g = msd.Geometric(0.5, dtype=dtype.int32)
|
||||
|
||||
def construct(self, value):
|
||||
prob = self.g('prob', value)
|
||||
log_prob = self.g('log_prob', value)
|
||||
cdf = self.g('cdf', value)
|
||||
log_cdf = self.g('log_cdf', value)
|
||||
sf = self.g('survival_function', value)
|
||||
log_sf = self.g('log_survival', value)
|
||||
prob = self.g.prob(value)
|
||||
log_prob = self.g.log_prob(value)
|
||||
cdf = self.g.cdf(value)
|
||||
log_cdf = self.g.log_cdf(value)
|
||||
sf = self.g.survival_function(value)
|
||||
log_sf = self.g.log_survival(value)
|
||||
return prob + log_prob + cdf + log_cdf + sf + log_sf
|
||||
|
||||
def test_geometric_prob():
|
||||
|
@ -76,12 +76,12 @@ class GeometricProb1(nn.Cell):
|
|||
self.g = msd.Geometric(dtype=dtype.int32)
|
||||
|
||||
def construct(self, value, probs):
|
||||
prob = self.g('prob', value, probs)
|
||||
log_prob = self.g('log_prob', value, probs)
|
||||
cdf = self.g('cdf', value, probs)
|
||||
log_cdf = self.g('log_cdf', value, probs)
|
||||
sf = self.g('survival_function', value, probs)
|
||||
log_sf = self.g('log_survival', value, probs)
|
||||
prob = self.g.prob(value, probs)
|
||||
log_prob = self.g.log_prob(value, probs)
|
||||
cdf = self.g.cdf(value, probs)
|
||||
log_cdf = self.g.log_cdf(value, probs)
|
||||
sf = self.g.survival_function(value, probs)
|
||||
log_sf = self.g.log_survival(value, probs)
|
||||
return prob + log_prob + cdf + log_cdf + sf + log_sf
|
||||
|
||||
def test_geometric_prob1():
|
||||
|
@ -105,8 +105,8 @@ class GeometricKl(nn.Cell):
|
|||
self.g2 = msd.Geometric(dtype=dtype.int32)
|
||||
|
||||
def construct(self, probs_b, probs_a):
|
||||
kl1 = self.g1('kl_loss', 'Geometric', probs_b)
|
||||
kl2 = self.g2('kl_loss', 'Geometric', probs_b, probs_a)
|
||||
kl1 = self.g1.kl_loss('Geometric', probs_b)
|
||||
kl2 = self.g2.kl_loss('Geometric', probs_b, probs_a)
|
||||
return kl1 + kl2
|
||||
|
||||
def test_kl():
|
||||
|
@ -129,8 +129,8 @@ class GeometricCrossEntropy(nn.Cell):
|
|||
self.g2 = msd.Geometric(dtype=dtype.int32)
|
||||
|
||||
def construct(self, probs_b, probs_a):
|
||||
h1 = self.g1('cross_entropy', 'Geometric', probs_b)
|
||||
h2 = self.g2('cross_entropy', 'Geometric', probs_b, probs_a)
|
||||
h1 = self.g1.cross_entropy('Geometric', probs_b)
|
||||
h2 = self.g2.cross_entropy('Geometric', probs_b, probs_a)
|
||||
return h1 + h2
|
||||
|
||||
def test_cross_entropy():
|
||||
|
@ -152,11 +152,11 @@ class GeometricBasics(nn.Cell):
|
|||
self.g = msd.Geometric([0.3, 0.5], dtype=dtype.int32)
|
||||
|
||||
def construct(self):
|
||||
mean = self.g('mean')
|
||||
sd = self.g('sd')
|
||||
var = self.g('var')
|
||||
mode = self.g('mode')
|
||||
entropy = self.g('entropy')
|
||||
mean = self.g.mean()
|
||||
sd = self.g.sd()
|
||||
var = self.g.var()
|
||||
mode = self.g.mode()
|
||||
entropy = self.g.entropy()
|
||||
return mean + sd + var + mode + entropy
|
||||
|
||||
def test_bascis():
|
||||
|
@ -166,3 +166,29 @@ def test_bascis():
|
|||
net = GeometricBasics()
|
||||
ans = net()
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
|
||||
class GeoConstruct(nn.Cell):
|
||||
"""
|
||||
Bernoulli distribution: going through construct.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(GeoConstruct, self).__init__()
|
||||
self.g = msd.Geometric(0.5, dtype=dtype.int32)
|
||||
self.g1 = msd.Geometric(dtype=dtype.int32)
|
||||
|
||||
def construct(self, value, probs):
|
||||
prob = self.g('prob', value)
|
||||
prob1 = self.g('prob', value, probs)
|
||||
prob2 = self.g1('prob', value, probs)
|
||||
return prob + prob1 + prob2
|
||||
|
||||
def test_geo_construct():
|
||||
"""
|
||||
Test probability function going through construct.
|
||||
"""
|
||||
net = GeoConstruct()
|
||||
value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32)
|
||||
probs = Tensor([0.5], dtype=dtype.float32)
|
||||
ans = net(value, probs)
|
||||
assert isinstance(ans, Tensor)
|
||||
|
|
|
@ -50,12 +50,12 @@ class NormalProb(nn.Cell):
|
|||
self.normal = msd.Normal(3.0, 4.0, dtype=dtype.float32)
|
||||
|
||||
def construct(self, value):
|
||||
prob = self.normal('prob', value)
|
||||
log_prob = self.normal('log_prob', value)
|
||||
cdf = self.normal('cdf', value)
|
||||
log_cdf = self.normal('log_cdf', value)
|
||||
sf = self.normal('survival_function', value)
|
||||
log_sf = self.normal('log_survival', value)
|
||||
prob = self.normal.prob(value)
|
||||
log_prob = self.normal.log_prob(value)
|
||||
cdf = self.normal.cdf(value)
|
||||
log_cdf = self.normal.log_cdf(value)
|
||||
sf = self.normal.survival_function(value)
|
||||
log_sf = self.normal.log_survival(value)
|
||||
return prob + log_prob + cdf + log_cdf + sf + log_sf
|
||||
|
||||
def test_normal_prob():
|
||||
|
@ -77,12 +77,12 @@ class NormalProb1(nn.Cell):
|
|||
self.normal = msd.Normal()
|
||||
|
||||
def construct(self, value, mean, sd):
|
||||
prob = self.normal('prob', value, mean, sd)
|
||||
log_prob = self.normal('log_prob', value, mean, sd)
|
||||
cdf = self.normal('cdf', value, mean, sd)
|
||||
log_cdf = self.normal('log_cdf', value, mean, sd)
|
||||
sf = self.normal('survival_function', value, mean, sd)
|
||||
log_sf = self.normal('log_survival', value, mean, sd)
|
||||
prob = self.normal.prob(value, mean, sd)
|
||||
log_prob = self.normal.log_prob(value, mean, sd)
|
||||
cdf = self.normal.cdf(value, mean, sd)
|
||||
log_cdf = self.normal.log_cdf(value, mean, sd)
|
||||
sf = self.normal.survival_function(value, mean, sd)
|
||||
log_sf = self.normal.log_survival(value, mean, sd)
|
||||
return prob + log_prob + cdf + log_cdf + sf + log_sf
|
||||
|
||||
def test_normal_prob1():
|
||||
|
@ -106,8 +106,8 @@ class NormalKl(nn.Cell):
|
|||
self.n2 = msd.Normal(dtype=dtype.float32)
|
||||
|
||||
def construct(self, mean_b, sd_b, mean_a, sd_a):
|
||||
kl1 = self.n1('kl_loss', 'Normal', mean_b, sd_b)
|
||||
kl2 = self.n2('kl_loss', 'Normal', mean_b, sd_b, mean_a, sd_a)
|
||||
kl1 = self.n1.kl_loss('Normal', mean_b, sd_b)
|
||||
kl2 = self.n2.kl_loss('Normal', mean_b, sd_b, mean_a, sd_a)
|
||||
return kl1 + kl2
|
||||
|
||||
def test_kl():
|
||||
|
@ -132,8 +132,8 @@ class NormalCrossEntropy(nn.Cell):
|
|||
self.n2 = msd.Normal(dtype=dtype.float32)
|
||||
|
||||
def construct(self, mean_b, sd_b, mean_a, sd_a):
|
||||
h1 = self.n1('cross_entropy', 'Normal', mean_b, sd_b)
|
||||
h2 = self.n2('cross_entropy', 'Normal', mean_b, sd_b, mean_a, sd_a)
|
||||
h1 = self.n1.cross_entropy('Normal', mean_b, sd_b)
|
||||
h2 = self.n2.cross_entropy('Normal', mean_b, sd_b, mean_a, sd_a)
|
||||
return h1 + h2
|
||||
|
||||
def test_cross_entropy():
|
||||
|
@ -157,10 +157,10 @@ class NormalBasics(nn.Cell):
|
|||
self.n = msd.Normal(3.0, 4.0, dtype=dtype.float32)
|
||||
|
||||
def construct(self):
|
||||
mean = self.n('mean')
|
||||
sd = self.n('sd')
|
||||
mode = self.n('mode')
|
||||
entropy = self.n('entropy')
|
||||
mean = self.n.mean()
|
||||
sd = self.n.sd()
|
||||
mode = self.n.mode()
|
||||
entropy = self.n.entropy()
|
||||
return mean + sd + mode + entropy
|
||||
|
||||
def test_bascis():
|
||||
|
@ -170,3 +170,30 @@ def test_bascis():
|
|||
net = NormalBasics()
|
||||
ans = net()
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
|
||||
class NormalConstruct(nn.Cell):
|
||||
"""
|
||||
Normal distribution: going through construct.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(NormalConstruct, self).__init__()
|
||||
self.normal = msd.Normal(3.0, 4.0)
|
||||
self.normal1 = msd.Normal()
|
||||
|
||||
def construct(self, value, mean, sd):
|
||||
prob = self.normal('prob', value)
|
||||
prob1 = self.normal('prob', value, mean, sd)
|
||||
prob2 = self.normal1('prob', value, mean, sd)
|
||||
return prob + prob1 + prob2
|
||||
|
||||
def test_normal_construct():
|
||||
"""
|
||||
Test probability function going through construct.
|
||||
"""
|
||||
net = NormalConstruct()
|
||||
value = Tensor([0.5, 1.0], dtype=dtype.float32)
|
||||
mean = Tensor([0.0], dtype=dtype.float32)
|
||||
sd = Tensor([1.0], dtype=dtype.float32)
|
||||
ans = net(value, mean, sd)
|
||||
assert isinstance(ans, Tensor)
|
||||
|
|
|
@ -60,12 +60,12 @@ class UniformProb(nn.Cell):
|
|||
self.u = msd.Uniform(3.0, 4.0, dtype=dtype.float32)
|
||||
|
||||
def construct(self, value):
|
||||
prob = self.u('prob', value)
|
||||
log_prob = self.u('log_prob', value)
|
||||
cdf = self.u('cdf', value)
|
||||
log_cdf = self.u('log_cdf', value)
|
||||
sf = self.u('survival_function', value)
|
||||
log_sf = self.u('log_survival', value)
|
||||
prob = self.u.prob(value)
|
||||
log_prob = self.u.log_prob(value)
|
||||
cdf = self.u.cdf(value)
|
||||
log_cdf = self.u.log_cdf(value)
|
||||
sf = self.u.survival_function(value)
|
||||
log_sf = self.u.log_survival(value)
|
||||
return prob + log_prob + cdf + log_cdf + sf + log_sf
|
||||
|
||||
def test_uniform_prob():
|
||||
|
@ -86,12 +86,12 @@ class UniformProb1(nn.Cell):
|
|||
self.u = msd.Uniform(dtype=dtype.float32)
|
||||
|
||||
def construct(self, value, low, high):
|
||||
prob = self.u('prob', value, low, high)
|
||||
log_prob = self.u('log_prob', value, low, high)
|
||||
cdf = self.u('cdf', value, low, high)
|
||||
log_cdf = self.u('log_cdf', value, low, high)
|
||||
sf = self.u('survival_function', value, low, high)
|
||||
log_sf = self.u('log_survival', value, low, high)
|
||||
prob = self.u.prob(value, low, high)
|
||||
log_prob = self.u.log_prob(value, low, high)
|
||||
cdf = self.u.cdf(value, low, high)
|
||||
log_cdf = self.u.log_cdf(value, low, high)
|
||||
sf = self.u.survival_function(value, low, high)
|
||||
log_sf = self.u.log_survival(value, low, high)
|
||||
return prob + log_prob + cdf + log_cdf + sf + log_sf
|
||||
|
||||
def test_uniform_prob1():
|
||||
|
@ -115,8 +115,8 @@ class UniformKl(nn.Cell):
|
|||
self.u2 = msd.Uniform(dtype=dtype.float32)
|
||||
|
||||
def construct(self, low_b, high_b, low_a, high_a):
|
||||
kl1 = self.u1('kl_loss', 'Uniform', low_b, high_b)
|
||||
kl2 = self.u2('kl_loss', 'Uniform', low_b, high_b, low_a, high_a)
|
||||
kl1 = self.u1.kl_loss('Uniform', low_b, high_b)
|
||||
kl2 = self.u2.kl_loss('Uniform', low_b, high_b, low_a, high_a)
|
||||
return kl1 + kl2
|
||||
|
||||
def test_kl():
|
||||
|
@ -141,8 +141,8 @@ class UniformCrossEntropy(nn.Cell):
|
|||
self.u2 = msd.Uniform(dtype=dtype.float32)
|
||||
|
||||
def construct(self, low_b, high_b, low_a, high_a):
|
||||
h1 = self.u1('cross_entropy', 'Uniform', low_b, high_b)
|
||||
h2 = self.u2('cross_entropy', 'Uniform', low_b, high_b, low_a, high_a)
|
||||
h1 = self.u1.cross_entropy('Uniform', low_b, high_b)
|
||||
h2 = self.u2.cross_entropy('Uniform', low_b, high_b, low_a, high_a)
|
||||
return h1 + h2
|
||||
|
||||
def test_cross_entropy():
|
||||
|
@ -166,10 +166,10 @@ class UniformBasics(nn.Cell):
|
|||
self.u = msd.Uniform(3.0, 4.0, dtype=dtype.float32)
|
||||
|
||||
def construct(self):
|
||||
mean = self.u('mean')
|
||||
sd = self.u('sd')
|
||||
var = self.u('var')
|
||||
entropy = self.u('entropy')
|
||||
mean = self.u.mean()
|
||||
sd = self.u.sd()
|
||||
var = self.u.var()
|
||||
entropy = self.u.entropy()
|
||||
return mean + sd + var + entropy
|
||||
|
||||
def test_bascis():
|
||||
|
@ -179,3 +179,30 @@ def test_bascis():
|
|||
net = UniformBasics()
|
||||
ans = net()
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
|
||||
class UniConstruct(nn.Cell):
|
||||
"""
|
||||
Unifrom distribution: going through construct.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(UniConstruct, self).__init__()
|
||||
self.u = msd.Uniform(-4.0, 4.0)
|
||||
self.u1 = msd.Uniform()
|
||||
|
||||
def construct(self, value, low, high):
|
||||
prob = self.u('prob', value)
|
||||
prob1 = self.u('prob', value, low, high)
|
||||
prob2 = self.u1('prob', value, low, high)
|
||||
return prob + prob1 + prob2
|
||||
|
||||
def test_uniform_construct():
|
||||
"""
|
||||
Test probability function going through construct.
|
||||
"""
|
||||
net = UniConstruct()
|
||||
value = Tensor([-5.0, 0.0, 1.0, 5.0], dtype=dtype.float32)
|
||||
low = Tensor([-1.0], dtype=dtype.float32)
|
||||
high = Tensor([1.0], dtype=dtype.float32)
|
||||
ans = net(value, low, high)
|
||||
assert isinstance(ans, Tensor)
|
||||
|
|
Loading…
Reference in New Issue