From ce170b22413c32b55c587d430470988674879866 Mon Sep 17 00:00:00 2001 From: Xun Deng Date: Wed, 14 Oct 2020 12:01:51 -0400 Subject: [PATCH] added Gumbel distribution --- mindspore/nn/probability/bijector/bijector.py | 34 +- .../nn/probability/bijector/gumbel_cdf.py | 36 ++- mindspore/nn/probability/bijector/invert.py | 8 + .../nn/probability/distribution/__init__.py | 2 + .../probability/distribution/distribution.py | 4 + .../nn/probability/distribution/gumbel.py | 249 ++++++++++++++ .../distribution/transformed_distribution.py | 10 + .../probability/distribution/test_gumbel.py | 303 ++++++++++++++++++ .../probability/distribution/test_gumbel.py | 153 +++++++++ 9 files changed, 791 insertions(+), 8 deletions(-) create mode 100644 mindspore/nn/probability/distribution/gumbel.py create mode 100644 tests/st/probability/distribution/test_gumbel.py create mode 100644 tests/ut/python/nn/probability/distribution/test_gumbel.py diff --git a/mindspore/nn/probability/bijector/bijector.py b/mindspore/nn/probability/bijector/bijector.py index e9fa7942205..35abd2d5210 100644 --- a/mindspore/nn/probability/bijector/bijector.py +++ b/mindspore/nn/probability/bijector/bijector.py @@ -17,7 +17,7 @@ from mindspore import context from mindspore.nn.cell import Cell from mindspore.ops import operations as P from mindspore._checkparam import Validator as validator -from ..distribution._utils.utils import CheckTensor +from ..distribution._utils.utils import CheckTensor, cast_to_tensor from ..distribution import Distribution from ..distribution import TransformedDistribution @@ -66,6 +66,8 @@ class Bijector(Cell): # ops needed for the base class self.cast_base = P.Cast() self.dtype_base = P.DType() + self.shape_base = P.Shape() + self.fill_base = P.Fill() @property def name(self): @@ -87,6 +89,36 @@ class Bijector(Cell): def is_injective(self): return self._is_injective + def _add_parameter(self, value, name): + """ + Cast `value` to a tensor and add it to `self.default_parameters`. + Add `name` into and `self.parameter_names`. + """ + # initialize the attributes if they do not exist yet + if not hasattr(self, 'default_parameters'): + self.default_parameters = [] + self.parameter_names = [] + # cast value to a tensor if it is not None + value_t = None if value is None else cast_to_tensor(value, self.parameter_type) + self.default_parameters += [value_t,] + self.parameter_names += [name,] + return value_t + + def _calc_event_shape(self): + """ + Calculate event_shape based on parameters. + """ + broadcast_shape = None + for param in self.default_parameters: + if broadcast_shape is None: + broadcast_shape = self.shape_base(param) + broadcast_shape_tensor = self.fill_base(self.parameter_type, broadcast_shape, 0.0) + else: + broadcast_shape = self.shape_base(param + broadcast_shape_tensor) + broadcast_shape_tensor = self.fill_base(self.parameter_type, broadcast_shape, 0.0) + return broadcast_shape + + def _check_value(self, value, name): """ Check availability of `value` as a Tensor. diff --git a/mindspore/nn/probability/bijector/gumbel_cdf.py b/mindspore/nn/probability/bijector/gumbel_cdf.py index b15fff9bdc9..5cfb6b0a57c 100644 --- a/mindspore/nn/probability/bijector/gumbel_cdf.py +++ b/mindspore/nn/probability/bijector/gumbel_cdf.py @@ -14,7 +14,9 @@ # ============================================================================ """GumbelCDF Bijector""" from mindspore.common import dtype as mstype -from ..distribution._utils.utils import cast_to_tensor, check_greater_zero, set_param_type +from mindspore._checkparam import Validator +from mindspore.ops import operations as P +from ..distribution._utils.utils import check_greater_zero, set_param_type from ..distribution._utils.custom_ops import exp_generic, log_generic from .bijector import Bijector @@ -33,6 +35,7 @@ class GumbelCDF(Bijector): Args: loc (int, float, list, numpy.ndarray, Tensor): The location. Default: 0.. scale (int, float, list, numpy.ndarray, Tensor): The scale. Default: 1.0. + dtype (mindspore.dtype): Type of the distribution which the bijector operates on. Default: float32. name (str): The name of the Bijector. Default: 'Gumbel_CDF'. Examples: @@ -58,17 +61,24 @@ class GumbelCDF(Bijector): def __init__(self, loc=0.0, scale=1.0, + dtype=mstype.float32, name='GumbelCDF'): """ Constructor of GumbelCDF Bijector. """ param = dict(locals()) - parameter_type = set_param_type({'loc': loc, "scale": scale}, mstype.float32) - super(GumbelCDF, self).__init__(name=name, dtype=parameter_type, param=param) - self._loc = cast_to_tensor(loc, parameter_type) - self._scale = cast_to_tensor(scale, parameter_type) - check_greater_zero(self._scale, "scale") + valid_dtype = mstype.float_type + mstype.int_type + mstype.uint_type + Validator.check_type(type(self).__name__, dtype, valid_dtype) + parameter_type = set_param_type({'loc': loc, "scale": scale}, dtype) + super(GumbelCDF, self).__init__(name=name, dtype=dtype, param=param) + self._parameter_type = parameter_type + self._loc = self._add_parameter(loc, 'loc') + self._scale = self._add_parameter(scale, 'scale') + check_greater_zero(self._scale, "scale") + self._event_shape = self._calc_event_shape() + + self.cast = P.Cast() self.exp = exp_generic self.log = log_generic @@ -81,6 +91,14 @@ class GumbelCDF(Bijector): def scale(self): return self._scale + @property + def event_shape(self): + return self._event_shape + + @property + def parameter_type(self): + return self._parameter_type + def extend_repr(self): str_info = f'loc = {self.loc}, scale = {self.scale}' return str_info @@ -90,18 +108,22 @@ class GumbelCDF(Bijector): def _forward(self, x): x = self._check_value(x, 'value') + x = self.cast(x, self.parameter_type) z = (x - self.loc) / self.scale return self.exp(-self.exp(-z)) def _inverse(self, y): y = self._check_value(y, 'value') + y = self.cast(y, self.parameter_type) return self.loc - self.scale * self.log(-self.log(y)) def _forward_log_jacobian(self, x): x = self._check_value(x, 'value') + x = self.cast(x, self.parameter_type) z = (x - self.loc) / self.scale return -z - self.exp(-z) - self.log(self.scale) def _inverse_log_jacobian(self, y): y = self._check_value(y, 'value') - return self.log(self.scale / (-y * self.log(y))) + y = self.cast(y, self.parameter_type) + return self.log(self.scale / (-1. * y * self.log(y))) diff --git a/mindspore/nn/probability/bijector/invert.py b/mindspore/nn/probability/bijector/invert.py index efb47a51f48..17f8dbc27a0 100644 --- a/mindspore/nn/probability/bijector/invert.py +++ b/mindspore/nn/probability/bijector/invert.py @@ -57,11 +57,19 @@ class Invert(Bijector): name=name, param=param) self._bijector = bijector + if hasattr(self._bijector, 'event_shape'): + self._event_shape = self.bijector.event_shape + else: + self._event_shape = () @property def bijector(self): return self._bijector + @property + def event_shape(self): + return self._event_shape + def inverse(self, y): return self.bijector("forward", y) diff --git a/mindspore/nn/probability/distribution/__init__.py b/mindspore/nn/probability/distribution/__init__.py index 4842a0cb68a..3dd818dc220 100644 --- a/mindspore/nn/probability/distribution/__init__.py +++ b/mindspore/nn/probability/distribution/__init__.py @@ -26,6 +26,7 @@ from .geometric import Geometric from .categorical import Categorical from .log_normal import LogNormal from .logistic import Logistic +from .gumbel import Gumbel __all__ = ['Distribution', 'TransformedDistribution', @@ -37,4 +38,5 @@ __all__ = ['Distribution', 'Geometric', 'LogNormal', 'Logistic', + 'Gumbel', ] diff --git a/mindspore/nn/probability/distribution/distribution.py b/mindspore/nn/probability/distribution/distribution.py index c121100ac29..39e1073e784 100644 --- a/mindspore/nn/probability/distribution/distribution.py +++ b/mindspore/nn/probability/distribution/distribution.py @@ -132,6 +132,10 @@ class Distribution(Cell): def broadcast_shape(self): return self._broadcast_shape + def _reset_parameters(self): + self.default_parameters = [] + self.parameter_names = [] + def _add_parameter(self, value, name): """ Cast `value` to a tensor and add it to `self.default_parameters`. diff --git a/mindspore/nn/probability/distribution/gumbel.py b/mindspore/nn/probability/distribution/gumbel.py new file mode 100644 index 00000000000..c341db85376 --- /dev/null +++ b/mindspore/nn/probability/distribution/gumbel.py @@ -0,0 +1,249 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Gumbel Distribution""" +import numpy as np +from mindspore.ops import operations as P +from mindspore._checkparam import Validator +from mindspore.common import dtype as mstype +import mindspore.nn as nn +import mindspore.nn.probability.bijector as msb +import mindspore.nn.probability.distribution as msd +from .transformed_distribution import TransformedDistribution +from ._utils.utils import check_distribution_name, raise_not_implemented_util +from ._utils.custom_ops import exp_generic, expm1_generic, log_generic + +class Gumbel(TransformedDistribution): + """ + Gumbel distribution. + + Args: + loc (int, float, list, numpy.ndarray, Tensor, Parameter): The location of Gumbel distribution. + scale (int, float, list, numpy.ndarray, Tensor, Parameter): The scale of Gumbel distribution. + seed (int): the seed used in sampling. The global seed is used if it is None. Default: None. + dtype (mindspore.dtype): type of the distribution. Default: mstype.float32. + name (str): the name of the distribution. Default: 'Gumbel'. + + Note: + `scale` must be greater than zero. + `dist_spec_args` are `loc` and `scale`. + `dtype` must be a float type because Gumbel distributions are continuous. + + Examples: + >>> # To initialize a Gumbel distribution of `loc` 3.0 and `scale` 4.0. + >>> gum = msd.Gumbel(3.0, 4.0, dtype=mstype.float32) + >>> + >>> # The following creates two independent Gumbel distributions. + >>> gum = msd.Gumbel([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32) + >>> + >>> # To use a Gumbel distribution in a network. + >>> class net(Cell): + >>> def __init__(self): + >>> super(net, self).__init__(): + >>> self.g1 = msd.Gumbel(0.0, 1.0, dtype=mstype.float32) + >>> + >>> # The following calls are valid in construct. + >>> def construct(self, value, loc_b, scale_b): + >>> + >>> # Private interfaces of probability functions corresponding to public interfaces, including + >>> # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, have the same + >>> # arguments as follows. + >>> # Args: + >>> # value (Tensor): the value to be evaluated. + >>> + >>> # Examples of `prob`. + >>> # Similar calls can be made to other probability functions + >>> # by replacing 'prob' by the name of the function. + >>> ans = self.g1.prob(value) + >>> + >>> # Functions `mean`, `mode`, sd`, `var`, and `entropy` do not take in any argument. + >>> ans = self.g1.mean() + >>> ans = self.g1.mode() + >>> ans = self.g1.sd() + >>> ans = self.g1.entropy() + >>> ans = self.g1.var() + >>> + >>> # Interfaces of 'kl_loss' and 'cross_entropy' are the same: + >>> # Args: + >>> # dist (str): the type of the distributions. Only "Gumbel" is supported. + >>> # loc_b (Tensor): the loc of distribution b. + >>> # scale_b (Tensor): the scale distribution b. + >>> + >>> # Examples of `kl_loss`. `cross_entropy` is similar. + >>> ans = self.g1.kl_loss('Gumbel', loc_b, scale_b) + >>> ans = self.g1.cross_entropy('Gumbel', loc_b, scale_b) + >>> + >>> # Examples of `sample`. + >>> # Args: + >>> # shape (tuple): the shape of the sample. Default: () + >>> + >>> ans = self.g1.sample() + >>> ans = self.g1.sample((2,3)) + """ + + def __init__(self, + loc, + scale, + seed=0, + dtype=mstype.float32, + name="Gumbel"): + """ + Constructor of Gumbel distribution. + """ + valid_dtype = mstype.float_type + Validator.check_type(type(self).__name__, dtype, valid_dtype) + gumbel_cdf = msb.GumbelCDF(loc, scale, dtype) + super(Gumbel, self).__init__( + distribution=msd.Uniform(0.0, 1.0, dtype=dtype), + bijector=msb.Invert(gumbel_cdf), + seed=seed, name=name) + + self._parameter_type = gumbel_cdf.parameter_type + self._broadcast_shape = gumbel_cdf.event_shape + if self._broadcast_shape != (): + self._is_scalar_batch = False + + # overwrite default_parameters and parameter_names + self._reset_parameters() + self._loc = self._add_parameter(loc, 'loc') + self._scale = self._add_parameter(scale, 'scale') + self._gumbel_bijector = gumbel_cdf + + # ops needed for the class + self.cast = P.Cast() + self.const = P.ScalarToArray() + self.exp = exp_generic + self.expm1 = expm1_generic + self.fill = P.Fill() + self.lgamma = nn.LGamma() + self.log = log_generic + self.shape = P.Shape() + self.sqrt = P.Sqrt() + + @property + def loc(self): + return self._loc + + @property + def scale(self): + return self._scale + + def extend_repr(self): + if self.is_scalar_batch: + str_info = f'loc = {self._loc}, scale = {self._scale}' + else: + str_info = f'batch_shape = {self._broadcast_shape}' + return str_info + + def _mean(self): + r""" + The mean of the distribution. + + .. math:: + MEAN(X) = loc + scale * Euler-Mascheroni_constant + """ + return self.loc + self.scale * np.euler_gamma + + def _mode(self): + """ + The mode of the distribution. + """ + return self.loc * self.fill(self.parameter_type, self.shape(self.scale), 1.0) + + def _sd(self): + r""" + The standard deviation of the distribution. + + .. math:: + STD(X) = \frac{\pi}{\sqrt(6)} * scale + """ + scale = self.scale * self.fill(self.parameter_type, self.broadcast_shape, 1.0) + return scale * np.pi / self.sqrt(self.const(6.)) + + def _entropy(self): + r""" + Evaluate entropy. + + .. math:: + H(X) = 1. + \log(scale) + Euler-Mascheroni_constant + """ + scale = self.scale * self.fill(self.parameter_type, self.broadcast_shape, 1.0) + return 1. + self.log(scale) + np.euler_gamma + + def _log_prob(self, value): + r""" + .. math:: + log_pdf(X) = -(z + \exp(-z)) - \log(scale) + where z = \frac{x - loc}{scale} + """ + value = self._check_value(value, 'value') + z = (value - self.loc) / self.scale + return -(z + self.exp(-z)) - self.log(self.scale) + + def _cdf(self, value): + r""" + .. math:: + cdf_pdf(X) = \exp(-\exp(-\frac{x - loc}{scale}) + """ + return self._gumbel_bijector("forward", value) + + def _cross_entropy(self, dist, loc_b, scale_b): + r""" + Evaluate cross entropy between Gumbel distributions. + + Args: + dist (str): The type of the distributions. Should be "Gumbel" in this case. + loc_b (Tensor): The loc of distribution b. + scale_b (Tensor): The scale of distribution b. + """ + if self.device_target == 'GPU': + raise_not_implemented_util('On GPU backend, cross_entropy', self.name) + check_distribution_name(dist, 'Gumbel') + return self._entropy() + self._kl_loss(dist, loc_b, scale_b) + + def _kl_loss(self, dist, loc_b, scale_b): + r""" + Evaluate Gumbel-Gumbel kl divergence, i.e. KL(a||b). + + Args: + dist (str): The type of the distributions. Should be "Gumbel" in this case. + loc_b (Tensor): The loc of distribution b. + scale_b (Tensor): The scale of distribution b. + + .. math:: + KL(a||b) = \log(scale_b / scale_a) + Euler-Mascheroni_constant * (scale_a / scale_b - 1.) + + \exp(\frac{(loc_b - loc_a)}{scale_b}) * \Gamma(scale_a / scale_b + 1.) - 1. + """ + if self.device_target == 'GPU': + raise_not_implemented_util('On GPU backend, kl_loss', self.name) + check_distribution_name(dist, 'Gumbel') + loc_b = self._check_value(loc_b, 'loc_b') + scale_b = self._check_value(scale_b, 'scale_b') + loc_b = self.cast(loc_b, self.parameter_type) + scale_b = self.cast(scale_b, self.parameter_type) + return self.log(scale_b) - self.log(self.scale) +\ + np.euler_gamma * (self.scale / scale_b - 1.) +\ + self.expm1((loc_b - self.loc) / scale_b + self.lgamma(self.scale / scale_b + 1.)) + + def _sample(self, shape=()): + origin_shape = shape + self._broadcast_shape + if origin_shape == (): + sample_shape = (1,) + else: + sample_shape = origin_shape + org_sample = self.distribution("sample", sample_shape) + value = self.bijector("forward", org_sample) + if origin_shape == (): + value = self.squeeze(value) + return value diff --git a/mindspore/nn/probability/distribution/transformed_distribution.py b/mindspore/nn/probability/distribution/transformed_distribution.py index b1fa0fc9c25..1bcc77781df 100644 --- a/mindspore/nn/probability/distribution/transformed_distribution.py +++ b/mindspore/nn/probability/distribution/transformed_distribution.py @@ -82,11 +82,21 @@ class TransformedDistribution(Distribution): self._is_linear_transformation = bijector.is_constant_jacobian self.default_parameters = distribution.default_parameters self.parameter_names = distribution.parameter_names + self.exp = exp_generic self.log = log_generic self.isnan = P.IsNan() self.equal_base = P.Equal() self.select_base = P.Select() + self.fill = P.Fill() + + # check if batch shape of the distribution and event shape is broadcastable + if hasattr(self.bijector, 'event_shape'): + event_shape_tensor = self.fill(self.dtype, self.bijector.event_shape, 0.0) + broadcast_shape_tensor = self.fill(self.dtype, self.broadcast_shape, 0.0) + self._batch_event = (event_shape_tensor + broadcast_shape_tensor).shape + else: + self._batch_event = self.broadcast_shape @property def bijector(self): diff --git a/tests/st/probability/distribution/test_gumbel.py b/tests/st/probability/distribution/test_gumbel.py new file mode 100644 index 00000000000..313fdfe5414 --- /dev/null +++ b/tests/st/probability/distribution/test_gumbel.py @@ -0,0 +1,303 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test cases for Gumbel distribution""" +import numpy as np +from scipy import stats +from scipy import special +import mindspore.context as context +import mindspore.nn as nn +import mindspore.nn.probability.distribution as msd +from mindspore import Tensor +from mindspore import dtype + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +class Prob(nn.Cell): + """ + Test class: probability of Gumbel distribution. + """ + def __init__(self): + super(Prob, self).__init__() + self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32) + + def construct(self, x_): + return self.gum.prob(x_) + +def test_pdf(): + """ + Test pdf. + """ + loc = np.array([0.0]).astype(np.float32) + scale = np.array([[1.0], [2.0]]).astype(np.float32) + gumbel_benchmark = stats.gumbel_r(loc, scale) + value = np.array([1.0, 2.0]).astype(np.float32) + expect_pdf = gumbel_benchmark.pdf(value).astype(np.float32) + pdf = Prob() + output = pdf(Tensor(value, dtype=dtype.float32)) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_pdf) < tol).all() + +class LogProb(nn.Cell): + """ + Test class: log probability of Gumbel distribution. + """ + def __init__(self): + super(LogProb, self).__init__() + self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32) + + def construct(self, x_): + return self.gum.log_prob(x_) + +def test_log_likelihood(): + """ + Test log_pdf. + """ + loc = np.array([0.0]).astype(np.float32) + scale = np.array([[1.0], [2.0]]).astype(np.float32) + gumbel_benchmark = stats.gumbel_r(loc, scale) + expect_logpdf = gumbel_benchmark.logpdf([1.0, 2.0]).astype(np.float32) + logprob = LogProb() + output = logprob(Tensor([1.0, 2.0], dtype=dtype.float32)) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all() + +class KL(nn.Cell): + """ + Test class: kl_loss of Gumbel distribution. + """ + def __init__(self): + super(KL, self).__init__() + self.gum = msd.Gumbel(np.array([0.0]), np.array([1.0, 2.0]), dtype=dtype.float32) + + def construct(self, loc_b, scale_b): + return self.gum.kl_loss('Gumbel', loc_b, scale_b) + +def test_kl_loss(): + """ + Test kl_loss. + """ + loc = np.array([0.0]).astype(np.float32) + scale = np.array([1.0, 2.0]).astype(np.float32) + + loc_b = np.array([1.0]).astype(np.float32) + scale_b = np.array([1.0, 2.0]).astype(np.float32) + + expect_kl_loss = np.log(scale_b) - np.log(scale) +\ + np.euler_gamma * (scale / scale_b - 1.) +\ + np.expm1((loc_b - loc) / scale_b + special.loggamma(scale / scale_b + 1.)) + + kl_loss = KL() + loc_b = Tensor(loc_b, dtype=dtype.float32) + scale_b = Tensor(scale_b, dtype=dtype.float32) + output = kl_loss(loc_b, scale_b) + tol = 1e-5 + assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() + +class Basics(nn.Cell): + """ + Test class: mean/sd/mode of Gumbel distribution. + """ + def __init__(self): + super(Basics, self).__init__() + self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32) + + def construct(self): + return self.gum.mean(), self.gum.sd(), self.gum.mode() + +def test_basics(): + """ + Test mean/standard deviation/mode. + """ + basics = Basics() + mean, sd, mode = basics() + + loc = np.array([0.0]).astype(np.float32) + scale = np.array([[1.0], [2.0]]).astype(np.float32) + gumbel_benchmark = stats.gumbel_r(loc, scale) + expect_mean = gumbel_benchmark.mean().astype(np.float32) + expect_sd = gumbel_benchmark.std().astype(np.float32) + expect_mode = np.array([[0.0], [0.0]]).astype(np.float32) + tol = 1e-6 + assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() + assert (np.abs(mode.asnumpy() - expect_mode) < tol).all() + assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() + +class Sampling(nn.Cell): + """ + Test class: sample of Gumbel distribution. + """ + def __init__(self, shape, seed=0): + super(Sampling, self).__init__() + self.gum = msd.Gumbel(np.array([0.0]), np.array([1.0, 2.0, 3.0]), dtype=dtype.float32, seed=seed) + self.shape = shape + + def construct(self): + return self.gum.sample(self.shape) + +def test_sample(): + """ + Test sample. + """ + shape = (2, 3) + seed = 10 + sample = Sampling(shape, seed=seed) + output = sample() + assert output.shape == (2, 3, 3) + +class CDF(nn.Cell): + """ + Test class: cdf of Gumbel distribution. + """ + def __init__(self): + super(CDF, self).__init__() + self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32) + + def construct(self, x_): + return self.gum.cdf(x_) + +def test_cdf(): + """ + Test cdf. + """ + loc = np.array([0.0]).astype(np.float32) + scale = np.array([[1.0], [2.0]]).astype(np.float32) + gumbel_benchmark = stats.gumbel_r(loc, scale) + expect_cdf = gumbel_benchmark.cdf([1.0, 2.0]).astype(np.float32) + cdf = CDF() + output = cdf(Tensor([1.0, 2.0], dtype=dtype.float32)) + tol = 2e-5 + assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() + +class LogCDF(nn.Cell): + """ + Test class: log_cdf of Gumbel distribution. + """ + def __init__(self): + super(LogCDF, self).__init__() + self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32) + + def construct(self, x_): + return self.gum.log_cdf(x_) + +def test_log_cdf(): + """ + Test log cdf. + """ + loc = np.array([0.0]).astype(np.float32) + scale = np.array([[1.0], [2.0]]).astype(np.float32) + gumbel_benchmark = stats.gumbel_r(loc, scale) + expect_logcdf = gumbel_benchmark.logcdf([1.0, 2.0]).astype(np.float32) + logcdf = LogCDF() + output = logcdf(Tensor([1.0, 2.0], dtype=dtype.float32)) + tol = 1e-4 + assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all() + +class SF(nn.Cell): + """ + Test class: survival function of Gumbel distribution. + """ + def __init__(self): + super(SF, self).__init__() + self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32) + + def construct(self, x_): + return self.gum.survival_function(x_) + +def test_survival(): + """ + Test log_survival. + """ + loc = np.array([0.0]).astype(np.float32) + scale = np.array([[1.0], [2.0]]).astype(np.float32) + gumbel_benchmark = stats.gumbel_r(loc, scale) + expect_survival = gumbel_benchmark.sf([1.0, 2.0]).astype(np.float32) + survival_function = SF() + output = survival_function(Tensor([1.0, 2.0], dtype=dtype.float32)) + tol = 2e-5 + assert (np.abs(output.asnumpy() - expect_survival) < tol).all() + +class LogSF(nn.Cell): + """ + Test class: log survival function of Gumbel distribution. + """ + def __init__(self): + super(LogSF, self).__init__() + self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32) + + def construct(self, x_): + return self.gum.log_survival(x_) + +def test_log_survival(): + """ + Test log_survival. + """ + loc = np.array([0.0]).astype(np.float32) + scale = np.array([[1.0], [2.0]]).astype(np.float32) + gumbel_benchmark = stats.gumbel_r(loc, scale) + expect_log_survival = gumbel_benchmark.logsf([1.0, 2.0]).astype(np.float32) + log_survival = LogSF() + output = log_survival(Tensor([1.0, 2.0], dtype=dtype.float32)) + tol = 5e-4 + assert (np.abs(output.asnumpy() - expect_log_survival) < tol).all() + +class EntropyH(nn.Cell): + """ + Test class: entropy of Gumbel distribution. + """ + def __init__(self): + super(EntropyH, self).__init__() + self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32) + + def construct(self): + return self.gum.entropy() + +def test_entropy(): + """ + Test entropy. + """ + loc = np.array([0.0]).astype(np.float32) + scale = np.array([[1.0], [2.0]]).astype(np.float32) + gumbel_benchmark = stats.gumbel_r(loc, scale) + expect_entropy = gumbel_benchmark.entropy().astype(np.float32) + entropy = EntropyH() + output = entropy() + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() + +class CrossEntropy(nn.Cell): + """ + Test class: cross entropy between Gumbel distributions. + """ + def __init__(self): + super(CrossEntropy, self).__init__() + self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32) + + def construct(self, x_, y_): + entropy = self.gum.entropy() + kl_loss = self.gum.kl_loss('Gumbel', x_, y_) + h_sum_kl = entropy + kl_loss + cross_entropy = self.gum.cross_entropy('Gumbel', x_, y_) + return h_sum_kl - cross_entropy + +def test_cross_entropy(): + """ + Test cross_entropy. + """ + cross_entropy = CrossEntropy() + loc = Tensor([1.0], dtype=dtype.float32) + scale = Tensor([1.0], dtype=dtype.float32) + diff = cross_entropy(loc, scale) + tol = 1e-6 + assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all() diff --git a/tests/ut/python/nn/probability/distribution/test_gumbel.py b/tests/ut/python/nn/probability/distribution/test_gumbel.py new file mode 100644 index 00000000000..11af5b70a95 --- /dev/null +++ b/tests/ut/python/nn/probability/distribution/test_gumbel.py @@ -0,0 +1,153 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Test nn.probability.distribution.gumbel. +""" +import numpy as np +import pytest +import mindspore.nn as nn +import mindspore.nn.probability.distribution as msd +from mindspore import dtype +from mindspore import Tensor + +def test_gumbel_shape_errpr(): + """ + Invalid shapes. + """ + with pytest.raises(ValueError): + msd.Gumbel([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32) + +def test_type(): + with pytest.raises(TypeError): + msd.Gumbel(0., 1., dtype=dtype.int32) + +def test_name(): + with pytest.raises(TypeError): + msd.Gumbel(0., 1., name=1.0) + +def test_seed(): + with pytest.raises(TypeError): + msd.Gumbel(0., 1., seed='seed') + +def test_scale(): + with pytest.raises(ValueError): + msd.Gumbel(0., 0.) + with pytest.raises(ValueError): + msd.Gumbel(0., -1.) + +def test_arguments(): + """ + args passing during initialization. + """ + l = msd.Gumbel([3.0], [4.0], dtype=dtype.float32) + assert isinstance(l, msd.Distribution) + + +class GumbelProb(nn.Cell): + """ + Gumbel distribution: initialize with loc/scale. + """ + def __init__(self): + super(GumbelProb, self).__init__() + self.gumbel = msd.Gumbel(3.0, 4.0, dtype=dtype.float32) + + def construct(self, value): + prob = self.gumbel.prob(value) + log_prob = self.gumbel.log_prob(value) + cdf = self.gumbel.cdf(value) + log_cdf = self.gumbel.log_cdf(value) + sf = self.gumbel.survival_function(value) + log_sf = self.gumbel.log_survival(value) + return prob + log_prob + cdf + log_cdf + sf + log_sf + +def test_gumbel_prob(): + """ + Test probability functions: passing value through construct. + """ + net = GumbelProb() + value = Tensor([0.5, 1.0], dtype=dtype.float32) + ans = net(value) + assert isinstance(ans, Tensor) + +class KL(nn.Cell): + """ + Test kl_loss. + """ + def __init__(self): + super(KL, self).__init__() + self.gumbel = msd.Gumbel(3.0, 4.0) + + def construct(self, mu, s): + kl = self.gumbel.kl_loss('Gumbel', mu, s) + cross_entropy = self.gumbel.cross_entropy('Gumbel', mu, s) + return kl + cross_entropy + +def test_kl_cross_entropy(): + """ + Test kl_loss and cross_entropy. + """ + net = KL() + loc_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) + scale_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) + ans = net(loc_b, scale_b) + assert isinstance(ans, Tensor) + + +class GumbelBasics(nn.Cell): + """ + Test class: basic loc/scale function. + """ + def __init__(self): + super(GumbelBasics, self).__init__() + self.gumbel = msd.Gumbel(3.0, 4.0, dtype=dtype.float32) + + def construct(self): + mean = self.gumbel.mean() + sd = self.gumbel.sd() + mode = self.gumbel.mode() + entropy = self.gumbel.entropy() + return mean + sd + mode + entropy + +def test_bascis(): + """ + Test mean/sd/mode/entropy functionality of Gumbel. + """ + net = GumbelBasics() + ans = net() + assert isinstance(ans, Tensor) + + +class GumbelConstruct(nn.Cell): + """ + Gumbel distribution: going through construct. + """ + def __init__(self): + super(GumbelConstruct, self).__init__() + self.gumbel = msd.Gumbel(3.0, 4.0) + + + def construct(self, value): + prob = self.gumbel('prob', value) + prob1 = self.gumbel.prob(value) + return prob + prob1 + +def test_gumbel_construct(): + """ + Test probability function going through construct. + """ + net = GumbelConstruct() + value = Tensor([0.5, 1.0], dtype=dtype.float32) + ans = net(value) + assert isinstance(ans, Tensor)