forked from mindspore-Ecosystem/mindspore
Added lognormal distribuition
This commit is contained in:
parent
3eff68f8aa
commit
23ff21edd8
|
@ -24,6 +24,7 @@ from .exponential import Exponential
|
||||||
from .uniform import Uniform
|
from .uniform import Uniform
|
||||||
from .geometric import Geometric
|
from .geometric import Geometric
|
||||||
from .categorical import Categorical
|
from .categorical import Categorical
|
||||||
|
from .log_normal import LogNormal
|
||||||
|
|
||||||
__all__ = ['Distribution',
|
__all__ = ['Distribution',
|
||||||
'TransformedDistribution',
|
'TransformedDistribution',
|
||||||
|
@ -32,4 +33,6 @@ __all__ = ['Distribution',
|
||||||
'Exponential',
|
'Exponential',
|
||||||
'Uniform',
|
'Uniform',
|
||||||
'Categorical',
|
'Categorical',
|
||||||
'Geometric',]
|
'Geometric',
|
||||||
|
'LogNormal',
|
||||||
|
]
|
||||||
|
|
|
@ -76,7 +76,10 @@ class Distribution(Cell):
|
||||||
self._parameters[k] = param[k]
|
self._parameters[k] = param[k]
|
||||||
|
|
||||||
# some attributes
|
# some attributes
|
||||||
self.parameter_type = set_param_type(self.parameters['param_dict'], dtype)
|
if 'distribution' in self.parameters.keys():
|
||||||
|
self.parameter_type = self.parameters['distribution'].parameter_type
|
||||||
|
else:
|
||||||
|
self.parameter_type = set_param_type(self.parameters['param_dict'], dtype)
|
||||||
self._broadcast_shape = self._calc_broadcast_shape()
|
self._broadcast_shape = self._calc_broadcast_shape()
|
||||||
self._is_scalar_batch = self._check_is_scalar_batch()
|
self._is_scalar_batch = self._check_is_scalar_batch()
|
||||||
|
|
||||||
|
@ -206,8 +209,8 @@ class Distribution(Cell):
|
||||||
"""
|
"""
|
||||||
Check if the parameters used during initialization are scalars.
|
Check if the parameters used during initialization are scalars.
|
||||||
"""
|
"""
|
||||||
if hasattr(self, 'distribution'):
|
if 'distribution' in self.parameters.keys():
|
||||||
return self._distribution.is_scalar_batch
|
return self.parameters['distribution'].is_scalar_batch
|
||||||
param_dict = self.parameters['param_dict']
|
param_dict = self.parameters['param_dict']
|
||||||
for value in param_dict.values():
|
for value in param_dict.values():
|
||||||
if value is None:
|
if value is None:
|
||||||
|
@ -220,8 +223,8 @@ class Distribution(Cell):
|
||||||
"""
|
"""
|
||||||
Calculate the broadcast shape of the parameters used during initialization.
|
Calculate the broadcast shape of the parameters used during initialization.
|
||||||
"""
|
"""
|
||||||
if hasattr(self, 'distribution'):
|
if 'distribution' in self.parameters.keys():
|
||||||
return self._distribution.broadcast_shape
|
return self.parameters['distribution'].broadcast_shape
|
||||||
param_dict = self.parameters['param_dict']
|
param_dict = self.parameters['param_dict']
|
||||||
broadcast_shape_tensor = None
|
broadcast_shape_tensor = None
|
||||||
for value in param_dict.values():
|
for value in param_dict.values():
|
||||||
|
|
|
@ -0,0 +1,235 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""LogNormal Distribution"""
|
||||||
|
import numpy as np
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
import mindspore.nn.probability.bijector as msb
|
||||||
|
import mindspore.nn.probability.distribution as msd
|
||||||
|
from ._utils.utils import check_distribution_name
|
||||||
|
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic
|
||||||
|
|
||||||
|
class LogNormal(msd.TransformedDistribution):
|
||||||
|
"""
|
||||||
|
LogNormal distribution.
|
||||||
|
A log-normal (or lognormal) distribution is a continuous probability distribution of a random variable whose
|
||||||
|
logarithm is normally distributed. It is constructed as the exponential transformation of a Normal distribution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loc (int, float, list, numpy.ndarray, Tensor, Parameter): The mean of the underlying Normal distribution.
|
||||||
|
scale (int, float, list, numpy.ndarray, Tensor, Parameter): The standard deviation of the underlying
|
||||||
|
Normal distribution.
|
||||||
|
seed (int): the seed used in sampling. The global seed is used if it is None. Default: None.
|
||||||
|
dtype (mindspore.dtype): type of the distribution. Default: mstype.float32.
|
||||||
|
name (str): the name of the distribution. Default: 'LogNormal'.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
`scale` must be greater than zero.
|
||||||
|
`dist_spec_args` are `loc` and `scale`.
|
||||||
|
`dtype` must be a float type because LogNormal distributions are continuous.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # To initialize a LogNormal distribution of `loc` 3.0 and `scale` 4.0.
|
||||||
|
>>> n = msd.LogNormal(3.0, 4.0, dtype=mstype.float32)
|
||||||
|
>>>
|
||||||
|
>>> # The following creates two independent LogNormal distributions.
|
||||||
|
>>> n = msd.LogNormal([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32)
|
||||||
|
>>>
|
||||||
|
>>> # A LogNormal distribution can be initilize without arguments.
|
||||||
|
>>> # In this case, `loc` and `scale` must be passed in during function calls.
|
||||||
|
>>> n = msd.LogNormal(dtype=mstype.float32)
|
||||||
|
>>>
|
||||||
|
>>> # To use a LogNormal distribution in a network.
|
||||||
|
>>> class net(Cell):
|
||||||
|
>>> def __init__(self):
|
||||||
|
>>> super(net, self).__init__():
|
||||||
|
>>> self.n1 = msd.LogNormal(0.0, 1.0, dtype=mstype.float32)
|
||||||
|
>>> self.n2 = msd.LogNormal(dtype=mstype.float32)
|
||||||
|
>>>
|
||||||
|
>>> # The following calls are valid in construct.
|
||||||
|
>>> def construct(self, value, loc_b, scale_b, loc_a, scale_a):
|
||||||
|
>>>
|
||||||
|
>>> # Private interfaces of probability functions corresponding to public interfaces, including
|
||||||
|
>>> # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, have the same
|
||||||
|
>>> # arguments as follows.
|
||||||
|
>>> # Args:
|
||||||
|
>>> # value (Tensor): the value to be evaluated.
|
||||||
|
>>> # loc (Tensor): the loc of distribution. Default: None. If `loc` is passed in as None,
|
||||||
|
>>> # the mean of the underlying Normal distribution will be used.
|
||||||
|
>>> # scale (Tensor): the scale of distribution. Default: None. If `scale` is passed in as None,
|
||||||
|
>>> # the standard deviation of the underlying Normal distribution will be used.
|
||||||
|
>>>
|
||||||
|
>>> # Examples of `prob`.
|
||||||
|
>>> # Similar calls can be made to other probability functions
|
||||||
|
>>> # by replacing 'prob' by the name of the function.
|
||||||
|
>>> ans = self.n1.prob(value)
|
||||||
|
>>> # Evaluate with respect to distribution b.
|
||||||
|
>>> ans = self.n1.prob(value, loc_b, scale_b)
|
||||||
|
>>> # `loc` and `scale` must be passed in during function calls since they were not passed in construct.
|
||||||
|
>>> ans = self.n2.prob(value, loc_a, scale_a)
|
||||||
|
>>>
|
||||||
|
>>>
|
||||||
|
>>> # Functions `mean`, `sd`, `var`, and `entropy` have the same arguments.
|
||||||
|
>>> # Args:
|
||||||
|
>>> # loc (Tensor): the loc of distribution. Default: None. If `loc` is passed in as None,
|
||||||
|
>>> # the mean of the underlying Normal distribution will be used.
|
||||||
|
>>> # scale (Tensor): the scale of distribution. Default: None. If `scale` is passed in as None,
|
||||||
|
>>> # the standard deviation of the underlying Normal distribution will be used.
|
||||||
|
>>>
|
||||||
|
>>> # Example of `mean`. `sd`, `var`, and `entropy` are similar.
|
||||||
|
>>> ans = self.n1.mean() # return 0.0
|
||||||
|
>>> ans = self.n1.mean(loc_b, scale_b) # return mean_b
|
||||||
|
>>> # `loc` and `scale` must be passed in during function calls since they were not passed in construct.
|
||||||
|
>>> ans = self.n2.mean(loc_a, scale_a)
|
||||||
|
>>>
|
||||||
|
>>>
|
||||||
|
>>> # Interfaces of 'kl_loss' and 'cross_entropy' are the same:
|
||||||
|
>>> # Args:
|
||||||
|
>>> # dist (str): the type of the distributions. Only "Normal" is supported.
|
||||||
|
>>> # loc_b (Tensor): the loc of distribution b.
|
||||||
|
>>> # scale_b (Tensor): the scale distribution b.
|
||||||
|
>>> # loc_a (Tensor): the loc of distribution a. Default: None. If `loc` is passed in as None,
|
||||||
|
>>> # the mean of the underlying Normal distribution will be used.
|
||||||
|
>>> # scale_a (Tensor): the scale distribution a. Default: None. If `scale` is passed in as None,
|
||||||
|
>>> # the standard deviation of the underlying Normal distribution will be used.
|
||||||
|
>>>
|
||||||
|
>>> # Examples of `kl_loss`. `cross_entropy` is similar.
|
||||||
|
>>> ans = self.n1.kl_loss('Normal', loc_b, scale_b)
|
||||||
|
>>> ans = self.n1.kl_loss('Normal', loc_b, scale_b, loc_a, scale_a)
|
||||||
|
>>> # Additional `loc` and `scale` must be passed in since they were not passed in construct.
|
||||||
|
>>> ans = self.n2.kl_loss('Normal', loc_b, scale_b, loc_a, scale_a)
|
||||||
|
>>>
|
||||||
|
>>> # Examples of `sample`.
|
||||||
|
>>> # Args:
|
||||||
|
>>> # shape (tuple): the shape of the sample. Default: ()
|
||||||
|
>>> # loc (Tensor): the loc of the distribution. Default: None. If `loc` is passed in as None,
|
||||||
|
>>> # the mean of the underlying Normal distribution will be used.
|
||||||
|
>>> # scale (Tensor): the scale of the distribution. Default: None. If `scale` is passed in as None,
|
||||||
|
>>> # the standard deviation of the underlying Normal distribution will be used.
|
||||||
|
>>> ans = self.n1.sample()
|
||||||
|
>>> ans = self.n1.sample((2,3))
|
||||||
|
>>> ans = self.n1.sample((2,3), loc_b, scale_b)
|
||||||
|
>>> ans = self.n2.sample((2,3), loc_a, scale_a)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
loc=None,
|
||||||
|
scale=None,
|
||||||
|
seed=0,
|
||||||
|
dtype=mstype.float32,
|
||||||
|
name="LogNormal"):
|
||||||
|
"""
|
||||||
|
Constructor of LogNormal distribution.
|
||||||
|
"""
|
||||||
|
super(LogNormal, self).__init__(distribution=msd.Normal(loc, scale, dtype=dtype),
|
||||||
|
bijector=msb.Exp(),
|
||||||
|
dtype=dtype, seed=seed, name=name)
|
||||||
|
|
||||||
|
self.log_2pi = np.log(2 * np.pi)
|
||||||
|
|
||||||
|
#ops needed for the class
|
||||||
|
self.exp = exp_generic
|
||||||
|
self.expm1 = expm1_generic
|
||||||
|
self.log = log_generic
|
||||||
|
self.const = P.ScalarToArray()
|
||||||
|
self.erf = P.Erf()
|
||||||
|
self.fill = P.Fill()
|
||||||
|
self.shape = P.Shape()
|
||||||
|
self.sq = P.Square()
|
||||||
|
self.sqrt = P.Sqrt()
|
||||||
|
self.zeroslike = P.ZerosLike()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loc(self):
|
||||||
|
"""Distribution parameter for the pre-transformed mean."""
|
||||||
|
return self.distribution("mean")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scale(self):
|
||||||
|
"""Distribution parameter for the pre-transformed standard deviation."""
|
||||||
|
return self.distribution("sd")
|
||||||
|
|
||||||
|
def extend_repr(self):
|
||||||
|
if self.is_scalar_batch:
|
||||||
|
str_info = f'loc = {self._mean_value}, scale = {self._sd_value}'
|
||||||
|
else:
|
||||||
|
str_info = f'batch_shape = {self._broadcast_shape}'
|
||||||
|
return str_info
|
||||||
|
|
||||||
|
def _mean(self, loc=None, scale=None):
|
||||||
|
"""
|
||||||
|
The mean of the distribution.
|
||||||
|
"""
|
||||||
|
mean, sd = self._check_param_type(loc, scale)
|
||||||
|
var = self.distribution("var", mean=mean, sd=sd)
|
||||||
|
return self.exp(mean + 0.5 * var)
|
||||||
|
|
||||||
|
def _mode(self, loc=None, scale=None):
|
||||||
|
"""
|
||||||
|
The mode of the distribution.
|
||||||
|
"""
|
||||||
|
mean, sd = self._check_param_type(loc, scale)
|
||||||
|
var = self.distribution("var", mean=mean, sd=sd)
|
||||||
|
return self.exp(mean - var)
|
||||||
|
|
||||||
|
def _var(self, loc=None, scale=None):
|
||||||
|
"""
|
||||||
|
The varience of the distribution.
|
||||||
|
"""
|
||||||
|
mean, sd = self._check_param_type(loc, scale)
|
||||||
|
var = self.distribution("var", mean=mean, sd=sd)
|
||||||
|
return self.expm1(var) * self.exp(2. * mean + var)
|
||||||
|
|
||||||
|
def _entropy(self, loc=None, scale=None):
|
||||||
|
r"""
|
||||||
|
Evaluate entropy.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
H(X) = μ + 0.5 + \log(σ) + 0.5 * \log(2pi)
|
||||||
|
"""
|
||||||
|
mean, sd = self._check_param_type(loc, scale)
|
||||||
|
return mean + 0.5 + self.log(sd) + 0.5 * self.log_2pi
|
||||||
|
|
||||||
|
def _cross_entropy(self, dist, loc_b, scale_b, loc_a=None, scale_a=None):
|
||||||
|
r"""
|
||||||
|
Evaluate cross entropy between lognormal distributions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dist (str): The type of the distributions. Should be "LogNormal" in this case.
|
||||||
|
loc_b (Tensor): The loc of distribution b.
|
||||||
|
scale_b (Tensor): The scale of distribution b.
|
||||||
|
loc_a (Tensor): The loc of distribution a. Default: None.
|
||||||
|
scale_a (Tensor): The scale of distribution a. Default: None.
|
||||||
|
"""
|
||||||
|
check_distribution_name(dist, 'LogNormal')
|
||||||
|
return self._entropy(loc_a, scale_a) + self._kl_loss(dist, loc_b, scale_b, loc_a, scale_a)
|
||||||
|
|
||||||
|
def _kl_loss(self, dist, loc_b, scale_b, loc_a=None, scale_a=None):
|
||||||
|
r"""
|
||||||
|
Evaluate LogNormal-LogNormal kl divergence, i.e. KL(a||b).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dist (str): The type of the distributions. Should be "LogNormal" in this case.
|
||||||
|
loc_b (Tensor): The loc of distribution b.
|
||||||
|
scale_b (Tensor): The scale of distribution b.
|
||||||
|
loc_a (Tensor): The loc of distribution a. Default: None.
|
||||||
|
scale_a (Tensor): The scale of distribution a. Default: None.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
KL(a||b) = 0.5 * (\fract{MEAN(a)}{STD(b)} - \fract{MEAN(b)}{STD(b)}) ^ 2 +
|
||||||
|
0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b)))
|
||||||
|
"""
|
||||||
|
check_distribution_name(dist, 'LogNormal')
|
||||||
|
return self.distribution("kl_loss", 'Normal', loc_b, scale_b, loc_a, scale_a)
|
|
@ -30,6 +30,8 @@ class TransformedDistribution(Distribution):
|
||||||
Args:
|
Args:
|
||||||
bijector (Bijector): The transformation to perform.
|
bijector (Bijector): The transformation to perform.
|
||||||
distribution (Distribution): The original distribution.
|
distribution (Distribution): The original distribution.
|
||||||
|
dtype (mindspore.dtype): The type of the event samples.
|
||||||
|
seed (int): The seed is used in sampling. The global seed is used if it is None.
|
||||||
name (str): The name of the transformed distribution. Default: 'transformed_distribution'.
|
name (str): The name of the transformed distribution. Default: 'transformed_distribution'.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
|
@ -98,38 +100,38 @@ class TransformedDistribution(Distribution):
|
||||||
def is_linear_transformation(self):
|
def is_linear_transformation(self):
|
||||||
return self._is_linear_transformation
|
return self._is_linear_transformation
|
||||||
|
|
||||||
def _cdf(self, *args, **kwargs):
|
def _cdf(self, value, *args, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
.. math::
|
.. math::
|
||||||
Y = g(X)
|
Y = g(X)
|
||||||
P(Y <= a) = P(X <= g^{-1}(a))
|
P(Y <= a) = P(X <= g^{-1}(a))
|
||||||
"""
|
"""
|
||||||
inverse_value = self.bijector("inverse", *args, **kwargs)
|
inverse_value = self.bijector("inverse", value)
|
||||||
return self.distribution("cdf", inverse_value)
|
return self.distribution("cdf", inverse_value, *args, **kwargs)
|
||||||
|
|
||||||
def _log_cdf(self, *args, **kwargs):
|
def _log_cdf(self, value, *args, **kwargs):
|
||||||
return self.log(self._cdf(*args, **kwargs))
|
return self.log(self._cdf(value, *args, **kwargs))
|
||||||
|
|
||||||
def _survival_function(self, *args, **kwargs):
|
def _survival_function(self, value, *args, **kwargs):
|
||||||
return 1.0 - self._cdf(*args, **kwargs)
|
return 1.0 - self._cdf(value, *args, **kwargs)
|
||||||
|
|
||||||
def _log_survival(self, *args, **kwargs):
|
def _log_survival(self, value, *args, **kwargs):
|
||||||
return self.log(self._survival_function(*args, **kwargs))
|
return self.log(self._survival_function(value, *args, **kwargs))
|
||||||
|
|
||||||
def _log_prob(self, *args, **kwargs):
|
def _log_prob(self, value, *args, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
.. math::
|
.. math::
|
||||||
Y = g(X)
|
Y = g(X)
|
||||||
Py(a) = Px(g^{-1}(a)) * (g^{-1})'(a)
|
Py(a) = Px(g^{-1}(a)) * (g^{-1})'(a)
|
||||||
\log(Py(a)) = \log(Px(g^{-1}(a))) + \log((g^{-1})'(a))
|
\log(Py(a)) = \log(Px(g^{-1}(a))) + \log((g^{-1})'(a))
|
||||||
"""
|
"""
|
||||||
inverse_value = self.bijector("inverse", *args, **kwargs)
|
inverse_value = self.bijector("inverse", value)
|
||||||
unadjust_prob = self.distribution("log_prob", inverse_value)
|
unadjust_prob = self.distribution("log_prob", inverse_value, *args, **kwargs)
|
||||||
log_jacobian = self.bijector("inverse_log_jacobian", *args, **kwargs)
|
log_jacobian = self.bijector("inverse_log_jacobian", value)
|
||||||
return unadjust_prob + log_jacobian
|
return unadjust_prob + log_jacobian
|
||||||
|
|
||||||
def _prob(self, *args, **kwargs):
|
def _prob(self, value, *args, **kwargs):
|
||||||
return self.exp(self._log_prob(*args, **kwargs))
|
return self.exp(self._log_prob(value, *args, **kwargs))
|
||||||
|
|
||||||
def _sample(self, *args, **kwargs):
|
def _sample(self, *args, **kwargs):
|
||||||
org_sample = self.distribution("sample", *args, **kwargs)
|
org_sample = self.distribution("sample", *args, **kwargs)
|
||||||
|
|
|
@ -0,0 +1,322 @@
|
||||||
|
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""test cases for LogNormal distribution"""
|
||||||
|
import numpy as np
|
||||||
|
from scipy import stats
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.nn.probability.distribution as msd
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore import dtype
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
|
||||||
|
class Prob(nn.Cell):
|
||||||
|
"""
|
||||||
|
Test class: probability of LogNormal distribution.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super(Prob, self).__init__()
|
||||||
|
self.ln = msd.LogNormal(np.array([0.3]), np.array([[0.2], [0.4]]), dtype=dtype.float32)
|
||||||
|
|
||||||
|
def construct(self, x_):
|
||||||
|
return self.ln.prob(x_)
|
||||||
|
|
||||||
|
def test_pdf():
|
||||||
|
"""
|
||||||
|
Test pdf.
|
||||||
|
"""
|
||||||
|
lognorm_benchmark = stats.lognorm(s=np.array([[0.2], [0.4]]), scale=np.exp(np.array([0.3])))
|
||||||
|
expect_pdf = lognorm_benchmark.pdf([1.0, 2.0]).astype(np.float32)
|
||||||
|
pdf = Prob()
|
||||||
|
output = pdf(Tensor([1.0, 2.0], dtype=dtype.float32))
|
||||||
|
tol = 1e-6
|
||||||
|
assert (np.abs(output.asnumpy() - expect_pdf) < tol).all()
|
||||||
|
|
||||||
|
class LogProb(nn.Cell):
|
||||||
|
"""
|
||||||
|
Test class: log probability of LogNormal distribution.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super(LogProb, self).__init__()
|
||||||
|
self.ln = msd.LogNormal(np.array([0.3]), np.array([[0.2], [0.4]]), dtype=dtype.float32)
|
||||||
|
|
||||||
|
def construct(self, x_):
|
||||||
|
return self.ln.log_prob(x_)
|
||||||
|
|
||||||
|
def test_log_likelihood():
|
||||||
|
"""
|
||||||
|
Test log_pdf.
|
||||||
|
"""
|
||||||
|
lognorm_benchmark = stats.lognorm(s=np.array([[0.2], [0.4]]), scale=np.exp(np.array([0.3])))
|
||||||
|
expect_logpdf = lognorm_benchmark.logpdf([1.0, 2.0]).astype(np.float32)
|
||||||
|
logprob = LogProb()
|
||||||
|
output = logprob(Tensor([1.0, 2.0], dtype=dtype.float32))
|
||||||
|
tol = 1e-6
|
||||||
|
assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all()
|
||||||
|
|
||||||
|
class KL(nn.Cell):
|
||||||
|
"""
|
||||||
|
Test class: kl_loss of LogNormal distribution.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super(KL, self).__init__()
|
||||||
|
self.ln = msd.LogNormal(np.array([0.3]), np.array([0.4]), dtype=dtype.float32)
|
||||||
|
|
||||||
|
def construct(self, x_, y_):
|
||||||
|
return self.ln.kl_loss('LogNormal', x_, y_)
|
||||||
|
|
||||||
|
def test_kl_loss():
|
||||||
|
"""
|
||||||
|
Test kl_loss.
|
||||||
|
"""
|
||||||
|
mean_a = np.array([0.3]).astype(np.float32)
|
||||||
|
sd_a = np.array([0.4]).astype(np.float32)
|
||||||
|
|
||||||
|
mean_b = np.array([1.0]).astype(np.float32)
|
||||||
|
sd_b = np.array([1.0]).astype(np.float32)
|
||||||
|
|
||||||
|
diff_log_scale = np.log(sd_a) - np.log(sd_b)
|
||||||
|
squared_diff = np.square(mean_a / sd_b - mean_b / sd_b)
|
||||||
|
expect_kl_loss = 0.5 * squared_diff + 0.5 * np.expm1(2 * diff_log_scale) - diff_log_scale
|
||||||
|
|
||||||
|
kl_loss = KL()
|
||||||
|
mean = Tensor(mean_b, dtype=dtype.float32)
|
||||||
|
sd = Tensor(sd_b, dtype=dtype.float32)
|
||||||
|
output = kl_loss(mean, sd)
|
||||||
|
tol = 1e-6
|
||||||
|
assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all()
|
||||||
|
|
||||||
|
class Basics(nn.Cell):
|
||||||
|
"""
|
||||||
|
Test class: mean/sd/mode of LogNormal distribution.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super(Basics, self).__init__()
|
||||||
|
self.ln = msd.LogNormal(np.array([0.3]), np.array([[0.2], [0.4]]), dtype=dtype.float32)
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
return self.ln.mean(), self.ln.sd(), self.ln.mode()
|
||||||
|
|
||||||
|
def test_basics():
|
||||||
|
"""
|
||||||
|
Test mean/standard deviation/mode.
|
||||||
|
"""
|
||||||
|
basics = Basics()
|
||||||
|
mean, sd, mode = basics()
|
||||||
|
lognorm_benchmark = stats.lognorm(s=np.array([[0.2], [0.4]]), scale=np.exp(np.array([0.3])))
|
||||||
|
expect_mean = lognorm_benchmark.mean().astype(np.float32)
|
||||||
|
expect_sd = lognorm_benchmark.std().astype(np.float32)
|
||||||
|
expect_mode = (lognorm_benchmark.median() / np.exp(np.square([[0.2], [0.4]]))).astype(np.float32)
|
||||||
|
tol = 1e-6
|
||||||
|
assert (np.abs(mean.asnumpy() - expect_mean) < tol).all()
|
||||||
|
assert (np.abs(mode.asnumpy() - expect_mode) < tol).all()
|
||||||
|
assert (np.abs(sd.asnumpy() - expect_sd) < tol).all()
|
||||||
|
|
||||||
|
class Sampling(nn.Cell):
|
||||||
|
"""
|
||||||
|
Test class: sample of LogNormal distribution.
|
||||||
|
"""
|
||||||
|
def __init__(self, shape, seed=0):
|
||||||
|
super(Sampling, self).__init__()
|
||||||
|
self.ln = msd.LogNormal(np.array([0.3]), np.array([[0.2], [0.4]]), seed=seed, dtype=dtype.float32)
|
||||||
|
self.shape = shape
|
||||||
|
|
||||||
|
def construct(self, mean=None, sd=None):
|
||||||
|
return self.ln.sample(self.shape, mean, sd)
|
||||||
|
|
||||||
|
def test_sample():
|
||||||
|
"""
|
||||||
|
Test sample.
|
||||||
|
"""
|
||||||
|
shape = (2, 3)
|
||||||
|
seed = 10
|
||||||
|
mean = Tensor([2.0], dtype=dtype.float32)
|
||||||
|
sd = Tensor([2.0, 2.0, 2.0], dtype=dtype.float32)
|
||||||
|
sample = Sampling(shape, seed=seed)
|
||||||
|
output = sample(mean, sd)
|
||||||
|
assert output.shape == (2, 3, 3)
|
||||||
|
|
||||||
|
class CDF(nn.Cell):
|
||||||
|
"""
|
||||||
|
Test class: cdf of LogNormal distribution.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super(CDF, self).__init__()
|
||||||
|
self.ln = msd.LogNormal(np.array([0.3]), np.array([[0.2], [0.4]]), dtype=dtype.float32)
|
||||||
|
|
||||||
|
def construct(self, x_):
|
||||||
|
return self.ln.cdf(x_)
|
||||||
|
|
||||||
|
def test_cdf():
|
||||||
|
"""
|
||||||
|
Test cdf.
|
||||||
|
"""
|
||||||
|
lognorm_benchmark = stats.lognorm(s=np.array([[0.2], [0.4]]), scale=np.exp(np.array([0.3])))
|
||||||
|
expect_cdf = lognorm_benchmark.cdf([1.0, 2.0]).astype(np.float32)
|
||||||
|
cdf = CDF()
|
||||||
|
output = cdf(Tensor([1.0, 2.0], dtype=dtype.float32))
|
||||||
|
tol = 2e-5
|
||||||
|
assert (np.abs(output.asnumpy() - expect_cdf) < tol).all()
|
||||||
|
|
||||||
|
class LogCDF(nn.Cell):
|
||||||
|
"""
|
||||||
|
Test class: log_cdf of Mormal distribution.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super(LogCDF, self).__init__()
|
||||||
|
self.ln = msd.LogNormal(np.array([0.3]), np.array([[0.2], [0.4]]), dtype=dtype.float32)
|
||||||
|
|
||||||
|
def construct(self, x_):
|
||||||
|
return self.ln.log_cdf(x_)
|
||||||
|
|
||||||
|
def test_log_cdf():
|
||||||
|
"""
|
||||||
|
Test log cdf.
|
||||||
|
"""
|
||||||
|
lognorm_benchmark = stats.lognorm(s=np.array([[0.2], [0.4]]), scale=np.exp(np.array([0.3])))
|
||||||
|
expect_logcdf = lognorm_benchmark.logcdf([1.0, 2.0]).astype(np.float32)
|
||||||
|
logcdf = LogCDF()
|
||||||
|
output = logcdf(Tensor([1.0, 2.0], dtype=dtype.float32))
|
||||||
|
tol = 1e-4
|
||||||
|
assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all()
|
||||||
|
|
||||||
|
class SF(nn.Cell):
|
||||||
|
"""
|
||||||
|
Test class: survival function of LogNormal distribution.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super(SF, self).__init__()
|
||||||
|
self.ln = msd.LogNormal(np.array([0.3]), np.array([[0.2], [0.4]]), dtype=dtype.float32)
|
||||||
|
|
||||||
|
def construct(self, x_):
|
||||||
|
return self.ln.survival_function(x_)
|
||||||
|
|
||||||
|
def test_survival():
|
||||||
|
"""
|
||||||
|
Test log_survival.
|
||||||
|
"""
|
||||||
|
lognorm_benchmark = stats.lognorm(s=np.array([[0.2], [0.4]]), scale=np.exp(np.array([0.3])))
|
||||||
|
expect_survival = lognorm_benchmark.sf([1.0, 2.0]).astype(np.float32)
|
||||||
|
survival_function = SF()
|
||||||
|
output = survival_function(Tensor([1.0, 2.0], dtype=dtype.float32))
|
||||||
|
tol = 2e-5
|
||||||
|
assert (np.abs(output.asnumpy() - expect_survival) < tol).all()
|
||||||
|
|
||||||
|
class LogSF(nn.Cell):
|
||||||
|
"""
|
||||||
|
Test class: log survival function of LogNormal distribution.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super(LogSF, self).__init__()
|
||||||
|
self.ln = msd.LogNormal(np.array([0.3]), np.array([[0.2], [0.4]]), dtype=dtype.float32)
|
||||||
|
|
||||||
|
def construct(self, x_):
|
||||||
|
return self.ln.log_survival(x_)
|
||||||
|
|
||||||
|
def test_log_survival():
|
||||||
|
"""
|
||||||
|
Test log_survival.
|
||||||
|
"""
|
||||||
|
lognorm_benchmark = stats.lognorm(s=np.array([[0.2], [0.4]]), scale=np.exp(np.array([0.3])))
|
||||||
|
expect_log_survival = lognorm_benchmark.logsf([1.0, 2.0]).astype(np.float32)
|
||||||
|
log_survival = LogSF()
|
||||||
|
output = log_survival(Tensor([1.0, 2.0], dtype=dtype.float32))
|
||||||
|
tol = 5e-4
|
||||||
|
assert (np.abs(output.asnumpy() - expect_log_survival) < tol).all()
|
||||||
|
|
||||||
|
class EntropyH(nn.Cell):
|
||||||
|
"""
|
||||||
|
Test class: entropy of LogNormal distribution.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super(EntropyH, self).__init__()
|
||||||
|
self.ln = msd.LogNormal(np.array([0.3]), np.array([[0.2], [0.4]]), dtype=dtype.float32)
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
return self.ln.entropy()
|
||||||
|
|
||||||
|
def test_entropy():
|
||||||
|
"""
|
||||||
|
Test entropy.
|
||||||
|
"""
|
||||||
|
lognorm_benchmark = stats.lognorm(s=np.array([[0.2], [0.4]]), scale=np.exp(np.array([0.3])))
|
||||||
|
expect_entropy = lognorm_benchmark.entropy().astype(np.float32)
|
||||||
|
entropy = EntropyH()
|
||||||
|
output = entropy()
|
||||||
|
tol = 1e-6
|
||||||
|
assert (np.abs(output.asnumpy() - expect_entropy) < tol).all()
|
||||||
|
|
||||||
|
class CrossEntropy(nn.Cell):
|
||||||
|
"""
|
||||||
|
Test class: cross entropy between LogNormal distributions.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super(CrossEntropy, self).__init__()
|
||||||
|
self.ln = msd.LogNormal(np.array([0.3]), np.array([0.4]), dtype=dtype.float32)
|
||||||
|
|
||||||
|
def construct(self, x_, y_):
|
||||||
|
entropy = self.ln.entropy()
|
||||||
|
kl_loss = self.ln.kl_loss('LogNormal', x_, y_)
|
||||||
|
h_sum_kl = entropy + kl_loss
|
||||||
|
cross_entropy = self.ln.cross_entropy('LogNormal', x_, y_)
|
||||||
|
return h_sum_kl - cross_entropy
|
||||||
|
|
||||||
|
def test_cross_entropy():
|
||||||
|
"""
|
||||||
|
Test cross_entropy.
|
||||||
|
"""
|
||||||
|
cross_entropy = CrossEntropy()
|
||||||
|
mean = Tensor([1.0], dtype=dtype.float32)
|
||||||
|
sd = Tensor([1.0], dtype=dtype.float32)
|
||||||
|
diff = cross_entropy(mean, sd)
|
||||||
|
tol = 1e-6
|
||||||
|
assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all()
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
"""
|
||||||
|
Test class: expand single distribution instance to multiple graphs
|
||||||
|
by specifying the attributes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.LogNormal = msd.LogNormal(0., 1., dtype=dtype.float32)
|
||||||
|
|
||||||
|
def construct(self, x_, y_):
|
||||||
|
kl = self.LogNormal.kl_loss('LogNormal', x_, y_)
|
||||||
|
prob = self.LogNormal.prob(kl)
|
||||||
|
return prob
|
||||||
|
|
||||||
|
def test_multiple_graphs():
|
||||||
|
"""
|
||||||
|
Test multiple graphs case.
|
||||||
|
"""
|
||||||
|
prob = Net()
|
||||||
|
mean_a = np.array([0.0]).astype(np.float32)
|
||||||
|
sd_a = np.array([1.0]).astype(np.float32)
|
||||||
|
mean_b = np.array([1.0]).astype(np.float32)
|
||||||
|
sd_b = np.array([1.0]).astype(np.float32)
|
||||||
|
ans = prob(Tensor(mean_b), Tensor(sd_b))
|
||||||
|
|
||||||
|
diff_log_scale = np.log(sd_a) - np.log(sd_b)
|
||||||
|
squared_diff = np.square(mean_a / sd_b - mean_b / sd_b)
|
||||||
|
expect_kl_loss = 0.5 * squared_diff + 0.5 * \
|
||||||
|
np.expm1(2 * diff_log_scale) - diff_log_scale
|
||||||
|
lognorm_benchmark = stats.lognorm(s=np.array([1.]), scale=np.exp(np.array([0.])))
|
||||||
|
expect_prob = lognorm_benchmark.pdf(expect_kl_loss).astype(np.float32)
|
||||||
|
|
||||||
|
tol = 1e-6
|
||||||
|
assert (np.abs(ans.asnumpy() - expect_prob) < tol).all()
|
|
@ -0,0 +1,216 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
Test nn.probability.distribution.LogNormal.
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.nn.probability.distribution as msd
|
||||||
|
from mindspore import dtype
|
||||||
|
from mindspore import Tensor
|
||||||
|
|
||||||
|
def test_lognormal_shape_errpr():
|
||||||
|
"""
|
||||||
|
Invalid shapes.
|
||||||
|
"""
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
msd.LogNormal([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32)
|
||||||
|
|
||||||
|
def test_type():
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
msd.LogNormal(0., 1., dtype=dtype.int32)
|
||||||
|
|
||||||
|
def test_name():
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
msd.LogNormal(0., 1., name=1.0)
|
||||||
|
|
||||||
|
def test_seed():
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
msd.LogNormal(0., 1., seed='seed')
|
||||||
|
|
||||||
|
def test_sd():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
msd.LogNormal(0., 0.)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
msd.LogNormal(0., -1.)
|
||||||
|
|
||||||
|
def test_arguments():
|
||||||
|
"""
|
||||||
|
args passing during initialization.
|
||||||
|
"""
|
||||||
|
n = msd.LogNormal()
|
||||||
|
assert isinstance(n, msd.Distribution)
|
||||||
|
n = msd.LogNormal([3.0], [4.0], dtype=dtype.float32)
|
||||||
|
assert isinstance(n, msd.Distribution)
|
||||||
|
|
||||||
|
|
||||||
|
class LogNormalProb(nn.Cell):
|
||||||
|
"""
|
||||||
|
LogNormal distribution: initialize with mean/sd.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super(LogNormalProb, self).__init__()
|
||||||
|
self.lognormal = msd.LogNormal(3.0, 4.0, dtype=dtype.float32)
|
||||||
|
|
||||||
|
def construct(self, value):
|
||||||
|
prob = self.lognormal.prob(value)
|
||||||
|
log_prob = self.lognormal.log_prob(value)
|
||||||
|
cdf = self.lognormal.cdf(value)
|
||||||
|
log_cdf = self.lognormal.log_cdf(value)
|
||||||
|
sf = self.lognormal.survival_function(value)
|
||||||
|
log_sf = self.lognormal.log_survival(value)
|
||||||
|
return prob + log_prob + cdf + log_cdf + sf + log_sf
|
||||||
|
|
||||||
|
def test_lognormal_prob():
|
||||||
|
"""
|
||||||
|
Test probability functions: passing value through construct.
|
||||||
|
"""
|
||||||
|
net = LogNormalProb()
|
||||||
|
value = Tensor([0.5, 1.0], dtype=dtype.float32)
|
||||||
|
ans = net(value)
|
||||||
|
assert isinstance(ans, Tensor)
|
||||||
|
|
||||||
|
|
||||||
|
class LogNormalProb1(nn.Cell):
|
||||||
|
"""
|
||||||
|
LogNormal distribution: initialize without mean/sd.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super(LogNormalProb1, self).__init__()
|
||||||
|
self.lognormal = msd.LogNormal()
|
||||||
|
|
||||||
|
def construct(self, value, mean, sd):
|
||||||
|
prob = self.lognormal.prob(value, mean, sd)
|
||||||
|
log_prob = self.lognormal.log_prob(value, mean, sd)
|
||||||
|
cdf = self.lognormal.cdf(value, mean, sd)
|
||||||
|
log_cdf = self.lognormal.log_cdf(value, mean, sd)
|
||||||
|
sf = self.lognormal.survival_function(value, mean, sd)
|
||||||
|
log_sf = self.lognormal.log_survival(value, mean, sd)
|
||||||
|
return prob + log_prob + cdf + log_cdf + sf + log_sf
|
||||||
|
|
||||||
|
def test_lognormal_prob1():
|
||||||
|
"""
|
||||||
|
Test probability functions: passing mean/sd, value through construct.
|
||||||
|
"""
|
||||||
|
net = LogNormalProb1()
|
||||||
|
value = Tensor([0.5, 1.0], dtype=dtype.float32)
|
||||||
|
mean = Tensor([0.0], dtype=dtype.float32)
|
||||||
|
sd = Tensor([1.0], dtype=dtype.float32)
|
||||||
|
ans = net(value, mean, sd)
|
||||||
|
assert isinstance(ans, Tensor)
|
||||||
|
|
||||||
|
class LogNormalKl(nn.Cell):
|
||||||
|
"""
|
||||||
|
Test class: kl_loss of LogNormal distribution.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super(LogNormalKl, self).__init__()
|
||||||
|
self.n1 = msd.LogNormal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
|
||||||
|
self.n2 = msd.LogNormal(dtype=dtype.float32)
|
||||||
|
|
||||||
|
def construct(self, mean_b, sd_b, mean_a, sd_a):
|
||||||
|
kl1 = self.n1.kl_loss('LogNormal', mean_b, sd_b)
|
||||||
|
kl2 = self.n2.kl_loss('LogNormal', mean_b, sd_b, mean_a, sd_a)
|
||||||
|
return kl1 + kl2
|
||||||
|
|
||||||
|
def test_kl():
|
||||||
|
"""
|
||||||
|
Test kl_loss.
|
||||||
|
"""
|
||||||
|
net = LogNormalKl()
|
||||||
|
mean_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32)
|
||||||
|
sd_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32)
|
||||||
|
mean_a = Tensor(np.array([2.0]).astype(np.float32), dtype=dtype.float32)
|
||||||
|
sd_a = Tensor(np.array([3.0]).astype(np.float32), dtype=dtype.float32)
|
||||||
|
ans = net(mean_b, sd_b, mean_a, sd_a)
|
||||||
|
assert isinstance(ans, Tensor)
|
||||||
|
|
||||||
|
class LogNormalCrossEntropy(nn.Cell):
|
||||||
|
"""
|
||||||
|
Test class: cross_entropy of LogNormal distribution.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super(LogNormalCrossEntropy, self).__init__()
|
||||||
|
self.n1 = msd.LogNormal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
|
||||||
|
self.n2 = msd.LogNormal(dtype=dtype.float32)
|
||||||
|
|
||||||
|
def construct(self, mean_b, sd_b, mean_a, sd_a):
|
||||||
|
h1 = self.n1.cross_entropy('LogNormal', mean_b, sd_b)
|
||||||
|
h2 = self.n2.cross_entropy('LogNormal', mean_b, sd_b, mean_a, sd_a)
|
||||||
|
return h1 + h2
|
||||||
|
|
||||||
|
def test_cross_entropy():
|
||||||
|
"""
|
||||||
|
Test cross entropy between LogNormal distributions.
|
||||||
|
"""
|
||||||
|
net = LogNormalCrossEntropy()
|
||||||
|
mean_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32)
|
||||||
|
sd_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32)
|
||||||
|
mean_a = Tensor(np.array([2.0]).astype(np.float32), dtype=dtype.float32)
|
||||||
|
sd_a = Tensor(np.array([3.0]).astype(np.float32), dtype=dtype.float32)
|
||||||
|
ans = net(mean_b, sd_b, mean_a, sd_a)
|
||||||
|
assert isinstance(ans, Tensor)
|
||||||
|
|
||||||
|
class LogNormalBasics(nn.Cell):
|
||||||
|
"""
|
||||||
|
Test class: basic mean/sd function.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super(LogNormalBasics, self).__init__()
|
||||||
|
self.n = msd.LogNormal(3.0, 4.0, dtype=dtype.float32)
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
mean = self.n.mean()
|
||||||
|
sd = self.n.sd()
|
||||||
|
mode = self.n.mode()
|
||||||
|
entropy = self.n.entropy()
|
||||||
|
return mean + sd + mode + entropy
|
||||||
|
|
||||||
|
def test_bascis():
|
||||||
|
"""
|
||||||
|
Test mean/sd/mode/entropy functionality of LogNormal.
|
||||||
|
"""
|
||||||
|
net = LogNormalBasics()
|
||||||
|
ans = net()
|
||||||
|
assert isinstance(ans, Tensor)
|
||||||
|
|
||||||
|
|
||||||
|
class LogNormalConstruct(nn.Cell):
|
||||||
|
"""
|
||||||
|
LogNormal distribution: going through construct.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super(LogNormalConstruct, self).__init__()
|
||||||
|
self.lognormal = msd.LogNormal(3.0, 4.0)
|
||||||
|
self.lognormal1 = msd.LogNormal()
|
||||||
|
|
||||||
|
def construct(self, value, mean, sd):
|
||||||
|
prob = self.lognormal('prob', value)
|
||||||
|
prob1 = self.lognormal('prob', value, mean, sd)
|
||||||
|
prob2 = self.lognormal1('prob', value, mean, sd)
|
||||||
|
return prob + prob1 + prob2
|
||||||
|
|
||||||
|
def test_lognormal_construct():
|
||||||
|
"""
|
||||||
|
Test probability function going through construct.
|
||||||
|
"""
|
||||||
|
net = LogNormalConstruct()
|
||||||
|
value = Tensor([0.5, 1.0], dtype=dtype.float32)
|
||||||
|
mean = Tensor([0.0], dtype=dtype.float32)
|
||||||
|
sd = Tensor([1.0], dtype=dtype.float32)
|
||||||
|
ans = net(value, mean, sd)
|
||||||
|
assert isinstance(ans, Tensor)
|
Loading…
Reference in New Issue