forked from mindspore-Ecosystem/mindspore
update doc example in probability
fix typo in probability module fix pylint in msp
This commit is contained in:
parent
f87b5e0cc8
commit
ca316f4422
|
@ -41,7 +41,7 @@ class Bijector(Cell):
|
||||||
Note:
|
Note:
|
||||||
`dtype` of bijector represents the type of the distributions that the bijector could operate on.
|
`dtype` of bijector represents the type of the distributions that the bijector could operate on.
|
||||||
When `dtype` is None, there is no enforcement on the type of input value except that the input value
|
When `dtype` is None, there is no enforcement on the type of input value except that the input value
|
||||||
has to be float type. During initilization, when `dtype` is None, there is no enforcement on the dtype
|
has to be float type. During initialization, when `dtype` is None, there is no enforcement on the dtype
|
||||||
of the parameters. All parameters should have the same float type, otherwise a TypeError will be raised.
|
of the parameters. All parameters should have the same float type, otherwise a TypeError will be raised.
|
||||||
Specifically, the parameter type will follow the dtype of the input value, i.e. parameters of the bijector
|
Specifically, the parameter type will follow the dtype of the input value, i.e. parameters of the bijector
|
||||||
will be casted into the same type as input value when `dtype`is None.
|
will be casted into the same type as input value when `dtype`is None.
|
||||||
|
@ -65,7 +65,8 @@ class Bijector(Cell):
|
||||||
'is_constant_jacobian', is_constant_jacobian, [bool], name)
|
'is_constant_jacobian', is_constant_jacobian, [bool], name)
|
||||||
validator.check_value_type('is_injective', is_injective, [bool], name)
|
validator.check_value_type('is_injective', is_injective, [bool], name)
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
validator.check_type_name("dtype", dtype, mstype.float_type, type(self).__name__)
|
validator.check_type_name(
|
||||||
|
"dtype", dtype, mstype.float_type, type(self).__name__)
|
||||||
self._name = name
|
self._name = name
|
||||||
self._dtype = dtype
|
self._dtype = dtype
|
||||||
self._parameters = {}
|
self._parameters = {}
|
||||||
|
@ -76,7 +77,7 @@ class Bijector(Cell):
|
||||||
if not(k == 'self' or k.startswith('_')):
|
if not(k == 'self' or k.startswith('_')):
|
||||||
self._parameters[k] = param[k]
|
self._parameters[k] = param[k]
|
||||||
|
|
||||||
# if no bijector is used as an argument during initilization
|
# if no bijector is used as an argument during initialization
|
||||||
if 'bijector' not in param.keys():
|
if 'bijector' not in param.keys():
|
||||||
self._batch_shape = self._calc_batch_shape()
|
self._batch_shape = self._calc_batch_shape()
|
||||||
self._is_scalar_batch = self._check_is_scalar_batch()
|
self._is_scalar_batch = self._check_is_scalar_batch()
|
||||||
|
@ -141,7 +142,8 @@ class Bijector(Cell):
|
||||||
|
|
||||||
def _shape_mapping(self, shape):
|
def _shape_mapping(self, shape):
|
||||||
shape_tensor = self.fill_base(self.parameter_type, shape, 0.0)
|
shape_tensor = self.fill_base(self.parameter_type, shape, 0.0)
|
||||||
dist_shape_tensor = self.fill_base(self.parameter_type, self.batch_shape, 0.0)
|
dist_shape_tensor = self.fill_base(
|
||||||
|
self.parameter_type, self.batch_shape, 0.0)
|
||||||
return (shape_tensor + dist_shape_tensor).shape
|
return (shape_tensor + dist_shape_tensor).shape
|
||||||
|
|
||||||
def shape_mapping(self, shape):
|
def shape_mapping(self, shape):
|
||||||
|
@ -166,12 +168,15 @@ class Bijector(Cell):
|
||||||
if self.common_dtype is None:
|
if self.common_dtype is None:
|
||||||
self.common_dtype = value_t.dtype
|
self.common_dtype = value_t.dtype
|
||||||
elif value_t.dtype != self.common_dtype:
|
elif value_t.dtype != self.common_dtype:
|
||||||
raise TypeError(f"{name} should have the same dtype as other arguments.")
|
raise TypeError(
|
||||||
|
f"{name} should have the same dtype as other arguments.")
|
||||||
# check if the parameters are casted into float-type tensors
|
# check if the parameters are casted into float-type tensors
|
||||||
validator.check_type_name(f"dtype of {name}", value_t.dtype, mstype.float_type, type(self).__name__)
|
validator.check_type_name(
|
||||||
|
f"dtype of {name}", value_t.dtype, mstype.float_type, type(self).__name__)
|
||||||
# check if the dtype of the input_parameter agrees with the bijector's dtype
|
# check if the dtype of the input_parameter agrees with the bijector's dtype
|
||||||
elif value_t.dtype != self.dtype:
|
elif value_t.dtype != self.dtype:
|
||||||
raise TypeError(f"{name} should have the same dtype as the bijector's dtype.")
|
raise TypeError(
|
||||||
|
f"{name} should have the same dtype as the bijector's dtype.")
|
||||||
self.default_parameters += [value,]
|
self.default_parameters += [value,]
|
||||||
self.parameter_names += [name,]
|
self.parameter_names += [name,]
|
||||||
return value_t
|
return value_t
|
||||||
|
|
|
@ -33,24 +33,17 @@ class Invert(Bijector):
|
||||||
>>> import mindspore.nn as nn
|
>>> import mindspore.nn as nn
|
||||||
>>> import mindspore.nn.probability.bijector as msb
|
>>> import mindspore.nn.probability.bijector as msb
|
||||||
>>> from mindspore import Tensor
|
>>> from mindspore import Tensor
|
||||||
>>> import mindspore.context as context
|
>>> class Net(nn.Cell):
|
||||||
>>> context.set_context(mode=1)
|
... def __init__(self):
|
||||||
>>>
|
... super(Net, self).__init__()
|
||||||
>>> # To initialize an inverse Exp bijector.
|
... self.origin = msb.ScalarAffine(scale=2.0, shift=1.0)
|
||||||
>>> inv_exp = msb.Invert(msb.Exp())
|
... self.invert = msb.Invert(self.origin)
|
||||||
>>> value = Tensor([1, 2, 3], dtype=mindspore.float32)
|
...
|
||||||
>>> ans1 = inv_exp.forward(value)
|
... def construct(self, x_):
|
||||||
>>> print(ans1.shape)
|
... return self.invert.forward(x_)
|
||||||
(3,)
|
>>> forward = Net()
|
||||||
>>> ans2 = inv_exp.inverse(value)
|
>>> x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32)
|
||||||
>>> print(ans2.shape)
|
>>> ans = forward(Tensor(x, dtype=dtype.float32))
|
||||||
(3,)
|
|
||||||
>>> ans3 = inv_exp.forward_log_jacobian(value)
|
|
||||||
>>> print(ans3.shape)
|
|
||||||
(3,)
|
|
||||||
>>> ans4 = inv_exp.inverse_log_jacobian(value)
|
|
||||||
>>> print(ans4.shape)
|
|
||||||
(3,)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
|
|
@ -84,6 +84,7 @@ class Bernoulli(Distribution):
|
||||||
(3,)
|
(3,)
|
||||||
>>> # `probs` must be passed in during function calls.
|
>>> # `probs` must be passed in during function calls.
|
||||||
>>> ans = b2.mean(probs_a)
|
>>> ans = b2.mean(probs_a)
|
||||||
|
>>> print(ans.shape)
|
||||||
(1,)
|
(1,)
|
||||||
>>> print(ans.shape)
|
>>> print(ans.shape)
|
||||||
>>> # Interfaces of `kl_loss` and `cross_entropy` are the same as follows:
|
>>> # Interfaces of `kl_loss` and `cross_entropy` are the same as follows:
|
||||||
|
|
|
@ -105,22 +105,6 @@ class Categorical(Distribution):
|
||||||
>>> ans = ca2.kl_loss('Categorical', probs_b, probs_a)
|
>>> ans = ca2.kl_loss('Categorical', probs_b, probs_a)
|
||||||
>>> print(ans.shape)
|
>>> print(ans.shape)
|
||||||
()
|
()
|
||||||
>>> # Examples of `sample`.
|
|
||||||
>>> # Args:
|
|
||||||
>>> # shape (tuple): the shape of the sample. Default: ().
|
|
||||||
>>> # probs (Tensor): event probabilities. Default: self.probs.
|
|
||||||
>>> ans = ca1.sample()
|
|
||||||
>>> print(ans.shape)
|
|
||||||
()
|
|
||||||
>>> ans = ca1.sample((2,3))
|
|
||||||
>>> print(ans.shape)
|
|
||||||
(2, 3)
|
|
||||||
>>> ans = ca1.sample((2,3), probs_b)
|
|
||||||
>>> print(ans.shape)
|
|
||||||
(2, 3)
|
|
||||||
>>> ans = ca2.sample((2,3), probs_a)
|
|
||||||
>>> print(ans.shape)
|
|
||||||
(2, 3)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
|
|
@ -18,7 +18,7 @@ from mindspore.ops import operations as P
|
||||||
from mindspore.nn.cell import Cell
|
from mindspore.nn.cell import Cell
|
||||||
from mindspore._checkparam import Validator as validator
|
from mindspore._checkparam import Validator as validator
|
||||||
from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device,\
|
from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device,\
|
||||||
raise_not_implemented_util
|
raise_not_implemented_util
|
||||||
from ._utils.utils import CheckTuple, CheckTensor
|
from ._utils.utils import CheckTuple, CheckTensor
|
||||||
from ._utils.custom_ops import broadcast_to, exp_generic, log_generic
|
from ._utils.custom_ops import broadcast_to, exp_generic, log_generic
|
||||||
|
|
||||||
|
@ -77,7 +77,8 @@ class Distribution(Cell):
|
||||||
|
|
||||||
# if not a transformed distribution, set the following attribute
|
# if not a transformed distribution, set the following attribute
|
||||||
if 'distribution' not in self.parameters.keys():
|
if 'distribution' not in self.parameters.keys():
|
||||||
self.parameter_type = set_param_type(self.parameters['param_dict'], dtype)
|
self.parameter_type = set_param_type(
|
||||||
|
self.parameters['param_dict'], dtype)
|
||||||
self._batch_shape = self._calc_batch_shape()
|
self._batch_shape = self._calc_batch_shape()
|
||||||
self._is_scalar_batch = self._check_is_scalar_batch()
|
self._is_scalar_batch = self._check_is_scalar_batch()
|
||||||
self._broadcast_shape = self._batch_shape
|
self._broadcast_shape = self._batch_shape
|
||||||
|
@ -152,7 +153,8 @@ class Distribution(Cell):
|
||||||
self.default_parameters = []
|
self.default_parameters = []
|
||||||
self.parameter_names = []
|
self.parameter_names = []
|
||||||
# cast value to a tensor if it is not None
|
# cast value to a tensor if it is not None
|
||||||
value_t = None if value is None else cast_to_tensor(value, self.parameter_type)
|
value_t = None if value is None else cast_to_tensor(
|
||||||
|
value, self.parameter_type)
|
||||||
self.default_parameters += [value_t,]
|
self.default_parameters += [value_t,]
|
||||||
self.parameter_names += [name,]
|
self.parameter_names += [name,]
|
||||||
return value_t
|
return value_t
|
||||||
|
@ -180,10 +182,12 @@ class Distribution(Cell):
|
||||||
if broadcast_shape is None:
|
if broadcast_shape is None:
|
||||||
broadcast_shape = self.shape_base(arg)
|
broadcast_shape = self.shape_base(arg)
|
||||||
common_dtype = self.dtype_base(arg)
|
common_dtype = self.dtype_base(arg)
|
||||||
broadcast_shape_tensor = self.fill_base(common_dtype, broadcast_shape, 1.0)
|
broadcast_shape_tensor = self.fill_base(
|
||||||
|
common_dtype, broadcast_shape, 1.0)
|
||||||
else:
|
else:
|
||||||
broadcast_shape = self.shape_base(arg + broadcast_shape_tensor)
|
broadcast_shape = self.shape_base(arg + broadcast_shape_tensor)
|
||||||
broadcast_shape_tensor = self.fill_base(common_dtype, broadcast_shape, 1.0)
|
broadcast_shape_tensor = self.fill_base(
|
||||||
|
common_dtype, broadcast_shape, 1.0)
|
||||||
arg = self.broadcast(arg, broadcast_shape_tensor)
|
arg = self.broadcast(arg, broadcast_shape_tensor)
|
||||||
# check if the arguments have the same dtype
|
# check if the arguments have the same dtype
|
||||||
self.sametypeshape_base(arg, broadcast_shape_tensor)
|
self.sametypeshape_base(arg, broadcast_shape_tensor)
|
||||||
|
@ -240,7 +244,7 @@ class Distribution(Cell):
|
||||||
|
|
||||||
def _set_prob(self):
|
def _set_prob(self):
|
||||||
"""
|
"""
|
||||||
Set probability funtion based on the availability of `_prob` and `_log_likehood`.
|
Set probability function based on the availability of `_prob` and `_log_likehood`.
|
||||||
"""
|
"""
|
||||||
if hasattr(self, '_prob'):
|
if hasattr(self, '_prob'):
|
||||||
self._call_prob = self._prob
|
self._call_prob = self._prob
|
||||||
|
@ -303,9 +307,10 @@ class Distribution(Cell):
|
||||||
Set survival function based on the availability of _survival function and `_log_survival`
|
Set survival function based on the availability of _survival function and `_log_survival`
|
||||||
and `_call_cdf`.
|
and `_call_cdf`.
|
||||||
"""
|
"""
|
||||||
if not (hasattr(self, '_survival_function') or hasattr(self, '_log_survival') or \
|
if not (hasattr(self, '_survival_function') or hasattr(self, '_log_survival') or
|
||||||
hasattr(self, '_cdf') or hasattr(self, '_log_cdf')):
|
hasattr(self, '_cdf') or hasattr(self, '_log_cdf')):
|
||||||
self._call_survival = self._raise_not_implemented_error('survival_function')
|
self._call_survival = self._raise_not_implemented_error(
|
||||||
|
'survival_function')
|
||||||
elif hasattr(self, '_survival_function'):
|
elif hasattr(self, '_survival_function'):
|
||||||
self._call_survival = self._survival_function
|
self._call_survival = self._survival_function
|
||||||
elif hasattr(self, '_log_survival'):
|
elif hasattr(self, '_log_survival'):
|
||||||
|
@ -317,7 +322,7 @@ class Distribution(Cell):
|
||||||
"""
|
"""
|
||||||
Set log cdf based on the availability of `_log_cdf` and `_call_cdf`.
|
Set log cdf based on the availability of `_log_cdf` and `_call_cdf`.
|
||||||
"""
|
"""
|
||||||
if not (hasattr(self, '_log_cdf') or hasattr(self, '_cdf') or \
|
if not (hasattr(self, '_log_cdf') or hasattr(self, '_cdf') or
|
||||||
hasattr(self, '_survival_function') or hasattr(self, '_log_survival')):
|
hasattr(self, '_survival_function') or hasattr(self, '_log_survival')):
|
||||||
self._call_log_cdf = self._raise_not_implemented_error('log_cdf')
|
self._call_log_cdf = self._raise_not_implemented_error('log_cdf')
|
||||||
elif hasattr(self, '_log_cdf'):
|
elif hasattr(self, '_log_cdf'):
|
||||||
|
@ -329,9 +334,10 @@ class Distribution(Cell):
|
||||||
"""
|
"""
|
||||||
Set log survival based on the availability of `_log_survival` and `_call_survival`.
|
Set log survival based on the availability of `_log_survival` and `_call_survival`.
|
||||||
"""
|
"""
|
||||||
if not (hasattr(self, '_log_survival') or hasattr(self, '_survival_function') or \
|
if not (hasattr(self, '_log_survival') or hasattr(self, '_survival_function') or
|
||||||
hasattr(self, '_log_cdf') or hasattr(self, '_cdf')):
|
hasattr(self, '_log_cdf') or hasattr(self, '_cdf')):
|
||||||
self._call_log_survival = self._raise_not_implemented_error('log_cdf')
|
self._call_log_survival = self._raise_not_implemented_error(
|
||||||
|
'log_cdf')
|
||||||
elif hasattr(self, '_log_survival'):
|
elif hasattr(self, '_log_survival'):
|
||||||
self._call_log_survival = self._log_survival
|
self._call_log_survival = self._log_survival
|
||||||
elif hasattr(self, '_call_survival'):
|
elif hasattr(self, '_call_survival'):
|
||||||
|
@ -344,7 +350,8 @@ class Distribution(Cell):
|
||||||
if hasattr(self, '_cross_entropy'):
|
if hasattr(self, '_cross_entropy'):
|
||||||
self._call_cross_entropy = self._cross_entropy
|
self._call_cross_entropy = self._cross_entropy
|
||||||
else:
|
else:
|
||||||
self._call_cross_entropy = self._raise_not_implemented_error('cross_entropy')
|
self._call_cross_entropy = self._raise_not_implemented_error(
|
||||||
|
'cross_entropy')
|
||||||
|
|
||||||
def _get_dist_args(self, *args, **kwargs):
|
def _get_dist_args(self, *args, **kwargs):
|
||||||
return raise_not_implemented_util('get_dist_args', self.name, *args, **kwargs)
|
return raise_not_implemented_util('get_dist_args', self.name, *args, **kwargs)
|
||||||
|
@ -375,6 +382,7 @@ class Distribution(Cell):
|
||||||
|
|
||||||
def _raise_not_implemented_error(self, func_name):
|
def _raise_not_implemented_error(self, func_name):
|
||||||
name = self.name
|
name = self.name
|
||||||
|
|
||||||
def raise_error(*args, **kwargs):
|
def raise_error(*args, **kwargs):
|
||||||
return raise_not_implemented_util(func_name, name, *args, **kwargs)
|
return raise_not_implemented_util(func_name, name, *args, **kwargs)
|
||||||
return raise_error
|
return raise_error
|
||||||
|
|
|
@ -46,48 +46,19 @@ class Gumbel(TransformedDistribution):
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> import mindspore
|
>>> import mindspore
|
||||||
>>> import mindspore.context as context
|
|
||||||
>>> import mindspore.nn as nn
|
>>> import mindspore.nn as nn
|
||||||
>>> import mindspore.nn.probability.distribution as msd
|
>>> import mindspore.nn.probability.distribution as msd
|
||||||
>>> from mindspore import Tensor
|
>>> from mindspore import Tensor
|
||||||
>>> context.set_context(mode=1)
|
>>> class Prob(nn.Cell):
|
||||||
>>> # To initialize a Gumbel distribution of `loc` 3.0 and `scale` 4.0.
|
... def __init__(self):
|
||||||
>>> gumbel = msd.Gumbel(3.0, 4.0, dtype=mindspore.float32)
|
... super(Prob, self).__init__()
|
||||||
>>> # Private interfaces of probability functions corresponding to public interfaces, including
|
... self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32)
|
||||||
>>> # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, have the same
|
...
|
||||||
>>> # arguments as follows.
|
... def construct(self, x_):
|
||||||
>>> # Args:
|
... return self.gum.prob(x_)
|
||||||
>>> # value (Tensor): the value to be evaluated.
|
>>> value = np.array([1.0, 2.0]).astype(np.float32)
|
||||||
>>> # Examples of `prob`.
|
>>> pdf = Prob()
|
||||||
>>> # Similar calls can be made to other probability functions
|
>>> output = pdf(Tensor(value, dtype=dtype.float32))
|
||||||
>>> # by replacing 'prob' by the name of the function.
|
|
||||||
>>> value = Tensor([1.0, 2.0, 3.0], dtype=mindspore.float32)
|
|
||||||
>>> ans = gumbel.prob(value)
|
|
||||||
>>> print(ans.shape)
|
|
||||||
(3,)
|
|
||||||
>>> # Functions `mean`, `mode`, sd`, `var`, and `entropy` do not take in any argument.
|
|
||||||
>>> ans = gumbel.mean()
|
|
||||||
>>> print(ans.shape)
|
|
||||||
()
|
|
||||||
>>> # Interfaces of 'kl_loss' and 'cross_entropy' are the same:
|
|
||||||
>>> # Args:
|
|
||||||
>>> # dist (str): the type of the distributions. Only "Gumbel" is supported.
|
|
||||||
>>> # loc_b (Tensor): the loc of distribution b.
|
|
||||||
>>> # scale_b (Tensor): the scale distribution b.
|
|
||||||
>>> # Examples of `kl_loss`. `cross_entropy` is similar.
|
|
||||||
>>> loc_b = Tensor([1.0], dtype=mindspore.float32)
|
|
||||||
>>> scale_b = Tensor([1.0, 1.5, 2.0], dtype=mindspore.float32)
|
|
||||||
>>> ans = gumbel.kl_loss('Gumbel', loc_b, scale_b)
|
|
||||||
>>> print(ans.shape)
|
|
||||||
(3,)
|
|
||||||
>>> # Examples of `sample`.
|
|
||||||
>>> # Args:
|
|
||||||
>>> # shape (tuple): the shape of the sample. Default: ()
|
|
||||||
>>> ans = gumbel.sample()
|
|
||||||
>>> print(ans.shape)
|
|
||||||
()
|
|
||||||
>>> ans = gumbel.sample((2,3))
|
|
||||||
>>> print(ans.shape)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
|
|
@ -45,103 +45,17 @@ class LogNormal(msd.TransformedDistribution):
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> import mindspore
|
>>> import mindspore
|
||||||
>>> import mindspore.context as context
|
|
||||||
>>> import mindspore.nn as nn
|
>>> import mindspore.nn as nn
|
||||||
>>> import mindspore.nn.probability.distribution as msd
|
>>> import mindspore.nn.probability.distribution as msd
|
||||||
>>> from mindspore import Tensor
|
>>> from mindspore import Tensor
|
||||||
>>> context.set_context(mode=1)
|
... class Prob(nn.Cell):
|
||||||
>>> # To initialize a LogNormal distribution of `loc` 3.0 and `scale` 4.0.
|
... def __init__(self):
|
||||||
>>> n1 = msd.LogNormal(3.0, 4.0, dtype=mindspore.float32)
|
... super(Prob, self).__init__()
|
||||||
>>> # A LogNormal distribution can be initialized without arguments.
|
... self.ln = msd.LogNormal(np.array([0.3]), np.array([[0.2], [0.4]]), dtype=dtype.float32)
|
||||||
>>> # In this case, `loc` and `scale` must be passed in during function calls.
|
... def construct(self, x_):
|
||||||
>>> n2 = msd.LogNormal(dtype=mindspore.float32)
|
... return self.ln.prob(x_)
|
||||||
>>>
|
>>> pdf = Prob()
|
||||||
>>> # Here are some tensors used below for testing
|
>>> output = pdf(Tensor([1.0, 2.0], dtype=dtype.float32))
|
||||||
>>> value = Tensor([1.0, 2.0, 3.0], dtype=mindspore.float32)
|
|
||||||
>>> loc_a = Tensor([2.0], dtype=mindspore.float32)
|
|
||||||
>>> scale_a = Tensor([2.0, 2.0, 2.0], dtype=mindspore.float32)
|
|
||||||
>>> loc_b = Tensor([1.0], dtype=mindspore.float32)
|
|
||||||
>>> scale_b = Tensor([1.0, 1.5, 2.0], dtype=mindspore.float32)
|
|
||||||
>>>
|
|
||||||
>>> # Private interfaces of probability functions corresponding to public interfaces, including
|
|
||||||
>>> # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, have the same
|
|
||||||
>>> # arguments as follows.
|
|
||||||
>>> # Args:
|
|
||||||
>>> # value (Tensor): the value to be evaluated.
|
|
||||||
>>> # loc (Tensor): the loc of distribution. Default: None. If `loc` is passed in as None,
|
|
||||||
>>> # the mean of the underlying Normal distribution will be used.
|
|
||||||
>>> # scale (Tensor): the scale of distribution. Default: None. If `scale` is passed in as None,
|
|
||||||
>>> # the standard deviation of the underlying Normal distribution will be used.
|
|
||||||
>>> # Examples of `prob`.
|
|
||||||
>>> # Similar calls can be made to other probability functions
|
|
||||||
>>> # by replacing 'prob' by the name of the function.
|
|
||||||
>>> ans = n1.prob(value)
|
|
||||||
>>> print(ans.shape)
|
|
||||||
(3,)
|
|
||||||
>>> # Evaluate with respect to distribution b.
|
|
||||||
>>> ans = n1.prob(value, loc_b, scale_b)
|
|
||||||
>>> print(ans.shape)
|
|
||||||
(3,)
|
|
||||||
>>> # `loc` and `scale` must be passed in during function calls since they were not passed in construct.
|
|
||||||
>>> ans = n2.prob(value, loc_a, scale_a)
|
|
||||||
>>> print(ans.shape)
|
|
||||||
(3,)
|
|
||||||
>>> # Functions `mean`, `sd`, `var`, and `entropy` have the same arguments.
|
|
||||||
>>> # Args:
|
|
||||||
>>> # loc (Tensor): the loc of distribution. Default: None. If `loc` is passed in as None,
|
|
||||||
>>> # the mean of the underlying Normal distribution will be used.
|
|
||||||
>>> # scale (Tensor): the scale of distribution. Default: None. If `scale` is passed in as None,
|
|
||||||
>>> # the standard deviation of the underlying Normal distribution will be used.
|
|
||||||
>>> # Example of `mean`. `sd`, `var`, and `entropy` are similar.
|
|
||||||
>>> ans = n1.mean()
|
|
||||||
>>> print(ans.shape)
|
|
||||||
()
|
|
||||||
>>> ans = n1.mean(loc_b, scale_b)
|
|
||||||
>>> print(ans.shape)
|
|
||||||
(3,)
|
|
||||||
>>> # `loc` and `scale` must be passed in during function calls since they were not passed in construct.
|
|
||||||
>>> ans = n2.mean(loc_a, scale_a)
|
|
||||||
>>> print(ans.shape)
|
|
||||||
(3,)
|
|
||||||
>>> # Interfaces of 'kl_loss' and 'cross_entropy' are the same:
|
|
||||||
>>> # Args:
|
|
||||||
>>> # dist (str): the type of the distributions. Only "Normal" is supported.
|
|
||||||
>>> # loc_b (Tensor): the loc of distribution b.
|
|
||||||
>>> # scale_b (Tensor): the scale distribution b.
|
|
||||||
>>> # loc_a (Tensor): the loc of distribution a. Default: None. If `loc` is passed in as None,
|
|
||||||
>>> # the mean of the underlying Normal distribution will be used.
|
|
||||||
>>> # scale_a (Tensor): the scale distribution a. Default: None. If `scale` is passed in as None,
|
|
||||||
>>> # the standard deviation of the underlying Normal distribution will be used.
|
|
||||||
>>> # Examples of `kl_loss`. `cross_entropy` is similar.
|
|
||||||
>>> ans = n1.kl_loss('LogNormal', loc_b, scale_b)
|
|
||||||
>>> print(ans.shape)
|
|
||||||
(3,)
|
|
||||||
>>> ans = n1.kl_loss('LogNormal', loc_b, scale_b, loc_a, scale_a)
|
|
||||||
>>> print(ans.shape)
|
|
||||||
(3,)
|
|
||||||
>>> # Additional `loc` and `scale` must be passed in since they were not passed in construct.
|
|
||||||
>>> ans = n2.kl_loss('LogNormal', loc_b, scale_b, loc_a, scale_a)
|
|
||||||
>>> print(ans.shape)
|
|
||||||
(3,)
|
|
||||||
>>> # Examples of `sample`.
|
|
||||||
>>> # Args:
|
|
||||||
>>> # shape (tuple): the shape of the sample. Default: ()
|
|
||||||
>>> # loc (Tensor): the loc of the distribution. Default: None. If `loc` is passed in as None,
|
|
||||||
>>> # the mean of the underlying Normal distribution will be used.
|
|
||||||
>>> # scale (Tensor): the scale of the distribution. Default: None. If `scale` is passed in as None,
|
|
||||||
>>> # the standard deviation of the underlying Normal distribution will be used.
|
|
||||||
>>> ans = n1.sample()
|
|
||||||
>>> print(ans.shape)
|
|
||||||
()
|
|
||||||
>>> ans = n1.sample((2,3))
|
|
||||||
>>> print(ans.shape)
|
|
||||||
(2, 3)
|
|
||||||
>>> ans = n1.sample((2,3), loc_b, scale_b)
|
|
||||||
>>> print(ans.shape)
|
|
||||||
(2, 3, 3)
|
|
||||||
>>> ans = n2.sample((2,3), loc_a, scale_a)
|
|
||||||
>>> print(ans.shape)
|
|
||||||
(2, 3, 3)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
|
|
@ -89,7 +89,7 @@ class Poisson(Distribution):
|
||||||
>>> # `rate` must be passed in during function calls.
|
>>> # `rate` must be passed in during function calls.
|
||||||
>>> ans = p2.mean(rate_a)
|
>>> ans = p2.mean(rate_a)
|
||||||
>>> print(ans.shape)
|
>>> print(ans.shape)
|
||||||
(1,)
|
()
|
||||||
>>> # Examples of `sample`.
|
>>> # Examples of `sample`.
|
||||||
>>> # Args:
|
>>> # Args:
|
||||||
>>> # shape (tuple): the shape of the sample. Default: ()
|
>>> # shape (tuple): the shape of the sample. Default: ()
|
||||||
|
|
|
@ -53,25 +53,28 @@ class TransformedDistribution(Distribution):
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> import mindspore
|
>>> import mindspore
|
||||||
>>> import mindspore.context as context
|
|
||||||
>>> import mindspore.nn as nn
|
>>> import mindspore.nn as nn
|
||||||
>>> import mindspore.nn.probability.distribution as msd
|
>>> import mindspore.nn.probability.distribution as msd
|
||||||
>>> import mindspore.nn.probability.bijector as msb
|
>>> import mindspore.nn.probability.bijector as msb
|
||||||
>>> from mindspore import Tensor
|
>>> from mindspore import Tensor
|
||||||
>>> context.set_context(mode=1)
|
>>> class Net(nn.Cell):
|
||||||
>>>
|
... def __init__(self, shape, dtype=dtype.float32, seed=0, name='transformed_distribution'):
|
||||||
>>> # To initialize a transformed distribution
|
... super(Net, self).__init__()
|
||||||
>>> # using a Normal distribution as the base distribution,
|
... # create TransformedDistribution distribution
|
||||||
>>> # and an Exp bijector as the bijector function.
|
... self.exp = msb.Exp()
|
||||||
>>> trans_dist = msd.TransformedDistribution(msb.Exp(), msd.Normal(0.0, 1.0))
|
... self.normal = msd.Normal(0.0, 1.0, dtype=dtype)
|
||||||
>>>
|
... self.lognormal = msd.TransformedDistribution(self.exp, self.normal, seed=seed, name=name)
|
||||||
>>> value = Tensor([1.0, 2.0, 3.0], dtype=mindspore.float32)
|
... self.shape = shape
|
||||||
>>> prob = trans_dist.prob(value)
|
...
|
||||||
>>> print(prob.shape)
|
... def construct(self, value):
|
||||||
(3,)
|
... cdf = self.lognormal.cdf(value)
|
||||||
>>> sample = trans_dist.sample(shape=(2, 3))
|
... sample = self.lognormal.sample(self.shape)
|
||||||
>>> print(sample.shape)
|
... return cdf, sample
|
||||||
(2, 3)
|
>>> shape = (2, 3)
|
||||||
|
>>> net = Net(shape=shape, name="LogNormal")
|
||||||
|
>>> x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32)
|
||||||
|
>>> tx = Tensor(x, dtype=dtype.float32)
|
||||||
|
>>> cdf, sample = net(tx)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
|
|
@ -38,7 +38,7 @@ class Uniform(Distribution):
|
||||||
``Ascend`` ``GPU``
|
``Ascend`` ``GPU``
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
`low` must be stricly less than `high`.
|
`low` must be strictly less than `high`.
|
||||||
`dist_spec_args` are `high` and `low`.
|
`dist_spec_args` are `high` and `low`.
|
||||||
`dtype` must be float type because Uniform distributions are continuous.
|
`dtype` must be float type because Uniform distributions are continuous.
|
||||||
|
|
||||||
|
@ -143,7 +143,8 @@ class Uniform(Distribution):
|
||||||
param = dict(locals())
|
param = dict(locals())
|
||||||
param['param_dict'] = {'low': low, 'high': high}
|
param['param_dict'] = {'low': low, 'high': high}
|
||||||
valid_dtype = mstype.float_type
|
valid_dtype = mstype.float_type
|
||||||
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
|
Validator.check_type_name(
|
||||||
|
"dtype", dtype, valid_dtype, type(self).__name__)
|
||||||
super(Uniform, self).__init__(seed, dtype, name, param)
|
super(Uniform, self).__init__(seed, dtype, name, param)
|
||||||
|
|
||||||
self._low = self._add_parameter(low, 'low')
|
self._low = self._add_parameter(low, 'low')
|
||||||
|
@ -151,7 +152,6 @@ class Uniform(Distribution):
|
||||||
if self.low is not None and self.high is not None:
|
if self.low is not None and self.high is not None:
|
||||||
check_greater(self.low, self.high, 'low', 'high')
|
check_greater(self.low, self.high, 'low', 'high')
|
||||||
|
|
||||||
|
|
||||||
# ops needed for the class
|
# ops needed for the class
|
||||||
self.exp = exp_generic
|
self.exp = exp_generic
|
||||||
self.log = log_generic
|
self.log = log_generic
|
||||||
|
|
|
@ -20,11 +20,13 @@ import mindspore.nn.probability.distribution as msd
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
|
||||||
class BayesianNet(nn.Cell):
|
class BayesianNet(nn.Cell):
|
||||||
"""
|
"""
|
||||||
We currently support 3 types of variables: x = observation, z = latent, y = condition.
|
We currently support 3 types of variables: x = observation, z = latent, y = condition.
|
||||||
A Bayeisian Network models a generative process for certain varaiables: p(x,z|y) or p(z|x,y) or p(x|z,y)
|
A Bayeisian Network models a generative process for certain variables: p(x,z|y) or p(z|x,y) or p(x|z,y)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.normal_dist = msd.Normal(dtype=mstype.float32)
|
self.normal_dist = msd.Normal(dtype=mstype.float32)
|
||||||
|
@ -49,14 +51,16 @@ class BayesianNet(nn.Cell):
|
||||||
|
|
||||||
if observation is None:
|
if observation is None:
|
||||||
if reparameterize:
|
if reparameterize:
|
||||||
epsilon = self.normal_dist('sample', shape, self.zeros(mean.shape), self.ones(std.shape))
|
epsilon = self.normal_dist('sample', shape, self.zeros(
|
||||||
|
mean.shape), self.ones(std.shape))
|
||||||
sample = mean + std * epsilon
|
sample = mean + std * epsilon
|
||||||
else:
|
else:
|
||||||
sample = self.normal_dist('sample', shape, mean, std)
|
sample = self.normal_dist('sample', shape, mean, std)
|
||||||
else:
|
else:
|
||||||
sample = observation
|
sample = observation
|
||||||
|
|
||||||
log_prob = self.reduce_sum(self.normal_dist('log_prob', sample, mean, std), 1)
|
log_prob = self.reduce_sum(self.normal_dist(
|
||||||
|
'log_prob', sample, mean, std), 1)
|
||||||
return sample, log_prob
|
return sample, log_prob
|
||||||
|
|
||||||
def Bernoulli(self,
|
def Bernoulli(self,
|
||||||
|
@ -77,7 +81,8 @@ class BayesianNet(nn.Cell):
|
||||||
else:
|
else:
|
||||||
sample = observation
|
sample = observation
|
||||||
|
|
||||||
log_prob = self.reduce_sum(self.bernoulli_dist('log_prob', sample, probs), 1)
|
log_prob = self.reduce_sum(
|
||||||
|
self.bernoulli_dist('log_prob', sample, probs), 1)
|
||||||
return sample, log_prob
|
return sample, log_prob
|
||||||
|
|
||||||
def construct(self, *inputs, **kwargs):
|
def construct(self, *inputs, **kwargs):
|
||||||
|
|
|
@ -23,10 +23,12 @@ from mindspore import dtype
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
|
||||||
|
|
||||||
class Prob(nn.Cell):
|
class Prob(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: probability of Bernoulli distribution.
|
Test class: probability of Bernoulli distribution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Prob, self).__init__()
|
super(Prob, self).__init__()
|
||||||
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
|
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
|
||||||
|
@ -34,6 +36,7 @@ class Prob(nn.Cell):
|
||||||
def construct(self, x_):
|
def construct(self, x_):
|
||||||
return self.b.prob(x_)
|
return self.b.prob(x_)
|
||||||
|
|
||||||
|
|
||||||
def test_pmf():
|
def test_pmf():
|
||||||
"""
|
"""
|
||||||
Test pmf.
|
Test pmf.
|
||||||
|
@ -41,7 +44,8 @@ def test_pmf():
|
||||||
bernoulli_benchmark = stats.bernoulli(0.7)
|
bernoulli_benchmark = stats.bernoulli(0.7)
|
||||||
expect_pmf = bernoulli_benchmark.pmf([0, 1, 0, 1, 1]).astype(np.float32)
|
expect_pmf = bernoulli_benchmark.pmf([0, 1, 0, 1, 1]).astype(np.float32)
|
||||||
pmf = Prob()
|
pmf = Prob()
|
||||||
x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32)
|
x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(
|
||||||
|
np.int32), dtype=dtype.float32)
|
||||||
output = pmf(x_)
|
output = pmf(x_)
|
||||||
tol = 1e-6
|
tol = 1e-6
|
||||||
assert (np.abs(output.asnumpy() - expect_pmf) < tol).all()
|
assert (np.abs(output.asnumpy() - expect_pmf) < tol).all()
|
||||||
|
@ -51,6 +55,7 @@ class LogProb(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: log probability of Bernoulli distribution.
|
Test class: log probability of Bernoulli distribution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(LogProb, self).__init__()
|
super(LogProb, self).__init__()
|
||||||
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
|
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
|
||||||
|
@ -58,22 +63,27 @@ class LogProb(nn.Cell):
|
||||||
def construct(self, x_):
|
def construct(self, x_):
|
||||||
return self.b.log_prob(x_)
|
return self.b.log_prob(x_)
|
||||||
|
|
||||||
|
|
||||||
def test_log_likelihood():
|
def test_log_likelihood():
|
||||||
"""
|
"""
|
||||||
Test log_pmf.
|
Test log_pmf.
|
||||||
"""
|
"""
|
||||||
bernoulli_benchmark = stats.bernoulli(0.7)
|
bernoulli_benchmark = stats.bernoulli(0.7)
|
||||||
expect_logpmf = bernoulli_benchmark.logpmf([0, 1, 0, 1, 1]).astype(np.float32)
|
expect_logpmf = bernoulli_benchmark.logpmf(
|
||||||
|
[0, 1, 0, 1, 1]).astype(np.float32)
|
||||||
logprob = LogProb()
|
logprob = LogProb()
|
||||||
x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32)
|
x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(
|
||||||
|
np.int32), dtype=dtype.float32)
|
||||||
output = logprob(x_)
|
output = logprob(x_)
|
||||||
tol = 1e-6
|
tol = 1e-6
|
||||||
assert (np.abs(output.asnumpy() - expect_logpmf) < tol).all()
|
assert (np.abs(output.asnumpy() - expect_logpmf) < tol).all()
|
||||||
|
|
||||||
|
|
||||||
class KL(nn.Cell):
|
class KL(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: kl_loss between Bernoulli distributions.
|
Test class: kl_loss between Bernoulli distributions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(KL, self).__init__()
|
super(KL, self).__init__()
|
||||||
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
|
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
|
||||||
|
@ -81,6 +91,7 @@ class KL(nn.Cell):
|
||||||
def construct(self, x_):
|
def construct(self, x_):
|
||||||
return self.b.kl_loss('Bernoulli', x_)
|
return self.b.kl_loss('Bernoulli', x_)
|
||||||
|
|
||||||
|
|
||||||
def test_kl_loss():
|
def test_kl_loss():
|
||||||
"""
|
"""
|
||||||
Test kl_loss.
|
Test kl_loss.
|
||||||
|
@ -89,16 +100,19 @@ def test_kl_loss():
|
||||||
probs1_b = 0.5
|
probs1_b = 0.5
|
||||||
probs0_a = 1 - probs1_a
|
probs0_a = 1 - probs1_a
|
||||||
probs0_b = 1 - probs1_b
|
probs0_b = 1 - probs1_b
|
||||||
expect_kl_loss = probs1_a * np.log(probs1_a / probs1_b) + probs0_a * np.log(probs0_a / probs0_b)
|
expect_kl_loss = probs1_a * \
|
||||||
|
np.log(probs1_a / probs1_b) + probs0_a * np.log(probs0_a / probs0_b)
|
||||||
kl_loss = KL()
|
kl_loss = KL()
|
||||||
output = kl_loss(Tensor([probs1_b], dtype=dtype.float32))
|
output = kl_loss(Tensor([probs1_b], dtype=dtype.float32))
|
||||||
tol = 1e-6
|
tol = 1e-6
|
||||||
assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all()
|
assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all()
|
||||||
|
|
||||||
|
|
||||||
class Basics(nn.Cell):
|
class Basics(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: mean/sd/mode of Bernoulli distribution.
|
Test class: mean/sd/mode of Bernoulli distribution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Basics, self).__init__()
|
super(Basics, self).__init__()
|
||||||
self.b = msd.Bernoulli([0.3, 0.5, 0.7], dtype=dtype.int32)
|
self.b = msd.Bernoulli([0.3, 0.5, 0.7], dtype=dtype.int32)
|
||||||
|
@ -106,6 +120,7 @@ class Basics(nn.Cell):
|
||||||
def construct(self):
|
def construct(self):
|
||||||
return self.b.mean(), self.b.sd(), self.b.mode()
|
return self.b.mean(), self.b.sd(), self.b.mode()
|
||||||
|
|
||||||
|
|
||||||
def test_basics():
|
def test_basics():
|
||||||
"""
|
"""
|
||||||
Test mean/standard deviation/mode.
|
Test mean/standard deviation/mode.
|
||||||
|
@ -120,10 +135,12 @@ def test_basics():
|
||||||
assert (np.abs(sd.asnumpy() - expect_sd) < tol).all()
|
assert (np.abs(sd.asnumpy() - expect_sd) < tol).all()
|
||||||
assert (np.abs(mode.asnumpy() - expect_mode) < tol).all()
|
assert (np.abs(mode.asnumpy() - expect_mode) < tol).all()
|
||||||
|
|
||||||
|
|
||||||
class Sampling(nn.Cell):
|
class Sampling(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: log probability of Bernoulli distribution.
|
Test class: log probability of Bernoulli distribution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, shape, seed=0):
|
def __init__(self, shape, seed=0):
|
||||||
super(Sampling, self).__init__()
|
super(Sampling, self).__init__()
|
||||||
self.b = msd.Bernoulli([0.7, 0.5], seed=seed, dtype=dtype.int32)
|
self.b = msd.Bernoulli([0.7, 0.5], seed=seed, dtype=dtype.int32)
|
||||||
|
@ -132,6 +149,7 @@ class Sampling(nn.Cell):
|
||||||
def construct(self, probs=None):
|
def construct(self, probs=None):
|
||||||
return self.b.sample(self.shape, probs)
|
return self.b.sample(self.shape, probs)
|
||||||
|
|
||||||
|
|
||||||
def test_sample():
|
def test_sample():
|
||||||
"""
|
"""
|
||||||
Test sample.
|
Test sample.
|
||||||
|
@ -141,10 +159,12 @@ def test_sample():
|
||||||
output = sample()
|
output = sample()
|
||||||
assert output.shape == (2, 3, 2)
|
assert output.shape == (2, 3, 2)
|
||||||
|
|
||||||
|
|
||||||
class CDF(nn.Cell):
|
class CDF(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: cdf of bernoulli distributions.
|
Test class: cdf of bernoulli distributions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(CDF, self).__init__()
|
super(CDF, self).__init__()
|
||||||
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
|
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
|
||||||
|
@ -152,22 +172,26 @@ class CDF(nn.Cell):
|
||||||
def construct(self, x_):
|
def construct(self, x_):
|
||||||
return self.b.cdf(x_)
|
return self.b.cdf(x_)
|
||||||
|
|
||||||
|
|
||||||
def test_cdf():
|
def test_cdf():
|
||||||
"""
|
"""
|
||||||
Test cdf.
|
Test cdf.
|
||||||
"""
|
"""
|
||||||
bernoulli_benchmark = stats.bernoulli(0.7)
|
bernoulli_benchmark = stats.bernoulli(0.7)
|
||||||
expect_cdf = bernoulli_benchmark.cdf([0, 0, 1, 0, 1]).astype(np.float32)
|
expect_cdf = bernoulli_benchmark.cdf([0, 0, 1, 0, 1]).astype(np.float32)
|
||||||
x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype(np.int32), dtype=dtype.float32)
|
x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype(
|
||||||
|
np.int32), dtype=dtype.float32)
|
||||||
cdf = CDF()
|
cdf = CDF()
|
||||||
output = cdf(x_)
|
output = cdf(x_)
|
||||||
tol = 1e-6
|
tol = 1e-6
|
||||||
assert (np.abs(output.asnumpy() - expect_cdf) < tol).all()
|
assert (np.abs(output.asnumpy() - expect_cdf) < tol).all()
|
||||||
|
|
||||||
|
|
||||||
class LogCDF(nn.Cell):
|
class LogCDF(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: log cdf of bernoulli distributions.
|
Test class: log cdf of bernoulli distributions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(LogCDF, self).__init__()
|
super(LogCDF, self).__init__()
|
||||||
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
|
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
|
||||||
|
@ -175,13 +199,16 @@ class LogCDF(nn.Cell):
|
||||||
def construct(self, x_):
|
def construct(self, x_):
|
||||||
return self.b.log_cdf(x_)
|
return self.b.log_cdf(x_)
|
||||||
|
|
||||||
|
|
||||||
def test_logcdf():
|
def test_logcdf():
|
||||||
"""
|
"""
|
||||||
Test log_cdf.
|
Test log_cdf.
|
||||||
"""
|
"""
|
||||||
bernoulli_benchmark = stats.bernoulli(0.7)
|
bernoulli_benchmark = stats.bernoulli(0.7)
|
||||||
expect_logcdf = bernoulli_benchmark.logcdf([0, 0, 1, 0, 1]).astype(np.float32)
|
expect_logcdf = bernoulli_benchmark.logcdf(
|
||||||
x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype(np.int32), dtype=dtype.float32)
|
[0, 0, 1, 0, 1]).astype(np.float32)
|
||||||
|
x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype(
|
||||||
|
np.int32), dtype=dtype.float32)
|
||||||
logcdf = LogCDF()
|
logcdf = LogCDF()
|
||||||
output = logcdf(x_)
|
output = logcdf(x_)
|
||||||
tol = 1e-6
|
tol = 1e-6
|
||||||
|
@ -192,6 +219,7 @@ class SF(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: survival function of Bernoulli distributions.
|
Test class: survival function of Bernoulli distributions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(SF, self).__init__()
|
super(SF, self).__init__()
|
||||||
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
|
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
|
||||||
|
@ -199,13 +227,16 @@ class SF(nn.Cell):
|
||||||
def construct(self, x_):
|
def construct(self, x_):
|
||||||
return self.b.survival_function(x_)
|
return self.b.survival_function(x_)
|
||||||
|
|
||||||
|
|
||||||
def test_survival():
|
def test_survival():
|
||||||
"""
|
"""
|
||||||
Test survival funciton.
|
Test survival function.
|
||||||
"""
|
"""
|
||||||
bernoulli_benchmark = stats.bernoulli(0.7)
|
bernoulli_benchmark = stats.bernoulli(0.7)
|
||||||
expect_survival = bernoulli_benchmark.sf([0, 1, 1, 0, 0]).astype(np.float32)
|
expect_survival = bernoulli_benchmark.sf(
|
||||||
x_ = Tensor(np.array([0, 1, 1, 0, 0]).astype(np.int32), dtype=dtype.float32)
|
[0, 1, 1, 0, 0]).astype(np.float32)
|
||||||
|
x_ = Tensor(np.array([0, 1, 1, 0, 0]).astype(
|
||||||
|
np.int32), dtype=dtype.float32)
|
||||||
sf = SF()
|
sf = SF()
|
||||||
output = sf(x_)
|
output = sf(x_)
|
||||||
tol = 1e-6
|
tol = 1e-6
|
||||||
|
@ -216,6 +247,7 @@ class LogSF(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: log survival function of Bernoulli distributions.
|
Test class: log survival function of Bernoulli distributions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(LogSF, self).__init__()
|
super(LogSF, self).__init__()
|
||||||
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
|
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
|
||||||
|
@ -223,22 +255,27 @@ class LogSF(nn.Cell):
|
||||||
def construct(self, x_):
|
def construct(self, x_):
|
||||||
return self.b.log_survival(x_)
|
return self.b.log_survival(x_)
|
||||||
|
|
||||||
|
|
||||||
def test_log_survival():
|
def test_log_survival():
|
||||||
"""
|
"""
|
||||||
Test log survival funciton.
|
Test log survival function.
|
||||||
"""
|
"""
|
||||||
bernoulli_benchmark = stats.bernoulli(0.7)
|
bernoulli_benchmark = stats.bernoulli(0.7)
|
||||||
expect_logsurvival = bernoulli_benchmark.logsf([-1, 0.9, 0, 0, 0]).astype(np.float32)
|
expect_logsurvival = bernoulli_benchmark.logsf(
|
||||||
x_ = Tensor(np.array([-1, 0.9, 0, 0, 0]).astype(np.float32), dtype=dtype.float32)
|
[-1, 0.9, 0, 0, 0]).astype(np.float32)
|
||||||
|
x_ = Tensor(np.array([-1, 0.9, 0, 0, 0]
|
||||||
|
).astype(np.float32), dtype=dtype.float32)
|
||||||
log_sf = LogSF()
|
log_sf = LogSF()
|
||||||
output = log_sf(x_)
|
output = log_sf(x_)
|
||||||
tol = 1e-6
|
tol = 1e-6
|
||||||
assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all()
|
assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all()
|
||||||
|
|
||||||
|
|
||||||
class EntropyH(nn.Cell):
|
class EntropyH(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: entropy of Bernoulli distributions.
|
Test class: entropy of Bernoulli distributions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(EntropyH, self).__init__()
|
super(EntropyH, self).__init__()
|
||||||
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
|
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
|
||||||
|
@ -246,6 +283,7 @@ class EntropyH(nn.Cell):
|
||||||
def construct(self):
|
def construct(self):
|
||||||
return self.b.entropy()
|
return self.b.entropy()
|
||||||
|
|
||||||
|
|
||||||
def test_entropy():
|
def test_entropy():
|
||||||
"""
|
"""
|
||||||
Test entropy.
|
Test entropy.
|
||||||
|
@ -257,10 +295,12 @@ def test_entropy():
|
||||||
tol = 1e-6
|
tol = 1e-6
|
||||||
assert (np.abs(output.asnumpy() - expect_entropy) < tol).all()
|
assert (np.abs(output.asnumpy() - expect_entropy) < tol).all()
|
||||||
|
|
||||||
|
|
||||||
class CrossEntropy(nn.Cell):
|
class CrossEntropy(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: cross entropy between bernoulli distributions.
|
Test class: cross entropy between bernoulli distributions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(CrossEntropy, self).__init__()
|
super(CrossEntropy, self).__init__()
|
||||||
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
|
self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
|
||||||
|
@ -272,6 +312,7 @@ class CrossEntropy(nn.Cell):
|
||||||
cross_entropy = self.b.cross_entropy('Bernoulli', x_)
|
cross_entropy = self.b.cross_entropy('Bernoulli', x_)
|
||||||
return h_sum_kl - cross_entropy
|
return h_sum_kl - cross_entropy
|
||||||
|
|
||||||
|
|
||||||
def test_cross_entropy():
|
def test_cross_entropy():
|
||||||
"""
|
"""
|
||||||
Test cross_entropy.
|
Test cross_entropy.
|
||||||
|
|
|
@ -24,10 +24,12 @@ from mindspore import dtype
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
|
||||||
|
|
||||||
class Prob(nn.Cell):
|
class Prob(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: probability of categorical distribution.
|
Test class: probability of categorical distribution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Prob, self).__init__()
|
super(Prob, self).__init__()
|
||||||
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
|
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
|
||||||
|
@ -35,13 +37,15 @@ class Prob(nn.Cell):
|
||||||
def construct(self, x_):
|
def construct(self, x_):
|
||||||
return self.c.prob(x_)
|
return self.c.prob(x_)
|
||||||
|
|
||||||
|
|
||||||
def test_pmf():
|
def test_pmf():
|
||||||
"""
|
"""
|
||||||
Test pmf.
|
Test pmf.
|
||||||
"""
|
"""
|
||||||
expect_pmf = [0.7, 0.3, 0.7, 0.3, 0.3]
|
expect_pmf = [0.7, 0.3, 0.7, 0.3, 0.3]
|
||||||
pmf = Prob()
|
pmf = Prob()
|
||||||
x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32)
|
x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(
|
||||||
|
np.int32), dtype=dtype.float32)
|
||||||
output = pmf(x_)
|
output = pmf(x_)
|
||||||
tol = 1e-6
|
tol = 1e-6
|
||||||
assert (np.abs(output.asnumpy() - expect_pmf) < tol).all()
|
assert (np.abs(output.asnumpy() - expect_pmf) < tol).all()
|
||||||
|
@ -51,6 +55,7 @@ class LogProb(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: log probability of categorical distribution.
|
Test class: log probability of categorical distribution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(LogProb, self).__init__()
|
super(LogProb, self).__init__()
|
||||||
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
|
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
|
||||||
|
@ -58,21 +63,25 @@ class LogProb(nn.Cell):
|
||||||
def construct(self, x_):
|
def construct(self, x_):
|
||||||
return self.c.log_prob(x_)
|
return self.c.log_prob(x_)
|
||||||
|
|
||||||
|
|
||||||
def test_log_likelihood():
|
def test_log_likelihood():
|
||||||
"""
|
"""
|
||||||
Test log_pmf.
|
Test log_pmf.
|
||||||
"""
|
"""
|
||||||
expect_logpmf = np.log([0.7, 0.3, 0.7, 0.3, 0.3])
|
expect_logpmf = np.log([0.7, 0.3, 0.7, 0.3, 0.3])
|
||||||
logprob = LogProb()
|
logprob = LogProb()
|
||||||
x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32)
|
x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(
|
||||||
|
np.int32), dtype=dtype.float32)
|
||||||
output = logprob(x_)
|
output = logprob(x_)
|
||||||
tol = 1e-6
|
tol = 1e-6
|
||||||
assert (np.abs(output.asnumpy() - expect_logpmf) < tol).all()
|
assert (np.abs(output.asnumpy() - expect_logpmf) < tol).all()
|
||||||
|
|
||||||
|
|
||||||
class KL(nn.Cell):
|
class KL(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: kl_loss between categorical distributions.
|
Test class: kl_loss between categorical distributions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(KL, self).__init__()
|
super(KL, self).__init__()
|
||||||
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
|
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
|
||||||
|
@ -80,6 +89,7 @@ class KL(nn.Cell):
|
||||||
def construct(self, x_):
|
def construct(self, x_):
|
||||||
return self.c.kl_loss('Categorical', x_)
|
return self.c.kl_loss('Categorical', x_)
|
||||||
|
|
||||||
|
|
||||||
def test_kl_loss():
|
def test_kl_loss():
|
||||||
"""
|
"""
|
||||||
Test kl_loss.
|
Test kl_loss.
|
||||||
|
@ -89,10 +99,12 @@ def test_kl_loss():
|
||||||
tol = 1e-6
|
tol = 1e-6
|
||||||
assert (np.abs(output.asnumpy()) < tol).all()
|
assert (np.abs(output.asnumpy()) < tol).all()
|
||||||
|
|
||||||
|
|
||||||
class Sampling(nn.Cell):
|
class Sampling(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: sampling of categorical distribution.
|
Test class: sampling of categorical distribution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Sampling, self).__init__()
|
super(Sampling, self).__init__()
|
||||||
self.c = msd.Categorical([0.2, 0.1, 0.7], dtype=dtype.int32)
|
self.c = msd.Categorical([0.2, 0.1, 0.7], dtype=dtype.int32)
|
||||||
|
@ -101,6 +113,7 @@ class Sampling(nn.Cell):
|
||||||
def construct(self):
|
def construct(self):
|
||||||
return self.c.sample(self.shape)
|
return self.c.sample(self.shape)
|
||||||
|
|
||||||
|
|
||||||
def test_sample():
|
def test_sample():
|
||||||
"""
|
"""
|
||||||
Test sample.
|
Test sample.
|
||||||
|
@ -109,10 +122,12 @@ def test_sample():
|
||||||
sample = Sampling()
|
sample = Sampling()
|
||||||
sample()
|
sample()
|
||||||
|
|
||||||
|
|
||||||
class Basics(nn.Cell):
|
class Basics(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: mean/var/mode of categorical distribution.
|
Test class: mean/var/mode of categorical distribution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Basics, self).__init__()
|
super(Basics, self).__init__()
|
||||||
self.c = msd.Categorical([0.2, 0.1, 0.7], dtype=dtype.int32)
|
self.c = msd.Categorical([0.2, 0.1, 0.7], dtype=dtype.int32)
|
||||||
|
@ -120,6 +135,7 @@ class Basics(nn.Cell):
|
||||||
def construct(self):
|
def construct(self):
|
||||||
return self.c.mean(), self.c.var(), self.c.mode()
|
return self.c.mean(), self.c.var(), self.c.mode()
|
||||||
|
|
||||||
|
|
||||||
def test_basics():
|
def test_basics():
|
||||||
"""
|
"""
|
||||||
Test mean/variance/mode.
|
Test mean/variance/mode.
|
||||||
|
@ -139,6 +155,7 @@ class CDF(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: cdf of categorical distributions.
|
Test class: cdf of categorical distributions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(CDF, self).__init__()
|
super(CDF, self).__init__()
|
||||||
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
|
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
|
||||||
|
@ -146,21 +163,25 @@ class CDF(nn.Cell):
|
||||||
def construct(self, x_):
|
def construct(self, x_):
|
||||||
return self.c.cdf(x_)
|
return self.c.cdf(x_)
|
||||||
|
|
||||||
|
|
||||||
def test_cdf():
|
def test_cdf():
|
||||||
"""
|
"""
|
||||||
Test cdf.
|
Test cdf.
|
||||||
"""
|
"""
|
||||||
expect_cdf = [0.7, 0.7, 1, 0.7, 1]
|
expect_cdf = [0.7, 0.7, 1, 0.7, 1]
|
||||||
x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype(np.int32), dtype=dtype.float32)
|
x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype(
|
||||||
|
np.int32), dtype=dtype.float32)
|
||||||
cdf = CDF()
|
cdf = CDF()
|
||||||
output = cdf(x_)
|
output = cdf(x_)
|
||||||
tol = 1e-6
|
tol = 1e-6
|
||||||
assert (np.abs(output.asnumpy() - expect_cdf) < tol).all()
|
assert (np.abs(output.asnumpy() - expect_cdf) < tol).all()
|
||||||
|
|
||||||
|
|
||||||
class LogCDF(nn.Cell):
|
class LogCDF(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: log cdf of categorical distributions.
|
Test class: log cdf of categorical distributions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(LogCDF, self).__init__()
|
super(LogCDF, self).__init__()
|
||||||
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
|
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
|
||||||
|
@ -168,12 +189,14 @@ class LogCDF(nn.Cell):
|
||||||
def construct(self, x_):
|
def construct(self, x_):
|
||||||
return self.c.log_cdf(x_)
|
return self.c.log_cdf(x_)
|
||||||
|
|
||||||
|
|
||||||
def test_logcdf():
|
def test_logcdf():
|
||||||
"""
|
"""
|
||||||
Test log_cdf.
|
Test log_cdf.
|
||||||
"""
|
"""
|
||||||
expect_logcdf = np.log([0.7, 0.7, 1, 0.7, 1])
|
expect_logcdf = np.log([0.7, 0.7, 1, 0.7, 1])
|
||||||
x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype(np.int32), dtype=dtype.float32)
|
x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype(
|
||||||
|
np.int32), dtype=dtype.float32)
|
||||||
logcdf = LogCDF()
|
logcdf = LogCDF()
|
||||||
output = logcdf(x_)
|
output = logcdf(x_)
|
||||||
tol = 1e-6
|
tol = 1e-6
|
||||||
|
@ -184,6 +207,7 @@ class SF(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: survival function of categorical distributions.
|
Test class: survival function of categorical distributions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(SF, self).__init__()
|
super(SF, self).__init__()
|
||||||
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
|
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
|
||||||
|
@ -191,12 +215,14 @@ class SF(nn.Cell):
|
||||||
def construct(self, x_):
|
def construct(self, x_):
|
||||||
return self.c.survival_function(x_)
|
return self.c.survival_function(x_)
|
||||||
|
|
||||||
|
|
||||||
def test_survival():
|
def test_survival():
|
||||||
"""
|
"""
|
||||||
Test survival funciton.
|
Test survival function.
|
||||||
"""
|
"""
|
||||||
expect_survival = [0.3, 0., 0., 0.3, 0.3]
|
expect_survival = [0.3, 0., 0., 0.3, 0.3]
|
||||||
x_ = Tensor(np.array([0, 1, 1, 0, 0]).astype(np.int32), dtype=dtype.float32)
|
x_ = Tensor(np.array([0, 1, 1, 0, 0]).astype(
|
||||||
|
np.int32), dtype=dtype.float32)
|
||||||
sf = SF()
|
sf = SF()
|
||||||
output = sf(x_)
|
output = sf(x_)
|
||||||
tol = 1e-6
|
tol = 1e-6
|
||||||
|
@ -207,6 +233,7 @@ class LogSF(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: log survival function of categorical distributions.
|
Test class: log survival function of categorical distributions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(LogSF, self).__init__()
|
super(LogSF, self).__init__()
|
||||||
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
|
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
|
||||||
|
@ -214,21 +241,25 @@ class LogSF(nn.Cell):
|
||||||
def construct(self, x_):
|
def construct(self, x_):
|
||||||
return self.c.log_survival(x_)
|
return self.c.log_survival(x_)
|
||||||
|
|
||||||
|
|
||||||
def test_log_survival():
|
def test_log_survival():
|
||||||
"""
|
"""
|
||||||
Test log survival funciton.
|
Test log survival function.
|
||||||
"""
|
"""
|
||||||
expect_logsurvival = np.log([1., 0.3, 0.3, 0.3, 0.3])
|
expect_logsurvival = np.log([1., 0.3, 0.3, 0.3, 0.3])
|
||||||
x_ = Tensor(np.array([-2, 0, 0, 0.5, 0.5]).astype(np.float32), dtype=dtype.float32)
|
x_ = Tensor(np.array([-2, 0, 0, 0.5, 0.5]
|
||||||
|
).astype(np.float32), dtype=dtype.float32)
|
||||||
log_sf = LogSF()
|
log_sf = LogSF()
|
||||||
output = log_sf(x_)
|
output = log_sf(x_)
|
||||||
tol = 1e-6
|
tol = 1e-6
|
||||||
assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all()
|
assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all()
|
||||||
|
|
||||||
|
|
||||||
class EntropyH(nn.Cell):
|
class EntropyH(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: entropy of categorical distributions.
|
Test class: entropy of categorical distributions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(EntropyH, self).__init__()
|
super(EntropyH, self).__init__()
|
||||||
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
|
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
|
||||||
|
@ -236,6 +267,7 @@ class EntropyH(nn.Cell):
|
||||||
def construct(self):
|
def construct(self):
|
||||||
return self.c.entropy()
|
return self.c.entropy()
|
||||||
|
|
||||||
|
|
||||||
def test_entropy():
|
def test_entropy():
|
||||||
"""
|
"""
|
||||||
Test entropy.
|
Test entropy.
|
||||||
|
@ -247,10 +279,12 @@ def test_entropy():
|
||||||
tol = 1e-6
|
tol = 1e-6
|
||||||
assert (np.abs(output.asnumpy() - expect_entropy) < tol).all()
|
assert (np.abs(output.asnumpy() - expect_entropy) < tol).all()
|
||||||
|
|
||||||
|
|
||||||
class CrossEntropy(nn.Cell):
|
class CrossEntropy(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: cross entropy between categorical distributions.
|
Test class: cross entropy between categorical distributions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(CrossEntropy, self).__init__()
|
super(CrossEntropy, self).__init__()
|
||||||
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
|
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
|
||||||
|
@ -262,6 +296,7 @@ class CrossEntropy(nn.Cell):
|
||||||
cross_entropy = self.c.cross_entropy('Categorical', x_)
|
cross_entropy = self.c.cross_entropy('Categorical', x_)
|
||||||
return h_sum_kl - cross_entropy
|
return h_sum_kl - cross_entropy
|
||||||
|
|
||||||
|
|
||||||
def test_cross_entropy():
|
def test_cross_entropy():
|
||||||
"""
|
"""
|
||||||
Test cross_entropy.
|
Test cross_entropy.
|
||||||
|
|
|
@ -23,10 +23,12 @@ from mindspore import dtype
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
|
||||||
|
|
||||||
class Prob(nn.Cell):
|
class Prob(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: probability of Geometric distribution.
|
Test class: probability of Geometric distribution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Prob, self).__init__()
|
super(Prob, self).__init__()
|
||||||
self.g = msd.Geometric(0.7, dtype=dtype.int32)
|
self.g = msd.Geometric(0.7, dtype=dtype.int32)
|
||||||
|
@ -34,6 +36,7 @@ class Prob(nn.Cell):
|
||||||
def construct(self, x_):
|
def construct(self, x_):
|
||||||
return self.g.prob(x_)
|
return self.g.prob(x_)
|
||||||
|
|
||||||
|
|
||||||
def test_pmf():
|
def test_pmf():
|
||||||
"""
|
"""
|
||||||
Test pmf.
|
Test pmf.
|
||||||
|
@ -41,15 +44,18 @@ def test_pmf():
|
||||||
geom_benchmark = stats.geom(0.7)
|
geom_benchmark = stats.geom(0.7)
|
||||||
expect_pmf = geom_benchmark.pmf([0, 1, 2, 3, 4]).astype(np.float32)
|
expect_pmf = geom_benchmark.pmf([0, 1, 2, 3, 4]).astype(np.float32)
|
||||||
pdf = Prob()
|
pdf = Prob()
|
||||||
x_ = Tensor(np.array([-1, 0, 1, 2, 3]).astype(np.float32), dtype=dtype.float32)
|
x_ = Tensor(np.array([-1, 0, 1, 2, 3]
|
||||||
|
).astype(np.float32), dtype=dtype.float32)
|
||||||
output = pdf(x_)
|
output = pdf(x_)
|
||||||
tol = 1e-6
|
tol = 1e-6
|
||||||
assert (np.abs(output.asnumpy() - expect_pmf) < tol).all()
|
assert (np.abs(output.asnumpy() - expect_pmf) < tol).all()
|
||||||
|
|
||||||
|
|
||||||
class LogProb(nn.Cell):
|
class LogProb(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: log probability of Geometric distribution.
|
Test class: log probability of Geometric distribution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(LogProb, self).__init__()
|
super(LogProb, self).__init__()
|
||||||
self.g = msd.Geometric(0.7, dtype=dtype.int32)
|
self.g = msd.Geometric(0.7, dtype=dtype.int32)
|
||||||
|
@ -57,6 +63,7 @@ class LogProb(nn.Cell):
|
||||||
def construct(self, x_):
|
def construct(self, x_):
|
||||||
return self.g.log_prob(x_)
|
return self.g.log_prob(x_)
|
||||||
|
|
||||||
|
|
||||||
def test_log_likelihood():
|
def test_log_likelihood():
|
||||||
"""
|
"""
|
||||||
Test log_pmf.
|
Test log_pmf.
|
||||||
|
@ -64,15 +71,18 @@ def test_log_likelihood():
|
||||||
geom_benchmark = stats.geom(0.7)
|
geom_benchmark = stats.geom(0.7)
|
||||||
expect_logpmf = geom_benchmark.logpmf([1, 2, 3, 4, 5]).astype(np.float32)
|
expect_logpmf = geom_benchmark.logpmf([1, 2, 3, 4, 5]).astype(np.float32)
|
||||||
logprob = LogProb()
|
logprob = LogProb()
|
||||||
x_ = Tensor(np.array([0, 1, 2, 3, 4]).astype(np.int32), dtype=dtype.float32)
|
x_ = Tensor(np.array([0, 1, 2, 3, 4]).astype(
|
||||||
|
np.int32), dtype=dtype.float32)
|
||||||
output = logprob(x_)
|
output = logprob(x_)
|
||||||
tol = 1e-6
|
tol = 1e-6
|
||||||
assert (np.abs(output.asnumpy() - expect_logpmf) < tol).all()
|
assert (np.abs(output.asnumpy() - expect_logpmf) < tol).all()
|
||||||
|
|
||||||
|
|
||||||
class KL(nn.Cell):
|
class KL(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: kl_loss between Geometric distributions.
|
Test class: kl_loss between Geometric distributions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(KL, self).__init__()
|
super(KL, self).__init__()
|
||||||
self.g = msd.Geometric(0.7, dtype=dtype.int32)
|
self.g = msd.Geometric(0.7, dtype=dtype.int32)
|
||||||
|
@ -80,6 +90,7 @@ class KL(nn.Cell):
|
||||||
def construct(self, x_):
|
def construct(self, x_):
|
||||||
return self.g.kl_loss('Geometric', x_)
|
return self.g.kl_loss('Geometric', x_)
|
||||||
|
|
||||||
|
|
||||||
def test_kl_loss():
|
def test_kl_loss():
|
||||||
"""
|
"""
|
||||||
Test kl_loss.
|
Test kl_loss.
|
||||||
|
@ -88,16 +99,19 @@ def test_kl_loss():
|
||||||
probs1_b = 0.5
|
probs1_b = 0.5
|
||||||
probs0_a = 1 - probs1_a
|
probs0_a = 1 - probs1_a
|
||||||
probs0_b = 1 - probs1_b
|
probs0_b = 1 - probs1_b
|
||||||
expect_kl_loss = np.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * np.log(probs0_a / probs0_b)
|
expect_kl_loss = np.log(probs1_a / probs1_b) + \
|
||||||
|
(probs0_a / probs1_a) * np.log(probs0_a / probs0_b)
|
||||||
kl_loss = KL()
|
kl_loss = KL()
|
||||||
output = kl_loss(Tensor([probs1_b], dtype=dtype.float32))
|
output = kl_loss(Tensor([probs1_b], dtype=dtype.float32))
|
||||||
tol = 1e-6
|
tol = 1e-6
|
||||||
assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all()
|
assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all()
|
||||||
|
|
||||||
|
|
||||||
class Basics(nn.Cell):
|
class Basics(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: mean/sd/mode of Geometric distribution.
|
Test class: mean/sd/mode of Geometric distribution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Basics, self).__init__()
|
super(Basics, self).__init__()
|
||||||
self.g = msd.Geometric([0.5, 0.5], dtype=dtype.int32)
|
self.g = msd.Geometric([0.5, 0.5], dtype=dtype.int32)
|
||||||
|
@ -105,6 +119,7 @@ class Basics(nn.Cell):
|
||||||
def construct(self):
|
def construct(self):
|
||||||
return self.g.mean(), self.g.sd(), self.g.mode()
|
return self.g.mean(), self.g.sd(), self.g.mode()
|
||||||
|
|
||||||
|
|
||||||
def test_basics():
|
def test_basics():
|
||||||
"""
|
"""
|
||||||
Test mean/standard deviation/mode.
|
Test mean/standard deviation/mode.
|
||||||
|
@ -115,14 +130,16 @@ def test_basics():
|
||||||
expect_sd = np.sqrt(np.array([0.5, 0.5]) / np.square(np.array([0.5, 0.5])))
|
expect_sd = np.sqrt(np.array([0.5, 0.5]) / np.square(np.array([0.5, 0.5])))
|
||||||
expect_mode = [0.0, 0.0]
|
expect_mode = [0.0, 0.0]
|
||||||
tol = 1e-6
|
tol = 1e-6
|
||||||
assert (np.abs(mean.asnumpy()- expect_mean) < tol).all()
|
assert (np.abs(mean.asnumpy() - expect_mean) < tol).all()
|
||||||
assert (np.abs(sd.asnumpy() - expect_sd) < tol).all()
|
assert (np.abs(sd.asnumpy() - expect_sd) < tol).all()
|
||||||
assert (np.abs(mode.asnumpy() - expect_mode) < tol).all()
|
assert (np.abs(mode.asnumpy() - expect_mode) < tol).all()
|
||||||
|
|
||||||
|
|
||||||
class Sampling(nn.Cell):
|
class Sampling(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: log probability of bernoulli distribution.
|
Test class: log probability of bernoulli distribution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, shape, seed=0):
|
def __init__(self, shape, seed=0):
|
||||||
super(Sampling, self).__init__()
|
super(Sampling, self).__init__()
|
||||||
self.g = msd.Geometric([0.7, 0.5], seed=seed, dtype=dtype.int32)
|
self.g = msd.Geometric([0.7, 0.5], seed=seed, dtype=dtype.int32)
|
||||||
|
@ -131,6 +148,7 @@ class Sampling(nn.Cell):
|
||||||
def construct(self, probs=None):
|
def construct(self, probs=None):
|
||||||
return self.g.sample(self.shape, probs)
|
return self.g.sample(self.shape, probs)
|
||||||
|
|
||||||
|
|
||||||
def test_sample():
|
def test_sample():
|
||||||
"""
|
"""
|
||||||
Test sample.
|
Test sample.
|
||||||
|
@ -140,10 +158,12 @@ def test_sample():
|
||||||
output = sample()
|
output = sample()
|
||||||
assert output.shape == (2, 3, 2)
|
assert output.shape == (2, 3, 2)
|
||||||
|
|
||||||
|
|
||||||
class CDF(nn.Cell):
|
class CDF(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: cdf of Geometric distribution.
|
Test class: cdf of Geometric distribution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(CDF, self).__init__()
|
super(CDF, self).__init__()
|
||||||
self.g = msd.Geometric(0.7, dtype=dtype.int32)
|
self.g = msd.Geometric(0.7, dtype=dtype.int32)
|
||||||
|
@ -151,22 +171,26 @@ class CDF(nn.Cell):
|
||||||
def construct(self, x_):
|
def construct(self, x_):
|
||||||
return self.g.cdf(x_)
|
return self.g.cdf(x_)
|
||||||
|
|
||||||
|
|
||||||
def test_cdf():
|
def test_cdf():
|
||||||
"""
|
"""
|
||||||
Test cdf.
|
Test cdf.
|
||||||
"""
|
"""
|
||||||
geom_benchmark = stats.geom(0.7)
|
geom_benchmark = stats.geom(0.7)
|
||||||
expect_cdf = geom_benchmark.cdf([0, 1, 2, 3, 4]).astype(np.float32)
|
expect_cdf = geom_benchmark.cdf([0, 1, 2, 3, 4]).astype(np.float32)
|
||||||
x_ = Tensor(np.array([-1, 0, 1, 2, 3]).astype(np.int32), dtype=dtype.float32)
|
x_ = Tensor(np.array([-1, 0, 1, 2, 3]
|
||||||
|
).astype(np.int32), dtype=dtype.float32)
|
||||||
cdf = CDF()
|
cdf = CDF()
|
||||||
output = cdf(x_)
|
output = cdf(x_)
|
||||||
tol = 1e-6
|
tol = 1e-6
|
||||||
assert (np.abs(output.asnumpy() - expect_cdf) < tol).all()
|
assert (np.abs(output.asnumpy() - expect_cdf) < tol).all()
|
||||||
|
|
||||||
|
|
||||||
class LogCDF(nn.Cell):
|
class LogCDF(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: log cdf of Geometric distribution.
|
Test class: log cdf of Geometric distribution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(LogCDF, self).__init__()
|
super(LogCDF, self).__init__()
|
||||||
self.g = msd.Geometric(0.7, dtype=dtype.int32)
|
self.g = msd.Geometric(0.7, dtype=dtype.int32)
|
||||||
|
@ -174,22 +198,26 @@ class LogCDF(nn.Cell):
|
||||||
def construct(self, x_):
|
def construct(self, x_):
|
||||||
return self.g.log_cdf(x_)
|
return self.g.log_cdf(x_)
|
||||||
|
|
||||||
|
|
||||||
def test_logcdf():
|
def test_logcdf():
|
||||||
"""
|
"""
|
||||||
Test log_cdf.
|
Test log_cdf.
|
||||||
"""
|
"""
|
||||||
geom_benchmark = stats.geom(0.7)
|
geom_benchmark = stats.geom(0.7)
|
||||||
expect_logcdf = geom_benchmark.logcdf([1, 2, 3, 4, 5]).astype(np.float32)
|
expect_logcdf = geom_benchmark.logcdf([1, 2, 3, 4, 5]).astype(np.float32)
|
||||||
x_ = Tensor(np.array([0, 1, 2, 3, 4]).astype(np.int32), dtype=dtype.float32)
|
x_ = Tensor(np.array([0, 1, 2, 3, 4]).astype(
|
||||||
|
np.int32), dtype=dtype.float32)
|
||||||
logcdf = LogCDF()
|
logcdf = LogCDF()
|
||||||
output = logcdf(x_)
|
output = logcdf(x_)
|
||||||
tol = 1e-6
|
tol = 1e-6
|
||||||
assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all()
|
assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all()
|
||||||
|
|
||||||
|
|
||||||
class SF(nn.Cell):
|
class SF(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: survial funciton of Geometric distribution.
|
Test class: survial function of Geometric distribution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(SF, self).__init__()
|
super(SF, self).__init__()
|
||||||
self.g = msd.Geometric(0.7, dtype=dtype.int32)
|
self.g = msd.Geometric(0.7, dtype=dtype.int32)
|
||||||
|
@ -197,22 +225,26 @@ class SF(nn.Cell):
|
||||||
def construct(self, x_):
|
def construct(self, x_):
|
||||||
return self.g.survival_function(x_)
|
return self.g.survival_function(x_)
|
||||||
|
|
||||||
|
|
||||||
def test_survival():
|
def test_survival():
|
||||||
"""
|
"""
|
||||||
Test survival function.
|
Test survival function.
|
||||||
"""
|
"""
|
||||||
geom_benchmark = stats.geom(0.7)
|
geom_benchmark = stats.geom(0.7)
|
||||||
expect_survival = geom_benchmark.sf([0, 1, 2, 3, 4]).astype(np.float32)
|
expect_survival = geom_benchmark.sf([0, 1, 2, 3, 4]).astype(np.float32)
|
||||||
x_ = Tensor(np.array([-1, 0, 1, 2, 3]).astype(np.int32), dtype=dtype.float32)
|
x_ = Tensor(np.array([-1, 0, 1, 2, 3]
|
||||||
|
).astype(np.int32), dtype=dtype.float32)
|
||||||
sf = SF()
|
sf = SF()
|
||||||
output = sf(x_)
|
output = sf(x_)
|
||||||
tol = 1e-6
|
tol = 1e-6
|
||||||
assert (np.abs(output.asnumpy() - expect_survival) < tol).all()
|
assert (np.abs(output.asnumpy() - expect_survival) < tol).all()
|
||||||
|
|
||||||
|
|
||||||
class LogSF(nn.Cell):
|
class LogSF(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: log survial funciton of Geometric distribution.
|
Test class: log survial function of Geometric distribution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(LogSF, self).__init__()
|
super(LogSF, self).__init__()
|
||||||
self.g = msd.Geometric(0.7, dtype=dtype.int32)
|
self.g = msd.Geometric(0.7, dtype=dtype.int32)
|
||||||
|
@ -220,22 +252,27 @@ class LogSF(nn.Cell):
|
||||||
def construct(self, x_):
|
def construct(self, x_):
|
||||||
return self.g.log_survival(x_)
|
return self.g.log_survival(x_)
|
||||||
|
|
||||||
|
|
||||||
def test_log_survival():
|
def test_log_survival():
|
||||||
"""
|
"""
|
||||||
Test log_survival function.
|
Test log_survival function.
|
||||||
"""
|
"""
|
||||||
geom_benchmark = stats.geom(0.7)
|
geom_benchmark = stats.geom(0.7)
|
||||||
expect_logsurvival = geom_benchmark.logsf([0, 1, 2, 3, 4]).astype(np.float32)
|
expect_logsurvival = geom_benchmark.logsf(
|
||||||
x_ = Tensor(np.array([-1, 0, 1, 2, 3]).astype(np.float32), dtype=dtype.float32)
|
[0, 1, 2, 3, 4]).astype(np.float32)
|
||||||
|
x_ = Tensor(np.array([-1, 0, 1, 2, 3]
|
||||||
|
).astype(np.float32), dtype=dtype.float32)
|
||||||
log_sf = LogSF()
|
log_sf = LogSF()
|
||||||
output = log_sf(x_)
|
output = log_sf(x_)
|
||||||
tol = 5e-6
|
tol = 5e-6
|
||||||
assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all()
|
assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all()
|
||||||
|
|
||||||
|
|
||||||
class EntropyH(nn.Cell):
|
class EntropyH(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: entropy of Geometric distribution.
|
Test class: entropy of Geometric distribution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(EntropyH, self).__init__()
|
super(EntropyH, self).__init__()
|
||||||
self.g = msd.Geometric(0.7, dtype=dtype.int32)
|
self.g = msd.Geometric(0.7, dtype=dtype.int32)
|
||||||
|
@ -243,6 +280,7 @@ class EntropyH(nn.Cell):
|
||||||
def construct(self):
|
def construct(self):
|
||||||
return self.g.entropy()
|
return self.g.entropy()
|
||||||
|
|
||||||
|
|
||||||
def test_entropy():
|
def test_entropy():
|
||||||
"""
|
"""
|
||||||
Test entropy.
|
Test entropy.
|
||||||
|
@ -254,10 +292,12 @@ def test_entropy():
|
||||||
tol = 1e-6
|
tol = 1e-6
|
||||||
assert (np.abs(output.asnumpy() - expect_entropy) < tol).all()
|
assert (np.abs(output.asnumpy() - expect_entropy) < tol).all()
|
||||||
|
|
||||||
|
|
||||||
class CrossEntropy(nn.Cell):
|
class CrossEntropy(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: cross entropy between Geometric distributions.
|
Test class: cross entropy between Geometric distributions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(CrossEntropy, self).__init__()
|
super(CrossEntropy, self).__init__()
|
||||||
self.g = msd.Geometric(0.7, dtype=dtype.int32)
|
self.g = msd.Geometric(0.7, dtype=dtype.int32)
|
||||||
|
@ -269,6 +309,7 @@ class CrossEntropy(nn.Cell):
|
||||||
ans = self.g.cross_entropy('Geometric', x_)
|
ans = self.g.cross_entropy('Geometric', x_)
|
||||||
return h_sum_kl - ans
|
return h_sum_kl - ans
|
||||||
|
|
||||||
|
|
||||||
def test_cross_entropy():
|
def test_cross_entropy():
|
||||||
"""
|
"""
|
||||||
Test cross_entropy.
|
Test cross_entropy.
|
||||||
|
|
|
@ -43,20 +43,29 @@ def create_dataset(data_path, batch_size=32, repeat_size=1,
|
||||||
shift = 0.0
|
shift = 0.0
|
||||||
|
|
||||||
# define map operations
|
# define map operations
|
||||||
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # resize images to (32, 32)
|
# resize images to (32, 32)
|
||||||
|
resize_op = CV.Resize((resize_height, resize_width),
|
||||||
|
interpolation=Inter.LINEAR)
|
||||||
rescale_op = CV.Rescale(rescale, shift) # rescale images
|
rescale_op = CV.Rescale(rescale, shift) # rescale images
|
||||||
hwc2chw_op = CV.HWC2CHW() # change shape from (height, width, channel) to (channel, height, width) to fit network.
|
# change shape from (height, width, channel) to (channel, height, width) to fit network.
|
||||||
type_cast_op = C.TypeCast(mstype.int32) # change data type of label to int32 to fit network
|
hwc2chw_op = CV.HWC2CHW()
|
||||||
|
# change data type of label to int32 to fit network
|
||||||
|
type_cast_op = C.TypeCast(mstype.int32)
|
||||||
|
|
||||||
# apply map operations on images
|
# apply map operations on images
|
||||||
mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=num_parallel_workers)
|
mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op,
|
||||||
mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers)
|
num_parallel_workers=num_parallel_workers)
|
||||||
mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers)
|
mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op,
|
||||||
mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)
|
num_parallel_workers=num_parallel_workers)
|
||||||
|
mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op,
|
||||||
|
num_parallel_workers=num_parallel_workers)
|
||||||
|
mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op,
|
||||||
|
num_parallel_workers=num_parallel_workers)
|
||||||
|
|
||||||
# apply DatasetOps
|
# apply DatasetOps
|
||||||
buffer_size = 10000
|
buffer_size = 10000
|
||||||
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script
|
# 10000 as in LeNet train script
|
||||||
|
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
|
||||||
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
|
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
|
||||||
mnist_ds = mnist_ds.repeat(repeat_size)
|
mnist_ds = mnist_ds.repeat(repeat_size)
|
||||||
|
|
||||||
|
@ -68,7 +77,7 @@ def save_img(data, name, size=32, num=32):
|
||||||
Visualize data and save to target files
|
Visualize data and save to target files
|
||||||
Args:
|
Args:
|
||||||
data: nparray of size (num, size, size)
|
data: nparray of size (num, size, size)
|
||||||
name: ouput file name
|
name: output file name
|
||||||
size: image size
|
size: image size
|
||||||
num: number of images
|
num: number of images
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -62,15 +62,17 @@ def test_prob():
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
msd.Categorical([1.0], dtype=dtype.int32)
|
msd.Categorical([1.0], dtype=dtype.int32)
|
||||||
|
|
||||||
|
|
||||||
def test_categorical_sum():
|
def test_categorical_sum():
|
||||||
"""
|
"""
|
||||||
Invaild probabilities.
|
Invalid probabilities.
|
||||||
"""
|
"""
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
msd.Categorical([[0.1, 0.2], [0.4, 0.6]], dtype=dtype.int32)
|
msd.Categorical([[0.1, 0.2], [0.4, 0.6]], dtype=dtype.int32)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
msd.Categorical([[0.5, 0.7], [0.6, 0.6]], dtype=dtype.int32)
|
msd.Categorical([[0.5, 0.7], [0.6, 0.6]], dtype=dtype.int32)
|
||||||
|
|
||||||
|
|
||||||
def rank():
|
def rank():
|
||||||
"""
|
"""
|
||||||
Rank dimenshion less than 1.
|
Rank dimenshion less than 1.
|
||||||
|
@ -80,7 +82,9 @@ def rank():
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
msd.Categorical(np.array(0.3).astype(np.float32), dtype=dtype.int32)
|
msd.Categorical(np.array(0.3).astype(np.float32), dtype=dtype.int32)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
msd.Categorical(Tensor(np.array(0.3).astype(np.float32)), dtype=dtype.int32)
|
msd.Categorical(
|
||||||
|
Tensor(np.array(0.3).astype(np.float32)), dtype=dtype.int32)
|
||||||
|
|
||||||
|
|
||||||
class CategoricalProb(nn.Cell):
|
class CategoricalProb(nn.Cell):
|
||||||
"""
|
"""
|
||||||
|
@ -211,6 +215,7 @@ class CategoricalConstruct(nn.Cell):
|
||||||
prob2 = self.c1('prob', value, probs)
|
prob2 = self.c1('prob', value, probs)
|
||||||
return prob + prob1 + prob2
|
return prob + prob1 + prob2
|
||||||
|
|
||||||
|
|
||||||
def test_categorical_construct():
|
def test_categorical_construct():
|
||||||
"""
|
"""
|
||||||
Test probability function going through construct.
|
Test probability function going through construct.
|
||||||
|
@ -235,7 +240,7 @@ class CategoricalBasics(nn.Cell):
|
||||||
def construct(self, probs):
|
def construct(self, probs):
|
||||||
basics1 = self.c.mean() + self.c.var() + self.c.mode() + self.c.entropy()
|
basics1 = self.c.mean() + self.c.var() + self.c.mode() + self.c.entropy()
|
||||||
basics2 = self.c1.mean(probs) + self.c1.var(probs) +\
|
basics2 = self.c1.mean(probs) + self.c1.var(probs) +\
|
||||||
self.c1.mode(probs) + self.c1.entropy(probs)
|
self.c1.mode(probs) + self.c1.entropy(probs)
|
||||||
return basics1 + basics2
|
return basics1 + basics2
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -29,19 +29,23 @@ func_name_list = ['prob', 'log_prob', 'cdf', 'log_cdf',
|
||||||
'entropy', 'kl_loss', 'cross_entropy',
|
'entropy', 'kl_loss', 'cross_entropy',
|
||||||
'sample']
|
'sample']
|
||||||
|
|
||||||
|
|
||||||
class MyExponential(msd.Distribution):
|
class MyExponential(msd.Distribution):
|
||||||
"""
|
"""
|
||||||
Test distirbution class: no function is implemented.
|
Test distribution class: no function is implemented.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, rate=None, seed=None, dtype=mstype.float32, name="MyExponential"):
|
def __init__(self, rate=None, seed=None, dtype=mstype.float32, name="MyExponential"):
|
||||||
param = dict(locals())
|
param = dict(locals())
|
||||||
param['param_dict'] = {'rate': rate}
|
param['param_dict'] = {'rate': rate}
|
||||||
super(MyExponential, self).__init__(seed, dtype, name, param)
|
super(MyExponential, self).__init__(seed, dtype, name, param)
|
||||||
|
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test Net: function called through construct.
|
Test Net: function called through construct.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, func_name):
|
def __init__(self, func_name):
|
||||||
super(Net, self).__init__()
|
super(Net, self).__init__()
|
||||||
self.dist = MyExponential()
|
self.dist = MyExponential()
|
||||||
|
@ -61,6 +65,7 @@ def test_raise_not_implemented_error_construct():
|
||||||
net = Net(func_name)
|
net = Net(func_name)
|
||||||
net(value)
|
net(value)
|
||||||
|
|
||||||
|
|
||||||
def test_raise_not_implemented_error_construct_graph_mode():
|
def test_raise_not_implemented_error_construct_graph_mode():
|
||||||
"""
|
"""
|
||||||
test raise not implemented error in graph mode.
|
test raise not implemented error in graph mode.
|
||||||
|
@ -72,10 +77,12 @@ def test_raise_not_implemented_error_construct_graph_mode():
|
||||||
net = Net(func_name)
|
net = Net(func_name)
|
||||||
net(value)
|
net(value)
|
||||||
|
|
||||||
|
|
||||||
class Net1(nn.Cell):
|
class Net1(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test Net: function called directly.
|
Test Net: function called directly.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, func_name):
|
def __init__(self, func_name):
|
||||||
super(Net1, self).__init__()
|
super(Net1, self).__init__()
|
||||||
self.dist = MyExponential()
|
self.dist = MyExponential()
|
||||||
|
@ -84,6 +91,7 @@ class Net1(nn.Cell):
|
||||||
def construct(self, *args, **kwargs):
|
def construct(self, *args, **kwargs):
|
||||||
return self.func(*args, **kwargs)
|
return self.func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def test_raise_not_implemented_error():
|
def test_raise_not_implemented_error():
|
||||||
"""
|
"""
|
||||||
test raise not implemented error in pynative mode.
|
test raise not implemented error in pynative mode.
|
||||||
|
@ -94,6 +102,7 @@ def test_raise_not_implemented_error():
|
||||||
net = Net1(func_name)
|
net = Net1(func_name)
|
||||||
net(value)
|
net(value)
|
||||||
|
|
||||||
|
|
||||||
def test_raise_not_implemented_error_graph_mode():
|
def test_raise_not_implemented_error_graph_mode():
|
||||||
"""
|
"""
|
||||||
test raise not implemented error in graph mode.
|
test raise not implemented error in graph mode.
|
||||||
|
|
|
@ -23,6 +23,7 @@ import mindspore.nn.probability.distribution as msd
|
||||||
from mindspore import dtype
|
from mindspore import dtype
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
|
|
||||||
|
|
||||||
def test_uniform_shape_errpr():
|
def test_uniform_shape_errpr():
|
||||||
"""
|
"""
|
||||||
Invalid shapes.
|
Invalid shapes.
|
||||||
|
@ -30,18 +31,22 @@ def test_uniform_shape_errpr():
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
msd.Uniform([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32)
|
msd.Uniform([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32)
|
||||||
|
|
||||||
|
|
||||||
def test_type():
|
def test_type():
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
msd.Uniform(0., 1., dtype=dtype.int32)
|
msd.Uniform(0., 1., dtype=dtype.int32)
|
||||||
|
|
||||||
|
|
||||||
def test_name():
|
def test_name():
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
msd.Uniform(0., 1., name=1.0)
|
msd.Uniform(0., 1., name=1.0)
|
||||||
|
|
||||||
|
|
||||||
def test_seed():
|
def test_seed():
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
msd.Uniform(0., 1., seed='seed')
|
msd.Uniform(0., 1., seed='seed')
|
||||||
|
|
||||||
|
|
||||||
def test_arguments():
|
def test_arguments():
|
||||||
"""
|
"""
|
||||||
Args passing during initialization.
|
Args passing during initialization.
|
||||||
|
@ -66,6 +71,7 @@ class UniformProb(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Uniform distribution: initialize with low/high.
|
Uniform distribution: initialize with low/high.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(UniformProb, self).__init__()
|
super(UniformProb, self).__init__()
|
||||||
self.u = msd.Uniform(3.0, 4.0, dtype=dtype.float32)
|
self.u = msd.Uniform(3.0, 4.0, dtype=dtype.float32)
|
||||||
|
@ -79,6 +85,7 @@ class UniformProb(nn.Cell):
|
||||||
log_sf = self.u.log_survival(value)
|
log_sf = self.u.log_survival(value)
|
||||||
return prob + log_prob + cdf + log_cdf + sf + log_sf
|
return prob + log_prob + cdf + log_cdf + sf + log_sf
|
||||||
|
|
||||||
|
|
||||||
def test_uniform_prob():
|
def test_uniform_prob():
|
||||||
"""
|
"""
|
||||||
Test probability functions: passing value through construct.
|
Test probability functions: passing value through construct.
|
||||||
|
@ -88,10 +95,12 @@ def test_uniform_prob():
|
||||||
ans = net(value)
|
ans = net(value)
|
||||||
assert isinstance(ans, Tensor)
|
assert isinstance(ans, Tensor)
|
||||||
|
|
||||||
|
|
||||||
class UniformProb1(nn.Cell):
|
class UniformProb1(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Uniform distribution: initialize without low/high.
|
Uniform distribution: initialize without low/high.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(UniformProb1, self).__init__()
|
super(UniformProb1, self).__init__()
|
||||||
self.u = msd.Uniform(dtype=dtype.float32)
|
self.u = msd.Uniform(dtype=dtype.float32)
|
||||||
|
@ -105,6 +114,7 @@ class UniformProb1(nn.Cell):
|
||||||
log_sf = self.u.log_survival(value, low, high)
|
log_sf = self.u.log_survival(value, low, high)
|
||||||
return prob + log_prob + cdf + log_cdf + sf + log_sf
|
return prob + log_prob + cdf + log_cdf + sf + log_sf
|
||||||
|
|
||||||
|
|
||||||
def test_uniform_prob1():
|
def test_uniform_prob1():
|
||||||
"""
|
"""
|
||||||
Test probability functions: passing low/high, value through construct.
|
Test probability functions: passing low/high, value through construct.
|
||||||
|
@ -116,13 +126,16 @@ def test_uniform_prob1():
|
||||||
ans = net(value, low, high)
|
ans = net(value, low, high)
|
||||||
assert isinstance(ans, Tensor)
|
assert isinstance(ans, Tensor)
|
||||||
|
|
||||||
|
|
||||||
class UniformKl(nn.Cell):
|
class UniformKl(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: kl_loss of Uniform distribution.
|
Test class: kl_loss of Uniform distribution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(UniformKl, self).__init__()
|
super(UniformKl, self).__init__()
|
||||||
self.u1 = msd.Uniform(np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
|
self.u1 = msd.Uniform(
|
||||||
|
np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
|
||||||
self.u2 = msd.Uniform(dtype=dtype.float32)
|
self.u2 = msd.Uniform(dtype=dtype.float32)
|
||||||
|
|
||||||
def construct(self, low_b, high_b, low_a, high_a):
|
def construct(self, low_b, high_b, low_a, high_a):
|
||||||
|
@ -130,6 +143,7 @@ class UniformKl(nn.Cell):
|
||||||
kl2 = self.u2.kl_loss('Uniform', low_b, high_b, low_a, high_a)
|
kl2 = self.u2.kl_loss('Uniform', low_b, high_b, low_a, high_a)
|
||||||
return kl1 + kl2
|
return kl1 + kl2
|
||||||
|
|
||||||
|
|
||||||
def test_kl():
|
def test_kl():
|
||||||
"""
|
"""
|
||||||
Test kl_loss.
|
Test kl_loss.
|
||||||
|
@ -142,13 +156,16 @@ def test_kl():
|
||||||
ans = net(low_b, high_b, low_a, high_a)
|
ans = net(low_b, high_b, low_a, high_a)
|
||||||
assert isinstance(ans, Tensor)
|
assert isinstance(ans, Tensor)
|
||||||
|
|
||||||
|
|
||||||
class UniformCrossEntropy(nn.Cell):
|
class UniformCrossEntropy(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: cross_entropy of Uniform distribution.
|
Test class: cross_entropy of Uniform distribution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(UniformCrossEntropy, self).__init__()
|
super(UniformCrossEntropy, self).__init__()
|
||||||
self.u1 = msd.Uniform(np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
|
self.u1 = msd.Uniform(
|
||||||
|
np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
|
||||||
self.u2 = msd.Uniform(dtype=dtype.float32)
|
self.u2 = msd.Uniform(dtype=dtype.float32)
|
||||||
|
|
||||||
def construct(self, low_b, high_b, low_a, high_a):
|
def construct(self, low_b, high_b, low_a, high_a):
|
||||||
|
@ -156,9 +173,10 @@ class UniformCrossEntropy(nn.Cell):
|
||||||
h2 = self.u2.cross_entropy('Uniform', low_b, high_b, low_a, high_a)
|
h2 = self.u2.cross_entropy('Uniform', low_b, high_b, low_a, high_a)
|
||||||
return h1 + h2
|
return h1 + h2
|
||||||
|
|
||||||
|
|
||||||
def test_cross_entropy():
|
def test_cross_entropy():
|
||||||
"""
|
"""
|
||||||
Test cross_entropy between Unifrom distributions.
|
Test cross_entropy between Uniform distributions.
|
||||||
"""
|
"""
|
||||||
net = UniformCrossEntropy()
|
net = UniformCrossEntropy()
|
||||||
low_b = Tensor(np.array([0.0]).astype(np.float32), dtype=dtype.float32)
|
low_b = Tensor(np.array([0.0]).astype(np.float32), dtype=dtype.float32)
|
||||||
|
@ -168,10 +186,12 @@ def test_cross_entropy():
|
||||||
ans = net(low_b, high_b, low_a, high_a)
|
ans = net(low_b, high_b, low_a, high_a)
|
||||||
assert isinstance(ans, Tensor)
|
assert isinstance(ans, Tensor)
|
||||||
|
|
||||||
|
|
||||||
class UniformBasics(nn.Cell):
|
class UniformBasics(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Test class: basic mean/sd/var/mode/entropy function.
|
Test class: basic mean/sd/var/mode/entropy function.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(UniformBasics, self).__init__()
|
super(UniformBasics, self).__init__()
|
||||||
self.u = msd.Uniform(3.0, 4.0, dtype=dtype.float32)
|
self.u = msd.Uniform(3.0, 4.0, dtype=dtype.float32)
|
||||||
|
@ -183,6 +203,7 @@ class UniformBasics(nn.Cell):
|
||||||
entropy = self.u.entropy()
|
entropy = self.u.entropy()
|
||||||
return mean + sd + var + entropy
|
return mean + sd + var + entropy
|
||||||
|
|
||||||
|
|
||||||
def test_bascis():
|
def test_bascis():
|
||||||
"""
|
"""
|
||||||
Test mean/sd/var/mode/entropy functionality of Uniform.
|
Test mean/sd/var/mode/entropy functionality of Uniform.
|
||||||
|
@ -194,8 +215,9 @@ def test_bascis():
|
||||||
|
|
||||||
class UniConstruct(nn.Cell):
|
class UniConstruct(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Unifrom distribution: going through construct.
|
Uniform distribution: going through construct.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(UniConstruct, self).__init__()
|
super(UniConstruct, self).__init__()
|
||||||
self.u = msd.Uniform(-4.0, 4.0)
|
self.u = msd.Uniform(-4.0, 4.0)
|
||||||
|
@ -207,6 +229,7 @@ class UniConstruct(nn.Cell):
|
||||||
prob2 = self.u1('prob', value, low, high)
|
prob2 = self.u1('prob', value, low, high)
|
||||||
return prob + prob1 + prob2
|
return prob + prob1 + prob2
|
||||||
|
|
||||||
|
|
||||||
def test_uniform_construct():
|
def test_uniform_construct():
|
||||||
"""
|
"""
|
||||||
Test probability function going through construct.
|
Test probability function going through construct.
|
||||||
|
|
Loading…
Reference in New Issue