diff --git a/mindspore/nn/probability/distribution/_utils/utils.py b/mindspore/nn/probability/distribution/_utils/utils.py index dbc34f79ba2..62ac681a272 100644 --- a/mindspore/nn/probability/distribution/_utils/utils.py +++ b/mindspore/nn/probability/distribution/_utils/utils.py @@ -110,6 +110,8 @@ def check_scalar_from_param(params): Notes: String parameters are excluded. """ for value in params.values(): + if value is None: + continue if isinstance(value, (msp.bijector.Bijector, msp.distribution.Distribution)): return params['distribution'].is_scalar_batch if isinstance(value, Parameter): @@ -358,23 +360,29 @@ class CheckTensor(PrimitiveWithInfer): return x raise TypeError(f"For {name}, input type should be a Tensor or Parameter.") -def common_dtype(arg_a, name_a, arg_b, name_b, hint_type): +def set_param_type(args, hint_type): """ - check if arg_a and arg_b have the same dtype. + Find the common type among arguments. + + Args: + args (dict): dictionary of arguments, {'name':value}. + hint_type (mindspore.dtype): hint type to return. + + Raises: + TypeError: if tensors in args are not the same dtype. """ - if hasattr(arg_a, 'dtype') and hasattr(arg_b, 'dtype'): - if isinstance(arg_a, np.ndarray): - a_dtype = mstype.pytype_to_dtype(arg_a.dtype) - else: - a_dtype = arg_a.dtype - if isinstance(arg_b, np.ndarray): - b_dtype = mstype.pytype_to_dtype(arg_b.dtype) - else: - b_dtype = arg_b.dtype - if a_dtype != b_dtype: - raise TypeError(f"{name_a} and {name_b} should have the same dtype.") - int_type = mstype.int_type + mstype.uint_type - if a_dtype in int_type or a_dtype == mstype.float64: - return mstype.float32 - return a_dtype - return hint_type + common_dtype = None + for name, arg in args.items(): + if hasattr(arg, 'dtype'): + if isinstance(arg, np.ndarray): + cur_dtype = mstype.pytype_to_dtype(arg.dtype) + else: + cur_dtype = arg.dtype + if common_dtype is None: + common_dtype = cur_dtype + elif cur_dtype != common_dtype: + raise TypeError(f"{name} should have the same dtype as other arguments.") + int_type = mstype.int_type + mstype.uint_type + if common_dtype in int_type or common_dtype == mstype.float64: + return mstype.float32 + return hint_type if common_dtype is None else common_dtype diff --git a/mindspore/nn/probability/distribution/bernoulli.py b/mindspore/nn/probability/distribution/bernoulli.py index 8c0d5cd5a38..c3f4fa01814 100644 --- a/mindspore/nn/probability/distribution/bernoulli.py +++ b/mindspore/nn/probability/distribution/bernoulli.py @@ -17,7 +17,7 @@ from mindspore.common import dtype as mstype from mindspore.ops import operations as P from mindspore.ops import composite as C from .distribution import Distribution -from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error +from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, set_param_type from ._utils.custom_ops import exp_generic, log_generic, erf_generic @@ -119,13 +119,16 @@ class Bernoulli(Distribution): valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type check_type(dtype, valid_dtype, type(self).__name__) super(Bernoulli, self).__init__(seed, dtype, name, param) - self.parameter_type = mstype.float32 + self.parameter_type = set_param_type({'probs1': probs}, mstype.float32) if probs is not None: - self._probs = cast_to_tensor(probs, mstype.float32) + self._probs = cast_to_tensor(probs, self.parameter_type) check_prob(self.probs) else: self._probs = probs + self.default_parameters = [self.probs] + self.parameter_names = ['probs1'] + # ops needed for the class self.exp = exp_generic self.log = log_generic @@ -157,24 +160,12 @@ class Bernoulli(Distribution): """ return self._probs - def _check_param(self, probs1): - """ - Check availablity of distribution specific args `probs1`. - """ - if probs1 is not None: - if self.context_mode == 0: - self.checktensor(probs1, 'probs1') - else: - probs1 = self.checktensor(probs1, 'probs1') - return self.cast(probs1, self.parameter_type) - return self.probs if self.probs is not None else raise_none_error('probs1') - def _mean(self, probs1=None): r""" .. math:: MEAN(B) = probs1 """ - probs1 = self._check_param(probs1) + probs1 = self._check_param_type(probs1) return probs1 def _mode(self, probs1=None): @@ -182,7 +173,7 @@ class Bernoulli(Distribution): .. math:: MODE(B) = 1 if probs1 > 0.5 else = 0 """ - probs1 = self._check_param(probs1) + probs1 = self._check_param_type(probs1) prob_type = self.dtypeop(probs1) zeros = self.fill(prob_type, self.shape(probs1), 0.0) ones = self.fill(prob_type, self.shape(probs1), 1.0) @@ -194,7 +185,7 @@ class Bernoulli(Distribution): .. math:: VAR(B) = probs1 * probs0 """ - probs1 = self._check_param(probs1) + probs1 = self._check_param_type(probs1) probs0 = 1.0 - probs1 return self.exp(self.log(probs0) + self.log(probs1)) @@ -203,11 +194,11 @@ class Bernoulli(Distribution): .. math:: H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1) """ - probs1 = self._check_param(probs1) + probs1 = self._check_param_type(probs1) probs0 = 1.0 - probs1 return -(probs0 * self.log(probs0)) - (probs1 * self.log(probs1)) - def _cross_entropy(self, dist, probs1_b, probs1_a=None): + def _cross_entropy(self, dist, probs1_b, probs1=None): """ Evaluate cross_entropy between Bernoulli distributions. @@ -217,7 +208,7 @@ class Bernoulli(Distribution): probs1_a (Tensor): `probs1` of distribution a. Default: self.probs. """ check_distribution_name(dist, 'Bernoulli') - return self._entropy(probs1_a) + self._kl_loss(dist, probs1_b, probs1_a) + return self._entropy(probs1) + self._kl_loss(dist, probs1_b, probs1) def _log_prob(self, value, probs1=None): r""" @@ -233,7 +224,7 @@ class Bernoulli(Distribution): """ value = self._check_value(value, 'value') value = self.cast(value, mstype.float32) - probs1 = self._check_param(probs1) + probs1 = self._check_param_type(probs1) probs0 = 1.0 - probs1 return self.log(probs1) * value + self.log(probs0) * (1.0 - value) @@ -253,7 +244,7 @@ class Bernoulli(Distribution): value = self._check_value(value, 'value') value = self.cast(value, mstype.float32) value = self.floor(value) - probs1 = self._check_param(probs1) + probs1 = self._check_param_type(probs1) prob_type = self.dtypeop(probs1) value = value * self.fill(prob_type, self.shape(probs1), 1.0) probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0) @@ -264,7 +255,7 @@ class Bernoulli(Distribution): less_than_zero = self.select(comp_zero, zeros, probs0) return self.select(comp_one, less_than_zero, ones) - def _kl_loss(self, dist, probs1_b, probs1_a=None): + def _kl_loss(self, dist, probs1_b, probs1=None): r""" Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b). @@ -280,7 +271,7 @@ class Bernoulli(Distribution): check_distribution_name(dist, 'Bernoulli') probs1_b = self._check_value(probs1_b, 'probs1_b') probs1_b = self.cast(probs1_b, self.parameter_type) - probs1_a = self._check_param(probs1_a) + probs1_a = self._check_param_type(probs1) probs0_a = 1.0 - probs1_a probs0_b = 1.0 - probs1_b return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b) @@ -297,7 +288,7 @@ class Bernoulli(Distribution): Tensor, shape is shape + batch_shape. """ shape = self.checktuple(shape, 'shape') - probs1 = self._check_param(probs1) + probs1 = self._check_param_type(probs1) origin_shape = shape + self.shape(probs1) if origin_shape == (): sample_shape = (1,) diff --git a/mindspore/nn/probability/distribution/distribution.py b/mindspore/nn/probability/distribution/distribution.py index 2388f54d7a0..0439d146285 100644 --- a/mindspore/nn/probability/distribution/distribution.py +++ b/mindspore/nn/probability/distribution/distribution.py @@ -18,7 +18,8 @@ from mindspore.nn.cell import Cell from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from mindspore.common import get_seed -from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param, cast_type_for_device +from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param, cast_type_for_device,\ + raise_none_error from ._utils.utils import CheckTuple, CheckTensor @@ -115,6 +116,51 @@ class Distribution(Cell): def broadcast_shape(self): return self._broadcast_shape + def _check_param_type(self, *args): + """ + Check the availability and validity of default parameters and dist_spec_args. + dist_spec_args passed in must be tensors. If default parameter of the distribution + is None, its parameter must be passed in through `args`. + """ + broadcast_shape = None + common_dtype = None + out = [] + + for arg, name, default in zip(args, self.parameter_names, self.default_parameters): + # check if the argument is a Tensor + if arg is not None: + if self.context_mode == 0: + self.checktensor(arg, name) + else: + arg = self.checktensor(arg, name) + else: + arg = default if default is not None else raise_none_error(name) + + # broadcast if the number of args > 1 + if broadcast_shape is None: + broadcast_shape = self.shape(arg) + common_dtype = self.dtypeop(arg) + else: + ones = self.fill(self.dtypeop(arg), broadcast_shape, 1.0) + broadcast_shape = self.shape(arg + ones) + + # check if the arguments have the same dtype + arg = arg * self.fill(self.dtypeop(arg), broadcast_shape, 1.0) + dtype_tensor = self.fill(common_dtype, broadcast_shape, 1.0) + self.sametypeshape(arg, dtype_tensor) + arg = self.cast(arg, self.parameter_type) + out.append(arg) + + if len(out) == 1: + return out[0] + + # broadcast all args to broadcast_shape + result = () + for arg in out: + arg = arg * self.fill(self.dtypeop(arg), broadcast_shape, 1.0) + result = result + (arg,) + return result + def _check_value(self, value, name): """ Check availability of `value` as a Tensor. @@ -211,163 +257,203 @@ class Distribution(Cell): if hasattr(self, '_cross_entropy'): self._call_cross_entropy = self._cross_entropy - def log_prob(self, *args, **kwargs): + def log_prob(self, value, *args, **kwargs): """ Evaluate the log probability(pdf or pmf) at the given value. - Note: - The argument `args` must include `value`. - dist_spec_args are optional. - """ - return self._call_log_prob(*args, **kwargs) + Args: + value (Tensor): value to be evaluated. + *args (list): the list of positional arguments forwarded to subclasses. + **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. - def _calc_prob_from_log_prob(self, *args, **kwargs): + Note: + A distribution can be optionally passed to the function by passing its dist_spec_args through + `args` or `kwargs`. + """ + return self._call_log_prob(value, *args, **kwargs) + + def _calc_prob_from_log_prob(self, value, *args, **kwargs): r""" Evaluate prob from log probability. .. math:: probability(x) = \exp(log_likehood(x)) """ - return self.exp(self._log_prob(*args, **kwargs)) + return self.exp(self._log_prob(value, *args, **kwargs)) - def prob(self, *args, **kwargs): + def prob(self, value, *args, **kwargs): """ Evaluate the probability (pdf or pmf) at given value. - Note: - The argument `args` must include `value`. - dist_spec_args are optional. - """ - return self._call_prob(*args, **kwargs) + Args: + value (Tensor): value to be evaluated. + *args (list): the list of positional arguments forwarded to subclasses. + **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. - def _calc_log_prob_from_prob(self, *args, **kwargs): + Note: + A distribution can be optionally passed to the function by passing its dist_spec_args through + `args` or `kwargs`. + """ + return self._call_prob(value, *args, **kwargs) + + def _calc_log_prob_from_prob(self, value, *args, **kwargs): r""" Evaluate log probability from probability. .. math:: log_prob(x) = \log(prob(x)) """ - return self.log(self._prob(*args, **kwargs)) + return self.log(self._prob(value, *args, **kwargs)) - def cdf(self, *args, **kwargs): + def cdf(self, value, *args, **kwargs): """ Evaluate the cdf at given value. - Note: - The argument `args` must include `value`. - dist_spec_args are optional. - """ - return self._call_cdf(*args, **kwargs) + Args: + value (Tensor): value to be evaluated. + *args (list): the list of positional arguments forwarded to subclasses. + **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. - def _calc_cdf_from_log_cdf(self, *args, **kwargs): + Note: + A distribution can be optionally passed to the function by passing its dist_spec_args through + `args` or `kwargs`. + """ + return self._call_cdf(value, *args, **kwargs) + + def _calc_cdf_from_log_cdf(self, value, *args, **kwargs): r""" Evaluate cdf from log_cdf. .. math:: cdf(x) = \exp(log_cdf(x)) """ - return self.exp(self._log_cdf(*args, **kwargs)) + return self.exp(self._log_cdf(value, *args, **kwargs)) - def _calc_cdf_from_survival(self, *args, **kwargs): + def _calc_cdf_from_survival(self, value, *args, **kwargs): r""" Evaluate cdf from survival function. .. math:: cdf(x) = 1 - (survival_function(x)) """ - return 1.0 - self._survival_function(*args, **kwargs) + return 1.0 - self._survival_function(value, *args, **kwargs) - def _calc_cdf_from_log_survival(self, *args, **kwargs): + def _calc_cdf_from_log_survival(self, value, *args, **kwargs): r""" Evaluate cdf from log survival function. .. math:: cdf(x) = 1 - (\exp(log_survival(x))) """ - return 1.0 - self.exp(self._log_survival(*args, **kwargs)) + return 1.0 - self.exp(self._log_survival(value, *args, **kwargs)) - def log_cdf(self, *args, **kwargs): + def log_cdf(self, value, *args, **kwargs): """ Evaluate the log cdf at given value. - Note: - The argument `args` must include `value`. - dist_spec_args are optional. - """ - return self._call_log_cdf(*args, **kwargs) + Args: + value (Tensor): value to be evaluated. + *args (list): the list of positional arguments forwarded to subclasses. + **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. - def _calc_log_cdf_from_call_cdf(self, *args, **kwargs): + Note: + A distribution can be optionally passed to the function by passing its dist_spec_args through + `args` or `kwargs`. + """ + return self._call_log_cdf(value, *args, **kwargs) + + def _calc_log_cdf_from_call_cdf(self, value, *args, **kwargs): r""" Evaluate log cdf from cdf. .. math:: log_cdf(x) = \log(cdf(x)) """ - return self.log(self._call_cdf(*args, **kwargs)) + return self.log(self._call_cdf(value, *args, **kwargs)) - def survival_function(self, *args, **kwargs): + def survival_function(self, value, *args, **kwargs): """ Evaluate the survival function at given value. - Note: - The argument `args` must include `value`. - dist_spec_args are optional. - """ - return self._call_survival(*args, **kwargs) + Args: + value (Tensor): value to be evaluated. + *args (list): the list of positional arguments forwarded to subclasses. + **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. - def _calc_survival_from_call_cdf(self, *args, **kwargs): + Note: + A distribution can be optionally passed to the function by passing its dist_spec_args through + `args` or `kwargs`. + """ + return self._call_survival(value, *args, **kwargs) + + def _calc_survival_from_call_cdf(self, value, *args, **kwargs): r""" Evaluate survival function from cdf. .. math:: survival_function(x) = 1 - (cdf(x)) """ - return 1.0 - self._call_cdf(*args, **kwargs) + return 1.0 - self._call_cdf(value, *args, **kwargs) - def _calc_survival_from_log_survival(self, *args, **kwargs): + def _calc_survival_from_log_survival(self, value, *args, **kwargs): r""" Evaluate survival function from log survival function. .. math:: survival(x) = \exp(survival_function(x)) """ - return self.exp(self._log_survival(*args, **kwargs)) + return self.exp(self._log_survival(value, *args, **kwargs)) - def log_survival(self, *args, **kwargs): + def log_survival(self, value, *args, **kwargs): """ Evaluate the log survival function at given value. - Note: - The arguments `args` must include `value`. - dist_spec_args are optional. - """ - return self._call_log_survival(*args, **kwargs) + Args: + value (Tensor): value to be evaluated. + *args (list): the list of positional arguments forwarded to subclasses. + **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. - def _calc_log_survival_from_call_survival(self, *args, **kwargs): + Note: + A distribution can be optionally passed to the function by passing its dist_spec_args through + `args` or `kwargs`. + """ + return self._call_log_survival(value, *args, **kwargs) + + def _calc_log_survival_from_call_survival(self, value, *args, **kwargs): r""" Evaluate log survival function from survival function. .. math:: log_survival(x) = \log(survival_function(x)) """ - return self.log(self._call_survival(*args, **kwargs)) + return self.log(self._call_survival(value, *args, **kwargs)) - def kl_loss(self, *args, **kwargs): + def kl_loss(self, dist, *args, **kwargs): """ Evaluate the KL divergence, i.e. KL(a||b). + Args: + dist (str): type of the distribution. + *args (list): the list of positional arguments forwarded to subclasses. + **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. + Note: - The argument `args` must include the type of the distribution, parameters of distribution b. - Parameters for distribution a are optional. + dist_spec_args of distribution b must be passed to the function through `args` or `kwargs`. + Passing in dist_spec_args of distribution a is optional. """ - return self._kl_loss(*args, **kwargs) + return self._kl_loss(dist, *args, **kwargs) def mean(self, *args, **kwargs): """ Evaluate the mean. + Args: + *args (list): the list of positional arguments forwarded to subclasses. + **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. + Note: - dist_spec_args are optional. + A distribution can be optionally passed to the function by passing its *dist_spec_args* through + *args* or *kwargs*. """ return self._mean(*args, **kwargs) @@ -375,8 +461,13 @@ class Distribution(Cell): """ Evaluate the mode. + Args: + *args (list): the list of positional arguments forwarded to subclasses. + **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. + Note: - dist_spec_args are optional. + A distribution can be optionally passed to the function by passing its *dist_spec_args* through + *args* or *kwargs*. """ return self._mode(*args, **kwargs) @@ -384,8 +475,13 @@ class Distribution(Cell): """ Evaluate the standard deviation. + Args: + *args (list): the list of positional arguments forwarded to subclasses. + **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. + Note: - dist_spec_args are optional. + A distribution can be optionally passed to the function by passing its *dist_spec_args* through + *args* or *kwargs*. """ return self._call_sd(*args, **kwargs) @@ -393,8 +489,13 @@ class Distribution(Cell): """ Evaluate the variance. + Args: + *args (list): the list of positional arguments forwarded to subclasses. + **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. + Note: - dist_spec_args are optional. + A distribution can be optionally passed to the function by passing its *dist_spec_args* through + *args* or *kwargs*. """ return self._call_var(*args, **kwargs) @@ -420,37 +521,52 @@ class Distribution(Cell): """ Evaluate the entropy. + Args: + *args (list): the list of positional arguments forwarded to subclasses. + **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. + Note: - dist_spec_args are optional. + A distribution can be optionally passed to the function by passing its *dist_spec_args* through + *args* or *kwargs*. """ return self._entropy(*args, **kwargs) - def cross_entropy(self, *args, **kwargs): + def cross_entropy(self, dist, *args, **kwargs): """ Evaluate the cross_entropy between distribution a and b. - Note: - The argument `args` must include the type of the distribution, parameters of distribution b. - Parameters for distribution a are optional. - """ - return self._call_cross_entropy(*args, **kwargs) + Args: + dist (str): type of the distribution. + *args (list): the list of positional arguments forwarded to subclasses. + **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. - def _calc_cross_entropy(self, *args, **kwargs): + Note: + dist_spec_args of distribution b must be passed to the function through `args` or `kwargs`. + Passing in dist_spec_args of distribution a is optional. + """ + return self._call_cross_entropy(dist, *args, **kwargs) + + def _calc_cross_entropy(self, dist, *args, **kwargs): r""" Evaluate cross_entropy from entropy and kl divergence. .. math:: H(X, Y) = H(X) + KL(X||Y) """ - return self._entropy(*args, **kwargs) + self._kl_loss(*args, **kwargs) + return self._entropy(*args, **kwargs) + self._kl_loss(dist, *args, **kwargs) def sample(self, *args, **kwargs): """ Sampling function. + Args: + shape (tuple): shape of the sample. + *args (list): the list of positional arguments forwarded to subclasses. + **kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses. + Note: - Shape of the sample is default to (). - dist_spec_args are optional. + A distribution can be optionally passed to the function by passing its *dist_spec_args* through + *args* or *kwargs*. """ return self._sample(*args, **kwargs) diff --git a/mindspore/nn/probability/distribution/exponential.py b/mindspore/nn/probability/distribution/exponential.py index 4170cde88a7..c12234782e5 100644 --- a/mindspore/nn/probability/distribution/exponential.py +++ b/mindspore/nn/probability/distribution/exponential.py @@ -18,8 +18,7 @@ from mindspore.ops import operations as P from mindspore.ops import composite as C from mindspore.common import dtype as mstype from .distribution import Distribution -from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\ - raise_none_error +from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name, set_param_type from ._utils.custom_ops import exp_generic, log_generic class Exponential(Distribution): @@ -121,15 +120,19 @@ class Exponential(Distribution): valid_dtype = mstype.float_type check_type(dtype, valid_dtype, type(self).__name__) super(Exponential, self).__init__(seed, dtype, name, param) - self.parameter_type = dtype + self.parameter_type = set_param_type({'rate': rate}, self.dtype) if rate is not None: self._rate = cast_to_tensor(rate, self.parameter_type) check_greater_zero(self._rate, "rate") else: self._rate = rate + self.default_parameters = [self.rate] + self.parameter_names = ['rate'] + self.minval = np.finfo(np.float).tiny + # ops needed for the class self.exp = exp_generic self.log = log_generic @@ -156,28 +159,16 @@ class Exponential(Distribution): @property def rate(self): """ - Return rate of the distribution. + Return `rate` of the distribution. """ return self._rate - def _check_param(self, rate): - """ - Check availablity of distribution specific argument `rate`. - """ - if rate is not None: - if self.context_mode == 0: - self.checktensor(rate, 'rate') - else: - rate = self.checktensor(rate, 'rate') - return self.cast(rate, self.parameter_type) - return self.rate if self.rate is not None else raise_none_error('rate') - def _mean(self, rate=None): r""" .. math:: MEAN(EXP) = \frac{1.0}{\lambda}. """ - rate = self._check_param(rate) + rate = self._check_param_type(rate) return 1.0 / rate def _mode(self, rate=None): @@ -185,7 +176,7 @@ class Exponential(Distribution): .. math:: MODE(EXP) = 0. """ - rate = self._check_param(rate) + rate = self._check_param_type(rate) return self.fill(self.dtype, self.shape(rate), 0.) def _sd(self, rate=None): @@ -193,7 +184,7 @@ class Exponential(Distribution): .. math:: sd(EXP) = \frac{1.0}{\lambda}. """ - rate = self._check_param(rate) + rate = self._check_param_type(rate) return 1.0 / rate def _entropy(self, rate=None): @@ -201,7 +192,7 @@ class Exponential(Distribution): .. math:: H(Exp) = 1 - \log(\lambda). """ - rate = self._check_param(rate) + rate = self._check_param_type(rate) return 1.0 - self.log(rate) def _cross_entropy(self, dist, rate_b, rate=None): @@ -234,7 +225,7 @@ class Exponential(Distribution): """ value = self._check_value(value, "value") value = self.cast(value, self.dtype) - rate = self._check_param(rate) + rate = self._check_param_type(rate) prob = self.log(rate) - rate * value zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0) neginf = self.fill(self.dtypeop(prob), self.shape(prob), -np.inf) @@ -257,7 +248,7 @@ class Exponential(Distribution): """ value = self._check_value(value, 'value') value = self.cast(value, self.dtype) - rate = self._check_param(rate) + rate = self._check_param_type(rate) cdf = 1.0 - self.exp(-1. * rate * value) zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0) comp = self.less(value, zeros) @@ -279,7 +270,7 @@ class Exponential(Distribution): """ value = self._check_value(value, 'value') value = self.cast(value, self.dtype) - rate = self._check_param(rate) + rate = self._check_param_type(rate) sf = -1. * rate * value zeros = self.fill(self.dtypeop(sf), self.shape(sf), 0.0) comp = self.less(value, zeros) @@ -297,7 +288,7 @@ class Exponential(Distribution): check_distribution_name(dist, 'Exponential') rate_b = self._check_value(rate_b, 'rate_b') rate_b = self.cast(rate_b, self.parameter_type) - rate_a = self._check_param(rate) + rate_a = self._check_param_type(rate) return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0 def _sample(self, shape=(), rate=None): @@ -312,7 +303,7 @@ class Exponential(Distribution): Tensor, shape is shape + batch_shape. """ shape = self.checktuple(shape, 'shape') - rate = self._check_param(rate) + rate = self._check_param_type(rate) origin_shape = shape + self.shape(rate) if origin_shape == (): sample_shape = (1,) diff --git a/mindspore/nn/probability/distribution/geometric.py b/mindspore/nn/probability/distribution/geometric.py index b746db0ad07..0c3ca959cef 100644 --- a/mindspore/nn/probability/distribution/geometric.py +++ b/mindspore/nn/probability/distribution/geometric.py @@ -19,7 +19,7 @@ from mindspore.ops import composite as C from mindspore.common import dtype as mstype from .distribution import Distribution from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\ - raise_none_error + set_param_type from ._utils.custom_ops import exp_generic, log_generic @@ -123,13 +123,16 @@ class Geometric(Distribution): valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type check_type(dtype, valid_dtype, type(self).__name__) super(Geometric, self).__init__(seed, dtype, name, param) - self.parameter_type = mstype.float32 + self.parameter_type = set_param_type({'probs1': probs}, mstype.float32) if probs is not None: self._probs = cast_to_tensor(probs, self.parameter_type) check_prob(self._probs) else: self._probs = probs + self.default_parameters = [self.probs] + self.parameter_names = ['probs1'] + self.minval = np.finfo(np.float).tiny # ops needed for the class @@ -164,24 +167,12 @@ class Geometric(Distribution): """ return self._probs - def _check_param(self, probs1): - """ - Check availablity of distribution specific args probs1. - """ - if probs1 is not None: - if self.context_mode == 0: - self.checktensor(probs1, 'probs1') - else: - probs1 = self.checktensor(probs1, 'probs1') - return self.cast(probs1, self.parameter_type) - return self.probs if self.probs is not None else raise_none_error('probs1') - def _mean(self, probs1=None): r""" .. math:: MEAN(Geo) = \fratc{1 - probs1}{probs1} """ - probs1 = self._check_param(probs1) + probs1 = self._check_param_type(probs1) return (1. - probs1) / probs1 def _mode(self, probs1=None): @@ -189,7 +180,7 @@ class Geometric(Distribution): .. math:: MODE(Geo) = 0 """ - probs1 = self._check_param(probs1) + probs1 = self._check_param_type(probs1) return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.) def _var(self, probs1=None): @@ -197,7 +188,7 @@ class Geometric(Distribution): .. math:: VAR(Geo) = \frac{1 - probs1}{probs1 ^ {2}} """ - probs1 = self._check_param(probs1) + probs1 = self._check_param_type(probs1) return (1.0 - probs1) / self.sq(probs1) def _entropy(self, probs1=None): @@ -205,7 +196,7 @@ class Geometric(Distribution): .. math:: H(Geo) = \frac{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1} """ - probs1 = self._check_param(probs1) + probs1 = self._check_param_type(probs1) probs0 = 1.0 - probs1 return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1 @@ -236,7 +227,7 @@ class Geometric(Distribution): value = self._check_value(value, 'value') value = self.cast(value, mstype.float32) value = self.floor(value) - probs1 = self._check_param(probs1) + probs1 = self._check_param_type(probs1) pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1)) zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0) comp = self.less(value, zeros) @@ -258,7 +249,7 @@ class Geometric(Distribution): value = self._check_value(value, 'value') value = self.cast(value, mstype.float32) value = self.floor(value) - probs1 = self._check_param(probs1) + probs1 = self._check_param_type(probs1) probs0 = 1.0 - probs1 cdf = 1.0 - self.pow(probs0, value + 1.0) zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0) @@ -280,7 +271,7 @@ class Geometric(Distribution): check_distribution_name(dist, 'Geometric') probs1_b = self._check_value(probs1_b, 'probs1_b') probs1_b = self.cast(probs1_b, self.parameter_type) - probs1_a = self._check_param(probs1) + probs1_a = self._check_param_type(probs1) probs0_a = 1.0 - probs1_a probs0_b = 1.0 - probs1_b return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b) @@ -297,7 +288,7 @@ class Geometric(Distribution): Tensor, shape is shape + batch_shape. """ shape = self.checktuple(shape, 'shape') - probs1 = self._check_param(probs1) + probs1 = self._check_param_type(probs1) origin_shape = shape + self.shape(probs1) if origin_shape == (): sample_shape = (1,) diff --git a/mindspore/nn/probability/distribution/normal.py b/mindspore/nn/probability/distribution/normal.py index a6fcb01d157..a53e49efd7a 100644 --- a/mindspore/nn/probability/distribution/normal.py +++ b/mindspore/nn/probability/distribution/normal.py @@ -19,7 +19,7 @@ from mindspore.ops import composite as C from mindspore.common import dtype as mstype from .distribution import Distribution from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\ - raise_none_error, common_dtype + set_param_type from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, erf_generic class Normal(Distribution): @@ -127,14 +127,17 @@ class Normal(Distribution): valid_dtype = mstype.float_type check_type(dtype, valid_dtype, type(self).__name__) super(Normal, self).__init__(seed, dtype, name, param) - self.parameter_type = common_dtype(mean, 'mean', sd, 'sd', self.dtype) + self.parameter_type = set_param_type({'mean': mean, 'sd': sd}, self.dtype) if mean is not None and sd is not None: self._mean_value = cast_to_tensor(mean, self.parameter_type) self._sd_value = cast_to_tensor(sd, self.parameter_type) check_greater_zero(self._sd_value, "Standard deviation") else: - self._mean_value = mean - self._sd_value = sd + self._mean_value = mean if mean is None else cast_to_tensor(mean, self.parameter_type) + self._sd_value = sd if sd is None else cast_to_tensor(sd, self.parameter_type) + + self.default_parameters = [self._mean_value, self._sd_value] + self.parameter_names = ['mean', 'sd'] #ops needed for the class self.exp = exp_generic @@ -159,51 +162,25 @@ class Normal(Distribution): str_info = f'batch_shape = {self._broadcast_shape}' return str_info - def _check_param(self, mean, sd): - """ - Check availablity of distribution specific args `mean` and `sd`. - """ - if mean is not None: - if self.context_mode == 0: - self.checktensor(mean, 'mean') - else: - mean = self.checktensor(mean, 'mean') - else: - mean = self._mean_value if self._mean_value is not None else raise_none_error('mean') - if sd is not None: - if self.context_mode == 0: - self.checktensor(sd, 'sd') - else: - sd = self.checktensor(sd, 'sd') - else: - sd = self._sd_value if self._sd_value is not None else raise_none_error('sd') - batch_shape = self.shape(mean + sd) - mean = mean * self.fill(self.dtypeop(mean), batch_shape, 1.0) - sd = sd * self.fill(self.dtypeop(sd), batch_shape, 1.0) - self.sametypeshape(mean, sd) - mean = self.cast(mean, self.parameter_type) - sd = self.cast(sd, self.parameter_type) - return mean, sd - def _mean(self, mean=None, sd=None): """ The mean of the distribution. """ - mean, sd = self._check_param(mean, sd) + mean, sd = self._check_param_type(mean, sd) return mean def _mode(self, mean=None, sd=None): """ The mode of the distribution. """ - mean, sd = self._check_param(mean, sd) + mean, sd = self._check_param_type(mean, sd) return mean def _sd(self, mean=None, sd=None): """ The standard deviation of the distribution. """ - mean, sd = self._check_param(mean, sd) + mean, sd = self._check_param_type(mean, sd) return sd def _entropy(self, mean=None, sd=None): @@ -213,7 +190,7 @@ class Normal(Distribution): .. math:: H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma))) """ - mean, sd = self._check_param(mean, sd) + mean, sd = self._check_param_type(mean, sd) return self.log(self.sqrt(self.const(np.e * 2. * np.pi))) + self.log(sd) def _cross_entropy(self, dist, mean_b, sd_b, mean=None, sd=None): @@ -244,7 +221,7 @@ class Normal(Distribution): """ value = self._check_value(value, 'value') value = self.cast(value, self.dtype) - mean, sd = self._check_param(mean, sd) + mean, sd = self._check_param_type(mean, sd) unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd)) neg_normalization = -1. * self.log(self.const(2. * np.pi)) / 2. - self.log(sd) return unnormalized_log_prob + neg_normalization @@ -263,7 +240,7 @@ class Normal(Distribution): """ value = self._check_value(value, 'value') value = self.cast(value, self.dtype) - mean, sd = self._check_param(mean, sd) + mean, sd = self._check_param_type(mean, sd) sqrt2 = self.sqrt(self.const(2.0)) adjusted = (value - mean) / (sd * sqrt2) return 0.5 * (1.0 + self.erf(adjusted)) @@ -288,7 +265,7 @@ class Normal(Distribution): sd_b = self._check_value(sd_b, 'sd_b') mean_b = self.cast(mean_b, self.parameter_type) sd_b = self.cast(sd_b, self.parameter_type) - mean_a, sd_a = self._check_param(mean, sd) + mean_a, sd_a = self._check_param_type(mean, sd) diff_log_scale = self.log(sd_a) - self.log(sd_b) squared_diff = self.sq(mean_a / sd_b - mean_b / sd_b) return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale @@ -306,7 +283,7 @@ class Normal(Distribution): Tensor, shape is shape + batch_shape. """ shape = self.checktuple(shape, 'shape') - mean, sd = self._check_param(mean, sd) + mean, sd = self._check_param_type(mean, sd) batch_shape = self.shape(mean + sd) origin_shape = shape + batch_shape if origin_shape == (): diff --git a/mindspore/nn/probability/distribution/uniform.py b/mindspore/nn/probability/distribution/uniform.py index 7de0d7f33f0..87224c5ee22 100644 --- a/mindspore/nn/probability/distribution/uniform.py +++ b/mindspore/nn/probability/distribution/uniform.py @@ -18,7 +18,7 @@ from mindspore.ops import composite as C from mindspore.common import dtype as mstype from .distribution import Distribution from ._utils.utils import cast_to_tensor, check_greater, check_type, check_distribution_name,\ - raise_none_error, common_dtype + set_param_type from ._utils.custom_ops import exp_generic, log_generic class Uniform(Distribution): @@ -126,14 +126,17 @@ class Uniform(Distribution): valid_dtype = mstype.float_type check_type(dtype, valid_dtype, type(self).__name__) super(Uniform, self).__init__(seed, dtype, name, param) - self.parameter_type = common_dtype(low, 'low', high, 'high', self.dtype) + self.parameter_type = set_param_type({'low': low, 'high': high}, self.dtype) if low is not None and high is not None: - self._low = cast_to_tensor(low, dtype) - self._high = cast_to_tensor(high, dtype) + self._low = cast_to_tensor(low, self.parameter_type) + self._high = cast_to_tensor(high, self.parameter_type) check_greater(self.low, self.high, "low value", "high value") else: - self._low = low - self._high = high + self._low = low if low is None else cast_to_tensor(low, self.parameter_type) + self._high = high if high is None else cast_to_tensor(high, self.parameter_type) + + self.default_parameters = [self.low, self.high] + self.parameter_names = ['low', 'high'] # ops needed for the class self.exp = exp_generic @@ -162,32 +165,6 @@ class Uniform(Distribution): str_info = f'batch_shape = {self._broadcast_shape}' return str_info - def _check_param(self, low, high): - """ - Check availablity of distribution specific args `low` and `high`. - """ - if low is not None: - if self.context_mode == 0: - self.checktensor(low, 'low') - else: - low = self.checktensor(low, 'low') - else: - low = self.low if self.low is not None else raise_none_error('low') - if high is not None: - if self.context_mode == 0: - self.checktensor(high, 'high') - else: - high = self.checktensor(high, 'high') - else: - high = self.high if self.high is not None else raise_none_error('high') - batch_shape = self.shape(high - low) - high = high * self.fill(self.dtypeop(high), batch_shape, 1.0) - low = low * self.fill(self.dtypeop(low), batch_shape, 1.0) - self.sametypeshape(high, low) - low = self.cast(low, self.parameter_type) - high = self.cast(high, self.parameter_type) - return low, high - @property def low(self): """ @@ -209,7 +186,7 @@ class Uniform(Distribution): .. math:: range(U) = high -low """ - low, high = self._check_param(low, high) + low, high = self._check_param_type(low, high) return high - low def _mean(self, low=None, high=None): @@ -217,7 +194,7 @@ class Uniform(Distribution): .. math:: MEAN(U) = \frac{low + high}{2}. """ - low, high = self._check_param(low, high) + low, high = self._check_param_type(low, high) return (low + high) / 2. def _var(self, low=None, high=None): @@ -225,7 +202,7 @@ class Uniform(Distribution): .. math:: VAR(U) = \frac{(high -low) ^ 2}{12}. """ - low, high = self._check_param(low, high) + low, high = self._check_param_type(low, high) return self.sq(high - low) / 12.0 def _entropy(self, low=None, high=None): @@ -233,7 +210,7 @@ class Uniform(Distribution): .. math:: H(U) = \log(high - low). """ - low, high = self._check_param(low, high) + low, high = self._check_param_type(low, high) return self.log(high - low) def _cross_entropy(self, dist, low_b, high_b, low=None, high=None): @@ -266,7 +243,7 @@ class Uniform(Distribution): """ value = self._check_value(value, 'value') value = self.cast(value, self.dtype) - low, high = self._check_param(low, high) + low, high = self._check_param_type(low, high) neg_ones = self.fill(self.dtype, self.shape(value), -1.0) prob = self.exp(neg_ones * self.log(high - low)) broadcast_shape = self.shape(prob) @@ -292,7 +269,7 @@ class Uniform(Distribution): low_b = self.cast(low_b, self.parameter_type) high_b = self._check_value(high_b, 'high_b') high_b = self.cast(high_b, self.parameter_type) - low_a, high_a = self._check_param(low, high) + low_a, high_a = self._check_param_type(low, high) kl = self.log(high_b - low_b) - self.log(high_a - low_a) comp = self.logicaland(self.lessequal(low_b, low_a), self.lessequal(high_a, high_b)) return self.select(comp, kl, self.log(self.zeroslike(kl))) @@ -313,7 +290,7 @@ class Uniform(Distribution): """ value = self._check_value(value, 'value') value = self.cast(value, self.dtype) - low, high = self._check_param(low, high) + low, high = self._check_param_type(low, high) prob = (value - low) / (high - low) broadcast_shape = self.shape(prob) zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0) @@ -336,7 +313,7 @@ class Uniform(Distribution): Tensor, shape is shape + batch_shape. """ shape = self.checktuple(shape, 'shape') - low, high = self._check_param(low, high) + low, high = self._check_param_type(low, high) broadcast_shape = self.shape(low + high) origin_shape = shape + broadcast_shape if origin_shape == (): diff --git a/tests/ut/python/nn/distribution/test_utils.py b/tests/ut/python/nn/distribution/test_utils.py new file mode 100644 index 00000000000..fbb8eb80509 --- /dev/null +++ b/tests/ut/python/nn/distribution/test_utils.py @@ -0,0 +1,182 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Test util functions used in distribution classes. +""" +import numpy as np +import pytest + +from mindspore.nn.cell import Cell +from mindspore import context +from mindspore import dtype +from mindspore import Tensor +from mindspore.common.parameter import Parameter +from mindspore.nn.probability.distribution._utils.utils import set_param_type, \ + cast_to_tensor, CheckTuple, CheckTensor + +def test_set_param_type(): + """ + Test set_param_type function. + """ + tensor_fp16 = Tensor(0.1, dtype=dtype.float16) + tensor_fp32 = Tensor(0.1, dtype=dtype.float32) + tensor_fp64 = Tensor(0.1, dtype=dtype.float64) + tensor_int32 = Tensor(0.1, dtype=dtype.int32) + array_fp32 = np.array(1.0).astype(np.float32) + array_fp64 = np.array(1.0).astype(np.float64) + array_int32 = np.array(1.0).astype(np.int32) + + dict1 = {'a': tensor_fp32, 'b': 1.0, 'c': tensor_fp32} + dict2 = {'a': tensor_fp32, 'b': 1.0, 'c': tensor_fp64} + dict3 = {'a': tensor_int32, 'b': 1.0, 'c': tensor_int32} + dict4 = {'a': array_fp32, 'b': 1.0, 'c': tensor_fp32} + dict5 = {'a': array_fp32, 'b': 1.0, 'c': array_fp64} + dict6 = {'a': array_fp32, 'b': 1.0, 'c': array_int32} + dict7 = {'a': 1.0} + dict8 = {'a': 1.0, 'b': 1.0, 'c': 1.0} + dict9 = {'a': tensor_fp16, 'b': tensor_fp16, 'c': tensor_fp16} + dict10 = {'a': tensor_fp64, 'b': tensor_fp64, 'c': tensor_fp64} + dict11 = {'a': array_fp64, 'b': array_fp64, 'c': tensor_fp64} + + ans1 = set_param_type(dict1, dtype.float16) + assert ans1 == dtype.float32 + + with pytest.raises(TypeError): + set_param_type(dict2, dtype.float32) + + ans3 = set_param_type(dict3, dtype.float16) + assert ans3 == dtype.float32 + ans4 = set_param_type(dict4, dtype.float16) + assert ans4 == dtype.float32 + + with pytest.raises(TypeError): + set_param_type(dict5, dtype.float32) + with pytest.raises(TypeError): + set_param_type(dict6, dtype.float32) + + ans7 = set_param_type(dict7, dtype.float32) + assert ans7 == dtype.float32 + ans8 = set_param_type(dict8, dtype.float32) + assert ans8 == dtype.float32 + ans9 = set_param_type(dict9, dtype.float32) + assert ans9 == dtype.float16 + ans10 = set_param_type(dict10, dtype.float32) + assert ans10 == dtype.float32 + ans11 = set_param_type(dict11, dtype.float32) + assert ans11 == dtype.float32 + +def test_cast_to_tensor(): + """ + Test cast_to_tensor. + """ + with pytest.raises(ValueError): + cast_to_tensor(None, dtype.float32) + with pytest.raises(TypeError): + cast_to_tensor(True, dtype.float32) + with pytest.raises(TypeError): + cast_to_tensor({'a': 1, 'b': 2}, dtype.float32) + with pytest.raises(TypeError): + cast_to_tensor('tensor', dtype.float32) + + ans1 = cast_to_tensor(Parameter(Tensor(0.1, dtype=dtype.float32), 'param')) + assert isinstance(ans1, Parameter) + ans2 = cast_to_tensor(np.array(1.0).astype(np.float32)) + assert isinstance(ans2, Tensor) + ans3 = cast_to_tensor([1.0, 2.0]) + assert isinstance(ans3, Tensor) + ans4 = cast_to_tensor(Tensor(0.1, dtype=dtype.float32), dtype.float32) + assert isinstance(ans4, Tensor) + ans5 = cast_to_tensor(0.1, dtype.float32) + assert isinstance(ans5, Tensor) + ans6 = cast_to_tensor(1, dtype.float32) + assert isinstance(ans6, Tensor) + +class Net(Cell): + """ + Test class: CheckTuple. + """ + def __init__(self, value): + super(Net, self).__init__() + self.checktuple = CheckTuple() + self.value = value + + def construct(self, value=None): + if value is None: + return self.checktuple(self.value, 'input') + return self.checktuple(value, 'input') + +def test_check_tuple(): + """ + Test CheckTuple. + """ + net1 = Net((1, 2, 3)) + ans1 = net1() + assert isinstance(ans1, tuple) + + with pytest.raises(TypeError): + net2 = Net('tuple') + net2() + + context.set_context(mode=context.GRAPH_MODE) + net3 = Net((1, 2, 3)) + ans3 = net3() + assert isinstance(ans3, tuple) + + with pytest.raises(TypeError): + net4 = Net('tuple') + net4() + +class Net1(Cell): + """ + Test class: CheckTensor. + """ + def __init__(self, value): + super(Net1, self).__init__() + self.checktensor = CheckTensor() + self.value = value + self.context = context.get_context('mode') + + def construct(self, value=None): + value = self.value if value is None else value + if self.context == 0: + self.checktensor(value, 'input') + return value + return self.checktensor(value, 'input') + +def test_check_tensor(): + """ + Test CheckTensor. + """ + value = Tensor(0.1, dtype=dtype.float32) + net1 = Net1(value) + ans1 = net1() + assert isinstance(ans1, Tensor) + ans1 = net1(value) + assert isinstance(ans1, Tensor) + + with pytest.raises(TypeError): + net2 = Net1('tuple') + net2() + + context.set_context(mode=context.GRAPH_MODE) + net3 = Net1(value) + ans3 = net3() + assert isinstance(ans3, Tensor) + ans3 = net3(value) + assert isinstance(ans3, Tensor) + + with pytest.raises(TypeError): + net4 = Net1('tuple') + net4()