added Gumbel distribution

This commit is contained in:
Xun Deng 2020-10-14 12:01:51 -04:00
parent 40b4844b76
commit ce170b2241
9 changed files with 791 additions and 8 deletions

View File

@ -17,7 +17,7 @@ from mindspore import context
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from ..distribution._utils.utils import CheckTensor from ..distribution._utils.utils import CheckTensor, cast_to_tensor
from ..distribution import Distribution from ..distribution import Distribution
from ..distribution import TransformedDistribution from ..distribution import TransformedDistribution
@ -66,6 +66,8 @@ class Bijector(Cell):
# ops needed for the base class # ops needed for the base class
self.cast_base = P.Cast() self.cast_base = P.Cast()
self.dtype_base = P.DType() self.dtype_base = P.DType()
self.shape_base = P.Shape()
self.fill_base = P.Fill()
@property @property
def name(self): def name(self):
@ -87,6 +89,36 @@ class Bijector(Cell):
def is_injective(self): def is_injective(self):
return self._is_injective return self._is_injective
def _add_parameter(self, value, name):
"""
Cast `value` to a tensor and add it to `self.default_parameters`.
Add `name` into and `self.parameter_names`.
"""
# initialize the attributes if they do not exist yet
if not hasattr(self, 'default_parameters'):
self.default_parameters = []
self.parameter_names = []
# cast value to a tensor if it is not None
value_t = None if value is None else cast_to_tensor(value, self.parameter_type)
self.default_parameters += [value_t,]
self.parameter_names += [name,]
return value_t
def _calc_event_shape(self):
"""
Calculate event_shape based on parameters.
"""
broadcast_shape = None
for param in self.default_parameters:
if broadcast_shape is None:
broadcast_shape = self.shape_base(param)
broadcast_shape_tensor = self.fill_base(self.parameter_type, broadcast_shape, 0.0)
else:
broadcast_shape = self.shape_base(param + broadcast_shape_tensor)
broadcast_shape_tensor = self.fill_base(self.parameter_type, broadcast_shape, 0.0)
return broadcast_shape
def _check_value(self, value, name): def _check_value(self, value, name):
""" """
Check availability of `value` as a Tensor. Check availability of `value` as a Tensor.

View File

@ -14,7 +14,9 @@
# ============================================================================ # ============================================================================
"""GumbelCDF Bijector""" """GumbelCDF Bijector"""
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from ..distribution._utils.utils import cast_to_tensor, check_greater_zero, set_param_type from mindspore._checkparam import Validator
from mindspore.ops import operations as P
from ..distribution._utils.utils import check_greater_zero, set_param_type
from ..distribution._utils.custom_ops import exp_generic, log_generic from ..distribution._utils.custom_ops import exp_generic, log_generic
from .bijector import Bijector from .bijector import Bijector
@ -33,6 +35,7 @@ class GumbelCDF(Bijector):
Args: Args:
loc (int, float, list, numpy.ndarray, Tensor): The location. Default: 0.. loc (int, float, list, numpy.ndarray, Tensor): The location. Default: 0..
scale (int, float, list, numpy.ndarray, Tensor): The scale. Default: 1.0. scale (int, float, list, numpy.ndarray, Tensor): The scale. Default: 1.0.
dtype (mindspore.dtype): Type of the distribution which the bijector operates on. Default: float32.
name (str): The name of the Bijector. Default: 'Gumbel_CDF'. name (str): The name of the Bijector. Default: 'Gumbel_CDF'.
Examples: Examples:
@ -58,17 +61,24 @@ class GumbelCDF(Bijector):
def __init__(self, def __init__(self,
loc=0.0, loc=0.0,
scale=1.0, scale=1.0,
dtype=mstype.float32,
name='GumbelCDF'): name='GumbelCDF'):
""" """
Constructor of GumbelCDF Bijector. Constructor of GumbelCDF Bijector.
""" """
param = dict(locals()) param = dict(locals())
parameter_type = set_param_type({'loc': loc, "scale": scale}, mstype.float32) valid_dtype = mstype.float_type + mstype.int_type + mstype.uint_type
super(GumbelCDF, self).__init__(name=name, dtype=parameter_type, param=param) Validator.check_type(type(self).__name__, dtype, valid_dtype)
self._loc = cast_to_tensor(loc, parameter_type) parameter_type = set_param_type({'loc': loc, "scale": scale}, dtype)
self._scale = cast_to_tensor(scale, parameter_type) super(GumbelCDF, self).__init__(name=name, dtype=dtype, param=param)
check_greater_zero(self._scale, "scale")
self._parameter_type = parameter_type
self._loc = self._add_parameter(loc, 'loc')
self._scale = self._add_parameter(scale, 'scale')
check_greater_zero(self._scale, "scale")
self._event_shape = self._calc_event_shape()
self.cast = P.Cast()
self.exp = exp_generic self.exp = exp_generic
self.log = log_generic self.log = log_generic
@ -81,6 +91,14 @@ class GumbelCDF(Bijector):
def scale(self): def scale(self):
return self._scale return self._scale
@property
def event_shape(self):
return self._event_shape
@property
def parameter_type(self):
return self._parameter_type
def extend_repr(self): def extend_repr(self):
str_info = f'loc = {self.loc}, scale = {self.scale}' str_info = f'loc = {self.loc}, scale = {self.scale}'
return str_info return str_info
@ -90,18 +108,22 @@ class GumbelCDF(Bijector):
def _forward(self, x): def _forward(self, x):
x = self._check_value(x, 'value') x = self._check_value(x, 'value')
x = self.cast(x, self.parameter_type)
z = (x - self.loc) / self.scale z = (x - self.loc) / self.scale
return self.exp(-self.exp(-z)) return self.exp(-self.exp(-z))
def _inverse(self, y): def _inverse(self, y):
y = self._check_value(y, 'value') y = self._check_value(y, 'value')
y = self.cast(y, self.parameter_type)
return self.loc - self.scale * self.log(-self.log(y)) return self.loc - self.scale * self.log(-self.log(y))
def _forward_log_jacobian(self, x): def _forward_log_jacobian(self, x):
x = self._check_value(x, 'value') x = self._check_value(x, 'value')
x = self.cast(x, self.parameter_type)
z = (x - self.loc) / self.scale z = (x - self.loc) / self.scale
return -z - self.exp(-z) - self.log(self.scale) return -z - self.exp(-z) - self.log(self.scale)
def _inverse_log_jacobian(self, y): def _inverse_log_jacobian(self, y):
y = self._check_value(y, 'value') y = self._check_value(y, 'value')
return self.log(self.scale / (-y * self.log(y))) y = self.cast(y, self.parameter_type)
return self.log(self.scale / (-1. * y * self.log(y)))

View File

@ -57,11 +57,19 @@ class Invert(Bijector):
name=name, name=name,
param=param) param=param)
self._bijector = bijector self._bijector = bijector
if hasattr(self._bijector, 'event_shape'):
self._event_shape = self.bijector.event_shape
else:
self._event_shape = ()
@property @property
def bijector(self): def bijector(self):
return self._bijector return self._bijector
@property
def event_shape(self):
return self._event_shape
def inverse(self, y): def inverse(self, y):
return self.bijector("forward", y) return self.bijector("forward", y)

View File

@ -26,6 +26,7 @@ from .geometric import Geometric
from .categorical import Categorical from .categorical import Categorical
from .log_normal import LogNormal from .log_normal import LogNormal
from .logistic import Logistic from .logistic import Logistic
from .gumbel import Gumbel
__all__ = ['Distribution', __all__ = ['Distribution',
'TransformedDistribution', 'TransformedDistribution',
@ -37,4 +38,5 @@ __all__ = ['Distribution',
'Geometric', 'Geometric',
'LogNormal', 'LogNormal',
'Logistic', 'Logistic',
'Gumbel',
] ]

View File

@ -132,6 +132,10 @@ class Distribution(Cell):
def broadcast_shape(self): def broadcast_shape(self):
return self._broadcast_shape return self._broadcast_shape
def _reset_parameters(self):
self.default_parameters = []
self.parameter_names = []
def _add_parameter(self, value, name): def _add_parameter(self, value, name):
""" """
Cast `value` to a tensor and add it to `self.default_parameters`. Cast `value` to a tensor and add it to `self.default_parameters`.

View File

@ -0,0 +1,249 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Gumbel Distribution"""
import numpy as np
from mindspore.ops import operations as P
from mindspore._checkparam import Validator
from mindspore.common import dtype as mstype
import mindspore.nn as nn
import mindspore.nn.probability.bijector as msb
import mindspore.nn.probability.distribution as msd
from .transformed_distribution import TransformedDistribution
from ._utils.utils import check_distribution_name, raise_not_implemented_util
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic
class Gumbel(TransformedDistribution):
"""
Gumbel distribution.
Args:
loc (int, float, list, numpy.ndarray, Tensor, Parameter): The location of Gumbel distribution.
scale (int, float, list, numpy.ndarray, Tensor, Parameter): The scale of Gumbel distribution.
seed (int): the seed used in sampling. The global seed is used if it is None. Default: None.
dtype (mindspore.dtype): type of the distribution. Default: mstype.float32.
name (str): the name of the distribution. Default: 'Gumbel'.
Note:
`scale` must be greater than zero.
`dist_spec_args` are `loc` and `scale`.
`dtype` must be a float type because Gumbel distributions are continuous.
Examples:
>>> # To initialize a Gumbel distribution of `loc` 3.0 and `scale` 4.0.
>>> gum = msd.Gumbel(3.0, 4.0, dtype=mstype.float32)
>>>
>>> # The following creates two independent Gumbel distributions.
>>> gum = msd.Gumbel([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32)
>>>
>>> # To use a Gumbel distribution in a network.
>>> class net(Cell):
>>> def __init__(self):
>>> super(net, self).__init__():
>>> self.g1 = msd.Gumbel(0.0, 1.0, dtype=mstype.float32)
>>>
>>> # The following calls are valid in construct.
>>> def construct(self, value, loc_b, scale_b):
>>>
>>> # Private interfaces of probability functions corresponding to public interfaces, including
>>> # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, have the same
>>> # arguments as follows.
>>> # Args:
>>> # value (Tensor): the value to be evaluated.
>>>
>>> # Examples of `prob`.
>>> # Similar calls can be made to other probability functions
>>> # by replacing 'prob' by the name of the function.
>>> ans = self.g1.prob(value)
>>>
>>> # Functions `mean`, `mode`, sd`, `var`, and `entropy` do not take in any argument.
>>> ans = self.g1.mean()
>>> ans = self.g1.mode()
>>> ans = self.g1.sd()
>>> ans = self.g1.entropy()
>>> ans = self.g1.var()
>>>
>>> # Interfaces of 'kl_loss' and 'cross_entropy' are the same:
>>> # Args:
>>> # dist (str): the type of the distributions. Only "Gumbel" is supported.
>>> # loc_b (Tensor): the loc of distribution b.
>>> # scale_b (Tensor): the scale distribution b.
>>>
>>> # Examples of `kl_loss`. `cross_entropy` is similar.
>>> ans = self.g1.kl_loss('Gumbel', loc_b, scale_b)
>>> ans = self.g1.cross_entropy('Gumbel', loc_b, scale_b)
>>>
>>> # Examples of `sample`.
>>> # Args:
>>> # shape (tuple): the shape of the sample. Default: ()
>>>
>>> ans = self.g1.sample()
>>> ans = self.g1.sample((2,3))
"""
def __init__(self,
loc,
scale,
seed=0,
dtype=mstype.float32,
name="Gumbel"):
"""
Constructor of Gumbel distribution.
"""
valid_dtype = mstype.float_type
Validator.check_type(type(self).__name__, dtype, valid_dtype)
gumbel_cdf = msb.GumbelCDF(loc, scale, dtype)
super(Gumbel, self).__init__(
distribution=msd.Uniform(0.0, 1.0, dtype=dtype),
bijector=msb.Invert(gumbel_cdf),
seed=seed, name=name)
self._parameter_type = gumbel_cdf.parameter_type
self._broadcast_shape = gumbel_cdf.event_shape
if self._broadcast_shape != ():
self._is_scalar_batch = False
# overwrite default_parameters and parameter_names
self._reset_parameters()
self._loc = self._add_parameter(loc, 'loc')
self._scale = self._add_parameter(scale, 'scale')
self._gumbel_bijector = gumbel_cdf
# ops needed for the class
self.cast = P.Cast()
self.const = P.ScalarToArray()
self.exp = exp_generic
self.expm1 = expm1_generic
self.fill = P.Fill()
self.lgamma = nn.LGamma()
self.log = log_generic
self.shape = P.Shape()
self.sqrt = P.Sqrt()
@property
def loc(self):
return self._loc
@property
def scale(self):
return self._scale
def extend_repr(self):
if self.is_scalar_batch:
str_info = f'loc = {self._loc}, scale = {self._scale}'
else:
str_info = f'batch_shape = {self._broadcast_shape}'
return str_info
def _mean(self):
r"""
The mean of the distribution.
.. math::
MEAN(X) = loc + scale * Euler-Mascheroni_constant
"""
return self.loc + self.scale * np.euler_gamma
def _mode(self):
"""
The mode of the distribution.
"""
return self.loc * self.fill(self.parameter_type, self.shape(self.scale), 1.0)
def _sd(self):
r"""
The standard deviation of the distribution.
.. math::
STD(X) = \frac{\pi}{\sqrt(6)} * scale
"""
scale = self.scale * self.fill(self.parameter_type, self.broadcast_shape, 1.0)
return scale * np.pi / self.sqrt(self.const(6.))
def _entropy(self):
r"""
Evaluate entropy.
.. math::
H(X) = 1. + \log(scale) + Euler-Mascheroni_constant
"""
scale = self.scale * self.fill(self.parameter_type, self.broadcast_shape, 1.0)
return 1. + self.log(scale) + np.euler_gamma
def _log_prob(self, value):
r"""
.. math::
log_pdf(X) = -(z + \exp(-z)) - \log(scale)
where z = \frac{x - loc}{scale}
"""
value = self._check_value(value, 'value')
z = (value - self.loc) / self.scale
return -(z + self.exp(-z)) - self.log(self.scale)
def _cdf(self, value):
r"""
.. math::
cdf_pdf(X) = \exp(-\exp(-\frac{x - loc}{scale})
"""
return self._gumbel_bijector("forward", value)
def _cross_entropy(self, dist, loc_b, scale_b):
r"""
Evaluate cross entropy between Gumbel distributions.
Args:
dist (str): The type of the distributions. Should be "Gumbel" in this case.
loc_b (Tensor): The loc of distribution b.
scale_b (Tensor): The scale of distribution b.
"""
if self.device_target == 'GPU':
raise_not_implemented_util('On GPU backend, cross_entropy', self.name)
check_distribution_name(dist, 'Gumbel')
return self._entropy() + self._kl_loss(dist, loc_b, scale_b)
def _kl_loss(self, dist, loc_b, scale_b):
r"""
Evaluate Gumbel-Gumbel kl divergence, i.e. KL(a||b).
Args:
dist (str): The type of the distributions. Should be "Gumbel" in this case.
loc_b (Tensor): The loc of distribution b.
scale_b (Tensor): The scale of distribution b.
.. math::
KL(a||b) = \log(scale_b / scale_a) + Euler-Mascheroni_constant * (scale_a / scale_b - 1.) +
\exp(\frac{(loc_b - loc_a)}{scale_b}) * \Gamma(scale_a / scale_b + 1.) - 1.
"""
if self.device_target == 'GPU':
raise_not_implemented_util('On GPU backend, kl_loss', self.name)
check_distribution_name(dist, 'Gumbel')
loc_b = self._check_value(loc_b, 'loc_b')
scale_b = self._check_value(scale_b, 'scale_b')
loc_b = self.cast(loc_b, self.parameter_type)
scale_b = self.cast(scale_b, self.parameter_type)
return self.log(scale_b) - self.log(self.scale) +\
np.euler_gamma * (self.scale / scale_b - 1.) +\
self.expm1((loc_b - self.loc) / scale_b + self.lgamma(self.scale / scale_b + 1.))
def _sample(self, shape=()):
origin_shape = shape + self._broadcast_shape
if origin_shape == ():
sample_shape = (1,)
else:
sample_shape = origin_shape
org_sample = self.distribution("sample", sample_shape)
value = self.bijector("forward", org_sample)
if origin_shape == ():
value = self.squeeze(value)
return value

View File

@ -82,11 +82,21 @@ class TransformedDistribution(Distribution):
self._is_linear_transformation = bijector.is_constant_jacobian self._is_linear_transformation = bijector.is_constant_jacobian
self.default_parameters = distribution.default_parameters self.default_parameters = distribution.default_parameters
self.parameter_names = distribution.parameter_names self.parameter_names = distribution.parameter_names
self.exp = exp_generic self.exp = exp_generic
self.log = log_generic self.log = log_generic
self.isnan = P.IsNan() self.isnan = P.IsNan()
self.equal_base = P.Equal() self.equal_base = P.Equal()
self.select_base = P.Select() self.select_base = P.Select()
self.fill = P.Fill()
# check if batch shape of the distribution and event shape is broadcastable
if hasattr(self.bijector, 'event_shape'):
event_shape_tensor = self.fill(self.dtype, self.bijector.event_shape, 0.0)
broadcast_shape_tensor = self.fill(self.dtype, self.broadcast_shape, 0.0)
self._batch_event = (event_shape_tensor + broadcast_shape_tensor).shape
else:
self._batch_event = self.broadcast_shape
@property @property
def bijector(self): def bijector(self):

View File

@ -0,0 +1,303 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test cases for Gumbel distribution"""
import numpy as np
from scipy import stats
from scipy import special
import mindspore.context as context
import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd
from mindspore import Tensor
from mindspore import dtype
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Prob(nn.Cell):
"""
Test class: probability of Gumbel distribution.
"""
def __init__(self):
super(Prob, self).__init__()
self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32)
def construct(self, x_):
return self.gum.prob(x_)
def test_pdf():
"""
Test pdf.
"""
loc = np.array([0.0]).astype(np.float32)
scale = np.array([[1.0], [2.0]]).astype(np.float32)
gumbel_benchmark = stats.gumbel_r(loc, scale)
value = np.array([1.0, 2.0]).astype(np.float32)
expect_pdf = gumbel_benchmark.pdf(value).astype(np.float32)
pdf = Prob()
output = pdf(Tensor(value, dtype=dtype.float32))
tol = 1e-6
assert (np.abs(output.asnumpy() - expect_pdf) < tol).all()
class LogProb(nn.Cell):
"""
Test class: log probability of Gumbel distribution.
"""
def __init__(self):
super(LogProb, self).__init__()
self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32)
def construct(self, x_):
return self.gum.log_prob(x_)
def test_log_likelihood():
"""
Test log_pdf.
"""
loc = np.array([0.0]).astype(np.float32)
scale = np.array([[1.0], [2.0]]).astype(np.float32)
gumbel_benchmark = stats.gumbel_r(loc, scale)
expect_logpdf = gumbel_benchmark.logpdf([1.0, 2.0]).astype(np.float32)
logprob = LogProb()
output = logprob(Tensor([1.0, 2.0], dtype=dtype.float32))
tol = 1e-6
assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all()
class KL(nn.Cell):
"""
Test class: kl_loss of Gumbel distribution.
"""
def __init__(self):
super(KL, self).__init__()
self.gum = msd.Gumbel(np.array([0.0]), np.array([1.0, 2.0]), dtype=dtype.float32)
def construct(self, loc_b, scale_b):
return self.gum.kl_loss('Gumbel', loc_b, scale_b)
def test_kl_loss():
"""
Test kl_loss.
"""
loc = np.array([0.0]).astype(np.float32)
scale = np.array([1.0, 2.0]).astype(np.float32)
loc_b = np.array([1.0]).astype(np.float32)
scale_b = np.array([1.0, 2.0]).astype(np.float32)
expect_kl_loss = np.log(scale_b) - np.log(scale) +\
np.euler_gamma * (scale / scale_b - 1.) +\
np.expm1((loc_b - loc) / scale_b + special.loggamma(scale / scale_b + 1.))
kl_loss = KL()
loc_b = Tensor(loc_b, dtype=dtype.float32)
scale_b = Tensor(scale_b, dtype=dtype.float32)
output = kl_loss(loc_b, scale_b)
tol = 1e-5
assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all()
class Basics(nn.Cell):
"""
Test class: mean/sd/mode of Gumbel distribution.
"""
def __init__(self):
super(Basics, self).__init__()
self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32)
def construct(self):
return self.gum.mean(), self.gum.sd(), self.gum.mode()
def test_basics():
"""
Test mean/standard deviation/mode.
"""
basics = Basics()
mean, sd, mode = basics()
loc = np.array([0.0]).astype(np.float32)
scale = np.array([[1.0], [2.0]]).astype(np.float32)
gumbel_benchmark = stats.gumbel_r(loc, scale)
expect_mean = gumbel_benchmark.mean().astype(np.float32)
expect_sd = gumbel_benchmark.std().astype(np.float32)
expect_mode = np.array([[0.0], [0.0]]).astype(np.float32)
tol = 1e-6
assert (np.abs(mean.asnumpy() - expect_mean) < tol).all()
assert (np.abs(mode.asnumpy() - expect_mode) < tol).all()
assert (np.abs(sd.asnumpy() - expect_sd) < tol).all()
class Sampling(nn.Cell):
"""
Test class: sample of Gumbel distribution.
"""
def __init__(self, shape, seed=0):
super(Sampling, self).__init__()
self.gum = msd.Gumbel(np.array([0.0]), np.array([1.0, 2.0, 3.0]), dtype=dtype.float32, seed=seed)
self.shape = shape
def construct(self):
return self.gum.sample(self.shape)
def test_sample():
"""
Test sample.
"""
shape = (2, 3)
seed = 10
sample = Sampling(shape, seed=seed)
output = sample()
assert output.shape == (2, 3, 3)
class CDF(nn.Cell):
"""
Test class: cdf of Gumbel distribution.
"""
def __init__(self):
super(CDF, self).__init__()
self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32)
def construct(self, x_):
return self.gum.cdf(x_)
def test_cdf():
"""
Test cdf.
"""
loc = np.array([0.0]).astype(np.float32)
scale = np.array([[1.0], [2.0]]).astype(np.float32)
gumbel_benchmark = stats.gumbel_r(loc, scale)
expect_cdf = gumbel_benchmark.cdf([1.0, 2.0]).astype(np.float32)
cdf = CDF()
output = cdf(Tensor([1.0, 2.0], dtype=dtype.float32))
tol = 2e-5
assert (np.abs(output.asnumpy() - expect_cdf) < tol).all()
class LogCDF(nn.Cell):
"""
Test class: log_cdf of Gumbel distribution.
"""
def __init__(self):
super(LogCDF, self).__init__()
self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32)
def construct(self, x_):
return self.gum.log_cdf(x_)
def test_log_cdf():
"""
Test log cdf.
"""
loc = np.array([0.0]).astype(np.float32)
scale = np.array([[1.0], [2.0]]).astype(np.float32)
gumbel_benchmark = stats.gumbel_r(loc, scale)
expect_logcdf = gumbel_benchmark.logcdf([1.0, 2.0]).astype(np.float32)
logcdf = LogCDF()
output = logcdf(Tensor([1.0, 2.0], dtype=dtype.float32))
tol = 1e-4
assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all()
class SF(nn.Cell):
"""
Test class: survival function of Gumbel distribution.
"""
def __init__(self):
super(SF, self).__init__()
self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32)
def construct(self, x_):
return self.gum.survival_function(x_)
def test_survival():
"""
Test log_survival.
"""
loc = np.array([0.0]).astype(np.float32)
scale = np.array([[1.0], [2.0]]).astype(np.float32)
gumbel_benchmark = stats.gumbel_r(loc, scale)
expect_survival = gumbel_benchmark.sf([1.0, 2.0]).astype(np.float32)
survival_function = SF()
output = survival_function(Tensor([1.0, 2.0], dtype=dtype.float32))
tol = 2e-5
assert (np.abs(output.asnumpy() - expect_survival) < tol).all()
class LogSF(nn.Cell):
"""
Test class: log survival function of Gumbel distribution.
"""
def __init__(self):
super(LogSF, self).__init__()
self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32)
def construct(self, x_):
return self.gum.log_survival(x_)
def test_log_survival():
"""
Test log_survival.
"""
loc = np.array([0.0]).astype(np.float32)
scale = np.array([[1.0], [2.0]]).astype(np.float32)
gumbel_benchmark = stats.gumbel_r(loc, scale)
expect_log_survival = gumbel_benchmark.logsf([1.0, 2.0]).astype(np.float32)
log_survival = LogSF()
output = log_survival(Tensor([1.0, 2.0], dtype=dtype.float32))
tol = 5e-4
assert (np.abs(output.asnumpy() - expect_log_survival) < tol).all()
class EntropyH(nn.Cell):
"""
Test class: entropy of Gumbel distribution.
"""
def __init__(self):
super(EntropyH, self).__init__()
self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32)
def construct(self):
return self.gum.entropy()
def test_entropy():
"""
Test entropy.
"""
loc = np.array([0.0]).astype(np.float32)
scale = np.array([[1.0], [2.0]]).astype(np.float32)
gumbel_benchmark = stats.gumbel_r(loc, scale)
expect_entropy = gumbel_benchmark.entropy().astype(np.float32)
entropy = EntropyH()
output = entropy()
tol = 1e-6
assert (np.abs(output.asnumpy() - expect_entropy) < tol).all()
class CrossEntropy(nn.Cell):
"""
Test class: cross entropy between Gumbel distributions.
"""
def __init__(self):
super(CrossEntropy, self).__init__()
self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32)
def construct(self, x_, y_):
entropy = self.gum.entropy()
kl_loss = self.gum.kl_loss('Gumbel', x_, y_)
h_sum_kl = entropy + kl_loss
cross_entropy = self.gum.cross_entropy('Gumbel', x_, y_)
return h_sum_kl - cross_entropy
def test_cross_entropy():
"""
Test cross_entropy.
"""
cross_entropy = CrossEntropy()
loc = Tensor([1.0], dtype=dtype.float32)
scale = Tensor([1.0], dtype=dtype.float32)
diff = cross_entropy(loc, scale)
tol = 1e-6
assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all()

View File

@ -0,0 +1,153 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Test nn.probability.distribution.gumbel.
"""
import numpy as np
import pytest
import mindspore.nn as nn
import mindspore.nn.probability.distribution as msd
from mindspore import dtype
from mindspore import Tensor
def test_gumbel_shape_errpr():
"""
Invalid shapes.
"""
with pytest.raises(ValueError):
msd.Gumbel([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32)
def test_type():
with pytest.raises(TypeError):
msd.Gumbel(0., 1., dtype=dtype.int32)
def test_name():
with pytest.raises(TypeError):
msd.Gumbel(0., 1., name=1.0)
def test_seed():
with pytest.raises(TypeError):
msd.Gumbel(0., 1., seed='seed')
def test_scale():
with pytest.raises(ValueError):
msd.Gumbel(0., 0.)
with pytest.raises(ValueError):
msd.Gumbel(0., -1.)
def test_arguments():
"""
args passing during initialization.
"""
l = msd.Gumbel([3.0], [4.0], dtype=dtype.float32)
assert isinstance(l, msd.Distribution)
class GumbelProb(nn.Cell):
"""
Gumbel distribution: initialize with loc/scale.
"""
def __init__(self):
super(GumbelProb, self).__init__()
self.gumbel = msd.Gumbel(3.0, 4.0, dtype=dtype.float32)
def construct(self, value):
prob = self.gumbel.prob(value)
log_prob = self.gumbel.log_prob(value)
cdf = self.gumbel.cdf(value)
log_cdf = self.gumbel.log_cdf(value)
sf = self.gumbel.survival_function(value)
log_sf = self.gumbel.log_survival(value)
return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_gumbel_prob():
"""
Test probability functions: passing value through construct.
"""
net = GumbelProb()
value = Tensor([0.5, 1.0], dtype=dtype.float32)
ans = net(value)
assert isinstance(ans, Tensor)
class KL(nn.Cell):
"""
Test kl_loss.
"""
def __init__(self):
super(KL, self).__init__()
self.gumbel = msd.Gumbel(3.0, 4.0)
def construct(self, mu, s):
kl = self.gumbel.kl_loss('Gumbel', mu, s)
cross_entropy = self.gumbel.cross_entropy('Gumbel', mu, s)
return kl + cross_entropy
def test_kl_cross_entropy():
"""
Test kl_loss and cross_entropy.
"""
net = KL()
loc_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32)
scale_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32)
ans = net(loc_b, scale_b)
assert isinstance(ans, Tensor)
class GumbelBasics(nn.Cell):
"""
Test class: basic loc/scale function.
"""
def __init__(self):
super(GumbelBasics, self).__init__()
self.gumbel = msd.Gumbel(3.0, 4.0, dtype=dtype.float32)
def construct(self):
mean = self.gumbel.mean()
sd = self.gumbel.sd()
mode = self.gumbel.mode()
entropy = self.gumbel.entropy()
return mean + sd + mode + entropy
def test_bascis():
"""
Test mean/sd/mode/entropy functionality of Gumbel.
"""
net = GumbelBasics()
ans = net()
assert isinstance(ans, Tensor)
class GumbelConstruct(nn.Cell):
"""
Gumbel distribution: going through construct.
"""
def __init__(self):
super(GumbelConstruct, self).__init__()
self.gumbel = msd.Gumbel(3.0, 4.0)
def construct(self, value):
prob = self.gumbel('prob', value)
prob1 = self.gumbel.prob(value)
return prob + prob1
def test_gumbel_construct():
"""
Test probability function going through construct.
"""
net = GumbelConstruct()
value = Tensor([0.5, 1.0], dtype=dtype.float32)
ans = net(value)
assert isinstance(ans, Tensor)