forked from mindspore-Ecosystem/mindspore
!6020 Add common check_param_type and set_param_type in distribution
Merge pull request !6020 from XunDeng/new_check_param
This commit is contained in:
commit
c10341dfb7
|
@ -110,6 +110,8 @@ def check_scalar_from_param(params):
|
||||||
Notes: String parameters are excluded.
|
Notes: String parameters are excluded.
|
||||||
"""
|
"""
|
||||||
for value in params.values():
|
for value in params.values():
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
if isinstance(value, (msp.bijector.Bijector, msp.distribution.Distribution)):
|
if isinstance(value, (msp.bijector.Bijector, msp.distribution.Distribution)):
|
||||||
return params['distribution'].is_scalar_batch
|
return params['distribution'].is_scalar_batch
|
||||||
if isinstance(value, Parameter):
|
if isinstance(value, Parameter):
|
||||||
|
@ -358,23 +360,29 @@ class CheckTensor(PrimitiveWithInfer):
|
||||||
return x
|
return x
|
||||||
raise TypeError(f"For {name}, input type should be a Tensor or Parameter.")
|
raise TypeError(f"For {name}, input type should be a Tensor or Parameter.")
|
||||||
|
|
||||||
def common_dtype(arg_a, name_a, arg_b, name_b, hint_type):
|
def set_param_type(args, hint_type):
|
||||||
"""
|
"""
|
||||||
check if arg_a and arg_b have the same dtype.
|
Find the common type among arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args (dict): dictionary of arguments, {'name':value}.
|
||||||
|
hint_type (mindspore.dtype): hint type to return.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if tensors in args are not the same dtype.
|
||||||
"""
|
"""
|
||||||
if hasattr(arg_a, 'dtype') and hasattr(arg_b, 'dtype'):
|
common_dtype = None
|
||||||
if isinstance(arg_a, np.ndarray):
|
for name, arg in args.items():
|
||||||
a_dtype = mstype.pytype_to_dtype(arg_a.dtype)
|
if hasattr(arg, 'dtype'):
|
||||||
else:
|
if isinstance(arg, np.ndarray):
|
||||||
a_dtype = arg_a.dtype
|
cur_dtype = mstype.pytype_to_dtype(arg.dtype)
|
||||||
if isinstance(arg_b, np.ndarray):
|
else:
|
||||||
b_dtype = mstype.pytype_to_dtype(arg_b.dtype)
|
cur_dtype = arg.dtype
|
||||||
else:
|
if common_dtype is None:
|
||||||
b_dtype = arg_b.dtype
|
common_dtype = cur_dtype
|
||||||
if a_dtype != b_dtype:
|
elif cur_dtype != common_dtype:
|
||||||
raise TypeError(f"{name_a} and {name_b} should have the same dtype.")
|
raise TypeError(f"{name} should have the same dtype as other arguments.")
|
||||||
int_type = mstype.int_type + mstype.uint_type
|
int_type = mstype.int_type + mstype.uint_type
|
||||||
if a_dtype in int_type or a_dtype == mstype.float64:
|
if common_dtype in int_type or common_dtype == mstype.float64:
|
||||||
return mstype.float32
|
return mstype.float32
|
||||||
return a_dtype
|
return hint_type if common_dtype is None else common_dtype
|
||||||
return hint_type
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ from mindspore.common import dtype as mstype
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops import composite as C
|
from mindspore.ops import composite as C
|
||||||
from .distribution import Distribution
|
from .distribution import Distribution
|
||||||
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error
|
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, set_param_type
|
||||||
from ._utils.custom_ops import exp_generic, log_generic, erf_generic
|
from ._utils.custom_ops import exp_generic, log_generic, erf_generic
|
||||||
|
|
||||||
|
|
||||||
|
@ -119,13 +119,16 @@ class Bernoulli(Distribution):
|
||||||
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
|
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
|
||||||
check_type(dtype, valid_dtype, type(self).__name__)
|
check_type(dtype, valid_dtype, type(self).__name__)
|
||||||
super(Bernoulli, self).__init__(seed, dtype, name, param)
|
super(Bernoulli, self).__init__(seed, dtype, name, param)
|
||||||
self.parameter_type = mstype.float32
|
self.parameter_type = set_param_type({'probs1': probs}, mstype.float32)
|
||||||
if probs is not None:
|
if probs is not None:
|
||||||
self._probs = cast_to_tensor(probs, mstype.float32)
|
self._probs = cast_to_tensor(probs, self.parameter_type)
|
||||||
check_prob(self.probs)
|
check_prob(self.probs)
|
||||||
else:
|
else:
|
||||||
self._probs = probs
|
self._probs = probs
|
||||||
|
|
||||||
|
self.default_parameters = [self.probs]
|
||||||
|
self.parameter_names = ['probs1']
|
||||||
|
|
||||||
# ops needed for the class
|
# ops needed for the class
|
||||||
self.exp = exp_generic
|
self.exp = exp_generic
|
||||||
self.log = log_generic
|
self.log = log_generic
|
||||||
|
@ -157,24 +160,12 @@ class Bernoulli(Distribution):
|
||||||
"""
|
"""
|
||||||
return self._probs
|
return self._probs
|
||||||
|
|
||||||
def _check_param(self, probs1):
|
|
||||||
"""
|
|
||||||
Check availablity of distribution specific args `probs1`.
|
|
||||||
"""
|
|
||||||
if probs1 is not None:
|
|
||||||
if self.context_mode == 0:
|
|
||||||
self.checktensor(probs1, 'probs1')
|
|
||||||
else:
|
|
||||||
probs1 = self.checktensor(probs1, 'probs1')
|
|
||||||
return self.cast(probs1, self.parameter_type)
|
|
||||||
return self.probs if self.probs is not None else raise_none_error('probs1')
|
|
||||||
|
|
||||||
def _mean(self, probs1=None):
|
def _mean(self, probs1=None):
|
||||||
r"""
|
r"""
|
||||||
.. math::
|
.. math::
|
||||||
MEAN(B) = probs1
|
MEAN(B) = probs1
|
||||||
"""
|
"""
|
||||||
probs1 = self._check_param(probs1)
|
probs1 = self._check_param_type(probs1)
|
||||||
return probs1
|
return probs1
|
||||||
|
|
||||||
def _mode(self, probs1=None):
|
def _mode(self, probs1=None):
|
||||||
|
@ -182,7 +173,7 @@ class Bernoulli(Distribution):
|
||||||
.. math::
|
.. math::
|
||||||
MODE(B) = 1 if probs1 > 0.5 else = 0
|
MODE(B) = 1 if probs1 > 0.5 else = 0
|
||||||
"""
|
"""
|
||||||
probs1 = self._check_param(probs1)
|
probs1 = self._check_param_type(probs1)
|
||||||
prob_type = self.dtypeop(probs1)
|
prob_type = self.dtypeop(probs1)
|
||||||
zeros = self.fill(prob_type, self.shape(probs1), 0.0)
|
zeros = self.fill(prob_type, self.shape(probs1), 0.0)
|
||||||
ones = self.fill(prob_type, self.shape(probs1), 1.0)
|
ones = self.fill(prob_type, self.shape(probs1), 1.0)
|
||||||
|
@ -194,7 +185,7 @@ class Bernoulli(Distribution):
|
||||||
.. math::
|
.. math::
|
||||||
VAR(B) = probs1 * probs0
|
VAR(B) = probs1 * probs0
|
||||||
"""
|
"""
|
||||||
probs1 = self._check_param(probs1)
|
probs1 = self._check_param_type(probs1)
|
||||||
probs0 = 1.0 - probs1
|
probs0 = 1.0 - probs1
|
||||||
return self.exp(self.log(probs0) + self.log(probs1))
|
return self.exp(self.log(probs0) + self.log(probs1))
|
||||||
|
|
||||||
|
@ -203,11 +194,11 @@ class Bernoulli(Distribution):
|
||||||
.. math::
|
.. math::
|
||||||
H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1)
|
H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1)
|
||||||
"""
|
"""
|
||||||
probs1 = self._check_param(probs1)
|
probs1 = self._check_param_type(probs1)
|
||||||
probs0 = 1.0 - probs1
|
probs0 = 1.0 - probs1
|
||||||
return -(probs0 * self.log(probs0)) - (probs1 * self.log(probs1))
|
return -(probs0 * self.log(probs0)) - (probs1 * self.log(probs1))
|
||||||
|
|
||||||
def _cross_entropy(self, dist, probs1_b, probs1_a=None):
|
def _cross_entropy(self, dist, probs1_b, probs1=None):
|
||||||
"""
|
"""
|
||||||
Evaluate cross_entropy between Bernoulli distributions.
|
Evaluate cross_entropy between Bernoulli distributions.
|
||||||
|
|
||||||
|
@ -217,7 +208,7 @@ class Bernoulli(Distribution):
|
||||||
probs1_a (Tensor): `probs1` of distribution a. Default: self.probs.
|
probs1_a (Tensor): `probs1` of distribution a. Default: self.probs.
|
||||||
"""
|
"""
|
||||||
check_distribution_name(dist, 'Bernoulli')
|
check_distribution_name(dist, 'Bernoulli')
|
||||||
return self._entropy(probs1_a) + self._kl_loss(dist, probs1_b, probs1_a)
|
return self._entropy(probs1) + self._kl_loss(dist, probs1_b, probs1)
|
||||||
|
|
||||||
def _log_prob(self, value, probs1=None):
|
def _log_prob(self, value, probs1=None):
|
||||||
r"""
|
r"""
|
||||||
|
@ -233,7 +224,7 @@ class Bernoulli(Distribution):
|
||||||
"""
|
"""
|
||||||
value = self._check_value(value, 'value')
|
value = self._check_value(value, 'value')
|
||||||
value = self.cast(value, mstype.float32)
|
value = self.cast(value, mstype.float32)
|
||||||
probs1 = self._check_param(probs1)
|
probs1 = self._check_param_type(probs1)
|
||||||
probs0 = 1.0 - probs1
|
probs0 = 1.0 - probs1
|
||||||
return self.log(probs1) * value + self.log(probs0) * (1.0 - value)
|
return self.log(probs1) * value + self.log(probs0) * (1.0 - value)
|
||||||
|
|
||||||
|
@ -253,7 +244,7 @@ class Bernoulli(Distribution):
|
||||||
value = self._check_value(value, 'value')
|
value = self._check_value(value, 'value')
|
||||||
value = self.cast(value, mstype.float32)
|
value = self.cast(value, mstype.float32)
|
||||||
value = self.floor(value)
|
value = self.floor(value)
|
||||||
probs1 = self._check_param(probs1)
|
probs1 = self._check_param_type(probs1)
|
||||||
prob_type = self.dtypeop(probs1)
|
prob_type = self.dtypeop(probs1)
|
||||||
value = value * self.fill(prob_type, self.shape(probs1), 1.0)
|
value = value * self.fill(prob_type, self.shape(probs1), 1.0)
|
||||||
probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0)
|
probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0)
|
||||||
|
@ -264,7 +255,7 @@ class Bernoulli(Distribution):
|
||||||
less_than_zero = self.select(comp_zero, zeros, probs0)
|
less_than_zero = self.select(comp_zero, zeros, probs0)
|
||||||
return self.select(comp_one, less_than_zero, ones)
|
return self.select(comp_one, less_than_zero, ones)
|
||||||
|
|
||||||
def _kl_loss(self, dist, probs1_b, probs1_a=None):
|
def _kl_loss(self, dist, probs1_b, probs1=None):
|
||||||
r"""
|
r"""
|
||||||
Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b).
|
Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b).
|
||||||
|
|
||||||
|
@ -280,7 +271,7 @@ class Bernoulli(Distribution):
|
||||||
check_distribution_name(dist, 'Bernoulli')
|
check_distribution_name(dist, 'Bernoulli')
|
||||||
probs1_b = self._check_value(probs1_b, 'probs1_b')
|
probs1_b = self._check_value(probs1_b, 'probs1_b')
|
||||||
probs1_b = self.cast(probs1_b, self.parameter_type)
|
probs1_b = self.cast(probs1_b, self.parameter_type)
|
||||||
probs1_a = self._check_param(probs1_a)
|
probs1_a = self._check_param_type(probs1)
|
||||||
probs0_a = 1.0 - probs1_a
|
probs0_a = 1.0 - probs1_a
|
||||||
probs0_b = 1.0 - probs1_b
|
probs0_b = 1.0 - probs1_b
|
||||||
return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b)
|
return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b)
|
||||||
|
@ -297,7 +288,7 @@ class Bernoulli(Distribution):
|
||||||
Tensor, shape is shape + batch_shape.
|
Tensor, shape is shape + batch_shape.
|
||||||
"""
|
"""
|
||||||
shape = self.checktuple(shape, 'shape')
|
shape = self.checktuple(shape, 'shape')
|
||||||
probs1 = self._check_param(probs1)
|
probs1 = self._check_param_type(probs1)
|
||||||
origin_shape = shape + self.shape(probs1)
|
origin_shape = shape + self.shape(probs1)
|
||||||
if origin_shape == ():
|
if origin_shape == ():
|
||||||
sample_shape = (1,)
|
sample_shape = (1,)
|
||||||
|
|
|
@ -18,7 +18,8 @@ from mindspore.nn.cell import Cell
|
||||||
from mindspore._checkparam import Validator as validator
|
from mindspore._checkparam import Validator as validator
|
||||||
from mindspore._checkparam import Rel
|
from mindspore._checkparam import Rel
|
||||||
from mindspore.common import get_seed
|
from mindspore.common import get_seed
|
||||||
from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param, cast_type_for_device
|
from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param, cast_type_for_device,\
|
||||||
|
raise_none_error
|
||||||
from ._utils.utils import CheckTuple, CheckTensor
|
from ._utils.utils import CheckTuple, CheckTensor
|
||||||
|
|
||||||
|
|
||||||
|
@ -115,6 +116,51 @@ class Distribution(Cell):
|
||||||
def broadcast_shape(self):
|
def broadcast_shape(self):
|
||||||
return self._broadcast_shape
|
return self._broadcast_shape
|
||||||
|
|
||||||
|
def _check_param_type(self, *args):
|
||||||
|
"""
|
||||||
|
Check the availability and validity of default parameters and dist_spec_args.
|
||||||
|
dist_spec_args passed in must be tensors. If default parameter of the distribution
|
||||||
|
is None, its parameter must be passed in through `args`.
|
||||||
|
"""
|
||||||
|
broadcast_shape = None
|
||||||
|
common_dtype = None
|
||||||
|
out = []
|
||||||
|
|
||||||
|
for arg, name, default in zip(args, self.parameter_names, self.default_parameters):
|
||||||
|
# check if the argument is a Tensor
|
||||||
|
if arg is not None:
|
||||||
|
if self.context_mode == 0:
|
||||||
|
self.checktensor(arg, name)
|
||||||
|
else:
|
||||||
|
arg = self.checktensor(arg, name)
|
||||||
|
else:
|
||||||
|
arg = default if default is not None else raise_none_error(name)
|
||||||
|
|
||||||
|
# broadcast if the number of args > 1
|
||||||
|
if broadcast_shape is None:
|
||||||
|
broadcast_shape = self.shape(arg)
|
||||||
|
common_dtype = self.dtypeop(arg)
|
||||||
|
else:
|
||||||
|
ones = self.fill(self.dtypeop(arg), broadcast_shape, 1.0)
|
||||||
|
broadcast_shape = self.shape(arg + ones)
|
||||||
|
|
||||||
|
# check if the arguments have the same dtype
|
||||||
|
arg = arg * self.fill(self.dtypeop(arg), broadcast_shape, 1.0)
|
||||||
|
dtype_tensor = self.fill(common_dtype, broadcast_shape, 1.0)
|
||||||
|
self.sametypeshape(arg, dtype_tensor)
|
||||||
|
arg = self.cast(arg, self.parameter_type)
|
||||||
|
out.append(arg)
|
||||||
|
|
||||||
|
if len(out) == 1:
|
||||||
|
return out[0]
|
||||||
|
|
||||||
|
# broadcast all args to broadcast_shape
|
||||||
|
result = ()
|
||||||
|
for arg in out:
|
||||||
|
arg = arg * self.fill(self.dtypeop(arg), broadcast_shape, 1.0)
|
||||||
|
result = result + (arg,)
|
||||||
|
return result
|
||||||
|
|
||||||
def _check_value(self, value, name):
|
def _check_value(self, value, name):
|
||||||
"""
|
"""
|
||||||
Check availability of `value` as a Tensor.
|
Check availability of `value` as a Tensor.
|
||||||
|
@ -211,163 +257,203 @@ class Distribution(Cell):
|
||||||
if hasattr(self, '_cross_entropy'):
|
if hasattr(self, '_cross_entropy'):
|
||||||
self._call_cross_entropy = self._cross_entropy
|
self._call_cross_entropy = self._cross_entropy
|
||||||
|
|
||||||
def log_prob(self, *args, **kwargs):
|
def log_prob(self, value, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Evaluate the log probability(pdf or pmf) at the given value.
|
Evaluate the log probability(pdf or pmf) at the given value.
|
||||||
|
|
||||||
Note:
|
Args:
|
||||||
The argument `args` must include `value`.
|
value (Tensor): value to be evaluated.
|
||||||
dist_spec_args are optional.
|
*args (list): the list of positional arguments forwarded to subclasses.
|
||||||
"""
|
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
|
||||||
return self._call_log_prob(*args, **kwargs)
|
|
||||||
|
|
||||||
def _calc_prob_from_log_prob(self, *args, **kwargs):
|
Note:
|
||||||
|
A distribution can be optionally passed to the function by passing its dist_spec_args through
|
||||||
|
`args` or `kwargs`.
|
||||||
|
"""
|
||||||
|
return self._call_log_prob(value, *args, **kwargs)
|
||||||
|
|
||||||
|
def _calc_prob_from_log_prob(self, value, *args, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Evaluate prob from log probability.
|
Evaluate prob from log probability.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
probability(x) = \exp(log_likehood(x))
|
probability(x) = \exp(log_likehood(x))
|
||||||
"""
|
"""
|
||||||
return self.exp(self._log_prob(*args, **kwargs))
|
return self.exp(self._log_prob(value, *args, **kwargs))
|
||||||
|
|
||||||
def prob(self, *args, **kwargs):
|
def prob(self, value, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Evaluate the probability (pdf or pmf) at given value.
|
Evaluate the probability (pdf or pmf) at given value.
|
||||||
|
|
||||||
Note:
|
Args:
|
||||||
The argument `args` must include `value`.
|
value (Tensor): value to be evaluated.
|
||||||
dist_spec_args are optional.
|
*args (list): the list of positional arguments forwarded to subclasses.
|
||||||
"""
|
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
|
||||||
return self._call_prob(*args, **kwargs)
|
|
||||||
|
|
||||||
def _calc_log_prob_from_prob(self, *args, **kwargs):
|
Note:
|
||||||
|
A distribution can be optionally passed to the function by passing its dist_spec_args through
|
||||||
|
`args` or `kwargs`.
|
||||||
|
"""
|
||||||
|
return self._call_prob(value, *args, **kwargs)
|
||||||
|
|
||||||
|
def _calc_log_prob_from_prob(self, value, *args, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Evaluate log probability from probability.
|
Evaluate log probability from probability.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
log_prob(x) = \log(prob(x))
|
log_prob(x) = \log(prob(x))
|
||||||
"""
|
"""
|
||||||
return self.log(self._prob(*args, **kwargs))
|
return self.log(self._prob(value, *args, **kwargs))
|
||||||
|
|
||||||
def cdf(self, *args, **kwargs):
|
def cdf(self, value, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Evaluate the cdf at given value.
|
Evaluate the cdf at given value.
|
||||||
|
|
||||||
Note:
|
Args:
|
||||||
The argument `args` must include `value`.
|
value (Tensor): value to be evaluated.
|
||||||
dist_spec_args are optional.
|
*args (list): the list of positional arguments forwarded to subclasses.
|
||||||
"""
|
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
|
||||||
return self._call_cdf(*args, **kwargs)
|
|
||||||
|
|
||||||
def _calc_cdf_from_log_cdf(self, *args, **kwargs):
|
Note:
|
||||||
|
A distribution can be optionally passed to the function by passing its dist_spec_args through
|
||||||
|
`args` or `kwargs`.
|
||||||
|
"""
|
||||||
|
return self._call_cdf(value, *args, **kwargs)
|
||||||
|
|
||||||
|
def _calc_cdf_from_log_cdf(self, value, *args, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Evaluate cdf from log_cdf.
|
Evaluate cdf from log_cdf.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
cdf(x) = \exp(log_cdf(x))
|
cdf(x) = \exp(log_cdf(x))
|
||||||
"""
|
"""
|
||||||
return self.exp(self._log_cdf(*args, **kwargs))
|
return self.exp(self._log_cdf(value, *args, **kwargs))
|
||||||
|
|
||||||
def _calc_cdf_from_survival(self, *args, **kwargs):
|
def _calc_cdf_from_survival(self, value, *args, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Evaluate cdf from survival function.
|
Evaluate cdf from survival function.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
cdf(x) = 1 - (survival_function(x))
|
cdf(x) = 1 - (survival_function(x))
|
||||||
"""
|
"""
|
||||||
return 1.0 - self._survival_function(*args, **kwargs)
|
return 1.0 - self._survival_function(value, *args, **kwargs)
|
||||||
|
|
||||||
def _calc_cdf_from_log_survival(self, *args, **kwargs):
|
def _calc_cdf_from_log_survival(self, value, *args, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Evaluate cdf from log survival function.
|
Evaluate cdf from log survival function.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
cdf(x) = 1 - (\exp(log_survival(x)))
|
cdf(x) = 1 - (\exp(log_survival(x)))
|
||||||
"""
|
"""
|
||||||
return 1.0 - self.exp(self._log_survival(*args, **kwargs))
|
return 1.0 - self.exp(self._log_survival(value, *args, **kwargs))
|
||||||
|
|
||||||
def log_cdf(self, *args, **kwargs):
|
def log_cdf(self, value, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Evaluate the log cdf at given value.
|
Evaluate the log cdf at given value.
|
||||||
|
|
||||||
Note:
|
Args:
|
||||||
The argument `args` must include `value`.
|
value (Tensor): value to be evaluated.
|
||||||
dist_spec_args are optional.
|
*args (list): the list of positional arguments forwarded to subclasses.
|
||||||
"""
|
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
|
||||||
return self._call_log_cdf(*args, **kwargs)
|
|
||||||
|
|
||||||
def _calc_log_cdf_from_call_cdf(self, *args, **kwargs):
|
Note:
|
||||||
|
A distribution can be optionally passed to the function by passing its dist_spec_args through
|
||||||
|
`args` or `kwargs`.
|
||||||
|
"""
|
||||||
|
return self._call_log_cdf(value, *args, **kwargs)
|
||||||
|
|
||||||
|
def _calc_log_cdf_from_call_cdf(self, value, *args, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Evaluate log cdf from cdf.
|
Evaluate log cdf from cdf.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
log_cdf(x) = \log(cdf(x))
|
log_cdf(x) = \log(cdf(x))
|
||||||
"""
|
"""
|
||||||
return self.log(self._call_cdf(*args, **kwargs))
|
return self.log(self._call_cdf(value, *args, **kwargs))
|
||||||
|
|
||||||
def survival_function(self, *args, **kwargs):
|
def survival_function(self, value, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Evaluate the survival function at given value.
|
Evaluate the survival function at given value.
|
||||||
|
|
||||||
Note:
|
Args:
|
||||||
The argument `args` must include `value`.
|
value (Tensor): value to be evaluated.
|
||||||
dist_spec_args are optional.
|
*args (list): the list of positional arguments forwarded to subclasses.
|
||||||
"""
|
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
|
||||||
return self._call_survival(*args, **kwargs)
|
|
||||||
|
|
||||||
def _calc_survival_from_call_cdf(self, *args, **kwargs):
|
Note:
|
||||||
|
A distribution can be optionally passed to the function by passing its dist_spec_args through
|
||||||
|
`args` or `kwargs`.
|
||||||
|
"""
|
||||||
|
return self._call_survival(value, *args, **kwargs)
|
||||||
|
|
||||||
|
def _calc_survival_from_call_cdf(self, value, *args, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Evaluate survival function from cdf.
|
Evaluate survival function from cdf.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
survival_function(x) = 1 - (cdf(x))
|
survival_function(x) = 1 - (cdf(x))
|
||||||
"""
|
"""
|
||||||
return 1.0 - self._call_cdf(*args, **kwargs)
|
return 1.0 - self._call_cdf(value, *args, **kwargs)
|
||||||
|
|
||||||
def _calc_survival_from_log_survival(self, *args, **kwargs):
|
def _calc_survival_from_log_survival(self, value, *args, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Evaluate survival function from log survival function.
|
Evaluate survival function from log survival function.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
survival(x) = \exp(survival_function(x))
|
survival(x) = \exp(survival_function(x))
|
||||||
"""
|
"""
|
||||||
return self.exp(self._log_survival(*args, **kwargs))
|
return self.exp(self._log_survival(value, *args, **kwargs))
|
||||||
|
|
||||||
def log_survival(self, *args, **kwargs):
|
def log_survival(self, value, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Evaluate the log survival function at given value.
|
Evaluate the log survival function at given value.
|
||||||
|
|
||||||
Note:
|
Args:
|
||||||
The arguments `args` must include `value`.
|
value (Tensor): value to be evaluated.
|
||||||
dist_spec_args are optional.
|
*args (list): the list of positional arguments forwarded to subclasses.
|
||||||
"""
|
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
|
||||||
return self._call_log_survival(*args, **kwargs)
|
|
||||||
|
|
||||||
def _calc_log_survival_from_call_survival(self, *args, **kwargs):
|
Note:
|
||||||
|
A distribution can be optionally passed to the function by passing its dist_spec_args through
|
||||||
|
`args` or `kwargs`.
|
||||||
|
"""
|
||||||
|
return self._call_log_survival(value, *args, **kwargs)
|
||||||
|
|
||||||
|
def _calc_log_survival_from_call_survival(self, value, *args, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Evaluate log survival function from survival function.
|
Evaluate log survival function from survival function.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
log_survival(x) = \log(survival_function(x))
|
log_survival(x) = \log(survival_function(x))
|
||||||
"""
|
"""
|
||||||
return self.log(self._call_survival(*args, **kwargs))
|
return self.log(self._call_survival(value, *args, **kwargs))
|
||||||
|
|
||||||
def kl_loss(self, *args, **kwargs):
|
def kl_loss(self, dist, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Evaluate the KL divergence, i.e. KL(a||b).
|
Evaluate the KL divergence, i.e. KL(a||b).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dist (str): type of the distribution.
|
||||||
|
*args (list): the list of positional arguments forwarded to subclasses.
|
||||||
|
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
The argument `args` must include the type of the distribution, parameters of distribution b.
|
dist_spec_args of distribution b must be passed to the function through `args` or `kwargs`.
|
||||||
Parameters for distribution a are optional.
|
Passing in dist_spec_args of distribution a is optional.
|
||||||
"""
|
"""
|
||||||
return self._kl_loss(*args, **kwargs)
|
return self._kl_loss(dist, *args, **kwargs)
|
||||||
|
|
||||||
def mean(self, *args, **kwargs):
|
def mean(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Evaluate the mean.
|
Evaluate the mean.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args (list): the list of positional arguments forwarded to subclasses.
|
||||||
|
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
dist_spec_args are optional.
|
A distribution can be optionally passed to the function by passing its *dist_spec_args* through
|
||||||
|
*args* or *kwargs*.
|
||||||
"""
|
"""
|
||||||
return self._mean(*args, **kwargs)
|
return self._mean(*args, **kwargs)
|
||||||
|
|
||||||
|
@ -375,8 +461,13 @@ class Distribution(Cell):
|
||||||
"""
|
"""
|
||||||
Evaluate the mode.
|
Evaluate the mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args (list): the list of positional arguments forwarded to subclasses.
|
||||||
|
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
dist_spec_args are optional.
|
A distribution can be optionally passed to the function by passing its *dist_spec_args* through
|
||||||
|
*args* or *kwargs*.
|
||||||
"""
|
"""
|
||||||
return self._mode(*args, **kwargs)
|
return self._mode(*args, **kwargs)
|
||||||
|
|
||||||
|
@ -384,8 +475,13 @@ class Distribution(Cell):
|
||||||
"""
|
"""
|
||||||
Evaluate the standard deviation.
|
Evaluate the standard deviation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args (list): the list of positional arguments forwarded to subclasses.
|
||||||
|
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
dist_spec_args are optional.
|
A distribution can be optionally passed to the function by passing its *dist_spec_args* through
|
||||||
|
*args* or *kwargs*.
|
||||||
"""
|
"""
|
||||||
return self._call_sd(*args, **kwargs)
|
return self._call_sd(*args, **kwargs)
|
||||||
|
|
||||||
|
@ -393,8 +489,13 @@ class Distribution(Cell):
|
||||||
"""
|
"""
|
||||||
Evaluate the variance.
|
Evaluate the variance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args (list): the list of positional arguments forwarded to subclasses.
|
||||||
|
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
dist_spec_args are optional.
|
A distribution can be optionally passed to the function by passing its *dist_spec_args* through
|
||||||
|
*args* or *kwargs*.
|
||||||
"""
|
"""
|
||||||
return self._call_var(*args, **kwargs)
|
return self._call_var(*args, **kwargs)
|
||||||
|
|
||||||
|
@ -420,37 +521,52 @@ class Distribution(Cell):
|
||||||
"""
|
"""
|
||||||
Evaluate the entropy.
|
Evaluate the entropy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args (list): the list of positional arguments forwarded to subclasses.
|
||||||
|
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
dist_spec_args are optional.
|
A distribution can be optionally passed to the function by passing its *dist_spec_args* through
|
||||||
|
*args* or *kwargs*.
|
||||||
"""
|
"""
|
||||||
return self._entropy(*args, **kwargs)
|
return self._entropy(*args, **kwargs)
|
||||||
|
|
||||||
def cross_entropy(self, *args, **kwargs):
|
def cross_entropy(self, dist, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Evaluate the cross_entropy between distribution a and b.
|
Evaluate the cross_entropy between distribution a and b.
|
||||||
|
|
||||||
Note:
|
Args:
|
||||||
The argument `args` must include the type of the distribution, parameters of distribution b.
|
dist (str): type of the distribution.
|
||||||
Parameters for distribution a are optional.
|
*args (list): the list of positional arguments forwarded to subclasses.
|
||||||
"""
|
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
|
||||||
return self._call_cross_entropy(*args, **kwargs)
|
|
||||||
|
|
||||||
def _calc_cross_entropy(self, *args, **kwargs):
|
Note:
|
||||||
|
dist_spec_args of distribution b must be passed to the function through `args` or `kwargs`.
|
||||||
|
Passing in dist_spec_args of distribution a is optional.
|
||||||
|
"""
|
||||||
|
return self._call_cross_entropy(dist, *args, **kwargs)
|
||||||
|
|
||||||
|
def _calc_cross_entropy(self, dist, *args, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Evaluate cross_entropy from entropy and kl divergence.
|
Evaluate cross_entropy from entropy and kl divergence.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
H(X, Y) = H(X) + KL(X||Y)
|
H(X, Y) = H(X) + KL(X||Y)
|
||||||
"""
|
"""
|
||||||
return self._entropy(*args, **kwargs) + self._kl_loss(*args, **kwargs)
|
return self._entropy(*args, **kwargs) + self._kl_loss(dist, *args, **kwargs)
|
||||||
|
|
||||||
def sample(self, *args, **kwargs):
|
def sample(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Sampling function.
|
Sampling function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape (tuple): shape of the sample.
|
||||||
|
*args (list): the list of positional arguments forwarded to subclasses.
|
||||||
|
**kwargs (dictionary): the dictionary of keyword arguments forwarded to subclasses.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Shape of the sample is default to ().
|
A distribution can be optionally passed to the function by passing its *dist_spec_args* through
|
||||||
dist_spec_args are optional.
|
*args* or *kwargs*.
|
||||||
"""
|
"""
|
||||||
return self._sample(*args, **kwargs)
|
return self._sample(*args, **kwargs)
|
||||||
|
|
||||||
|
|
|
@ -18,8 +18,7 @@ from mindspore.ops import operations as P
|
||||||
from mindspore.ops import composite as C
|
from mindspore.ops import composite as C
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
from .distribution import Distribution
|
from .distribution import Distribution
|
||||||
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\
|
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name, set_param_type
|
||||||
raise_none_error
|
|
||||||
from ._utils.custom_ops import exp_generic, log_generic
|
from ._utils.custom_ops import exp_generic, log_generic
|
||||||
|
|
||||||
class Exponential(Distribution):
|
class Exponential(Distribution):
|
||||||
|
@ -121,15 +120,19 @@ class Exponential(Distribution):
|
||||||
valid_dtype = mstype.float_type
|
valid_dtype = mstype.float_type
|
||||||
check_type(dtype, valid_dtype, type(self).__name__)
|
check_type(dtype, valid_dtype, type(self).__name__)
|
||||||
super(Exponential, self).__init__(seed, dtype, name, param)
|
super(Exponential, self).__init__(seed, dtype, name, param)
|
||||||
self.parameter_type = dtype
|
self.parameter_type = set_param_type({'rate': rate}, self.dtype)
|
||||||
if rate is not None:
|
if rate is not None:
|
||||||
self._rate = cast_to_tensor(rate, self.parameter_type)
|
self._rate = cast_to_tensor(rate, self.parameter_type)
|
||||||
check_greater_zero(self._rate, "rate")
|
check_greater_zero(self._rate, "rate")
|
||||||
else:
|
else:
|
||||||
self._rate = rate
|
self._rate = rate
|
||||||
|
|
||||||
|
self.default_parameters = [self.rate]
|
||||||
|
self.parameter_names = ['rate']
|
||||||
|
|
||||||
self.minval = np.finfo(np.float).tiny
|
self.minval = np.finfo(np.float).tiny
|
||||||
|
|
||||||
|
|
||||||
# ops needed for the class
|
# ops needed for the class
|
||||||
self.exp = exp_generic
|
self.exp = exp_generic
|
||||||
self.log = log_generic
|
self.log = log_generic
|
||||||
|
@ -156,28 +159,16 @@ class Exponential(Distribution):
|
||||||
@property
|
@property
|
||||||
def rate(self):
|
def rate(self):
|
||||||
"""
|
"""
|
||||||
Return rate of the distribution.
|
Return `rate` of the distribution.
|
||||||
"""
|
"""
|
||||||
return self._rate
|
return self._rate
|
||||||
|
|
||||||
def _check_param(self, rate):
|
|
||||||
"""
|
|
||||||
Check availablity of distribution specific argument `rate`.
|
|
||||||
"""
|
|
||||||
if rate is not None:
|
|
||||||
if self.context_mode == 0:
|
|
||||||
self.checktensor(rate, 'rate')
|
|
||||||
else:
|
|
||||||
rate = self.checktensor(rate, 'rate')
|
|
||||||
return self.cast(rate, self.parameter_type)
|
|
||||||
return self.rate if self.rate is not None else raise_none_error('rate')
|
|
||||||
|
|
||||||
def _mean(self, rate=None):
|
def _mean(self, rate=None):
|
||||||
r"""
|
r"""
|
||||||
.. math::
|
.. math::
|
||||||
MEAN(EXP) = \frac{1.0}{\lambda}.
|
MEAN(EXP) = \frac{1.0}{\lambda}.
|
||||||
"""
|
"""
|
||||||
rate = self._check_param(rate)
|
rate = self._check_param_type(rate)
|
||||||
return 1.0 / rate
|
return 1.0 / rate
|
||||||
|
|
||||||
def _mode(self, rate=None):
|
def _mode(self, rate=None):
|
||||||
|
@ -185,7 +176,7 @@ class Exponential(Distribution):
|
||||||
.. math::
|
.. math::
|
||||||
MODE(EXP) = 0.
|
MODE(EXP) = 0.
|
||||||
"""
|
"""
|
||||||
rate = self._check_param(rate)
|
rate = self._check_param_type(rate)
|
||||||
return self.fill(self.dtype, self.shape(rate), 0.)
|
return self.fill(self.dtype, self.shape(rate), 0.)
|
||||||
|
|
||||||
def _sd(self, rate=None):
|
def _sd(self, rate=None):
|
||||||
|
@ -193,7 +184,7 @@ class Exponential(Distribution):
|
||||||
.. math::
|
.. math::
|
||||||
sd(EXP) = \frac{1.0}{\lambda}.
|
sd(EXP) = \frac{1.0}{\lambda}.
|
||||||
"""
|
"""
|
||||||
rate = self._check_param(rate)
|
rate = self._check_param_type(rate)
|
||||||
return 1.0 / rate
|
return 1.0 / rate
|
||||||
|
|
||||||
def _entropy(self, rate=None):
|
def _entropy(self, rate=None):
|
||||||
|
@ -201,7 +192,7 @@ class Exponential(Distribution):
|
||||||
.. math::
|
.. math::
|
||||||
H(Exp) = 1 - \log(\lambda).
|
H(Exp) = 1 - \log(\lambda).
|
||||||
"""
|
"""
|
||||||
rate = self._check_param(rate)
|
rate = self._check_param_type(rate)
|
||||||
return 1.0 - self.log(rate)
|
return 1.0 - self.log(rate)
|
||||||
|
|
||||||
def _cross_entropy(self, dist, rate_b, rate=None):
|
def _cross_entropy(self, dist, rate_b, rate=None):
|
||||||
|
@ -234,7 +225,7 @@ class Exponential(Distribution):
|
||||||
"""
|
"""
|
||||||
value = self._check_value(value, "value")
|
value = self._check_value(value, "value")
|
||||||
value = self.cast(value, self.dtype)
|
value = self.cast(value, self.dtype)
|
||||||
rate = self._check_param(rate)
|
rate = self._check_param_type(rate)
|
||||||
prob = self.log(rate) - rate * value
|
prob = self.log(rate) - rate * value
|
||||||
zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0)
|
zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0)
|
||||||
neginf = self.fill(self.dtypeop(prob), self.shape(prob), -np.inf)
|
neginf = self.fill(self.dtypeop(prob), self.shape(prob), -np.inf)
|
||||||
|
@ -257,7 +248,7 @@ class Exponential(Distribution):
|
||||||
"""
|
"""
|
||||||
value = self._check_value(value, 'value')
|
value = self._check_value(value, 'value')
|
||||||
value = self.cast(value, self.dtype)
|
value = self.cast(value, self.dtype)
|
||||||
rate = self._check_param(rate)
|
rate = self._check_param_type(rate)
|
||||||
cdf = 1.0 - self.exp(-1. * rate * value)
|
cdf = 1.0 - self.exp(-1. * rate * value)
|
||||||
zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
|
zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
|
||||||
comp = self.less(value, zeros)
|
comp = self.less(value, zeros)
|
||||||
|
@ -279,7 +270,7 @@ class Exponential(Distribution):
|
||||||
"""
|
"""
|
||||||
value = self._check_value(value, 'value')
|
value = self._check_value(value, 'value')
|
||||||
value = self.cast(value, self.dtype)
|
value = self.cast(value, self.dtype)
|
||||||
rate = self._check_param(rate)
|
rate = self._check_param_type(rate)
|
||||||
sf = -1. * rate * value
|
sf = -1. * rate * value
|
||||||
zeros = self.fill(self.dtypeop(sf), self.shape(sf), 0.0)
|
zeros = self.fill(self.dtypeop(sf), self.shape(sf), 0.0)
|
||||||
comp = self.less(value, zeros)
|
comp = self.less(value, zeros)
|
||||||
|
@ -297,7 +288,7 @@ class Exponential(Distribution):
|
||||||
check_distribution_name(dist, 'Exponential')
|
check_distribution_name(dist, 'Exponential')
|
||||||
rate_b = self._check_value(rate_b, 'rate_b')
|
rate_b = self._check_value(rate_b, 'rate_b')
|
||||||
rate_b = self.cast(rate_b, self.parameter_type)
|
rate_b = self.cast(rate_b, self.parameter_type)
|
||||||
rate_a = self._check_param(rate)
|
rate_a = self._check_param_type(rate)
|
||||||
return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0
|
return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0
|
||||||
|
|
||||||
def _sample(self, shape=(), rate=None):
|
def _sample(self, shape=(), rate=None):
|
||||||
|
@ -312,7 +303,7 @@ class Exponential(Distribution):
|
||||||
Tensor, shape is shape + batch_shape.
|
Tensor, shape is shape + batch_shape.
|
||||||
"""
|
"""
|
||||||
shape = self.checktuple(shape, 'shape')
|
shape = self.checktuple(shape, 'shape')
|
||||||
rate = self._check_param(rate)
|
rate = self._check_param_type(rate)
|
||||||
origin_shape = shape + self.shape(rate)
|
origin_shape = shape + self.shape(rate)
|
||||||
if origin_shape == ():
|
if origin_shape == ():
|
||||||
sample_shape = (1,)
|
sample_shape = (1,)
|
||||||
|
|
|
@ -19,7 +19,7 @@ from mindspore.ops import composite as C
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
from .distribution import Distribution
|
from .distribution import Distribution
|
||||||
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\
|
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\
|
||||||
raise_none_error
|
set_param_type
|
||||||
from ._utils.custom_ops import exp_generic, log_generic
|
from ._utils.custom_ops import exp_generic, log_generic
|
||||||
|
|
||||||
|
|
||||||
|
@ -123,13 +123,16 @@ class Geometric(Distribution):
|
||||||
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
|
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
|
||||||
check_type(dtype, valid_dtype, type(self).__name__)
|
check_type(dtype, valid_dtype, type(self).__name__)
|
||||||
super(Geometric, self).__init__(seed, dtype, name, param)
|
super(Geometric, self).__init__(seed, dtype, name, param)
|
||||||
self.parameter_type = mstype.float32
|
self.parameter_type = set_param_type({'probs1': probs}, mstype.float32)
|
||||||
if probs is not None:
|
if probs is not None:
|
||||||
self._probs = cast_to_tensor(probs, self.parameter_type)
|
self._probs = cast_to_tensor(probs, self.parameter_type)
|
||||||
check_prob(self._probs)
|
check_prob(self._probs)
|
||||||
else:
|
else:
|
||||||
self._probs = probs
|
self._probs = probs
|
||||||
|
|
||||||
|
self.default_parameters = [self.probs]
|
||||||
|
self.parameter_names = ['probs1']
|
||||||
|
|
||||||
self.minval = np.finfo(np.float).tiny
|
self.minval = np.finfo(np.float).tiny
|
||||||
|
|
||||||
# ops needed for the class
|
# ops needed for the class
|
||||||
|
@ -164,24 +167,12 @@ class Geometric(Distribution):
|
||||||
"""
|
"""
|
||||||
return self._probs
|
return self._probs
|
||||||
|
|
||||||
def _check_param(self, probs1):
|
|
||||||
"""
|
|
||||||
Check availablity of distribution specific args probs1.
|
|
||||||
"""
|
|
||||||
if probs1 is not None:
|
|
||||||
if self.context_mode == 0:
|
|
||||||
self.checktensor(probs1, 'probs1')
|
|
||||||
else:
|
|
||||||
probs1 = self.checktensor(probs1, 'probs1')
|
|
||||||
return self.cast(probs1, self.parameter_type)
|
|
||||||
return self.probs if self.probs is not None else raise_none_error('probs1')
|
|
||||||
|
|
||||||
def _mean(self, probs1=None):
|
def _mean(self, probs1=None):
|
||||||
r"""
|
r"""
|
||||||
.. math::
|
.. math::
|
||||||
MEAN(Geo) = \fratc{1 - probs1}{probs1}
|
MEAN(Geo) = \fratc{1 - probs1}{probs1}
|
||||||
"""
|
"""
|
||||||
probs1 = self._check_param(probs1)
|
probs1 = self._check_param_type(probs1)
|
||||||
return (1. - probs1) / probs1
|
return (1. - probs1) / probs1
|
||||||
|
|
||||||
def _mode(self, probs1=None):
|
def _mode(self, probs1=None):
|
||||||
|
@ -189,7 +180,7 @@ class Geometric(Distribution):
|
||||||
.. math::
|
.. math::
|
||||||
MODE(Geo) = 0
|
MODE(Geo) = 0
|
||||||
"""
|
"""
|
||||||
probs1 = self._check_param(probs1)
|
probs1 = self._check_param_type(probs1)
|
||||||
return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.)
|
return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.)
|
||||||
|
|
||||||
def _var(self, probs1=None):
|
def _var(self, probs1=None):
|
||||||
|
@ -197,7 +188,7 @@ class Geometric(Distribution):
|
||||||
.. math::
|
.. math::
|
||||||
VAR(Geo) = \frac{1 - probs1}{probs1 ^ {2}}
|
VAR(Geo) = \frac{1 - probs1}{probs1 ^ {2}}
|
||||||
"""
|
"""
|
||||||
probs1 = self._check_param(probs1)
|
probs1 = self._check_param_type(probs1)
|
||||||
return (1.0 - probs1) / self.sq(probs1)
|
return (1.0 - probs1) / self.sq(probs1)
|
||||||
|
|
||||||
def _entropy(self, probs1=None):
|
def _entropy(self, probs1=None):
|
||||||
|
@ -205,7 +196,7 @@ class Geometric(Distribution):
|
||||||
.. math::
|
.. math::
|
||||||
H(Geo) = \frac{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1}
|
H(Geo) = \frac{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1}
|
||||||
"""
|
"""
|
||||||
probs1 = self._check_param(probs1)
|
probs1 = self._check_param_type(probs1)
|
||||||
probs0 = 1.0 - probs1
|
probs0 = 1.0 - probs1
|
||||||
return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1
|
return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1
|
||||||
|
|
||||||
|
@ -236,7 +227,7 @@ class Geometric(Distribution):
|
||||||
value = self._check_value(value, 'value')
|
value = self._check_value(value, 'value')
|
||||||
value = self.cast(value, mstype.float32)
|
value = self.cast(value, mstype.float32)
|
||||||
value = self.floor(value)
|
value = self.floor(value)
|
||||||
probs1 = self._check_param(probs1)
|
probs1 = self._check_param_type(probs1)
|
||||||
pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1))
|
pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1))
|
||||||
zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0)
|
zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0)
|
||||||
comp = self.less(value, zeros)
|
comp = self.less(value, zeros)
|
||||||
|
@ -258,7 +249,7 @@ class Geometric(Distribution):
|
||||||
value = self._check_value(value, 'value')
|
value = self._check_value(value, 'value')
|
||||||
value = self.cast(value, mstype.float32)
|
value = self.cast(value, mstype.float32)
|
||||||
value = self.floor(value)
|
value = self.floor(value)
|
||||||
probs1 = self._check_param(probs1)
|
probs1 = self._check_param_type(probs1)
|
||||||
probs0 = 1.0 - probs1
|
probs0 = 1.0 - probs1
|
||||||
cdf = 1.0 - self.pow(probs0, value + 1.0)
|
cdf = 1.0 - self.pow(probs0, value + 1.0)
|
||||||
zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0)
|
zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0)
|
||||||
|
@ -280,7 +271,7 @@ class Geometric(Distribution):
|
||||||
check_distribution_name(dist, 'Geometric')
|
check_distribution_name(dist, 'Geometric')
|
||||||
probs1_b = self._check_value(probs1_b, 'probs1_b')
|
probs1_b = self._check_value(probs1_b, 'probs1_b')
|
||||||
probs1_b = self.cast(probs1_b, self.parameter_type)
|
probs1_b = self.cast(probs1_b, self.parameter_type)
|
||||||
probs1_a = self._check_param(probs1)
|
probs1_a = self._check_param_type(probs1)
|
||||||
probs0_a = 1.0 - probs1_a
|
probs0_a = 1.0 - probs1_a
|
||||||
probs0_b = 1.0 - probs1_b
|
probs0_b = 1.0 - probs1_b
|
||||||
return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b)
|
return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b)
|
||||||
|
@ -297,7 +288,7 @@ class Geometric(Distribution):
|
||||||
Tensor, shape is shape + batch_shape.
|
Tensor, shape is shape + batch_shape.
|
||||||
"""
|
"""
|
||||||
shape = self.checktuple(shape, 'shape')
|
shape = self.checktuple(shape, 'shape')
|
||||||
probs1 = self._check_param(probs1)
|
probs1 = self._check_param_type(probs1)
|
||||||
origin_shape = shape + self.shape(probs1)
|
origin_shape = shape + self.shape(probs1)
|
||||||
if origin_shape == ():
|
if origin_shape == ():
|
||||||
sample_shape = (1,)
|
sample_shape = (1,)
|
||||||
|
|
|
@ -19,7 +19,7 @@ from mindspore.ops import composite as C
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
from .distribution import Distribution
|
from .distribution import Distribution
|
||||||
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\
|
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\
|
||||||
raise_none_error, common_dtype
|
set_param_type
|
||||||
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, erf_generic
|
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, erf_generic
|
||||||
|
|
||||||
class Normal(Distribution):
|
class Normal(Distribution):
|
||||||
|
@ -127,14 +127,17 @@ class Normal(Distribution):
|
||||||
valid_dtype = mstype.float_type
|
valid_dtype = mstype.float_type
|
||||||
check_type(dtype, valid_dtype, type(self).__name__)
|
check_type(dtype, valid_dtype, type(self).__name__)
|
||||||
super(Normal, self).__init__(seed, dtype, name, param)
|
super(Normal, self).__init__(seed, dtype, name, param)
|
||||||
self.parameter_type = common_dtype(mean, 'mean', sd, 'sd', self.dtype)
|
self.parameter_type = set_param_type({'mean': mean, 'sd': sd}, self.dtype)
|
||||||
if mean is not None and sd is not None:
|
if mean is not None and sd is not None:
|
||||||
self._mean_value = cast_to_tensor(mean, self.parameter_type)
|
self._mean_value = cast_to_tensor(mean, self.parameter_type)
|
||||||
self._sd_value = cast_to_tensor(sd, self.parameter_type)
|
self._sd_value = cast_to_tensor(sd, self.parameter_type)
|
||||||
check_greater_zero(self._sd_value, "Standard deviation")
|
check_greater_zero(self._sd_value, "Standard deviation")
|
||||||
else:
|
else:
|
||||||
self._mean_value = mean
|
self._mean_value = mean if mean is None else cast_to_tensor(mean, self.parameter_type)
|
||||||
self._sd_value = sd
|
self._sd_value = sd if sd is None else cast_to_tensor(sd, self.parameter_type)
|
||||||
|
|
||||||
|
self.default_parameters = [self._mean_value, self._sd_value]
|
||||||
|
self.parameter_names = ['mean', 'sd']
|
||||||
|
|
||||||
#ops needed for the class
|
#ops needed for the class
|
||||||
self.exp = exp_generic
|
self.exp = exp_generic
|
||||||
|
@ -159,51 +162,25 @@ class Normal(Distribution):
|
||||||
str_info = f'batch_shape = {self._broadcast_shape}'
|
str_info = f'batch_shape = {self._broadcast_shape}'
|
||||||
return str_info
|
return str_info
|
||||||
|
|
||||||
def _check_param(self, mean, sd):
|
|
||||||
"""
|
|
||||||
Check availablity of distribution specific args `mean` and `sd`.
|
|
||||||
"""
|
|
||||||
if mean is not None:
|
|
||||||
if self.context_mode == 0:
|
|
||||||
self.checktensor(mean, 'mean')
|
|
||||||
else:
|
|
||||||
mean = self.checktensor(mean, 'mean')
|
|
||||||
else:
|
|
||||||
mean = self._mean_value if self._mean_value is not None else raise_none_error('mean')
|
|
||||||
if sd is not None:
|
|
||||||
if self.context_mode == 0:
|
|
||||||
self.checktensor(sd, 'sd')
|
|
||||||
else:
|
|
||||||
sd = self.checktensor(sd, 'sd')
|
|
||||||
else:
|
|
||||||
sd = self._sd_value if self._sd_value is not None else raise_none_error('sd')
|
|
||||||
batch_shape = self.shape(mean + sd)
|
|
||||||
mean = mean * self.fill(self.dtypeop(mean), batch_shape, 1.0)
|
|
||||||
sd = sd * self.fill(self.dtypeop(sd), batch_shape, 1.0)
|
|
||||||
self.sametypeshape(mean, sd)
|
|
||||||
mean = self.cast(mean, self.parameter_type)
|
|
||||||
sd = self.cast(sd, self.parameter_type)
|
|
||||||
return mean, sd
|
|
||||||
|
|
||||||
def _mean(self, mean=None, sd=None):
|
def _mean(self, mean=None, sd=None):
|
||||||
"""
|
"""
|
||||||
The mean of the distribution.
|
The mean of the distribution.
|
||||||
"""
|
"""
|
||||||
mean, sd = self._check_param(mean, sd)
|
mean, sd = self._check_param_type(mean, sd)
|
||||||
return mean
|
return mean
|
||||||
|
|
||||||
def _mode(self, mean=None, sd=None):
|
def _mode(self, mean=None, sd=None):
|
||||||
"""
|
"""
|
||||||
The mode of the distribution.
|
The mode of the distribution.
|
||||||
"""
|
"""
|
||||||
mean, sd = self._check_param(mean, sd)
|
mean, sd = self._check_param_type(mean, sd)
|
||||||
return mean
|
return mean
|
||||||
|
|
||||||
def _sd(self, mean=None, sd=None):
|
def _sd(self, mean=None, sd=None):
|
||||||
"""
|
"""
|
||||||
The standard deviation of the distribution.
|
The standard deviation of the distribution.
|
||||||
"""
|
"""
|
||||||
mean, sd = self._check_param(mean, sd)
|
mean, sd = self._check_param_type(mean, sd)
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
def _entropy(self, mean=None, sd=None):
|
def _entropy(self, mean=None, sd=None):
|
||||||
|
@ -213,7 +190,7 @@ class Normal(Distribution):
|
||||||
.. math::
|
.. math::
|
||||||
H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma)))
|
H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma)))
|
||||||
"""
|
"""
|
||||||
mean, sd = self._check_param(mean, sd)
|
mean, sd = self._check_param_type(mean, sd)
|
||||||
return self.log(self.sqrt(self.const(np.e * 2. * np.pi))) + self.log(sd)
|
return self.log(self.sqrt(self.const(np.e * 2. * np.pi))) + self.log(sd)
|
||||||
|
|
||||||
def _cross_entropy(self, dist, mean_b, sd_b, mean=None, sd=None):
|
def _cross_entropy(self, dist, mean_b, sd_b, mean=None, sd=None):
|
||||||
|
@ -244,7 +221,7 @@ class Normal(Distribution):
|
||||||
"""
|
"""
|
||||||
value = self._check_value(value, 'value')
|
value = self._check_value(value, 'value')
|
||||||
value = self.cast(value, self.dtype)
|
value = self.cast(value, self.dtype)
|
||||||
mean, sd = self._check_param(mean, sd)
|
mean, sd = self._check_param_type(mean, sd)
|
||||||
unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd))
|
unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd))
|
||||||
neg_normalization = -1. * self.log(self.const(2. * np.pi)) / 2. - self.log(sd)
|
neg_normalization = -1. * self.log(self.const(2. * np.pi)) / 2. - self.log(sd)
|
||||||
return unnormalized_log_prob + neg_normalization
|
return unnormalized_log_prob + neg_normalization
|
||||||
|
@ -263,7 +240,7 @@ class Normal(Distribution):
|
||||||
"""
|
"""
|
||||||
value = self._check_value(value, 'value')
|
value = self._check_value(value, 'value')
|
||||||
value = self.cast(value, self.dtype)
|
value = self.cast(value, self.dtype)
|
||||||
mean, sd = self._check_param(mean, sd)
|
mean, sd = self._check_param_type(mean, sd)
|
||||||
sqrt2 = self.sqrt(self.const(2.0))
|
sqrt2 = self.sqrt(self.const(2.0))
|
||||||
adjusted = (value - mean) / (sd * sqrt2)
|
adjusted = (value - mean) / (sd * sqrt2)
|
||||||
return 0.5 * (1.0 + self.erf(adjusted))
|
return 0.5 * (1.0 + self.erf(adjusted))
|
||||||
|
@ -288,7 +265,7 @@ class Normal(Distribution):
|
||||||
sd_b = self._check_value(sd_b, 'sd_b')
|
sd_b = self._check_value(sd_b, 'sd_b')
|
||||||
mean_b = self.cast(mean_b, self.parameter_type)
|
mean_b = self.cast(mean_b, self.parameter_type)
|
||||||
sd_b = self.cast(sd_b, self.parameter_type)
|
sd_b = self.cast(sd_b, self.parameter_type)
|
||||||
mean_a, sd_a = self._check_param(mean, sd)
|
mean_a, sd_a = self._check_param_type(mean, sd)
|
||||||
diff_log_scale = self.log(sd_a) - self.log(sd_b)
|
diff_log_scale = self.log(sd_a) - self.log(sd_b)
|
||||||
squared_diff = self.sq(mean_a / sd_b - mean_b / sd_b)
|
squared_diff = self.sq(mean_a / sd_b - mean_b / sd_b)
|
||||||
return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale
|
return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale
|
||||||
|
@ -306,7 +283,7 @@ class Normal(Distribution):
|
||||||
Tensor, shape is shape + batch_shape.
|
Tensor, shape is shape + batch_shape.
|
||||||
"""
|
"""
|
||||||
shape = self.checktuple(shape, 'shape')
|
shape = self.checktuple(shape, 'shape')
|
||||||
mean, sd = self._check_param(mean, sd)
|
mean, sd = self._check_param_type(mean, sd)
|
||||||
batch_shape = self.shape(mean + sd)
|
batch_shape = self.shape(mean + sd)
|
||||||
origin_shape = shape + batch_shape
|
origin_shape = shape + batch_shape
|
||||||
if origin_shape == ():
|
if origin_shape == ():
|
||||||
|
|
|
@ -18,7 +18,7 @@ from mindspore.ops import composite as C
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
from .distribution import Distribution
|
from .distribution import Distribution
|
||||||
from ._utils.utils import cast_to_tensor, check_greater, check_type, check_distribution_name,\
|
from ._utils.utils import cast_to_tensor, check_greater, check_type, check_distribution_name,\
|
||||||
raise_none_error, common_dtype
|
set_param_type
|
||||||
from ._utils.custom_ops import exp_generic, log_generic
|
from ._utils.custom_ops import exp_generic, log_generic
|
||||||
|
|
||||||
class Uniform(Distribution):
|
class Uniform(Distribution):
|
||||||
|
@ -126,14 +126,17 @@ class Uniform(Distribution):
|
||||||
valid_dtype = mstype.float_type
|
valid_dtype = mstype.float_type
|
||||||
check_type(dtype, valid_dtype, type(self).__name__)
|
check_type(dtype, valid_dtype, type(self).__name__)
|
||||||
super(Uniform, self).__init__(seed, dtype, name, param)
|
super(Uniform, self).__init__(seed, dtype, name, param)
|
||||||
self.parameter_type = common_dtype(low, 'low', high, 'high', self.dtype)
|
self.parameter_type = set_param_type({'low': low, 'high': high}, self.dtype)
|
||||||
if low is not None and high is not None:
|
if low is not None and high is not None:
|
||||||
self._low = cast_to_tensor(low, dtype)
|
self._low = cast_to_tensor(low, self.parameter_type)
|
||||||
self._high = cast_to_tensor(high, dtype)
|
self._high = cast_to_tensor(high, self.parameter_type)
|
||||||
check_greater(self.low, self.high, "low value", "high value")
|
check_greater(self.low, self.high, "low value", "high value")
|
||||||
else:
|
else:
|
||||||
self._low = low
|
self._low = low if low is None else cast_to_tensor(low, self.parameter_type)
|
||||||
self._high = high
|
self._high = high if high is None else cast_to_tensor(high, self.parameter_type)
|
||||||
|
|
||||||
|
self.default_parameters = [self.low, self.high]
|
||||||
|
self.parameter_names = ['low', 'high']
|
||||||
|
|
||||||
# ops needed for the class
|
# ops needed for the class
|
||||||
self.exp = exp_generic
|
self.exp = exp_generic
|
||||||
|
@ -162,32 +165,6 @@ class Uniform(Distribution):
|
||||||
str_info = f'batch_shape = {self._broadcast_shape}'
|
str_info = f'batch_shape = {self._broadcast_shape}'
|
||||||
return str_info
|
return str_info
|
||||||
|
|
||||||
def _check_param(self, low, high):
|
|
||||||
"""
|
|
||||||
Check availablity of distribution specific args `low` and `high`.
|
|
||||||
"""
|
|
||||||
if low is not None:
|
|
||||||
if self.context_mode == 0:
|
|
||||||
self.checktensor(low, 'low')
|
|
||||||
else:
|
|
||||||
low = self.checktensor(low, 'low')
|
|
||||||
else:
|
|
||||||
low = self.low if self.low is not None else raise_none_error('low')
|
|
||||||
if high is not None:
|
|
||||||
if self.context_mode == 0:
|
|
||||||
self.checktensor(high, 'high')
|
|
||||||
else:
|
|
||||||
high = self.checktensor(high, 'high')
|
|
||||||
else:
|
|
||||||
high = self.high if self.high is not None else raise_none_error('high')
|
|
||||||
batch_shape = self.shape(high - low)
|
|
||||||
high = high * self.fill(self.dtypeop(high), batch_shape, 1.0)
|
|
||||||
low = low * self.fill(self.dtypeop(low), batch_shape, 1.0)
|
|
||||||
self.sametypeshape(high, low)
|
|
||||||
low = self.cast(low, self.parameter_type)
|
|
||||||
high = self.cast(high, self.parameter_type)
|
|
||||||
return low, high
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def low(self):
|
def low(self):
|
||||||
"""
|
"""
|
||||||
|
@ -209,7 +186,7 @@ class Uniform(Distribution):
|
||||||
.. math::
|
.. math::
|
||||||
range(U) = high -low
|
range(U) = high -low
|
||||||
"""
|
"""
|
||||||
low, high = self._check_param(low, high)
|
low, high = self._check_param_type(low, high)
|
||||||
return high - low
|
return high - low
|
||||||
|
|
||||||
def _mean(self, low=None, high=None):
|
def _mean(self, low=None, high=None):
|
||||||
|
@ -217,7 +194,7 @@ class Uniform(Distribution):
|
||||||
.. math::
|
.. math::
|
||||||
MEAN(U) = \frac{low + high}{2}.
|
MEAN(U) = \frac{low + high}{2}.
|
||||||
"""
|
"""
|
||||||
low, high = self._check_param(low, high)
|
low, high = self._check_param_type(low, high)
|
||||||
return (low + high) / 2.
|
return (low + high) / 2.
|
||||||
|
|
||||||
def _var(self, low=None, high=None):
|
def _var(self, low=None, high=None):
|
||||||
|
@ -225,7 +202,7 @@ class Uniform(Distribution):
|
||||||
.. math::
|
.. math::
|
||||||
VAR(U) = \frac{(high -low) ^ 2}{12}.
|
VAR(U) = \frac{(high -low) ^ 2}{12}.
|
||||||
"""
|
"""
|
||||||
low, high = self._check_param(low, high)
|
low, high = self._check_param_type(low, high)
|
||||||
return self.sq(high - low) / 12.0
|
return self.sq(high - low) / 12.0
|
||||||
|
|
||||||
def _entropy(self, low=None, high=None):
|
def _entropy(self, low=None, high=None):
|
||||||
|
@ -233,7 +210,7 @@ class Uniform(Distribution):
|
||||||
.. math::
|
.. math::
|
||||||
H(U) = \log(high - low).
|
H(U) = \log(high - low).
|
||||||
"""
|
"""
|
||||||
low, high = self._check_param(low, high)
|
low, high = self._check_param_type(low, high)
|
||||||
return self.log(high - low)
|
return self.log(high - low)
|
||||||
|
|
||||||
def _cross_entropy(self, dist, low_b, high_b, low=None, high=None):
|
def _cross_entropy(self, dist, low_b, high_b, low=None, high=None):
|
||||||
|
@ -266,7 +243,7 @@ class Uniform(Distribution):
|
||||||
"""
|
"""
|
||||||
value = self._check_value(value, 'value')
|
value = self._check_value(value, 'value')
|
||||||
value = self.cast(value, self.dtype)
|
value = self.cast(value, self.dtype)
|
||||||
low, high = self._check_param(low, high)
|
low, high = self._check_param_type(low, high)
|
||||||
neg_ones = self.fill(self.dtype, self.shape(value), -1.0)
|
neg_ones = self.fill(self.dtype, self.shape(value), -1.0)
|
||||||
prob = self.exp(neg_ones * self.log(high - low))
|
prob = self.exp(neg_ones * self.log(high - low))
|
||||||
broadcast_shape = self.shape(prob)
|
broadcast_shape = self.shape(prob)
|
||||||
|
@ -292,7 +269,7 @@ class Uniform(Distribution):
|
||||||
low_b = self.cast(low_b, self.parameter_type)
|
low_b = self.cast(low_b, self.parameter_type)
|
||||||
high_b = self._check_value(high_b, 'high_b')
|
high_b = self._check_value(high_b, 'high_b')
|
||||||
high_b = self.cast(high_b, self.parameter_type)
|
high_b = self.cast(high_b, self.parameter_type)
|
||||||
low_a, high_a = self._check_param(low, high)
|
low_a, high_a = self._check_param_type(low, high)
|
||||||
kl = self.log(high_b - low_b) - self.log(high_a - low_a)
|
kl = self.log(high_b - low_b) - self.log(high_a - low_a)
|
||||||
comp = self.logicaland(self.lessequal(low_b, low_a), self.lessequal(high_a, high_b))
|
comp = self.logicaland(self.lessequal(low_b, low_a), self.lessequal(high_a, high_b))
|
||||||
return self.select(comp, kl, self.log(self.zeroslike(kl)))
|
return self.select(comp, kl, self.log(self.zeroslike(kl)))
|
||||||
|
@ -313,7 +290,7 @@ class Uniform(Distribution):
|
||||||
"""
|
"""
|
||||||
value = self._check_value(value, 'value')
|
value = self._check_value(value, 'value')
|
||||||
value = self.cast(value, self.dtype)
|
value = self.cast(value, self.dtype)
|
||||||
low, high = self._check_param(low, high)
|
low, high = self._check_param_type(low, high)
|
||||||
prob = (value - low) / (high - low)
|
prob = (value - low) / (high - low)
|
||||||
broadcast_shape = self.shape(prob)
|
broadcast_shape = self.shape(prob)
|
||||||
zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
|
zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
|
||||||
|
@ -336,7 +313,7 @@ class Uniform(Distribution):
|
||||||
Tensor, shape is shape + batch_shape.
|
Tensor, shape is shape + batch_shape.
|
||||||
"""
|
"""
|
||||||
shape = self.checktuple(shape, 'shape')
|
shape = self.checktuple(shape, 'shape')
|
||||||
low, high = self._check_param(low, high)
|
low, high = self._check_param_type(low, high)
|
||||||
broadcast_shape = self.shape(low + high)
|
broadcast_shape = self.shape(low + high)
|
||||||
origin_shape = shape + broadcast_shape
|
origin_shape = shape + broadcast_shape
|
||||||
if origin_shape == ():
|
if origin_shape == ():
|
||||||
|
|
|
@ -0,0 +1,182 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
Test util functions used in distribution classes.
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from mindspore.nn.cell import Cell
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore import dtype
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.common.parameter import Parameter
|
||||||
|
from mindspore.nn.probability.distribution._utils.utils import set_param_type, \
|
||||||
|
cast_to_tensor, CheckTuple, CheckTensor
|
||||||
|
|
||||||
|
def test_set_param_type():
|
||||||
|
"""
|
||||||
|
Test set_param_type function.
|
||||||
|
"""
|
||||||
|
tensor_fp16 = Tensor(0.1, dtype=dtype.float16)
|
||||||
|
tensor_fp32 = Tensor(0.1, dtype=dtype.float32)
|
||||||
|
tensor_fp64 = Tensor(0.1, dtype=dtype.float64)
|
||||||
|
tensor_int32 = Tensor(0.1, dtype=dtype.int32)
|
||||||
|
array_fp32 = np.array(1.0).astype(np.float32)
|
||||||
|
array_fp64 = np.array(1.0).astype(np.float64)
|
||||||
|
array_int32 = np.array(1.0).astype(np.int32)
|
||||||
|
|
||||||
|
dict1 = {'a': tensor_fp32, 'b': 1.0, 'c': tensor_fp32}
|
||||||
|
dict2 = {'a': tensor_fp32, 'b': 1.0, 'c': tensor_fp64}
|
||||||
|
dict3 = {'a': tensor_int32, 'b': 1.0, 'c': tensor_int32}
|
||||||
|
dict4 = {'a': array_fp32, 'b': 1.0, 'c': tensor_fp32}
|
||||||
|
dict5 = {'a': array_fp32, 'b': 1.0, 'c': array_fp64}
|
||||||
|
dict6 = {'a': array_fp32, 'b': 1.0, 'c': array_int32}
|
||||||
|
dict7 = {'a': 1.0}
|
||||||
|
dict8 = {'a': 1.0, 'b': 1.0, 'c': 1.0}
|
||||||
|
dict9 = {'a': tensor_fp16, 'b': tensor_fp16, 'c': tensor_fp16}
|
||||||
|
dict10 = {'a': tensor_fp64, 'b': tensor_fp64, 'c': tensor_fp64}
|
||||||
|
dict11 = {'a': array_fp64, 'b': array_fp64, 'c': tensor_fp64}
|
||||||
|
|
||||||
|
ans1 = set_param_type(dict1, dtype.float16)
|
||||||
|
assert ans1 == dtype.float32
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
set_param_type(dict2, dtype.float32)
|
||||||
|
|
||||||
|
ans3 = set_param_type(dict3, dtype.float16)
|
||||||
|
assert ans3 == dtype.float32
|
||||||
|
ans4 = set_param_type(dict4, dtype.float16)
|
||||||
|
assert ans4 == dtype.float32
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
set_param_type(dict5, dtype.float32)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
set_param_type(dict6, dtype.float32)
|
||||||
|
|
||||||
|
ans7 = set_param_type(dict7, dtype.float32)
|
||||||
|
assert ans7 == dtype.float32
|
||||||
|
ans8 = set_param_type(dict8, dtype.float32)
|
||||||
|
assert ans8 == dtype.float32
|
||||||
|
ans9 = set_param_type(dict9, dtype.float32)
|
||||||
|
assert ans9 == dtype.float16
|
||||||
|
ans10 = set_param_type(dict10, dtype.float32)
|
||||||
|
assert ans10 == dtype.float32
|
||||||
|
ans11 = set_param_type(dict11, dtype.float32)
|
||||||
|
assert ans11 == dtype.float32
|
||||||
|
|
||||||
|
def test_cast_to_tensor():
|
||||||
|
"""
|
||||||
|
Test cast_to_tensor.
|
||||||
|
"""
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
cast_to_tensor(None, dtype.float32)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
cast_to_tensor(True, dtype.float32)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
cast_to_tensor({'a': 1, 'b': 2}, dtype.float32)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
cast_to_tensor('tensor', dtype.float32)
|
||||||
|
|
||||||
|
ans1 = cast_to_tensor(Parameter(Tensor(0.1, dtype=dtype.float32), 'param'))
|
||||||
|
assert isinstance(ans1, Parameter)
|
||||||
|
ans2 = cast_to_tensor(np.array(1.0).astype(np.float32))
|
||||||
|
assert isinstance(ans2, Tensor)
|
||||||
|
ans3 = cast_to_tensor([1.0, 2.0])
|
||||||
|
assert isinstance(ans3, Tensor)
|
||||||
|
ans4 = cast_to_tensor(Tensor(0.1, dtype=dtype.float32), dtype.float32)
|
||||||
|
assert isinstance(ans4, Tensor)
|
||||||
|
ans5 = cast_to_tensor(0.1, dtype.float32)
|
||||||
|
assert isinstance(ans5, Tensor)
|
||||||
|
ans6 = cast_to_tensor(1, dtype.float32)
|
||||||
|
assert isinstance(ans6, Tensor)
|
||||||
|
|
||||||
|
class Net(Cell):
|
||||||
|
"""
|
||||||
|
Test class: CheckTuple.
|
||||||
|
"""
|
||||||
|
def __init__(self, value):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.checktuple = CheckTuple()
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def construct(self, value=None):
|
||||||
|
if value is None:
|
||||||
|
return self.checktuple(self.value, 'input')
|
||||||
|
return self.checktuple(value, 'input')
|
||||||
|
|
||||||
|
def test_check_tuple():
|
||||||
|
"""
|
||||||
|
Test CheckTuple.
|
||||||
|
"""
|
||||||
|
net1 = Net((1, 2, 3))
|
||||||
|
ans1 = net1()
|
||||||
|
assert isinstance(ans1, tuple)
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
net2 = Net('tuple')
|
||||||
|
net2()
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
net3 = Net((1, 2, 3))
|
||||||
|
ans3 = net3()
|
||||||
|
assert isinstance(ans3, tuple)
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
net4 = Net('tuple')
|
||||||
|
net4()
|
||||||
|
|
||||||
|
class Net1(Cell):
|
||||||
|
"""
|
||||||
|
Test class: CheckTensor.
|
||||||
|
"""
|
||||||
|
def __init__(self, value):
|
||||||
|
super(Net1, self).__init__()
|
||||||
|
self.checktensor = CheckTensor()
|
||||||
|
self.value = value
|
||||||
|
self.context = context.get_context('mode')
|
||||||
|
|
||||||
|
def construct(self, value=None):
|
||||||
|
value = self.value if value is None else value
|
||||||
|
if self.context == 0:
|
||||||
|
self.checktensor(value, 'input')
|
||||||
|
return value
|
||||||
|
return self.checktensor(value, 'input')
|
||||||
|
|
||||||
|
def test_check_tensor():
|
||||||
|
"""
|
||||||
|
Test CheckTensor.
|
||||||
|
"""
|
||||||
|
value = Tensor(0.1, dtype=dtype.float32)
|
||||||
|
net1 = Net1(value)
|
||||||
|
ans1 = net1()
|
||||||
|
assert isinstance(ans1, Tensor)
|
||||||
|
ans1 = net1(value)
|
||||||
|
assert isinstance(ans1, Tensor)
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
net2 = Net1('tuple')
|
||||||
|
net2()
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
net3 = Net1(value)
|
||||||
|
ans3 = net3()
|
||||||
|
assert isinstance(ans3, Tensor)
|
||||||
|
ans3 = net3(value)
|
||||||
|
assert isinstance(ans3, Tensor)
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
net4 = Net1('tuple')
|
||||||
|
net4()
|
Loading…
Reference in New Issue