!5005 Fix the bug in the formula in Bernoulli log_probs

Merge pull request !5005 from zichun_ye/fix_bernoulli_probs
This commit is contained in:
mindspore-ci-bot 2020-08-24 09:09:43 +08:00 committed by Gitee
commit 56835aaf88
8 changed files with 113 additions and 40 deletions

View File

@ -20,6 +20,7 @@ from ..distribution._utils.utils import CheckTensor
from ..distribution._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step, log1p_by_step from ..distribution._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step, log1p_by_step
from .bijector import Bijector from .bijector import Bijector
class PowerTransform(Bijector): class PowerTransform(Bijector):
r""" r"""
Power Bijector. Power Bijector.
@ -49,6 +50,7 @@ class PowerTransform(Bijector):
>>> # by replacing 'forward' with the name of the function >>> # by replacing 'forward' with the name of the function
>>> ans = self.p1.forward(, value) >>> ans = self.p1.forward(, value)
""" """
def __init__(self, def __init__(self,
power=0, power=0,
name='PowerTransform', name='PowerTransform',
@ -78,13 +80,13 @@ class PowerTransform(Bijector):
return shape return shape
def _forward(self, x): def _forward(self, x):
self.checktensor(x, 'x') self.checktensor(x, 'value')
if self.power == 0: if self.power == 0:
return self.exp(x) return self.exp(x)
return self.exp(self.log1p(x * self.power) / self.power) return self.exp(self.log1p(x * self.power) / self.power)
def _inverse(self, y): def _inverse(self, y):
self.checktensor(y, 'y') self.checktensor(y, 'value')
if self.power == 0: if self.power == 0:
return self.log(y) return self.log(y)
return self.expm1(self.log(y) * self.power) / self.power return self.expm1(self.log(y) * self.power) / self.power
@ -101,7 +103,7 @@ class PowerTransform(Bijector):
f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1} f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1}
\log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1) \log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1)
""" """
self.checktensor(x, 'x') self.checktensor(x, 'value')
if self.power == 0: if self.power == 0:
return x return x
return (1. / self.power - 1) * self.log1p(x * self.power) return (1. / self.power - 1) * self.log1p(x * self.power)
@ -118,5 +120,5 @@ class PowerTransform(Bijector):
f'(x) = \frac{e^c\log(y)}{y} f'(x) = \frac{e^c\log(y)}{y}
\log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y) \log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y)
""" """
self.checktensor(y, 'y') self.checktensor(y, 'value')
return (self.power - 1) * self.log(y) return (self.power - 1) * self.log(y)

View File

@ -19,6 +19,7 @@ from ..distribution._utils.utils import cast_to_tensor, CheckTensor
from ..distribution._utils.custom_ops import log_by_step from ..distribution._utils.custom_ops import log_by_step
from .bijector import Bijector from .bijector import Bijector
class ScalarAffine(Bijector): class ScalarAffine(Bijector):
""" """
Scalar Affine Bijector. Scalar Affine Bijector.
@ -47,6 +48,7 @@ class ScalarAffine(Bijector):
>>> ans = self.s1.forward_log_jacobian(value) >>> ans = self.s1.forward_log_jacobian(value)
>>> ans = self.s1.inverse_log_jacobian(value) >>> ans = self.s1.inverse_log_jacobian(value)
""" """
def __init__(self, def __init__(self,
scale=1.0, scale=1.0,
shift=0.0, shift=0.0,
@ -91,7 +93,7 @@ class ScalarAffine(Bijector):
.. math:: .. math::
f(x) = a * x + b f(x) = a * x + b
""" """
self.checktensor(x, 'x') self.checktensor(x, 'value')
return self.scale * x + self.shift return self.scale * x + self.shift
def _inverse(self, y): def _inverse(self, y):
@ -99,7 +101,7 @@ class ScalarAffine(Bijector):
.. math:: .. math::
f(y) = \frac{y - b}{a} f(y) = \frac{y - b}{a}
""" """
self.checktensor(y, 'y') self.checktensor(y, 'value')
return (y - self.shift) / self.scale return (y - self.shift) / self.scale
def _forward_log_jacobian(self, x): def _forward_log_jacobian(self, x):
@ -109,7 +111,7 @@ class ScalarAffine(Bijector):
f'(x) = a f'(x) = a
\log(f'(x)) = \log(a) \log(f'(x)) = \log(a)
""" """
self.checktensor(x, 'x') self.checktensor(x, 'value')
return self.log(self.abs(self.scale)) return self.log(self.abs(self.scale))
def _inverse_log_jacobian(self, y): def _inverse_log_jacobian(self, y):
@ -119,5 +121,5 @@ class ScalarAffine(Bijector):
f'(x) = \frac{1.0}{a} f'(x) = \frac{1.0}{a}
\log(f'(x)) = - \log(a) \log(f'(x)) = - \log(a)
""" """
self.checktensor(y, 'y') self.checktensor(y, 'value')
return -1. * self.log(self.abs(self.scale)) return -1. * self.log(self.abs(self.scale))

View File

@ -22,6 +22,7 @@ from ..distribution._utils.utils import cast_to_tensor, CheckTensor
from ..distribution._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step from ..distribution._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step
from .bijector import Bijector from .bijector import Bijector
class Softplus(Bijector): class Softplus(Bijector):
r""" r"""
Softplus Bijector. Softplus Bijector.
@ -51,6 +52,7 @@ class Softplus(Bijector):
>>> ans = self.sp1.forward_log_jacobian(value) >>> ans = self.sp1.forward_log_jacobian(value)
>>> ans = self.sp1.inverse_log_jacobian(value) >>> ans = self.sp1.inverse_log_jacobian(value)
""" """
def __init__(self, def __init__(self,
sharpness=1.0, sharpness=1.0,
name='Softplus'): name='Softplus'):
@ -76,6 +78,7 @@ class Softplus(Bijector):
self.checktensor = CheckTensor() self.checktensor = CheckTensor()
self.threshold = np.log(np.finfo(np.float32).eps) + 1 self.threshold = np.log(np.finfo(np.float32).eps) + 1
self.tiny = np.exp(self.threshold)
def _softplus(self, x): def _softplus(self, x):
too_small = self.less(x, self.threshold) too_small = self.less(x, self.threshold)
@ -94,7 +97,7 @@ class Softplus(Bijector):
f(x) = \frac{\log(1 + e^{x}))} f(x) = \frac{\log(1 + e^{x}))}
f^{-1}(y) = \frac{\log(e^{y} - 1)} f^{-1}(y) = \frac{\log(e^{y} - 1)}
""" """
too_small = self.less(x, self.threshold) too_small = self.less(x, self.tiny)
too_large = self.greater(x, -self.threshold) too_large = self.greater(x, -self.threshold)
too_small_value = self.log(x) too_small_value = self.log(x)
too_large_value = x too_large_value = x
@ -116,7 +119,7 @@ class Softplus(Bijector):
return shape return shape
def _forward(self, x): def _forward(self, x):
self.checktensor(x, 'x') self.checktensor(x, 'value')
scaled_value = self.sharpness * x scaled_value = self.sharpness * x
return self.softplus(scaled_value) / self.sharpness return self.softplus(scaled_value) / self.sharpness
@ -126,7 +129,7 @@ class Softplus(Bijector):
f(x) = \frac{\log(1 + e^{kx}))}{k} f(x) = \frac{\log(1 + e^{kx}))}{k}
f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k} f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k}
""" """
self.checktensor(y, 'y') self.checktensor(y, 'value')
scaled_value = self.sharpness * y scaled_value = self.sharpness * y
return self.inverse_softplus(scaled_value) / self.sharpness return self.inverse_softplus(scaled_value) / self.sharpness
@ -137,7 +140,7 @@ class Softplus(Bijector):
f'(x) = \frac{e^{kx}}{ 1 + e^{kx}} f'(x) = \frac{e^{kx}}{ 1 + e^{kx}}
\log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx) \log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx)
""" """
self.checktensor(x, 'x') self.checktensor(x, 'value')
scaled_value = self.sharpness * x scaled_value = self.sharpness * x
return self.log_sigmoid(scaled_value) return self.log_sigmoid(scaled_value)
@ -148,6 +151,6 @@ class Softplus(Bijector):
f'(y) = \frac{e^{ky}}{e^{ky} - 1} f'(y) = \frac{e^{ky}}{e^{ky} - 1}
\log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky) \log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky)
""" """
self.checktensor(y, 'y') self.checktensor(y, 'value')
scaled_value = self.sharpness * y scaled_value = self.sharpness * y
return scaled_value - self.inverse_softplus(scaled_value) return scaled_value - self.inverse_softplus(scaled_value)

View File

@ -26,6 +26,7 @@ from mindspore import context
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.nn.probability as msp import mindspore.nn.probability as msp
def cast_to_tensor(t, hint_type=mstype.float32): def cast_to_tensor(t, hint_type=mstype.float32):
""" """
Cast an user input value into a Tensor of dtype. Cast an user input value into a Tensor of dtype.
@ -47,7 +48,7 @@ def cast_to_tensor(t, hint_type=mstype.float32):
return t return t
t_type = hint_type t_type = hint_type
if isinstance(t, Tensor): if isinstance(t, Tensor):
#convert the type of tensor to dtype # convert the type of tensor to dtype
return Tensor(t.asnumpy(), dtype=t_type) return Tensor(t.asnumpy(), dtype=t_type)
if isinstance(t, (list, np.ndarray)): if isinstance(t, (list, np.ndarray)):
return Tensor(t, dtype=t_type) return Tensor(t, dtype=t_type)
@ -56,7 +57,8 @@ def cast_to_tensor(t, hint_type=mstype.float32):
if isinstance(t, (int, float)): if isinstance(t, (int, float)):
return Tensor(t, dtype=t_type) return Tensor(t, dtype=t_type)
invalid_type = type(t) invalid_type = type(t)
raise TypeError(f"Unable to convert input of type {invalid_type} to a Tensor of type {t_type}") raise TypeError(
f"Unable to convert input of type {invalid_type} to a Tensor of type {t_type}")
def convert_to_batch(t, batch_shape, required_type): def convert_to_batch(t, batch_shape, required_type):
@ -79,6 +81,7 @@ def convert_to_batch(t, batch_shape, required_type):
t = cast_to_tensor(t, required_type) t = cast_to_tensor(t, required_type)
return Tensor(np.broadcast_to(t.asnumpy(), batch_shape), dtype=required_type) return Tensor(np.broadcast_to(t.asnumpy(), batch_shape), dtype=required_type)
def check_scalar_from_param(params): def check_scalar_from_param(params):
""" """
Check if params are all scalars. Check if params are all scalars.
@ -93,11 +96,7 @@ def check_scalar_from_param(params):
return params['distribution'].is_scalar_batch return params['distribution'].is_scalar_batch
if isinstance(value, Parameter): if isinstance(value, Parameter):
return False return False
if isinstance(value, (str, type(params['dtype']))): if not isinstance(value, (int, float, str, type(params['dtype']))):
continue
elif isinstance(value, (int, float)):
continue
else:
return False return False
return True return True
@ -124,7 +123,8 @@ def calc_broadcast_shape_from_param(params):
value_t = value.default_input value_t = value.default_input
else: else:
value_t = cast_to_tensor(value, mstype.float32) value_t = cast_to_tensor(value, mstype.float32)
broadcast_shape = utils.get_broadcast_shape(broadcast_shape, list(value_t.shape), params['name']) broadcast_shape = utils.get_broadcast_shape(
broadcast_shape, list(value_t.shape), params['name'])
return tuple(broadcast_shape) return tuple(broadcast_shape)
@ -148,6 +148,7 @@ def check_greater_equal_zero(value, name):
if comp.any(): if comp.any():
raise ValueError(f'{name} should be greater than ot equal to zero.') raise ValueError(f'{name} should be greater than ot equal to zero.')
def check_greater_zero(value, name): def check_greater_zero(value, name):
""" """
Check if the given Tensor is strictly greater than zero. Check if the given Tensor is strictly greater than zero.
@ -251,6 +252,7 @@ def probs_to_logits(probs, is_binary=False):
return P.Log()(ps_clamped) - P.Log()(1-ps_clamped) return P.Log()(ps_clamped) - P.Log()(1-ps_clamped)
return P.Log()(ps_clamped) return P.Log()(ps_clamped)
def check_tensor_type(name, inputs, valid_type): def check_tensor_type(name, inputs, valid_type):
""" """
Check if inputs is proper. Check if inputs is proper.
@ -268,25 +270,34 @@ def check_tensor_type(name, inputs, valid_type):
if input_type not in valid_type: if input_type not in valid_type:
raise TypeError(f"{name} dtype is invalid") raise TypeError(f"{name} dtype is invalid")
def check_type(data_type, value_type, name): def check_type(data_type, value_type, name):
if not data_type in value_type: if not data_type in value_type:
raise TypeError(f"For {name}, valid type include {value_type}, {data_type} is invalid") raise TypeError(
f"For {name}, valid type include {value_type}, {data_type} is invalid")
@constexpr @constexpr
def raise_none_error(name): def raise_none_error(name):
raise TypeError(f"the type {name} should be subclass of Tensor." raise TypeError(f"the type {name} should be subclass of Tensor."
f" It should not be None since it is not specified during initialization.") f" It should not be None since it is not specified during initialization.")
@constexpr @constexpr
def raise_not_impl_error(name): def raise_not_impl_error(name):
raise ValueError(f"{name} function should be implemented for non-linear transformation") raise ValueError(
f"{name} function should be implemented for non-linear transformation")
@constexpr @constexpr
def check_distribution_name(name, expected_name): def check_distribution_name(name, expected_name):
if name is None: if name is None:
raise ValueError(f"Distribution should be a constant which is not None.") raise ValueError(
f"Distribution should be a constant which is not None.")
if name != expected_name: if name != expected_name:
raise ValueError(f"Expected distribution name is {expected_name}, but got {name}.") raise ValueError(
f"Expected distribution name is {expected_name}, but got {name}.")
class CheckTuple(PrimitiveWithInfer): class CheckTuple(PrimitiveWithInfer):
""" """
@ -294,13 +305,13 @@ class CheckTuple(PrimitiveWithInfer):
""" """
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""init Cast"""
super(CheckTuple, self).__init__("CheckTuple") super(CheckTuple, self).__init__("CheckTuple")
self.init_prim_io_names(inputs=['x'], outputs=['dummy_output']) self.init_prim_io_names(inputs=['x', 'name'], outputs=['dummy_output'])
def __infer__(self, x, name): def __infer__(self, x, name):
if not isinstance(x['dtype'], tuple): if not isinstance(x['dtype'], tuple):
raise TypeError(f"For {name['value']}, Input type should b a tuple.") raise TypeError(
f"For {name['value']}, Input type should b a tuple.")
out = {'shape': None, out = {'shape': None,
'dtype': None, 'dtype': None,
@ -310,24 +321,25 @@ class CheckTuple(PrimitiveWithInfer):
def __call__(self, x, name): def __call__(self, x, name):
if context.get_context("mode") == 0: if context.get_context("mode") == 0:
return x["value"] return x["value"]
#Pynative mode # Pynative mode
if isinstance(x, tuple): if isinstance(x, tuple):
return x return x
raise TypeError(f"For {name['value']}, Input type should b a tuple.") raise TypeError(f"For {name['value']}, Input type should b a tuple.")
class CheckTensor(PrimitiveWithInfer): class CheckTensor(PrimitiveWithInfer):
""" """
Check if input is a Tensor. Check if input is a Tensor.
""" """
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""init Cast"""
super(CheckTensor, self).__init__("CheckTensor") super(CheckTensor, self).__init__("CheckTensor")
self.init_prim_io_names(inputs=['x'], outputs=['dummy_output']) self.init_prim_io_names(inputs=['x', 'name'], outputs=['dummy_output'])
def __infer__(self, x, name): def __infer__(self, x, name):
src_type = x['dtype'] src_type = x['dtype']
validator.check_subclass("input", src_type, [mstype.tensor], name["value"]) validator.check_subclass(
"input", src_type, [mstype.tensor], name["value"])
out = {'shape': None, out = {'shape': None,
'dtype': None, 'dtype': None,

View File

@ -20,6 +20,7 @@ from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error
from ._utils.custom_ops import exp_by_step, log_by_step from ._utils.custom_ops import exp_by_step, log_by_step
class Bernoulli(Distribution): class Bernoulli(Distribution):
""" """
Bernoulli Distribution. Bernoulli Distribution.
@ -97,7 +98,7 @@ class Bernoulli(Distribution):
Constructor of Bernoulli distribution. Constructor of Bernoulli distribution.
""" """
param = dict(locals()) param = dict(locals())
valid_dtype = mstype.int_type + mstype.uint_type valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
check_type(dtype, valid_dtype, "Bernoulli") check_type(dtype, valid_dtype, "Bernoulli")
super(Bernoulli, self).__init__(seed, dtype, name, param) super(Bernoulli, self).__init__(seed, dtype, name, param)
self.parameter_type = mstype.float32 self.parameter_type = mstype.float32
@ -211,7 +212,6 @@ class Bernoulli(Distribution):
""" """
self.checktensor(value, 'value') self.checktensor(value, 'value')
value = self.cast(value, mstype.float32) value = self.cast(value, mstype.float32)
value = self.floor(value)
probs1 = self._check_param(probs1) probs1 = self._check_param(probs1)
probs0 = 1.0 - probs1 probs0 = 1.0 - probs1
return self.log(probs1) * value + self.log(probs0) * (1.0 - value) return self.log(probs1) * value + self.log(probs0) * (1.0 - value)

View File

@ -19,9 +19,10 @@ from mindspore.ops import composite as C
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\ from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\
raise_none_error raise_none_error
from ._utils.custom_ops import exp_by_step, log_by_step from ._utils.custom_ops import exp_by_step, log_by_step
class Geometric(Distribution): class Geometric(Distribution):
""" """
Geometric Distribution. Geometric Distribution.
@ -100,7 +101,7 @@ class Geometric(Distribution):
Constructor of Geometric distribution. Constructor of Geometric distribution.
""" """
param = dict(locals()) param = dict(locals())
valid_dtype = mstype.int_type + mstype.uint_type valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
check_type(dtype, valid_dtype, "Geometric") check_type(dtype, valid_dtype, "Geometric")
super(Geometric, self).__init__(seed, dtype, name, param) super(Geometric, self).__init__(seed, dtype, name, param)
self.parameter_type = mstype.float32 self.parameter_type = mstype.float32
@ -130,7 +131,6 @@ class Geometric(Distribution):
self.sqrt = P.Sqrt() self.sqrt = P.Sqrt()
self.uniform = C.uniform self.uniform = C.uniform
def extend_repr(self): def extend_repr(self):
if self.is_scalar_batch: if self.is_scalar_batch:
str_info = f'probs = {self.probs}' str_info = f'probs = {self.probs}'
@ -243,7 +243,6 @@ class Geometric(Distribution):
comp = self.less(value, zeros) comp = self.less(value, zeros)
return self.select(comp, zeros, cdf) return self.select(comp, zeros, cdf)
def _kl_loss(self, dist, probs1_b, probs1=None): def _kl_loss(self, dist, probs1_b, probs1=None):
r""" r"""
Evaluate Geometric-Geometric kl divergence, i.e. KL(a||b). Evaluate Geometric-Geometric kl divergence, i.e. KL(a||b).

View File

@ -22,6 +22,7 @@ import mindspore.nn.probability.distribution as msd
from mindspore import dtype from mindspore import dtype
from mindspore import Tensor from mindspore import Tensor
def test_arguments(): def test_arguments():
""" """
Args passing during initialization. Args passing during initialization.
@ -31,18 +32,22 @@ def test_arguments():
b = msd.Bernoulli([0.1, 0.3, 0.5, 0.9], dtype=dtype.int32) b = msd.Bernoulli([0.1, 0.3, 0.5, 0.9], dtype=dtype.int32)
assert isinstance(b, msd.Distribution) assert isinstance(b, msd.Distribution)
def test_type(): def test_type():
with pytest.raises(TypeError): with pytest.raises(TypeError):
msd.Bernoulli([0.1], dtype=dtype.float32) msd.Bernoulli([0.1], dtype=dtype.bool_)
def test_name(): def test_name():
with pytest.raises(TypeError): with pytest.raises(TypeError):
msd.Bernoulli([0.1], name=1.0) msd.Bernoulli([0.1], name=1.0)
def test_seed(): def test_seed():
with pytest.raises(TypeError): with pytest.raises(TypeError):
msd.Bernoulli([0.1], seed='seed') msd.Bernoulli([0.1], seed='seed')
def test_prob(): def test_prob():
""" """
Invalid probability. Invalid probability.
@ -56,10 +61,12 @@ def test_prob():
with pytest.raises(ValueError): with pytest.raises(ValueError):
msd.Bernoulli([1.0], dtype=dtype.int32) msd.Bernoulli([1.0], dtype=dtype.int32)
class BernoulliProb(nn.Cell): class BernoulliProb(nn.Cell):
""" """
Bernoulli distribution: initialize with probs. Bernoulli distribution: initialize with probs.
""" """
def __init__(self): def __init__(self):
super(BernoulliProb, self).__init__() super(BernoulliProb, self).__init__()
self.b = msd.Bernoulli(0.5, dtype=dtype.int32) self.b = msd.Bernoulli(0.5, dtype=dtype.int32)
@ -73,6 +80,7 @@ class BernoulliProb(nn.Cell):
log_sf = self.b.log_survival(value) log_sf = self.b.log_survival(value)
return prob + log_prob + cdf + log_cdf + sf + log_sf return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_bernoulli_prob(): def test_bernoulli_prob():
""" """
Test probability functions: passing value through construct. Test probability functions: passing value through construct.
@ -82,10 +90,12 @@ def test_bernoulli_prob():
ans = net(value) ans = net(value)
assert isinstance(ans, Tensor) assert isinstance(ans, Tensor)
class BernoulliProb1(nn.Cell): class BernoulliProb1(nn.Cell):
""" """
Bernoulli distribution: initialize without probs. Bernoulli distribution: initialize without probs.
""" """
def __init__(self): def __init__(self):
super(BernoulliProb1, self).__init__() super(BernoulliProb1, self).__init__()
self.b = msd.Bernoulli(dtype=dtype.int32) self.b = msd.Bernoulli(dtype=dtype.int32)
@ -99,6 +109,7 @@ class BernoulliProb1(nn.Cell):
log_sf = self.b.log_survival(value, probs) log_sf = self.b.log_survival(value, probs)
return prob + log_prob + cdf + log_cdf + sf + log_sf return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_bernoulli_prob1(): def test_bernoulli_prob1():
""" """
Test probability functions: passing value/probs through construct. Test probability functions: passing value/probs through construct.
@ -109,10 +120,12 @@ def test_bernoulli_prob1():
ans = net(value, probs) ans = net(value, probs)
assert isinstance(ans, Tensor) assert isinstance(ans, Tensor)
class BernoulliKl(nn.Cell): class BernoulliKl(nn.Cell):
""" """
Test class: kl_loss between Bernoulli distributions. Test class: kl_loss between Bernoulli distributions.
""" """
def __init__(self): def __init__(self):
super(BernoulliKl, self).__init__() super(BernoulliKl, self).__init__()
self.b1 = msd.Bernoulli(0.7, dtype=dtype.int32) self.b1 = msd.Bernoulli(0.7, dtype=dtype.int32)
@ -123,6 +136,7 @@ class BernoulliKl(nn.Cell):
kl2 = self.b2.kl_loss('Bernoulli', probs_b, probs_a) kl2 = self.b2.kl_loss('Bernoulli', probs_b, probs_a)
return kl1 + kl2 return kl1 + kl2
def test_kl(): def test_kl():
""" """
Test kl_loss function. Test kl_loss function.
@ -133,10 +147,12 @@ def test_kl():
ans = ber_net(probs_b, probs_a) ans = ber_net(probs_b, probs_a)
assert isinstance(ans, Tensor) assert isinstance(ans, Tensor)
class BernoulliCrossEntropy(nn.Cell): class BernoulliCrossEntropy(nn.Cell):
""" """
Test class: cross_entropy of Bernoulli distribution. Test class: cross_entropy of Bernoulli distribution.
""" """
def __init__(self): def __init__(self):
super(BernoulliCrossEntropy, self).__init__() super(BernoulliCrossEntropy, self).__init__()
self.b1 = msd.Bernoulli(0.7, dtype=dtype.int32) self.b1 = msd.Bernoulli(0.7, dtype=dtype.int32)
@ -147,6 +163,7 @@ class BernoulliCrossEntropy(nn.Cell):
h2 = self.b2.cross_entropy('Bernoulli', probs_b, probs_a) h2 = self.b2.cross_entropy('Bernoulli', probs_b, probs_a)
return h1 + h2 return h1 + h2
def test_cross_entropy(): def test_cross_entropy():
""" """
Test cross_entropy between Bernoulli distributions. Test cross_entropy between Bernoulli distributions.
@ -157,10 +174,12 @@ def test_cross_entropy():
ans = net(probs_b, probs_a) ans = net(probs_b, probs_a)
assert isinstance(ans, Tensor) assert isinstance(ans, Tensor)
class BernoulliConstruct(nn.Cell): class BernoulliConstruct(nn.Cell):
""" """
Bernoulli distribution: going through construct. Bernoulli distribution: going through construct.
""" """
def __init__(self): def __init__(self):
super(BernoulliConstruct, self).__init__() super(BernoulliConstruct, self).__init__()
self.b = msd.Bernoulli(0.5, dtype=dtype.int32) self.b = msd.Bernoulli(0.5, dtype=dtype.int32)
@ -172,6 +191,7 @@ class BernoulliConstruct(nn.Cell):
prob2 = self.b1('prob', value, probs) prob2 = self.b1('prob', value, probs)
return prob + prob1 + prob2 return prob + prob1 + prob2
def test_bernoulli_construct(): def test_bernoulli_construct():
""" """
Test probability function going through construct. Test probability function going through construct.
@ -182,10 +202,12 @@ def test_bernoulli_construct():
ans = net(value, probs) ans = net(value, probs)
assert isinstance(ans, Tensor) assert isinstance(ans, Tensor)
class BernoulliMean(nn.Cell): class BernoulliMean(nn.Cell):
""" """
Test class: basic mean/sd/var/mode/entropy function. Test class: basic mean/sd/var/mode/entropy function.
""" """
def __init__(self): def __init__(self):
super(BernoulliMean, self).__init__() super(BernoulliMean, self).__init__()
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
@ -194,6 +216,7 @@ class BernoulliMean(nn.Cell):
mean = self.b.mean() mean = self.b.mean()
return mean return mean
def test_mean(): def test_mean():
""" """
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
@ -202,10 +225,12 @@ def test_mean():
ans = net() ans = net()
assert isinstance(ans, Tensor) assert isinstance(ans, Tensor)
class BernoulliSd(nn.Cell): class BernoulliSd(nn.Cell):
""" """
Test class: basic mean/sd/var/mode/entropy function. Test class: basic mean/sd/var/mode/entropy function.
""" """
def __init__(self): def __init__(self):
super(BernoulliSd, self).__init__() super(BernoulliSd, self).__init__()
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
@ -214,6 +239,7 @@ class BernoulliSd(nn.Cell):
sd = self.b.sd() sd = self.b.sd()
return sd return sd
def test_sd(): def test_sd():
""" """
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
@ -222,10 +248,12 @@ def test_sd():
ans = net() ans = net()
assert isinstance(ans, Tensor) assert isinstance(ans, Tensor)
class BernoulliVar(nn.Cell): class BernoulliVar(nn.Cell):
""" """
Test class: basic mean/sd/var/mode/entropy function. Test class: basic mean/sd/var/mode/entropy function.
""" """
def __init__(self): def __init__(self):
super(BernoulliVar, self).__init__() super(BernoulliVar, self).__init__()
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
@ -234,6 +262,7 @@ class BernoulliVar(nn.Cell):
var = self.b.var() var = self.b.var()
return var return var
def test_var(): def test_var():
""" """
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
@ -242,10 +271,12 @@ def test_var():
ans = net() ans = net()
assert isinstance(ans, Tensor) assert isinstance(ans, Tensor)
class BernoulliMode(nn.Cell): class BernoulliMode(nn.Cell):
""" """
Test class: basic mean/sd/var/mode/entropy function. Test class: basic mean/sd/var/mode/entropy function.
""" """
def __init__(self): def __init__(self):
super(BernoulliMode, self).__init__() super(BernoulliMode, self).__init__()
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
@ -254,6 +285,7 @@ class BernoulliMode(nn.Cell):
mode = self.b.mode() mode = self.b.mode()
return mode return mode
def test_mode(): def test_mode():
""" """
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
@ -262,10 +294,12 @@ def test_mode():
ans = net() ans = net()
assert isinstance(ans, Tensor) assert isinstance(ans, Tensor)
class BernoulliEntropy(nn.Cell): class BernoulliEntropy(nn.Cell):
""" """
Test class: basic mean/sd/var/mode/entropy function. Test class: basic mean/sd/var/mode/entropy function.
""" """
def __init__(self): def __init__(self):
super(BernoulliEntropy, self).__init__() super(BernoulliEntropy, self).__init__()
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
@ -274,6 +308,7 @@ class BernoulliEntropy(nn.Cell):
entropy = self.b.entropy() entropy = self.b.entropy()
return entropy return entropy
def test_entropy(): def test_entropy():
""" """
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.

View File

@ -32,18 +32,22 @@ def test_arguments():
g = msd.Geometric([0.1, 0.3, 0.5, 0.9], dtype=dtype.int32) g = msd.Geometric([0.1, 0.3, 0.5, 0.9], dtype=dtype.int32)
assert isinstance(g, msd.Distribution) assert isinstance(g, msd.Distribution)
def test_type(): def test_type():
with pytest.raises(TypeError): with pytest.raises(TypeError):
msd.Geometric([0.1], dtype=dtype.float32) msd.Geometric([0.1], dtype=dtype.bool_)
def test_name(): def test_name():
with pytest.raises(TypeError): with pytest.raises(TypeError):
msd.Geometric([0.1], name=1.0) msd.Geometric([0.1], name=1.0)
def test_seed(): def test_seed():
with pytest.raises(TypeError): with pytest.raises(TypeError):
msd.Geometric([0.1], seed='seed') msd.Geometric([0.1], seed='seed')
def test_prob(): def test_prob():
""" """
Invalid probability. Invalid probability.
@ -57,10 +61,12 @@ def test_prob():
with pytest.raises(ValueError): with pytest.raises(ValueError):
msd.Geometric([1.0], dtype=dtype.int32) msd.Geometric([1.0], dtype=dtype.int32)
class GeometricProb(nn.Cell): class GeometricProb(nn.Cell):
""" """
Geometric distribution: initialize with probs. Geometric distribution: initialize with probs.
""" """
def __init__(self): def __init__(self):
super(GeometricProb, self).__init__() super(GeometricProb, self).__init__()
self.g = msd.Geometric(0.5, dtype=dtype.int32) self.g = msd.Geometric(0.5, dtype=dtype.int32)
@ -74,6 +80,7 @@ class GeometricProb(nn.Cell):
log_sf = self.g.log_survival(value) log_sf = self.g.log_survival(value)
return prob + log_prob + cdf + log_cdf + sf + log_sf return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_geometric_prob(): def test_geometric_prob():
""" """
Test probability functions: passing value through construct. Test probability functions: passing value through construct.
@ -83,10 +90,12 @@ def test_geometric_prob():
ans = net(value) ans = net(value)
assert isinstance(ans, Tensor) assert isinstance(ans, Tensor)
class GeometricProb1(nn.Cell): class GeometricProb1(nn.Cell):
""" """
Geometric distribution: initialize without probs. Geometric distribution: initialize without probs.
""" """
def __init__(self): def __init__(self):
super(GeometricProb1, self).__init__() super(GeometricProb1, self).__init__()
self.g = msd.Geometric(dtype=dtype.int32) self.g = msd.Geometric(dtype=dtype.int32)
@ -100,6 +109,7 @@ class GeometricProb1(nn.Cell):
log_sf = self.g.log_survival(value, probs) log_sf = self.g.log_survival(value, probs)
return prob + log_prob + cdf + log_cdf + sf + log_sf return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_geometric_prob1(): def test_geometric_prob1():
""" """
Test probability functions: passing value/probs through construct. Test probability functions: passing value/probs through construct.
@ -115,6 +125,7 @@ class GeometricKl(nn.Cell):
""" """
Test class: kl_loss between Geometric distributions. Test class: kl_loss between Geometric distributions.
""" """
def __init__(self): def __init__(self):
super(GeometricKl, self).__init__() super(GeometricKl, self).__init__()
self.g1 = msd.Geometric(0.7, dtype=dtype.int32) self.g1 = msd.Geometric(0.7, dtype=dtype.int32)
@ -125,6 +136,7 @@ class GeometricKl(nn.Cell):
kl2 = self.g2.kl_loss('Geometric', probs_b, probs_a) kl2 = self.g2.kl_loss('Geometric', probs_b, probs_a)
return kl1 + kl2 return kl1 + kl2
def test_kl(): def test_kl():
""" """
Test kl_loss function. Test kl_loss function.
@ -135,10 +147,12 @@ def test_kl():
ans = ber_net(probs_b, probs_a) ans = ber_net(probs_b, probs_a)
assert isinstance(ans, Tensor) assert isinstance(ans, Tensor)
class GeometricCrossEntropy(nn.Cell): class GeometricCrossEntropy(nn.Cell):
""" """
Test class: cross_entropy of Geometric distribution. Test class: cross_entropy of Geometric distribution.
""" """
def __init__(self): def __init__(self):
super(GeometricCrossEntropy, self).__init__() super(GeometricCrossEntropy, self).__init__()
self.g1 = msd.Geometric(0.3, dtype=dtype.int32) self.g1 = msd.Geometric(0.3, dtype=dtype.int32)
@ -149,6 +163,7 @@ class GeometricCrossEntropy(nn.Cell):
h2 = self.g2.cross_entropy('Geometric', probs_b, probs_a) h2 = self.g2.cross_entropy('Geometric', probs_b, probs_a)
return h1 + h2 return h1 + h2
def test_cross_entropy(): def test_cross_entropy():
""" """
Test cross_entropy between Geometric distributions. Test cross_entropy between Geometric distributions.
@ -159,10 +174,12 @@ def test_cross_entropy():
ans = net(probs_b, probs_a) ans = net(probs_b, probs_a)
assert isinstance(ans, Tensor) assert isinstance(ans, Tensor)
class GeometricBasics(nn.Cell): class GeometricBasics(nn.Cell):
""" """
Test class: basic mean/sd/mode/entropy function. Test class: basic mean/sd/mode/entropy function.
""" """
def __init__(self): def __init__(self):
super(GeometricBasics, self).__init__() super(GeometricBasics, self).__init__()
self.g = msd.Geometric([0.3, 0.5], dtype=dtype.int32) self.g = msd.Geometric([0.3, 0.5], dtype=dtype.int32)
@ -175,6 +192,7 @@ class GeometricBasics(nn.Cell):
entropy = self.g.entropy() entropy = self.g.entropy()
return mean + sd + var + mode + entropy return mean + sd + var + mode + entropy
def test_bascis(): def test_bascis():
""" """
Test mean/sd/mode/entropy functionality of Geometric distribution. Test mean/sd/mode/entropy functionality of Geometric distribution.
@ -188,6 +206,7 @@ class GeoConstruct(nn.Cell):
""" """
Bernoulli distribution: going through construct. Bernoulli distribution: going through construct.
""" """
def __init__(self): def __init__(self):
super(GeoConstruct, self).__init__() super(GeoConstruct, self).__init__()
self.g = msd.Geometric(0.5, dtype=dtype.int32) self.g = msd.Geometric(0.5, dtype=dtype.int32)
@ -199,6 +218,7 @@ class GeoConstruct(nn.Cell):
prob2 = self.g1('prob', value, probs) prob2 = self.g1('prob', value, probs)
return prob + prob1 + prob2 return prob + prob1 + prob2
def test_geo_construct(): def test_geo_construct():
""" """
Test probability function going through construct. Test probability function going through construct.