fix bernoulli prob formula; fix some other minor bugs

update threshold of softplus computation

support fp for bernoulli and geometric distribution
This commit is contained in:
Zichun Ye 2020-08-23 15:40:09 -04:00
parent befc209480
commit 9e7d6e2397
8 changed files with 113 additions and 40 deletions

View File

@ -20,6 +20,7 @@ from ..distribution._utils.utils import CheckTensor
from ..distribution._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step, log1p_by_step
from .bijector import Bijector
class PowerTransform(Bijector):
r"""
Power Bijector.
@ -49,6 +50,7 @@ class PowerTransform(Bijector):
>>> # by replacing 'forward' with the name of the function
>>> ans = self.p1.forward(, value)
"""
def __init__(self,
power=0,
name='PowerTransform',
@ -78,13 +80,13 @@ class PowerTransform(Bijector):
return shape
def _forward(self, x):
self.checktensor(x, 'x')
self.checktensor(x, 'value')
if self.power == 0:
return self.exp(x)
return self.exp(self.log1p(x * self.power) / self.power)
def _inverse(self, y):
self.checktensor(y, 'y')
self.checktensor(y, 'value')
if self.power == 0:
return self.log(y)
return self.expm1(self.log(y) * self.power) / self.power
@ -101,7 +103,7 @@ class PowerTransform(Bijector):
f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1}
\log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1)
"""
self.checktensor(x, 'x')
self.checktensor(x, 'value')
if self.power == 0:
return x
return (1. / self.power - 1) * self.log1p(x * self.power)
@ -118,5 +120,5 @@ class PowerTransform(Bijector):
f'(x) = \frac{e^c\log(y)}{y}
\log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y)
"""
self.checktensor(y, 'y')
self.checktensor(y, 'value')
return (self.power - 1) * self.log(y)

View File

@ -19,6 +19,7 @@ from ..distribution._utils.utils import cast_to_tensor, CheckTensor
from ..distribution._utils.custom_ops import log_by_step
from .bijector import Bijector
class ScalarAffine(Bijector):
"""
Scalar Affine Bijector.
@ -47,6 +48,7 @@ class ScalarAffine(Bijector):
>>> ans = self.s1.forward_log_jacobian(value)
>>> ans = self.s1.inverse_log_jacobian(value)
"""
def __init__(self,
scale=1.0,
shift=0.0,
@ -91,7 +93,7 @@ class ScalarAffine(Bijector):
.. math::
f(x) = a * x + b
"""
self.checktensor(x, 'x')
self.checktensor(x, 'value')
return self.scale * x + self.shift
def _inverse(self, y):
@ -99,7 +101,7 @@ class ScalarAffine(Bijector):
.. math::
f(y) = \frac{y - b}{a}
"""
self.checktensor(y, 'y')
self.checktensor(y, 'value')
return (y - self.shift) / self.scale
def _forward_log_jacobian(self, x):
@ -109,7 +111,7 @@ class ScalarAffine(Bijector):
f'(x) = a
\log(f'(x)) = \log(a)
"""
self.checktensor(x, 'x')
self.checktensor(x, 'value')
return self.log(self.abs(self.scale))
def _inverse_log_jacobian(self, y):
@ -119,5 +121,5 @@ class ScalarAffine(Bijector):
f'(x) = \frac{1.0}{a}
\log(f'(x)) = - \log(a)
"""
self.checktensor(y, 'y')
self.checktensor(y, 'value')
return -1. * self.log(self.abs(self.scale))

View File

@ -22,6 +22,7 @@ from ..distribution._utils.utils import cast_to_tensor, CheckTensor
from ..distribution._utils.custom_ops import exp_by_step, expm1_by_step, log_by_step
from .bijector import Bijector
class Softplus(Bijector):
r"""
Softplus Bijector.
@ -51,6 +52,7 @@ class Softplus(Bijector):
>>> ans = self.sp1.forward_log_jacobian(value)
>>> ans = self.sp1.inverse_log_jacobian(value)
"""
def __init__(self,
sharpness=1.0,
name='Softplus'):
@ -76,6 +78,7 @@ class Softplus(Bijector):
self.checktensor = CheckTensor()
self.threshold = np.log(np.finfo(np.float32).eps) + 1
self.tiny = np.exp(self.threshold)
def _softplus(self, x):
too_small = self.less(x, self.threshold)
@ -94,7 +97,7 @@ class Softplus(Bijector):
f(x) = \frac{\log(1 + e^{x}))}
f^{-1}(y) = \frac{\log(e^{y} - 1)}
"""
too_small = self.less(x, self.threshold)
too_small = self.less(x, self.tiny)
too_large = self.greater(x, -self.threshold)
too_small_value = self.log(x)
too_large_value = x
@ -116,7 +119,7 @@ class Softplus(Bijector):
return shape
def _forward(self, x):
self.checktensor(x, 'x')
self.checktensor(x, 'value')
scaled_value = self.sharpness * x
return self.softplus(scaled_value) / self.sharpness
@ -126,7 +129,7 @@ class Softplus(Bijector):
f(x) = \frac{\log(1 + e^{kx}))}{k}
f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k}
"""
self.checktensor(y, 'y')
self.checktensor(y, 'value')
scaled_value = self.sharpness * y
return self.inverse_softplus(scaled_value) / self.sharpness
@ -137,7 +140,7 @@ class Softplus(Bijector):
f'(x) = \frac{e^{kx}}{ 1 + e^{kx}}
\log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx)
"""
self.checktensor(x, 'x')
self.checktensor(x, 'value')
scaled_value = self.sharpness * x
return self.log_sigmoid(scaled_value)
@ -148,6 +151,6 @@ class Softplus(Bijector):
f'(y) = \frac{e^{ky}}{e^{ky} - 1}
\log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky)
"""
self.checktensor(y, 'y')
self.checktensor(y, 'value')
scaled_value = self.sharpness * y
return scaled_value - self.inverse_softplus(scaled_value)

View File

@ -26,6 +26,7 @@ from mindspore import context
import mindspore.nn as nn
import mindspore.nn.probability as msp
def cast_to_tensor(t, hint_type=mstype.float32):
"""
Cast an user input value into a Tensor of dtype.
@ -47,7 +48,7 @@ def cast_to_tensor(t, hint_type=mstype.float32):
return t
t_type = hint_type
if isinstance(t, Tensor):
#convert the type of tensor to dtype
# convert the type of tensor to dtype
return Tensor(t.asnumpy(), dtype=t_type)
if isinstance(t, (list, np.ndarray)):
return Tensor(t, dtype=t_type)
@ -56,7 +57,8 @@ def cast_to_tensor(t, hint_type=mstype.float32):
if isinstance(t, (int, float)):
return Tensor(t, dtype=t_type)
invalid_type = type(t)
raise TypeError(f"Unable to convert input of type {invalid_type} to a Tensor of type {t_type}")
raise TypeError(
f"Unable to convert input of type {invalid_type} to a Tensor of type {t_type}")
def convert_to_batch(t, batch_shape, required_type):
@ -79,6 +81,7 @@ def convert_to_batch(t, batch_shape, required_type):
t = cast_to_tensor(t, required_type)
return Tensor(np.broadcast_to(t.asnumpy(), batch_shape), dtype=required_type)
def check_scalar_from_param(params):
"""
Check if params are all scalars.
@ -93,11 +96,7 @@ def check_scalar_from_param(params):
return params['distribution'].is_scalar_batch
if isinstance(value, Parameter):
return False
if isinstance(value, (str, type(params['dtype']))):
continue
elif isinstance(value, (int, float)):
continue
else:
if not isinstance(value, (int, float, str, type(params['dtype']))):
return False
return True
@ -124,7 +123,8 @@ def calc_broadcast_shape_from_param(params):
value_t = value.default_input
else:
value_t = cast_to_tensor(value, mstype.float32)
broadcast_shape = utils.get_broadcast_shape(broadcast_shape, list(value_t.shape), params['name'])
broadcast_shape = utils.get_broadcast_shape(
broadcast_shape, list(value_t.shape), params['name'])
return tuple(broadcast_shape)
@ -148,6 +148,7 @@ def check_greater_equal_zero(value, name):
if comp.any():
raise ValueError(f'{name} should be greater than ot equal to zero.')
def check_greater_zero(value, name):
"""
Check if the given Tensor is strictly greater than zero.
@ -251,6 +252,7 @@ def probs_to_logits(probs, is_binary=False):
return P.Log()(ps_clamped) - P.Log()(1-ps_clamped)
return P.Log()(ps_clamped)
def check_tensor_type(name, inputs, valid_type):
"""
Check if inputs is proper.
@ -268,25 +270,34 @@ def check_tensor_type(name, inputs, valid_type):
if input_type not in valid_type:
raise TypeError(f"{name} dtype is invalid")
def check_type(data_type, value_type, name):
if not data_type in value_type:
raise TypeError(f"For {name}, valid type include {value_type}, {data_type} is invalid")
raise TypeError(
f"For {name}, valid type include {value_type}, {data_type} is invalid")
@constexpr
def raise_none_error(name):
raise TypeError(f"the type {name} should be subclass of Tensor."
f" It should not be None since it is not specified during initialization.")
@constexpr
def raise_not_impl_error(name):
raise ValueError(f"{name} function should be implemented for non-linear transformation")
raise ValueError(
f"{name} function should be implemented for non-linear transformation")
@constexpr
def check_distribution_name(name, expected_name):
if name is None:
raise ValueError(f"Distribution should be a constant which is not None.")
raise ValueError(
f"Distribution should be a constant which is not None.")
if name != expected_name:
raise ValueError(f"Expected distribution name is {expected_name}, but got {name}.")
raise ValueError(
f"Expected distribution name is {expected_name}, but got {name}.")
class CheckTuple(PrimitiveWithInfer):
"""
@ -294,13 +305,13 @@ class CheckTuple(PrimitiveWithInfer):
"""
@prim_attr_register
def __init__(self):
"""init Cast"""
super(CheckTuple, self).__init__("CheckTuple")
self.init_prim_io_names(inputs=['x'], outputs=['dummy_output'])
self.init_prim_io_names(inputs=['x', 'name'], outputs=['dummy_output'])
def __infer__(self, x, name):
if not isinstance(x['dtype'], tuple):
raise TypeError(f"For {name['value']}, Input type should b a tuple.")
raise TypeError(
f"For {name['value']}, Input type should b a tuple.")
out = {'shape': None,
'dtype': None,
@ -310,24 +321,25 @@ class CheckTuple(PrimitiveWithInfer):
def __call__(self, x, name):
if context.get_context("mode") == 0:
return x["value"]
#Pynative mode
# Pynative mode
if isinstance(x, tuple):
return x
raise TypeError(f"For {name['value']}, Input type should b a tuple.")
class CheckTensor(PrimitiveWithInfer):
"""
Check if input is a Tensor.
"""
@prim_attr_register
def __init__(self):
"""init Cast"""
super(CheckTensor, self).__init__("CheckTensor")
self.init_prim_io_names(inputs=['x'], outputs=['dummy_output'])
self.init_prim_io_names(inputs=['x', 'name'], outputs=['dummy_output'])
def __infer__(self, x, name):
src_type = x['dtype']
validator.check_subclass("input", src_type, [mstype.tensor], name["value"])
validator.check_subclass(
"input", src_type, [mstype.tensor], name["value"])
out = {'shape': None,
'dtype': None,

View File

@ -20,6 +20,7 @@ from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error
from ._utils.custom_ops import exp_by_step, log_by_step
class Bernoulli(Distribution):
"""
Bernoulli Distribution.
@ -97,7 +98,7 @@ class Bernoulli(Distribution):
Constructor of Bernoulli distribution.
"""
param = dict(locals())
valid_dtype = mstype.int_type + mstype.uint_type
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
check_type(dtype, valid_dtype, "Bernoulli")
super(Bernoulli, self).__init__(seed, dtype, name, param)
self.parameter_type = mstype.float32
@ -211,7 +212,6 @@ class Bernoulli(Distribution):
"""
self.checktensor(value, 'value')
value = self.cast(value, mstype.float32)
value = self.floor(value)
probs1 = self._check_param(probs1)
probs0 = 1.0 - probs1
return self.log(probs1) * value + self.log(probs0) * (1.0 - value)

View File

@ -19,9 +19,10 @@ from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\
raise_none_error
raise_none_error
from ._utils.custom_ops import exp_by_step, log_by_step
class Geometric(Distribution):
"""
Geometric Distribution.
@ -100,7 +101,7 @@ class Geometric(Distribution):
Constructor of Geometric distribution.
"""
param = dict(locals())
valid_dtype = mstype.int_type + mstype.uint_type
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
check_type(dtype, valid_dtype, "Geometric")
super(Geometric, self).__init__(seed, dtype, name, param)
self.parameter_type = mstype.float32
@ -130,7 +131,6 @@ class Geometric(Distribution):
self.sqrt = P.Sqrt()
self.uniform = C.uniform
def extend_repr(self):
if self.is_scalar_batch:
str_info = f'probs = {self.probs}'
@ -243,7 +243,6 @@ class Geometric(Distribution):
comp = self.less(value, zeros)
return self.select(comp, zeros, cdf)
def _kl_loss(self, dist, probs1_b, probs1=None):
r"""
Evaluate Geometric-Geometric kl divergence, i.e. KL(a||b).

View File

@ -22,6 +22,7 @@ import mindspore.nn.probability.distribution as msd
from mindspore import dtype
from mindspore import Tensor
def test_arguments():
"""
Args passing during initialization.
@ -31,18 +32,22 @@ def test_arguments():
b = msd.Bernoulli([0.1, 0.3, 0.5, 0.9], dtype=dtype.int32)
assert isinstance(b, msd.Distribution)
def test_type():
with pytest.raises(TypeError):
msd.Bernoulli([0.1], dtype=dtype.float32)
msd.Bernoulli([0.1], dtype=dtype.bool_)
def test_name():
with pytest.raises(TypeError):
msd.Bernoulli([0.1], name=1.0)
def test_seed():
with pytest.raises(TypeError):
msd.Bernoulli([0.1], seed='seed')
def test_prob():
"""
Invalid probability.
@ -56,10 +61,12 @@ def test_prob():
with pytest.raises(ValueError):
msd.Bernoulli([1.0], dtype=dtype.int32)
class BernoulliProb(nn.Cell):
"""
Bernoulli distribution: initialize with probs.
"""
def __init__(self):
super(BernoulliProb, self).__init__()
self.b = msd.Bernoulli(0.5, dtype=dtype.int32)
@ -73,6 +80,7 @@ class BernoulliProb(nn.Cell):
log_sf = self.b.log_survival(value)
return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_bernoulli_prob():
"""
Test probability functions: passing value through construct.
@ -82,10 +90,12 @@ def test_bernoulli_prob():
ans = net(value)
assert isinstance(ans, Tensor)
class BernoulliProb1(nn.Cell):
"""
Bernoulli distribution: initialize without probs.
"""
def __init__(self):
super(BernoulliProb1, self).__init__()
self.b = msd.Bernoulli(dtype=dtype.int32)
@ -99,6 +109,7 @@ class BernoulliProb1(nn.Cell):
log_sf = self.b.log_survival(value, probs)
return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_bernoulli_prob1():
"""
Test probability functions: passing value/probs through construct.
@ -109,10 +120,12 @@ def test_bernoulli_prob1():
ans = net(value, probs)
assert isinstance(ans, Tensor)
class BernoulliKl(nn.Cell):
"""
Test class: kl_loss between Bernoulli distributions.
"""
def __init__(self):
super(BernoulliKl, self).__init__()
self.b1 = msd.Bernoulli(0.7, dtype=dtype.int32)
@ -123,6 +136,7 @@ class BernoulliKl(nn.Cell):
kl2 = self.b2.kl_loss('Bernoulli', probs_b, probs_a)
return kl1 + kl2
def test_kl():
"""
Test kl_loss function.
@ -133,10 +147,12 @@ def test_kl():
ans = ber_net(probs_b, probs_a)
assert isinstance(ans, Tensor)
class BernoulliCrossEntropy(nn.Cell):
"""
Test class: cross_entropy of Bernoulli distribution.
"""
def __init__(self):
super(BernoulliCrossEntropy, self).__init__()
self.b1 = msd.Bernoulli(0.7, dtype=dtype.int32)
@ -147,6 +163,7 @@ class BernoulliCrossEntropy(nn.Cell):
h2 = self.b2.cross_entropy('Bernoulli', probs_b, probs_a)
return h1 + h2
def test_cross_entropy():
"""
Test cross_entropy between Bernoulli distributions.
@ -157,10 +174,12 @@ def test_cross_entropy():
ans = net(probs_b, probs_a)
assert isinstance(ans, Tensor)
class BernoulliConstruct(nn.Cell):
"""
Bernoulli distribution: going through construct.
"""
def __init__(self):
super(BernoulliConstruct, self).__init__()
self.b = msd.Bernoulli(0.5, dtype=dtype.int32)
@ -172,6 +191,7 @@ class BernoulliConstruct(nn.Cell):
prob2 = self.b1('prob', value, probs)
return prob + prob1 + prob2
def test_bernoulli_construct():
"""
Test probability function going through construct.
@ -182,10 +202,12 @@ def test_bernoulli_construct():
ans = net(value, probs)
assert isinstance(ans, Tensor)
class BernoulliMean(nn.Cell):
"""
Test class: basic mean/sd/var/mode/entropy function.
"""
def __init__(self):
super(BernoulliMean, self).__init__()
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
@ -194,6 +216,7 @@ class BernoulliMean(nn.Cell):
mean = self.b.mean()
return mean
def test_mean():
"""
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
@ -202,10 +225,12 @@ def test_mean():
ans = net()
assert isinstance(ans, Tensor)
class BernoulliSd(nn.Cell):
"""
Test class: basic mean/sd/var/mode/entropy function.
"""
def __init__(self):
super(BernoulliSd, self).__init__()
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
@ -214,6 +239,7 @@ class BernoulliSd(nn.Cell):
sd = self.b.sd()
return sd
def test_sd():
"""
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
@ -222,10 +248,12 @@ def test_sd():
ans = net()
assert isinstance(ans, Tensor)
class BernoulliVar(nn.Cell):
"""
Test class: basic mean/sd/var/mode/entropy function.
"""
def __init__(self):
super(BernoulliVar, self).__init__()
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
@ -234,6 +262,7 @@ class BernoulliVar(nn.Cell):
var = self.b.var()
return var
def test_var():
"""
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
@ -242,10 +271,12 @@ def test_var():
ans = net()
assert isinstance(ans, Tensor)
class BernoulliMode(nn.Cell):
"""
Test class: basic mean/sd/var/mode/entropy function.
"""
def __init__(self):
super(BernoulliMode, self).__init__()
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
@ -254,6 +285,7 @@ class BernoulliMode(nn.Cell):
mode = self.b.mode()
return mode
def test_mode():
"""
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
@ -262,10 +294,12 @@ def test_mode():
ans = net()
assert isinstance(ans, Tensor)
class BernoulliEntropy(nn.Cell):
"""
Test class: basic mean/sd/var/mode/entropy function.
"""
def __init__(self):
super(BernoulliEntropy, self).__init__()
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
@ -274,6 +308,7 @@ class BernoulliEntropy(nn.Cell):
entropy = self.b.entropy()
return entropy
def test_entropy():
"""
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.

View File

@ -32,18 +32,22 @@ def test_arguments():
g = msd.Geometric([0.1, 0.3, 0.5, 0.9], dtype=dtype.int32)
assert isinstance(g, msd.Distribution)
def test_type():
with pytest.raises(TypeError):
msd.Geometric([0.1], dtype=dtype.float32)
msd.Geometric([0.1], dtype=dtype.bool_)
def test_name():
with pytest.raises(TypeError):
msd.Geometric([0.1], name=1.0)
def test_seed():
with pytest.raises(TypeError):
msd.Geometric([0.1], seed='seed')
def test_prob():
"""
Invalid probability.
@ -57,10 +61,12 @@ def test_prob():
with pytest.raises(ValueError):
msd.Geometric([1.0], dtype=dtype.int32)
class GeometricProb(nn.Cell):
"""
Geometric distribution: initialize with probs.
"""
def __init__(self):
super(GeometricProb, self).__init__()
self.g = msd.Geometric(0.5, dtype=dtype.int32)
@ -74,6 +80,7 @@ class GeometricProb(nn.Cell):
log_sf = self.g.log_survival(value)
return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_geometric_prob():
"""
Test probability functions: passing value through construct.
@ -83,10 +90,12 @@ def test_geometric_prob():
ans = net(value)
assert isinstance(ans, Tensor)
class GeometricProb1(nn.Cell):
"""
Geometric distribution: initialize without probs.
"""
def __init__(self):
super(GeometricProb1, self).__init__()
self.g = msd.Geometric(dtype=dtype.int32)
@ -100,6 +109,7 @@ class GeometricProb1(nn.Cell):
log_sf = self.g.log_survival(value, probs)
return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_geometric_prob1():
"""
Test probability functions: passing value/probs through construct.
@ -115,6 +125,7 @@ class GeometricKl(nn.Cell):
"""
Test class: kl_loss between Geometric distributions.
"""
def __init__(self):
super(GeometricKl, self).__init__()
self.g1 = msd.Geometric(0.7, dtype=dtype.int32)
@ -125,6 +136,7 @@ class GeometricKl(nn.Cell):
kl2 = self.g2.kl_loss('Geometric', probs_b, probs_a)
return kl1 + kl2
def test_kl():
"""
Test kl_loss function.
@ -135,10 +147,12 @@ def test_kl():
ans = ber_net(probs_b, probs_a)
assert isinstance(ans, Tensor)
class GeometricCrossEntropy(nn.Cell):
"""
Test class: cross_entropy of Geometric distribution.
"""
def __init__(self):
super(GeometricCrossEntropy, self).__init__()
self.g1 = msd.Geometric(0.3, dtype=dtype.int32)
@ -149,6 +163,7 @@ class GeometricCrossEntropy(nn.Cell):
h2 = self.g2.cross_entropy('Geometric', probs_b, probs_a)
return h1 + h2
def test_cross_entropy():
"""
Test cross_entropy between Geometric distributions.
@ -159,10 +174,12 @@ def test_cross_entropy():
ans = net(probs_b, probs_a)
assert isinstance(ans, Tensor)
class GeometricBasics(nn.Cell):
"""
Test class: basic mean/sd/mode/entropy function.
"""
def __init__(self):
super(GeometricBasics, self).__init__()
self.g = msd.Geometric([0.3, 0.5], dtype=dtype.int32)
@ -175,6 +192,7 @@ class GeometricBasics(nn.Cell):
entropy = self.g.entropy()
return mean + sd + var + mode + entropy
def test_bascis():
"""
Test mean/sd/mode/entropy functionality of Geometric distribution.
@ -188,6 +206,7 @@ class GeoConstruct(nn.Cell):
"""
Bernoulli distribution: going through construct.
"""
def __init__(self):
super(GeoConstruct, self).__init__()
self.g = msd.Geometric(0.5, dtype=dtype.int32)
@ -199,6 +218,7 @@ class GeoConstruct(nn.Cell):
prob2 = self.g1('prob', value, probs)
return prob + prob1 + prob2
def test_geo_construct():
"""
Test probability function going through construct.