forked from mindspore-Ecosystem/mindspore
!6020 Add common check_param_type and set_param_type in distribution
Merge pull request !6020 from XunDeng/new_check_param
This commit is contained in:
commit
c10341dfb7
|
@ -110,6 +110,8 @@ def check_scalar_from_param(params):
|
|||
Notes: String parameters are excluded.
|
||||
"""
|
||||
for value in params.values():
|
||||
if value is None:
|
||||
continue
|
||||
if isinstance(value, (msp.bijector.Bijector, msp.distribution.Distribution)):
|
||||
return params['distribution'].is_scalar_batch
|
||||
if isinstance(value, Parameter):
|
||||
|
@ -358,23 +360,29 @@ class CheckTensor(PrimitiveWithInfer):
|
|||
return x
|
||||
raise TypeError(f"For {name}, input type should be a Tensor or Parameter.")
|
||||
|
||||
def common_dtype(arg_a, name_a, arg_b, name_b, hint_type):
|
||||
def set_param_type(args, hint_type):
|
||||
"""
|
||||
check if arg_a and arg_b have the same dtype.
|
||||
Find the common type among arguments.
|
||||
|
||||
Args:
|
||||
args (dict): dictionary of arguments, {'name':value}.
|
||||
hint_type (mindspore.dtype): hint type to return.
|
||||
|
||||
Raises:
|
||||
TypeError: if tensors in args are not the same dtype.
|
||||
"""
|
||||
if hasattr(arg_a, 'dtype') and hasattr(arg_b, 'dtype'):
|
||||
if isinstance(arg_a, np.ndarray):
|
||||
a_dtype = mstype.pytype_to_dtype(arg_a.dtype)
|
||||
else:
|
||||
a_dtype = arg_a.dtype
|
||||
if isinstance(arg_b, np.ndarray):
|
||||
b_dtype = mstype.pytype_to_dtype(arg_b.dtype)
|
||||
else:
|
||||
b_dtype = arg_b.dtype
|
||||
if a_dtype != b_dtype:
|
||||
raise TypeError(f"{name_a} and {name_b} should have the same dtype.")
|
||||
int_type = mstype.int_type + mstype.uint_type
|
||||
if a_dtype in int_type or a_dtype == mstype.float64:
|
||||
return mstype.float32
|
||||
return a_dtype
|
||||
return hint_type
|
||||
common_dtype = None
|
||||
for name, arg in args.items():
|
||||
if hasattr(arg, 'dtype'):
|
||||
if isinstance(arg, np.ndarray):
|
||||
cur_dtype = mstype.pytype_to_dtype(arg.dtype)
|
||||
else:
|
||||
cur_dtype = arg.dtype
|
||||
if common_dtype is None:
|
||||
common_dtype = cur_dtype
|
||||
elif cur_dtype != common_dtype:
|
||||
raise TypeError(f"{name} should have the same dtype as other arguments.")
|
||||
int_type = mstype.int_type + mstype.uint_type
|
||||
if common_dtype in int_type or common_dtype == mstype.float64:
|
||||
return mstype.float32
|
||||
return hint_type if common_dtype is None else common_dtype
|
||||
|
|
|
@ -17,7 +17,7 @@ from mindspore.common import dtype as mstype
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error
|
||||
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, set_param_type
|
||||
from ._utils.custom_ops import exp_generic, log_generic, erf_generic
|
||||
|
||||
|
||||
|
@ -119,13 +119,16 @@ class Bernoulli(Distribution):
|
|||
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
|
||||
check_type(dtype, valid_dtype, type(self).__name__)
|
||||
super(Bernoulli, self).__init__(seed, dtype, name, param)
|
||||
self.parameter_type = mstype.float32
|
||||
self.parameter_type = set_param_type({'probs1': probs}, mstype.float32)
|
||||
if probs is not None:
|
||||
self._probs = cast_to_tensor(probs, mstype.float32)
|
||||
self._probs = cast_to_tensor(probs, self.parameter_type)
|
||||
check_prob(self.probs)
|
||||
else:
|
||||
self._probs = probs
|
||||
|
||||
self.default_parameters = [self.probs]
|
||||
self.parameter_names = ['probs1']
|
||||
|
||||
# ops needed for the class
|
||||
self.exp = exp_generic
|
||||
self.log = log_generic
|
||||
|
@ -157,24 +160,12 @@ class Bernoulli(Distribution):
|
|||
"""
|
||||
return self._probs
|
||||
|
||||
def _check_param(self, probs1):
|
||||
"""
|
||||
Check availablity of distribution specific args `probs1`.
|
||||
"""
|
||||
if probs1 is not None:
|
||||
if self.context_mode == 0:
|
||||
self.checktensor(probs1, 'probs1')
|
||||
else:
|
||||
probs1 = self.checktensor(probs1, 'probs1')
|
||||
return self.cast(probs1, self.parameter_type)
|
||||
return self.probs if self.probs is not None else raise_none_error('probs1')
|
||||
|
||||
def _mean(self, probs1=None):
|
||||
r"""
|
||||
.. math::
|
||||
MEAN(B) = probs1
|
||||
"""
|
||||
probs1 = self._check_param(probs1)
|
||||
probs1 = self._check_param_type(probs1)
|
||||
return probs1
|
||||
|
||||
def _mode(self, probs1=None):
|
||||
|
@ -182,7 +173,7 @@ class Bernoulli(Distribution):
|
|||
.. math::
|
||||
MODE(B) = 1 if probs1 > 0.5 else = 0
|
||||
"""
|
||||
probs1 = self._check_param(probs1)
|
||||
probs1 = self._check_param_type(probs1)
|
||||
prob_type = self.dtypeop(probs1)
|
||||
zeros = self.fill(prob_type, self.shape(probs1), 0.0)
|
||||
ones = self.fill(prob_type, self.shape(probs1), 1.0)
|
||||
|
@ -194,7 +185,7 @@ class Bernoulli(Distribution):
|
|||
.. math::
|
||||
VAR(B) = probs1 * probs0
|
||||
"""
|
||||
probs1 = self._check_param(probs1)
|
||||
probs1 = self._check_param_type(probs1)
|
||||
probs0 = 1.0 - probs1
|
||||
return self.exp(self.log(probs0) + self.log(probs1))
|
||||
|
||||
|
@ -203,11 +194,11 @@ class Bernoulli(Distribution):
|
|||
.. math::
|
||||
H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1)
|
||||
"""
|
||||
probs1 = self._check_param(probs1)
|
||||
probs1 = self._check_param_type(probs1)
|
||||
probs0 = 1.0 - probs1
|
||||
return -(probs0 * self.log(probs0)) - (probs1 * self.log(probs1))
|
||||
|
||||
def _cross_entropy(self, dist, probs1_b, probs1_a=None):
|
||||
def _cross_entropy(self, dist, probs1_b, probs1=None):
|
||||
"""
|
||||
Evaluate cross_entropy between Bernoulli distributions.
|
||||
|
||||
|
@ -217,7 +208,7 @@ class Bernoulli(Distribution):
|
|||
probs1_a (Tensor): `probs1` of distribution a. Default: self.probs.
|
||||
"""
|
||||
check_distribution_name(dist, 'Bernoulli')
|
||||
return self._entropy(probs1_a) + self._kl_loss(dist, probs1_b, probs1_a)
|
||||
return self._entropy(probs1) + self._kl_loss(dist, probs1_b, probs1)
|
||||
|
||||
def _log_prob(self, value, probs1=None):
|
||||
r"""
|
||||
|
@ -233,7 +224,7 @@ class Bernoulli(Distribution):
|
|||
"""
|
||||
value = self._check_value(value, 'value')
|
||||
value = self.cast(value, mstype.float32)
|
||||
probs1 = self._check_param(probs1)
|
||||
probs1 = self._check_param_type(probs1)
|
||||
probs0 = 1.0 - probs1
|
||||
return self.log(probs1) * value + self.log(probs0) * (1.0 - value)
|
||||
|
||||
|
@ -253,7 +244,7 @@ class Bernoulli(Distribution):
|
|||
value = self._check_value(value, 'value')
|
||||
value = self.cast(value, mstype.float32)
|
||||
value = self.floor(value)
|
||||
probs1 = self._check_param(probs1)
|
||||
probs1 = self._check_param_type(probs1)
|
||||
prob_type = self.dtypeop(probs1)
|
||||
value = value * self.fill(prob_type, self.shape(probs1), 1.0)
|
||||
probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0)
|
||||
|
@ -264,7 +255,7 @@ class Bernoulli(Distribution):
|
|||
less_than_zero = self.select(comp_zero, zeros, probs0)
|
||||
return self.select(comp_one, less_than_zero, ones)
|
||||
|
||||
def _kl_loss(self, dist, probs1_b, probs1_a=None):
|
||||
def _kl_loss(self, dist, probs1_b, probs1=None):
|
||||
r"""
|
||||
Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b).
|
||||
|
||||
|
@ -280,7 +271,7 @@ class Bernoulli(Distribution):
|
|||
check_distribution_name(dist, 'Bernoulli')
|
||||
probs1_b = self._check_value(probs1_b, 'probs1_b')
|
||||
probs1_b = self.cast(probs1_b, self.parameter_type)
|
||||
probs1_a = self._check_param(probs1_a)
|
||||
probs1_a = self._check_param_type(probs1)
|
||||
probs0_a = 1.0 - probs1_a
|
||||
probs0_b = 1.0 - probs1_b
|
||||
return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b)
|
||||
|
@ -297,7 +288,7 @@ class Bernoulli(Distribution):
|
|||
Tensor, shape is shape + batch_shape.
|
||||
"""
|
||||
shape = self.checktuple(shape, 'shape')
|
||||
probs1 = self._check_param(probs1)
|
||||
probs1 = self._check_param_type(probs1)
|
||||
origin_shape = shape + self.shape(probs1)
|
||||
if origin_shape == ():
|
||||
sample_shape = (1,)
|
||||
|
|
|
@ -18,7 +18,8 @@ from mindspore.nn.cell import Cell
|
|||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from mindspore.common import get_seed
|
||||
from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param, cast_type_for_device
|
||||
from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param, cast_type_for_device,\
|
||||
raise_none_error
|
||||
from ._utils.utils import CheckTuple, CheckTensor
|
||||
|
||||
|
||||
|
@ -115,6 +116,51 @@ class Distribution(Cell):
|
|||
def broadcast_shape(self):
|
||||
return self._broadcast_shape
|
||||
|
||||
def _check_param_type(self, *args):
|
||||
"""
|
||||
Check the availability and validity of default parameters and dist_spec_args.
|
||||
dist_spec_args passed in must be tensors. If default parameter of the distribution
|
||||
is None, its parameter must be passed in through `args`.
|
||||
"""
|
||||
broadcast_shape = None
|
||||
common_dtype = None
|
||||
out = []
|
||||
|
||||
for arg, name, default in zip(args, self.parameter_names, self.default_parameters):
|
||||
# check if the argument is a Tensor
|
||||
if arg is not None:
|
||||
if self.context_mode == 0:
|
||||
self.checktensor(arg, name)
|
||||
else:
|
||||
arg = self.checktensor(arg, name)
|
||||
else:
|
||||
arg = default if default is not None else raise_none_error(name)
|
||||
|
||||
# broadcast if the number of args > 1
|
||||
if broadcast_shape is None:
|
||||
broadcast_shape = self.shape(arg)
|
||||
common_dtype = self.dtypeop(arg)
|
||||
else:
|
||||
ones = self.fill(self.dtypeop(arg), broadcast_shape, 1.0)
|
||||
broadcast_shape = self.shape(arg + ones)
|
||||
|
||||
# check if the arguments have the same dtype
|
||||
arg = arg * self.fill(self.dtypeop(arg), broadcast_shape, 1.0)
|
||||
dtype_tensor = self.fill(common_dtype, broadcast_shape, 1.0)
|
||||
self.sametypeshape(arg, dtype_tensor)
|
||||
arg = self.cast(arg, self.parameter_type)
|
||||
out.append(arg)
|
||||
|
||||
if len(out) == 1:
|
||||
return out[0]
|
||||
|
||||
# broadcast all args to broadcast_shape
|
||||
result = ()
|
||||
for arg in out:
|
||||
arg = arg * self.fill(self.dtypeop(arg), broadcast_shape, 1.0)
|
||||
result = result + (arg,)
|
||||
return result
|
||||
|
||||
def _check_value(self, value, name):
|
||||
"""
|
||||
Check availability of `value` as a Tensor.
|
||||
|
@ -211,163 +257,203 @@ class Distribution(Cell):
|
|||
if hasattr(self, '_cross_entropy'):
|
||||
self._call_cross_entropy = self._cross_entropy
|
||||
|
||||
def log_prob(self, *args, **kwargs):
|
||||
def log_prob(self, value, *args, **kwargs):
|
||||
"""
|
||||
Evaluate the log probability(pdf or pmf) at the given value.
|
||||
|
||||
Note:
|
||||
The argument `args` must include `value`.
|
||||
dist_spec_args are optional.
|
||||
"""
|
||||
return self._call_log_prob(*args, **kwargs)
|
||||
Args:
|
||||
value (Tensor): value to be evaluated.
|
||||
*args (list): the list of positional arguments forwarded to subclasses.
|
||||
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
|
||||
|
||||
def _calc_prob_from_log_prob(self, *args, **kwargs):
|
||||
Note:
|
||||
A distribution can be optionally passed to the function by passing its dist_spec_args through
|
||||
`args` or `kwargs`.
|
||||
"""
|
||||
return self._call_log_prob(value, *args, **kwargs)
|
||||
|
||||
def _calc_prob_from_log_prob(self, value, *args, **kwargs):
|
||||
r"""
|
||||
Evaluate prob from log probability.
|
||||
|
||||
.. math::
|
||||
probability(x) = \exp(log_likehood(x))
|
||||
"""
|
||||
return self.exp(self._log_prob(*args, **kwargs))
|
||||
return self.exp(self._log_prob(value, *args, **kwargs))
|
||||
|
||||
def prob(self, *args, **kwargs):
|
||||
def prob(self, value, *args, **kwargs):
|
||||
"""
|
||||
Evaluate the probability (pdf or pmf) at given value.
|
||||
|
||||
Note:
|
||||
The argument `args` must include `value`.
|
||||
dist_spec_args are optional.
|
||||
"""
|
||||
return self._call_prob(*args, **kwargs)
|
||||
Args:
|
||||
value (Tensor): value to be evaluated.
|
||||
*args (list): the list of positional arguments forwarded to subclasses.
|
||||
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
|
||||
|
||||
def _calc_log_prob_from_prob(self, *args, **kwargs):
|
||||
Note:
|
||||
A distribution can be optionally passed to the function by passing its dist_spec_args through
|
||||
`args` or `kwargs`.
|
||||
"""
|
||||
return self._call_prob(value, *args, **kwargs)
|
||||
|
||||
def _calc_log_prob_from_prob(self, value, *args, **kwargs):
|
||||
r"""
|
||||
Evaluate log probability from probability.
|
||||
|
||||
.. math::
|
||||
log_prob(x) = \log(prob(x))
|
||||
"""
|
||||
return self.log(self._prob(*args, **kwargs))
|
||||
return self.log(self._prob(value, *args, **kwargs))
|
||||
|
||||
def cdf(self, *args, **kwargs):
|
||||
def cdf(self, value, *args, **kwargs):
|
||||
"""
|
||||
Evaluate the cdf at given value.
|
||||
|
||||
Note:
|
||||
The argument `args` must include `value`.
|
||||
dist_spec_args are optional.
|
||||
"""
|
||||
return self._call_cdf(*args, **kwargs)
|
||||
Args:
|
||||
value (Tensor): value to be evaluated.
|
||||
*args (list): the list of positional arguments forwarded to subclasses.
|
||||
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
|
||||
|
||||
def _calc_cdf_from_log_cdf(self, *args, **kwargs):
|
||||
Note:
|
||||
A distribution can be optionally passed to the function by passing its dist_spec_args through
|
||||
`args` or `kwargs`.
|
||||
"""
|
||||
return self._call_cdf(value, *args, **kwargs)
|
||||
|
||||
def _calc_cdf_from_log_cdf(self, value, *args, **kwargs):
|
||||
r"""
|
||||
Evaluate cdf from log_cdf.
|
||||
|
||||
.. math::
|
||||
cdf(x) = \exp(log_cdf(x))
|
||||
"""
|
||||
return self.exp(self._log_cdf(*args, **kwargs))
|
||||
return self.exp(self._log_cdf(value, *args, **kwargs))
|
||||
|
||||
def _calc_cdf_from_survival(self, *args, **kwargs):
|
||||
def _calc_cdf_from_survival(self, value, *args, **kwargs):
|
||||
r"""
|
||||
Evaluate cdf from survival function.
|
||||
|
||||
.. math::
|
||||
cdf(x) = 1 - (survival_function(x))
|
||||
"""
|
||||
return 1.0 - self._survival_function(*args, **kwargs)
|
||||
return 1.0 - self._survival_function(value, *args, **kwargs)
|
||||
|
||||
def _calc_cdf_from_log_survival(self, *args, **kwargs):
|
||||
def _calc_cdf_from_log_survival(self, value, *args, **kwargs):
|
||||
r"""
|
||||
Evaluate cdf from log survival function.
|
||||
|
||||
.. math::
|
||||
cdf(x) = 1 - (\exp(log_survival(x)))
|
||||
"""
|
||||
return 1.0 - self.exp(self._log_survival(*args, **kwargs))
|
||||
return 1.0 - self.exp(self._log_survival(value, *args, **kwargs))
|
||||
|
||||
def log_cdf(self, *args, **kwargs):
|
||||
def log_cdf(self, value, *args, **kwargs):
|
||||
"""
|
||||
Evaluate the log cdf at given value.
|
||||
|
||||
Note:
|
||||
The argument `args` must include `value`.
|
||||
dist_spec_args are optional.
|
||||
"""
|
||||
return self._call_log_cdf(*args, **kwargs)
|
||||
Args:
|
||||
value (Tensor): value to be evaluated.
|
||||
*args (list): the list of positional arguments forwarded to subclasses.
|
||||
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
|
||||
|
||||
def _calc_log_cdf_from_call_cdf(self, *args, **kwargs):
|
||||
Note:
|
||||
A distribution can be optionally passed to the function by passing its dist_spec_args through
|
||||
`args` or `kwargs`.
|
||||
"""
|
||||
return self._call_log_cdf(value, *args, **kwargs)
|
||||
|
||||
def _calc_log_cdf_from_call_cdf(self, value, *args, **kwargs):
|
||||
r"""
|
||||
Evaluate log cdf from cdf.
|
||||
|
||||
.. math::
|
||||
log_cdf(x) = \log(cdf(x))
|
||||
"""
|
||||
return self.log(self._call_cdf(*args, **kwargs))
|
||||
return self.log(self._call_cdf(value, *args, **kwargs))
|
||||
|
||||
def survival_function(self, *args, **kwargs):
|
||||
def survival_function(self, value, *args, **kwargs):
|
||||
"""
|
||||
Evaluate the survival function at given value.
|
||||
|
||||
Note:
|
||||
The argument `args` must include `value`.
|
||||
dist_spec_args are optional.
|
||||
"""
|
||||
return self._call_survival(*args, **kwargs)
|
||||
Args:
|
||||
value (Tensor): value to be evaluated.
|
||||
*args (list): the list of positional arguments forwarded to subclasses.
|
||||
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
|
||||
|
||||
def _calc_survival_from_call_cdf(self, *args, **kwargs):
|
||||
Note:
|
||||
A distribution can be optionally passed to the function by passing its dist_spec_args through
|
||||
`args` or `kwargs`.
|
||||
"""
|
||||
return self._call_survival(value, *args, **kwargs)
|
||||
|
||||
def _calc_survival_from_call_cdf(self, value, *args, **kwargs):
|
||||
r"""
|
||||
Evaluate survival function from cdf.
|
||||
|
||||
.. math::
|
||||
survival_function(x) = 1 - (cdf(x))
|
||||
"""
|
||||
return 1.0 - self._call_cdf(*args, **kwargs)
|
||||
return 1.0 - self._call_cdf(value, *args, **kwargs)
|
||||
|
||||
def _calc_survival_from_log_survival(self, *args, **kwargs):
|
||||
def _calc_survival_from_log_survival(self, value, *args, **kwargs):
|
||||
r"""
|
||||
Evaluate survival function from log survival function.
|
||||
|
||||
.. math::
|
||||
survival(x) = \exp(survival_function(x))
|
||||
"""
|
||||
return self.exp(self._log_survival(*args, **kwargs))
|
||||
return self.exp(self._log_survival(value, *args, **kwargs))
|
||||
|
||||
def log_survival(self, *args, **kwargs):
|
||||
def log_survival(self, value, *args, **kwargs):
|
||||
"""
|
||||
Evaluate the log survival function at given value.
|
||||
|
||||
Note:
|
||||
The arguments `args` must include `value`.
|
||||
dist_spec_args are optional.
|
||||
"""
|
||||
return self._call_log_survival(*args, **kwargs)
|
||||
Args:
|
||||
value (Tensor): value to be evaluated.
|
||||
*args (list): the list of positional arguments forwarded to subclasses.
|
||||
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
|
||||
|
||||
def _calc_log_survival_from_call_survival(self, *args, **kwargs):
|
||||
Note:
|
||||
A distribution can be optionally passed to the function by passing its dist_spec_args through
|
||||
`args` or `kwargs`.
|
||||
"""
|
||||
return self._call_log_survival(value, *args, **kwargs)
|
||||
|
||||
def _calc_log_survival_from_call_survival(self, value, *args, **kwargs):
|
||||
r"""
|
||||
Evaluate log survival function from survival function.
|
||||
|
||||
.. math::
|
||||
log_survival(x) = \log(survival_function(x))
|
||||
"""
|
||||
return self.log(self._call_survival(*args, **kwargs))
|
||||
return self.log(self._call_survival(value, *args, **kwargs))
|
||||
|
||||
def kl_loss(self, *args, **kwargs):
|
||||
def kl_loss(self, dist, *args, **kwargs):
|
||||
"""
|
||||
Evaluate the KL divergence, i.e. KL(a||b).
|
||||
|
||||
Args:
|
||||
dist (str): type of the distribution.
|
||||
*args (list): the list of positional arguments forwarded to subclasses.
|
||||
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
|
||||
|
||||
Note:
|
||||
The argument `args` must include the type of the distribution, parameters of distribution b.
|
||||
Parameters for distribution a are optional.
|
||||
dist_spec_args of distribution b must be passed to the function through `args` or `kwargs`.
|
||||
Passing in dist_spec_args of distribution a is optional.
|
||||
"""
|
||||
return self._kl_loss(*args, **kwargs)
|
||||
return self._kl_loss(dist, *args, **kwargs)
|
||||
|
||||
def mean(self, *args, **kwargs):
|
||||
"""
|
||||
Evaluate the mean.
|
||||
|
||||
Args:
|
||||
*args (list): the list of positional arguments forwarded to subclasses.
|
||||
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
|
||||
|
||||
Note:
|
||||
dist_spec_args are optional.
|
||||
A distribution can be optionally passed to the function by passing its *dist_spec_args* through
|
||||
*args* or *kwargs*.
|
||||
"""
|
||||
return self._mean(*args, **kwargs)
|
||||
|
||||
|
@ -375,8 +461,13 @@ class Distribution(Cell):
|
|||
"""
|
||||
Evaluate the mode.
|
||||
|
||||
Args:
|
||||
*args (list): the list of positional arguments forwarded to subclasses.
|
||||
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
|
||||
|
||||
Note:
|
||||
dist_spec_args are optional.
|
||||
A distribution can be optionally passed to the function by passing its *dist_spec_args* through
|
||||
*args* or *kwargs*.
|
||||
"""
|
||||
return self._mode(*args, **kwargs)
|
||||
|
||||
|
@ -384,8 +475,13 @@ class Distribution(Cell):
|
|||
"""
|
||||
Evaluate the standard deviation.
|
||||
|
||||
Args:
|
||||
*args (list): the list of positional arguments forwarded to subclasses.
|
||||
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
|
||||
|
||||
Note:
|
||||
dist_spec_args are optional.
|
||||
A distribution can be optionally passed to the function by passing its *dist_spec_args* through
|
||||
*args* or *kwargs*.
|
||||
"""
|
||||
return self._call_sd(*args, **kwargs)
|
||||
|
||||
|
@ -393,8 +489,13 @@ class Distribution(Cell):
|
|||
"""
|
||||
Evaluate the variance.
|
||||
|
||||
Args:
|
||||
*args (list): the list of positional arguments forwarded to subclasses.
|
||||
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
|
||||
|
||||
Note:
|
||||
dist_spec_args are optional.
|
||||
A distribution can be optionally passed to the function by passing its *dist_spec_args* through
|
||||
*args* or *kwargs*.
|
||||
"""
|
||||
return self._call_var(*args, **kwargs)
|
||||
|
||||
|
@ -420,37 +521,52 @@ class Distribution(Cell):
|
|||
"""
|
||||
Evaluate the entropy.
|
||||
|
||||
Args:
|
||||
*args (list): the list of positional arguments forwarded to subclasses.
|
||||
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
|
||||
|
||||
Note:
|
||||
dist_spec_args are optional.
|
||||
A distribution can be optionally passed to the function by passing its *dist_spec_args* through
|
||||
*args* or *kwargs*.
|
||||
"""
|
||||
return self._entropy(*args, **kwargs)
|
||||
|
||||
def cross_entropy(self, *args, **kwargs):
|
||||
def cross_entropy(self, dist, *args, **kwargs):
|
||||
"""
|
||||
Evaluate the cross_entropy between distribution a and b.
|
||||
|
||||
Note:
|
||||
The argument `args` must include the type of the distribution, parameters of distribution b.
|
||||
Parameters for distribution a are optional.
|
||||
"""
|
||||
return self._call_cross_entropy(*args, **kwargs)
|
||||
Args:
|
||||
dist (str): type of the distribution.
|
||||
*args (list): the list of positional arguments forwarded to subclasses.
|
||||
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
|
||||
|
||||
def _calc_cross_entropy(self, *args, **kwargs):
|
||||
Note:
|
||||
dist_spec_args of distribution b must be passed to the function through `args` or `kwargs`.
|
||||
Passing in dist_spec_args of distribution a is optional.
|
||||
"""
|
||||
return self._call_cross_entropy(dist, *args, **kwargs)
|
||||
|
||||
def _calc_cross_entropy(self, dist, *args, **kwargs):
|
||||
r"""
|
||||
Evaluate cross_entropy from entropy and kl divergence.
|
||||
|
||||
.. math::
|
||||
H(X, Y) = H(X) + KL(X||Y)
|
||||
"""
|
||||
return self._entropy(*args, **kwargs) + self._kl_loss(*args, **kwargs)
|
||||
return self._entropy(*args, **kwargs) + self._kl_loss(dist, *args, **kwargs)
|
||||
|
||||
def sample(self, *args, **kwargs):
|
||||
"""
|
||||
Sampling function.
|
||||
|
||||
Args:
|
||||
shape (tuple): shape of the sample.
|
||||
*args (list): the list of positional arguments forwarded to subclasses.
|
||||
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
|
||||
|
||||
Note:
|
||||
Shape of the sample is default to ().
|
||||
dist_spec_args are optional.
|
||||
A distribution can be optionally passed to the function by passing its *dist_spec_args* through
|
||||
*args* or *kwargs*.
|
||||
"""
|
||||
return self._sample(*args, **kwargs)
|
||||
|
||||
|
|
|
@ -18,8 +18,7 @@ from mindspore.ops import operations as P
|
|||
from mindspore.ops import composite as C
|
||||
from mindspore.common import dtype as mstype
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\
|
||||
raise_none_error
|
||||
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name, set_param_type
|
||||
from ._utils.custom_ops import exp_generic, log_generic
|
||||
|
||||
class Exponential(Distribution):
|
||||
|
@ -121,15 +120,19 @@ class Exponential(Distribution):
|
|||
valid_dtype = mstype.float_type
|
||||
check_type(dtype, valid_dtype, type(self).__name__)
|
||||
super(Exponential, self).__init__(seed, dtype, name, param)
|
||||
self.parameter_type = dtype
|
||||
self.parameter_type = set_param_type({'rate': rate}, self.dtype)
|
||||
if rate is not None:
|
||||
self._rate = cast_to_tensor(rate, self.parameter_type)
|
||||
check_greater_zero(self._rate, "rate")
|
||||
else:
|
||||
self._rate = rate
|
||||
|
||||
self.default_parameters = [self.rate]
|
||||
self.parameter_names = ['rate']
|
||||
|
||||
self.minval = np.finfo(np.float).tiny
|
||||
|
||||
|
||||
# ops needed for the class
|
||||
self.exp = exp_generic
|
||||
self.log = log_generic
|
||||
|
@ -156,28 +159,16 @@ class Exponential(Distribution):
|
|||
@property
|
||||
def rate(self):
|
||||
"""
|
||||
Return rate of the distribution.
|
||||
Return `rate` of the distribution.
|
||||
"""
|
||||
return self._rate
|
||||
|
||||
def _check_param(self, rate):
|
||||
"""
|
||||
Check availablity of distribution specific argument `rate`.
|
||||
"""
|
||||
if rate is not None:
|
||||
if self.context_mode == 0:
|
||||
self.checktensor(rate, 'rate')
|
||||
else:
|
||||
rate = self.checktensor(rate, 'rate')
|
||||
return self.cast(rate, self.parameter_type)
|
||||
return self.rate if self.rate is not None else raise_none_error('rate')
|
||||
|
||||
def _mean(self, rate=None):
|
||||
r"""
|
||||
.. math::
|
||||
MEAN(EXP) = \frac{1.0}{\lambda}.
|
||||
"""
|
||||
rate = self._check_param(rate)
|
||||
rate = self._check_param_type(rate)
|
||||
return 1.0 / rate
|
||||
|
||||
def _mode(self, rate=None):
|
||||
|
@ -185,7 +176,7 @@ class Exponential(Distribution):
|
|||
.. math::
|
||||
MODE(EXP) = 0.
|
||||
"""
|
||||
rate = self._check_param(rate)
|
||||
rate = self._check_param_type(rate)
|
||||
return self.fill(self.dtype, self.shape(rate), 0.)
|
||||
|
||||
def _sd(self, rate=None):
|
||||
|
@ -193,7 +184,7 @@ class Exponential(Distribution):
|
|||
.. math::
|
||||
sd(EXP) = \frac{1.0}{\lambda}.
|
||||
"""
|
||||
rate = self._check_param(rate)
|
||||
rate = self._check_param_type(rate)
|
||||
return 1.0 / rate
|
||||
|
||||
def _entropy(self, rate=None):
|
||||
|
@ -201,7 +192,7 @@ class Exponential(Distribution):
|
|||
.. math::
|
||||
H(Exp) = 1 - \log(\lambda).
|
||||
"""
|
||||
rate = self._check_param(rate)
|
||||
rate = self._check_param_type(rate)
|
||||
return 1.0 - self.log(rate)
|
||||
|
||||
def _cross_entropy(self, dist, rate_b, rate=None):
|
||||
|
@ -234,7 +225,7 @@ class Exponential(Distribution):
|
|||
"""
|
||||
value = self._check_value(value, "value")
|
||||
value = self.cast(value, self.dtype)
|
||||
rate = self._check_param(rate)
|
||||
rate = self._check_param_type(rate)
|
||||
prob = self.log(rate) - rate * value
|
||||
zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0)
|
||||
neginf = self.fill(self.dtypeop(prob), self.shape(prob), -np.inf)
|
||||
|
@ -257,7 +248,7 @@ class Exponential(Distribution):
|
|||
"""
|
||||
value = self._check_value(value, 'value')
|
||||
value = self.cast(value, self.dtype)
|
||||
rate = self._check_param(rate)
|
||||
rate = self._check_param_type(rate)
|
||||
cdf = 1.0 - self.exp(-1. * rate * value)
|
||||
zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
|
||||
comp = self.less(value, zeros)
|
||||
|
@ -279,7 +270,7 @@ class Exponential(Distribution):
|
|||
"""
|
||||
value = self._check_value(value, 'value')
|
||||
value = self.cast(value, self.dtype)
|
||||
rate = self._check_param(rate)
|
||||
rate = self._check_param_type(rate)
|
||||
sf = -1. * rate * value
|
||||
zeros = self.fill(self.dtypeop(sf), self.shape(sf), 0.0)
|
||||
comp = self.less(value, zeros)
|
||||
|
@ -297,7 +288,7 @@ class Exponential(Distribution):
|
|||
check_distribution_name(dist, 'Exponential')
|
||||
rate_b = self._check_value(rate_b, 'rate_b')
|
||||
rate_b = self.cast(rate_b, self.parameter_type)
|
||||
rate_a = self._check_param(rate)
|
||||
rate_a = self._check_param_type(rate)
|
||||
return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0
|
||||
|
||||
def _sample(self, shape=(), rate=None):
|
||||
|
@ -312,7 +303,7 @@ class Exponential(Distribution):
|
|||
Tensor, shape is shape + batch_shape.
|
||||
"""
|
||||
shape = self.checktuple(shape, 'shape')
|
||||
rate = self._check_param(rate)
|
||||
rate = self._check_param_type(rate)
|
||||
origin_shape = shape + self.shape(rate)
|
||||
if origin_shape == ():
|
||||
sample_shape = (1,)
|
||||
|
|
|
@ -19,7 +19,7 @@ from mindspore.ops import composite as C
|
|||
from mindspore.common import dtype as mstype
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\
|
||||
raise_none_error
|
||||
set_param_type
|
||||
from ._utils.custom_ops import exp_generic, log_generic
|
||||
|
||||
|
||||
|
@ -123,13 +123,16 @@ class Geometric(Distribution):
|
|||
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
|
||||
check_type(dtype, valid_dtype, type(self).__name__)
|
||||
super(Geometric, self).__init__(seed, dtype, name, param)
|
||||
self.parameter_type = mstype.float32
|
||||
self.parameter_type = set_param_type({'probs1': probs}, mstype.float32)
|
||||
if probs is not None:
|
||||
self._probs = cast_to_tensor(probs, self.parameter_type)
|
||||
check_prob(self._probs)
|
||||
else:
|
||||
self._probs = probs
|
||||
|
||||
self.default_parameters = [self.probs]
|
||||
self.parameter_names = ['probs1']
|
||||
|
||||
self.minval = np.finfo(np.float).tiny
|
||||
|
||||
# ops needed for the class
|
||||
|
@ -164,24 +167,12 @@ class Geometric(Distribution):
|
|||
"""
|
||||
return self._probs
|
||||
|
||||
def _check_param(self, probs1):
|
||||
"""
|
||||
Check availablity of distribution specific args probs1.
|
||||
"""
|
||||
if probs1 is not None:
|
||||
if self.context_mode == 0:
|
||||
self.checktensor(probs1, 'probs1')
|
||||
else:
|
||||
probs1 = self.checktensor(probs1, 'probs1')
|
||||
return self.cast(probs1, self.parameter_type)
|
||||
return self.probs if self.probs is not None else raise_none_error('probs1')
|
||||
|
||||
def _mean(self, probs1=None):
|
||||
r"""
|
||||
.. math::
|
||||
MEAN(Geo) = \fratc{1 - probs1}{probs1}
|
||||
"""
|
||||
probs1 = self._check_param(probs1)
|
||||
probs1 = self._check_param_type(probs1)
|
||||
return (1. - probs1) / probs1
|
||||
|
||||
def _mode(self, probs1=None):
|
||||
|
@ -189,7 +180,7 @@ class Geometric(Distribution):
|
|||
.. math::
|
||||
MODE(Geo) = 0
|
||||
"""
|
||||
probs1 = self._check_param(probs1)
|
||||
probs1 = self._check_param_type(probs1)
|
||||
return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.)
|
||||
|
||||
def _var(self, probs1=None):
|
||||
|
@ -197,7 +188,7 @@ class Geometric(Distribution):
|
|||
.. math::
|
||||
VAR(Geo) = \frac{1 - probs1}{probs1 ^ {2}}
|
||||
"""
|
||||
probs1 = self._check_param(probs1)
|
||||
probs1 = self._check_param_type(probs1)
|
||||
return (1.0 - probs1) / self.sq(probs1)
|
||||
|
||||
def _entropy(self, probs1=None):
|
||||
|
@ -205,7 +196,7 @@ class Geometric(Distribution):
|
|||
.. math::
|
||||
H(Geo) = \frac{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1}
|
||||
"""
|
||||
probs1 = self._check_param(probs1)
|
||||
probs1 = self._check_param_type(probs1)
|
||||
probs0 = 1.0 - probs1
|
||||
return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1
|
||||
|
||||
|
@ -236,7 +227,7 @@ class Geometric(Distribution):
|
|||
value = self._check_value(value, 'value')
|
||||
value = self.cast(value, mstype.float32)
|
||||
value = self.floor(value)
|
||||
probs1 = self._check_param(probs1)
|
||||
probs1 = self._check_param_type(probs1)
|
||||
pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1))
|
||||
zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0)
|
||||
comp = self.less(value, zeros)
|
||||
|
@ -258,7 +249,7 @@ class Geometric(Distribution):
|
|||
value = self._check_value(value, 'value')
|
||||
value = self.cast(value, mstype.float32)
|
||||
value = self.floor(value)
|
||||
probs1 = self._check_param(probs1)
|
||||
probs1 = self._check_param_type(probs1)
|
||||
probs0 = 1.0 - probs1
|
||||
cdf = 1.0 - self.pow(probs0, value + 1.0)
|
||||
zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0)
|
||||
|
@ -280,7 +271,7 @@ class Geometric(Distribution):
|
|||
check_distribution_name(dist, 'Geometric')
|
||||
probs1_b = self._check_value(probs1_b, 'probs1_b')
|
||||
probs1_b = self.cast(probs1_b, self.parameter_type)
|
||||
probs1_a = self._check_param(probs1)
|
||||
probs1_a = self._check_param_type(probs1)
|
||||
probs0_a = 1.0 - probs1_a
|
||||
probs0_b = 1.0 - probs1_b
|
||||
return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b)
|
||||
|
@ -297,7 +288,7 @@ class Geometric(Distribution):
|
|||
Tensor, shape is shape + batch_shape.
|
||||
"""
|
||||
shape = self.checktuple(shape, 'shape')
|
||||
probs1 = self._check_param(probs1)
|
||||
probs1 = self._check_param_type(probs1)
|
||||
origin_shape = shape + self.shape(probs1)
|
||||
if origin_shape == ():
|
||||
sample_shape = (1,)
|
||||
|
|
|
@ -19,7 +19,7 @@ from mindspore.ops import composite as C
|
|||
from mindspore.common import dtype as mstype
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\
|
||||
raise_none_error, common_dtype
|
||||
set_param_type
|
||||
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, erf_generic
|
||||
|
||||
class Normal(Distribution):
|
||||
|
@ -127,14 +127,17 @@ class Normal(Distribution):
|
|||
valid_dtype = mstype.float_type
|
||||
check_type(dtype, valid_dtype, type(self).__name__)
|
||||
super(Normal, self).__init__(seed, dtype, name, param)
|
||||
self.parameter_type = common_dtype(mean, 'mean', sd, 'sd', self.dtype)
|
||||
self.parameter_type = set_param_type({'mean': mean, 'sd': sd}, self.dtype)
|
||||
if mean is not None and sd is not None:
|
||||
self._mean_value = cast_to_tensor(mean, self.parameter_type)
|
||||
self._sd_value = cast_to_tensor(sd, self.parameter_type)
|
||||
check_greater_zero(self._sd_value, "Standard deviation")
|
||||
else:
|
||||
self._mean_value = mean
|
||||
self._sd_value = sd
|
||||
self._mean_value = mean if mean is None else cast_to_tensor(mean, self.parameter_type)
|
||||
self._sd_value = sd if sd is None else cast_to_tensor(sd, self.parameter_type)
|
||||
|
||||
self.default_parameters = [self._mean_value, self._sd_value]
|
||||
self.parameter_names = ['mean', 'sd']
|
||||
|
||||
#ops needed for the class
|
||||
self.exp = exp_generic
|
||||
|
@ -159,51 +162,25 @@ class Normal(Distribution):
|
|||
str_info = f'batch_shape = {self._broadcast_shape}'
|
||||
return str_info
|
||||
|
||||
def _check_param(self, mean, sd):
|
||||
"""
|
||||
Check availablity of distribution specific args `mean` and `sd`.
|
||||
"""
|
||||
if mean is not None:
|
||||
if self.context_mode == 0:
|
||||
self.checktensor(mean, 'mean')
|
||||
else:
|
||||
mean = self.checktensor(mean, 'mean')
|
||||
else:
|
||||
mean = self._mean_value if self._mean_value is not None else raise_none_error('mean')
|
||||
if sd is not None:
|
||||
if self.context_mode == 0:
|
||||
self.checktensor(sd, 'sd')
|
||||
else:
|
||||
sd = self.checktensor(sd, 'sd')
|
||||
else:
|
||||
sd = self._sd_value if self._sd_value is not None else raise_none_error('sd')
|
||||
batch_shape = self.shape(mean + sd)
|
||||
mean = mean * self.fill(self.dtypeop(mean), batch_shape, 1.0)
|
||||
sd = sd * self.fill(self.dtypeop(sd), batch_shape, 1.0)
|
||||
self.sametypeshape(mean, sd)
|
||||
mean = self.cast(mean, self.parameter_type)
|
||||
sd = self.cast(sd, self.parameter_type)
|
||||
return mean, sd
|
||||
|
||||
def _mean(self, mean=None, sd=None):
|
||||
"""
|
||||
The mean of the distribution.
|
||||
"""
|
||||
mean, sd = self._check_param(mean, sd)
|
||||
mean, sd = self._check_param_type(mean, sd)
|
||||
return mean
|
||||
|
||||
def _mode(self, mean=None, sd=None):
|
||||
"""
|
||||
The mode of the distribution.
|
||||
"""
|
||||
mean, sd = self._check_param(mean, sd)
|
||||
mean, sd = self._check_param_type(mean, sd)
|
||||
return mean
|
||||
|
||||
def _sd(self, mean=None, sd=None):
|
||||
"""
|
||||
The standard deviation of the distribution.
|
||||
"""
|
||||
mean, sd = self._check_param(mean, sd)
|
||||
mean, sd = self._check_param_type(mean, sd)
|
||||
return sd
|
||||
|
||||
def _entropy(self, mean=None, sd=None):
|
||||
|
@ -213,7 +190,7 @@ class Normal(Distribution):
|
|||
.. math::
|
||||
H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma)))
|
||||
"""
|
||||
mean, sd = self._check_param(mean, sd)
|
||||
mean, sd = self._check_param_type(mean, sd)
|
||||
return self.log(self.sqrt(self.const(np.e * 2. * np.pi))) + self.log(sd)
|
||||
|
||||
def _cross_entropy(self, dist, mean_b, sd_b, mean=None, sd=None):
|
||||
|
@ -244,7 +221,7 @@ class Normal(Distribution):
|
|||
"""
|
||||
value = self._check_value(value, 'value')
|
||||
value = self.cast(value, self.dtype)
|
||||
mean, sd = self._check_param(mean, sd)
|
||||
mean, sd = self._check_param_type(mean, sd)
|
||||
unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd))
|
||||
neg_normalization = -1. * self.log(self.const(2. * np.pi)) / 2. - self.log(sd)
|
||||
return unnormalized_log_prob + neg_normalization
|
||||
|
@ -263,7 +240,7 @@ class Normal(Distribution):
|
|||
"""
|
||||
value = self._check_value(value, 'value')
|
||||
value = self.cast(value, self.dtype)
|
||||
mean, sd = self._check_param(mean, sd)
|
||||
mean, sd = self._check_param_type(mean, sd)
|
||||
sqrt2 = self.sqrt(self.const(2.0))
|
||||
adjusted = (value - mean) / (sd * sqrt2)
|
||||
return 0.5 * (1.0 + self.erf(adjusted))
|
||||
|
@ -288,7 +265,7 @@ class Normal(Distribution):
|
|||
sd_b = self._check_value(sd_b, 'sd_b')
|
||||
mean_b = self.cast(mean_b, self.parameter_type)
|
||||
sd_b = self.cast(sd_b, self.parameter_type)
|
||||
mean_a, sd_a = self._check_param(mean, sd)
|
||||
mean_a, sd_a = self._check_param_type(mean, sd)
|
||||
diff_log_scale = self.log(sd_a) - self.log(sd_b)
|
||||
squared_diff = self.sq(mean_a / sd_b - mean_b / sd_b)
|
||||
return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale
|
||||
|
@ -306,7 +283,7 @@ class Normal(Distribution):
|
|||
Tensor, shape is shape + batch_shape.
|
||||
"""
|
||||
shape = self.checktuple(shape, 'shape')
|
||||
mean, sd = self._check_param(mean, sd)
|
||||
mean, sd = self._check_param_type(mean, sd)
|
||||
batch_shape = self.shape(mean + sd)
|
||||
origin_shape = shape + batch_shape
|
||||
if origin_shape == ():
|
||||
|
|
|
@ -18,7 +18,7 @@ from mindspore.ops import composite as C
|
|||
from mindspore.common import dtype as mstype
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import cast_to_tensor, check_greater, check_type, check_distribution_name,\
|
||||
raise_none_error, common_dtype
|
||||
set_param_type
|
||||
from ._utils.custom_ops import exp_generic, log_generic
|
||||
|
||||
class Uniform(Distribution):
|
||||
|
@ -126,14 +126,17 @@ class Uniform(Distribution):
|
|||
valid_dtype = mstype.float_type
|
||||
check_type(dtype, valid_dtype, type(self).__name__)
|
||||
super(Uniform, self).__init__(seed, dtype, name, param)
|
||||
self.parameter_type = common_dtype(low, 'low', high, 'high', self.dtype)
|
||||
self.parameter_type = set_param_type({'low': low, 'high': high}, self.dtype)
|
||||
if low is not None and high is not None:
|
||||
self._low = cast_to_tensor(low, dtype)
|
||||
self._high = cast_to_tensor(high, dtype)
|
||||
self._low = cast_to_tensor(low, self.parameter_type)
|
||||
self._high = cast_to_tensor(high, self.parameter_type)
|
||||
check_greater(self.low, self.high, "low value", "high value")
|
||||
else:
|
||||
self._low = low
|
||||
self._high = high
|
||||
self._low = low if low is None else cast_to_tensor(low, self.parameter_type)
|
||||
self._high = high if high is None else cast_to_tensor(high, self.parameter_type)
|
||||
|
||||
self.default_parameters = [self.low, self.high]
|
||||
self.parameter_names = ['low', 'high']
|
||||
|
||||
# ops needed for the class
|
||||
self.exp = exp_generic
|
||||
|
@ -162,32 +165,6 @@ class Uniform(Distribution):
|
|||
str_info = f'batch_shape = {self._broadcast_shape}'
|
||||
return str_info
|
||||
|
||||
def _check_param(self, low, high):
|
||||
"""
|
||||
Check availablity of distribution specific args `low` and `high`.
|
||||
"""
|
||||
if low is not None:
|
||||
if self.context_mode == 0:
|
||||
self.checktensor(low, 'low')
|
||||
else:
|
||||
low = self.checktensor(low, 'low')
|
||||
else:
|
||||
low = self.low if self.low is not None else raise_none_error('low')
|
||||
if high is not None:
|
||||
if self.context_mode == 0:
|
||||
self.checktensor(high, 'high')
|
||||
else:
|
||||
high = self.checktensor(high, 'high')
|
||||
else:
|
||||
high = self.high if self.high is not None else raise_none_error('high')
|
||||
batch_shape = self.shape(high - low)
|
||||
high = high * self.fill(self.dtypeop(high), batch_shape, 1.0)
|
||||
low = low * self.fill(self.dtypeop(low), batch_shape, 1.0)
|
||||
self.sametypeshape(high, low)
|
||||
low = self.cast(low, self.parameter_type)
|
||||
high = self.cast(high, self.parameter_type)
|
||||
return low, high
|
||||
|
||||
@property
|
||||
def low(self):
|
||||
"""
|
||||
|
@ -209,7 +186,7 @@ class Uniform(Distribution):
|
|||
.. math::
|
||||
range(U) = high -low
|
||||
"""
|
||||
low, high = self._check_param(low, high)
|
||||
low, high = self._check_param_type(low, high)
|
||||
return high - low
|
||||
|
||||
def _mean(self, low=None, high=None):
|
||||
|
@ -217,7 +194,7 @@ class Uniform(Distribution):
|
|||
.. math::
|
||||
MEAN(U) = \frac{low + high}{2}.
|
||||
"""
|
||||
low, high = self._check_param(low, high)
|
||||
low, high = self._check_param_type(low, high)
|
||||
return (low + high) / 2.
|
||||
|
||||
def _var(self, low=None, high=None):
|
||||
|
@ -225,7 +202,7 @@ class Uniform(Distribution):
|
|||
.. math::
|
||||
VAR(U) = \frac{(high -low) ^ 2}{12}.
|
||||
"""
|
||||
low, high = self._check_param(low, high)
|
||||
low, high = self._check_param_type(low, high)
|
||||
return self.sq(high - low) / 12.0
|
||||
|
||||
def _entropy(self, low=None, high=None):
|
||||
|
@ -233,7 +210,7 @@ class Uniform(Distribution):
|
|||
.. math::
|
||||
H(U) = \log(high - low).
|
||||
"""
|
||||
low, high = self._check_param(low, high)
|
||||
low, high = self._check_param_type(low, high)
|
||||
return self.log(high - low)
|
||||
|
||||
def _cross_entropy(self, dist, low_b, high_b, low=None, high=None):
|
||||
|
@ -266,7 +243,7 @@ class Uniform(Distribution):
|
|||
"""
|
||||
value = self._check_value(value, 'value')
|
||||
value = self.cast(value, self.dtype)
|
||||
low, high = self._check_param(low, high)
|
||||
low, high = self._check_param_type(low, high)
|
||||
neg_ones = self.fill(self.dtype, self.shape(value), -1.0)
|
||||
prob = self.exp(neg_ones * self.log(high - low))
|
||||
broadcast_shape = self.shape(prob)
|
||||
|
@ -292,7 +269,7 @@ class Uniform(Distribution):
|
|||
low_b = self.cast(low_b, self.parameter_type)
|
||||
high_b = self._check_value(high_b, 'high_b')
|
||||
high_b = self.cast(high_b, self.parameter_type)
|
||||
low_a, high_a = self._check_param(low, high)
|
||||
low_a, high_a = self._check_param_type(low, high)
|
||||
kl = self.log(high_b - low_b) - self.log(high_a - low_a)
|
||||
comp = self.logicaland(self.lessequal(low_b, low_a), self.lessequal(high_a, high_b))
|
||||
return self.select(comp, kl, self.log(self.zeroslike(kl)))
|
||||
|
@ -313,7 +290,7 @@ class Uniform(Distribution):
|
|||
"""
|
||||
value = self._check_value(value, 'value')
|
||||
value = self.cast(value, self.dtype)
|
||||
low, high = self._check_param(low, high)
|
||||
low, high = self._check_param_type(low, high)
|
||||
prob = (value - low) / (high - low)
|
||||
broadcast_shape = self.shape(prob)
|
||||
zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
|
||||
|
@ -336,7 +313,7 @@ class Uniform(Distribution):
|
|||
Tensor, shape is shape + batch_shape.
|
||||
"""
|
||||
shape = self.checktuple(shape, 'shape')
|
||||
low, high = self._check_param(low, high)
|
||||
low, high = self._check_param_type(low, high)
|
||||
broadcast_shape = self.shape(low + high)
|
||||
origin_shape = shape + broadcast_shape
|
||||
if origin_shape == ():
|
||||
|
|
|
@ -0,0 +1,182 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Test util functions used in distribution classes.
|
||||
"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore import context
|
||||
from mindspore import dtype
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.nn.probability.distribution._utils.utils import set_param_type, \
|
||||
cast_to_tensor, CheckTuple, CheckTensor
|
||||
|
||||
def test_set_param_type():
|
||||
"""
|
||||
Test set_param_type function.
|
||||
"""
|
||||
tensor_fp16 = Tensor(0.1, dtype=dtype.float16)
|
||||
tensor_fp32 = Tensor(0.1, dtype=dtype.float32)
|
||||
tensor_fp64 = Tensor(0.1, dtype=dtype.float64)
|
||||
tensor_int32 = Tensor(0.1, dtype=dtype.int32)
|
||||
array_fp32 = np.array(1.0).astype(np.float32)
|
||||
array_fp64 = np.array(1.0).astype(np.float64)
|
||||
array_int32 = np.array(1.0).astype(np.int32)
|
||||
|
||||
dict1 = {'a': tensor_fp32, 'b': 1.0, 'c': tensor_fp32}
|
||||
dict2 = {'a': tensor_fp32, 'b': 1.0, 'c': tensor_fp64}
|
||||
dict3 = {'a': tensor_int32, 'b': 1.0, 'c': tensor_int32}
|
||||
dict4 = {'a': array_fp32, 'b': 1.0, 'c': tensor_fp32}
|
||||
dict5 = {'a': array_fp32, 'b': 1.0, 'c': array_fp64}
|
||||
dict6 = {'a': array_fp32, 'b': 1.0, 'c': array_int32}
|
||||
dict7 = {'a': 1.0}
|
||||
dict8 = {'a': 1.0, 'b': 1.0, 'c': 1.0}
|
||||
dict9 = {'a': tensor_fp16, 'b': tensor_fp16, 'c': tensor_fp16}
|
||||
dict10 = {'a': tensor_fp64, 'b': tensor_fp64, 'c': tensor_fp64}
|
||||
dict11 = {'a': array_fp64, 'b': array_fp64, 'c': tensor_fp64}
|
||||
|
||||
ans1 = set_param_type(dict1, dtype.float16)
|
||||
assert ans1 == dtype.float32
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
set_param_type(dict2, dtype.float32)
|
||||
|
||||
ans3 = set_param_type(dict3, dtype.float16)
|
||||
assert ans3 == dtype.float32
|
||||
ans4 = set_param_type(dict4, dtype.float16)
|
||||
assert ans4 == dtype.float32
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
set_param_type(dict5, dtype.float32)
|
||||
with pytest.raises(TypeError):
|
||||
set_param_type(dict6, dtype.float32)
|
||||
|
||||
ans7 = set_param_type(dict7, dtype.float32)
|
||||
assert ans7 == dtype.float32
|
||||
ans8 = set_param_type(dict8, dtype.float32)
|
||||
assert ans8 == dtype.float32
|
||||
ans9 = set_param_type(dict9, dtype.float32)
|
||||
assert ans9 == dtype.float16
|
||||
ans10 = set_param_type(dict10, dtype.float32)
|
||||
assert ans10 == dtype.float32
|
||||
ans11 = set_param_type(dict11, dtype.float32)
|
||||
assert ans11 == dtype.float32
|
||||
|
||||
def test_cast_to_tensor():
|
||||
"""
|
||||
Test cast_to_tensor.
|
||||
"""
|
||||
with pytest.raises(ValueError):
|
||||
cast_to_tensor(None, dtype.float32)
|
||||
with pytest.raises(TypeError):
|
||||
cast_to_tensor(True, dtype.float32)
|
||||
with pytest.raises(TypeError):
|
||||
cast_to_tensor({'a': 1, 'b': 2}, dtype.float32)
|
||||
with pytest.raises(TypeError):
|
||||
cast_to_tensor('tensor', dtype.float32)
|
||||
|
||||
ans1 = cast_to_tensor(Parameter(Tensor(0.1, dtype=dtype.float32), 'param'))
|
||||
assert isinstance(ans1, Parameter)
|
||||
ans2 = cast_to_tensor(np.array(1.0).astype(np.float32))
|
||||
assert isinstance(ans2, Tensor)
|
||||
ans3 = cast_to_tensor([1.0, 2.0])
|
||||
assert isinstance(ans3, Tensor)
|
||||
ans4 = cast_to_tensor(Tensor(0.1, dtype=dtype.float32), dtype.float32)
|
||||
assert isinstance(ans4, Tensor)
|
||||
ans5 = cast_to_tensor(0.1, dtype.float32)
|
||||
assert isinstance(ans5, Tensor)
|
||||
ans6 = cast_to_tensor(1, dtype.float32)
|
||||
assert isinstance(ans6, Tensor)
|
||||
|
||||
class Net(Cell):
|
||||
"""
|
||||
Test class: CheckTuple.
|
||||
"""
|
||||
def __init__(self, value):
|
||||
super(Net, self).__init__()
|
||||
self.checktuple = CheckTuple()
|
||||
self.value = value
|
||||
|
||||
def construct(self, value=None):
|
||||
if value is None:
|
||||
return self.checktuple(self.value, 'input')
|
||||
return self.checktuple(value, 'input')
|
||||
|
||||
def test_check_tuple():
|
||||
"""
|
||||
Test CheckTuple.
|
||||
"""
|
||||
net1 = Net((1, 2, 3))
|
||||
ans1 = net1()
|
||||
assert isinstance(ans1, tuple)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
net2 = Net('tuple')
|
||||
net2()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net3 = Net((1, 2, 3))
|
||||
ans3 = net3()
|
||||
assert isinstance(ans3, tuple)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
net4 = Net('tuple')
|
||||
net4()
|
||||
|
||||
class Net1(Cell):
|
||||
"""
|
||||
Test class: CheckTensor.
|
||||
"""
|
||||
def __init__(self, value):
|
||||
super(Net1, self).__init__()
|
||||
self.checktensor = CheckTensor()
|
||||
self.value = value
|
||||
self.context = context.get_context('mode')
|
||||
|
||||
def construct(self, value=None):
|
||||
value = self.value if value is None else value
|
||||
if self.context == 0:
|
||||
self.checktensor(value, 'input')
|
||||
return value
|
||||
return self.checktensor(value, 'input')
|
||||
|
||||
def test_check_tensor():
|
||||
"""
|
||||
Test CheckTensor.
|
||||
"""
|
||||
value = Tensor(0.1, dtype=dtype.float32)
|
||||
net1 = Net1(value)
|
||||
ans1 = net1()
|
||||
assert isinstance(ans1, Tensor)
|
||||
ans1 = net1(value)
|
||||
assert isinstance(ans1, Tensor)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
net2 = Net1('tuple')
|
||||
net2()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net3 = Net1(value)
|
||||
ans3 = net3()
|
||||
assert isinstance(ans3, Tensor)
|
||||
ans3 = net3(value)
|
||||
assert isinstance(ans3, Tensor)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
net4 = Net1('tuple')
|
||||
net4()
|
Loading…
Reference in New Issue