forked from mindspore-Ecosystem/mindspore
!8752 Add Poisson distribution
From: @peixu_ren Reviewed-by: Signed-off-by:
This commit is contained in:
commit
21b501bb94
|
@ -18,27 +18,29 @@ Distributions are the high-level components used to construct the probabilistic
|
|||
|
||||
from .distribution import Distribution
|
||||
from .transformed_distribution import TransformedDistribution
|
||||
from .normal import Normal
|
||||
from .bernoulli import Bernoulli
|
||||
from .exponential import Exponential
|
||||
from .uniform import Uniform
|
||||
from .geometric import Geometric
|
||||
from .categorical import Categorical
|
||||
from .log_normal import LogNormal
|
||||
from .logistic import Logistic
|
||||
from .gumbel import Gumbel
|
||||
from .cauchy import Cauchy
|
||||
from .exponential import Exponential
|
||||
from .geometric import Geometric
|
||||
from .gumbel import Gumbel
|
||||
from .logistic import Logistic
|
||||
from .log_normal import LogNormal
|
||||
from .normal import Normal
|
||||
from .poisson import Poisson
|
||||
from .uniform import Uniform
|
||||
|
||||
__all__ = ['Distribution',
|
||||
'TransformedDistribution',
|
||||
'Normal',
|
||||
'Bernoulli',
|
||||
'Exponential',
|
||||
'Uniform',
|
||||
'Categorical',
|
||||
'Geometric',
|
||||
'LogNormal',
|
||||
'Logistic',
|
||||
'Gumbel',
|
||||
'Cauchy',
|
||||
'Exponential',
|
||||
'Geometric',
|
||||
'Gumbel',
|
||||
'Logistic',
|
||||
'LogNormal',
|
||||
'Normal',
|
||||
'Poisson',
|
||||
'Uniform',
|
||||
]
|
||||
|
|
|
@ -0,0 +1,255 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Poisson Distribution"""
|
||||
import numpy as np
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
import mindspore.nn as nn
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore.common import dtype as mstype
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import check_greater_zero
|
||||
from ._utils.custom_ops import exp_generic, log_generic
|
||||
|
||||
|
||||
class Poisson(Distribution):
|
||||
"""
|
||||
Poisson Distribution.
|
||||
|
||||
Args:
|
||||
rate (float, list, numpy.ndarray, Tensor, Parameter): The rate of the Poisson distribution..
|
||||
seed (int): The seed used in sampling. The global seed is used if it is None. Default: None.
|
||||
dtype (mindspore.dtype): The type of the event samples. Default: mstype.float32.
|
||||
name (str): The name of the distribution. Default: 'Poisson'.
|
||||
|
||||
Note:
|
||||
`rate` must be strictly greater than 0.
|
||||
`dist_spec_args` is `rate`.
|
||||
|
||||
Examples:
|
||||
>>> # To initialize an Poisson distribution of the rate 0.5.
|
||||
>>> import mindspore.nn.probability.distribution as msd
|
||||
>>> p = msd.Poisson(0.5, dtype=mstype.float32)
|
||||
>>>
|
||||
>>> # The following creates two independent Poisson distributions.
|
||||
>>> p = msd.Poisson([0.5, 0.5], dtype=mstype.float32)
|
||||
>>>
|
||||
>>> # An Poisson distribution can be initilized without arguments.
|
||||
>>> # In this case, `rate` must be passed in through `args` during function calls.
|
||||
>>> p = msd.Poisson(dtype=mstype.float32)
|
||||
>>>
|
||||
>>> # To use an Poisson distribution in a network.
|
||||
>>> class net(Cell):
|
||||
... def __init__(self):
|
||||
... super(net, self).__init__():
|
||||
... self.p1 = msd.Poisson(0.5, dtype=mstype.float32)
|
||||
... self.p2 = msd.Poisson(dtype=mstype.float32)
|
||||
...
|
||||
... # All the following calls in construct are valid.
|
||||
... def construct(self, value, rate_b, rate_a):
|
||||
...
|
||||
... # Private interfaces of probability functions corresponding to public interfaces, including
|
||||
... # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, are the same as follows.
|
||||
... # Args:
|
||||
... # value (Tensor): the value to be evaluated.
|
||||
... # rate (Tensor): the rate of the distribution. Default: self.rate.
|
||||
...
|
||||
... # Examples of `prob`.
|
||||
... # Similar calls can be made to other probability functions
|
||||
... # by replacing `prob` by the name of the function.
|
||||
... ans = self.p1.prob(value)
|
||||
... # Evaluate with respect to distribution b.
|
||||
... ans = self.p1.prob(value, rate_b)
|
||||
... # `rate` must be passed in during function calls.
|
||||
... ans = self.p2.prob(value, rate_a)
|
||||
...
|
||||
...
|
||||
... # Functions `mean`, `sd`, and 'var' have the same arguments as follows.
|
||||
... # Args:
|
||||
... # rate (Tensor): the rate of the distribution. Default: self.rate.
|
||||
...
|
||||
... # Examples of `mean`. `sd`, `var`, and `entropy` are similar.
|
||||
... ans = self.p1.mean() # return 2
|
||||
... ans = self.p1.mean(rate_b) # return 1 / rate_b
|
||||
... # `rate` must be passed in during function calls.
|
||||
... ans = self.p2.mean(rate_a)
|
||||
...
|
||||
...
|
||||
... # Examples of `sample`.
|
||||
... # Args:
|
||||
... # shape (tuple): the shape of the sample. Default: ()
|
||||
... # probs1 (Tensor): the rate of the distribution. Default: self.rate.
|
||||
... ans = self.p1.sample()
|
||||
... ans = self.p1.sample((2,3))
|
||||
... ans = self.p1.sample((2,3), rate_b)
|
||||
... ans = self.p2.sample((2,3), rate_a)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
rate=None,
|
||||
seed=None,
|
||||
dtype=mstype.float32,
|
||||
name="Poisson"):
|
||||
"""
|
||||
Constructor of Poisson.
|
||||
"""
|
||||
param = dict(locals())
|
||||
param['param_dict'] = {'rate': rate}
|
||||
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
|
||||
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
|
||||
super(Poisson, self).__init__(seed, dtype, name, param)
|
||||
|
||||
self._rate = self._add_parameter(rate, 'rate')
|
||||
if self.rate is not None:
|
||||
check_greater_zero(self.rate, 'rate')
|
||||
|
||||
# ops needed for the class
|
||||
self.exp = exp_generic
|
||||
self.log = log_generic
|
||||
self.squeeze = P.Squeeze(0)
|
||||
self.cast = P.Cast()
|
||||
self.floor = P.Floor()
|
||||
self.dtypeop = P.DType()
|
||||
self.shape = P.Shape()
|
||||
self.fill = P.Fill()
|
||||
self.less = P.Less()
|
||||
self.equal = P.Equal()
|
||||
self.select = P.Select()
|
||||
self.lgamma = nn.LGamma()
|
||||
self.igamma = nn.IGamma()
|
||||
self.poisson = C.poisson
|
||||
|
||||
def extend_repr(self):
|
||||
if self.is_scalar_batch:
|
||||
s = f'rate = {self.rate}'
|
||||
else:
|
||||
s = f'batch_shape = {self._broadcast_shape}'
|
||||
return s
|
||||
|
||||
@property
|
||||
def rate(self):
|
||||
"""
|
||||
Return `rate` of the distribution.
|
||||
"""
|
||||
return self._rate
|
||||
|
||||
def _get_dist_type(self):
|
||||
return "Poisson"
|
||||
|
||||
def _get_dist_args(self, rate=None):
|
||||
if rate is not None:
|
||||
self.checktensor(rate, 'rate')
|
||||
else:
|
||||
rate = self.rate
|
||||
return (rate,)
|
||||
|
||||
def _mean(self, rate=None):
|
||||
r"""
|
||||
.. math::
|
||||
MEAN(POISSON) = \lambda.
|
||||
"""
|
||||
rate = self._check_param_type(rate)
|
||||
return rate
|
||||
|
||||
def _mode(self, rate=None):
|
||||
r"""
|
||||
.. math::
|
||||
MODE(POISSON) = \lfloor{\lambda}.
|
||||
"""
|
||||
rate = self._check_param_type(rate)
|
||||
return self.floor(rate)
|
||||
|
||||
def _var(self, rate=None):
|
||||
r"""
|
||||
.. math::
|
||||
VAR(POISSON) = \lambda.
|
||||
"""
|
||||
rate = self._check_param_type(rate)
|
||||
return rate
|
||||
|
||||
def _log_prob(self, value, rate=None):
|
||||
r"""
|
||||
Log probability density function of Poisson distributions.
|
||||
|
||||
Args:
|
||||
Args:
|
||||
value (Tensor): The value to be evaluated.
|
||||
rate (Tensor): The rate of the distribution. Default: self.rate.
|
||||
|
||||
Note:
|
||||
`value` must be greater or equal to zero.
|
||||
|
||||
.. math::
|
||||
log_pdf(x) = x * \log(\lambda) - \lambda - \log(\Gamma(x)) if x >= 0 else -inf
|
||||
"""
|
||||
value = self._check_value(value, "value")
|
||||
value = self.cast(value, self.dtype)
|
||||
rate = self._check_param_type(rate)
|
||||
log_rate = self.log(rate)
|
||||
zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0)
|
||||
inf = self.fill(self.dtypeop(value), self.shape(value), np.inf)
|
||||
safe_x = self.select(self.less(value, zeros), zeros, value)
|
||||
y = log_rate * safe_x - self.lgamma(safe_x + 1.)
|
||||
comp = self.equal(value, safe_x)
|
||||
log_unnormalized_prob = self.select(comp, y, -inf)
|
||||
log_normalization = self.exp(log_rate)
|
||||
return log_unnormalized_prob - log_normalization
|
||||
|
||||
def _cdf(self, value, rate=None):
|
||||
r"""
|
||||
Cumulative distribution function (cdf) of Poisson distributions.
|
||||
|
||||
Args:
|
||||
value (Tensor): The value to be evaluated.
|
||||
rate (Tensor): The rate of the distribution. Default: self.rate.
|
||||
|
||||
Note:
|
||||
`value` must be greater or equal to zero.
|
||||
|
||||
.. math::
|
||||
cdf(x) = \Gamma(x + 1) if x >= 0 else 0
|
||||
"""
|
||||
value = self._check_value(value, 'value')
|
||||
value = self.cast(value, self.dtype)
|
||||
rate = self._check_param_type(rate)
|
||||
zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0)
|
||||
comp = self.less(value, zeros)
|
||||
safe_x = self.select(comp, zeros, value)
|
||||
cdf = 1. - self.igamma(1. + safe_x, rate)
|
||||
return self.select(comp, zeros, cdf)
|
||||
|
||||
def _sample(self, shape=(), rate=None):
|
||||
"""
|
||||
Sampling.
|
||||
|
||||
Args:
|
||||
shape (tuple): The shape of the sample. Default: ().
|
||||
rate (Tensor): The rate of the distribution. Default: self.rate.
|
||||
|
||||
Returns:
|
||||
Tensor, shape is shape + batch_shape.
|
||||
"""
|
||||
shape = self.checktuple(shape, 'shape')
|
||||
rate = self._check_param_type(rate)
|
||||
origin_shape = shape + self.shape(rate)
|
||||
if origin_shape == ():
|
||||
sample_shape = (1,)
|
||||
else:
|
||||
sample_shape = origin_shape
|
||||
sample_poisson = self.poisson(sample_shape, rate, self.seed)
|
||||
value = self.cast(sample_poisson, self.dtype)
|
||||
if origin_shape == ():
|
||||
value = self.squeeze(value)
|
||||
return value
|
|
@ -0,0 +1,210 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test cases for Poisson distribution"""
|
||||
import numpy as np
|
||||
from scipy import stats
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.nn.probability.distribution as msd
|
||||
from mindspore import Tensor
|
||||
from mindspore import dtype
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
class Prob(nn.Cell):
|
||||
"""
|
||||
Test class: probability of Poisson distribution.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(Prob, self).__init__()
|
||||
self.p = msd.Poisson([0.5], dtype=dtype.float32)
|
||||
|
||||
def construct(self, x_):
|
||||
return self.p.prob(x_)
|
||||
|
||||
def test_pdf():
|
||||
"""
|
||||
Test pdf.
|
||||
"""
|
||||
poisson_benchmark = stats.poisson(mu=0.5)
|
||||
expect_pdf = poisson_benchmark.pmf([-1.0, 0.0, 1.0]).astype(np.float32)
|
||||
pdf = Prob()
|
||||
x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32)
|
||||
output = pdf(x_)
|
||||
tol = 1e-6
|
||||
assert (np.abs(output.asnumpy() - expect_pdf) < tol).all()
|
||||
|
||||
class LogProb(nn.Cell):
|
||||
"""
|
||||
Test class: log probability of Poisson distribution.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(LogProb, self).__init__()
|
||||
self.p = msd.Poisson(0.5, dtype=dtype.float32)
|
||||
|
||||
def construct(self, x_):
|
||||
return self.p.log_prob(x_)
|
||||
|
||||
def test_log_likelihood():
|
||||
"""
|
||||
Test log_pdf.
|
||||
"""
|
||||
poisson_benchmark = stats.poisson(mu=0.5)
|
||||
expect_logpdf = poisson_benchmark.logpmf([1.0, 2.0]).astype(np.float32)
|
||||
logprob = LogProb()
|
||||
x_ = Tensor(np.array([1.0, 2.0]).astype(np.float32), dtype=dtype.float32)
|
||||
output = logprob(x_)
|
||||
tol = 1e-6
|
||||
assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all()
|
||||
|
||||
class Basics(nn.Cell):
|
||||
"""
|
||||
Test class: mean/sd/mode of Poisson distribution.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(Basics, self).__init__()
|
||||
self.p = msd.Poisson([1.44], dtype=dtype.float32)
|
||||
|
||||
def construct(self):
|
||||
return self.p.mean(), self.p.sd(), self.p.mode()
|
||||
|
||||
def test_basics():
|
||||
"""
|
||||
Test mean/standard/mode deviation.
|
||||
"""
|
||||
basics = Basics()
|
||||
mean, sd, mode = basics()
|
||||
expect_mean = 1.44
|
||||
expect_sd = 1.2
|
||||
expect_mode = 1
|
||||
tol = 1e-6
|
||||
assert (np.abs(mean.asnumpy() - expect_mean) < tol).all()
|
||||
assert (np.abs(sd.asnumpy() - expect_sd) < tol).all()
|
||||
assert (np.abs(mode.asnumpy() - expect_mode) < tol).all()
|
||||
|
||||
class Sampling(nn.Cell):
|
||||
"""
|
||||
Test class: sample of Poisson distribution.
|
||||
"""
|
||||
def __init__(self, shape, seed=0):
|
||||
super(Sampling, self).__init__()
|
||||
self.p = msd.Poisson([[1.0], [0.5]], seed=seed, dtype=dtype.float32)
|
||||
self.shape = shape
|
||||
|
||||
def construct(self, rate=None):
|
||||
return self.p.sample(self.shape, rate)
|
||||
|
||||
def test_sample():
|
||||
"""
|
||||
Test sample.
|
||||
"""
|
||||
shape = (2, 3)
|
||||
seed = 10
|
||||
rate = Tensor([1.0, 2.0, 3.0], dtype=dtype.float32)
|
||||
sample = Sampling(shape, seed=seed)
|
||||
output = sample(rate)
|
||||
assert output.shape == (2, 3, 3)
|
||||
|
||||
class CDF(nn.Cell):
|
||||
"""
|
||||
Test class: cdf of Poisson distribution.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(CDF, self).__init__()
|
||||
self.p = msd.Poisson([0.5], dtype=dtype.float32)
|
||||
|
||||
def construct(self, x_):
|
||||
return self.p.cdf(x_)
|
||||
|
||||
def test_cdf():
|
||||
"""
|
||||
Test cdf.
|
||||
"""
|
||||
poisson_benchmark = stats.poisson(mu=0.5)
|
||||
expect_cdf = poisson_benchmark.cdf([-1.0, 0.0, 1.0]).astype(np.float32)
|
||||
cdf = CDF()
|
||||
x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32)
|
||||
output = cdf(x_)
|
||||
tol = 1e-6
|
||||
assert (np.abs(output.asnumpy() - expect_cdf) < tol).all()
|
||||
|
||||
class LogCDF(nn.Cell):
|
||||
"""
|
||||
Test class: log_cdf of Poisson distribution.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(LogCDF, self).__init__()
|
||||
self.p = msd.Poisson([0.5], dtype=dtype.float32)
|
||||
|
||||
def construct(self, x_):
|
||||
return self.p.log_cdf(x_)
|
||||
|
||||
def test_log_cdf():
|
||||
"""
|
||||
Test log_cdf.
|
||||
"""
|
||||
poisson_benchmark = stats.poisson(mu=0.5)
|
||||
expect_logcdf = poisson_benchmark.logcdf([0.5, 1.0, 2.5]).astype(np.float32)
|
||||
logcdf = LogCDF()
|
||||
x_ = Tensor(np.array([0.5, 1.0, 2.5]).astype(np.float32), dtype=dtype.float32)
|
||||
output = logcdf(x_)
|
||||
tol = 1e-6
|
||||
assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all()
|
||||
|
||||
class SF(nn.Cell):
|
||||
"""
|
||||
Test class: survival function of Poisson distribution.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(SF, self).__init__()
|
||||
self.p = msd.Poisson(0.5, dtype=dtype.float32)
|
||||
|
||||
def construct(self, x_):
|
||||
return self.p.survival_function(x_)
|
||||
|
||||
def test_survival():
|
||||
"""
|
||||
Test survival function.
|
||||
"""
|
||||
poisson_benchmark = stats.poisson(mu=0.5)
|
||||
expect_survival = poisson_benchmark.sf([-1.0, 0.0, 1.0]).astype(np.float32)
|
||||
survival = SF()
|
||||
x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32)
|
||||
output = survival(x_)
|
||||
tol = 1e-6
|
||||
assert (np.abs(output.asnumpy() - expect_survival) < tol).all()
|
||||
|
||||
class LogSF(nn.Cell):
|
||||
"""
|
||||
Test class: log survival function of Poisson distribution.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(LogSF, self).__init__()
|
||||
self.p = msd.Poisson(0.5, dtype=dtype.float32)
|
||||
|
||||
def construct(self, x_):
|
||||
return self.p.log_survival(x_)
|
||||
|
||||
def test_log_survival():
|
||||
"""
|
||||
Test log survival function.
|
||||
"""
|
||||
poisson_benchmark = stats.poisson(mu=0.5)
|
||||
expect_logsurvival = poisson_benchmark.logsf([-1.0, 0.0, 1.0]).astype(np.float32)
|
||||
logsurvival = LogSF()
|
||||
x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32)
|
||||
output = logsurvival(x_)
|
||||
tol = 1e-6
|
||||
assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all()
|
|
@ -0,0 +1,154 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Test nn.probability.distribution.Poisson.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.nn.probability.distribution as msd
|
||||
from mindspore import dtype
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
def test_arguments():
|
||||
"""
|
||||
Args passing during initialization.
|
||||
"""
|
||||
p = msd.Poisson()
|
||||
assert isinstance(p, msd.Distribution)
|
||||
p = msd.Poisson([0.1, 0.3, 0.5, 1.0], dtype=dtype.float32)
|
||||
assert isinstance(p, msd.Distribution)
|
||||
|
||||
def test_type():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Poisson([0.1], dtype=dtype.bool_)
|
||||
|
||||
def test_name():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Poisson([0.1], name=1.0)
|
||||
|
||||
def test_seed():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Poisson([0.1], seed='seed')
|
||||
|
||||
def test_rate():
|
||||
"""
|
||||
Invalid rate.
|
||||
"""
|
||||
with pytest.raises(ValueError):
|
||||
msd.Poisson([-0.1], dtype=dtype.float32)
|
||||
with pytest.raises(ValueError):
|
||||
msd.Poisson([0.0], dtype=dtype.float32)
|
||||
|
||||
class PoissonProb(nn.Cell):
|
||||
"""
|
||||
Poisson distribution: initialize with rate.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(PoissonProb, self).__init__()
|
||||
self.p = msd.Poisson([0.5, 0.5, 0.5, 0.5, 0.5], dtype=dtype.float32)
|
||||
|
||||
def construct(self, value):
|
||||
prob = self.p.prob(value)
|
||||
log_prob = self.p.log_prob(value)
|
||||
cdf = self.p.cdf(value)
|
||||
log_cdf = self.p.log_cdf(value)
|
||||
sf = self.p.survival_function(value)
|
||||
log_sf = self.p.log_survival(value)
|
||||
return prob + log_prob + cdf + log_cdf + sf + log_sf
|
||||
|
||||
def test_poisson_prob():
|
||||
"""
|
||||
Test probability functions: passing value through construct.
|
||||
"""
|
||||
net = PoissonProb()
|
||||
value = Tensor([0.2, 0.3, 5.0, 2, 3.9], dtype=dtype.float32)
|
||||
ans = net(value)
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
class PoissonProb1(nn.Cell):
|
||||
"""
|
||||
Poisson distribution: initialize without rate.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(PoissonProb1, self).__init__()
|
||||
self.p = msd.Poisson(dtype=dtype.float32)
|
||||
|
||||
def construct(self, value, rate):
|
||||
prob = self.p.prob(value, rate)
|
||||
log_prob = self.p.log_prob(value, rate)
|
||||
cdf = self.p.cdf(value, rate)
|
||||
log_cdf = self.p.log_cdf(value, rate)
|
||||
sf = self.p.survival_function(value, rate)
|
||||
log_sf = self.p.log_survival(value, rate)
|
||||
return prob + log_prob + cdf + log_cdf + sf + log_sf
|
||||
|
||||
def test_poisson_prob1():
|
||||
"""
|
||||
Test probability functions: passing value/rate through construct.
|
||||
"""
|
||||
net = PoissonProb1()
|
||||
value = Tensor([0.2, 0.9, 1, 2, 3], dtype=dtype.float32)
|
||||
rate = Tensor([0.5, 0.5, 0.5, 0.5, 0.5], dtype=dtype.float32)
|
||||
ans = net(value, rate)
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
class PoissonBasics(nn.Cell):
|
||||
"""
|
||||
Test class: basic mean/sd/var/mode function.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(PoissonBasics, self).__init__()
|
||||
self.p = msd.Poisson([2.3, 2.5], dtype=dtype.float32)
|
||||
|
||||
def construct(self):
|
||||
mean = self.p.mean()
|
||||
sd = self.p.sd()
|
||||
var = self.p.var()
|
||||
return mean + sd + var
|
||||
|
||||
def test_bascis():
|
||||
"""
|
||||
Test mean/sd/var/mode functionality of Poisson distribution.
|
||||
"""
|
||||
net = PoissonBasics()
|
||||
ans = net()
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
class PoissonConstruct(nn.Cell):
|
||||
"""
|
||||
Poisson distribution: going through construct.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(PoissonConstruct, self).__init__()
|
||||
self.p = msd.Poisson([0.5, 0.5, 0.5, 0.5, 0.5], dtype=dtype.float32)
|
||||
self.p1 = msd.Poisson(dtype=dtype.float32)
|
||||
|
||||
def construct(self, value, rate):
|
||||
prob = self.p('prob', value)
|
||||
prob1 = self.p('prob', value, rate)
|
||||
prob2 = self.p1('prob', value, rate)
|
||||
return prob + prob1 + prob2
|
||||
|
||||
def test_poisson_construct():
|
||||
"""
|
||||
Test probability function going through construct.
|
||||
"""
|
||||
net = PoissonConstruct()
|
||||
value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32)
|
||||
probs = Tensor([0.5, 0.5, 0.5, 0.5, 0.5], dtype=dtype.float32)
|
||||
ans = net(value, probs)
|
||||
assert isinstance(ans, Tensor)
|
Loading…
Reference in New Issue