From 9e7d6e2397a2840b30ccfd13eee46f53876f2da4 Mon Sep 17 00:00:00 2001 From: Zichun Ye Date: Sun, 23 Aug 2020 15:40:09 -0400 Subject: [PATCH] fix bernoulli prob formula; fix some other minor bugs update threshold of softplus computation support fp for bernoulli and geometric distribution --- .../probability/bijector/power_transform.py | 10 ++-- .../nn/probability/bijector/scalar_affine.py | 10 ++-- mindspore/nn/probability/bijector/softplus.py | 13 +++-- .../probability/distribution/_utils/utils.py | 50 ++++++++++++------- .../nn/probability/distribution/bernoulli.py | 4 +- .../nn/probability/distribution/geometric.py | 7 ++- .../python/nn/distribution/test_bernoulli.py | 37 +++++++++++++- .../python/nn/distribution/test_geometric.py | 22 +++++++- 8 files changed, 113 insertions(+), 40 deletions(-) diff --git a/mindspore/nn/probability/bijector/power_transform.py b/mindspore/nn/probability/bijector/power_transform.py index c6e2b9a635e..76b172a1953 100644 --- a/mindspore/nn/probability/bijector/power_transform.py +++ b/mindspore/nn/probability/bijector/power_transform.py @@ -20,6 +20,7 @@ from ..distribution._utils.utils import CheckTensor from ..distribution._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step, log1p_by_step from .bijector import Bijector + class PowerTransform(Bijector): r""" Power Bijector. @@ -49,6 +50,7 @@ class PowerTransform(Bijector): >>> # by replacing 'forward' with the name of the function >>> ans = self.p1.forward(, value) """ + def __init__(self, power=0, name='PowerTransform', @@ -78,13 +80,13 @@ class PowerTransform(Bijector): return shape def _forward(self, x): - self.checktensor(x, 'x') + self.checktensor(x, 'value') if self.power == 0: return self.exp(x) return self.exp(self.log1p(x * self.power) / self.power) def _inverse(self, y): - self.checktensor(y, 'y') + self.checktensor(y, 'value') if self.power == 0: return self.log(y) return self.expm1(self.log(y) * self.power) / self.power @@ -101,7 +103,7 @@ class PowerTransform(Bijector): f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1} \log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1) """ - self.checktensor(x, 'x') + self.checktensor(x, 'value') if self.power == 0: return x return (1. / self.power - 1) * self.log1p(x * self.power) @@ -118,5 +120,5 @@ class PowerTransform(Bijector): f'(x) = \frac{e^c\log(y)}{y} \log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y) """ - self.checktensor(y, 'y') + self.checktensor(y, 'value') return (self.power - 1) * self.log(y) diff --git a/mindspore/nn/probability/bijector/scalar_affine.py b/mindspore/nn/probability/bijector/scalar_affine.py index 276009b5fce..9707c462964 100644 --- a/mindspore/nn/probability/bijector/scalar_affine.py +++ b/mindspore/nn/probability/bijector/scalar_affine.py @@ -19,6 +19,7 @@ from ..distribution._utils.utils import cast_to_tensor, CheckTensor from ..distribution._utils.custom_ops import log_by_step from .bijector import Bijector + class ScalarAffine(Bijector): """ Scalar Affine Bijector. @@ -47,6 +48,7 @@ class ScalarAffine(Bijector): >>> ans = self.s1.forward_log_jacobian(value) >>> ans = self.s1.inverse_log_jacobian(value) """ + def __init__(self, scale=1.0, shift=0.0, @@ -91,7 +93,7 @@ class ScalarAffine(Bijector): .. math:: f(x) = a * x + b """ - self.checktensor(x, 'x') + self.checktensor(x, 'value') return self.scale * x + self.shift def _inverse(self, y): @@ -99,7 +101,7 @@ class ScalarAffine(Bijector): .. math:: f(y) = \frac{y - b}{a} """ - self.checktensor(y, 'y') + self.checktensor(y, 'value') return (y - self.shift) / self.scale def _forward_log_jacobian(self, x): @@ -109,7 +111,7 @@ class ScalarAffine(Bijector): f'(x) = a \log(f'(x)) = \log(a) """ - self.checktensor(x, 'x') + self.checktensor(x, 'value') return self.log(self.abs(self.scale)) def _inverse_log_jacobian(self, y): @@ -119,5 +121,5 @@ class ScalarAffine(Bijector): f'(x) = \frac{1.0}{a} \log(f'(x)) = - \log(a) """ - self.checktensor(y, 'y') + self.checktensor(y, 'value') return -1. * self.log(self.abs(self.scale)) diff --git a/mindspore/nn/probability/bijector/softplus.py b/mindspore/nn/probability/bijector/softplus.py index 9c0fc4e5f8a..f2396b88cd1 100644 --- a/mindspore/nn/probability/bijector/softplus.py +++ b/mindspore/nn/probability/bijector/softplus.py @@ -22,6 +22,7 @@ from ..distribution._utils.utils import cast_to_tensor, CheckTensor from ..distribution._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step from .bijector import Bijector + class Softplus(Bijector): r""" Softplus Bijector. @@ -51,6 +52,7 @@ class Softplus(Bijector): >>> ans = self.sp1.forward_log_jacobian(value) >>> ans = self.sp1.inverse_log_jacobian(value) """ + def __init__(self, sharpness=1.0, name='Softplus'): @@ -76,6 +78,7 @@ class Softplus(Bijector): self.checktensor = CheckTensor() self.threshold = np.log(np.finfo(np.float32).eps) + 1 + self.tiny = np.exp(self.threshold) def _softplus(self, x): too_small = self.less(x, self.threshold) @@ -94,7 +97,7 @@ class Softplus(Bijector): f(x) = \frac{\log(1 + e^{x}))} f^{-1}(y) = \frac{\log(e^{y} - 1)} """ - too_small = self.less(x, self.threshold) + too_small = self.less(x, self.tiny) too_large = self.greater(x, -self.threshold) too_small_value = self.log(x) too_large_value = x @@ -116,7 +119,7 @@ class Softplus(Bijector): return shape def _forward(self, x): - self.checktensor(x, 'x') + self.checktensor(x, 'value') scaled_value = self.sharpness * x return self.softplus(scaled_value) / self.sharpness @@ -126,7 +129,7 @@ class Softplus(Bijector): f(x) = \frac{\log(1 + e^{kx}))}{k} f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k} """ - self.checktensor(y, 'y') + self.checktensor(y, 'value') scaled_value = self.sharpness * y return self.inverse_softplus(scaled_value) / self.sharpness @@ -137,7 +140,7 @@ class Softplus(Bijector): f'(x) = \frac{e^{kx}}{ 1 + e^{kx}} \log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx) """ - self.checktensor(x, 'x') + self.checktensor(x, 'value') scaled_value = self.sharpness * x return self.log_sigmoid(scaled_value) @@ -148,6 +151,6 @@ class Softplus(Bijector): f'(y) = \frac{e^{ky}}{e^{ky} - 1} \log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky) """ - self.checktensor(y, 'y') + self.checktensor(y, 'value') scaled_value = self.sharpness * y return scaled_value - self.inverse_softplus(scaled_value) diff --git a/mindspore/nn/probability/distribution/_utils/utils.py b/mindspore/nn/probability/distribution/_utils/utils.py index 2da4ca30d1e..9ad95394782 100644 --- a/mindspore/nn/probability/distribution/_utils/utils.py +++ b/mindspore/nn/probability/distribution/_utils/utils.py @@ -26,6 +26,7 @@ from mindspore import context import mindspore.nn as nn import mindspore.nn.probability as msp + def cast_to_tensor(t, hint_type=mstype.float32): """ Cast an user input value into a Tensor of dtype. @@ -47,7 +48,7 @@ def cast_to_tensor(t, hint_type=mstype.float32): return t t_type = hint_type if isinstance(t, Tensor): - #convert the type of tensor to dtype + # convert the type of tensor to dtype return Tensor(t.asnumpy(), dtype=t_type) if isinstance(t, (list, np.ndarray)): return Tensor(t, dtype=t_type) @@ -56,7 +57,8 @@ def cast_to_tensor(t, hint_type=mstype.float32): if isinstance(t, (int, float)): return Tensor(t, dtype=t_type) invalid_type = type(t) - raise TypeError(f"Unable to convert input of type {invalid_type} to a Tensor of type {t_type}") + raise TypeError( + f"Unable to convert input of type {invalid_type} to a Tensor of type {t_type}") def convert_to_batch(t, batch_shape, required_type): @@ -79,6 +81,7 @@ def convert_to_batch(t, batch_shape, required_type): t = cast_to_tensor(t, required_type) return Tensor(np.broadcast_to(t.asnumpy(), batch_shape), dtype=required_type) + def check_scalar_from_param(params): """ Check if params are all scalars. @@ -93,11 +96,7 @@ def check_scalar_from_param(params): return params['distribution'].is_scalar_batch if isinstance(value, Parameter): return False - if isinstance(value, (str, type(params['dtype']))): - continue - elif isinstance(value, (int, float)): - continue - else: + if not isinstance(value, (int, float, str, type(params['dtype']))): return False return True @@ -124,7 +123,8 @@ def calc_broadcast_shape_from_param(params): value_t = value.default_input else: value_t = cast_to_tensor(value, mstype.float32) - broadcast_shape = utils.get_broadcast_shape(broadcast_shape, list(value_t.shape), params['name']) + broadcast_shape = utils.get_broadcast_shape( + broadcast_shape, list(value_t.shape), params['name']) return tuple(broadcast_shape) @@ -148,6 +148,7 @@ def check_greater_equal_zero(value, name): if comp.any(): raise ValueError(f'{name} should be greater than ot equal to zero.') + def check_greater_zero(value, name): """ Check if the given Tensor is strictly greater than zero. @@ -251,6 +252,7 @@ def probs_to_logits(probs, is_binary=False): return P.Log()(ps_clamped) - P.Log()(1-ps_clamped) return P.Log()(ps_clamped) + def check_tensor_type(name, inputs, valid_type): """ Check if inputs is proper. @@ -268,25 +270,34 @@ def check_tensor_type(name, inputs, valid_type): if input_type not in valid_type: raise TypeError(f"{name} dtype is invalid") + def check_type(data_type, value_type, name): if not data_type in value_type: - raise TypeError(f"For {name}, valid type include {value_type}, {data_type} is invalid") + raise TypeError( + f"For {name}, valid type include {value_type}, {data_type} is invalid") + @constexpr def raise_none_error(name): raise TypeError(f"the type {name} should be subclass of Tensor." f" It should not be None since it is not specified during initialization.") + @constexpr def raise_not_impl_error(name): - raise ValueError(f"{name} function should be implemented for non-linear transformation") + raise ValueError( + f"{name} function should be implemented for non-linear transformation") + @constexpr def check_distribution_name(name, expected_name): if name is None: - raise ValueError(f"Distribution should be a constant which is not None.") + raise ValueError( + f"Distribution should be a constant which is not None.") if name != expected_name: - raise ValueError(f"Expected distribution name is {expected_name}, but got {name}.") + raise ValueError( + f"Expected distribution name is {expected_name}, but got {name}.") + class CheckTuple(PrimitiveWithInfer): """ @@ -294,13 +305,13 @@ class CheckTuple(PrimitiveWithInfer): """ @prim_attr_register def __init__(self): - """init Cast""" super(CheckTuple, self).__init__("CheckTuple") - self.init_prim_io_names(inputs=['x'], outputs=['dummy_output']) + self.init_prim_io_names(inputs=['x', 'name'], outputs=['dummy_output']) def __infer__(self, x, name): if not isinstance(x['dtype'], tuple): - raise TypeError(f"For {name['value']}, Input type should b a tuple.") + raise TypeError( + f"For {name['value']}, Input type should b a tuple.") out = {'shape': None, 'dtype': None, @@ -310,24 +321,25 @@ class CheckTuple(PrimitiveWithInfer): def __call__(self, x, name): if context.get_context("mode") == 0: return x["value"] - #Pynative mode + # Pynative mode if isinstance(x, tuple): return x raise TypeError(f"For {name['value']}, Input type should b a tuple.") + class CheckTensor(PrimitiveWithInfer): """ Check if input is a Tensor. """ @prim_attr_register def __init__(self): - """init Cast""" super(CheckTensor, self).__init__("CheckTensor") - self.init_prim_io_names(inputs=['x'], outputs=['dummy_output']) + self.init_prim_io_names(inputs=['x', 'name'], outputs=['dummy_output']) def __infer__(self, x, name): src_type = x['dtype'] - validator.check_subclass("input", src_type, [mstype.tensor], name["value"]) + validator.check_subclass( + "input", src_type, [mstype.tensor], name["value"]) out = {'shape': None, 'dtype': None, diff --git a/mindspore/nn/probability/distribution/bernoulli.py b/mindspore/nn/probability/distribution/bernoulli.py index ee673833f32..d1b48ad46c2 100644 --- a/mindspore/nn/probability/distribution/bernoulli.py +++ b/mindspore/nn/probability/distribution/bernoulli.py @@ -20,6 +20,7 @@ from .distribution import Distribution from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error from ._utils.custom_ops import exp_by_step, log_by_step + class Bernoulli(Distribution): """ Bernoulli Distribution. @@ -97,7 +98,7 @@ class Bernoulli(Distribution): Constructor of Bernoulli distribution. """ param = dict(locals()) - valid_dtype = mstype.int_type + mstype.uint_type + valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type check_type(dtype, valid_dtype, "Bernoulli") super(Bernoulli, self).__init__(seed, dtype, name, param) self.parameter_type = mstype.float32 @@ -211,7 +212,6 @@ class Bernoulli(Distribution): """ self.checktensor(value, 'value') value = self.cast(value, mstype.float32) - value = self.floor(value) probs1 = self._check_param(probs1) probs0 = 1.0 - probs1 return self.log(probs1) * value + self.log(probs0) * (1.0 - value) diff --git a/mindspore/nn/probability/distribution/geometric.py b/mindspore/nn/probability/distribution/geometric.py index 19531aad44b..e5315797096 100644 --- a/mindspore/nn/probability/distribution/geometric.py +++ b/mindspore/nn/probability/distribution/geometric.py @@ -19,9 +19,10 @@ from mindspore.ops import composite as C from mindspore.common import dtype as mstype from .distribution import Distribution from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\ - raise_none_error + raise_none_error from ._utils.custom_ops import exp_by_step, log_by_step + class Geometric(Distribution): """ Geometric Distribution. @@ -100,7 +101,7 @@ class Geometric(Distribution): Constructor of Geometric distribution. """ param = dict(locals()) - valid_dtype = mstype.int_type + mstype.uint_type + valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type check_type(dtype, valid_dtype, "Geometric") super(Geometric, self).__init__(seed, dtype, name, param) self.parameter_type = mstype.float32 @@ -130,7 +131,6 @@ class Geometric(Distribution): self.sqrt = P.Sqrt() self.uniform = C.uniform - def extend_repr(self): if self.is_scalar_batch: str_info = f'probs = {self.probs}' @@ -243,7 +243,6 @@ class Geometric(Distribution): comp = self.less(value, zeros) return self.select(comp, zeros, cdf) - def _kl_loss(self, dist, probs1_b, probs1=None): r""" Evaluate Geometric-Geometric kl divergence, i.e. KL(a||b). diff --git a/tests/ut/python/nn/distribution/test_bernoulli.py b/tests/ut/python/nn/distribution/test_bernoulli.py index 29fe10a8447..c167f571dbb 100644 --- a/tests/ut/python/nn/distribution/test_bernoulli.py +++ b/tests/ut/python/nn/distribution/test_bernoulli.py @@ -22,6 +22,7 @@ import mindspore.nn.probability.distribution as msd from mindspore import dtype from mindspore import Tensor + def test_arguments(): """ Args passing during initialization. @@ -31,18 +32,22 @@ def test_arguments(): b = msd.Bernoulli([0.1, 0.3, 0.5, 0.9], dtype=dtype.int32) assert isinstance(b, msd.Distribution) + def test_type(): with pytest.raises(TypeError): - msd.Bernoulli([0.1], dtype=dtype.float32) + msd.Bernoulli([0.1], dtype=dtype.bool_) + def test_name(): with pytest.raises(TypeError): msd.Bernoulli([0.1], name=1.0) + def test_seed(): with pytest.raises(TypeError): msd.Bernoulli([0.1], seed='seed') + def test_prob(): """ Invalid probability. @@ -56,10 +61,12 @@ def test_prob(): with pytest.raises(ValueError): msd.Bernoulli([1.0], dtype=dtype.int32) + class BernoulliProb(nn.Cell): """ Bernoulli distribution: initialize with probs. """ + def __init__(self): super(BernoulliProb, self).__init__() self.b = msd.Bernoulli(0.5, dtype=dtype.int32) @@ -73,6 +80,7 @@ class BernoulliProb(nn.Cell): log_sf = self.b.log_survival(value) return prob + log_prob + cdf + log_cdf + sf + log_sf + def test_bernoulli_prob(): """ Test probability functions: passing value through construct. @@ -82,10 +90,12 @@ def test_bernoulli_prob(): ans = net(value) assert isinstance(ans, Tensor) + class BernoulliProb1(nn.Cell): """ Bernoulli distribution: initialize without probs. """ + def __init__(self): super(BernoulliProb1, self).__init__() self.b = msd.Bernoulli(dtype=dtype.int32) @@ -99,6 +109,7 @@ class BernoulliProb1(nn.Cell): log_sf = self.b.log_survival(value, probs) return prob + log_prob + cdf + log_cdf + sf + log_sf + def test_bernoulli_prob1(): """ Test probability functions: passing value/probs through construct. @@ -109,10 +120,12 @@ def test_bernoulli_prob1(): ans = net(value, probs) assert isinstance(ans, Tensor) + class BernoulliKl(nn.Cell): """ Test class: kl_loss between Bernoulli distributions. """ + def __init__(self): super(BernoulliKl, self).__init__() self.b1 = msd.Bernoulli(0.7, dtype=dtype.int32) @@ -123,6 +136,7 @@ class BernoulliKl(nn.Cell): kl2 = self.b2.kl_loss('Bernoulli', probs_b, probs_a) return kl1 + kl2 + def test_kl(): """ Test kl_loss function. @@ -133,10 +147,12 @@ def test_kl(): ans = ber_net(probs_b, probs_a) assert isinstance(ans, Tensor) + class BernoulliCrossEntropy(nn.Cell): """ Test class: cross_entropy of Bernoulli distribution. """ + def __init__(self): super(BernoulliCrossEntropy, self).__init__() self.b1 = msd.Bernoulli(0.7, dtype=dtype.int32) @@ -147,6 +163,7 @@ class BernoulliCrossEntropy(nn.Cell): h2 = self.b2.cross_entropy('Bernoulli', probs_b, probs_a) return h1 + h2 + def test_cross_entropy(): """ Test cross_entropy between Bernoulli distributions. @@ -157,10 +174,12 @@ def test_cross_entropy(): ans = net(probs_b, probs_a) assert isinstance(ans, Tensor) + class BernoulliConstruct(nn.Cell): """ Bernoulli distribution: going through construct. """ + def __init__(self): super(BernoulliConstruct, self).__init__() self.b = msd.Bernoulli(0.5, dtype=dtype.int32) @@ -172,6 +191,7 @@ class BernoulliConstruct(nn.Cell): prob2 = self.b1('prob', value, probs) return prob + prob1 + prob2 + def test_bernoulli_construct(): """ Test probability function going through construct. @@ -182,10 +202,12 @@ def test_bernoulli_construct(): ans = net(value, probs) assert isinstance(ans, Tensor) + class BernoulliMean(nn.Cell): """ Test class: basic mean/sd/var/mode/entropy function. """ + def __init__(self): super(BernoulliMean, self).__init__() self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) @@ -194,6 +216,7 @@ class BernoulliMean(nn.Cell): mean = self.b.mean() return mean + def test_mean(): """ Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. @@ -202,10 +225,12 @@ def test_mean(): ans = net() assert isinstance(ans, Tensor) + class BernoulliSd(nn.Cell): """ Test class: basic mean/sd/var/mode/entropy function. """ + def __init__(self): super(BernoulliSd, self).__init__() self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) @@ -214,6 +239,7 @@ class BernoulliSd(nn.Cell): sd = self.b.sd() return sd + def test_sd(): """ Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. @@ -222,10 +248,12 @@ def test_sd(): ans = net() assert isinstance(ans, Tensor) + class BernoulliVar(nn.Cell): """ Test class: basic mean/sd/var/mode/entropy function. """ + def __init__(self): super(BernoulliVar, self).__init__() self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) @@ -234,6 +262,7 @@ class BernoulliVar(nn.Cell): var = self.b.var() return var + def test_var(): """ Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. @@ -242,10 +271,12 @@ def test_var(): ans = net() assert isinstance(ans, Tensor) + class BernoulliMode(nn.Cell): """ Test class: basic mean/sd/var/mode/entropy function. """ + def __init__(self): super(BernoulliMode, self).__init__() self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) @@ -254,6 +285,7 @@ class BernoulliMode(nn.Cell): mode = self.b.mode() return mode + def test_mode(): """ Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. @@ -262,10 +294,12 @@ def test_mode(): ans = net() assert isinstance(ans, Tensor) + class BernoulliEntropy(nn.Cell): """ Test class: basic mean/sd/var/mode/entropy function. """ + def __init__(self): super(BernoulliEntropy, self).__init__() self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) @@ -274,6 +308,7 @@ class BernoulliEntropy(nn.Cell): entropy = self.b.entropy() return entropy + def test_entropy(): """ Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. diff --git a/tests/ut/python/nn/distribution/test_geometric.py b/tests/ut/python/nn/distribution/test_geometric.py index 4685adaa422..02f4033863f 100644 --- a/tests/ut/python/nn/distribution/test_geometric.py +++ b/tests/ut/python/nn/distribution/test_geometric.py @@ -32,18 +32,22 @@ def test_arguments(): g = msd.Geometric([0.1, 0.3, 0.5, 0.9], dtype=dtype.int32) assert isinstance(g, msd.Distribution) + def test_type(): with pytest.raises(TypeError): - msd.Geometric([0.1], dtype=dtype.float32) + msd.Geometric([0.1], dtype=dtype.bool_) + def test_name(): with pytest.raises(TypeError): msd.Geometric([0.1], name=1.0) + def test_seed(): with pytest.raises(TypeError): msd.Geometric([0.1], seed='seed') + def test_prob(): """ Invalid probability. @@ -57,10 +61,12 @@ def test_prob(): with pytest.raises(ValueError): msd.Geometric([1.0], dtype=dtype.int32) + class GeometricProb(nn.Cell): """ Geometric distribution: initialize with probs. """ + def __init__(self): super(GeometricProb, self).__init__() self.g = msd.Geometric(0.5, dtype=dtype.int32) @@ -74,6 +80,7 @@ class GeometricProb(nn.Cell): log_sf = self.g.log_survival(value) return prob + log_prob + cdf + log_cdf + sf + log_sf + def test_geometric_prob(): """ Test probability functions: passing value through construct. @@ -83,10 +90,12 @@ def test_geometric_prob(): ans = net(value) assert isinstance(ans, Tensor) + class GeometricProb1(nn.Cell): """ Geometric distribution: initialize without probs. """ + def __init__(self): super(GeometricProb1, self).__init__() self.g = msd.Geometric(dtype=dtype.int32) @@ -100,6 +109,7 @@ class GeometricProb1(nn.Cell): log_sf = self.g.log_survival(value, probs) return prob + log_prob + cdf + log_cdf + sf + log_sf + def test_geometric_prob1(): """ Test probability functions: passing value/probs through construct. @@ -115,6 +125,7 @@ class GeometricKl(nn.Cell): """ Test class: kl_loss between Geometric distributions. """ + def __init__(self): super(GeometricKl, self).__init__() self.g1 = msd.Geometric(0.7, dtype=dtype.int32) @@ -125,6 +136,7 @@ class GeometricKl(nn.Cell): kl2 = self.g2.kl_loss('Geometric', probs_b, probs_a) return kl1 + kl2 + def test_kl(): """ Test kl_loss function. @@ -135,10 +147,12 @@ def test_kl(): ans = ber_net(probs_b, probs_a) assert isinstance(ans, Tensor) + class GeometricCrossEntropy(nn.Cell): """ Test class: cross_entropy of Geometric distribution. """ + def __init__(self): super(GeometricCrossEntropy, self).__init__() self.g1 = msd.Geometric(0.3, dtype=dtype.int32) @@ -149,6 +163,7 @@ class GeometricCrossEntropy(nn.Cell): h2 = self.g2.cross_entropy('Geometric', probs_b, probs_a) return h1 + h2 + def test_cross_entropy(): """ Test cross_entropy between Geometric distributions. @@ -159,10 +174,12 @@ def test_cross_entropy(): ans = net(probs_b, probs_a) assert isinstance(ans, Tensor) + class GeometricBasics(nn.Cell): """ Test class: basic mean/sd/mode/entropy function. """ + def __init__(self): super(GeometricBasics, self).__init__() self.g = msd.Geometric([0.3, 0.5], dtype=dtype.int32) @@ -175,6 +192,7 @@ class GeometricBasics(nn.Cell): entropy = self.g.entropy() return mean + sd + var + mode + entropy + def test_bascis(): """ Test mean/sd/mode/entropy functionality of Geometric distribution. @@ -188,6 +206,7 @@ class GeoConstruct(nn.Cell): """ Bernoulli distribution: going through construct. """ + def __init__(self): super(GeoConstruct, self).__init__() self.g = msd.Geometric(0.5, dtype=dtype.int32) @@ -199,6 +218,7 @@ class GeoConstruct(nn.Cell): prob2 = self.g1('prob', value, probs) return prob + prob1 + prob2 + def test_geo_construct(): """ Test probability function going through construct.