!6020 Add common check_param_type and set_param_type in distribution

Merge pull request !6020 from XunDeng/new_check_param
This commit is contained in:
mindspore-ci-bot 2020-09-13 13:54:39 +08:00 committed by Gitee
commit c10341dfb7
8 changed files with 476 additions and 243 deletions

View File

@ -110,6 +110,8 @@ def check_scalar_from_param(params):
Notes: String parameters are excluded.
"""
for value in params.values():
if value is None:
continue
if isinstance(value, (msp.bijector.Bijector, msp.distribution.Distribution)):
return params['distribution'].is_scalar_batch
if isinstance(value, Parameter):
@ -358,23 +360,29 @@ class CheckTensor(PrimitiveWithInfer):
return x
raise TypeError(f"For {name}, input type should be a Tensor or Parameter.")
def common_dtype(arg_a, name_a, arg_b, name_b, hint_type):
def set_param_type(args, hint_type):
"""
check if arg_a and arg_b have the same dtype.
Find the common type among arguments.
Args:
args (dict): dictionary of arguments, {'name':value}.
hint_type (mindspore.dtype): hint type to return.
Raises:
TypeError: if tensors in args are not the same dtype.
"""
if hasattr(arg_a, 'dtype') and hasattr(arg_b, 'dtype'):
if isinstance(arg_a, np.ndarray):
a_dtype = mstype.pytype_to_dtype(arg_a.dtype)
else:
a_dtype = arg_a.dtype
if isinstance(arg_b, np.ndarray):
b_dtype = mstype.pytype_to_dtype(arg_b.dtype)
else:
b_dtype = arg_b.dtype
if a_dtype != b_dtype:
raise TypeError(f"{name_a} and {name_b} should have the same dtype.")
int_type = mstype.int_type + mstype.uint_type
if a_dtype in int_type or a_dtype == mstype.float64:
return mstype.float32
return a_dtype
return hint_type
common_dtype = None
for name, arg in args.items():
if hasattr(arg, 'dtype'):
if isinstance(arg, np.ndarray):
cur_dtype = mstype.pytype_to_dtype(arg.dtype)
else:
cur_dtype = arg.dtype
if common_dtype is None:
common_dtype = cur_dtype
elif cur_dtype != common_dtype:
raise TypeError(f"{name} should have the same dtype as other arguments.")
int_type = mstype.int_type + mstype.uint_type
if common_dtype in int_type or common_dtype == mstype.float64:
return mstype.float32
return hint_type if common_dtype is None else common_dtype

View File

@ -17,7 +17,7 @@ from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, set_param_type
from ._utils.custom_ops import exp_generic, log_generic, erf_generic
@ -119,13 +119,16 @@ class Bernoulli(Distribution):
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
super(Bernoulli, self).__init__(seed, dtype, name, param)
self.parameter_type = mstype.float32
self.parameter_type = set_param_type({'probs1': probs}, mstype.float32)
if probs is not None:
self._probs = cast_to_tensor(probs, mstype.float32)
self._probs = cast_to_tensor(probs, self.parameter_type)
check_prob(self.probs)
else:
self._probs = probs
self.default_parameters = [self.probs]
self.parameter_names = ['probs1']
# ops needed for the class
self.exp = exp_generic
self.log = log_generic
@ -157,24 +160,12 @@ class Bernoulli(Distribution):
"""
return self._probs
def _check_param(self, probs1):
"""
Check availablity of distribution specific args `probs1`.
"""
if probs1 is not None:
if self.context_mode == 0:
self.checktensor(probs1, 'probs1')
else:
probs1 = self.checktensor(probs1, 'probs1')
return self.cast(probs1, self.parameter_type)
return self.probs if self.probs is not None else raise_none_error('probs1')
def _mean(self, probs1=None):
r"""
.. math::
MEAN(B) = probs1
"""
probs1 = self._check_param(probs1)
probs1 = self._check_param_type(probs1)
return probs1
def _mode(self, probs1=None):
@ -182,7 +173,7 @@ class Bernoulli(Distribution):
.. math::
MODE(B) = 1 if probs1 > 0.5 else = 0
"""
probs1 = self._check_param(probs1)
probs1 = self._check_param_type(probs1)
prob_type = self.dtypeop(probs1)
zeros = self.fill(prob_type, self.shape(probs1), 0.0)
ones = self.fill(prob_type, self.shape(probs1), 1.0)
@ -194,7 +185,7 @@ class Bernoulli(Distribution):
.. math::
VAR(B) = probs1 * probs0
"""
probs1 = self._check_param(probs1)
probs1 = self._check_param_type(probs1)
probs0 = 1.0 - probs1
return self.exp(self.log(probs0) + self.log(probs1))
@ -203,11 +194,11 @@ class Bernoulli(Distribution):
.. math::
H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1)
"""
probs1 = self._check_param(probs1)
probs1 = self._check_param_type(probs1)
probs0 = 1.0 - probs1
return -(probs0 * self.log(probs0)) - (probs1 * self.log(probs1))
def _cross_entropy(self, dist, probs1_b, probs1_a=None):
def _cross_entropy(self, dist, probs1_b, probs1=None):
"""
Evaluate cross_entropy between Bernoulli distributions.
@ -217,7 +208,7 @@ class Bernoulli(Distribution):
probs1_a (Tensor): `probs1` of distribution a. Default: self.probs.
"""
check_distribution_name(dist, 'Bernoulli')
return self._entropy(probs1_a) + self._kl_loss(dist, probs1_b, probs1_a)
return self._entropy(probs1) + self._kl_loss(dist, probs1_b, probs1)
def _log_prob(self, value, probs1=None):
r"""
@ -233,7 +224,7 @@ class Bernoulli(Distribution):
"""
value = self._check_value(value, 'value')
value = self.cast(value, mstype.float32)
probs1 = self._check_param(probs1)
probs1 = self._check_param_type(probs1)
probs0 = 1.0 - probs1
return self.log(probs1) * value + self.log(probs0) * (1.0 - value)
@ -253,7 +244,7 @@ class Bernoulli(Distribution):
value = self._check_value(value, 'value')
value = self.cast(value, mstype.float32)
value = self.floor(value)
probs1 = self._check_param(probs1)
probs1 = self._check_param_type(probs1)
prob_type = self.dtypeop(probs1)
value = value * self.fill(prob_type, self.shape(probs1), 1.0)
probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0)
@ -264,7 +255,7 @@ class Bernoulli(Distribution):
less_than_zero = self.select(comp_zero, zeros, probs0)
return self.select(comp_one, less_than_zero, ones)
def _kl_loss(self, dist, probs1_b, probs1_a=None):
def _kl_loss(self, dist, probs1_b, probs1=None):
r"""
Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b).
@ -280,7 +271,7 @@ class Bernoulli(Distribution):
check_distribution_name(dist, 'Bernoulli')
probs1_b = self._check_value(probs1_b, 'probs1_b')
probs1_b = self.cast(probs1_b, self.parameter_type)
probs1_a = self._check_param(probs1_a)
probs1_a = self._check_param_type(probs1)
probs0_a = 1.0 - probs1_a
probs0_b = 1.0 - probs1_b
return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b)
@ -297,7 +288,7 @@ class Bernoulli(Distribution):
Tensor, shape is shape + batch_shape.
"""
shape = self.checktuple(shape, 'shape')
probs1 = self._check_param(probs1)
probs1 = self._check_param_type(probs1)
origin_shape = shape + self.shape(probs1)
if origin_shape == ():
sample_shape = (1,)

View File

@ -18,7 +18,8 @@ from mindspore.nn.cell import Cell
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore.common import get_seed
from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param, cast_type_for_device
from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param, cast_type_for_device,\
raise_none_error
from ._utils.utils import CheckTuple, CheckTensor
@ -115,6 +116,51 @@ class Distribution(Cell):
def broadcast_shape(self):
return self._broadcast_shape
def _check_param_type(self, *args):
"""
Check the availability and validity of default parameters and dist_spec_args.
dist_spec_args passed in must be tensors. If default parameter of the distribution
is None, its parameter must be passed in through `args`.
"""
broadcast_shape = None
common_dtype = None
out = []
for arg, name, default in zip(args, self.parameter_names, self.default_parameters):
# check if the argument is a Tensor
if arg is not None:
if self.context_mode == 0:
self.checktensor(arg, name)
else:
arg = self.checktensor(arg, name)
else:
arg = default if default is not None else raise_none_error(name)
# broadcast if the number of args > 1
if broadcast_shape is None:
broadcast_shape = self.shape(arg)
common_dtype = self.dtypeop(arg)
else:
ones = self.fill(self.dtypeop(arg), broadcast_shape, 1.0)
broadcast_shape = self.shape(arg + ones)
# check if the arguments have the same dtype
arg = arg * self.fill(self.dtypeop(arg), broadcast_shape, 1.0)
dtype_tensor = self.fill(common_dtype, broadcast_shape, 1.0)
self.sametypeshape(arg, dtype_tensor)
arg = self.cast(arg, self.parameter_type)
out.append(arg)
if len(out) == 1:
return out[0]
# broadcast all args to broadcast_shape
result = ()
for arg in out:
arg = arg * self.fill(self.dtypeop(arg), broadcast_shape, 1.0)
result = result + (arg,)
return result
def _check_value(self, value, name):
"""
Check availability of `value` as a Tensor.
@ -211,163 +257,203 @@ class Distribution(Cell):
if hasattr(self, '_cross_entropy'):
self._call_cross_entropy = self._cross_entropy
def log_prob(self, *args, **kwargs):
def log_prob(self, value, *args, **kwargs):
"""
Evaluate the log probability(pdf or pmf) at the given value.
Note:
The argument `args` must include `value`.
dist_spec_args are optional.
"""
return self._call_log_prob(*args, **kwargs)
Args:
value (Tensor): value to be evaluated.
*args (list): the list of positional arguments forwarded to subclasses.
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
def _calc_prob_from_log_prob(self, *args, **kwargs):
Note:
A distribution can be optionally passed to the function by passing its dist_spec_args through
`args` or `kwargs`.
"""
return self._call_log_prob(value, *args, **kwargs)
def _calc_prob_from_log_prob(self, value, *args, **kwargs):
r"""
Evaluate prob from log probability.
.. math::
probability(x) = \exp(log_likehood(x))
"""
return self.exp(self._log_prob(*args, **kwargs))
return self.exp(self._log_prob(value, *args, **kwargs))
def prob(self, *args, **kwargs):
def prob(self, value, *args, **kwargs):
"""
Evaluate the probability (pdf or pmf) at given value.
Note:
The argument `args` must include `value`.
dist_spec_args are optional.
"""
return self._call_prob(*args, **kwargs)
Args:
value (Tensor): value to be evaluated.
*args (list): the list of positional arguments forwarded to subclasses.
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
def _calc_log_prob_from_prob(self, *args, **kwargs):
Note:
A distribution can be optionally passed to the function by passing its dist_spec_args through
`args` or `kwargs`.
"""
return self._call_prob(value, *args, **kwargs)
def _calc_log_prob_from_prob(self, value, *args, **kwargs):
r"""
Evaluate log probability from probability.
.. math::
log_prob(x) = \log(prob(x))
"""
return self.log(self._prob(*args, **kwargs))
return self.log(self._prob(value, *args, **kwargs))
def cdf(self, *args, **kwargs):
def cdf(self, value, *args, **kwargs):
"""
Evaluate the cdf at given value.
Note:
The argument `args` must include `value`.
dist_spec_args are optional.
"""
return self._call_cdf(*args, **kwargs)
Args:
value (Tensor): value to be evaluated.
*args (list): the list of positional arguments forwarded to subclasses.
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
def _calc_cdf_from_log_cdf(self, *args, **kwargs):
Note:
A distribution can be optionally passed to the function by passing its dist_spec_args through
`args` or `kwargs`.
"""
return self._call_cdf(value, *args, **kwargs)
def _calc_cdf_from_log_cdf(self, value, *args, **kwargs):
r"""
Evaluate cdf from log_cdf.
.. math::
cdf(x) = \exp(log_cdf(x))
"""
return self.exp(self._log_cdf(*args, **kwargs))
return self.exp(self._log_cdf(value, *args, **kwargs))
def _calc_cdf_from_survival(self, *args, **kwargs):
def _calc_cdf_from_survival(self, value, *args, **kwargs):
r"""
Evaluate cdf from survival function.
.. math::
cdf(x) = 1 - (survival_function(x))
"""
return 1.0 - self._survival_function(*args, **kwargs)
return 1.0 - self._survival_function(value, *args, **kwargs)
def _calc_cdf_from_log_survival(self, *args, **kwargs):
def _calc_cdf_from_log_survival(self, value, *args, **kwargs):
r"""
Evaluate cdf from log survival function.
.. math::
cdf(x) = 1 - (\exp(log_survival(x)))
"""
return 1.0 - self.exp(self._log_survival(*args, **kwargs))
return 1.0 - self.exp(self._log_survival(value, *args, **kwargs))
def log_cdf(self, *args, **kwargs):
def log_cdf(self, value, *args, **kwargs):
"""
Evaluate the log cdf at given value.
Note:
The argument `args` must include `value`.
dist_spec_args are optional.
"""
return self._call_log_cdf(*args, **kwargs)
Args:
value (Tensor): value to be evaluated.
*args (list): the list of positional arguments forwarded to subclasses.
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
def _calc_log_cdf_from_call_cdf(self, *args, **kwargs):
Note:
A distribution can be optionally passed to the function by passing its dist_spec_args through
`args` or `kwargs`.
"""
return self._call_log_cdf(value, *args, **kwargs)
def _calc_log_cdf_from_call_cdf(self, value, *args, **kwargs):
r"""
Evaluate log cdf from cdf.
.. math::
log_cdf(x) = \log(cdf(x))
"""
return self.log(self._call_cdf(*args, **kwargs))
return self.log(self._call_cdf(value, *args, **kwargs))
def survival_function(self, *args, **kwargs):
def survival_function(self, value, *args, **kwargs):
"""
Evaluate the survival function at given value.
Note:
The argument `args` must include `value`.
dist_spec_args are optional.
"""
return self._call_survival(*args, **kwargs)
Args:
value (Tensor): value to be evaluated.
*args (list): the list of positional arguments forwarded to subclasses.
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
def _calc_survival_from_call_cdf(self, *args, **kwargs):
Note:
A distribution can be optionally passed to the function by passing its dist_spec_args through
`args` or `kwargs`.
"""
return self._call_survival(value, *args, **kwargs)
def _calc_survival_from_call_cdf(self, value, *args, **kwargs):
r"""
Evaluate survival function from cdf.
.. math::
survival_function(x) = 1 - (cdf(x))
"""
return 1.0 - self._call_cdf(*args, **kwargs)
return 1.0 - self._call_cdf(value, *args, **kwargs)
def _calc_survival_from_log_survival(self, *args, **kwargs):
def _calc_survival_from_log_survival(self, value, *args, **kwargs):
r"""
Evaluate survival function from log survival function.
.. math::
survival(x) = \exp(survival_function(x))
"""
return self.exp(self._log_survival(*args, **kwargs))
return self.exp(self._log_survival(value, *args, **kwargs))
def log_survival(self, *args, **kwargs):
def log_survival(self, value, *args, **kwargs):
"""
Evaluate the log survival function at given value.
Note:
The arguments `args` must include `value`.
dist_spec_args are optional.
"""
return self._call_log_survival(*args, **kwargs)
Args:
value (Tensor): value to be evaluated.
*args (list): the list of positional arguments forwarded to subclasses.
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
def _calc_log_survival_from_call_survival(self, *args, **kwargs):
Note:
A distribution can be optionally passed to the function by passing its dist_spec_args through
`args` or `kwargs`.
"""
return self._call_log_survival(value, *args, **kwargs)
def _calc_log_survival_from_call_survival(self, value, *args, **kwargs):
r"""
Evaluate log survival function from survival function.
.. math::
log_survival(x) = \log(survival_function(x))
"""
return self.log(self._call_survival(*args, **kwargs))
return self.log(self._call_survival(value, *args, **kwargs))
def kl_loss(self, *args, **kwargs):
def kl_loss(self, dist, *args, **kwargs):
"""
Evaluate the KL divergence, i.e. KL(a||b).
Args:
dist (str): type of the distribution.
*args (list): the list of positional arguments forwarded to subclasses.
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
Note:
The argument `args` must include the type of the distribution, parameters of distribution b.
Parameters for distribution a are optional.
dist_spec_args of distribution b must be passed to the function through `args` or `kwargs`.
Passing in dist_spec_args of distribution a is optional.
"""
return self._kl_loss(*args, **kwargs)
return self._kl_loss(dist, *args, **kwargs)
def mean(self, *args, **kwargs):
"""
Evaluate the mean.
Args:
*args (list): the list of positional arguments forwarded to subclasses.
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
Note:
dist_spec_args are optional.
A distribution can be optionally passed to the function by passing its *dist_spec_args* through
*args* or *kwargs*.
"""
return self._mean(*args, **kwargs)
@ -375,8 +461,13 @@ class Distribution(Cell):
"""
Evaluate the mode.
Args:
*args (list): the list of positional arguments forwarded to subclasses.
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
Note:
dist_spec_args are optional.
A distribution can be optionally passed to the function by passing its *dist_spec_args* through
*args* or *kwargs*.
"""
return self._mode(*args, **kwargs)
@ -384,8 +475,13 @@ class Distribution(Cell):
"""
Evaluate the standard deviation.
Args:
*args (list): the list of positional arguments forwarded to subclasses.
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
Note:
dist_spec_args are optional.
A distribution can be optionally passed to the function by passing its *dist_spec_args* through
*args* or *kwargs*.
"""
return self._call_sd(*args, **kwargs)
@ -393,8 +489,13 @@ class Distribution(Cell):
"""
Evaluate the variance.
Args:
*args (list): the list of positional arguments forwarded to subclasses.
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
Note:
dist_spec_args are optional.
A distribution can be optionally passed to the function by passing its *dist_spec_args* through
*args* or *kwargs*.
"""
return self._call_var(*args, **kwargs)
@ -420,37 +521,52 @@ class Distribution(Cell):
"""
Evaluate the entropy.
Args:
*args (list): the list of positional arguments forwarded to subclasses.
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
Note:
dist_spec_args are optional.
A distribution can be optionally passed to the function by passing its *dist_spec_args* through
*args* or *kwargs*.
"""
return self._entropy(*args, **kwargs)
def cross_entropy(self, *args, **kwargs):
def cross_entropy(self, dist, *args, **kwargs):
"""
Evaluate the cross_entropy between distribution a and b.
Note:
The argument `args` must include the type of the distribution, parameters of distribution b.
Parameters for distribution a are optional.
"""
return self._call_cross_entropy(*args, **kwargs)
Args:
dist (str): type of the distribution.
*args (list): the list of positional arguments forwarded to subclasses.
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
def _calc_cross_entropy(self, *args, **kwargs):
Note:
dist_spec_args of distribution b must be passed to the function through `args` or `kwargs`.
Passing in dist_spec_args of distribution a is optional.
"""
return self._call_cross_entropy(dist, *args, **kwargs)
def _calc_cross_entropy(self, dist, *args, **kwargs):
r"""
Evaluate cross_entropy from entropy and kl divergence.
.. math::
H(X, Y) = H(X) + KL(X||Y)
"""
return self._entropy(*args, **kwargs) + self._kl_loss(*args, **kwargs)
return self._entropy(*args, **kwargs) + self._kl_loss(dist, *args, **kwargs)
def sample(self, *args, **kwargs):
"""
Sampling function.
Args:
shape (tuple): shape of the sample.
*args (list): the list of positional arguments forwarded to subclasses.
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
Note:
Shape of the sample is default to ().
dist_spec_args are optional.
A distribution can be optionally passed to the function by passing its *dist_spec_args* through
*args* or *kwargs*.
"""
return self._sample(*args, **kwargs)

View File

@ -18,8 +18,7 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\
raise_none_error
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name, set_param_type
from ._utils.custom_ops import exp_generic, log_generic
class Exponential(Distribution):
@ -121,15 +120,19 @@ class Exponential(Distribution):
valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
super(Exponential, self).__init__(seed, dtype, name, param)
self.parameter_type = dtype
self.parameter_type = set_param_type({'rate': rate}, self.dtype)
if rate is not None:
self._rate = cast_to_tensor(rate, self.parameter_type)
check_greater_zero(self._rate, "rate")
else:
self._rate = rate
self.default_parameters = [self.rate]
self.parameter_names = ['rate']
self.minval = np.finfo(np.float).tiny
# ops needed for the class
self.exp = exp_generic
self.log = log_generic
@ -156,28 +159,16 @@ class Exponential(Distribution):
@property
def rate(self):
"""
Return rate of the distribution.
Return `rate` of the distribution.
"""
return self._rate
def _check_param(self, rate):
"""
Check availablity of distribution specific argument `rate`.
"""
if rate is not None:
if self.context_mode == 0:
self.checktensor(rate, 'rate')
else:
rate = self.checktensor(rate, 'rate')
return self.cast(rate, self.parameter_type)
return self.rate if self.rate is not None else raise_none_error('rate')
def _mean(self, rate=None):
r"""
.. math::
MEAN(EXP) = \frac{1.0}{\lambda}.
"""
rate = self._check_param(rate)
rate = self._check_param_type(rate)
return 1.0 / rate
def _mode(self, rate=None):
@ -185,7 +176,7 @@ class Exponential(Distribution):
.. math::
MODE(EXP) = 0.
"""
rate = self._check_param(rate)
rate = self._check_param_type(rate)
return self.fill(self.dtype, self.shape(rate), 0.)
def _sd(self, rate=None):
@ -193,7 +184,7 @@ class Exponential(Distribution):
.. math::
sd(EXP) = \frac{1.0}{\lambda}.
"""
rate = self._check_param(rate)
rate = self._check_param_type(rate)
return 1.0 / rate
def _entropy(self, rate=None):
@ -201,7 +192,7 @@ class Exponential(Distribution):
.. math::
H(Exp) = 1 - \log(\lambda).
"""
rate = self._check_param(rate)
rate = self._check_param_type(rate)
return 1.0 - self.log(rate)
def _cross_entropy(self, dist, rate_b, rate=None):
@ -234,7 +225,7 @@ class Exponential(Distribution):
"""
value = self._check_value(value, "value")
value = self.cast(value, self.dtype)
rate = self._check_param(rate)
rate = self._check_param_type(rate)
prob = self.log(rate) - rate * value
zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0)
neginf = self.fill(self.dtypeop(prob), self.shape(prob), -np.inf)
@ -257,7 +248,7 @@ class Exponential(Distribution):
"""
value = self._check_value(value, 'value')
value = self.cast(value, self.dtype)
rate = self._check_param(rate)
rate = self._check_param_type(rate)
cdf = 1.0 - self.exp(-1. * rate * value)
zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
comp = self.less(value, zeros)
@ -279,7 +270,7 @@ class Exponential(Distribution):
"""
value = self._check_value(value, 'value')
value = self.cast(value, self.dtype)
rate = self._check_param(rate)
rate = self._check_param_type(rate)
sf = -1. * rate * value
zeros = self.fill(self.dtypeop(sf), self.shape(sf), 0.0)
comp = self.less(value, zeros)
@ -297,7 +288,7 @@ class Exponential(Distribution):
check_distribution_name(dist, 'Exponential')
rate_b = self._check_value(rate_b, 'rate_b')
rate_b = self.cast(rate_b, self.parameter_type)
rate_a = self._check_param(rate)
rate_a = self._check_param_type(rate)
return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0
def _sample(self, shape=(), rate=None):
@ -312,7 +303,7 @@ class Exponential(Distribution):
Tensor, shape is shape + batch_shape.
"""
shape = self.checktuple(shape, 'shape')
rate = self._check_param(rate)
rate = self._check_param_type(rate)
origin_shape = shape + self.shape(rate)
if origin_shape == ():
sample_shape = (1,)

View File

@ -19,7 +19,7 @@ from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\
raise_none_error
set_param_type
from ._utils.custom_ops import exp_generic, log_generic
@ -123,13 +123,16 @@ class Geometric(Distribution):
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
super(Geometric, self).__init__(seed, dtype, name, param)
self.parameter_type = mstype.float32
self.parameter_type = set_param_type({'probs1': probs}, mstype.float32)
if probs is not None:
self._probs = cast_to_tensor(probs, self.parameter_type)
check_prob(self._probs)
else:
self._probs = probs
self.default_parameters = [self.probs]
self.parameter_names = ['probs1']
self.minval = np.finfo(np.float).tiny
# ops needed for the class
@ -164,24 +167,12 @@ class Geometric(Distribution):
"""
return self._probs
def _check_param(self, probs1):
"""
Check availablity of distribution specific args probs1.
"""
if probs1 is not None:
if self.context_mode == 0:
self.checktensor(probs1, 'probs1')
else:
probs1 = self.checktensor(probs1, 'probs1')
return self.cast(probs1, self.parameter_type)
return self.probs if self.probs is not None else raise_none_error('probs1')
def _mean(self, probs1=None):
r"""
.. math::
MEAN(Geo) = \fratc{1 - probs1}{probs1}
"""
probs1 = self._check_param(probs1)
probs1 = self._check_param_type(probs1)
return (1. - probs1) / probs1
def _mode(self, probs1=None):
@ -189,7 +180,7 @@ class Geometric(Distribution):
.. math::
MODE(Geo) = 0
"""
probs1 = self._check_param(probs1)
probs1 = self._check_param_type(probs1)
return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.)
def _var(self, probs1=None):
@ -197,7 +188,7 @@ class Geometric(Distribution):
.. math::
VAR(Geo) = \frac{1 - probs1}{probs1 ^ {2}}
"""
probs1 = self._check_param(probs1)
probs1 = self._check_param_type(probs1)
return (1.0 - probs1) / self.sq(probs1)
def _entropy(self, probs1=None):
@ -205,7 +196,7 @@ class Geometric(Distribution):
.. math::
H(Geo) = \frac{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1}
"""
probs1 = self._check_param(probs1)
probs1 = self._check_param_type(probs1)
probs0 = 1.0 - probs1
return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1
@ -236,7 +227,7 @@ class Geometric(Distribution):
value = self._check_value(value, 'value')
value = self.cast(value, mstype.float32)
value = self.floor(value)
probs1 = self._check_param(probs1)
probs1 = self._check_param_type(probs1)
pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1))
zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0)
comp = self.less(value, zeros)
@ -258,7 +249,7 @@ class Geometric(Distribution):
value = self._check_value(value, 'value')
value = self.cast(value, mstype.float32)
value = self.floor(value)
probs1 = self._check_param(probs1)
probs1 = self._check_param_type(probs1)
probs0 = 1.0 - probs1
cdf = 1.0 - self.pow(probs0, value + 1.0)
zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0)
@ -280,7 +271,7 @@ class Geometric(Distribution):
check_distribution_name(dist, 'Geometric')
probs1_b = self._check_value(probs1_b, 'probs1_b')
probs1_b = self.cast(probs1_b, self.parameter_type)
probs1_a = self._check_param(probs1)
probs1_a = self._check_param_type(probs1)
probs0_a = 1.0 - probs1_a
probs0_b = 1.0 - probs1_b
return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b)
@ -297,7 +288,7 @@ class Geometric(Distribution):
Tensor, shape is shape + batch_shape.
"""
shape = self.checktuple(shape, 'shape')
probs1 = self._check_param(probs1)
probs1 = self._check_param_type(probs1)
origin_shape = shape + self.shape(probs1)
if origin_shape == ():
sample_shape = (1,)

View File

@ -19,7 +19,7 @@ from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\
raise_none_error, common_dtype
set_param_type
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, erf_generic
class Normal(Distribution):
@ -127,14 +127,17 @@ class Normal(Distribution):
valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
super(Normal, self).__init__(seed, dtype, name, param)
self.parameter_type = common_dtype(mean, 'mean', sd, 'sd', self.dtype)
self.parameter_type = set_param_type({'mean': mean, 'sd': sd}, self.dtype)
if mean is not None and sd is not None:
self._mean_value = cast_to_tensor(mean, self.parameter_type)
self._sd_value = cast_to_tensor(sd, self.parameter_type)
check_greater_zero(self._sd_value, "Standard deviation")
else:
self._mean_value = mean
self._sd_value = sd
self._mean_value = mean if mean is None else cast_to_tensor(mean, self.parameter_type)
self._sd_value = sd if sd is None else cast_to_tensor(sd, self.parameter_type)
self.default_parameters = [self._mean_value, self._sd_value]
self.parameter_names = ['mean', 'sd']
#ops needed for the class
self.exp = exp_generic
@ -159,51 +162,25 @@ class Normal(Distribution):
str_info = f'batch_shape = {self._broadcast_shape}'
return str_info
def _check_param(self, mean, sd):
"""
Check availablity of distribution specific args `mean` and `sd`.
"""
if mean is not None:
if self.context_mode == 0:
self.checktensor(mean, 'mean')
else:
mean = self.checktensor(mean, 'mean')
else:
mean = self._mean_value if self._mean_value is not None else raise_none_error('mean')
if sd is not None:
if self.context_mode == 0:
self.checktensor(sd, 'sd')
else:
sd = self.checktensor(sd, 'sd')
else:
sd = self._sd_value if self._sd_value is not None else raise_none_error('sd')
batch_shape = self.shape(mean + sd)
mean = mean * self.fill(self.dtypeop(mean), batch_shape, 1.0)
sd = sd * self.fill(self.dtypeop(sd), batch_shape, 1.0)
self.sametypeshape(mean, sd)
mean = self.cast(mean, self.parameter_type)
sd = self.cast(sd, self.parameter_type)
return mean, sd
def _mean(self, mean=None, sd=None):
"""
The mean of the distribution.
"""
mean, sd = self._check_param(mean, sd)
mean, sd = self._check_param_type(mean, sd)
return mean
def _mode(self, mean=None, sd=None):
"""
The mode of the distribution.
"""
mean, sd = self._check_param(mean, sd)
mean, sd = self._check_param_type(mean, sd)
return mean
def _sd(self, mean=None, sd=None):
"""
The standard deviation of the distribution.
"""
mean, sd = self._check_param(mean, sd)
mean, sd = self._check_param_type(mean, sd)
return sd
def _entropy(self, mean=None, sd=None):
@ -213,7 +190,7 @@ class Normal(Distribution):
.. math::
H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma)))
"""
mean, sd = self._check_param(mean, sd)
mean, sd = self._check_param_type(mean, sd)
return self.log(self.sqrt(self.const(np.e * 2. * np.pi))) + self.log(sd)
def _cross_entropy(self, dist, mean_b, sd_b, mean=None, sd=None):
@ -244,7 +221,7 @@ class Normal(Distribution):
"""
value = self._check_value(value, 'value')
value = self.cast(value, self.dtype)
mean, sd = self._check_param(mean, sd)
mean, sd = self._check_param_type(mean, sd)
unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd))
neg_normalization = -1. * self.log(self.const(2. * np.pi)) / 2. - self.log(sd)
return unnormalized_log_prob + neg_normalization
@ -263,7 +240,7 @@ class Normal(Distribution):
"""
value = self._check_value(value, 'value')
value = self.cast(value, self.dtype)
mean, sd = self._check_param(mean, sd)
mean, sd = self._check_param_type(mean, sd)
sqrt2 = self.sqrt(self.const(2.0))
adjusted = (value - mean) / (sd * sqrt2)
return 0.5 * (1.0 + self.erf(adjusted))
@ -288,7 +265,7 @@ class Normal(Distribution):
sd_b = self._check_value(sd_b, 'sd_b')
mean_b = self.cast(mean_b, self.parameter_type)
sd_b = self.cast(sd_b, self.parameter_type)
mean_a, sd_a = self._check_param(mean, sd)
mean_a, sd_a = self._check_param_type(mean, sd)
diff_log_scale = self.log(sd_a) - self.log(sd_b)
squared_diff = self.sq(mean_a / sd_b - mean_b / sd_b)
return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale
@ -306,7 +283,7 @@ class Normal(Distribution):
Tensor, shape is shape + batch_shape.
"""
shape = self.checktuple(shape, 'shape')
mean, sd = self._check_param(mean, sd)
mean, sd = self._check_param_type(mean, sd)
batch_shape = self.shape(mean + sd)
origin_shape = shape + batch_shape
if origin_shape == ():

View File

@ -18,7 +18,7 @@ from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_greater, check_type, check_distribution_name,\
raise_none_error, common_dtype
set_param_type
from ._utils.custom_ops import exp_generic, log_generic
class Uniform(Distribution):
@ -126,14 +126,17 @@ class Uniform(Distribution):
valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
super(Uniform, self).__init__(seed, dtype, name, param)
self.parameter_type = common_dtype(low, 'low', high, 'high', self.dtype)
self.parameter_type = set_param_type({'low': low, 'high': high}, self.dtype)
if low is not None and high is not None:
self._low = cast_to_tensor(low, dtype)
self._high = cast_to_tensor(high, dtype)
self._low = cast_to_tensor(low, self.parameter_type)
self._high = cast_to_tensor(high, self.parameter_type)
check_greater(self.low, self.high, "low value", "high value")
else:
self._low = low
self._high = high
self._low = low if low is None else cast_to_tensor(low, self.parameter_type)
self._high = high if high is None else cast_to_tensor(high, self.parameter_type)
self.default_parameters = [self.low, self.high]
self.parameter_names = ['low', 'high']
# ops needed for the class
self.exp = exp_generic
@ -162,32 +165,6 @@ class Uniform(Distribution):
str_info = f'batch_shape = {self._broadcast_shape}'
return str_info
def _check_param(self, low, high):
"""
Check availablity of distribution specific args `low` and `high`.
"""
if low is not None:
if self.context_mode == 0:
self.checktensor(low, 'low')
else:
low = self.checktensor(low, 'low')
else:
low = self.low if self.low is not None else raise_none_error('low')
if high is not None:
if self.context_mode == 0:
self.checktensor(high, 'high')
else:
high = self.checktensor(high, 'high')
else:
high = self.high if self.high is not None else raise_none_error('high')
batch_shape = self.shape(high - low)
high = high * self.fill(self.dtypeop(high), batch_shape, 1.0)
low = low * self.fill(self.dtypeop(low), batch_shape, 1.0)
self.sametypeshape(high, low)
low = self.cast(low, self.parameter_type)
high = self.cast(high, self.parameter_type)
return low, high
@property
def low(self):
"""
@ -209,7 +186,7 @@ class Uniform(Distribution):
.. math::
range(U) = high -low
"""
low, high = self._check_param(low, high)
low, high = self._check_param_type(low, high)
return high - low
def _mean(self, low=None, high=None):
@ -217,7 +194,7 @@ class Uniform(Distribution):
.. math::
MEAN(U) = \frac{low + high}{2}.
"""
low, high = self._check_param(low, high)
low, high = self._check_param_type(low, high)
return (low + high) / 2.
def _var(self, low=None, high=None):
@ -225,7 +202,7 @@ class Uniform(Distribution):
.. math::
VAR(U) = \frac{(high -low) ^ 2}{12}.
"""
low, high = self._check_param(low, high)
low, high = self._check_param_type(low, high)
return self.sq(high - low) / 12.0
def _entropy(self, low=None, high=None):
@ -233,7 +210,7 @@ class Uniform(Distribution):
.. math::
H(U) = \log(high - low).
"""
low, high = self._check_param(low, high)
low, high = self._check_param_type(low, high)
return self.log(high - low)
def _cross_entropy(self, dist, low_b, high_b, low=None, high=None):
@ -266,7 +243,7 @@ class Uniform(Distribution):
"""
value = self._check_value(value, 'value')
value = self.cast(value, self.dtype)
low, high = self._check_param(low, high)
low, high = self._check_param_type(low, high)
neg_ones = self.fill(self.dtype, self.shape(value), -1.0)
prob = self.exp(neg_ones * self.log(high - low))
broadcast_shape = self.shape(prob)
@ -292,7 +269,7 @@ class Uniform(Distribution):
low_b = self.cast(low_b, self.parameter_type)
high_b = self._check_value(high_b, 'high_b')
high_b = self.cast(high_b, self.parameter_type)
low_a, high_a = self._check_param(low, high)
low_a, high_a = self._check_param_type(low, high)
kl = self.log(high_b - low_b) - self.log(high_a - low_a)
comp = self.logicaland(self.lessequal(low_b, low_a), self.lessequal(high_a, high_b))
return self.select(comp, kl, self.log(self.zeroslike(kl)))
@ -313,7 +290,7 @@ class Uniform(Distribution):
"""
value = self._check_value(value, 'value')
value = self.cast(value, self.dtype)
low, high = self._check_param(low, high)
low, high = self._check_param_type(low, high)
prob = (value - low) / (high - low)
broadcast_shape = self.shape(prob)
zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
@ -336,7 +313,7 @@ class Uniform(Distribution):
Tensor, shape is shape + batch_shape.
"""
shape = self.checktuple(shape, 'shape')
low, high = self._check_param(low, high)
low, high = self._check_param_type(low, high)
broadcast_shape = self.shape(low + high)
origin_shape = shape + broadcast_shape
if origin_shape == ():

View File

@ -0,0 +1,182 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Test util functions used in distribution classes.
"""
import numpy as np
import pytest
from mindspore.nn.cell import Cell
from mindspore import context
from mindspore import dtype
from mindspore import Tensor
from mindspore.common.parameter import Parameter
from mindspore.nn.probability.distribution._utils.utils import set_param_type, \
cast_to_tensor, CheckTuple, CheckTensor
def test_set_param_type():
"""
Test set_param_type function.
"""
tensor_fp16 = Tensor(0.1, dtype=dtype.float16)
tensor_fp32 = Tensor(0.1, dtype=dtype.float32)
tensor_fp64 = Tensor(0.1, dtype=dtype.float64)
tensor_int32 = Tensor(0.1, dtype=dtype.int32)
array_fp32 = np.array(1.0).astype(np.float32)
array_fp64 = np.array(1.0).astype(np.float64)
array_int32 = np.array(1.0).astype(np.int32)
dict1 = {'a': tensor_fp32, 'b': 1.0, 'c': tensor_fp32}
dict2 = {'a': tensor_fp32, 'b': 1.0, 'c': tensor_fp64}
dict3 = {'a': tensor_int32, 'b': 1.0, 'c': tensor_int32}
dict4 = {'a': array_fp32, 'b': 1.0, 'c': tensor_fp32}
dict5 = {'a': array_fp32, 'b': 1.0, 'c': array_fp64}
dict6 = {'a': array_fp32, 'b': 1.0, 'c': array_int32}
dict7 = {'a': 1.0}
dict8 = {'a': 1.0, 'b': 1.0, 'c': 1.0}
dict9 = {'a': tensor_fp16, 'b': tensor_fp16, 'c': tensor_fp16}
dict10 = {'a': tensor_fp64, 'b': tensor_fp64, 'c': tensor_fp64}
dict11 = {'a': array_fp64, 'b': array_fp64, 'c': tensor_fp64}
ans1 = set_param_type(dict1, dtype.float16)
assert ans1 == dtype.float32
with pytest.raises(TypeError):
set_param_type(dict2, dtype.float32)
ans3 = set_param_type(dict3, dtype.float16)
assert ans3 == dtype.float32
ans4 = set_param_type(dict4, dtype.float16)
assert ans4 == dtype.float32
with pytest.raises(TypeError):
set_param_type(dict5, dtype.float32)
with pytest.raises(TypeError):
set_param_type(dict6, dtype.float32)
ans7 = set_param_type(dict7, dtype.float32)
assert ans7 == dtype.float32
ans8 = set_param_type(dict8, dtype.float32)
assert ans8 == dtype.float32
ans9 = set_param_type(dict9, dtype.float32)
assert ans9 == dtype.float16
ans10 = set_param_type(dict10, dtype.float32)
assert ans10 == dtype.float32
ans11 = set_param_type(dict11, dtype.float32)
assert ans11 == dtype.float32
def test_cast_to_tensor():
"""
Test cast_to_tensor.
"""
with pytest.raises(ValueError):
cast_to_tensor(None, dtype.float32)
with pytest.raises(TypeError):
cast_to_tensor(True, dtype.float32)
with pytest.raises(TypeError):
cast_to_tensor({'a': 1, 'b': 2}, dtype.float32)
with pytest.raises(TypeError):
cast_to_tensor('tensor', dtype.float32)
ans1 = cast_to_tensor(Parameter(Tensor(0.1, dtype=dtype.float32), 'param'))
assert isinstance(ans1, Parameter)
ans2 = cast_to_tensor(np.array(1.0).astype(np.float32))
assert isinstance(ans2, Tensor)
ans3 = cast_to_tensor([1.0, 2.0])
assert isinstance(ans3, Tensor)
ans4 = cast_to_tensor(Tensor(0.1, dtype=dtype.float32), dtype.float32)
assert isinstance(ans4, Tensor)
ans5 = cast_to_tensor(0.1, dtype.float32)
assert isinstance(ans5, Tensor)
ans6 = cast_to_tensor(1, dtype.float32)
assert isinstance(ans6, Tensor)
class Net(Cell):
"""
Test class: CheckTuple.
"""
def __init__(self, value):
super(Net, self).__init__()
self.checktuple = CheckTuple()
self.value = value
def construct(self, value=None):
if value is None:
return self.checktuple(self.value, 'input')
return self.checktuple(value, 'input')
def test_check_tuple():
"""
Test CheckTuple.
"""
net1 = Net((1, 2, 3))
ans1 = net1()
assert isinstance(ans1, tuple)
with pytest.raises(TypeError):
net2 = Net('tuple')
net2()
context.set_context(mode=context.GRAPH_MODE)
net3 = Net((1, 2, 3))
ans3 = net3()
assert isinstance(ans3, tuple)
with pytest.raises(TypeError):
net4 = Net('tuple')
net4()
class Net1(Cell):
"""
Test class: CheckTensor.
"""
def __init__(self, value):
super(Net1, self).__init__()
self.checktensor = CheckTensor()
self.value = value
self.context = context.get_context('mode')
def construct(self, value=None):
value = self.value if value is None else value
if self.context == 0:
self.checktensor(value, 'input')
return value
return self.checktensor(value, 'input')
def test_check_tensor():
"""
Test CheckTensor.
"""
value = Tensor(0.1, dtype=dtype.float32)
net1 = Net1(value)
ans1 = net1()
assert isinstance(ans1, Tensor)
ans1 = net1(value)
assert isinstance(ans1, Tensor)
with pytest.raises(TypeError):
net2 = Net1('tuple')
net2()
context.set_context(mode=context.GRAPH_MODE)
net3 = Net1(value)
ans3 = net3()
assert isinstance(ans3, Tensor)
ans3 = net3(value)
assert isinstance(ans3, Tensor)
with pytest.raises(TypeError):
net4 = Net1('tuple')
net4()