!6020 Add common check_param_type and set_param_type in distribution

Merge pull request !6020 from XunDeng/new_check_param
This commit is contained in:
mindspore-ci-bot 2020-09-13 13:54:39 +08:00 committed by Gitee
commit c10341dfb7
8 changed files with 476 additions and 243 deletions

View File

@ -110,6 +110,8 @@ def check_scalar_from_param(params):
Notes: String parameters are excluded. Notes: String parameters are excluded.
""" """
for value in params.values(): for value in params.values():
if value is None:
continue
if isinstance(value, (msp.bijector.Bijector, msp.distribution.Distribution)): if isinstance(value, (msp.bijector.Bijector, msp.distribution.Distribution)):
return params['distribution'].is_scalar_batch return params['distribution'].is_scalar_batch
if isinstance(value, Parameter): if isinstance(value, Parameter):
@ -358,23 +360,29 @@ class CheckTensor(PrimitiveWithInfer):
return x return x
raise TypeError(f"For {name}, input type should be a Tensor or Parameter.") raise TypeError(f"For {name}, input type should be a Tensor or Parameter.")
def common_dtype(arg_a, name_a, arg_b, name_b, hint_type): def set_param_type(args, hint_type):
""" """
check if arg_a and arg_b have the same dtype. Find the common type among arguments.
Args:
args (dict): dictionary of arguments, {'name':value}.
hint_type (mindspore.dtype): hint type to return.
Raises:
TypeError: if tensors in args are not the same dtype.
""" """
if hasattr(arg_a, 'dtype') and hasattr(arg_b, 'dtype'): common_dtype = None
if isinstance(arg_a, np.ndarray): for name, arg in args.items():
a_dtype = mstype.pytype_to_dtype(arg_a.dtype) if hasattr(arg, 'dtype'):
else: if isinstance(arg, np.ndarray):
a_dtype = arg_a.dtype cur_dtype = mstype.pytype_to_dtype(arg.dtype)
if isinstance(arg_b, np.ndarray): else:
b_dtype = mstype.pytype_to_dtype(arg_b.dtype) cur_dtype = arg.dtype
else: if common_dtype is None:
b_dtype = arg_b.dtype common_dtype = cur_dtype
if a_dtype != b_dtype: elif cur_dtype != common_dtype:
raise TypeError(f"{name_a} and {name_b} should have the same dtype.") raise TypeError(f"{name} should have the same dtype as other arguments.")
int_type = mstype.int_type + mstype.uint_type int_type = mstype.int_type + mstype.uint_type
if a_dtype in int_type or a_dtype == mstype.float64: if common_dtype in int_type or common_dtype == mstype.float64:
return mstype.float32 return mstype.float32
return a_dtype return hint_type if common_dtype is None else common_dtype
return hint_type

View File

@ -17,7 +17,7 @@ from mindspore.common import dtype as mstype
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, set_param_type
from ._utils.custom_ops import exp_generic, log_generic, erf_generic from ._utils.custom_ops import exp_generic, log_generic, erf_generic
@ -119,13 +119,16 @@ class Bernoulli(Distribution):
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__) check_type(dtype, valid_dtype, type(self).__name__)
super(Bernoulli, self).__init__(seed, dtype, name, param) super(Bernoulli, self).__init__(seed, dtype, name, param)
self.parameter_type = mstype.float32 self.parameter_type = set_param_type({'probs1': probs}, mstype.float32)
if probs is not None: if probs is not None:
self._probs = cast_to_tensor(probs, mstype.float32) self._probs = cast_to_tensor(probs, self.parameter_type)
check_prob(self.probs) check_prob(self.probs)
else: else:
self._probs = probs self._probs = probs
self.default_parameters = [self.probs]
self.parameter_names = ['probs1']
# ops needed for the class # ops needed for the class
self.exp = exp_generic self.exp = exp_generic
self.log = log_generic self.log = log_generic
@ -157,24 +160,12 @@ class Bernoulli(Distribution):
""" """
return self._probs return self._probs
def _check_param(self, probs1):
"""
Check availablity of distribution specific args `probs1`.
"""
if probs1 is not None:
if self.context_mode == 0:
self.checktensor(probs1, 'probs1')
else:
probs1 = self.checktensor(probs1, 'probs1')
return self.cast(probs1, self.parameter_type)
return self.probs if self.probs is not None else raise_none_error('probs1')
def _mean(self, probs1=None): def _mean(self, probs1=None):
r""" r"""
.. math:: .. math::
MEAN(B) = probs1 MEAN(B) = probs1
""" """
probs1 = self._check_param(probs1) probs1 = self._check_param_type(probs1)
return probs1 return probs1
def _mode(self, probs1=None): def _mode(self, probs1=None):
@ -182,7 +173,7 @@ class Bernoulli(Distribution):
.. math:: .. math::
MODE(B) = 1 if probs1 > 0.5 else = 0 MODE(B) = 1 if probs1 > 0.5 else = 0
""" """
probs1 = self._check_param(probs1) probs1 = self._check_param_type(probs1)
prob_type = self.dtypeop(probs1) prob_type = self.dtypeop(probs1)
zeros = self.fill(prob_type, self.shape(probs1), 0.0) zeros = self.fill(prob_type, self.shape(probs1), 0.0)
ones = self.fill(prob_type, self.shape(probs1), 1.0) ones = self.fill(prob_type, self.shape(probs1), 1.0)
@ -194,7 +185,7 @@ class Bernoulli(Distribution):
.. math:: .. math::
VAR(B) = probs1 * probs0 VAR(B) = probs1 * probs0
""" """
probs1 = self._check_param(probs1) probs1 = self._check_param_type(probs1)
probs0 = 1.0 - probs1 probs0 = 1.0 - probs1
return self.exp(self.log(probs0) + self.log(probs1)) return self.exp(self.log(probs0) + self.log(probs1))
@ -203,11 +194,11 @@ class Bernoulli(Distribution):
.. math:: .. math::
H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1) H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1)
""" """
probs1 = self._check_param(probs1) probs1 = self._check_param_type(probs1)
probs0 = 1.0 - probs1 probs0 = 1.0 - probs1
return -(probs0 * self.log(probs0)) - (probs1 * self.log(probs1)) return -(probs0 * self.log(probs0)) - (probs1 * self.log(probs1))
def _cross_entropy(self, dist, probs1_b, probs1_a=None): def _cross_entropy(self, dist, probs1_b, probs1=None):
""" """
Evaluate cross_entropy between Bernoulli distributions. Evaluate cross_entropy between Bernoulli distributions.
@ -217,7 +208,7 @@ class Bernoulli(Distribution):
probs1_a (Tensor): `probs1` of distribution a. Default: self.probs. probs1_a (Tensor): `probs1` of distribution a. Default: self.probs.
""" """
check_distribution_name(dist, 'Bernoulli') check_distribution_name(dist, 'Bernoulli')
return self._entropy(probs1_a) + self._kl_loss(dist, probs1_b, probs1_a) return self._entropy(probs1) + self._kl_loss(dist, probs1_b, probs1)
def _log_prob(self, value, probs1=None): def _log_prob(self, value, probs1=None):
r""" r"""
@ -233,7 +224,7 @@ class Bernoulli(Distribution):
""" """
value = self._check_value(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, mstype.float32) value = self.cast(value, mstype.float32)
probs1 = self._check_param(probs1) probs1 = self._check_param_type(probs1)
probs0 = 1.0 - probs1 probs0 = 1.0 - probs1
return self.log(probs1) * value + self.log(probs0) * (1.0 - value) return self.log(probs1) * value + self.log(probs0) * (1.0 - value)
@ -253,7 +244,7 @@ class Bernoulli(Distribution):
value = self._check_value(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, mstype.float32) value = self.cast(value, mstype.float32)
value = self.floor(value) value = self.floor(value)
probs1 = self._check_param(probs1) probs1 = self._check_param_type(probs1)
prob_type = self.dtypeop(probs1) prob_type = self.dtypeop(probs1)
value = value * self.fill(prob_type, self.shape(probs1), 1.0) value = value * self.fill(prob_type, self.shape(probs1), 1.0)
probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0) probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0)
@ -264,7 +255,7 @@ class Bernoulli(Distribution):
less_than_zero = self.select(comp_zero, zeros, probs0) less_than_zero = self.select(comp_zero, zeros, probs0)
return self.select(comp_one, less_than_zero, ones) return self.select(comp_one, less_than_zero, ones)
def _kl_loss(self, dist, probs1_b, probs1_a=None): def _kl_loss(self, dist, probs1_b, probs1=None):
r""" r"""
Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b). Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b).
@ -280,7 +271,7 @@ class Bernoulli(Distribution):
check_distribution_name(dist, 'Bernoulli') check_distribution_name(dist, 'Bernoulli')
probs1_b = self._check_value(probs1_b, 'probs1_b') probs1_b = self._check_value(probs1_b, 'probs1_b')
probs1_b = self.cast(probs1_b, self.parameter_type) probs1_b = self.cast(probs1_b, self.parameter_type)
probs1_a = self._check_param(probs1_a) probs1_a = self._check_param_type(probs1)
probs0_a = 1.0 - probs1_a probs0_a = 1.0 - probs1_a
probs0_b = 1.0 - probs1_b probs0_b = 1.0 - probs1_b
return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b) return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b)
@ -297,7 +288,7 @@ class Bernoulli(Distribution):
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.
""" """
shape = self.checktuple(shape, 'shape') shape = self.checktuple(shape, 'shape')
probs1 = self._check_param(probs1) probs1 = self._check_param_type(probs1)
origin_shape = shape + self.shape(probs1) origin_shape = shape + self.shape(probs1)
if origin_shape == (): if origin_shape == ():
sample_shape = (1,) sample_shape = (1,)

View File

@ -18,7 +18,8 @@ from mindspore.nn.cell import Cell
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from mindspore.common import get_seed from mindspore.common import get_seed
from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param, cast_type_for_device from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param, cast_type_for_device,\
raise_none_error
from ._utils.utils import CheckTuple, CheckTensor from ._utils.utils import CheckTuple, CheckTensor
@ -115,6 +116,51 @@ class Distribution(Cell):
def broadcast_shape(self): def broadcast_shape(self):
return self._broadcast_shape return self._broadcast_shape
def _check_param_type(self, *args):
"""
Check the availability and validity of default parameters and dist_spec_args.
dist_spec_args passed in must be tensors. If default parameter of the distribution
is None, its parameter must be passed in through `args`.
"""
broadcast_shape = None
common_dtype = None
out = []
for arg, name, default in zip(args, self.parameter_names, self.default_parameters):
# check if the argument is a Tensor
if arg is not None:
if self.context_mode == 0:
self.checktensor(arg, name)
else:
arg = self.checktensor(arg, name)
else:
arg = default if default is not None else raise_none_error(name)
# broadcast if the number of args > 1
if broadcast_shape is None:
broadcast_shape = self.shape(arg)
common_dtype = self.dtypeop(arg)
else:
ones = self.fill(self.dtypeop(arg), broadcast_shape, 1.0)
broadcast_shape = self.shape(arg + ones)
# check if the arguments have the same dtype
arg = arg * self.fill(self.dtypeop(arg), broadcast_shape, 1.0)
dtype_tensor = self.fill(common_dtype, broadcast_shape, 1.0)
self.sametypeshape(arg, dtype_tensor)
arg = self.cast(arg, self.parameter_type)
out.append(arg)
if len(out) == 1:
return out[0]
# broadcast all args to broadcast_shape
result = ()
for arg in out:
arg = arg * self.fill(self.dtypeop(arg), broadcast_shape, 1.0)
result = result + (arg,)
return result
def _check_value(self, value, name): def _check_value(self, value, name):
""" """
Check availability of `value` as a Tensor. Check availability of `value` as a Tensor.
@ -211,163 +257,203 @@ class Distribution(Cell):
if hasattr(self, '_cross_entropy'): if hasattr(self, '_cross_entropy'):
self._call_cross_entropy = self._cross_entropy self._call_cross_entropy = self._cross_entropy
def log_prob(self, *args, **kwargs): def log_prob(self, value, *args, **kwargs):
""" """
Evaluate the log probability(pdf or pmf) at the given value. Evaluate the log probability(pdf or pmf) at the given value.
Note: Args:
The argument `args` must include `value`. value (Tensor): value to be evaluated.
dist_spec_args are optional. *args (list): the list of positional arguments forwarded to subclasses.
""" **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
return self._call_log_prob(*args, **kwargs)
def _calc_prob_from_log_prob(self, *args, **kwargs): Note:
A distribution can be optionally passed to the function by passing its dist_spec_args through
`args` or `kwargs`.
"""
return self._call_log_prob(value, *args, **kwargs)
def _calc_prob_from_log_prob(self, value, *args, **kwargs):
r""" r"""
Evaluate prob from log probability. Evaluate prob from log probability.
.. math:: .. math::
probability(x) = \exp(log_likehood(x)) probability(x) = \exp(log_likehood(x))
""" """
return self.exp(self._log_prob(*args, **kwargs)) return self.exp(self._log_prob(value, *args, **kwargs))
def prob(self, *args, **kwargs): def prob(self, value, *args, **kwargs):
""" """
Evaluate the probability (pdf or pmf) at given value. Evaluate the probability (pdf or pmf) at given value.
Note: Args:
The argument `args` must include `value`. value (Tensor): value to be evaluated.
dist_spec_args are optional. *args (list): the list of positional arguments forwarded to subclasses.
""" **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
return self._call_prob(*args, **kwargs)
def _calc_log_prob_from_prob(self, *args, **kwargs): Note:
A distribution can be optionally passed to the function by passing its dist_spec_args through
`args` or `kwargs`.
"""
return self._call_prob(value, *args, **kwargs)
def _calc_log_prob_from_prob(self, value, *args, **kwargs):
r""" r"""
Evaluate log probability from probability. Evaluate log probability from probability.
.. math:: .. math::
log_prob(x) = \log(prob(x)) log_prob(x) = \log(prob(x))
""" """
return self.log(self._prob(*args, **kwargs)) return self.log(self._prob(value, *args, **kwargs))
def cdf(self, *args, **kwargs): def cdf(self, value, *args, **kwargs):
""" """
Evaluate the cdf at given value. Evaluate the cdf at given value.
Note: Args:
The argument `args` must include `value`. value (Tensor): value to be evaluated.
dist_spec_args are optional. *args (list): the list of positional arguments forwarded to subclasses.
""" **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
return self._call_cdf(*args, **kwargs)
def _calc_cdf_from_log_cdf(self, *args, **kwargs): Note:
A distribution can be optionally passed to the function by passing its dist_spec_args through
`args` or `kwargs`.
"""
return self._call_cdf(value, *args, **kwargs)
def _calc_cdf_from_log_cdf(self, value, *args, **kwargs):
r""" r"""
Evaluate cdf from log_cdf. Evaluate cdf from log_cdf.
.. math:: .. math::
cdf(x) = \exp(log_cdf(x)) cdf(x) = \exp(log_cdf(x))
""" """
return self.exp(self._log_cdf(*args, **kwargs)) return self.exp(self._log_cdf(value, *args, **kwargs))
def _calc_cdf_from_survival(self, *args, **kwargs): def _calc_cdf_from_survival(self, value, *args, **kwargs):
r""" r"""
Evaluate cdf from survival function. Evaluate cdf from survival function.
.. math:: .. math::
cdf(x) = 1 - (survival_function(x)) cdf(x) = 1 - (survival_function(x))
""" """
return 1.0 - self._survival_function(*args, **kwargs) return 1.0 - self._survival_function(value, *args, **kwargs)
def _calc_cdf_from_log_survival(self, *args, **kwargs): def _calc_cdf_from_log_survival(self, value, *args, **kwargs):
r""" r"""
Evaluate cdf from log survival function. Evaluate cdf from log survival function.
.. math:: .. math::
cdf(x) = 1 - (\exp(log_survival(x))) cdf(x) = 1 - (\exp(log_survival(x)))
""" """
return 1.0 - self.exp(self._log_survival(*args, **kwargs)) return 1.0 - self.exp(self._log_survival(value, *args, **kwargs))
def log_cdf(self, *args, **kwargs): def log_cdf(self, value, *args, **kwargs):
""" """
Evaluate the log cdf at given value. Evaluate the log cdf at given value.
Note: Args:
The argument `args` must include `value`. value (Tensor): value to be evaluated.
dist_spec_args are optional. *args (list): the list of positional arguments forwarded to subclasses.
""" **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
return self._call_log_cdf(*args, **kwargs)
def _calc_log_cdf_from_call_cdf(self, *args, **kwargs): Note:
A distribution can be optionally passed to the function by passing its dist_spec_args through
`args` or `kwargs`.
"""
return self._call_log_cdf(value, *args, **kwargs)
def _calc_log_cdf_from_call_cdf(self, value, *args, **kwargs):
r""" r"""
Evaluate log cdf from cdf. Evaluate log cdf from cdf.
.. math:: .. math::
log_cdf(x) = \log(cdf(x)) log_cdf(x) = \log(cdf(x))
""" """
return self.log(self._call_cdf(*args, **kwargs)) return self.log(self._call_cdf(value, *args, **kwargs))
def survival_function(self, *args, **kwargs): def survival_function(self, value, *args, **kwargs):
""" """
Evaluate the survival function at given value. Evaluate the survival function at given value.
Note: Args:
The argument `args` must include `value`. value (Tensor): value to be evaluated.
dist_spec_args are optional. *args (list): the list of positional arguments forwarded to subclasses.
""" **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
return self._call_survival(*args, **kwargs)
def _calc_survival_from_call_cdf(self, *args, **kwargs): Note:
A distribution can be optionally passed to the function by passing its dist_spec_args through
`args` or `kwargs`.
"""
return self._call_survival(value, *args, **kwargs)
def _calc_survival_from_call_cdf(self, value, *args, **kwargs):
r""" r"""
Evaluate survival function from cdf. Evaluate survival function from cdf.
.. math:: .. math::
survival_function(x) = 1 - (cdf(x)) survival_function(x) = 1 - (cdf(x))
""" """
return 1.0 - self._call_cdf(*args, **kwargs) return 1.0 - self._call_cdf(value, *args, **kwargs)
def _calc_survival_from_log_survival(self, *args, **kwargs): def _calc_survival_from_log_survival(self, value, *args, **kwargs):
r""" r"""
Evaluate survival function from log survival function. Evaluate survival function from log survival function.
.. math:: .. math::
survival(x) = \exp(survival_function(x)) survival(x) = \exp(survival_function(x))
""" """
return self.exp(self._log_survival(*args, **kwargs)) return self.exp(self._log_survival(value, *args, **kwargs))
def log_survival(self, *args, **kwargs): def log_survival(self, value, *args, **kwargs):
""" """
Evaluate the log survival function at given value. Evaluate the log survival function at given value.
Note: Args:
The arguments `args` must include `value`. value (Tensor): value to be evaluated.
dist_spec_args are optional. *args (list): the list of positional arguments forwarded to subclasses.
""" **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
return self._call_log_survival(*args, **kwargs)
def _calc_log_survival_from_call_survival(self, *args, **kwargs): Note:
A distribution can be optionally passed to the function by passing its dist_spec_args through
`args` or `kwargs`.
"""
return self._call_log_survival(value, *args, **kwargs)
def _calc_log_survival_from_call_survival(self, value, *args, **kwargs):
r""" r"""
Evaluate log survival function from survival function. Evaluate log survival function from survival function.
.. math:: .. math::
log_survival(x) = \log(survival_function(x)) log_survival(x) = \log(survival_function(x))
""" """
return self.log(self._call_survival(*args, **kwargs)) return self.log(self._call_survival(value, *args, **kwargs))
def kl_loss(self, *args, **kwargs): def kl_loss(self, dist, *args, **kwargs):
""" """
Evaluate the KL divergence, i.e. KL(a||b). Evaluate the KL divergence, i.e. KL(a||b).
Args:
dist (str): type of the distribution.
*args (list): the list of positional arguments forwarded to subclasses.
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
Note: Note:
The argument `args` must include the type of the distribution, parameters of distribution b. dist_spec_args of distribution b must be passed to the function through `args` or `kwargs`.
Parameters for distribution a are optional. Passing in dist_spec_args of distribution a is optional.
""" """
return self._kl_loss(*args, **kwargs) return self._kl_loss(dist, *args, **kwargs)
def mean(self, *args, **kwargs): def mean(self, *args, **kwargs):
""" """
Evaluate the mean. Evaluate the mean.
Args:
*args (list): the list of positional arguments forwarded to subclasses.
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
Note: Note:
dist_spec_args are optional. A distribution can be optionally passed to the function by passing its *dist_spec_args* through
*args* or *kwargs*.
""" """
return self._mean(*args, **kwargs) return self._mean(*args, **kwargs)
@ -375,8 +461,13 @@ class Distribution(Cell):
""" """
Evaluate the mode. Evaluate the mode.
Args:
*args (list): the list of positional arguments forwarded to subclasses.
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
Note: Note:
dist_spec_args are optional. A distribution can be optionally passed to the function by passing its *dist_spec_args* through
*args* or *kwargs*.
""" """
return self._mode(*args, **kwargs) return self._mode(*args, **kwargs)
@ -384,8 +475,13 @@ class Distribution(Cell):
""" """
Evaluate the standard deviation. Evaluate the standard deviation.
Args:
*args (list): the list of positional arguments forwarded to subclasses.
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
Note: Note:
dist_spec_args are optional. A distribution can be optionally passed to the function by passing its *dist_spec_args* through
*args* or *kwargs*.
""" """
return self._call_sd(*args, **kwargs) return self._call_sd(*args, **kwargs)
@ -393,8 +489,13 @@ class Distribution(Cell):
""" """
Evaluate the variance. Evaluate the variance.
Args:
*args (list): the list of positional arguments forwarded to subclasses.
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
Note: Note:
dist_spec_args are optional. A distribution can be optionally passed to the function by passing its *dist_spec_args* through
*args* or *kwargs*.
""" """
return self._call_var(*args, **kwargs) return self._call_var(*args, **kwargs)
@ -420,37 +521,52 @@ class Distribution(Cell):
""" """
Evaluate the entropy. Evaluate the entropy.
Args:
*args (list): the list of positional arguments forwarded to subclasses.
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
Note: Note:
dist_spec_args are optional. A distribution can be optionally passed to the function by passing its *dist_spec_args* through
*args* or *kwargs*.
""" """
return self._entropy(*args, **kwargs) return self._entropy(*args, **kwargs)
def cross_entropy(self, *args, **kwargs): def cross_entropy(self, dist, *args, **kwargs):
""" """
Evaluate the cross_entropy between distribution a and b. Evaluate the cross_entropy between distribution a and b.
Note: Args:
The argument `args` must include the type of the distribution, parameters of distribution b. dist (str): type of the distribution.
Parameters for distribution a are optional. *args (list): the list of positional arguments forwarded to subclasses.
""" **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
return self._call_cross_entropy(*args, **kwargs)
def _calc_cross_entropy(self, *args, **kwargs): Note:
dist_spec_args of distribution b must be passed to the function through `args` or `kwargs`.
Passing in dist_spec_args of distribution a is optional.
"""
return self._call_cross_entropy(dist, *args, **kwargs)
def _calc_cross_entropy(self, dist, *args, **kwargs):
r""" r"""
Evaluate cross_entropy from entropy and kl divergence. Evaluate cross_entropy from entropy and kl divergence.
.. math:: .. math::
H(X, Y) = H(X) + KL(X||Y) H(X, Y) = H(X) + KL(X||Y)
""" """
return self._entropy(*args, **kwargs) + self._kl_loss(*args, **kwargs) return self._entropy(*args, **kwargs) + self._kl_loss(dist, *args, **kwargs)
def sample(self, *args, **kwargs): def sample(self, *args, **kwargs):
""" """
Sampling function. Sampling function.
Args:
shape (tuple): shape of the sample.
*args (list): the list of positional arguments forwarded to subclasses.
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
Note: Note:
Shape of the sample is default to (). A distribution can be optionally passed to the function by passing its *dist_spec_args* through
dist_spec_args are optional. *args* or *kwargs*.
""" """
return self._sample(*args, **kwargs) return self._sample(*args, **kwargs)

View File

@ -18,8 +18,7 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\ from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name, set_param_type
raise_none_error
from ._utils.custom_ops import exp_generic, log_generic from ._utils.custom_ops import exp_generic, log_generic
class Exponential(Distribution): class Exponential(Distribution):
@ -121,15 +120,19 @@ class Exponential(Distribution):
valid_dtype = mstype.float_type valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__) check_type(dtype, valid_dtype, type(self).__name__)
super(Exponential, self).__init__(seed, dtype, name, param) super(Exponential, self).__init__(seed, dtype, name, param)
self.parameter_type = dtype self.parameter_type = set_param_type({'rate': rate}, self.dtype)
if rate is not None: if rate is not None:
self._rate = cast_to_tensor(rate, self.parameter_type) self._rate = cast_to_tensor(rate, self.parameter_type)
check_greater_zero(self._rate, "rate") check_greater_zero(self._rate, "rate")
else: else:
self._rate = rate self._rate = rate
self.default_parameters = [self.rate]
self.parameter_names = ['rate']
self.minval = np.finfo(np.float).tiny self.minval = np.finfo(np.float).tiny
# ops needed for the class # ops needed for the class
self.exp = exp_generic self.exp = exp_generic
self.log = log_generic self.log = log_generic
@ -156,28 +159,16 @@ class Exponential(Distribution):
@property @property
def rate(self): def rate(self):
""" """
Return rate of the distribution. Return `rate` of the distribution.
""" """
return self._rate return self._rate
def _check_param(self, rate):
"""
Check availablity of distribution specific argument `rate`.
"""
if rate is not None:
if self.context_mode == 0:
self.checktensor(rate, 'rate')
else:
rate = self.checktensor(rate, 'rate')
return self.cast(rate, self.parameter_type)
return self.rate if self.rate is not None else raise_none_error('rate')
def _mean(self, rate=None): def _mean(self, rate=None):
r""" r"""
.. math:: .. math::
MEAN(EXP) = \frac{1.0}{\lambda}. MEAN(EXP) = \frac{1.0}{\lambda}.
""" """
rate = self._check_param(rate) rate = self._check_param_type(rate)
return 1.0 / rate return 1.0 / rate
def _mode(self, rate=None): def _mode(self, rate=None):
@ -185,7 +176,7 @@ class Exponential(Distribution):
.. math:: .. math::
MODE(EXP) = 0. MODE(EXP) = 0.
""" """
rate = self._check_param(rate) rate = self._check_param_type(rate)
return self.fill(self.dtype, self.shape(rate), 0.) return self.fill(self.dtype, self.shape(rate), 0.)
def _sd(self, rate=None): def _sd(self, rate=None):
@ -193,7 +184,7 @@ class Exponential(Distribution):
.. math:: .. math::
sd(EXP) = \frac{1.0}{\lambda}. sd(EXP) = \frac{1.0}{\lambda}.
""" """
rate = self._check_param(rate) rate = self._check_param_type(rate)
return 1.0 / rate return 1.0 / rate
def _entropy(self, rate=None): def _entropy(self, rate=None):
@ -201,7 +192,7 @@ class Exponential(Distribution):
.. math:: .. math::
H(Exp) = 1 - \log(\lambda). H(Exp) = 1 - \log(\lambda).
""" """
rate = self._check_param(rate) rate = self._check_param_type(rate)
return 1.0 - self.log(rate) return 1.0 - self.log(rate)
def _cross_entropy(self, dist, rate_b, rate=None): def _cross_entropy(self, dist, rate_b, rate=None):
@ -234,7 +225,7 @@ class Exponential(Distribution):
""" """
value = self._check_value(value, "value") value = self._check_value(value, "value")
value = self.cast(value, self.dtype) value = self.cast(value, self.dtype)
rate = self._check_param(rate) rate = self._check_param_type(rate)
prob = self.log(rate) - rate * value prob = self.log(rate) - rate * value
zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0) zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0)
neginf = self.fill(self.dtypeop(prob), self.shape(prob), -np.inf) neginf = self.fill(self.dtypeop(prob), self.shape(prob), -np.inf)
@ -257,7 +248,7 @@ class Exponential(Distribution):
""" """
value = self._check_value(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, self.dtype) value = self.cast(value, self.dtype)
rate = self._check_param(rate) rate = self._check_param_type(rate)
cdf = 1.0 - self.exp(-1. * rate * value) cdf = 1.0 - self.exp(-1. * rate * value)
zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0) zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
comp = self.less(value, zeros) comp = self.less(value, zeros)
@ -279,7 +270,7 @@ class Exponential(Distribution):
""" """
value = self._check_value(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, self.dtype) value = self.cast(value, self.dtype)
rate = self._check_param(rate) rate = self._check_param_type(rate)
sf = -1. * rate * value sf = -1. * rate * value
zeros = self.fill(self.dtypeop(sf), self.shape(sf), 0.0) zeros = self.fill(self.dtypeop(sf), self.shape(sf), 0.0)
comp = self.less(value, zeros) comp = self.less(value, zeros)
@ -297,7 +288,7 @@ class Exponential(Distribution):
check_distribution_name(dist, 'Exponential') check_distribution_name(dist, 'Exponential')
rate_b = self._check_value(rate_b, 'rate_b') rate_b = self._check_value(rate_b, 'rate_b')
rate_b = self.cast(rate_b, self.parameter_type) rate_b = self.cast(rate_b, self.parameter_type)
rate_a = self._check_param(rate) rate_a = self._check_param_type(rate)
return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0 return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0
def _sample(self, shape=(), rate=None): def _sample(self, shape=(), rate=None):
@ -312,7 +303,7 @@ class Exponential(Distribution):
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.
""" """
shape = self.checktuple(shape, 'shape') shape = self.checktuple(shape, 'shape')
rate = self._check_param(rate) rate = self._check_param_type(rate)
origin_shape = shape + self.shape(rate) origin_shape = shape + self.shape(rate)
if origin_shape == (): if origin_shape == ():
sample_shape = (1,) sample_shape = (1,)

View File

@ -19,7 +19,7 @@ from mindspore.ops import composite as C
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\ from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\
raise_none_error set_param_type
from ._utils.custom_ops import exp_generic, log_generic from ._utils.custom_ops import exp_generic, log_generic
@ -123,13 +123,16 @@ class Geometric(Distribution):
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__) check_type(dtype, valid_dtype, type(self).__name__)
super(Geometric, self).__init__(seed, dtype, name, param) super(Geometric, self).__init__(seed, dtype, name, param)
self.parameter_type = mstype.float32 self.parameter_type = set_param_type({'probs1': probs}, mstype.float32)
if probs is not None: if probs is not None:
self._probs = cast_to_tensor(probs, self.parameter_type) self._probs = cast_to_tensor(probs, self.parameter_type)
check_prob(self._probs) check_prob(self._probs)
else: else:
self._probs = probs self._probs = probs
self.default_parameters = [self.probs]
self.parameter_names = ['probs1']
self.minval = np.finfo(np.float).tiny self.minval = np.finfo(np.float).tiny
# ops needed for the class # ops needed for the class
@ -164,24 +167,12 @@ class Geometric(Distribution):
""" """
return self._probs return self._probs
def _check_param(self, probs1):
"""
Check availablity of distribution specific args probs1.
"""
if probs1 is not None:
if self.context_mode == 0:
self.checktensor(probs1, 'probs1')
else:
probs1 = self.checktensor(probs1, 'probs1')
return self.cast(probs1, self.parameter_type)
return self.probs if self.probs is not None else raise_none_error('probs1')
def _mean(self, probs1=None): def _mean(self, probs1=None):
r""" r"""
.. math:: .. math::
MEAN(Geo) = \fratc{1 - probs1}{probs1} MEAN(Geo) = \fratc{1 - probs1}{probs1}
""" """
probs1 = self._check_param(probs1) probs1 = self._check_param_type(probs1)
return (1. - probs1) / probs1 return (1. - probs1) / probs1
def _mode(self, probs1=None): def _mode(self, probs1=None):
@ -189,7 +180,7 @@ class Geometric(Distribution):
.. math:: .. math::
MODE(Geo) = 0 MODE(Geo) = 0
""" """
probs1 = self._check_param(probs1) probs1 = self._check_param_type(probs1)
return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.) return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.)
def _var(self, probs1=None): def _var(self, probs1=None):
@ -197,7 +188,7 @@ class Geometric(Distribution):
.. math:: .. math::
VAR(Geo) = \frac{1 - probs1}{probs1 ^ {2}} VAR(Geo) = \frac{1 - probs1}{probs1 ^ {2}}
""" """
probs1 = self._check_param(probs1) probs1 = self._check_param_type(probs1)
return (1.0 - probs1) / self.sq(probs1) return (1.0 - probs1) / self.sq(probs1)
def _entropy(self, probs1=None): def _entropy(self, probs1=None):
@ -205,7 +196,7 @@ class Geometric(Distribution):
.. math:: .. math::
H(Geo) = \frac{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1} H(Geo) = \frac{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1}
""" """
probs1 = self._check_param(probs1) probs1 = self._check_param_type(probs1)
probs0 = 1.0 - probs1 probs0 = 1.0 - probs1
return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1 return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1
@ -236,7 +227,7 @@ class Geometric(Distribution):
value = self._check_value(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, mstype.float32) value = self.cast(value, mstype.float32)
value = self.floor(value) value = self.floor(value)
probs1 = self._check_param(probs1) probs1 = self._check_param_type(probs1)
pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1)) pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1))
zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0) zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0)
comp = self.less(value, zeros) comp = self.less(value, zeros)
@ -258,7 +249,7 @@ class Geometric(Distribution):
value = self._check_value(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, mstype.float32) value = self.cast(value, mstype.float32)
value = self.floor(value) value = self.floor(value)
probs1 = self._check_param(probs1) probs1 = self._check_param_type(probs1)
probs0 = 1.0 - probs1 probs0 = 1.0 - probs1
cdf = 1.0 - self.pow(probs0, value + 1.0) cdf = 1.0 - self.pow(probs0, value + 1.0)
zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0) zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0)
@ -280,7 +271,7 @@ class Geometric(Distribution):
check_distribution_name(dist, 'Geometric') check_distribution_name(dist, 'Geometric')
probs1_b = self._check_value(probs1_b, 'probs1_b') probs1_b = self._check_value(probs1_b, 'probs1_b')
probs1_b = self.cast(probs1_b, self.parameter_type) probs1_b = self.cast(probs1_b, self.parameter_type)
probs1_a = self._check_param(probs1) probs1_a = self._check_param_type(probs1)
probs0_a = 1.0 - probs1_a probs0_a = 1.0 - probs1_a
probs0_b = 1.0 - probs1_b probs0_b = 1.0 - probs1_b
return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b) return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b)
@ -297,7 +288,7 @@ class Geometric(Distribution):
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.
""" """
shape = self.checktuple(shape, 'shape') shape = self.checktuple(shape, 'shape')
probs1 = self._check_param(probs1) probs1 = self._check_param_type(probs1)
origin_shape = shape + self.shape(probs1) origin_shape = shape + self.shape(probs1)
if origin_shape == (): if origin_shape == ():
sample_shape = (1,) sample_shape = (1,)

View File

@ -19,7 +19,7 @@ from mindspore.ops import composite as C
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\ from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\
raise_none_error, common_dtype set_param_type
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, erf_generic from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, erf_generic
class Normal(Distribution): class Normal(Distribution):
@ -127,14 +127,17 @@ class Normal(Distribution):
valid_dtype = mstype.float_type valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__) check_type(dtype, valid_dtype, type(self).__name__)
super(Normal, self).__init__(seed, dtype, name, param) super(Normal, self).__init__(seed, dtype, name, param)
self.parameter_type = common_dtype(mean, 'mean', sd, 'sd', self.dtype) self.parameter_type = set_param_type({'mean': mean, 'sd': sd}, self.dtype)
if mean is not None and sd is not None: if mean is not None and sd is not None:
self._mean_value = cast_to_tensor(mean, self.parameter_type) self._mean_value = cast_to_tensor(mean, self.parameter_type)
self._sd_value = cast_to_tensor(sd, self.parameter_type) self._sd_value = cast_to_tensor(sd, self.parameter_type)
check_greater_zero(self._sd_value, "Standard deviation") check_greater_zero(self._sd_value, "Standard deviation")
else: else:
self._mean_value = mean self._mean_value = mean if mean is None else cast_to_tensor(mean, self.parameter_type)
self._sd_value = sd self._sd_value = sd if sd is None else cast_to_tensor(sd, self.parameter_type)
self.default_parameters = [self._mean_value, self._sd_value]
self.parameter_names = ['mean', 'sd']
#ops needed for the class #ops needed for the class
self.exp = exp_generic self.exp = exp_generic
@ -159,51 +162,25 @@ class Normal(Distribution):
str_info = f'batch_shape = {self._broadcast_shape}' str_info = f'batch_shape = {self._broadcast_shape}'
return str_info return str_info
def _check_param(self, mean, sd):
"""
Check availablity of distribution specific args `mean` and `sd`.
"""
if mean is not None:
if self.context_mode == 0:
self.checktensor(mean, 'mean')
else:
mean = self.checktensor(mean, 'mean')
else:
mean = self._mean_value if self._mean_value is not None else raise_none_error('mean')
if sd is not None:
if self.context_mode == 0:
self.checktensor(sd, 'sd')
else:
sd = self.checktensor(sd, 'sd')
else:
sd = self._sd_value if self._sd_value is not None else raise_none_error('sd')
batch_shape = self.shape(mean + sd)
mean = mean * self.fill(self.dtypeop(mean), batch_shape, 1.0)
sd = sd * self.fill(self.dtypeop(sd), batch_shape, 1.0)
self.sametypeshape(mean, sd)
mean = self.cast(mean, self.parameter_type)
sd = self.cast(sd, self.parameter_type)
return mean, sd
def _mean(self, mean=None, sd=None): def _mean(self, mean=None, sd=None):
""" """
The mean of the distribution. The mean of the distribution.
""" """
mean, sd = self._check_param(mean, sd) mean, sd = self._check_param_type(mean, sd)
return mean return mean
def _mode(self, mean=None, sd=None): def _mode(self, mean=None, sd=None):
""" """
The mode of the distribution. The mode of the distribution.
""" """
mean, sd = self._check_param(mean, sd) mean, sd = self._check_param_type(mean, sd)
return mean return mean
def _sd(self, mean=None, sd=None): def _sd(self, mean=None, sd=None):
""" """
The standard deviation of the distribution. The standard deviation of the distribution.
""" """
mean, sd = self._check_param(mean, sd) mean, sd = self._check_param_type(mean, sd)
return sd return sd
def _entropy(self, mean=None, sd=None): def _entropy(self, mean=None, sd=None):
@ -213,7 +190,7 @@ class Normal(Distribution):
.. math:: .. math::
H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma))) H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma)))
""" """
mean, sd = self._check_param(mean, sd) mean, sd = self._check_param_type(mean, sd)
return self.log(self.sqrt(self.const(np.e * 2. * np.pi))) + self.log(sd) return self.log(self.sqrt(self.const(np.e * 2. * np.pi))) + self.log(sd)
def _cross_entropy(self, dist, mean_b, sd_b, mean=None, sd=None): def _cross_entropy(self, dist, mean_b, sd_b, mean=None, sd=None):
@ -244,7 +221,7 @@ class Normal(Distribution):
""" """
value = self._check_value(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, self.dtype) value = self.cast(value, self.dtype)
mean, sd = self._check_param(mean, sd) mean, sd = self._check_param_type(mean, sd)
unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd)) unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd))
neg_normalization = -1. * self.log(self.const(2. * np.pi)) / 2. - self.log(sd) neg_normalization = -1. * self.log(self.const(2. * np.pi)) / 2. - self.log(sd)
return unnormalized_log_prob + neg_normalization return unnormalized_log_prob + neg_normalization
@ -263,7 +240,7 @@ class Normal(Distribution):
""" """
value = self._check_value(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, self.dtype) value = self.cast(value, self.dtype)
mean, sd = self._check_param(mean, sd) mean, sd = self._check_param_type(mean, sd)
sqrt2 = self.sqrt(self.const(2.0)) sqrt2 = self.sqrt(self.const(2.0))
adjusted = (value - mean) / (sd * sqrt2) adjusted = (value - mean) / (sd * sqrt2)
return 0.5 * (1.0 + self.erf(adjusted)) return 0.5 * (1.0 + self.erf(adjusted))
@ -288,7 +265,7 @@ class Normal(Distribution):
sd_b = self._check_value(sd_b, 'sd_b') sd_b = self._check_value(sd_b, 'sd_b')
mean_b = self.cast(mean_b, self.parameter_type) mean_b = self.cast(mean_b, self.parameter_type)
sd_b = self.cast(sd_b, self.parameter_type) sd_b = self.cast(sd_b, self.parameter_type)
mean_a, sd_a = self._check_param(mean, sd) mean_a, sd_a = self._check_param_type(mean, sd)
diff_log_scale = self.log(sd_a) - self.log(sd_b) diff_log_scale = self.log(sd_a) - self.log(sd_b)
squared_diff = self.sq(mean_a / sd_b - mean_b / sd_b) squared_diff = self.sq(mean_a / sd_b - mean_b / sd_b)
return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale
@ -306,7 +283,7 @@ class Normal(Distribution):
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.
""" """
shape = self.checktuple(shape, 'shape') shape = self.checktuple(shape, 'shape')
mean, sd = self._check_param(mean, sd) mean, sd = self._check_param_type(mean, sd)
batch_shape = self.shape(mean + sd) batch_shape = self.shape(mean + sd)
origin_shape = shape + batch_shape origin_shape = shape + batch_shape
if origin_shape == (): if origin_shape == ():

View File

@ -18,7 +18,7 @@ from mindspore.ops import composite as C
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_greater, check_type, check_distribution_name,\ from ._utils.utils import cast_to_tensor, check_greater, check_type, check_distribution_name,\
raise_none_error, common_dtype set_param_type
from ._utils.custom_ops import exp_generic, log_generic from ._utils.custom_ops import exp_generic, log_generic
class Uniform(Distribution): class Uniform(Distribution):
@ -126,14 +126,17 @@ class Uniform(Distribution):
valid_dtype = mstype.float_type valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__) check_type(dtype, valid_dtype, type(self).__name__)
super(Uniform, self).__init__(seed, dtype, name, param) super(Uniform, self).__init__(seed, dtype, name, param)
self.parameter_type = common_dtype(low, 'low', high, 'high', self.dtype) self.parameter_type = set_param_type({'low': low, 'high': high}, self.dtype)
if low is not None and high is not None: if low is not None and high is not None:
self._low = cast_to_tensor(low, dtype) self._low = cast_to_tensor(low, self.parameter_type)
self._high = cast_to_tensor(high, dtype) self._high = cast_to_tensor(high, self.parameter_type)
check_greater(self.low, self.high, "low value", "high value") check_greater(self.low, self.high, "low value", "high value")
else: else:
self._low = low self._low = low if low is None else cast_to_tensor(low, self.parameter_type)
self._high = high self._high = high if high is None else cast_to_tensor(high, self.parameter_type)
self.default_parameters = [self.low, self.high]
self.parameter_names = ['low', 'high']
# ops needed for the class # ops needed for the class
self.exp = exp_generic self.exp = exp_generic
@ -162,32 +165,6 @@ class Uniform(Distribution):
str_info = f'batch_shape = {self._broadcast_shape}' str_info = f'batch_shape = {self._broadcast_shape}'
return str_info return str_info
def _check_param(self, low, high):
"""
Check availablity of distribution specific args `low` and `high`.
"""
if low is not None:
if self.context_mode == 0:
self.checktensor(low, 'low')
else:
low = self.checktensor(low, 'low')
else:
low = self.low if self.low is not None else raise_none_error('low')
if high is not None:
if self.context_mode == 0:
self.checktensor(high, 'high')
else:
high = self.checktensor(high, 'high')
else:
high = self.high if self.high is not None else raise_none_error('high')
batch_shape = self.shape(high - low)
high = high * self.fill(self.dtypeop(high), batch_shape, 1.0)
low = low * self.fill(self.dtypeop(low), batch_shape, 1.0)
self.sametypeshape(high, low)
low = self.cast(low, self.parameter_type)
high = self.cast(high, self.parameter_type)
return low, high
@property @property
def low(self): def low(self):
""" """
@ -209,7 +186,7 @@ class Uniform(Distribution):
.. math:: .. math::
range(U) = high -low range(U) = high -low
""" """
low, high = self._check_param(low, high) low, high = self._check_param_type(low, high)
return high - low return high - low
def _mean(self, low=None, high=None): def _mean(self, low=None, high=None):
@ -217,7 +194,7 @@ class Uniform(Distribution):
.. math:: .. math::
MEAN(U) = \frac{low + high}{2}. MEAN(U) = \frac{low + high}{2}.
""" """
low, high = self._check_param(low, high) low, high = self._check_param_type(low, high)
return (low + high) / 2. return (low + high) / 2.
def _var(self, low=None, high=None): def _var(self, low=None, high=None):
@ -225,7 +202,7 @@ class Uniform(Distribution):
.. math:: .. math::
VAR(U) = \frac{(high -low) ^ 2}{12}. VAR(U) = \frac{(high -low) ^ 2}{12}.
""" """
low, high = self._check_param(low, high) low, high = self._check_param_type(low, high)
return self.sq(high - low) / 12.0 return self.sq(high - low) / 12.0
def _entropy(self, low=None, high=None): def _entropy(self, low=None, high=None):
@ -233,7 +210,7 @@ class Uniform(Distribution):
.. math:: .. math::
H(U) = \log(high - low). H(U) = \log(high - low).
""" """
low, high = self._check_param(low, high) low, high = self._check_param_type(low, high)
return self.log(high - low) return self.log(high - low)
def _cross_entropy(self, dist, low_b, high_b, low=None, high=None): def _cross_entropy(self, dist, low_b, high_b, low=None, high=None):
@ -266,7 +243,7 @@ class Uniform(Distribution):
""" """
value = self._check_value(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, self.dtype) value = self.cast(value, self.dtype)
low, high = self._check_param(low, high) low, high = self._check_param_type(low, high)
neg_ones = self.fill(self.dtype, self.shape(value), -1.0) neg_ones = self.fill(self.dtype, self.shape(value), -1.0)
prob = self.exp(neg_ones * self.log(high - low)) prob = self.exp(neg_ones * self.log(high - low))
broadcast_shape = self.shape(prob) broadcast_shape = self.shape(prob)
@ -292,7 +269,7 @@ class Uniform(Distribution):
low_b = self.cast(low_b, self.parameter_type) low_b = self.cast(low_b, self.parameter_type)
high_b = self._check_value(high_b, 'high_b') high_b = self._check_value(high_b, 'high_b')
high_b = self.cast(high_b, self.parameter_type) high_b = self.cast(high_b, self.parameter_type)
low_a, high_a = self._check_param(low, high) low_a, high_a = self._check_param_type(low, high)
kl = self.log(high_b - low_b) - self.log(high_a - low_a) kl = self.log(high_b - low_b) - self.log(high_a - low_a)
comp = self.logicaland(self.lessequal(low_b, low_a), self.lessequal(high_a, high_b)) comp = self.logicaland(self.lessequal(low_b, low_a), self.lessequal(high_a, high_b))
return self.select(comp, kl, self.log(self.zeroslike(kl))) return self.select(comp, kl, self.log(self.zeroslike(kl)))
@ -313,7 +290,7 @@ class Uniform(Distribution):
""" """
value = self._check_value(value, 'value') value = self._check_value(value, 'value')
value = self.cast(value, self.dtype) value = self.cast(value, self.dtype)
low, high = self._check_param(low, high) low, high = self._check_param_type(low, high)
prob = (value - low) / (high - low) prob = (value - low) / (high - low)
broadcast_shape = self.shape(prob) broadcast_shape = self.shape(prob)
zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0) zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
@ -336,7 +313,7 @@ class Uniform(Distribution):
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.
""" """
shape = self.checktuple(shape, 'shape') shape = self.checktuple(shape, 'shape')
low, high = self._check_param(low, high) low, high = self._check_param_type(low, high)
broadcast_shape = self.shape(low + high) broadcast_shape = self.shape(low + high)
origin_shape = shape + broadcast_shape origin_shape = shape + broadcast_shape
if origin_shape == (): if origin_shape == ():

View File

@ -0,0 +1,182 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Test util functions used in distribution classes.
"""
import numpy as np
import pytest
from mindspore.nn.cell import Cell
from mindspore import context
from mindspore import dtype
from mindspore import Tensor
from mindspore.common.parameter import Parameter
from mindspore.nn.probability.distribution._utils.utils import set_param_type, \
cast_to_tensor, CheckTuple, CheckTensor
def test_set_param_type():
"""
Test set_param_type function.
"""
tensor_fp16 = Tensor(0.1, dtype=dtype.float16)
tensor_fp32 = Tensor(0.1, dtype=dtype.float32)
tensor_fp64 = Tensor(0.1, dtype=dtype.float64)
tensor_int32 = Tensor(0.1, dtype=dtype.int32)
array_fp32 = np.array(1.0).astype(np.float32)
array_fp64 = np.array(1.0).astype(np.float64)
array_int32 = np.array(1.0).astype(np.int32)
dict1 = {'a': tensor_fp32, 'b': 1.0, 'c': tensor_fp32}
dict2 = {'a': tensor_fp32, 'b': 1.0, 'c': tensor_fp64}
dict3 = {'a': tensor_int32, 'b': 1.0, 'c': tensor_int32}
dict4 = {'a': array_fp32, 'b': 1.0, 'c': tensor_fp32}
dict5 = {'a': array_fp32, 'b': 1.0, 'c': array_fp64}
dict6 = {'a': array_fp32, 'b': 1.0, 'c': array_int32}
dict7 = {'a': 1.0}
dict8 = {'a': 1.0, 'b': 1.0, 'c': 1.0}
dict9 = {'a': tensor_fp16, 'b': tensor_fp16, 'c': tensor_fp16}
dict10 = {'a': tensor_fp64, 'b': tensor_fp64, 'c': tensor_fp64}
dict11 = {'a': array_fp64, 'b': array_fp64, 'c': tensor_fp64}
ans1 = set_param_type(dict1, dtype.float16)
assert ans1 == dtype.float32
with pytest.raises(TypeError):
set_param_type(dict2, dtype.float32)
ans3 = set_param_type(dict3, dtype.float16)
assert ans3 == dtype.float32
ans4 = set_param_type(dict4, dtype.float16)
assert ans4 == dtype.float32
with pytest.raises(TypeError):
set_param_type(dict5, dtype.float32)
with pytest.raises(TypeError):
set_param_type(dict6, dtype.float32)
ans7 = set_param_type(dict7, dtype.float32)
assert ans7 == dtype.float32
ans8 = set_param_type(dict8, dtype.float32)
assert ans8 == dtype.float32
ans9 = set_param_type(dict9, dtype.float32)
assert ans9 == dtype.float16
ans10 = set_param_type(dict10, dtype.float32)
assert ans10 == dtype.float32
ans11 = set_param_type(dict11, dtype.float32)
assert ans11 == dtype.float32
def test_cast_to_tensor():
"""
Test cast_to_tensor.
"""
with pytest.raises(ValueError):
cast_to_tensor(None, dtype.float32)
with pytest.raises(TypeError):
cast_to_tensor(True, dtype.float32)
with pytest.raises(TypeError):
cast_to_tensor({'a': 1, 'b': 2}, dtype.float32)
with pytest.raises(TypeError):
cast_to_tensor('tensor', dtype.float32)
ans1 = cast_to_tensor(Parameter(Tensor(0.1, dtype=dtype.float32), 'param'))
assert isinstance(ans1, Parameter)
ans2 = cast_to_tensor(np.array(1.0).astype(np.float32))
assert isinstance(ans2, Tensor)
ans3 = cast_to_tensor([1.0, 2.0])
assert isinstance(ans3, Tensor)
ans4 = cast_to_tensor(Tensor(0.1, dtype=dtype.float32), dtype.float32)
assert isinstance(ans4, Tensor)
ans5 = cast_to_tensor(0.1, dtype.float32)
assert isinstance(ans5, Tensor)
ans6 = cast_to_tensor(1, dtype.float32)
assert isinstance(ans6, Tensor)
class Net(Cell):
"""
Test class: CheckTuple.
"""
def __init__(self, value):
super(Net, self).__init__()
self.checktuple = CheckTuple()
self.value = value
def construct(self, value=None):
if value is None:
return self.checktuple(self.value, 'input')
return self.checktuple(value, 'input')
def test_check_tuple():
"""
Test CheckTuple.
"""
net1 = Net((1, 2, 3))
ans1 = net1()
assert isinstance(ans1, tuple)
with pytest.raises(TypeError):
net2 = Net('tuple')
net2()
context.set_context(mode=context.GRAPH_MODE)
net3 = Net((1, 2, 3))
ans3 = net3()
assert isinstance(ans3, tuple)
with pytest.raises(TypeError):
net4 = Net('tuple')
net4()
class Net1(Cell):
"""
Test class: CheckTensor.
"""
def __init__(self, value):
super(Net1, self).__init__()
self.checktensor = CheckTensor()
self.value = value
self.context = context.get_context('mode')
def construct(self, value=None):
value = self.value if value is None else value
if self.context == 0:
self.checktensor(value, 'input')
return value
return self.checktensor(value, 'input')
def test_check_tensor():
"""
Test CheckTensor.
"""
value = Tensor(0.1, dtype=dtype.float32)
net1 = Net1(value)
ans1 = net1()
assert isinstance(ans1, Tensor)
ans1 = net1(value)
assert isinstance(ans1, Tensor)
with pytest.raises(TypeError):
net2 = Net1('tuple')
net2()
context.set_context(mode=context.GRAPH_MODE)
net3 = Net1(value)
ans3 = net3()
assert isinstance(ans3, Tensor)
ans3 = net3(value)
assert isinstance(ans3, Tensor)
with pytest.raises(TypeError):
net4 = Net1('tuple')
net4()