update doc example in probability

fix typo in probability module

fix pylint in msp
This commit is contained in:
Zichun Ye 2021-01-13 00:12:52 -05:00
parent f87b5e0cc8
commit ca316f4422
18 changed files with 305 additions and 258 deletions

View File

@ -41,7 +41,7 @@ class Bijector(Cell):
Note: Note:
`dtype` of bijector represents the type of the distributions that the bijector could operate on. `dtype` of bijector represents the type of the distributions that the bijector could operate on.
When `dtype` is None, there is no enforcement on the type of input value except that the input value When `dtype` is None, there is no enforcement on the type of input value except that the input value
has to be float type. During initilization, when `dtype` is None, there is no enforcement on the dtype has to be float type. During initialization, when `dtype` is None, there is no enforcement on the dtype
of the parameters. All parameters should have the same float type, otherwise a TypeError will be raised. of the parameters. All parameters should have the same float type, otherwise a TypeError will be raised.
Specifically, the parameter type will follow the dtype of the input value, i.e. parameters of the bijector Specifically, the parameter type will follow the dtype of the input value, i.e. parameters of the bijector
will be casted into the same type as input value when `dtype`is None. will be casted into the same type as input value when `dtype`is None.
@ -65,7 +65,8 @@ class Bijector(Cell):
'is_constant_jacobian', is_constant_jacobian, [bool], name) 'is_constant_jacobian', is_constant_jacobian, [bool], name)
validator.check_value_type('is_injective', is_injective, [bool], name) validator.check_value_type('is_injective', is_injective, [bool], name)
if dtype is not None: if dtype is not None:
validator.check_type_name("dtype", dtype, mstype.float_type, type(self).__name__) validator.check_type_name(
"dtype", dtype, mstype.float_type, type(self).__name__)
self._name = name self._name = name
self._dtype = dtype self._dtype = dtype
self._parameters = {} self._parameters = {}
@ -76,7 +77,7 @@ class Bijector(Cell):
if not(k == 'self' or k.startswith('_')): if not(k == 'self' or k.startswith('_')):
self._parameters[k] = param[k] self._parameters[k] = param[k]
# if no bijector is used as an argument during initilization # if no bijector is used as an argument during initialization
if 'bijector' not in param.keys(): if 'bijector' not in param.keys():
self._batch_shape = self._calc_batch_shape() self._batch_shape = self._calc_batch_shape()
self._is_scalar_batch = self._check_is_scalar_batch() self._is_scalar_batch = self._check_is_scalar_batch()
@ -141,7 +142,8 @@ class Bijector(Cell):
def _shape_mapping(self, shape): def _shape_mapping(self, shape):
shape_tensor = self.fill_base(self.parameter_type, shape, 0.0) shape_tensor = self.fill_base(self.parameter_type, shape, 0.0)
dist_shape_tensor = self.fill_base(self.parameter_type, self.batch_shape, 0.0) dist_shape_tensor = self.fill_base(
self.parameter_type, self.batch_shape, 0.0)
return (shape_tensor + dist_shape_tensor).shape return (shape_tensor + dist_shape_tensor).shape
def shape_mapping(self, shape): def shape_mapping(self, shape):
@ -166,12 +168,15 @@ class Bijector(Cell):
if self.common_dtype is None: if self.common_dtype is None:
self.common_dtype = value_t.dtype self.common_dtype = value_t.dtype
elif value_t.dtype != self.common_dtype: elif value_t.dtype != self.common_dtype:
raise TypeError(f"{name} should have the same dtype as other arguments.") raise TypeError(
f"{name} should have the same dtype as other arguments.")
# check if the parameters are casted into float-type tensors # check if the parameters are casted into float-type tensors
validator.check_type_name(f"dtype of {name}", value_t.dtype, mstype.float_type, type(self).__name__) validator.check_type_name(
f"dtype of {name}", value_t.dtype, mstype.float_type, type(self).__name__)
# check if the dtype of the input_parameter agrees with the bijector's dtype # check if the dtype of the input_parameter agrees with the bijector's dtype
elif value_t.dtype != self.dtype: elif value_t.dtype != self.dtype:
raise TypeError(f"{name} should have the same dtype as the bijector's dtype.") raise TypeError(
f"{name} should have the same dtype as the bijector's dtype.")
self.default_parameters += [value,] self.default_parameters += [value,]
self.parameter_names += [name,] self.parameter_names += [name,]
return value_t return value_t

View File

@ -33,24 +33,17 @@ class Invert(Bijector):
>>> import mindspore.nn as nn >>> import mindspore.nn as nn
>>> import mindspore.nn.probability.bijector as msb >>> import mindspore.nn.probability.bijector as msb
>>> from mindspore import Tensor >>> from mindspore import Tensor
>>> import mindspore.context as context >>> class Net(nn.Cell):
>>> context.set_context(mode=1) ... def __init__(self):
>>> ... super(Net, self).__init__()
>>> # To initialize an inverse Exp bijector. ... self.origin = msb.ScalarAffine(scale=2.0, shift=1.0)
>>> inv_exp = msb.Invert(msb.Exp()) ... self.invert = msb.Invert(self.origin)
>>> value = Tensor([1, 2, 3], dtype=mindspore.float32) ...
>>> ans1 = inv_exp.forward(value) ... def construct(self, x_):
>>> print(ans1.shape) ... return self.invert.forward(x_)
(3,) >>> forward = Net()
>>> ans2 = inv_exp.inverse(value) >>> x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32)
>>> print(ans2.shape) >>> ans = forward(Tensor(x, dtype=dtype.float32))
(3,)
>>> ans3 = inv_exp.forward_log_jacobian(value)
>>> print(ans3.shape)
(3,)
>>> ans4 = inv_exp.inverse_log_jacobian(value)
>>> print(ans4.shape)
(3,)
""" """
def __init__(self, def __init__(self,

View File

@ -84,6 +84,7 @@ class Bernoulli(Distribution):
(3,) (3,)
>>> # `probs` must be passed in during function calls. >>> # `probs` must be passed in during function calls.
>>> ans = b2.mean(probs_a) >>> ans = b2.mean(probs_a)
>>> print(ans.shape)
(1,) (1,)
>>> print(ans.shape) >>> print(ans.shape)
>>> # Interfaces of `kl_loss` and `cross_entropy` are the same as follows: >>> # Interfaces of `kl_loss` and `cross_entropy` are the same as follows:

View File

@ -105,22 +105,6 @@ class Categorical(Distribution):
>>> ans = ca2.kl_loss('Categorical', probs_b, probs_a) >>> ans = ca2.kl_loss('Categorical', probs_b, probs_a)
>>> print(ans.shape) >>> print(ans.shape)
() ()
>>> # Examples of `sample`.
>>> # Args:
>>> # shape (tuple): the shape of the sample. Default: ().
>>> # probs (Tensor): event probabilities. Default: self.probs.
>>> ans = ca1.sample()
>>> print(ans.shape)
()
>>> ans = ca1.sample((2,3))
>>> print(ans.shape)
(2, 3)
>>> ans = ca1.sample((2,3), probs_b)
>>> print(ans.shape)
(2, 3)
>>> ans = ca2.sample((2,3), probs_a)
>>> print(ans.shape)
(2, 3)
""" """
def __init__(self, def __init__(self,

View File

@ -18,7 +18,7 @@ from mindspore.ops import operations as P
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device,\ from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device,\
raise_not_implemented_util raise_not_implemented_util
from ._utils.utils import CheckTuple, CheckTensor from ._utils.utils import CheckTuple, CheckTensor
from ._utils.custom_ops import broadcast_to, exp_generic, log_generic from ._utils.custom_ops import broadcast_to, exp_generic, log_generic
@ -77,7 +77,8 @@ class Distribution(Cell):
# if not a transformed distribution, set the following attribute # if not a transformed distribution, set the following attribute
if 'distribution' not in self.parameters.keys(): if 'distribution' not in self.parameters.keys():
self.parameter_type = set_param_type(self.parameters['param_dict'], dtype) self.parameter_type = set_param_type(
self.parameters['param_dict'], dtype)
self._batch_shape = self._calc_batch_shape() self._batch_shape = self._calc_batch_shape()
self._is_scalar_batch = self._check_is_scalar_batch() self._is_scalar_batch = self._check_is_scalar_batch()
self._broadcast_shape = self._batch_shape self._broadcast_shape = self._batch_shape
@ -152,7 +153,8 @@ class Distribution(Cell):
self.default_parameters = [] self.default_parameters = []
self.parameter_names = [] self.parameter_names = []
# cast value to a tensor if it is not None # cast value to a tensor if it is not None
value_t = None if value is None else cast_to_tensor(value, self.parameter_type) value_t = None if value is None else cast_to_tensor(
value, self.parameter_type)
self.default_parameters += [value_t,] self.default_parameters += [value_t,]
self.parameter_names += [name,] self.parameter_names += [name,]
return value_t return value_t
@ -180,10 +182,12 @@ class Distribution(Cell):
if broadcast_shape is None: if broadcast_shape is None:
broadcast_shape = self.shape_base(arg) broadcast_shape = self.shape_base(arg)
common_dtype = self.dtype_base(arg) common_dtype = self.dtype_base(arg)
broadcast_shape_tensor = self.fill_base(common_dtype, broadcast_shape, 1.0) broadcast_shape_tensor = self.fill_base(
common_dtype, broadcast_shape, 1.0)
else: else:
broadcast_shape = self.shape_base(arg + broadcast_shape_tensor) broadcast_shape = self.shape_base(arg + broadcast_shape_tensor)
broadcast_shape_tensor = self.fill_base(common_dtype, broadcast_shape, 1.0) broadcast_shape_tensor = self.fill_base(
common_dtype, broadcast_shape, 1.0)
arg = self.broadcast(arg, broadcast_shape_tensor) arg = self.broadcast(arg, broadcast_shape_tensor)
# check if the arguments have the same dtype # check if the arguments have the same dtype
self.sametypeshape_base(arg, broadcast_shape_tensor) self.sametypeshape_base(arg, broadcast_shape_tensor)
@ -240,7 +244,7 @@ class Distribution(Cell):
def _set_prob(self): def _set_prob(self):
""" """
Set probability funtion based on the availability of `_prob` and `_log_likehood`. Set probability function based on the availability of `_prob` and `_log_likehood`.
""" """
if hasattr(self, '_prob'): if hasattr(self, '_prob'):
self._call_prob = self._prob self._call_prob = self._prob
@ -303,9 +307,10 @@ class Distribution(Cell):
Set survival function based on the availability of _survival function and `_log_survival` Set survival function based on the availability of _survival function and `_log_survival`
and `_call_cdf`. and `_call_cdf`.
""" """
if not (hasattr(self, '_survival_function') or hasattr(self, '_log_survival') or \ if not (hasattr(self, '_survival_function') or hasattr(self, '_log_survival') or
hasattr(self, '_cdf') or hasattr(self, '_log_cdf')): hasattr(self, '_cdf') or hasattr(self, '_log_cdf')):
self._call_survival = self._raise_not_implemented_error('survival_function') self._call_survival = self._raise_not_implemented_error(
'survival_function')
elif hasattr(self, '_survival_function'): elif hasattr(self, '_survival_function'):
self._call_survival = self._survival_function self._call_survival = self._survival_function
elif hasattr(self, '_log_survival'): elif hasattr(self, '_log_survival'):
@ -317,7 +322,7 @@ class Distribution(Cell):
""" """
Set log cdf based on the availability of `_log_cdf` and `_call_cdf`. Set log cdf based on the availability of `_log_cdf` and `_call_cdf`.
""" """
if not (hasattr(self, '_log_cdf') or hasattr(self, '_cdf') or \ if not (hasattr(self, '_log_cdf') or hasattr(self, '_cdf') or
hasattr(self, '_survival_function') or hasattr(self, '_log_survival')): hasattr(self, '_survival_function') or hasattr(self, '_log_survival')):
self._call_log_cdf = self._raise_not_implemented_error('log_cdf') self._call_log_cdf = self._raise_not_implemented_error('log_cdf')
elif hasattr(self, '_log_cdf'): elif hasattr(self, '_log_cdf'):
@ -329,9 +334,10 @@ class Distribution(Cell):
""" """
Set log survival based on the availability of `_log_survival` and `_call_survival`. Set log survival based on the availability of `_log_survival` and `_call_survival`.
""" """
if not (hasattr(self, '_log_survival') or hasattr(self, '_survival_function') or \ if not (hasattr(self, '_log_survival') or hasattr(self, '_survival_function') or
hasattr(self, '_log_cdf') or hasattr(self, '_cdf')): hasattr(self, '_log_cdf') or hasattr(self, '_cdf')):
self._call_log_survival = self._raise_not_implemented_error('log_cdf') self._call_log_survival = self._raise_not_implemented_error(
'log_cdf')
elif hasattr(self, '_log_survival'): elif hasattr(self, '_log_survival'):
self._call_log_survival = self._log_survival self._call_log_survival = self._log_survival
elif hasattr(self, '_call_survival'): elif hasattr(self, '_call_survival'):
@ -344,7 +350,8 @@ class Distribution(Cell):
if hasattr(self, '_cross_entropy'): if hasattr(self, '_cross_entropy'):
self._call_cross_entropy = self._cross_entropy self._call_cross_entropy = self._cross_entropy
else: else:
self._call_cross_entropy = self._raise_not_implemented_error('cross_entropy') self._call_cross_entropy = self._raise_not_implemented_error(
'cross_entropy')
def _get_dist_args(self, *args, **kwargs): def _get_dist_args(self, *args, **kwargs):
return raise_not_implemented_util('get_dist_args', self.name, *args, **kwargs) return raise_not_implemented_util('get_dist_args', self.name, *args, **kwargs)
@ -375,6 +382,7 @@ class Distribution(Cell):
def _raise_not_implemented_error(self, func_name): def _raise_not_implemented_error(self, func_name):
name = self.name name = self.name
def raise_error(*args, **kwargs): def raise_error(*args, **kwargs):
return raise_not_implemented_util(func_name, name, *args, **kwargs) return raise_not_implemented_util(func_name, name, *args, **kwargs)
return raise_error return raise_error

View File

@ -46,48 +46,19 @@ class Gumbel(TransformedDistribution):
Examples: Examples:
>>> import mindspore >>> import mindspore
>>> import mindspore.context as context
>>> import mindspore.nn as nn >>> import mindspore.nn as nn
>>> import mindspore.nn.probability.distribution as msd >>> import mindspore.nn.probability.distribution as msd
>>> from mindspore import Tensor >>> from mindspore import Tensor
>>> context.set_context(mode=1) >>> class Prob(nn.Cell):
>>> # To initialize a Gumbel distribution of `loc` 3.0 and `scale` 4.0. ... def __init__(self):
>>> gumbel = msd.Gumbel(3.0, 4.0, dtype=mindspore.float32) ... super(Prob, self).__init__()
>>> # Private interfaces of probability functions corresponding to public interfaces, including ... self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32)
>>> # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, have the same ...
>>> # arguments as follows. ... def construct(self, x_):
>>> # Args: ... return self.gum.prob(x_)
>>> # value (Tensor): the value to be evaluated. >>> value = np.array([1.0, 2.0]).astype(np.float32)
>>> # Examples of `prob`. >>> pdf = Prob()
>>> # Similar calls can be made to other probability functions >>> output = pdf(Tensor(value, dtype=dtype.float32))
>>> # by replacing 'prob' by the name of the function.
>>> value = Tensor([1.0, 2.0, 3.0], dtype=mindspore.float32)
>>> ans = gumbel.prob(value)
>>> print(ans.shape)
(3,)
>>> # Functions `mean`, `mode`, sd`, `var`, and `entropy` do not take in any argument.
>>> ans = gumbel.mean()
>>> print(ans.shape)
()
>>> # Interfaces of 'kl_loss' and 'cross_entropy' are the same:
>>> # Args:
>>> # dist (str): the type of the distributions. Only "Gumbel" is supported.
>>> # loc_b (Tensor): the loc of distribution b.
>>> # scale_b (Tensor): the scale distribution b.
>>> # Examples of `kl_loss`. `cross_entropy` is similar.
>>> loc_b = Tensor([1.0], dtype=mindspore.float32)
>>> scale_b = Tensor([1.0, 1.5, 2.0], dtype=mindspore.float32)
>>> ans = gumbel.kl_loss('Gumbel', loc_b, scale_b)
>>> print(ans.shape)
(3,)
>>> # Examples of `sample`.
>>> # Args:
>>> # shape (tuple): the shape of the sample. Default: ()
>>> ans = gumbel.sample()
>>> print(ans.shape)
()
>>> ans = gumbel.sample((2,3))
>>> print(ans.shape)
""" """
def __init__(self, def __init__(self,

View File

@ -45,103 +45,17 @@ class LogNormal(msd.TransformedDistribution):
Examples: Examples:
>>> import mindspore >>> import mindspore
>>> import mindspore.context as context
>>> import mindspore.nn as nn >>> import mindspore.nn as nn
>>> import mindspore.nn.probability.distribution as msd >>> import mindspore.nn.probability.distribution as msd
>>> from mindspore import Tensor >>> from mindspore import Tensor
>>> context.set_context(mode=1) ... class Prob(nn.Cell):
>>> # To initialize a LogNormal distribution of `loc` 3.0 and `scale` 4.0. ... def __init__(self):
>>> n1 = msd.LogNormal(3.0, 4.0, dtype=mindspore.float32) ... super(Prob, self).__init__()
>>> # A LogNormal distribution can be initialized without arguments. ... self.ln = msd.LogNormal(np.array([0.3]), np.array([[0.2], [0.4]]), dtype=dtype.float32)
>>> # In this case, `loc` and `scale` must be passed in during function calls. ... def construct(self, x_):
>>> n2 = msd.LogNormal(dtype=mindspore.float32) ... return self.ln.prob(x_)
>>> >>> pdf = Prob()
>>> # Here are some tensors used below for testing >>> output = pdf(Tensor([1.0, 2.0], dtype=dtype.float32))
>>> value = Tensor([1.0, 2.0, 3.0], dtype=mindspore.float32)
>>> loc_a = Tensor([2.0], dtype=mindspore.float32)
>>> scale_a = Tensor([2.0, 2.0, 2.0], dtype=mindspore.float32)
>>> loc_b = Tensor([1.0], dtype=mindspore.float32)
>>> scale_b = Tensor([1.0, 1.5, 2.0], dtype=mindspore.float32)
>>>
>>> # Private interfaces of probability functions corresponding to public interfaces, including
>>> # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, have the same
>>> # arguments as follows.
>>> # Args:
>>> # value (Tensor): the value to be evaluated.
>>> # loc (Tensor): the loc of distribution. Default: None. If `loc` is passed in as None,
>>> # the mean of the underlying Normal distribution will be used.
>>> # scale (Tensor): the scale of distribution. Default: None. If `scale` is passed in as None,
>>> # the standard deviation of the underlying Normal distribution will be used.
>>> # Examples of `prob`.
>>> # Similar calls can be made to other probability functions
>>> # by replacing 'prob' by the name of the function.
>>> ans = n1.prob(value)
>>> print(ans.shape)
(3,)
>>> # Evaluate with respect to distribution b.
>>> ans = n1.prob(value, loc_b, scale_b)
>>> print(ans.shape)
(3,)
>>> # `loc` and `scale` must be passed in during function calls since they were not passed in construct.
>>> ans = n2.prob(value, loc_a, scale_a)
>>> print(ans.shape)
(3,)
>>> # Functions `mean`, `sd`, `var`, and `entropy` have the same arguments.
>>> # Args:
>>> # loc (Tensor): the loc of distribution. Default: None. If `loc` is passed in as None,
>>> # the mean of the underlying Normal distribution will be used.
>>> # scale (Tensor): the scale of distribution. Default: None. If `scale` is passed in as None,
>>> # the standard deviation of the underlying Normal distribution will be used.
>>> # Example of `mean`. `sd`, `var`, and `entropy` are similar.
>>> ans = n1.mean()
>>> print(ans.shape)
()
>>> ans = n1.mean(loc_b, scale_b)
>>> print(ans.shape)
(3,)
>>> # `loc` and `scale` must be passed in during function calls since they were not passed in construct.
>>> ans = n2.mean(loc_a, scale_a)
>>> print(ans.shape)
(3,)
>>> # Interfaces of 'kl_loss' and 'cross_entropy' are the same:
>>> # Args:
>>> # dist (str): the type of the distributions. Only "Normal" is supported.
>>> # loc_b (Tensor): the loc of distribution b.
>>> # scale_b (Tensor): the scale distribution b.
>>> # loc_a (Tensor): the loc of distribution a. Default: None. If `loc` is passed in as None,
>>> # the mean of the underlying Normal distribution will be used.
>>> # scale_a (Tensor): the scale distribution a. Default: None. If `scale` is passed in as None,
>>> # the standard deviation of the underlying Normal distribution will be used.
>>> # Examples of `kl_loss`. `cross_entropy` is similar.
>>> ans = n1.kl_loss('LogNormal', loc_b, scale_b)
>>> print(ans.shape)
(3,)
>>> ans = n1.kl_loss('LogNormal', loc_b, scale_b, loc_a, scale_a)
>>> print(ans.shape)
(3,)
>>> # Additional `loc` and `scale` must be passed in since they were not passed in construct.
>>> ans = n2.kl_loss('LogNormal', loc_b, scale_b, loc_a, scale_a)
>>> print(ans.shape)
(3,)
>>> # Examples of `sample`.
>>> # Args:
>>> # shape (tuple): the shape of the sample. Default: ()
>>> # loc (Tensor): the loc of the distribution. Default: None. If `loc` is passed in as None,
>>> # the mean of the underlying Normal distribution will be used.
>>> # scale (Tensor): the scale of the distribution. Default: None. If `scale` is passed in as None,
>>> # the standard deviation of the underlying Normal distribution will be used.
>>> ans = n1.sample()
>>> print(ans.shape)
()
>>> ans = n1.sample((2,3))
>>> print(ans.shape)
(2, 3)
>>> ans = n1.sample((2,3), loc_b, scale_b)
>>> print(ans.shape)
(2, 3, 3)
>>> ans = n2.sample((2,3), loc_a, scale_a)
>>> print(ans.shape)
(2, 3, 3)
""" """
def __init__(self, def __init__(self,

View File

@ -89,7 +89,7 @@ class Poisson(Distribution):
>>> # `rate` must be passed in during function calls. >>> # `rate` must be passed in during function calls.
>>> ans = p2.mean(rate_a) >>> ans = p2.mean(rate_a)
>>> print(ans.shape) >>> print(ans.shape)
(1,) ()
>>> # Examples of `sample`. >>> # Examples of `sample`.
>>> # Args: >>> # Args:
>>> # shape (tuple): the shape of the sample. Default: () >>> # shape (tuple): the shape of the sample. Default: ()

View File

@ -53,25 +53,28 @@ class TransformedDistribution(Distribution):
Examples: Examples:
>>> import mindspore >>> import mindspore
>>> import mindspore.context as context
>>> import mindspore.nn as nn >>> import mindspore.nn as nn
>>> import mindspore.nn.probability.distribution as msd >>> import mindspore.nn.probability.distribution as msd
>>> import mindspore.nn.probability.bijector as msb >>> import mindspore.nn.probability.bijector as msb
>>> from mindspore import Tensor >>> from mindspore import Tensor
>>> context.set_context(mode=1) >>> class Net(nn.Cell):
>>> ... def __init__(self, shape, dtype=dtype.float32, seed=0, name='transformed_distribution'):
>>> # To initialize a transformed distribution ... super(Net, self).__init__()
>>> # using a Normal distribution as the base distribution, ... # create TransformedDistribution distribution
>>> # and an Exp bijector as the bijector function. ... self.exp = msb.Exp()
>>> trans_dist = msd.TransformedDistribution(msb.Exp(), msd.Normal(0.0, 1.0)) ... self.normal = msd.Normal(0.0, 1.0, dtype=dtype)
>>> ... self.lognormal = msd.TransformedDistribution(self.exp, self.normal, seed=seed, name=name)
>>> value = Tensor([1.0, 2.0, 3.0], dtype=mindspore.float32) ... self.shape = shape
>>> prob = trans_dist.prob(value) ...
>>> print(prob.shape) ... def construct(self, value):
(3,) ... cdf = self.lognormal.cdf(value)
>>> sample = trans_dist.sample(shape=(2, 3)) ... sample = self.lognormal.sample(self.shape)
>>> print(sample.shape) ... return cdf, sample
(2, 3) >>> shape = (2, 3)
>>> net = Net(shape=shape, name="LogNormal")
>>> x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32)
>>> tx = Tensor(x, dtype=dtype.float32)
>>> cdf, sample = net(tx)
""" """
def __init__(self, def __init__(self,

View File

@ -38,7 +38,7 @@ class Uniform(Distribution):
``Ascend`` ``GPU`` ``Ascend`` ``GPU``
Note: Note:
`low` must be stricly less than `high`. `low` must be strictly less than `high`.
`dist_spec_args` are `high` and `low`. `dist_spec_args` are `high` and `low`.
`dtype` must be float type because Uniform distributions are continuous. `dtype` must be float type because Uniform distributions are continuous.
@ -143,7 +143,8 @@ class Uniform(Distribution):
param = dict(locals()) param = dict(locals())
param['param_dict'] = {'low': low, 'high': high} param['param_dict'] = {'low': low, 'high': high}
valid_dtype = mstype.float_type valid_dtype = mstype.float_type
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) Validator.check_type_name(
"dtype", dtype, valid_dtype, type(self).__name__)
super(Uniform, self).__init__(seed, dtype, name, param) super(Uniform, self).__init__(seed, dtype, name, param)
self._low = self._add_parameter(low, 'low') self._low = self._add_parameter(low, 'low')
@ -151,7 +152,6 @@ class Uniform(Distribution):
if self.low is not None and self.high is not None: if self.low is not None and self.high is not None:
check_greater(self.low, self.high, 'low', 'high') check_greater(self.low, self.high, 'low', 'high')
# ops needed for the class # ops needed for the class
self.exp = exp_generic self.exp = exp_generic
self.log = log_generic self.log = log_generic

View File

@ -20,11 +20,13 @@ import mindspore.nn.probability.distribution as msd
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.ops import operations as P from mindspore.ops import operations as P
class BayesianNet(nn.Cell): class BayesianNet(nn.Cell):
""" """
We currently support 3 types of variables: x = observation, z = latent, y = condition. We currently support 3 types of variables: x = observation, z = latent, y = condition.
A Bayeisian Network models a generative process for certain varaiables: p(x,z|y) or p(z|x,y) or p(x|z,y) A Bayeisian Network models a generative process for certain variables: p(x,z|y) or p(z|x,y) or p(x|z,y)
""" """
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.normal_dist = msd.Normal(dtype=mstype.float32) self.normal_dist = msd.Normal(dtype=mstype.float32)
@ -49,14 +51,16 @@ class BayesianNet(nn.Cell):
if observation is None: if observation is None:
if reparameterize: if reparameterize:
epsilon = self.normal_dist('sample', shape, self.zeros(mean.shape), self.ones(std.shape)) epsilon = self.normal_dist('sample', shape, self.zeros(
mean.shape), self.ones(std.shape))
sample = mean + std * epsilon sample = mean + std * epsilon
else: else:
sample = self.normal_dist('sample', shape, mean, std) sample = self.normal_dist('sample', shape, mean, std)
else: else:
sample = observation sample = observation
log_prob = self.reduce_sum(self.normal_dist('log_prob', sample, mean, std), 1) log_prob = self.reduce_sum(self.normal_dist(
'log_prob', sample, mean, std), 1)
return sample, log_prob return sample, log_prob
def Bernoulli(self, def Bernoulli(self,
@ -77,7 +81,8 @@ class BayesianNet(nn.Cell):
else: else:
sample = observation sample = observation
log_prob = self.reduce_sum(self.bernoulli_dist('log_prob', sample, probs), 1) log_prob = self.reduce_sum(
self.bernoulli_dist('log_prob', sample, probs), 1)
return sample, log_prob return sample, log_prob
def construct(self, *inputs, **kwargs): def construct(self, *inputs, **kwargs):

View File

@ -23,10 +23,12 @@ from mindspore import dtype
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Prob(nn.Cell): class Prob(nn.Cell):
""" """
Test class: probability of Bernoulli distribution. Test class: probability of Bernoulli distribution.
""" """
def __init__(self): def __init__(self):
super(Prob, self).__init__() super(Prob, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32) self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ -34,6 +36,7 @@ class Prob(nn.Cell):
def construct(self, x_): def construct(self, x_):
return self.b.prob(x_) return self.b.prob(x_)
def test_pmf(): def test_pmf():
""" """
Test pmf. Test pmf.
@ -41,7 +44,8 @@ def test_pmf():
bernoulli_benchmark = stats.bernoulli(0.7) bernoulli_benchmark = stats.bernoulli(0.7)
expect_pmf = bernoulli_benchmark.pmf([0, 1, 0, 1, 1]).astype(np.float32) expect_pmf = bernoulli_benchmark.pmf([0, 1, 0, 1, 1]).astype(np.float32)
pmf = Prob() pmf = Prob()
x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32) x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(
np.int32), dtype=dtype.float32)
output = pmf(x_) output = pmf(x_)
tol = 1e-6 tol = 1e-6
assert (np.abs(output.asnumpy() - expect_pmf) < tol).all() assert (np.abs(output.asnumpy() - expect_pmf) < tol).all()
@ -51,6 +55,7 @@ class LogProb(nn.Cell):
""" """
Test class: log probability of Bernoulli distribution. Test class: log probability of Bernoulli distribution.
""" """
def __init__(self): def __init__(self):
super(LogProb, self).__init__() super(LogProb, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32) self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ -58,22 +63,27 @@ class LogProb(nn.Cell):
def construct(self, x_): def construct(self, x_):
return self.b.log_prob(x_) return self.b.log_prob(x_)
def test_log_likelihood(): def test_log_likelihood():
""" """
Test log_pmf. Test log_pmf.
""" """
bernoulli_benchmark = stats.bernoulli(0.7) bernoulli_benchmark = stats.bernoulli(0.7)
expect_logpmf = bernoulli_benchmark.logpmf([0, 1, 0, 1, 1]).astype(np.float32) expect_logpmf = bernoulli_benchmark.logpmf(
[0, 1, 0, 1, 1]).astype(np.float32)
logprob = LogProb() logprob = LogProb()
x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32) x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(
np.int32), dtype=dtype.float32)
output = logprob(x_) output = logprob(x_)
tol = 1e-6 tol = 1e-6
assert (np.abs(output.asnumpy() - expect_logpmf) < tol).all() assert (np.abs(output.asnumpy() - expect_logpmf) < tol).all()
class KL(nn.Cell): class KL(nn.Cell):
""" """
Test class: kl_loss between Bernoulli distributions. Test class: kl_loss between Bernoulli distributions.
""" """
def __init__(self): def __init__(self):
super(KL, self).__init__() super(KL, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32) self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ -81,6 +91,7 @@ class KL(nn.Cell):
def construct(self, x_): def construct(self, x_):
return self.b.kl_loss('Bernoulli', x_) return self.b.kl_loss('Bernoulli', x_)
def test_kl_loss(): def test_kl_loss():
""" """
Test kl_loss. Test kl_loss.
@ -89,16 +100,19 @@ def test_kl_loss():
probs1_b = 0.5 probs1_b = 0.5
probs0_a = 1 - probs1_a probs0_a = 1 - probs1_a
probs0_b = 1 - probs1_b probs0_b = 1 - probs1_b
expect_kl_loss = probs1_a * np.log(probs1_a / probs1_b) + probs0_a * np.log(probs0_a / probs0_b) expect_kl_loss = probs1_a * \
np.log(probs1_a / probs1_b) + probs0_a * np.log(probs0_a / probs0_b)
kl_loss = KL() kl_loss = KL()
output = kl_loss(Tensor([probs1_b], dtype=dtype.float32)) output = kl_loss(Tensor([probs1_b], dtype=dtype.float32))
tol = 1e-6 tol = 1e-6
assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all()
class Basics(nn.Cell): class Basics(nn.Cell):
""" """
Test class: mean/sd/mode of Bernoulli distribution. Test class: mean/sd/mode of Bernoulli distribution.
""" """
def __init__(self): def __init__(self):
super(Basics, self).__init__() super(Basics, self).__init__()
self.b = msd.Bernoulli([0.3, 0.5, 0.7], dtype=dtype.int32) self.b = msd.Bernoulli([0.3, 0.5, 0.7], dtype=dtype.int32)
@ -106,6 +120,7 @@ class Basics(nn.Cell):
def construct(self): def construct(self):
return self.b.mean(), self.b.sd(), self.b.mode() return self.b.mean(), self.b.sd(), self.b.mode()
def test_basics(): def test_basics():
""" """
Test mean/standard deviation/mode. Test mean/standard deviation/mode.
@ -120,10 +135,12 @@ def test_basics():
assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() assert (np.abs(sd.asnumpy() - expect_sd) < tol).all()
assert (np.abs(mode.asnumpy() - expect_mode) < tol).all() assert (np.abs(mode.asnumpy() - expect_mode) < tol).all()
class Sampling(nn.Cell): class Sampling(nn.Cell):
""" """
Test class: log probability of Bernoulli distribution. Test class: log probability of Bernoulli distribution.
""" """
def __init__(self, shape, seed=0): def __init__(self, shape, seed=0):
super(Sampling, self).__init__() super(Sampling, self).__init__()
self.b = msd.Bernoulli([0.7, 0.5], seed=seed, dtype=dtype.int32) self.b = msd.Bernoulli([0.7, 0.5], seed=seed, dtype=dtype.int32)
@ -132,6 +149,7 @@ class Sampling(nn.Cell):
def construct(self, probs=None): def construct(self, probs=None):
return self.b.sample(self.shape, probs) return self.b.sample(self.shape, probs)
def test_sample(): def test_sample():
""" """
Test sample. Test sample.
@ -141,10 +159,12 @@ def test_sample():
output = sample() output = sample()
assert output.shape == (2, 3, 2) assert output.shape == (2, 3, 2)
class CDF(nn.Cell): class CDF(nn.Cell):
""" """
Test class: cdf of bernoulli distributions. Test class: cdf of bernoulli distributions.
""" """
def __init__(self): def __init__(self):
super(CDF, self).__init__() super(CDF, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32) self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ -152,22 +172,26 @@ class CDF(nn.Cell):
def construct(self, x_): def construct(self, x_):
return self.b.cdf(x_) return self.b.cdf(x_)
def test_cdf(): def test_cdf():
""" """
Test cdf. Test cdf.
""" """
bernoulli_benchmark = stats.bernoulli(0.7) bernoulli_benchmark = stats.bernoulli(0.7)
expect_cdf = bernoulli_benchmark.cdf([0, 0, 1, 0, 1]).astype(np.float32) expect_cdf = bernoulli_benchmark.cdf([0, 0, 1, 0, 1]).astype(np.float32)
x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype(np.int32), dtype=dtype.float32) x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype(
np.int32), dtype=dtype.float32)
cdf = CDF() cdf = CDF()
output = cdf(x_) output = cdf(x_)
tol = 1e-6 tol = 1e-6
assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() assert (np.abs(output.asnumpy() - expect_cdf) < tol).all()
class LogCDF(nn.Cell): class LogCDF(nn.Cell):
""" """
Test class: log cdf of bernoulli distributions. Test class: log cdf of bernoulli distributions.
""" """
def __init__(self): def __init__(self):
super(LogCDF, self).__init__() super(LogCDF, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32) self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ -175,13 +199,16 @@ class LogCDF(nn.Cell):
def construct(self, x_): def construct(self, x_):
return self.b.log_cdf(x_) return self.b.log_cdf(x_)
def test_logcdf(): def test_logcdf():
""" """
Test log_cdf. Test log_cdf.
""" """
bernoulli_benchmark = stats.bernoulli(0.7) bernoulli_benchmark = stats.bernoulli(0.7)
expect_logcdf = bernoulli_benchmark.logcdf([0, 0, 1, 0, 1]).astype(np.float32) expect_logcdf = bernoulli_benchmark.logcdf(
x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype(np.int32), dtype=dtype.float32) [0, 0, 1, 0, 1]).astype(np.float32)
x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype(
np.int32), dtype=dtype.float32)
logcdf = LogCDF() logcdf = LogCDF()
output = logcdf(x_) output = logcdf(x_)
tol = 1e-6 tol = 1e-6
@ -192,6 +219,7 @@ class SF(nn.Cell):
""" """
Test class: survival function of Bernoulli distributions. Test class: survival function of Bernoulli distributions.
""" """
def __init__(self): def __init__(self):
super(SF, self).__init__() super(SF, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32) self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ -199,13 +227,16 @@ class SF(nn.Cell):
def construct(self, x_): def construct(self, x_):
return self.b.survival_function(x_) return self.b.survival_function(x_)
def test_survival(): def test_survival():
""" """
Test survival funciton. Test survival function.
""" """
bernoulli_benchmark = stats.bernoulli(0.7) bernoulli_benchmark = stats.bernoulli(0.7)
expect_survival = bernoulli_benchmark.sf([0, 1, 1, 0, 0]).astype(np.float32) expect_survival = bernoulli_benchmark.sf(
x_ = Tensor(np.array([0, 1, 1, 0, 0]).astype(np.int32), dtype=dtype.float32) [0, 1, 1, 0, 0]).astype(np.float32)
x_ = Tensor(np.array([0, 1, 1, 0, 0]).astype(
np.int32), dtype=dtype.float32)
sf = SF() sf = SF()
output = sf(x_) output = sf(x_)
tol = 1e-6 tol = 1e-6
@ -216,6 +247,7 @@ class LogSF(nn.Cell):
""" """
Test class: log survival function of Bernoulli distributions. Test class: log survival function of Bernoulli distributions.
""" """
def __init__(self): def __init__(self):
super(LogSF, self).__init__() super(LogSF, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32) self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ -223,22 +255,27 @@ class LogSF(nn.Cell):
def construct(self, x_): def construct(self, x_):
return self.b.log_survival(x_) return self.b.log_survival(x_)
def test_log_survival(): def test_log_survival():
""" """
Test log survival funciton. Test log survival function.
""" """
bernoulli_benchmark = stats.bernoulli(0.7) bernoulli_benchmark = stats.bernoulli(0.7)
expect_logsurvival = bernoulli_benchmark.logsf([-1, 0.9, 0, 0, 0]).astype(np.float32) expect_logsurvival = bernoulli_benchmark.logsf(
x_ = Tensor(np.array([-1, 0.9, 0, 0, 0]).astype(np.float32), dtype=dtype.float32) [-1, 0.9, 0, 0, 0]).astype(np.float32)
x_ = Tensor(np.array([-1, 0.9, 0, 0, 0]
).astype(np.float32), dtype=dtype.float32)
log_sf = LogSF() log_sf = LogSF()
output = log_sf(x_) output = log_sf(x_)
tol = 1e-6 tol = 1e-6
assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all() assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all()
class EntropyH(nn.Cell): class EntropyH(nn.Cell):
""" """
Test class: entropy of Bernoulli distributions. Test class: entropy of Bernoulli distributions.
""" """
def __init__(self): def __init__(self):
super(EntropyH, self).__init__() super(EntropyH, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32) self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ -246,6 +283,7 @@ class EntropyH(nn.Cell):
def construct(self): def construct(self):
return self.b.entropy() return self.b.entropy()
def test_entropy(): def test_entropy():
""" """
Test entropy. Test entropy.
@ -257,10 +295,12 @@ def test_entropy():
tol = 1e-6 tol = 1e-6
assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() assert (np.abs(output.asnumpy() - expect_entropy) < tol).all()
class CrossEntropy(nn.Cell): class CrossEntropy(nn.Cell):
""" """
Test class: cross entropy between bernoulli distributions. Test class: cross entropy between bernoulli distributions.
""" """
def __init__(self): def __init__(self):
super(CrossEntropy, self).__init__() super(CrossEntropy, self).__init__()
self.b = msd.Bernoulli(0.7, dtype=dtype.int32) self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
@ -272,6 +312,7 @@ class CrossEntropy(nn.Cell):
cross_entropy = self.b.cross_entropy('Bernoulli', x_) cross_entropy = self.b.cross_entropy('Bernoulli', x_)
return h_sum_kl - cross_entropy return h_sum_kl - cross_entropy
def test_cross_entropy(): def test_cross_entropy():
""" """
Test cross_entropy. Test cross_entropy.

View File

@ -24,10 +24,12 @@ from mindspore import dtype
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Prob(nn.Cell): class Prob(nn.Cell):
""" """
Test class: probability of categorical distribution. Test class: probability of categorical distribution.
""" """
def __init__(self): def __init__(self):
super(Prob, self).__init__() super(Prob, self).__init__()
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32) self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
@ -35,13 +37,15 @@ class Prob(nn.Cell):
def construct(self, x_): def construct(self, x_):
return self.c.prob(x_) return self.c.prob(x_)
def test_pmf(): def test_pmf():
""" """
Test pmf. Test pmf.
""" """
expect_pmf = [0.7, 0.3, 0.7, 0.3, 0.3] expect_pmf = [0.7, 0.3, 0.7, 0.3, 0.3]
pmf = Prob() pmf = Prob()
x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32) x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(
np.int32), dtype=dtype.float32)
output = pmf(x_) output = pmf(x_)
tol = 1e-6 tol = 1e-6
assert (np.abs(output.asnumpy() - expect_pmf) < tol).all() assert (np.abs(output.asnumpy() - expect_pmf) < tol).all()
@ -51,6 +55,7 @@ class LogProb(nn.Cell):
""" """
Test class: log probability of categorical distribution. Test class: log probability of categorical distribution.
""" """
def __init__(self): def __init__(self):
super(LogProb, self).__init__() super(LogProb, self).__init__()
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32) self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
@ -58,21 +63,25 @@ class LogProb(nn.Cell):
def construct(self, x_): def construct(self, x_):
return self.c.log_prob(x_) return self.c.log_prob(x_)
def test_log_likelihood(): def test_log_likelihood():
""" """
Test log_pmf. Test log_pmf.
""" """
expect_logpmf = np.log([0.7, 0.3, 0.7, 0.3, 0.3]) expect_logpmf = np.log([0.7, 0.3, 0.7, 0.3, 0.3])
logprob = LogProb() logprob = LogProb()
x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32) x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(
np.int32), dtype=dtype.float32)
output = logprob(x_) output = logprob(x_)
tol = 1e-6 tol = 1e-6
assert (np.abs(output.asnumpy() - expect_logpmf) < tol).all() assert (np.abs(output.asnumpy() - expect_logpmf) < tol).all()
class KL(nn.Cell): class KL(nn.Cell):
""" """
Test class: kl_loss between categorical distributions. Test class: kl_loss between categorical distributions.
""" """
def __init__(self): def __init__(self):
super(KL, self).__init__() super(KL, self).__init__()
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32) self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
@ -80,6 +89,7 @@ class KL(nn.Cell):
def construct(self, x_): def construct(self, x_):
return self.c.kl_loss('Categorical', x_) return self.c.kl_loss('Categorical', x_)
def test_kl_loss(): def test_kl_loss():
""" """
Test kl_loss. Test kl_loss.
@ -89,10 +99,12 @@ def test_kl_loss():
tol = 1e-6 tol = 1e-6
assert (np.abs(output.asnumpy()) < tol).all() assert (np.abs(output.asnumpy()) < tol).all()
class Sampling(nn.Cell): class Sampling(nn.Cell):
""" """
Test class: sampling of categorical distribution. Test class: sampling of categorical distribution.
""" """
def __init__(self): def __init__(self):
super(Sampling, self).__init__() super(Sampling, self).__init__()
self.c = msd.Categorical([0.2, 0.1, 0.7], dtype=dtype.int32) self.c = msd.Categorical([0.2, 0.1, 0.7], dtype=dtype.int32)
@ -101,6 +113,7 @@ class Sampling(nn.Cell):
def construct(self): def construct(self):
return self.c.sample(self.shape) return self.c.sample(self.shape)
def test_sample(): def test_sample():
""" """
Test sample. Test sample.
@ -109,10 +122,12 @@ def test_sample():
sample = Sampling() sample = Sampling()
sample() sample()
class Basics(nn.Cell): class Basics(nn.Cell):
""" """
Test class: mean/var/mode of categorical distribution. Test class: mean/var/mode of categorical distribution.
""" """
def __init__(self): def __init__(self):
super(Basics, self).__init__() super(Basics, self).__init__()
self.c = msd.Categorical([0.2, 0.1, 0.7], dtype=dtype.int32) self.c = msd.Categorical([0.2, 0.1, 0.7], dtype=dtype.int32)
@ -120,6 +135,7 @@ class Basics(nn.Cell):
def construct(self): def construct(self):
return self.c.mean(), self.c.var(), self.c.mode() return self.c.mean(), self.c.var(), self.c.mode()
def test_basics(): def test_basics():
""" """
Test mean/variance/mode. Test mean/variance/mode.
@ -139,6 +155,7 @@ class CDF(nn.Cell):
""" """
Test class: cdf of categorical distributions. Test class: cdf of categorical distributions.
""" """
def __init__(self): def __init__(self):
super(CDF, self).__init__() super(CDF, self).__init__()
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32) self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
@ -146,21 +163,25 @@ class CDF(nn.Cell):
def construct(self, x_): def construct(self, x_):
return self.c.cdf(x_) return self.c.cdf(x_)
def test_cdf(): def test_cdf():
""" """
Test cdf. Test cdf.
""" """
expect_cdf = [0.7, 0.7, 1, 0.7, 1] expect_cdf = [0.7, 0.7, 1, 0.7, 1]
x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype(np.int32), dtype=dtype.float32) x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype(
np.int32), dtype=dtype.float32)
cdf = CDF() cdf = CDF()
output = cdf(x_) output = cdf(x_)
tol = 1e-6 tol = 1e-6
assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() assert (np.abs(output.asnumpy() - expect_cdf) < tol).all()
class LogCDF(nn.Cell): class LogCDF(nn.Cell):
""" """
Test class: log cdf of categorical distributions. Test class: log cdf of categorical distributions.
""" """
def __init__(self): def __init__(self):
super(LogCDF, self).__init__() super(LogCDF, self).__init__()
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32) self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
@ -168,12 +189,14 @@ class LogCDF(nn.Cell):
def construct(self, x_): def construct(self, x_):
return self.c.log_cdf(x_) return self.c.log_cdf(x_)
def test_logcdf(): def test_logcdf():
""" """
Test log_cdf. Test log_cdf.
""" """
expect_logcdf = np.log([0.7, 0.7, 1, 0.7, 1]) expect_logcdf = np.log([0.7, 0.7, 1, 0.7, 1])
x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype(np.int32), dtype=dtype.float32) x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype(
np.int32), dtype=dtype.float32)
logcdf = LogCDF() logcdf = LogCDF()
output = logcdf(x_) output = logcdf(x_)
tol = 1e-6 tol = 1e-6
@ -184,6 +207,7 @@ class SF(nn.Cell):
""" """
Test class: survival function of categorical distributions. Test class: survival function of categorical distributions.
""" """
def __init__(self): def __init__(self):
super(SF, self).__init__() super(SF, self).__init__()
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32) self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
@ -191,12 +215,14 @@ class SF(nn.Cell):
def construct(self, x_): def construct(self, x_):
return self.c.survival_function(x_) return self.c.survival_function(x_)
def test_survival(): def test_survival():
""" """
Test survival funciton. Test survival function.
""" """
expect_survival = [0.3, 0., 0., 0.3, 0.3] expect_survival = [0.3, 0., 0., 0.3, 0.3]
x_ = Tensor(np.array([0, 1, 1, 0, 0]).astype(np.int32), dtype=dtype.float32) x_ = Tensor(np.array([0, 1, 1, 0, 0]).astype(
np.int32), dtype=dtype.float32)
sf = SF() sf = SF()
output = sf(x_) output = sf(x_)
tol = 1e-6 tol = 1e-6
@ -207,6 +233,7 @@ class LogSF(nn.Cell):
""" """
Test class: log survival function of categorical distributions. Test class: log survival function of categorical distributions.
""" """
def __init__(self): def __init__(self):
super(LogSF, self).__init__() super(LogSF, self).__init__()
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32) self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
@ -214,21 +241,25 @@ class LogSF(nn.Cell):
def construct(self, x_): def construct(self, x_):
return self.c.log_survival(x_) return self.c.log_survival(x_)
def test_log_survival(): def test_log_survival():
""" """
Test log survival funciton. Test log survival function.
""" """
expect_logsurvival = np.log([1., 0.3, 0.3, 0.3, 0.3]) expect_logsurvival = np.log([1., 0.3, 0.3, 0.3, 0.3])
x_ = Tensor(np.array([-2, 0, 0, 0.5, 0.5]).astype(np.float32), dtype=dtype.float32) x_ = Tensor(np.array([-2, 0, 0, 0.5, 0.5]
).astype(np.float32), dtype=dtype.float32)
log_sf = LogSF() log_sf = LogSF()
output = log_sf(x_) output = log_sf(x_)
tol = 1e-6 tol = 1e-6
assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all() assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all()
class EntropyH(nn.Cell): class EntropyH(nn.Cell):
""" """
Test class: entropy of categorical distributions. Test class: entropy of categorical distributions.
""" """
def __init__(self): def __init__(self):
super(EntropyH, self).__init__() super(EntropyH, self).__init__()
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32) self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
@ -236,6 +267,7 @@ class EntropyH(nn.Cell):
def construct(self): def construct(self):
return self.c.entropy() return self.c.entropy()
def test_entropy(): def test_entropy():
""" """
Test entropy. Test entropy.
@ -247,10 +279,12 @@ def test_entropy():
tol = 1e-6 tol = 1e-6
assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() assert (np.abs(output.asnumpy() - expect_entropy) < tol).all()
class CrossEntropy(nn.Cell): class CrossEntropy(nn.Cell):
""" """
Test class: cross entropy between categorical distributions. Test class: cross entropy between categorical distributions.
""" """
def __init__(self): def __init__(self):
super(CrossEntropy, self).__init__() super(CrossEntropy, self).__init__()
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32) self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
@ -262,6 +296,7 @@ class CrossEntropy(nn.Cell):
cross_entropy = self.c.cross_entropy('Categorical', x_) cross_entropy = self.c.cross_entropy('Categorical', x_)
return h_sum_kl - cross_entropy return h_sum_kl - cross_entropy
def test_cross_entropy(): def test_cross_entropy():
""" """
Test cross_entropy. Test cross_entropy.

View File

@ -23,10 +23,12 @@ from mindspore import dtype
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Prob(nn.Cell): class Prob(nn.Cell):
""" """
Test class: probability of Geometric distribution. Test class: probability of Geometric distribution.
""" """
def __init__(self): def __init__(self):
super(Prob, self).__init__() super(Prob, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32) self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ -34,6 +36,7 @@ class Prob(nn.Cell):
def construct(self, x_): def construct(self, x_):
return self.g.prob(x_) return self.g.prob(x_)
def test_pmf(): def test_pmf():
""" """
Test pmf. Test pmf.
@ -41,15 +44,18 @@ def test_pmf():
geom_benchmark = stats.geom(0.7) geom_benchmark = stats.geom(0.7)
expect_pmf = geom_benchmark.pmf([0, 1, 2, 3, 4]).astype(np.float32) expect_pmf = geom_benchmark.pmf([0, 1, 2, 3, 4]).astype(np.float32)
pdf = Prob() pdf = Prob()
x_ = Tensor(np.array([-1, 0, 1, 2, 3]).astype(np.float32), dtype=dtype.float32) x_ = Tensor(np.array([-1, 0, 1, 2, 3]
).astype(np.float32), dtype=dtype.float32)
output = pdf(x_) output = pdf(x_)
tol = 1e-6 tol = 1e-6
assert (np.abs(output.asnumpy() - expect_pmf) < tol).all() assert (np.abs(output.asnumpy() - expect_pmf) < tol).all()
class LogProb(nn.Cell): class LogProb(nn.Cell):
""" """
Test class: log probability of Geometric distribution. Test class: log probability of Geometric distribution.
""" """
def __init__(self): def __init__(self):
super(LogProb, self).__init__() super(LogProb, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32) self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ -57,6 +63,7 @@ class LogProb(nn.Cell):
def construct(self, x_): def construct(self, x_):
return self.g.log_prob(x_) return self.g.log_prob(x_)
def test_log_likelihood(): def test_log_likelihood():
""" """
Test log_pmf. Test log_pmf.
@ -64,15 +71,18 @@ def test_log_likelihood():
geom_benchmark = stats.geom(0.7) geom_benchmark = stats.geom(0.7)
expect_logpmf = geom_benchmark.logpmf([1, 2, 3, 4, 5]).astype(np.float32) expect_logpmf = geom_benchmark.logpmf([1, 2, 3, 4, 5]).astype(np.float32)
logprob = LogProb() logprob = LogProb()
x_ = Tensor(np.array([0, 1, 2, 3, 4]).astype(np.int32), dtype=dtype.float32) x_ = Tensor(np.array([0, 1, 2, 3, 4]).astype(
np.int32), dtype=dtype.float32)
output = logprob(x_) output = logprob(x_)
tol = 1e-6 tol = 1e-6
assert (np.abs(output.asnumpy() - expect_logpmf) < tol).all() assert (np.abs(output.asnumpy() - expect_logpmf) < tol).all()
class KL(nn.Cell): class KL(nn.Cell):
""" """
Test class: kl_loss between Geometric distributions. Test class: kl_loss between Geometric distributions.
""" """
def __init__(self): def __init__(self):
super(KL, self).__init__() super(KL, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32) self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ -80,6 +90,7 @@ class KL(nn.Cell):
def construct(self, x_): def construct(self, x_):
return self.g.kl_loss('Geometric', x_) return self.g.kl_loss('Geometric', x_)
def test_kl_loss(): def test_kl_loss():
""" """
Test kl_loss. Test kl_loss.
@ -88,16 +99,19 @@ def test_kl_loss():
probs1_b = 0.5 probs1_b = 0.5
probs0_a = 1 - probs1_a probs0_a = 1 - probs1_a
probs0_b = 1 - probs1_b probs0_b = 1 - probs1_b
expect_kl_loss = np.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * np.log(probs0_a / probs0_b) expect_kl_loss = np.log(probs1_a / probs1_b) + \
(probs0_a / probs1_a) * np.log(probs0_a / probs0_b)
kl_loss = KL() kl_loss = KL()
output = kl_loss(Tensor([probs1_b], dtype=dtype.float32)) output = kl_loss(Tensor([probs1_b], dtype=dtype.float32))
tol = 1e-6 tol = 1e-6
assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all()
class Basics(nn.Cell): class Basics(nn.Cell):
""" """
Test class: mean/sd/mode of Geometric distribution. Test class: mean/sd/mode of Geometric distribution.
""" """
def __init__(self): def __init__(self):
super(Basics, self).__init__() super(Basics, self).__init__()
self.g = msd.Geometric([0.5, 0.5], dtype=dtype.int32) self.g = msd.Geometric([0.5, 0.5], dtype=dtype.int32)
@ -105,6 +119,7 @@ class Basics(nn.Cell):
def construct(self): def construct(self):
return self.g.mean(), self.g.sd(), self.g.mode() return self.g.mean(), self.g.sd(), self.g.mode()
def test_basics(): def test_basics():
""" """
Test mean/standard deviation/mode. Test mean/standard deviation/mode.
@ -115,14 +130,16 @@ def test_basics():
expect_sd = np.sqrt(np.array([0.5, 0.5]) / np.square(np.array([0.5, 0.5]))) expect_sd = np.sqrt(np.array([0.5, 0.5]) / np.square(np.array([0.5, 0.5])))
expect_mode = [0.0, 0.0] expect_mode = [0.0, 0.0]
tol = 1e-6 tol = 1e-6
assert (np.abs(mean.asnumpy()- expect_mean) < tol).all() assert (np.abs(mean.asnumpy() - expect_mean) < tol).all()
assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() assert (np.abs(sd.asnumpy() - expect_sd) < tol).all()
assert (np.abs(mode.asnumpy() - expect_mode) < tol).all() assert (np.abs(mode.asnumpy() - expect_mode) < tol).all()
class Sampling(nn.Cell): class Sampling(nn.Cell):
""" """
Test class: log probability of bernoulli distribution. Test class: log probability of bernoulli distribution.
""" """
def __init__(self, shape, seed=0): def __init__(self, shape, seed=0):
super(Sampling, self).__init__() super(Sampling, self).__init__()
self.g = msd.Geometric([0.7, 0.5], seed=seed, dtype=dtype.int32) self.g = msd.Geometric([0.7, 0.5], seed=seed, dtype=dtype.int32)
@ -131,6 +148,7 @@ class Sampling(nn.Cell):
def construct(self, probs=None): def construct(self, probs=None):
return self.g.sample(self.shape, probs) return self.g.sample(self.shape, probs)
def test_sample(): def test_sample():
""" """
Test sample. Test sample.
@ -140,10 +158,12 @@ def test_sample():
output = sample() output = sample()
assert output.shape == (2, 3, 2) assert output.shape == (2, 3, 2)
class CDF(nn.Cell): class CDF(nn.Cell):
""" """
Test class: cdf of Geometric distribution. Test class: cdf of Geometric distribution.
""" """
def __init__(self): def __init__(self):
super(CDF, self).__init__() super(CDF, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32) self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ -151,22 +171,26 @@ class CDF(nn.Cell):
def construct(self, x_): def construct(self, x_):
return self.g.cdf(x_) return self.g.cdf(x_)
def test_cdf(): def test_cdf():
""" """
Test cdf. Test cdf.
""" """
geom_benchmark = stats.geom(0.7) geom_benchmark = stats.geom(0.7)
expect_cdf = geom_benchmark.cdf([0, 1, 2, 3, 4]).astype(np.float32) expect_cdf = geom_benchmark.cdf([0, 1, 2, 3, 4]).astype(np.float32)
x_ = Tensor(np.array([-1, 0, 1, 2, 3]).astype(np.int32), dtype=dtype.float32) x_ = Tensor(np.array([-1, 0, 1, 2, 3]
).astype(np.int32), dtype=dtype.float32)
cdf = CDF() cdf = CDF()
output = cdf(x_) output = cdf(x_)
tol = 1e-6 tol = 1e-6
assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() assert (np.abs(output.asnumpy() - expect_cdf) < tol).all()
class LogCDF(nn.Cell): class LogCDF(nn.Cell):
""" """
Test class: log cdf of Geometric distribution. Test class: log cdf of Geometric distribution.
""" """
def __init__(self): def __init__(self):
super(LogCDF, self).__init__() super(LogCDF, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32) self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ -174,22 +198,26 @@ class LogCDF(nn.Cell):
def construct(self, x_): def construct(self, x_):
return self.g.log_cdf(x_) return self.g.log_cdf(x_)
def test_logcdf(): def test_logcdf():
""" """
Test log_cdf. Test log_cdf.
""" """
geom_benchmark = stats.geom(0.7) geom_benchmark = stats.geom(0.7)
expect_logcdf = geom_benchmark.logcdf([1, 2, 3, 4, 5]).astype(np.float32) expect_logcdf = geom_benchmark.logcdf([1, 2, 3, 4, 5]).astype(np.float32)
x_ = Tensor(np.array([0, 1, 2, 3, 4]).astype(np.int32), dtype=dtype.float32) x_ = Tensor(np.array([0, 1, 2, 3, 4]).astype(
np.int32), dtype=dtype.float32)
logcdf = LogCDF() logcdf = LogCDF()
output = logcdf(x_) output = logcdf(x_)
tol = 1e-6 tol = 1e-6
assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all() assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all()
class SF(nn.Cell): class SF(nn.Cell):
""" """
Test class: survial funciton of Geometric distribution. Test class: survial function of Geometric distribution.
""" """
def __init__(self): def __init__(self):
super(SF, self).__init__() super(SF, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32) self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ -197,22 +225,26 @@ class SF(nn.Cell):
def construct(self, x_): def construct(self, x_):
return self.g.survival_function(x_) return self.g.survival_function(x_)
def test_survival(): def test_survival():
""" """
Test survival function. Test survival function.
""" """
geom_benchmark = stats.geom(0.7) geom_benchmark = stats.geom(0.7)
expect_survival = geom_benchmark.sf([0, 1, 2, 3, 4]).astype(np.float32) expect_survival = geom_benchmark.sf([0, 1, 2, 3, 4]).astype(np.float32)
x_ = Tensor(np.array([-1, 0, 1, 2, 3]).astype(np.int32), dtype=dtype.float32) x_ = Tensor(np.array([-1, 0, 1, 2, 3]
).astype(np.int32), dtype=dtype.float32)
sf = SF() sf = SF()
output = sf(x_) output = sf(x_)
tol = 1e-6 tol = 1e-6
assert (np.abs(output.asnumpy() - expect_survival) < tol).all() assert (np.abs(output.asnumpy() - expect_survival) < tol).all()
class LogSF(nn.Cell): class LogSF(nn.Cell):
""" """
Test class: log survial funciton of Geometric distribution. Test class: log survial function of Geometric distribution.
""" """
def __init__(self): def __init__(self):
super(LogSF, self).__init__() super(LogSF, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32) self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ -220,22 +252,27 @@ class LogSF(nn.Cell):
def construct(self, x_): def construct(self, x_):
return self.g.log_survival(x_) return self.g.log_survival(x_)
def test_log_survival(): def test_log_survival():
""" """
Test log_survival function. Test log_survival function.
""" """
geom_benchmark = stats.geom(0.7) geom_benchmark = stats.geom(0.7)
expect_logsurvival = geom_benchmark.logsf([0, 1, 2, 3, 4]).astype(np.float32) expect_logsurvival = geom_benchmark.logsf(
x_ = Tensor(np.array([-1, 0, 1, 2, 3]).astype(np.float32), dtype=dtype.float32) [0, 1, 2, 3, 4]).astype(np.float32)
x_ = Tensor(np.array([-1, 0, 1, 2, 3]
).astype(np.float32), dtype=dtype.float32)
log_sf = LogSF() log_sf = LogSF()
output = log_sf(x_) output = log_sf(x_)
tol = 5e-6 tol = 5e-6
assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all() assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all()
class EntropyH(nn.Cell): class EntropyH(nn.Cell):
""" """
Test class: entropy of Geometric distribution. Test class: entropy of Geometric distribution.
""" """
def __init__(self): def __init__(self):
super(EntropyH, self).__init__() super(EntropyH, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32) self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ -243,6 +280,7 @@ class EntropyH(nn.Cell):
def construct(self): def construct(self):
return self.g.entropy() return self.g.entropy()
def test_entropy(): def test_entropy():
""" """
Test entropy. Test entropy.
@ -254,10 +292,12 @@ def test_entropy():
tol = 1e-6 tol = 1e-6
assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() assert (np.abs(output.asnumpy() - expect_entropy) < tol).all()
class CrossEntropy(nn.Cell): class CrossEntropy(nn.Cell):
""" """
Test class: cross entropy between Geometric distributions. Test class: cross entropy between Geometric distributions.
""" """
def __init__(self): def __init__(self):
super(CrossEntropy, self).__init__() super(CrossEntropy, self).__init__()
self.g = msd.Geometric(0.7, dtype=dtype.int32) self.g = msd.Geometric(0.7, dtype=dtype.int32)
@ -269,6 +309,7 @@ class CrossEntropy(nn.Cell):
ans = self.g.cross_entropy('Geometric', x_) ans = self.g.cross_entropy('Geometric', x_)
return h_sum_kl - ans return h_sum_kl - ans
def test_cross_entropy(): def test_cross_entropy():
""" """
Test cross_entropy. Test cross_entropy.

View File

@ -43,20 +43,29 @@ def create_dataset(data_path, batch_size=32, repeat_size=1,
shift = 0.0 shift = 0.0
# define map operations # define map operations
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # resize images to (32, 32) # resize images to (32, 32)
resize_op = CV.Resize((resize_height, resize_width),
interpolation=Inter.LINEAR)
rescale_op = CV.Rescale(rescale, shift) # rescale images rescale_op = CV.Rescale(rescale, shift) # rescale images
hwc2chw_op = CV.HWC2CHW() # change shape from (height, width, channel) to (channel, height, width) to fit network. # change shape from (height, width, channel) to (channel, height, width) to fit network.
type_cast_op = C.TypeCast(mstype.int32) # change data type of label to int32 to fit network hwc2chw_op = CV.HWC2CHW()
# change data type of label to int32 to fit network
type_cast_op = C.TypeCast(mstype.int32)
# apply map operations on images # apply map operations on images
mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op,
mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers) num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op,
mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers) num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op,
num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op,
num_parallel_workers=num_parallel_workers)
# apply DatasetOps # apply DatasetOps
buffer_size = 10000 buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script # 10000 as in LeNet train script
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
mnist_ds = mnist_ds.repeat(repeat_size) mnist_ds = mnist_ds.repeat(repeat_size)
@ -68,7 +77,7 @@ def save_img(data, name, size=32, num=32):
Visualize data and save to target files Visualize data and save to target files
Args: Args:
data: nparray of size (num, size, size) data: nparray of size (num, size, size)
name: ouput file name name: output file name
size: image size size: image size
num: number of images num: number of images
""" """

View File

@ -62,15 +62,17 @@ def test_prob():
with pytest.raises(ValueError): with pytest.raises(ValueError):
msd.Categorical([1.0], dtype=dtype.int32) msd.Categorical([1.0], dtype=dtype.int32)
def test_categorical_sum(): def test_categorical_sum():
""" """
Invaild probabilities. Invalid probabilities.
""" """
with pytest.raises(ValueError): with pytest.raises(ValueError):
msd.Categorical([[0.1, 0.2], [0.4, 0.6]], dtype=dtype.int32) msd.Categorical([[0.1, 0.2], [0.4, 0.6]], dtype=dtype.int32)
with pytest.raises(ValueError): with pytest.raises(ValueError):
msd.Categorical([[0.5, 0.7], [0.6, 0.6]], dtype=dtype.int32) msd.Categorical([[0.5, 0.7], [0.6, 0.6]], dtype=dtype.int32)
def rank(): def rank():
""" """
Rank dimenshion less than 1. Rank dimenshion less than 1.
@ -80,7 +82,9 @@ def rank():
with pytest.raises(ValueError): with pytest.raises(ValueError):
msd.Categorical(np.array(0.3).astype(np.float32), dtype=dtype.int32) msd.Categorical(np.array(0.3).astype(np.float32), dtype=dtype.int32)
with pytest.raises(ValueError): with pytest.raises(ValueError):
msd.Categorical(Tensor(np.array(0.3).astype(np.float32)), dtype=dtype.int32) msd.Categorical(
Tensor(np.array(0.3).astype(np.float32)), dtype=dtype.int32)
class CategoricalProb(nn.Cell): class CategoricalProb(nn.Cell):
""" """
@ -211,6 +215,7 @@ class CategoricalConstruct(nn.Cell):
prob2 = self.c1('prob', value, probs) prob2 = self.c1('prob', value, probs)
return prob + prob1 + prob2 return prob + prob1 + prob2
def test_categorical_construct(): def test_categorical_construct():
""" """
Test probability function going through construct. Test probability function going through construct.
@ -235,7 +240,7 @@ class CategoricalBasics(nn.Cell):
def construct(self, probs): def construct(self, probs):
basics1 = self.c.mean() + self.c.var() + self.c.mode() + self.c.entropy() basics1 = self.c.mean() + self.c.var() + self.c.mode() + self.c.entropy()
basics2 = self.c1.mean(probs) + self.c1.var(probs) +\ basics2 = self.c1.mean(probs) + self.c1.var(probs) +\
self.c1.mode(probs) + self.c1.entropy(probs) self.c1.mode(probs) + self.c1.entropy(probs)
return basics1 + basics2 return basics1 + basics2

View File

@ -29,19 +29,23 @@ func_name_list = ['prob', 'log_prob', 'cdf', 'log_cdf',
'entropy', 'kl_loss', 'cross_entropy', 'entropy', 'kl_loss', 'cross_entropy',
'sample'] 'sample']
class MyExponential(msd.Distribution): class MyExponential(msd.Distribution):
""" """
Test distirbution class: no function is implemented. Test distribution class: no function is implemented.
""" """
def __init__(self, rate=None, seed=None, dtype=mstype.float32, name="MyExponential"): def __init__(self, rate=None, seed=None, dtype=mstype.float32, name="MyExponential"):
param = dict(locals()) param = dict(locals())
param['param_dict'] = {'rate': rate} param['param_dict'] = {'rate': rate}
super(MyExponential, self).__init__(seed, dtype, name, param) super(MyExponential, self).__init__(seed, dtype, name, param)
class Net(nn.Cell): class Net(nn.Cell):
""" """
Test Net: function called through construct. Test Net: function called through construct.
""" """
def __init__(self, func_name): def __init__(self, func_name):
super(Net, self).__init__() super(Net, self).__init__()
self.dist = MyExponential() self.dist = MyExponential()
@ -61,6 +65,7 @@ def test_raise_not_implemented_error_construct():
net = Net(func_name) net = Net(func_name)
net(value) net(value)
def test_raise_not_implemented_error_construct_graph_mode(): def test_raise_not_implemented_error_construct_graph_mode():
""" """
test raise not implemented error in graph mode. test raise not implemented error in graph mode.
@ -72,10 +77,12 @@ def test_raise_not_implemented_error_construct_graph_mode():
net = Net(func_name) net = Net(func_name)
net(value) net(value)
class Net1(nn.Cell): class Net1(nn.Cell):
""" """
Test Net: function called directly. Test Net: function called directly.
""" """
def __init__(self, func_name): def __init__(self, func_name):
super(Net1, self).__init__() super(Net1, self).__init__()
self.dist = MyExponential() self.dist = MyExponential()
@ -84,6 +91,7 @@ class Net1(nn.Cell):
def construct(self, *args, **kwargs): def construct(self, *args, **kwargs):
return self.func(*args, **kwargs) return self.func(*args, **kwargs)
def test_raise_not_implemented_error(): def test_raise_not_implemented_error():
""" """
test raise not implemented error in pynative mode. test raise not implemented error in pynative mode.
@ -94,6 +102,7 @@ def test_raise_not_implemented_error():
net = Net1(func_name) net = Net1(func_name)
net(value) net(value)
def test_raise_not_implemented_error_graph_mode(): def test_raise_not_implemented_error_graph_mode():
""" """
test raise not implemented error in graph mode. test raise not implemented error in graph mode.

View File

@ -23,6 +23,7 @@ import mindspore.nn.probability.distribution as msd
from mindspore import dtype from mindspore import dtype
from mindspore import Tensor from mindspore import Tensor
def test_uniform_shape_errpr(): def test_uniform_shape_errpr():
""" """
Invalid shapes. Invalid shapes.
@ -30,18 +31,22 @@ def test_uniform_shape_errpr():
with pytest.raises(ValueError): with pytest.raises(ValueError):
msd.Uniform([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32) msd.Uniform([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32)
def test_type(): def test_type():
with pytest.raises(TypeError): with pytest.raises(TypeError):
msd.Uniform(0., 1., dtype=dtype.int32) msd.Uniform(0., 1., dtype=dtype.int32)
def test_name(): def test_name():
with pytest.raises(TypeError): with pytest.raises(TypeError):
msd.Uniform(0., 1., name=1.0) msd.Uniform(0., 1., name=1.0)
def test_seed(): def test_seed():
with pytest.raises(TypeError): with pytest.raises(TypeError):
msd.Uniform(0., 1., seed='seed') msd.Uniform(0., 1., seed='seed')
def test_arguments(): def test_arguments():
""" """
Args passing during initialization. Args passing during initialization.
@ -66,6 +71,7 @@ class UniformProb(nn.Cell):
""" """
Uniform distribution: initialize with low/high. Uniform distribution: initialize with low/high.
""" """
def __init__(self): def __init__(self):
super(UniformProb, self).__init__() super(UniformProb, self).__init__()
self.u = msd.Uniform(3.0, 4.0, dtype=dtype.float32) self.u = msd.Uniform(3.0, 4.0, dtype=dtype.float32)
@ -79,6 +85,7 @@ class UniformProb(nn.Cell):
log_sf = self.u.log_survival(value) log_sf = self.u.log_survival(value)
return prob + log_prob + cdf + log_cdf + sf + log_sf return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_uniform_prob(): def test_uniform_prob():
""" """
Test probability functions: passing value through construct. Test probability functions: passing value through construct.
@ -88,10 +95,12 @@ def test_uniform_prob():
ans = net(value) ans = net(value)
assert isinstance(ans, Tensor) assert isinstance(ans, Tensor)
class UniformProb1(nn.Cell): class UniformProb1(nn.Cell):
""" """
Uniform distribution: initialize without low/high. Uniform distribution: initialize without low/high.
""" """
def __init__(self): def __init__(self):
super(UniformProb1, self).__init__() super(UniformProb1, self).__init__()
self.u = msd.Uniform(dtype=dtype.float32) self.u = msd.Uniform(dtype=dtype.float32)
@ -105,6 +114,7 @@ class UniformProb1(nn.Cell):
log_sf = self.u.log_survival(value, low, high) log_sf = self.u.log_survival(value, low, high)
return prob + log_prob + cdf + log_cdf + sf + log_sf return prob + log_prob + cdf + log_cdf + sf + log_sf
def test_uniform_prob1(): def test_uniform_prob1():
""" """
Test probability functions: passing low/high, value through construct. Test probability functions: passing low/high, value through construct.
@ -116,13 +126,16 @@ def test_uniform_prob1():
ans = net(value, low, high) ans = net(value, low, high)
assert isinstance(ans, Tensor) assert isinstance(ans, Tensor)
class UniformKl(nn.Cell): class UniformKl(nn.Cell):
""" """
Test class: kl_loss of Uniform distribution. Test class: kl_loss of Uniform distribution.
""" """
def __init__(self): def __init__(self):
super(UniformKl, self).__init__() super(UniformKl, self).__init__()
self.u1 = msd.Uniform(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) self.u1 = msd.Uniform(
np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
self.u2 = msd.Uniform(dtype=dtype.float32) self.u2 = msd.Uniform(dtype=dtype.float32)
def construct(self, low_b, high_b, low_a, high_a): def construct(self, low_b, high_b, low_a, high_a):
@ -130,6 +143,7 @@ class UniformKl(nn.Cell):
kl2 = self.u2.kl_loss('Uniform', low_b, high_b, low_a, high_a) kl2 = self.u2.kl_loss('Uniform', low_b, high_b, low_a, high_a)
return kl1 + kl2 return kl1 + kl2
def test_kl(): def test_kl():
""" """
Test kl_loss. Test kl_loss.
@ -142,13 +156,16 @@ def test_kl():
ans = net(low_b, high_b, low_a, high_a) ans = net(low_b, high_b, low_a, high_a)
assert isinstance(ans, Tensor) assert isinstance(ans, Tensor)
class UniformCrossEntropy(nn.Cell): class UniformCrossEntropy(nn.Cell):
""" """
Test class: cross_entropy of Uniform distribution. Test class: cross_entropy of Uniform distribution.
""" """
def __init__(self): def __init__(self):
super(UniformCrossEntropy, self).__init__() super(UniformCrossEntropy, self).__init__()
self.u1 = msd.Uniform(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) self.u1 = msd.Uniform(
np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
self.u2 = msd.Uniform(dtype=dtype.float32) self.u2 = msd.Uniform(dtype=dtype.float32)
def construct(self, low_b, high_b, low_a, high_a): def construct(self, low_b, high_b, low_a, high_a):
@ -156,9 +173,10 @@ class UniformCrossEntropy(nn.Cell):
h2 = self.u2.cross_entropy('Uniform', low_b, high_b, low_a, high_a) h2 = self.u2.cross_entropy('Uniform', low_b, high_b, low_a, high_a)
return h1 + h2 return h1 + h2
def test_cross_entropy(): def test_cross_entropy():
""" """
Test cross_entropy between Unifrom distributions. Test cross_entropy between Uniform distributions.
""" """
net = UniformCrossEntropy() net = UniformCrossEntropy()
low_b = Tensor(np.array([0.0]).astype(np.float32), dtype=dtype.float32) low_b = Tensor(np.array([0.0]).astype(np.float32), dtype=dtype.float32)
@ -168,10 +186,12 @@ def test_cross_entropy():
ans = net(low_b, high_b, low_a, high_a) ans = net(low_b, high_b, low_a, high_a)
assert isinstance(ans, Tensor) assert isinstance(ans, Tensor)
class UniformBasics(nn.Cell): class UniformBasics(nn.Cell):
""" """
Test class: basic mean/sd/var/mode/entropy function. Test class: basic mean/sd/var/mode/entropy function.
""" """
def __init__(self): def __init__(self):
super(UniformBasics, self).__init__() super(UniformBasics, self).__init__()
self.u = msd.Uniform(3.0, 4.0, dtype=dtype.float32) self.u = msd.Uniform(3.0, 4.0, dtype=dtype.float32)
@ -183,6 +203,7 @@ class UniformBasics(nn.Cell):
entropy = self.u.entropy() entropy = self.u.entropy()
return mean + sd + var + entropy return mean + sd + var + entropy
def test_bascis(): def test_bascis():
""" """
Test mean/sd/var/mode/entropy functionality of Uniform. Test mean/sd/var/mode/entropy functionality of Uniform.
@ -194,8 +215,9 @@ def test_bascis():
class UniConstruct(nn.Cell): class UniConstruct(nn.Cell):
""" """
Unifrom distribution: going through construct. Uniform distribution: going through construct.
""" """
def __init__(self): def __init__(self):
super(UniConstruct, self).__init__() super(UniConstruct, self).__init__()
self.u = msd.Uniform(-4.0, 4.0) self.u = msd.Uniform(-4.0, 4.0)
@ -207,6 +229,7 @@ class UniConstruct(nn.Cell):
prob2 = self.u1('prob', value, low, high) prob2 = self.u1('prob', value, low, high)
return prob + prob1 + prob2 return prob + prob1 + prob2
def test_uniform_construct(): def test_uniform_construct():
""" """
Test probability function going through construct. Test probability function going through construct.