diff --git a/mindspore/nn/probability/distribution/__init__.py b/mindspore/nn/probability/distribution/__init__.py index db8dad5c3cf..4842a0cb68a 100644 --- a/mindspore/nn/probability/distribution/__init__.py +++ b/mindspore/nn/probability/distribution/__init__.py @@ -25,6 +25,7 @@ from .uniform import Uniform from .geometric import Geometric from .categorical import Categorical from .log_normal import LogNormal +from .logistic import Logistic __all__ = ['Distribution', 'TransformedDistribution', @@ -35,4 +36,5 @@ __all__ = ['Distribution', 'Categorical', 'Geometric', 'LogNormal', + 'Logistic', ] diff --git a/mindspore/nn/probability/distribution/logistic.py b/mindspore/nn/probability/distribution/logistic.py new file mode 100644 index 00000000000..0a8ff8fdb54 --- /dev/null +++ b/mindspore/nn/probability/distribution/logistic.py @@ -0,0 +1,327 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Logistic Distribution""" +import numpy as np +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore.common import dtype as mstype +from .distribution import Distribution +from ._utils.utils import check_greater_zero, check_type +from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, log1p_generic + + +class Logistic(Distribution): + """ + Logistic distribution. + + Args: + loc (int, float, list, numpy.ndarray, Tensor, Parameter): The location of the Logistic distribution. + scale (int, float, list, numpy.ndarray, Tensor, Parameter): The scale of the Logistic distribution. + seed (int): The seed used in sampling. The global seed is used if it is None. Default: None. + dtype (mindspore.dtype): The type of the event samples. Default: mstype.float32. + name (str): The name of the distribution. Default: 'Logistic'. + + Note: + `scale` must be greater than zero. + `dist_spec_args` are `loc` and `scale`. + `dtype` must be a float type because Logistic distributions are continuous. + + Examples: + >>> # To initialize a Logistic distribution of loc 3.0 and scale 4.0. + >>> import mindspore.nn.probability.distribution as msd + >>> n = msd.Logistic(3.0, 4.0, dtype=mstype.float32) + >>> + >>> # The following creates two independent Logistic distributions. + >>> n = msd.Logistic([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32) + >>> + >>> # A Logistic distribution can be initilize without arguments. + >>> # In this case, `loc` and `scale` must be passed in through arguments. + >>> n = msd.Logistic(dtype=mstype.float32) + >>> + >>> # To use a Normal distribution in a network. + >>> class net(Cell): + >>> def __init__(self): + >>> super(net, self).__init__(): + >>> self.l1 = msd.Logistic(0.0, 1.0, dtype=mstype.float32) + >>> self.l2 = msd.Logistic(dtype=mstype.float32) + >>> + >>> # The following calls are valid in construct. + >>> def construct(self, value, loc_b, scale_b, loc_a, scale_a): + >>> + >>> # Private interfaces of probability functions corresponding to public interfaces, including + >>> # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, have the same arguments as follows. + >>> # Args: + >>> # value (Tensor): the value to be evaluated. + >>> # loc (Tensor): the location of the distribution. Default: self.loc. + >>> # scale (Tensor): the scale of the distribution. Default: self.scale. + >>> + >>> # Examples of `prob`. + >>> # Similar calls can be made to other probability functions + >>> # by replacing 'prob' by the name of the function + >>> ans = self.l1.prob(value) + >>> # Evaluate with respect to distribution b. + >>> ans = self.l1.prob(value, loc_b, scale_b) + >>> # `loc` and `scale` must be passed in during function calls + >>> ans = self.l2.prob(value, loc_a, scale_a) + >>> + >>> # Functions `mean`, `mode`, `sd`, `var`, and `entropy` have the same arguments. + >>> # Args: + >>> # loc (Tensor): the location of the distribution. Default: self.loc. + >>> # scale (Tensor): the scale of the distribution. Default: self.scale. + >>> + >>> # Example of `mean`. `mode`, `sd`, `var`, and `entropy` are similar. + >>> ans = self.l1.mean() # return 0.0 + >>> ans = self.l1.mean(loc_b, scale_b) # return loc_b + >>> # `loc` and `scale` must be passed in during function calls. + >>> ans = self.l2.mean(loc_a, scale_a) + >>> + >>> # Examples of `sample`. + >>> # Args: + >>> # shape (tuple): the shape of the sample. Default: () + >>> # loc (Tensor): the location of the distribution. Default: self.loc. + >>> # scale (Tensor): the scale of the distribution. Default: self.scale. + >>> ans = self.l1.sample() + >>> ans = self.l1.sample((2,3)) + >>> ans = self.l1.sample((2,3), scale_b, scale_b) + >>> ans = self.l2.sample((2,3), scale_a, scale_a) + """ + + def __init__(self, + loc=None, + scale=None, + seed=None, + dtype=mstype.float32, + name="Logistic"): + """ + Constructor of Logistic. + """ + param = dict(locals()) + param['param_dict'] = {'loc': loc, 'scale': scale} + valid_dtype = mstype.float_type + check_type(dtype, valid_dtype, type(self).__name__) + super(Logistic, self).__init__(seed, dtype, name, param) + + self._loc = self._add_parameter(loc, 'loc') + self._scale = self._add_parameter(scale, 'scale') + if self._scale is not None: + check_greater_zero(self._scale, "scale") + + # ops needed for the class + self.cast = P.Cast() + self.const = P.ScalarToArray() + self.dtypeop = P.DType() + self.exp = exp_generic + self.expm1 = expm1_generic + self.fill = P.Fill() + self.less = P.Less() + self.log = log_generic + self.log1p = log1p_generic + self.logicalor = P.LogicalOr() + self.erf = P.Erf() + self.greater = P.Greater() + self.sigmoid = P.Sigmoid() + self.squeeze = P.Squeeze(0) + self.select = P.Select() + self.shape = P.Shape() + self.softplus = self._softplus + self.sqrt = P.Sqrt() + self.uniform = C.uniform + + self.threshold = np.log(np.finfo(np.float32).eps) + 1. + self.tiny = np.finfo(np.float).tiny + + def _softplus(self, x): + too_small = self.less(x, self.threshold) + too_large = self.greater(x, -self.threshold) + too_small_value = self.exp(x) + too_large_value = x + ones = self.fill(self.dtypeop(x), self.shape(x), 1.0) + too_small_or_too_large = self.logicalor(too_small, too_large) + x = self.select(too_small_or_too_large, ones, x) + y = self.log(self.exp(x) + 1.0) + return self.select(too_small, too_small_value, self.select(too_large, too_large_value, y)) + + def extend_repr(self): + if self.is_scalar_batch: + str_info = f'location = {self._loc}, scale = {self._scale}' + else: + str_info = f'batch_shape = {self._broadcast_shape}' + return str_info + + @property + def loc(self): + """ + Return the location of the distribution. + """ + return self._loc + + @property + def scale(self): + """ + Return the scale of the distribution. + """ + return self._scale + + def _mean(self, loc=None, scale=None): + """ + The mean of the distribution. + """ + loc, scale = self._check_param_type(loc, scale) + return loc + + def _mode(self, loc=None, scale=None): + """ + The mode of the distribution. + """ + loc, scale = self._check_param_type(loc, scale) + return loc + + def _sd(self, loc=None, scale=None): + """ + The standard deviation of the distribution. + """ + loc, scale = self._check_param_type(loc, scale) + return scale * self.const(np.pi) / self.sqrt(self.const(3.0)) + + def _entropy(self, loc=None, scale=None): + r""" + Evaluate entropy. + + .. math:: + H(X) = \log(scale) + 2. + """ + loc, scale = self._check_param_type(loc, scale) + return self.log(scale) + 2. + + def _log_prob(self, value, loc=None, scale=None): + r""" + Evaluate log probability. + + Args: + value (Tensor): The value to be evaluated. + loc (Tensor): The location of the distribution. Default: self.loc. + scale (Tensor): The scale of the distribution. Default: self.scale. + + .. math:: + z = (x - \mu) / \sigma + L(x) = -z * -2. * softplus(-z) - \log(\sigma) + """ + value = self._check_value(value, 'value') + value = self.cast(value, self.dtype) + loc, scale = self._check_param_type(loc, scale) + z = (value - loc) / scale + return -z - 2. * self.softplus(-z) - self.log(scale) + + def _cdf(self, value, loc=None, scale=None): + r""" + Evaluate the cumulative distribution function on the given value. + + Args: + value (Tensor): The value to be evaluated. + loc (Tensor): The location of the distribution. Default: self.loc. + scale (Tensor): The scale the distribution. Default: self.scale. + + .. math:: + cdf(x) = sigmoid((x - loc) / scale) + """ + value = self._check_value(value, 'value') + value = self.cast(value, self.dtype) + loc, scale = self._check_param_type(loc, scale) + z = (value - loc) / scale + return self.sigmoid(z) + + def _log_cdf(self, value, loc=None, scale=None): + r""" + Evaluate the log cumulative distribution function on the given value. + + Args: + value (Tensor): The value to be evaluated. + loc (Tensor): The location of the distribution. Default: self.loc. + scale (Tensor): The scale the distribution. Default: self.scale. + + .. math:: + log_cdf(x) = -softplus(-(x - loc) / scale) + """ + value = self._check_value(value, 'value') + value = self.cast(value, self.dtype) + loc, scale = self._check_param_type(loc, scale) + z = (value - loc) / scale + return -self.softplus(-z) + + def _survival_function(self, value, loc=None, scale=None): + r""" + Evaluate the survival function on the given value. + + Args: + value (Tensor): The value to be evaluated. + loc (Tensor): The location of the distribution. Default: self.loc. + scale (Tensor): The scale the distribution. Default: self.scale. + + .. math:: + survival(x) = sigmoid(-(x - loc) / scale) + """ + value = self._check_value(value, 'value') + value = self.cast(value, self.dtype) + loc, scale = self._check_param_type(loc, scale) + z = (value - loc) / scale + return self.sigmoid(-z) + + def _log_survival(self, value, loc=None, scale=None): + r""" + Evaluate the log survival function on the given value. + + Args: + value (Tensor): The value to be evaluated. + loc (Tensor): The location of the distribution. Default: self.loc. + scale (Tensor): The scale the distribution. Default: self.scale. + + .. math:: + survival(x) = -softplus((x - loc) / scale) + """ + value = self._check_value(value, 'value') + value = self.cast(value, self.dtype) + loc, scale = self._check_param_type(loc, scale) + z = (value - loc) / scale + return -self.softplus(z) + + def _sample(self, shape=(), loc=None, scale=None): + """ + Sampling. + + Args: + shape (tuple): The shape of the sample. Default: (). + loc (Tensor): The location of the samples. Default: self.loc. + scale (Tensor): The scale of the samples. Default: self.scale. + + Returns: + Tensor, with the shape being shape + batch_shape. + """ + shape = self.checktuple(shape, 'shape') + loc, scale = self._check_param_type(loc, scale) + batch_shape = self.shape(loc + scale) + origin_shape = shape + batch_shape + if origin_shape == (): + sample_shape = (1,) + else: + sample_shape = origin_shape + l_zero = self.const(self.tiny) + h_one = self.const(1.0) + sample_uniform = self.uniform(sample_shape, l_zero, h_one, self.seed) + sample = self.log(sample_uniform) - self.log1p(sample_uniform) + sample = sample * scale + loc + value = self.cast(sample, self.dtype) + if origin_shape == (): + value = self.squeeze(value) + return value diff --git a/tests/st/probability/distribution/test_logistic.py b/tests/st/probability/distribution/test_logistic.py new file mode 100644 index 00000000000..2292f37ab56 --- /dev/null +++ b/tests/st/probability/distribution/test_logistic.py @@ -0,0 +1,227 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test cases for Logistic distribution""" +import numpy as np +from scipy import stats +import mindspore.context as context +import mindspore.nn as nn +import mindspore.nn.probability.distribution as msd +from mindspore import Tensor +from mindspore import dtype + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +class Prob(nn.Cell): + """ + Test class: probability of Logistic distribution. + """ + def __init__(self): + super(Prob, self).__init__() + self.l = msd.Logistic(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) + + def construct(self, x_): + return self.l.prob(x_) + +def test_pdf(): + """ + Test pdf. + """ + logistic_benchmark = stats.logistic(np.array([3.0]), np.array([[2.0], [4.0]])) + expect_pdf = logistic_benchmark.pdf([1.0, 2.0]).astype(np.float32) + pdf = Prob() + output = pdf(Tensor([1.0, 2.0], dtype=dtype.float32)) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_pdf) < tol).all() + +class LogProb(nn.Cell): + """ + Test class: log probability of Logistic distribution. + """ + def __init__(self): + super(LogProb, self).__init__() + self.l = msd.Logistic(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) + + def construct(self, x_): + return self.l.log_prob(x_) + +def test_log_likelihood(): + """ + Test log_pdf. + """ + logistic_benchmark = stats.logistic(np.array([3.0]), np.array([[2.0], [4.0]])) + expect_logpdf = logistic_benchmark.logpdf([1.0, 2.0]).astype(np.float32) + logprob = LogProb() + output = logprob(Tensor([1.0, 2.0], dtype=dtype.float32)) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all() + +class Basics(nn.Cell): + """ + Test class: mean/sd/mode of Logistic distribution. + """ + def __init__(self): + super(Basics, self).__init__() + self.l = msd.Logistic(np.array([3.0]), np.array([2.0, 4.0]), dtype=dtype.float32) + + def construct(self): + return self.l.mean(), self.l.sd(), self.l.mode() + +def test_basics(): + """ + Test mean/standard deviation/mode. + """ + basics = Basics() + mean, sd, mode = basics() + expect_mean = [3.0, 3.0] + expect_sd = np.pi * np.array([2.0, 4.0]) / np.sqrt(np.array([3.0])) + tol = 1e-6 + assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() + assert (np.abs(mode.asnumpy() - expect_mean) < tol).all() + assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() + +class Sampling(nn.Cell): + """ + Test class: sample of Logistic distribution. + """ + def __init__(self, shape, seed=0): + super(Sampling, self).__init__() + self.l = msd.Logistic(np.array([3.0]), np.array([[2.0], [4.0]]), seed=seed, dtype=dtype.float32) + self.shape = shape + + def construct(self, mean=None, sd=None): + return self.l.sample(self.shape, mean, sd) + +def test_sample(): + """ + Test sample. + """ + shape = (2, 3) + seed = 10 + mean = Tensor([2.0], dtype=dtype.float32) + sd = Tensor([2.0, 2.0, 2.0], dtype=dtype.float32) + sample = Sampling(shape, seed=seed) + output = sample(mean, sd) + assert output.shape == (2, 3, 3) + +class CDF(nn.Cell): + """ + Test class: cdf of Logistic distribution. + """ + def __init__(self): + super(CDF, self).__init__() + self.l = msd.Logistic(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) + + def construct(self, x_): + return self.l.cdf(x_) + + +def test_cdf(): + """ + Test cdf. + """ + logistic_benchmark = stats.logistic(np.array([3.0]), np.array([[2.0], [4.0]])) + expect_cdf = logistic_benchmark.cdf([1.0, 2.0]).astype(np.float32) + cdf = CDF() + output = cdf(Tensor([1.0, 2.0], dtype=dtype.float32)) + tol = 2e-5 + assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() + +class LogCDF(nn.Cell): + """ + Test class: log_cdf of Logistic distribution. + """ + def __init__(self): + super(LogCDF, self).__init__() + self.l = msd.Logistic(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) + + def construct(self, x_): + return self.l.log_cdf(x_) + +def test_log_cdf(): + """ + Test log cdf. + """ + logistic_benchmark = stats.logistic(np.array([3.0]), np.array([[2.0], [4.0]])) + expect_logcdf = logistic_benchmark.logcdf([1.0, 2.0]).astype(np.float32) + logcdf = LogCDF() + output = logcdf(Tensor([1.0, 2.0], dtype=dtype.float32)) + tol = 5e-5 + assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all() + +class SF(nn.Cell): + """ + Test class: survival function of Logistic distribution. + """ + def __init__(self): + super(SF, self).__init__() + self.l = msd.Logistic(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) + + def construct(self, x_): + return self.l.survival_function(x_) + +def test_survival(): + """ + Test log_survival. + """ + logistic_benchmark = stats.logistic(np.array([3.0]), np.array([[2.0], [4.0]])) + expect_survival = logistic_benchmark.sf([1.0, 2.0]).astype(np.float32) + survival_function = SF() + output = survival_function(Tensor([1.0, 2.0], dtype=dtype.float32)) + tol = 2e-5 + assert (np.abs(output.asnumpy() - expect_survival) < tol).all() + +class LogSF(nn.Cell): + """ + Test class: log survival function of Logistic distribution. + """ + def __init__(self): + super(LogSF, self).__init__() + self.l = msd.Logistic(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) + + def construct(self, x_): + return self.l.log_survival(x_) + +def test_log_survival(): + """ + Test log_survival. + """ + logistic_benchmark = stats.logistic(np.array([3.0]), np.array([[2.0], [4.0]])) + expect_log_survival = logistic_benchmark.logsf([1.0, 2.0]).astype(np.float32) + log_survival = LogSF() + output = log_survival(Tensor([1.0, 2.0], dtype=dtype.float32)) + tol = 2e-5 + assert (np.abs(output.asnumpy() - expect_log_survival) < tol).all() + +class EntropyH(nn.Cell): + """ + Test class: entropy of Logistic distribution. + """ + def __init__(self): + super(EntropyH, self).__init__() + self.l = msd.Logistic(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) + + def construct(self): + return self.l.entropy() + +def test_entropy(): + """ + Test entropy. + """ + logistic_benchmark = stats.logistic(np.array([3.0]), np.array([[2.0], [4.0]])) + expect_entropy = logistic_benchmark.entropy().astype(np.float32) + entropy = EntropyH() + output = entropy() + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() diff --git a/tests/ut/python/nn/probability/distribution/test_logistic.py b/tests/ut/python/nn/probability/distribution/test_logistic.py new file mode 100644 index 00000000000..404cf71ee2b --- /dev/null +++ b/tests/ut/python/nn/probability/distribution/test_logistic.py @@ -0,0 +1,195 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Test nn.probability.distribution.logistic. +""" +import pytest + +import mindspore.nn as nn +import mindspore.nn.probability.distribution as msd +from mindspore import dtype +from mindspore import Tensor + +def test_logistic_shape_errpr(): + """ + Invalid shapes. + """ + with pytest.raises(ValueError): + msd.Logistic([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32) + +def test_type(): + with pytest.raises(TypeError): + msd.Logistic(0., 1., dtype=dtype.int32) + +def test_name(): + with pytest.raises(TypeError): + msd.Logistic(0., 1., name=1.0) + +def test_seed(): + with pytest.raises(TypeError): + msd.Logistic(0., 1., seed='seed') + +def test_scale(): + with pytest.raises(ValueError): + msd.Logistic(0., 0.) + with pytest.raises(ValueError): + msd.Logistic(0., -1.) + +def test_arguments(): + """ + args passing during initialization. + """ + l = msd.Logistic() + assert isinstance(l, msd.Distribution) + l = msd.Logistic([3.0], [4.0], dtype=dtype.float32) + assert isinstance(l, msd.Distribution) + + +class LogisticProb(nn.Cell): + """ + logistic distribution: initialize with loc/scale. + """ + def __init__(self): + super(LogisticProb, self).__init__() + self.logistic = msd.Logistic(3.0, 4.0, dtype=dtype.float32) + + def construct(self, value): + prob = self.logistic.prob(value) + log_prob = self.logistic.log_prob(value) + cdf = self.logistic.cdf(value) + log_cdf = self.logistic.log_cdf(value) + sf = self.logistic.survival_function(value) + log_sf = self.logistic.log_survival(value) + return prob + log_prob + cdf + log_cdf + sf + log_sf + +def test_logistic_prob(): + """ + Test probability functions: passing value through construct. + """ + net = LogisticProb() + value = Tensor([0.5, 1.0], dtype=dtype.float32) + ans = net(value) + assert isinstance(ans, Tensor) + + +class LogisticProb1(nn.Cell): + """ + logistic distribution: initialize without loc/scale. + """ + def __init__(self): + super(LogisticProb1, self).__init__() + self.logistic = msd.Logistic() + + def construct(self, value, mu, s): + prob = self.logistic.prob(value, mu, s) + log_prob = self.logistic.log_prob(value, mu, s) + cdf = self.logistic.cdf(value, mu, s) + log_cdf = self.logistic.log_cdf(value, mu, s) + sf = self.logistic.survival_function(value, mu, s) + log_sf = self.logistic.log_survival(value, mu, s) + return prob + log_prob + cdf + log_cdf + sf + log_sf + +def test_logistic_prob1(): + """ + Test probability functions: passing loc/scale, value through construct. + """ + net = LogisticProb1() + value = Tensor([0.5, 1.0], dtype=dtype.float32) + mu = Tensor([0.0], dtype=dtype.float32) + s = Tensor([1.0], dtype=dtype.float32) + ans = net(value, mu, s) + assert isinstance(ans, Tensor) + +class KL(nn.Cell): + """ + Test kl_loss. Should raise NotImplementedError. + """ + def __init__(self): + super(KL, self).__init__() + self.logistic = msd.Logistic(3.0, 4.0) + + def construct(self, mu, s): + kl = self.logistic.kl_loss('Logistic', mu, s) + return kl + +class Crossentropy(nn.Cell): + """ + Test cross entropy. Should raise NotImplementedError. + """ + def __init__(self): + super(Crossentropy, self).__init__() + self.logistic = msd.Logistic(3.0, 4.0) + + def construct(self, mu, s): + cross_entropy = self.logistic.cross_entropy('Logistic', mu, s) + return cross_entropy + + +class LogisticBasics(nn.Cell): + """ + Test class: basic loc/scale function. + """ + def __init__(self): + super(LogisticBasics, self).__init__() + self.logistic = msd.Logistic(3.0, 4.0, dtype=dtype.float32) + + def construct(self): + mean = self.logistic.mean() + sd = self.logistic.sd() + mode = self.logistic.mode() + entropy = self.logistic.entropy() + return mean + sd + mode + entropy + +def test_bascis(): + """ + Test mean/sd/mode/entropy functionality of logistic. + """ + net = LogisticBasics() + ans = net() + assert isinstance(ans, Tensor) + mu = Tensor(1.0, dtype=dtype.float32) + s = Tensor(1.0, dtype=dtype.float32) + with pytest.raises(NotImplementedError): + kl = KL() + ans = kl(mu, s) + with pytest.raises(NotImplementedError): + crossentropy = Crossentropy() + ans = crossentropy(mu, s) + +class LogisticConstruct(nn.Cell): + """ + logistic distribution: going through construct. + """ + def __init__(self): + super(LogisticConstruct, self).__init__() + self.logistic = msd.Logistic(3.0, 4.0) + self.logistic1 = msd.Logistic() + + def construct(self, value, mu, s): + prob = self.logistic('prob', value) + prob1 = self.logistic('prob', value, mu, s) + prob2 = self.logistic1('prob', value, mu, s) + return prob + prob1 + prob2 + +def test_logistic_construct(): + """ + Test probability function going through construct. + """ + net = LogisticConstruct() + value = Tensor([0.5, 1.0], dtype=dtype.float32) + mu = Tensor([0.0], dtype=dtype.float32) + s = Tensor([1.0], dtype=dtype.float32) + ans = net(value, mu, s) + assert isinstance(ans, Tensor)