forked from mindspore-Ecosystem/mindspore
!5005 Fix the bug in the formula in Bernoulli log_probs
Merge pull request !5005 from zichun_ye/fix_bernoulli_probs
This commit is contained in:
commit
56835aaf88
|
@ -20,6 +20,7 @@ from ..distribution._utils.utils import CheckTensor
|
|||
from ..distribution._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step, log1p_by_step
|
||||
from .bijector import Bijector
|
||||
|
||||
|
||||
class PowerTransform(Bijector):
|
||||
r"""
|
||||
Power Bijector.
|
||||
|
@ -49,6 +50,7 @@ class PowerTransform(Bijector):
|
|||
>>> # by replacing 'forward' with the name of the function
|
||||
>>> ans = self.p1.forward(, value)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
power=0,
|
||||
name='PowerTransform',
|
||||
|
@ -78,13 +80,13 @@ class PowerTransform(Bijector):
|
|||
return shape
|
||||
|
||||
def _forward(self, x):
|
||||
self.checktensor(x, 'x')
|
||||
self.checktensor(x, 'value')
|
||||
if self.power == 0:
|
||||
return self.exp(x)
|
||||
return self.exp(self.log1p(x * self.power) / self.power)
|
||||
|
||||
def _inverse(self, y):
|
||||
self.checktensor(y, 'y')
|
||||
self.checktensor(y, 'value')
|
||||
if self.power == 0:
|
||||
return self.log(y)
|
||||
return self.expm1(self.log(y) * self.power) / self.power
|
||||
|
@ -101,7 +103,7 @@ class PowerTransform(Bijector):
|
|||
f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1}
|
||||
\log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1)
|
||||
"""
|
||||
self.checktensor(x, 'x')
|
||||
self.checktensor(x, 'value')
|
||||
if self.power == 0:
|
||||
return x
|
||||
return (1. / self.power - 1) * self.log1p(x * self.power)
|
||||
|
@ -118,5 +120,5 @@ class PowerTransform(Bijector):
|
|||
f'(x) = \frac{e^c\log(y)}{y}
|
||||
\log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y)
|
||||
"""
|
||||
self.checktensor(y, 'y')
|
||||
self.checktensor(y, 'value')
|
||||
return (self.power - 1) * self.log(y)
|
||||
|
|
|
@ -19,6 +19,7 @@ from ..distribution._utils.utils import cast_to_tensor, CheckTensor
|
|||
from ..distribution._utils.custom_ops import log_by_step
|
||||
from .bijector import Bijector
|
||||
|
||||
|
||||
class ScalarAffine(Bijector):
|
||||
"""
|
||||
Scalar Affine Bijector.
|
||||
|
@ -47,6 +48,7 @@ class ScalarAffine(Bijector):
|
|||
>>> ans = self.s1.forward_log_jacobian(value)
|
||||
>>> ans = self.s1.inverse_log_jacobian(value)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
scale=1.0,
|
||||
shift=0.0,
|
||||
|
@ -91,7 +93,7 @@ class ScalarAffine(Bijector):
|
|||
.. math::
|
||||
f(x) = a * x + b
|
||||
"""
|
||||
self.checktensor(x, 'x')
|
||||
self.checktensor(x, 'value')
|
||||
return self.scale * x + self.shift
|
||||
|
||||
def _inverse(self, y):
|
||||
|
@ -99,7 +101,7 @@ class ScalarAffine(Bijector):
|
|||
.. math::
|
||||
f(y) = \frac{y - b}{a}
|
||||
"""
|
||||
self.checktensor(y, 'y')
|
||||
self.checktensor(y, 'value')
|
||||
return (y - self.shift) / self.scale
|
||||
|
||||
def _forward_log_jacobian(self, x):
|
||||
|
@ -109,7 +111,7 @@ class ScalarAffine(Bijector):
|
|||
f'(x) = a
|
||||
\log(f'(x)) = \log(a)
|
||||
"""
|
||||
self.checktensor(x, 'x')
|
||||
self.checktensor(x, 'value')
|
||||
return self.log(self.abs(self.scale))
|
||||
|
||||
def _inverse_log_jacobian(self, y):
|
||||
|
@ -119,5 +121,5 @@ class ScalarAffine(Bijector):
|
|||
f'(x) = \frac{1.0}{a}
|
||||
\log(f'(x)) = - \log(a)
|
||||
"""
|
||||
self.checktensor(y, 'y')
|
||||
self.checktensor(y, 'value')
|
||||
return -1. * self.log(self.abs(self.scale))
|
||||
|
|
|
@ -22,6 +22,7 @@ from ..distribution._utils.utils import cast_to_tensor, CheckTensor
|
|||
from ..distribution._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step
|
||||
from .bijector import Bijector
|
||||
|
||||
|
||||
class Softplus(Bijector):
|
||||
r"""
|
||||
Softplus Bijector.
|
||||
|
@ -51,6 +52,7 @@ class Softplus(Bijector):
|
|||
>>> ans = self.sp1.forward_log_jacobian(value)
|
||||
>>> ans = self.sp1.inverse_log_jacobian(value)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
sharpness=1.0,
|
||||
name='Softplus'):
|
||||
|
@ -76,6 +78,7 @@ class Softplus(Bijector):
|
|||
|
||||
self.checktensor = CheckTensor()
|
||||
self.threshold = np.log(np.finfo(np.float32).eps) + 1
|
||||
self.tiny = np.exp(self.threshold)
|
||||
|
||||
def _softplus(self, x):
|
||||
too_small = self.less(x, self.threshold)
|
||||
|
@ -94,7 +97,7 @@ class Softplus(Bijector):
|
|||
f(x) = \frac{\log(1 + e^{x}))}
|
||||
f^{-1}(y) = \frac{\log(e^{y} - 1)}
|
||||
"""
|
||||
too_small = self.less(x, self.threshold)
|
||||
too_small = self.less(x, self.tiny)
|
||||
too_large = self.greater(x, -self.threshold)
|
||||
too_small_value = self.log(x)
|
||||
too_large_value = x
|
||||
|
@ -116,7 +119,7 @@ class Softplus(Bijector):
|
|||
return shape
|
||||
|
||||
def _forward(self, x):
|
||||
self.checktensor(x, 'x')
|
||||
self.checktensor(x, 'value')
|
||||
scaled_value = self.sharpness * x
|
||||
return self.softplus(scaled_value) / self.sharpness
|
||||
|
||||
|
@ -126,7 +129,7 @@ class Softplus(Bijector):
|
|||
f(x) = \frac{\log(1 + e^{kx}))}{k}
|
||||
f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k}
|
||||
"""
|
||||
self.checktensor(y, 'y')
|
||||
self.checktensor(y, 'value')
|
||||
scaled_value = self.sharpness * y
|
||||
return self.inverse_softplus(scaled_value) / self.sharpness
|
||||
|
||||
|
@ -137,7 +140,7 @@ class Softplus(Bijector):
|
|||
f'(x) = \frac{e^{kx}}{ 1 + e^{kx}}
|
||||
\log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx)
|
||||
"""
|
||||
self.checktensor(x, 'x')
|
||||
self.checktensor(x, 'value')
|
||||
scaled_value = self.sharpness * x
|
||||
return self.log_sigmoid(scaled_value)
|
||||
|
||||
|
@ -148,6 +151,6 @@ class Softplus(Bijector):
|
|||
f'(y) = \frac{e^{ky}}{e^{ky} - 1}
|
||||
\log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky)
|
||||
"""
|
||||
self.checktensor(y, 'y')
|
||||
self.checktensor(y, 'value')
|
||||
scaled_value = self.sharpness * y
|
||||
return scaled_value - self.inverse_softplus(scaled_value)
|
||||
|
|
|
@ -26,6 +26,7 @@ from mindspore import context
|
|||
import mindspore.nn as nn
|
||||
import mindspore.nn.probability as msp
|
||||
|
||||
|
||||
def cast_to_tensor(t, hint_type=mstype.float32):
|
||||
"""
|
||||
Cast an user input value into a Tensor of dtype.
|
||||
|
@ -47,7 +48,7 @@ def cast_to_tensor(t, hint_type=mstype.float32):
|
|||
return t
|
||||
t_type = hint_type
|
||||
if isinstance(t, Tensor):
|
||||
#convert the type of tensor to dtype
|
||||
# convert the type of tensor to dtype
|
||||
return Tensor(t.asnumpy(), dtype=t_type)
|
||||
if isinstance(t, (list, np.ndarray)):
|
||||
return Tensor(t, dtype=t_type)
|
||||
|
@ -56,7 +57,8 @@ def cast_to_tensor(t, hint_type=mstype.float32):
|
|||
if isinstance(t, (int, float)):
|
||||
return Tensor(t, dtype=t_type)
|
||||
invalid_type = type(t)
|
||||
raise TypeError(f"Unable to convert input of type {invalid_type} to a Tensor of type {t_type}")
|
||||
raise TypeError(
|
||||
f"Unable to convert input of type {invalid_type} to a Tensor of type {t_type}")
|
||||
|
||||
|
||||
def convert_to_batch(t, batch_shape, required_type):
|
||||
|
@ -79,6 +81,7 @@ def convert_to_batch(t, batch_shape, required_type):
|
|||
t = cast_to_tensor(t, required_type)
|
||||
return Tensor(np.broadcast_to(t.asnumpy(), batch_shape), dtype=required_type)
|
||||
|
||||
|
||||
def check_scalar_from_param(params):
|
||||
"""
|
||||
Check if params are all scalars.
|
||||
|
@ -93,11 +96,7 @@ def check_scalar_from_param(params):
|
|||
return params['distribution'].is_scalar_batch
|
||||
if isinstance(value, Parameter):
|
||||
return False
|
||||
if isinstance(value, (str, type(params['dtype']))):
|
||||
continue
|
||||
elif isinstance(value, (int, float)):
|
||||
continue
|
||||
else:
|
||||
if not isinstance(value, (int, float, str, type(params['dtype']))):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
@ -124,7 +123,8 @@ def calc_broadcast_shape_from_param(params):
|
|||
value_t = value.default_input
|
||||
else:
|
||||
value_t = cast_to_tensor(value, mstype.float32)
|
||||
broadcast_shape = utils.get_broadcast_shape(broadcast_shape, list(value_t.shape), params['name'])
|
||||
broadcast_shape = utils.get_broadcast_shape(
|
||||
broadcast_shape, list(value_t.shape), params['name'])
|
||||
return tuple(broadcast_shape)
|
||||
|
||||
|
||||
|
@ -148,6 +148,7 @@ def check_greater_equal_zero(value, name):
|
|||
if comp.any():
|
||||
raise ValueError(f'{name} should be greater than ot equal to zero.')
|
||||
|
||||
|
||||
def check_greater_zero(value, name):
|
||||
"""
|
||||
Check if the given Tensor is strictly greater than zero.
|
||||
|
@ -251,6 +252,7 @@ def probs_to_logits(probs, is_binary=False):
|
|||
return P.Log()(ps_clamped) - P.Log()(1-ps_clamped)
|
||||
return P.Log()(ps_clamped)
|
||||
|
||||
|
||||
def check_tensor_type(name, inputs, valid_type):
|
||||
"""
|
||||
Check if inputs is proper.
|
||||
|
@ -268,25 +270,34 @@ def check_tensor_type(name, inputs, valid_type):
|
|||
if input_type not in valid_type:
|
||||
raise TypeError(f"{name} dtype is invalid")
|
||||
|
||||
|
||||
def check_type(data_type, value_type, name):
|
||||
if not data_type in value_type:
|
||||
raise TypeError(f"For {name}, valid type include {value_type}, {data_type} is invalid")
|
||||
raise TypeError(
|
||||
f"For {name}, valid type include {value_type}, {data_type} is invalid")
|
||||
|
||||
|
||||
@constexpr
|
||||
def raise_none_error(name):
|
||||
raise TypeError(f"the type {name} should be subclass of Tensor."
|
||||
f" It should not be None since it is not specified during initialization.")
|
||||
|
||||
|
||||
@constexpr
|
||||
def raise_not_impl_error(name):
|
||||
raise ValueError(f"{name} function should be implemented for non-linear transformation")
|
||||
raise ValueError(
|
||||
f"{name} function should be implemented for non-linear transformation")
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_distribution_name(name, expected_name):
|
||||
if name is None:
|
||||
raise ValueError(f"Distribution should be a constant which is not None.")
|
||||
raise ValueError(
|
||||
f"Distribution should be a constant which is not None.")
|
||||
if name != expected_name:
|
||||
raise ValueError(f"Expected distribution name is {expected_name}, but got {name}.")
|
||||
raise ValueError(
|
||||
f"Expected distribution name is {expected_name}, but got {name}.")
|
||||
|
||||
|
||||
class CheckTuple(PrimitiveWithInfer):
|
||||
"""
|
||||
|
@ -294,13 +305,13 @@ class CheckTuple(PrimitiveWithInfer):
|
|||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init Cast"""
|
||||
super(CheckTuple, self).__init__("CheckTuple")
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['dummy_output'])
|
||||
self.init_prim_io_names(inputs=['x', 'name'], outputs=['dummy_output'])
|
||||
|
||||
def __infer__(self, x, name):
|
||||
if not isinstance(x['dtype'], tuple):
|
||||
raise TypeError(f"For {name['value']}, Input type should b a tuple.")
|
||||
raise TypeError(
|
||||
f"For {name['value']}, Input type should b a tuple.")
|
||||
|
||||
out = {'shape': None,
|
||||
'dtype': None,
|
||||
|
@ -310,24 +321,25 @@ class CheckTuple(PrimitiveWithInfer):
|
|||
def __call__(self, x, name):
|
||||
if context.get_context("mode") == 0:
|
||||
return x["value"]
|
||||
#Pynative mode
|
||||
# Pynative mode
|
||||
if isinstance(x, tuple):
|
||||
return x
|
||||
raise TypeError(f"For {name['value']}, Input type should b a tuple.")
|
||||
|
||||
|
||||
class CheckTensor(PrimitiveWithInfer):
|
||||
"""
|
||||
Check if input is a Tensor.
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init Cast"""
|
||||
super(CheckTensor, self).__init__("CheckTensor")
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['dummy_output'])
|
||||
self.init_prim_io_names(inputs=['x', 'name'], outputs=['dummy_output'])
|
||||
|
||||
def __infer__(self, x, name):
|
||||
src_type = x['dtype']
|
||||
validator.check_subclass("input", src_type, [mstype.tensor], name["value"])
|
||||
validator.check_subclass(
|
||||
"input", src_type, [mstype.tensor], name["value"])
|
||||
|
||||
out = {'shape': None,
|
||||
'dtype': None,
|
||||
|
|
|
@ -20,6 +20,7 @@ from .distribution import Distribution
|
|||
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error
|
||||
from ._utils.custom_ops import exp_by_step, log_by_step
|
||||
|
||||
|
||||
class Bernoulli(Distribution):
|
||||
"""
|
||||
Bernoulli Distribution.
|
||||
|
@ -97,7 +98,7 @@ class Bernoulli(Distribution):
|
|||
Constructor of Bernoulli distribution.
|
||||
"""
|
||||
param = dict(locals())
|
||||
valid_dtype = mstype.int_type + mstype.uint_type
|
||||
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
|
||||
check_type(dtype, valid_dtype, "Bernoulli")
|
||||
super(Bernoulli, self).__init__(seed, dtype, name, param)
|
||||
self.parameter_type = mstype.float32
|
||||
|
@ -211,7 +212,6 @@ class Bernoulli(Distribution):
|
|||
"""
|
||||
self.checktensor(value, 'value')
|
||||
value = self.cast(value, mstype.float32)
|
||||
value = self.floor(value)
|
||||
probs1 = self._check_param(probs1)
|
||||
probs0 = 1.0 - probs1
|
||||
return self.log(probs1) * value + self.log(probs0) * (1.0 - value)
|
||||
|
|
|
@ -19,9 +19,10 @@ from mindspore.ops import composite as C
|
|||
from mindspore.common import dtype as mstype
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\
|
||||
raise_none_error
|
||||
raise_none_error
|
||||
from ._utils.custom_ops import exp_by_step, log_by_step
|
||||
|
||||
|
||||
class Geometric(Distribution):
|
||||
"""
|
||||
Geometric Distribution.
|
||||
|
@ -100,7 +101,7 @@ class Geometric(Distribution):
|
|||
Constructor of Geometric distribution.
|
||||
"""
|
||||
param = dict(locals())
|
||||
valid_dtype = mstype.int_type + mstype.uint_type
|
||||
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
|
||||
check_type(dtype, valid_dtype, "Geometric")
|
||||
super(Geometric, self).__init__(seed, dtype, name, param)
|
||||
self.parameter_type = mstype.float32
|
||||
|
@ -130,7 +131,6 @@ class Geometric(Distribution):
|
|||
self.sqrt = P.Sqrt()
|
||||
self.uniform = C.uniform
|
||||
|
||||
|
||||
def extend_repr(self):
|
||||
if self.is_scalar_batch:
|
||||
str_info = f'probs = {self.probs}'
|
||||
|
@ -243,7 +243,6 @@ class Geometric(Distribution):
|
|||
comp = self.less(value, zeros)
|
||||
return self.select(comp, zeros, cdf)
|
||||
|
||||
|
||||
def _kl_loss(self, dist, probs1_b, probs1=None):
|
||||
r"""
|
||||
Evaluate Geometric-Geometric kl divergence, i.e. KL(a||b).
|
||||
|
|
|
@ -22,6 +22,7 @@ import mindspore.nn.probability.distribution as msd
|
|||
from mindspore import dtype
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
def test_arguments():
|
||||
"""
|
||||
Args passing during initialization.
|
||||
|
@ -31,18 +32,22 @@ def test_arguments():
|
|||
b = msd.Bernoulli([0.1, 0.3, 0.5, 0.9], dtype=dtype.int32)
|
||||
assert isinstance(b, msd.Distribution)
|
||||
|
||||
|
||||
def test_type():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Bernoulli([0.1], dtype=dtype.float32)
|
||||
msd.Bernoulli([0.1], dtype=dtype.bool_)
|
||||
|
||||
|
||||
def test_name():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Bernoulli([0.1], name=1.0)
|
||||
|
||||
|
||||
def test_seed():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Bernoulli([0.1], seed='seed')
|
||||
|
||||
|
||||
def test_prob():
|
||||
"""
|
||||
Invalid probability.
|
||||
|
@ -56,10 +61,12 @@ def test_prob():
|
|||
with pytest.raises(ValueError):
|
||||
msd.Bernoulli([1.0], dtype=dtype.int32)
|
||||
|
||||
|
||||
class BernoulliProb(nn.Cell):
|
||||
"""
|
||||
Bernoulli distribution: initialize with probs.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(BernoulliProb, self).__init__()
|
||||
self.b = msd.Bernoulli(0.5, dtype=dtype.int32)
|
||||
|
@ -73,6 +80,7 @@ class BernoulliProb(nn.Cell):
|
|||
log_sf = self.b.log_survival(value)
|
||||
return prob + log_prob + cdf + log_cdf + sf + log_sf
|
||||
|
||||
|
||||
def test_bernoulli_prob():
|
||||
"""
|
||||
Test probability functions: passing value through construct.
|
||||
|
@ -82,10 +90,12 @@ def test_bernoulli_prob():
|
|||
ans = net(value)
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
|
||||
class BernoulliProb1(nn.Cell):
|
||||
"""
|
||||
Bernoulli distribution: initialize without probs.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(BernoulliProb1, self).__init__()
|
||||
self.b = msd.Bernoulli(dtype=dtype.int32)
|
||||
|
@ -99,6 +109,7 @@ class BernoulliProb1(nn.Cell):
|
|||
log_sf = self.b.log_survival(value, probs)
|
||||
return prob + log_prob + cdf + log_cdf + sf + log_sf
|
||||
|
||||
|
||||
def test_bernoulli_prob1():
|
||||
"""
|
||||
Test probability functions: passing value/probs through construct.
|
||||
|
@ -109,10 +120,12 @@ def test_bernoulli_prob1():
|
|||
ans = net(value, probs)
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
|
||||
class BernoulliKl(nn.Cell):
|
||||
"""
|
||||
Test class: kl_loss between Bernoulli distributions.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(BernoulliKl, self).__init__()
|
||||
self.b1 = msd.Bernoulli(0.7, dtype=dtype.int32)
|
||||
|
@ -123,6 +136,7 @@ class BernoulliKl(nn.Cell):
|
|||
kl2 = self.b2.kl_loss('Bernoulli', probs_b, probs_a)
|
||||
return kl1 + kl2
|
||||
|
||||
|
||||
def test_kl():
|
||||
"""
|
||||
Test kl_loss function.
|
||||
|
@ -133,10 +147,12 @@ def test_kl():
|
|||
ans = ber_net(probs_b, probs_a)
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
|
||||
class BernoulliCrossEntropy(nn.Cell):
|
||||
"""
|
||||
Test class: cross_entropy of Bernoulli distribution.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(BernoulliCrossEntropy, self).__init__()
|
||||
self.b1 = msd.Bernoulli(0.7, dtype=dtype.int32)
|
||||
|
@ -147,6 +163,7 @@ class BernoulliCrossEntropy(nn.Cell):
|
|||
h2 = self.b2.cross_entropy('Bernoulli', probs_b, probs_a)
|
||||
return h1 + h2
|
||||
|
||||
|
||||
def test_cross_entropy():
|
||||
"""
|
||||
Test cross_entropy between Bernoulli distributions.
|
||||
|
@ -157,10 +174,12 @@ def test_cross_entropy():
|
|||
ans = net(probs_b, probs_a)
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
|
||||
class BernoulliConstruct(nn.Cell):
|
||||
"""
|
||||
Bernoulli distribution: going through construct.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(BernoulliConstruct, self).__init__()
|
||||
self.b = msd.Bernoulli(0.5, dtype=dtype.int32)
|
||||
|
@ -172,6 +191,7 @@ class BernoulliConstruct(nn.Cell):
|
|||
prob2 = self.b1('prob', value, probs)
|
||||
return prob + prob1 + prob2
|
||||
|
||||
|
||||
def test_bernoulli_construct():
|
||||
"""
|
||||
Test probability function going through construct.
|
||||
|
@ -182,10 +202,12 @@ def test_bernoulli_construct():
|
|||
ans = net(value, probs)
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
|
||||
class BernoulliMean(nn.Cell):
|
||||
"""
|
||||
Test class: basic mean/sd/var/mode/entropy function.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(BernoulliMean, self).__init__()
|
||||
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
|
||||
|
@ -194,6 +216,7 @@ class BernoulliMean(nn.Cell):
|
|||
mean = self.b.mean()
|
||||
return mean
|
||||
|
||||
|
||||
def test_mean():
|
||||
"""
|
||||
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
|
||||
|
@ -202,10 +225,12 @@ def test_mean():
|
|||
ans = net()
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
|
||||
class BernoulliSd(nn.Cell):
|
||||
"""
|
||||
Test class: basic mean/sd/var/mode/entropy function.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(BernoulliSd, self).__init__()
|
||||
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
|
||||
|
@ -214,6 +239,7 @@ class BernoulliSd(nn.Cell):
|
|||
sd = self.b.sd()
|
||||
return sd
|
||||
|
||||
|
||||
def test_sd():
|
||||
"""
|
||||
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
|
||||
|
@ -222,10 +248,12 @@ def test_sd():
|
|||
ans = net()
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
|
||||
class BernoulliVar(nn.Cell):
|
||||
"""
|
||||
Test class: basic mean/sd/var/mode/entropy function.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(BernoulliVar, self).__init__()
|
||||
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
|
||||
|
@ -234,6 +262,7 @@ class BernoulliVar(nn.Cell):
|
|||
var = self.b.var()
|
||||
return var
|
||||
|
||||
|
||||
def test_var():
|
||||
"""
|
||||
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
|
||||
|
@ -242,10 +271,12 @@ def test_var():
|
|||
ans = net()
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
|
||||
class BernoulliMode(nn.Cell):
|
||||
"""
|
||||
Test class: basic mean/sd/var/mode/entropy function.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(BernoulliMode, self).__init__()
|
||||
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
|
||||
|
@ -254,6 +285,7 @@ class BernoulliMode(nn.Cell):
|
|||
mode = self.b.mode()
|
||||
return mode
|
||||
|
||||
|
||||
def test_mode():
|
||||
"""
|
||||
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
|
||||
|
@ -262,10 +294,12 @@ def test_mode():
|
|||
ans = net()
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
|
||||
class BernoulliEntropy(nn.Cell):
|
||||
"""
|
||||
Test class: basic mean/sd/var/mode/entropy function.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(BernoulliEntropy, self).__init__()
|
||||
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
|
||||
|
@ -274,6 +308,7 @@ class BernoulliEntropy(nn.Cell):
|
|||
entropy = self.b.entropy()
|
||||
return entropy
|
||||
|
||||
|
||||
def test_entropy():
|
||||
"""
|
||||
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
|
||||
|
|
|
@ -32,18 +32,22 @@ def test_arguments():
|
|||
g = msd.Geometric([0.1, 0.3, 0.5, 0.9], dtype=dtype.int32)
|
||||
assert isinstance(g, msd.Distribution)
|
||||
|
||||
|
||||
def test_type():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Geometric([0.1], dtype=dtype.float32)
|
||||
msd.Geometric([0.1], dtype=dtype.bool_)
|
||||
|
||||
|
||||
def test_name():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Geometric([0.1], name=1.0)
|
||||
|
||||
|
||||
def test_seed():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Geometric([0.1], seed='seed')
|
||||
|
||||
|
||||
def test_prob():
|
||||
"""
|
||||
Invalid probability.
|
||||
|
@ -57,10 +61,12 @@ def test_prob():
|
|||
with pytest.raises(ValueError):
|
||||
msd.Geometric([1.0], dtype=dtype.int32)
|
||||
|
||||
|
||||
class GeometricProb(nn.Cell):
|
||||
"""
|
||||
Geometric distribution: initialize with probs.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(GeometricProb, self).__init__()
|
||||
self.g = msd.Geometric(0.5, dtype=dtype.int32)
|
||||
|
@ -74,6 +80,7 @@ class GeometricProb(nn.Cell):
|
|||
log_sf = self.g.log_survival(value)
|
||||
return prob + log_prob + cdf + log_cdf + sf + log_sf
|
||||
|
||||
|
||||
def test_geometric_prob():
|
||||
"""
|
||||
Test probability functions: passing value through construct.
|
||||
|
@ -83,10 +90,12 @@ def test_geometric_prob():
|
|||
ans = net(value)
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
|
||||
class GeometricProb1(nn.Cell):
|
||||
"""
|
||||
Geometric distribution: initialize without probs.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(GeometricProb1, self).__init__()
|
||||
self.g = msd.Geometric(dtype=dtype.int32)
|
||||
|
@ -100,6 +109,7 @@ class GeometricProb1(nn.Cell):
|
|||
log_sf = self.g.log_survival(value, probs)
|
||||
return prob + log_prob + cdf + log_cdf + sf + log_sf
|
||||
|
||||
|
||||
def test_geometric_prob1():
|
||||
"""
|
||||
Test probability functions: passing value/probs through construct.
|
||||
|
@ -115,6 +125,7 @@ class GeometricKl(nn.Cell):
|
|||
"""
|
||||
Test class: kl_loss between Geometric distributions.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(GeometricKl, self).__init__()
|
||||
self.g1 = msd.Geometric(0.7, dtype=dtype.int32)
|
||||
|
@ -125,6 +136,7 @@ class GeometricKl(nn.Cell):
|
|||
kl2 = self.g2.kl_loss('Geometric', probs_b, probs_a)
|
||||
return kl1 + kl2
|
||||
|
||||
|
||||
def test_kl():
|
||||
"""
|
||||
Test kl_loss function.
|
||||
|
@ -135,10 +147,12 @@ def test_kl():
|
|||
ans = ber_net(probs_b, probs_a)
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
|
||||
class GeometricCrossEntropy(nn.Cell):
|
||||
"""
|
||||
Test class: cross_entropy of Geometric distribution.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(GeometricCrossEntropy, self).__init__()
|
||||
self.g1 = msd.Geometric(0.3, dtype=dtype.int32)
|
||||
|
@ -149,6 +163,7 @@ class GeometricCrossEntropy(nn.Cell):
|
|||
h2 = self.g2.cross_entropy('Geometric', probs_b, probs_a)
|
||||
return h1 + h2
|
||||
|
||||
|
||||
def test_cross_entropy():
|
||||
"""
|
||||
Test cross_entropy between Geometric distributions.
|
||||
|
@ -159,10 +174,12 @@ def test_cross_entropy():
|
|||
ans = net(probs_b, probs_a)
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
|
||||
class GeometricBasics(nn.Cell):
|
||||
"""
|
||||
Test class: basic mean/sd/mode/entropy function.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(GeometricBasics, self).__init__()
|
||||
self.g = msd.Geometric([0.3, 0.5], dtype=dtype.int32)
|
||||
|
@ -175,6 +192,7 @@ class GeometricBasics(nn.Cell):
|
|||
entropy = self.g.entropy()
|
||||
return mean + sd + var + mode + entropy
|
||||
|
||||
|
||||
def test_bascis():
|
||||
"""
|
||||
Test mean/sd/mode/entropy functionality of Geometric distribution.
|
||||
|
@ -188,6 +206,7 @@ class GeoConstruct(nn.Cell):
|
|||
"""
|
||||
Bernoulli distribution: going through construct.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(GeoConstruct, self).__init__()
|
||||
self.g = msd.Geometric(0.5, dtype=dtype.int32)
|
||||
|
@ -199,6 +218,7 @@ class GeoConstruct(nn.Cell):
|
|||
prob2 = self.g1('prob', value, probs)
|
||||
return prob + prob1 + prob2
|
||||
|
||||
|
||||
def test_geo_construct():
|
||||
"""
|
||||
Test probability function going through construct.
|
||||
|
|
Loading…
Reference in New Issue