add 3 distributions
This commit is contained in:
parent
66270a8f84
commit
41e9f05a0b
|
@ -32,9 +32,12 @@ Distributions
|
|||
mindspore.nn.probability.distribution.Gamma
|
||||
mindspore.nn.probability.distribution.Geometric
|
||||
mindspore.nn.probability.distribution.Gumbel
|
||||
mindspore.nn.probability.distribution.HalfNormal
|
||||
mindspore.nn.probability.distribution.Laplace
|
||||
mindspore.nn.probability.distribution.Logistic
|
||||
mindspore.nn.probability.distribution.LogNormal
|
||||
mindspore.nn.probability.distribution.Normal
|
||||
mindspore.nn.probability.distribution.Poisson
|
||||
mindspore.nn.probability.distribution.StudentT
|
||||
mindspore.nn.probability.distribution.TransformedDistribution
|
||||
mindspore.nn.probability.distribution.Uniform
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
mindspore.nn.probability.distribution.HalfNormal
|
||||
================================================
|
||||
|
||||
.. py:class:: mindspore.nn.probability.distribution.HalfNormal(mean=None, sd=None, seed=None, dtype=mstype.float32, name='HalfNormal')
|
||||
|
||||
半正态分布(HalfNormal distribution)。
|
||||
连续随机分布,取值范围为 :math:`[\mu, \inf)` ,概率密度函数为
|
||||
|
||||
.. math::
|
||||
f(x; \mu, \sigma) = 1 / \sigma\sqrt{2\pi} \exp(-(x - \mu)^2 / 2\sigma^2).
|
||||
|
||||
其中 :math:`\mu, \sigma` 为分别为半正态分布的期望与标准差。
|
||||
|
||||
参数:
|
||||
- **mean** (int, float, list, numpy.ndarray, Tensor) - 半正态分布的平均值。默认值:None。
|
||||
- **sd** (int, float, list, numpy.ndarray, Tensor) - 半正态分布的标准差。默认值:None。
|
||||
- **seed** (int) - 采样时使用的种子。如果为None,则使用全局种子。默认值:None。
|
||||
- **dtype** (mindspore.dtype) - 事件样例的类型。默认值:mstype.float32。
|
||||
- **name** (str) - 分布的名称。默认值:'HalfNormal'。
|
||||
|
||||
.. note::
|
||||
- `sd` 必须大于0。
|
||||
- `dtype` 必须是float,因为半正态分布是连续的。
|
||||
|
||||
异常:
|
||||
- **ValueError** - `sd` 中元素不大于0。
|
||||
- **TypeError** - `dtype` 不是float的子类。
|
||||
|
||||
.. py:method:: log_prob(value, mean, sd)
|
||||
|
||||
计算给定值对应的概率的对数。
|
||||
|
||||
参数:
|
||||
- **value** (Tensor) - 要计算的值。
|
||||
- **mean** (Tensor) - 分布的期望。默认值:None。
|
||||
- **sd** (Tensor) - 分布的标准差。默认值:None。
|
||||
|
||||
返回:
|
||||
Tensor,概率密度函数的对数。
|
|
@ -0,0 +1,39 @@
|
|||
mindspore.nn.probability.distribution.Laplace
|
||||
================================================
|
||||
|
||||
.. py:class:: mindspore.nn.probability.distribution.Laplace(mean=None, sd=None, seed=None, dtype=mstype.float32, name='Laplace')
|
||||
|
||||
拉普拉斯分布(Laplace distribution)。
|
||||
连续随机分布,取值范围为 :math:`(-\inf, \inf)` ,概率密度函数为
|
||||
|
||||
.. math::
|
||||
f(x, \mu, b) = 1 / (2. * b) * \exp(-abs(x - \mu) / b).
|
||||
|
||||
其中 :math:`\mu, b` 为分别为拉普拉斯分布的期望与扩散度。
|
||||
|
||||
参数:
|
||||
- **mean** (int, float, list, numpy.ndarray, Tensor) - 拉普拉斯分布的平均值。默认值:None。
|
||||
- **sd** (int, float, list, numpy.ndarray, Tensor) - 拉普拉斯分布的扩散度。默认值:None。
|
||||
- **seed** (int) - 采样时使用的种子。如果为None,则使用全局种子。默认值:None。
|
||||
- **dtype** (mindspore.dtype) - 事件样例的类型。默认值:mstype.float32。
|
||||
- **name** (str) - 分布的名称。默认值:'Laplace'。
|
||||
|
||||
.. note::
|
||||
- `sd` 必须大于0。
|
||||
- `dtype` 必须是float,因为拉普拉斯分布是连续的。
|
||||
|
||||
异常:
|
||||
- **ValueError** - `sd` 中元素不大于0。
|
||||
- **TypeError** - `dtype` 不是float的子类。
|
||||
|
||||
.. py:method:: log_prob(value, mean, sd)
|
||||
|
||||
计算给定值对应的概率的对数。
|
||||
|
||||
参数:
|
||||
- **value** (Tensor) - 要计算的值。
|
||||
- **mean** (Tensor) - 分布的期望。默认值:None。
|
||||
- **sd** (Tensor) - 分布的扩散度。默认值:None。
|
||||
|
||||
返回:
|
||||
Tensor,概率密度函数的对数。
|
|
@ -0,0 +1,44 @@
|
|||
mindspore.nn.probability.distribution.StudentT
|
||||
================================================
|
||||
|
||||
.. py:class:: mindspore.nn.probability.distribution.StudentT(df=None, mean=None, sd=None, seed=None, dtype=mstype.float32, name='StudentT')
|
||||
|
||||
StudentT分布(StudentT distribution)。
|
||||
连续随机分布,取值范围为 :math:`(-\inf, \inf)` ,概率密度函数为
|
||||
|
||||
.. math::
|
||||
f(x, \nu, \mu, \sigma) = (1 + y^2 / \nu)^(-0.5*(\nu + 1)) / Z
|
||||
|
||||
|
||||
其中 :math:`y = (x - \mu) / \sigma`, :math:`Z = abs(\sigma) * \sqrt(\nu * \pi) * \Gamma(0.5 * \nu) / \Gamma(0.5 * (\nu + 1))`, :math:`\nu, \mu, \sigma` 为分别为StudentT分布的自由度,期望与标准差。
|
||||
|
||||
参数:
|
||||
- **df** (int, float, list, numpy.ndarray, Tensor) - StudentT分布的自由度。默认值:None。
|
||||
- **mean** (int, float, list, numpy.ndarray, Tensor) - StudentT分布的平均值。默认值:None。
|
||||
- **sd** (int, float, list, numpy.ndarray, Tensor) - StudentT分布的扩散度。默认值:None。
|
||||
- **seed** (int) - 采样时使用的种子。如果为None,则使用全局种子。默认值:None。
|
||||
- **dtype** (mindspore.dtype) - 事件样例的类型。默认值:mstype.float32。
|
||||
- **name** (str) - 分布的名称。默认值:'StudentT'。
|
||||
|
||||
.. note::
|
||||
- `df` 必须大于0。
|
||||
- `sd` 必须大于0。
|
||||
- `dtype` 必须是float,因为StudentT分布是连续的。
|
||||
|
||||
异常:
|
||||
- **ValueError** - `df` 中元素不大于0。
|
||||
- **ValueError** - `sd` 中元素不大于0。
|
||||
- **TypeError** - `dtype` 不是float的子类。
|
||||
|
||||
.. py:method:: log_prob(value, df, mean, sd)
|
||||
|
||||
计算给定值对应的概率的对数。
|
||||
|
||||
参数:
|
||||
- **value** (Tensor) - 要计算的值。
|
||||
- **df** (Tensor) - 分布的自由度。默认值:None。
|
||||
- **mean** (Tensor) - 分布的期望。默认值:None。
|
||||
- **sd** (Tensor) - 分布的扩散度。默认值:None。
|
||||
|
||||
返回:
|
||||
Tensor,概率密度函数的对数。
|
|
@ -34,9 +34,12 @@ Distributions
|
|||
mindspore.nn.probability.distribution.Gamma
|
||||
mindspore.nn.probability.distribution.Geometric
|
||||
mindspore.nn.probability.distribution.Gumbel
|
||||
mindspore.nn.probability.distribution.HalfNormal
|
||||
mindspore.nn.probability.distribution.Laplace
|
||||
mindspore.nn.probability.distribution.Logistic
|
||||
mindspore.nn.probability.distribution.LogNormal
|
||||
mindspore.nn.probability.distribution.Normal
|
||||
mindspore.nn.probability.distribution.Poisson
|
||||
mindspore.nn.probability.distribution.StudentT
|
||||
mindspore.nn.probability.distribution.TransformedDistribution
|
||||
mindspore.nn.probability.distribution.Uniform
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
.. py:method:: log_prob(value, mean, sd)
|
||||
|
||||
the log value of the probability.
|
||||
|
||||
**Parameters**
|
||||
|
||||
- **value** (Tensor) - the value to compute.
|
||||
- **mean** (Tensor) - the mean of the distribution. Default value: None.
|
||||
- **sd** (Tensor) - the standard deviation of the distribution. Default value: None.
|
||||
|
||||
**Returns**
|
||||
|
||||
Tensor, the log value of the probability.
|
|
@ -0,0 +1,13 @@
|
|||
.. py:method:: log_prob(value, mean, sd)
|
||||
|
||||
the log value of the probability.
|
||||
|
||||
**Parameters**
|
||||
|
||||
- **value** (Tensor) - the value to compute.
|
||||
- **mean** (Tensor) - the mean of the distribution. Default value: None.
|
||||
- **sd** (Tensor) - the standard deviation of the distribution. Default value: None.
|
||||
|
||||
**Returns**
|
||||
|
||||
Tensor, the log value of the probability.
|
|
@ -0,0 +1,14 @@
|
|||
.. py:method:: log_prob(value, mean, sd)
|
||||
|
||||
the log value of the probability.
|
||||
|
||||
**Parameters**
|
||||
|
||||
- **value** (Tensor) - the value to compute.
|
||||
- **df** (Tensor) - the degrees of freedom of the distribution. Default value: None.
|
||||
- **mean** (Tensor) - the mean of the distribution. Default value: None.
|
||||
- **sd** (Tensor) - the standard deviation of the distribution. Default value: None.
|
||||
|
||||
**Returns**
|
||||
|
||||
Tensor, the log value of the probability.
|
|
@ -31,6 +31,9 @@ from .log_normal import LogNormal
|
|||
from .normal import Normal
|
||||
from .poisson import Poisson
|
||||
from .uniform import Uniform
|
||||
from .half_normal import HalfNormal
|
||||
from .laplace import Laplace
|
||||
from .student_t import StudentT
|
||||
|
||||
__all__ = ['Distribution',
|
||||
'TransformedDistribution',
|
||||
|
@ -47,4 +50,7 @@ __all__ = ['Distribution',
|
|||
'Normal',
|
||||
'Poisson',
|
||||
'Uniform',
|
||||
'HalfNormal',
|
||||
'Laplace',
|
||||
'StudentT',
|
||||
]
|
||||
|
|
|
@ -113,13 +113,17 @@ class Distribution(Cell):
|
|||
# ops needed for the base class
|
||||
self.cast_base = P.Cast()
|
||||
self.dtype_base = P.DType()
|
||||
self.exp_base = exp_generic
|
||||
self.fill_base = P.Fill()
|
||||
self.log_base = log_generic
|
||||
self.sametypeshape_base = inner.SameTypeShape()
|
||||
self.sq_base = P.Square()
|
||||
self.sqrt_base = P.Sqrt()
|
||||
self.shape_base = P.Shape()
|
||||
if self.device_target != "Ascend":
|
||||
self.log_base = P.Log()
|
||||
self.exp_base = P.Exp()
|
||||
else:
|
||||
self.exp_base = exp_generic
|
||||
self.log_base = log_generic
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""HalfNormal Distribution"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
import numpy as np
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn.probability.distribution import Distribution
|
||||
from mindspore.nn.probability.distribution._utils.utils import check_greater_zero
|
||||
|
||||
|
||||
class HalfNormal(Distribution):
|
||||
r"""
|
||||
HalfNormal distribution.
|
||||
A HalfNormal distribution is a continuous distribution with the range :math:`[\mu, \inf)`
|
||||
and the probability density function:
|
||||
|
||||
.. math::
|
||||
f(x, \mu, \sigma) = 1 / \sigma\sqrt{2\pi} \exp(-(x - \mu)^2 / 2\sigma^2).
|
||||
|
||||
where :math:`\mu, \sigma` are the mean and the standard deviation of the half normal distribution respectively.
|
||||
|
||||
Args:
|
||||
mean (int, float, list, numpy.ndarray, Tensor): The mean of the distribution. Default: None.
|
||||
sd (int, float, list, numpy.ndarray, Tensor): The standard deviation of the distribution. Default: None.
|
||||
seed (int): The seed used in sampling. The global seed is used if it is None. Default: None.
|
||||
dtype (mindspore.dtype): The type of the event samples. Default: mstype.float32.
|
||||
name (str): The name of the distribution. Default: 'HalfNormal'.
|
||||
|
||||
Note:
|
||||
- `sd` must be greater than zero.
|
||||
- `dist_spec_args` are `mean` and `sd`.
|
||||
- `dtype` must be a float type because HalfNormal distributions are continuous.
|
||||
|
||||
Raises:
|
||||
ValueError: When sd <= 0.
|
||||
TypeError: When the input `dtype` is not a subclass of float.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore
|
||||
>>> import mindspore.nn as nn
|
||||
>>> from mindspore.nn.probability.distribution import HalfNormal
|
||||
>>> from mindspore import Tensor
|
||||
>>> # To initialize a HalfNormal distribution of the mean 3.0 and the standard deviation 4.0.
|
||||
>>> n1 = HalfNormal(3.0, 4.0, dtype=mindspore.float32)
|
||||
>>> # A HalfNormal distribution can be initialized without arguments.
|
||||
>>> # In this case, `mean` and `sd` must be passed in through arguments.
|
||||
>>> hn = HalfNormal(dtype=mindspore.float32)
|
||||
>>> # Here are some tensors used below for testing
|
||||
>>> value = Tensor([1.0, 2.0, 3.0], dtype=mindspore.float32)
|
||||
>>> mean_a = Tensor([2.0], dtype=mindspore.float32)
|
||||
>>> sd_a = Tensor([2.0, 2.0, 2.0], dtype=mindspore.float32)
|
||||
>>> mean_b = Tensor([1.0], dtype=mindspore.float32)
|
||||
>>> sd_b = Tensor([1.0, 1.5, 2.5], dtype=mindspore.float32)
|
||||
>>> ans = n1.log_prob(value)
|
||||
>>> print(ans.shape)
|
||||
(3,)
|
||||
>>> # Evaluate with respect to the distribution b.
|
||||
>>> ans = n1.log_prob(value, mean_b, sd_b)
|
||||
>>> print(ans.shape)
|
||||
(3,)
|
||||
>>> # `mean` and `sd` must be passed in during function calls
|
||||
>>> ans = hn.log_prob(value, mean_a, sd_a)
|
||||
>>> print(ans.shape)
|
||||
(3,)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
mean=None,
|
||||
sd=None,
|
||||
seed=None,
|
||||
dtype=mstype.float32,
|
||||
name="HalfNormal"):
|
||||
"""
|
||||
Constructor of HalfNormal.
|
||||
"""
|
||||
param = dict(locals())
|
||||
param['param_dict'] = {'mean': mean, 'sd': sd}
|
||||
valid_dtype = mstype.float_type
|
||||
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
|
||||
super(HalfNormal, self).__init__(seed, dtype, name, param)
|
||||
|
||||
self._mean_value = self._add_parameter(mean, 'mean')
|
||||
self._sd_value = self._add_parameter(sd, 'sd')
|
||||
if self._sd_value is not None:
|
||||
check_greater_zero(self._sd_value, "Standard deviation")
|
||||
|
||||
self.exp = P.Exp()
|
||||
self.cast = P.Cast()
|
||||
self.const = np.sqrt(2. / np.pi)
|
||||
self.sq = P.Square()
|
||||
self.type = dtype
|
||||
|
||||
def _prob(self, value, mean=None, sd=None):
|
||||
r"""
|
||||
Evaluate probability.
|
||||
|
||||
Args:
|
||||
value (Tensor): The value to be evaluated.
|
||||
mean (Tensor): The mean of the distribution. Default: self._mean_value.
|
||||
sd (Tensor): The standard deviation the distribution. Default: self._sd_value.
|
||||
|
||||
.. math::
|
||||
P(x) = 1 / \sigma \sqrt{2\pi} \exp(-(x - \mu)^2 / 2\sigma^2)
|
||||
"""
|
||||
value = self._check_value(value, 'value')
|
||||
value = self.cast(value, self.dtype)
|
||||
mean, sd = self._check_param_type(mean, sd)
|
||||
|
||||
coeff = self.const / sd
|
||||
pdf = coeff * self.exp(-0.5 * self.sq((value - mean) / sd))
|
||||
return pdf * self.cast(value >= 0, self.type)
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Laplace Distribution"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn.probability.distribution import Distribution
|
||||
from mindspore.nn.probability.distribution._utils.utils import check_greater_zero
|
||||
|
||||
|
||||
class Laplace(Distribution):
|
||||
r"""
|
||||
Laplace distribution.
|
||||
A Laplace distribution is a continuous distribution with the range :math:`[-\inf, \inf)`
|
||||
and the probability density function:
|
||||
|
||||
.. math::
|
||||
f(x, \mu, b) = 1 / (2. * b) * \exp(-abs(x - \mu) / b).
|
||||
|
||||
where :math:`\mu, b` are the mean and the scale of the laplace distribution respectively.
|
||||
|
||||
Args:
|
||||
mean (int, float, list, numpy.ndarray, Tensor): The mean of the distribution. Default: None.
|
||||
sd (int, float, list, numpy.ndarray, Tensor): The standard deviation of the distribution. Default: None.
|
||||
seed (int): The seed used in sampling. The global seed is used if it is None. Default: None.
|
||||
dtype (mindspore.dtype): The type of the event samples. Default: mstype.float32.
|
||||
name (str): The name of the distribution. Default: 'Laplace'.
|
||||
|
||||
Note:
|
||||
- `sd` must be greater than zero.
|
||||
- `dist_spec_args` are `mean` and `sd`.
|
||||
- `dtype` must be a float type because Laplace distributions are continuous.
|
||||
|
||||
Raises:
|
||||
ValueError: When sd <= 0.
|
||||
TypeError: When the input `dtype` is not a subclass of float.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore
|
||||
>>> import mindspore.nn as nn
|
||||
>>> from mindspore.nn.probability.distribution import Laplace
|
||||
>>> from mindspore import Tensor
|
||||
>>> # To initialize a Laplace distribution of the mean 3.0 and the standard deviation 4.0.
|
||||
>>> n1 = Laplace(3.0, 4.0, dtype=mindspore.float32)
|
||||
>>> # A Laplace distribution can be initialized without arguments.
|
||||
>>> # In this case, `mean` and `sd` must be passed in through arguments.
|
||||
>>> n2 = Laplace(dtype=mindspore.float32)
|
||||
>>> # Here are some tensors used below for testing
|
||||
>>> value = Tensor([1.0, 2.0, 3.0], dtype=mindspore.float32)
|
||||
>>> mean_a = Tensor([2.0], dtype=mindspore.float32)
|
||||
>>> sd_a = Tensor([2.0, 2.0, 2.0], dtype=mindspore.float32)
|
||||
>>> mean_b = Tensor([1.0], dtype=mindspore.float32)
|
||||
>>> sd_b = Tensor([1.0, 1.5, 2.0], dtype=mindspore.float32)
|
||||
>>> ans = n1.log_prob(value)
|
||||
>>> print(ans.shape)
|
||||
(3,)
|
||||
>>> # Evaluate with respect to the distribution b.
|
||||
>>> ans = n1.log_prob(value, mean_b, sd_b)
|
||||
>>> print(ans.shape)
|
||||
(3,)
|
||||
>>> # `mean` and `sd` must be passed in during function calls
|
||||
>>> ans = n2.log_prob(value, mean_a, sd_a)
|
||||
>>> print(ans.shape)
|
||||
(3,)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
mean=None,
|
||||
sd=None,
|
||||
seed=None,
|
||||
dtype=mstype.float32,
|
||||
name="Laplace"):
|
||||
"""
|
||||
Constructor of Laplace.
|
||||
"""
|
||||
param = dict(locals())
|
||||
param['param_dict'] = {'mean': mean, 'sd': sd}
|
||||
valid_dtype = mstype.float_type
|
||||
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
|
||||
super(Laplace, self).__init__(seed, dtype, name, param)
|
||||
|
||||
self._mean_value = self._add_parameter(mean, 'mean')
|
||||
self._sd_value = self._add_parameter(sd, 'sd')
|
||||
if self._sd_value is not None:
|
||||
check_greater_zero(self._sd_value, "Standard deviation")
|
||||
|
||||
self.log = P.Log()
|
||||
self.cast = P.Cast()
|
||||
self.abs = P.Abs()
|
||||
|
||||
def _log_prob(self, value, mean=None, sd=None):
|
||||
r"""
|
||||
Evaluate log probability.
|
||||
|
||||
Args:
|
||||
value (Tensor): The value to be evaluated.
|
||||
mean (Tensor): The mean of the distribution. Default: self._mean_value.
|
||||
sd (Tensor): The standard deviation the distribution. Default: self._sd_value.
|
||||
|
||||
.. math::
|
||||
L(x) = -1* \abs{\frac{x - \mu}{\sigma}} - \log(2. * \sigma))
|
||||
"""
|
||||
value = self._check_value(value, 'value')
|
||||
value = self.cast(value, self.dtype)
|
||||
mean, sd = self._check_param_type(mean, sd)
|
||||
|
||||
pdf = -1.0 * (self.abs((value - mean) / sd)) - self.log(2. * sd)
|
||||
return pdf
|
|
@ -14,20 +14,19 @@
|
|||
# ============================================================================
|
||||
"""Normal Distribution"""
|
||||
import numpy as np
|
||||
from mindspore import context
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common import Tensor
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import check_greater_zero, check_distribution_name
|
||||
from ._utils.custom_ops import exp_generic, log_generic
|
||||
|
||||
|
||||
class Normal(Distribution):
|
||||
r"""
|
||||
Normal distribution.
|
||||
A Normal distributio is a continuous distribution with the range :math:`(-\inf, \inf)`
|
||||
A Normal distribution is a continuous distribution with the range :math:`(-\inf, \inf)`
|
||||
and the probability density function:
|
||||
|
||||
.. math::
|
||||
|
@ -166,11 +165,9 @@ class Normal(Distribution):
|
|||
check_greater_zero(self._sd_value, "Standard deviation")
|
||||
|
||||
# ops needed for the class
|
||||
self.exp = exp_generic
|
||||
self.exp = self.exp_base
|
||||
self.log = self.log_base
|
||||
self.expm1 = P.Expm1()
|
||||
# when the graph kernel mode is enable
|
||||
# use Log directly as akg will handle the corner cases
|
||||
self.log = P.Log() if context.get_context("enable_graph_kernel") else log_generic
|
||||
self.erf = P.Erf()
|
||||
self.squeeze = P.Squeeze(0)
|
||||
self.cast = P.Cast()
|
||||
|
@ -178,6 +175,7 @@ class Normal(Distribution):
|
|||
self.shape = P.Shape()
|
||||
self.sq = P.Square()
|
||||
self.sqrt = P.Sqrt()
|
||||
self.coff = Tensor(-0.5 * np.log(2. * np.pi), dtype=dtype)
|
||||
|
||||
def extend_repr(self):
|
||||
"""Display instance object as string."""
|
||||
|
@ -262,10 +260,8 @@ class Normal(Distribution):
|
|||
value = self._check_value(value, 'value')
|
||||
value = self.cast(value, self.dtype)
|
||||
mean, sd = self._check_param_type(mean, sd)
|
||||
unnormalized_log_prob = -1. * \
|
||||
(self.sq(value - mean)) / (2. * self.sq(sd))
|
||||
neg_normalization = -1. * \
|
||||
self.log(self.const(2. * np.pi, mstype.float32)) / 2. - self.log(sd)
|
||||
unnormalized_log_prob = -0.5 * (self.sq((value - mean) / sd))
|
||||
neg_normalization = self.coff - self.log(sd)
|
||||
return unnormalized_log_prob + neg_normalization
|
||||
|
||||
def _cdf(self, value, mean=None, sd=None):
|
||||
|
|
|
@ -0,0 +1,145 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""StudentT Distribution"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn.probability.distribution import Distribution
|
||||
from mindspore.nn.probability.distribution._utils.utils import check_greater_zero
|
||||
|
||||
|
||||
class StudentT(Distribution):
|
||||
r"""
|
||||
StudentT distribution.
|
||||
A StudentT distribution is a continuous distribution with the range :math:`[-\inf, \inf)`
|
||||
and the probability density function:
|
||||
|
||||
.. math::
|
||||
f(x, \nu, \mu, \sigma) = (1 + y^2 / \nu)^(-0.5*(\nu + 1)) / Z
|
||||
|
||||
where :math:`y = (x-\mu)/\sigma`, :math:`Z = abs(\sigma)*\sqrt(\nu * \pi)*\Gamma(0.5 * \nu)/\Gamma(0.5*(\nu + 1))`,
|
||||
:math:`\nu, \mu, \sigma` are the degrees of freedom , mean and scale of the laplace distribution respectively.
|
||||
|
||||
Args:
|
||||
df (int, float, list, numpy.ndarray, Tensor): The degrees of freedom. Default: None.
|
||||
mean (int, float, list, numpy.ndarray, Tensor): The mean of the distribution. Default: None.
|
||||
sd (int, float, list, numpy.ndarray, Tensor): The standard deviation of the distribution. Default: None.
|
||||
seed (int): The seed used in sampling. The global seed is used if it is None. Default: None.
|
||||
dtype (mindspore.dtype): The type of the event samples. Default: mstype.float32.
|
||||
name (str): The name of the distribution. Default: 'StudentT'.
|
||||
|
||||
Note:
|
||||
- `df` must be greater than zero.
|
||||
- `sd` must be greater than zero.
|
||||
- `dist_spec_args` are `mean` and `sd`.
|
||||
- `dtype` must be a float type because StudentT distributions are continuous.
|
||||
|
||||
Raises:
|
||||
ValueError: When df <= 0.
|
||||
ValueError: When sd <= 0.
|
||||
TypeError: When the input `dtype` is not a subclass of float.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore
|
||||
>>> import mindspore.nn as nn
|
||||
>>> import mindspore.nn.probability.distribution as msd
|
||||
>>> from mindspore import Tensor
|
||||
>>> # To initialize a StudentT distribution of the df 2.0, the mean 3.0 and the standard deviation 4.0.
|
||||
>>> n1 = msd.StudentT(2.0, 3.0, 4.0, dtype=mindspore.float32)
|
||||
>>> # A StudentT distribution can be initialized without arguments.
|
||||
>>> # In this case, `df`, `mean` and `sd` must be passed in through arguments.
|
||||
>>> n2 = msd.StudentT(dtype=mindspore.float32)
|
||||
>>> # Here are some tensors used below for testing
|
||||
>>> value = Tensor([1.0, 2.0, 3.0], dtype=mindspore.float32)
|
||||
>>> df_a = Tensor([2.0], dtype=mindspore.float32)
|
||||
>>> mean_a = Tensor([2.0], dtype=mindspore.float32)
|
||||
>>> sd_a = Tensor([2.0, 2.0, 2.0], dtype=mindspore.float32)
|
||||
>>> df_b = Tensor([1.0], dtype=mindspore.float32)
|
||||
>>> mean_b = Tensor([1.0], dtype=mindspore.float32)
|
||||
>>> sd_b = Tensor([1.0, 1.5, 2.0], dtype=mindspore.float32)
|
||||
>>> ans = n1.log_prob(value)
|
||||
>>> print(ans.shape)
|
||||
(3,)
|
||||
>>> # Evaluate with respect to the distribution b.
|
||||
>>> ans = n1.log_prob(value, df_b, mean_b, sd_b)
|
||||
>>> print(ans.shape)
|
||||
(3,)
|
||||
>>> # `mean` and `sd` must be passed in during function calls
|
||||
>>> ans = n2.log_prob(value, df_a, mean_a, sd_a)
|
||||
>>> print(ans.shape)
|
||||
(3,)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
df=None,
|
||||
mean=None,
|
||||
sd=None,
|
||||
seed=None,
|
||||
dtype=mstype.float32,
|
||||
name="StudentT"):
|
||||
"""
|
||||
Constructor of StudentT.
|
||||
"""
|
||||
param = dict(locals())
|
||||
param['param_dict'] = {'df': df, 'mean': mean, 'sd': sd}
|
||||
valid_dtype = mstype.float_type
|
||||
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
|
||||
super(StudentT, self).__init__(seed, dtype, name, param)
|
||||
|
||||
self._df_value = self._add_parameter(df, 'df')
|
||||
self._mean_value = self._add_parameter(mean, 'mean')
|
||||
self._sd_value = self._add_parameter(sd, 'sd')
|
||||
if self._sd_value is not None:
|
||||
check_greater_zero(self._sd_value, "Standard deviation")
|
||||
if self._df_value is not None:
|
||||
check_greater_zero(self._df_value, "Degrees of freedom")
|
||||
self.log1p = P.Log1p()
|
||||
self.log = P.Log()
|
||||
self.cast = P.Cast()
|
||||
self.abs = P.Abs()
|
||||
self.half = 0.5
|
||||
self.half_log_pi = 0.5 * np.log(np.pi)
|
||||
self.lgamma = nn.LGamma()
|
||||
|
||||
def _log_prob(self, value, df=None, mean=None, sd=None):
|
||||
r"""
|
||||
Evaluate log probability.
|
||||
|
||||
Args:
|
||||
value (Tensor): The value to be evaluated.
|
||||
df (Tensor): The degrees of freedom of the distribution. Default: self._df_value.
|
||||
mean (Tensor): The mean of the distribution. Default: self._mean_value.
|
||||
sd (Tensor): The standard deviation the distribution. Default: self._sd_value.
|
||||
|
||||
.. math::
|
||||
L(x) = -0.5 * (\nu + 1.) * \log((x - \mu) / \sigma + 1.)) + \log(\sqrt(\pi * \mu * \sigma^2))
|
||||
+ log(\Gamma(\nu / 2.)) - log(\Gamma((\nu + 1.) / 2.))
|
||||
"""
|
||||
value = self._check_value(value, 'value')
|
||||
value = self.cast(value, self.dtype)
|
||||
df, mean, sd = self._check_param_type(df, mean, sd)
|
||||
|
||||
y = (value - mean) / sd
|
||||
log_unnormalized_prob = -0.5 * (df + 1.) * self.log1p(y**2. / df)
|
||||
log_normalization = self.log(self.abs(sd)) + 0.5 * self.log(df) + self.half_log_pi + \
|
||||
self.lgamma(self.half * df) - self.lgamma(self.half * (df + 1.))
|
||||
return log_unnormalized_prob - log_normalization
|
|
@ -0,0 +1,83 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test cases for HalfNormal distribution"""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from scipy import stats
|
||||
import mindspore.context as context
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
import mindspore.nn.probability.distribution as msd
|
||||
from mindspore import Tensor
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
|
||||
|
||||
class LogProb(nn.Cell):
|
||||
"""
|
||||
Test class: log probability of HalfNormal distribution.
|
||||
"""
|
||||
def __init__(self, loc, scale):
|
||||
super(LogProb, self).__init__()
|
||||
self.n = msd.HalfNormal(loc, scale, dtype=mstype.float32)
|
||||
|
||||
def construct(self, x_):
|
||||
return self.n.log_prob(x_)
|
||||
|
||||
|
||||
class LogProb2(nn.Cell):
|
||||
"""
|
||||
Test class: log probability of HalfNormal distribution.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(LogProb2, self).__init__()
|
||||
self.n = msd.HalfNormal(dtype=mstype.float32)
|
||||
|
||||
def construct(self, x_, loc, scale):
|
||||
return self.n.log_prob(x_, loc, scale)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_log_likelihood():
|
||||
"""
|
||||
Feature: HalfNormal distribution
|
||||
Description: test cases for log_prob() of HalfNormal distribution
|
||||
Expectation: the result match to stats
|
||||
"""
|
||||
x = np.array([0.3, 4.0, np.pi, np.e, -2.0], dtype=np.float32)
|
||||
loc = np.array([0.0, 0.0, 0.5, 0.7, 1.0], dtype=np.float32)
|
||||
scale = np.array([1.5, 1.0, 2.0, 3.0, 2.0], dtype=np.float32)
|
||||
|
||||
# stats as benchmark
|
||||
expected = stats.halfnorm.logpdf(x, loc=loc, scale=scale).astype(np.float32)
|
||||
|
||||
log_prob = LogProb(loc, scale)
|
||||
output = log_prob(Tensor(x, dtype=mstype.float32))
|
||||
|
||||
log_prob2 = LogProb2()
|
||||
output2 = log_prob2(Tensor(x, dtype=mstype.float32), Tensor(loc, dtype=mstype.float32),
|
||||
Tensor(scale, dtype=mstype.float32))
|
||||
|
||||
tol = 1e-5
|
||||
|
||||
output = output.asnumpy()
|
||||
assert (output[np.isinf(output)] == expected[np.isinf(expected)]).all()
|
||||
assert (np.abs(output[~np.isinf(output)] - expected[~np.isinf(expected)]) < tol).all()
|
||||
|
||||
output2 = output2.asnumpy()
|
||||
assert (output2[np.isinf(output2)] == expected[np.isinf(expected)]).all()
|
||||
assert (np.abs(output2[~np.isinf(output2)] - expected[~np.isinf(expected)]) < tol).all()
|
|
@ -0,0 +1,83 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test cases for Laplace distribution"""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from scipy import stats
|
||||
import mindspore.context as context
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
import mindspore.nn.probability.distribution as msd
|
||||
from mindspore import Tensor
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
|
||||
|
||||
class LogProb(nn.Cell):
|
||||
"""
|
||||
Test class: log probability of Laplace distribution.
|
||||
"""
|
||||
def __init__(self, loc, scale):
|
||||
super(LogProb, self).__init__()
|
||||
self.n = msd.Laplace(loc, scale, dtype=mstype.float32)
|
||||
|
||||
def construct(self, x_):
|
||||
return self.n.log_prob(x_)
|
||||
|
||||
|
||||
class LogProb2(nn.Cell):
|
||||
"""
|
||||
Test class: log probability of Laplace distribution.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(LogProb2, self).__init__()
|
||||
self.n = msd.Laplace(dtype=mstype.float32)
|
||||
|
||||
def construct(self, x_, loc, scale):
|
||||
return self.n.log_prob(x_, loc, scale)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_log_likelihood():
|
||||
"""
|
||||
Feature: Laplace distribution
|
||||
Description: test cases for log_prob() of Laplace distribution
|
||||
Expectation: the result match to stats
|
||||
"""
|
||||
x = np.array([0.3, 4.0, np.pi, np.e, -2.0], dtype=np.float32)
|
||||
loc = np.array([0.0, 0.0, 0.5, 0.7, 1.0], dtype=np.float32)
|
||||
scale = np.array([1.5, 1.0, 2.0, 3.0, 2.0], dtype=np.float32)
|
||||
|
||||
# stats as benchmark
|
||||
expected = stats.laplace.logpdf(x, loc=loc, scale=scale).astype(np.float32)
|
||||
|
||||
log_prob = LogProb(loc, scale)
|
||||
output = log_prob(Tensor(x, dtype=mstype.float32))
|
||||
|
||||
log_prob2 = LogProb2()
|
||||
output2 = log_prob2(Tensor(x, dtype=mstype.float32), Tensor(loc, dtype=mstype.float32),
|
||||
Tensor(scale, dtype=mstype.float32))
|
||||
|
||||
tol = 1e-5
|
||||
|
||||
output = output.asnumpy()
|
||||
assert (output[np.isinf(output)] == expected[np.isinf(expected)]).all()
|
||||
assert (np.abs(output[~np.isinf(output)] - expected[~np.isinf(expected)]) < tol).all()
|
||||
|
||||
output2 = output2.asnumpy()
|
||||
assert (output2[np.isinf(output2)] == expected[np.isinf(expected)]).all()
|
||||
assert (np.abs(output2[~np.isinf(output2)] - expected[~np.isinf(expected)]) < tol).all()
|
|
@ -0,0 +1,81 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test cases for Normal distribution"""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from scipy import stats
|
||||
import mindspore.context as context
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
import mindspore.nn.probability.distribution as msd
|
||||
from mindspore import Tensor
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
|
||||
|
||||
class LogProb(nn.Cell):
|
||||
"""
|
||||
Test class: log probability of Normal distribution.
|
||||
"""
|
||||
def __init__(self, loc, scale):
|
||||
super(LogProb, self).__init__()
|
||||
self.n = msd.Normal(loc, scale, dtype=mstype.float32)
|
||||
|
||||
def construct(self, x_):
|
||||
return self.n.log_prob(x_)
|
||||
|
||||
|
||||
class LogProb2(nn.Cell):
|
||||
"""
|
||||
Test class: log probability of Normal distribution.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(LogProb2, self).__init__()
|
||||
self.n = msd.Normal(dtype=mstype.float32)
|
||||
|
||||
def construct(self, x_, loc, scale):
|
||||
return self.n.log_prob(x_, loc, scale)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_log_likelihood():
|
||||
"""
|
||||
Test log_pdf.
|
||||
"""
|
||||
x = np.array([0.3, 4.0, np.pi, np.e, -2.0], dtype=np.float32)
|
||||
loc = np.array([0.0, 0.0, 0.5, 0.7, 1.0], dtype=np.float32)
|
||||
scale = np.array([1.5, 1.0, 2.0, 3.0, 2.0], dtype=np.float32)
|
||||
|
||||
# stats as benchmark
|
||||
expected = stats.norm.logpdf(x, loc=loc, scale=scale).astype(np.float32)
|
||||
|
||||
log_prob = LogProb(loc, scale)
|
||||
output = log_prob(Tensor(x, dtype=mstype.float32))
|
||||
|
||||
log_prob2 = LogProb2()
|
||||
output2 = log_prob2(Tensor(x, dtype=mstype.float32), Tensor(loc, dtype=mstype.float32),
|
||||
Tensor(scale, dtype=mstype.float32))
|
||||
|
||||
tol = 1e-5
|
||||
|
||||
output = output.asnumpy()
|
||||
assert (output[np.isinf(output)] == expected[np.isinf(expected)]).all()
|
||||
assert (np.abs(output[~np.isinf(output)] - expected[~np.isinf(expected)]) < tol).all()
|
||||
|
||||
output2 = output2.asnumpy()
|
||||
assert (output2[np.isinf(output2)] == expected[np.isinf(expected)]).all()
|
||||
assert (np.abs(output2[~np.isinf(output2)] - expected[~np.isinf(expected)]) < tol).all()
|
|
@ -0,0 +1,84 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test cases for StudentT distribution"""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from scipy import stats
|
||||
import mindspore.context as context
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
import mindspore.nn.probability.distribution as msd
|
||||
from mindspore import Tensor
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
|
||||
|
||||
class LogProb(nn.Cell):
|
||||
"""
|
||||
Test class: log probability of StudentT distribution.
|
||||
"""
|
||||
def __init__(self, df, loc, scale):
|
||||
super(LogProb, self).__init__()
|
||||
self.n = msd.StudentT(df, loc, scale, dtype=mstype.float32)
|
||||
|
||||
def construct(self, x_):
|
||||
return self.n.log_prob(x_)
|
||||
|
||||
|
||||
class LogProb2(nn.Cell):
|
||||
"""
|
||||
Test class: log probability of StudentT distribution.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(LogProb2, self).__init__()
|
||||
self.n = msd.StudentT(dtype=mstype.float32)
|
||||
|
||||
def construct(self, x_, df, loc, scale):
|
||||
return self.n.log_prob(x_, df, loc, scale)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_log_likelihood():
|
||||
"""
|
||||
Feature: StudentT distribution
|
||||
Description: test cases for log_prob() of StudentT distribution
|
||||
Expectation: the result match to stats
|
||||
"""
|
||||
x = np.array([0.3, 4.0, np.pi, np.e, -2.0], dtype=np.float32)
|
||||
df = np.array([0.1, 0.3, 0.5, 0.7, 1.0], dtype=np.float32)
|
||||
loc = np.array([0.0, 0.0, 0.5, 0.7, 1.0], dtype=np.float32)
|
||||
scale = np.array([1.5, 1.0, 2.0, 3.0, 2.0], dtype=np.float32)
|
||||
|
||||
# stats as benchmark
|
||||
expected = stats.t.logpdf(x, df=df, loc=loc, scale=scale).astype(np.float32)
|
||||
|
||||
log_prob = LogProb(df, loc, scale)
|
||||
output = log_prob(Tensor(x, dtype=mstype.float32))
|
||||
|
||||
log_prob2 = LogProb2()
|
||||
output2 = log_prob2(Tensor(x, dtype=mstype.float32), Tensor(df, dtype=mstype.float32),
|
||||
Tensor(loc, dtype=mstype.float32), Tensor(scale, dtype=mstype.float32))
|
||||
|
||||
tol = 1e-5
|
||||
|
||||
output = output.asnumpy()
|
||||
assert (output[np.isinf(output)] == expected[np.isinf(expected)]).all()
|
||||
assert (np.abs(output[~np.isinf(output)] - expected[~np.isinf(expected)]) < tol).all()
|
||||
|
||||
output2 = output2.asnumpy()
|
||||
assert (output2[np.isinf(output2)] == expected[np.isinf(expected)]).all()
|
||||
assert (np.abs(output2[~np.isinf(output2)] - expected[~np.isinf(expected)]) < tol).all()
|
Loading…
Reference in New Issue