diff --git a/mindspore/nn/distribution/__init__.py b/mindspore/nn/distribution/__init__.py index 55b4b03ef7..2f36c51bcc 100644 --- a/mindspore/nn/distribution/__init__.py +++ b/mindspore/nn/distribution/__init__.py @@ -21,7 +21,13 @@ The high-level components(Distributions) used to construct the probabilistic net from .distribution import Distribution from .normal import Normal from .bernoulli import Bernoulli +from .exponential import Exponential +from .uniform import Uniform +from .geometric import Geometric __all__ = ['Distribution', 'Normal', - 'Bernoulli',] + 'Bernoulli', + 'Exponential', + 'Uniform', + 'Geometric',] diff --git a/mindspore/nn/distribution/_utils/__init__.py b/mindspore/nn/distribution/_utils/__init__.py index 816485643a..f9cd3d3c2e 100644 --- a/mindspore/nn/distribution/_utils/__init__.py +++ b/mindspore/nn/distribution/_utils/__init__.py @@ -17,8 +17,11 @@ Distribution operation utility functions. """ from .utils import * -__all__ = ['check_scalar', 'convert_to_batch', 'cast_to_tensor', - 'calc_batch_size', 'check_greater', +__all__ = ['convert_to_batch', + 'cast_to_tensor', + 'check_greater', 'check_greater_equal_zero', + 'check_greater_zero', 'calc_broadcast_shape_from_param', - 'check_scalar_from_param', 'check_prob'] + 'check_scalar_from_param', + 'check_prob'] diff --git a/mindspore/nn/distribution/_utils/utils.py b/mindspore/nn/distribution/_utils/utils.py index c790a66f25..e37f9d632c 100644 --- a/mindspore/nn/distribution/_utils/utils.py +++ b/mindspore/nn/distribution/_utils/utils.py @@ -20,17 +20,10 @@ from ....common.tensor import Tensor from ....common.parameter import Parameter from ....common import dtype as mstype - -def check_scalar(value): - """ - Check if input value is a scalar. - """ - return np.isscalar(value) - - def cast_to_tensor(t, dtype=mstype.float32): """ Cast an user input value into a Tensor of dtype. + If the input t is of type Parameter, t is directly returned as a Parameter. Args: t (int, float, list, numpy.ndarray, Tensor, Parameter): object to be cast to Tensor. @@ -54,22 +47,10 @@ def cast_to_tensor(t, dtype=mstype.float32): return t if isinstance(t, (list, np.ndarray)): return Tensor(t, dtype=dtype) - if check_scalar(t): + if np.isscalar(t): return Tensor([t], dtype=dtype) raise RuntimeError("Input type is not supported.") -def calc_batch_size(batch_shape): - """ - Calculate the size of a given batch_shape. - - Args: - batch_shape (tuple): batch shape to be calculated. - - Returns: - int. - """ - return int(np.prod(batch_shape)) - def convert_to_batch(t, batch_shape, dtype): """ Convert a Tensor to a given batch shape. @@ -87,15 +68,9 @@ def convert_to_batch(t, batch_shape, dtype): """ if isinstance(t, Parameter): return t - t = cast_to_tensor(t, dtype) - if t.shape != batch_shape: - mul = calc_batch_size(batch_shape) // t.size() - if (calc_batch_size(batch_shape) % t.size()) != 0: - raise RuntimeError("Cannot cast the tensor to the given batch shape.") - temp = list(t.asnumpy()) * mul - temp = np.reshape(temp, batch_shape) - return Tensor(temp, dtype) - return t + if isinstance(t, Tensor): + return Tensor(np.broadcast_to(t.asnumpy(), batch_shape), dtype=dtype) + return Tensor(np.broadcast_to(t, batch_shape), dtype=dtype) def check_scalar_from_param(params): """ @@ -107,9 +82,11 @@ def check_scalar_from_param(params): Notes: String parameters are excluded. """ for value in params.values(): + if isinstance(value, Parameter): + return False if isinstance(value, (str, type(params['dtype']))): continue - elif check_scalar(value): + elif np.isscalar(value): continue else: return False @@ -157,6 +134,26 @@ def check_greater_equal_zero(value, name): value = value.default_input comp = np.less(value.asnumpy(), np.zeros(value.shape)) if comp.any(): + raise ValueError(f'{name} should be greater than ot equal to zero.') + +def check_greater_zero(value, name): + """ + Check if the given Tensor is strictly greater than zero. + + Args: + value (Tensor, Parameter): value to be checked. + name (str) : name of the value. + + Raises: + ValueError: if the input value is less than or equal to zero. + + """ + if isinstance(value, Parameter): + if isinstance(value.default_input, MetaTensor): + return + value = value.default_input + comp = np.less(np.zeros(value.shape), value.asnumpy()) + if not comp.all(): raise ValueError(f'{name} should be greater than zero.') def check_greater(a, b, name_a, name_b): @@ -164,14 +161,16 @@ def check_greater(a, b, name_a, name_b): Check if Tensor b is strictly greater than Tensor a. Args: - a (Tensor): input tensor a. - b (Tensor): input tensor b. + a (Tensor, Parameter): input tensor a. + b (Tensor, Parameter): input tensor b. name_a (str): name of Tensor_a. name_b (str): name of Tensor_b. Raises: ValueError: if b is less than or equal to a """ + if isinstance(a, Parameter) or isinstance(b, Parameter): + return comp = np.less(a.asnumpy(), b.asnumpy()) if not comp.all(): raise ValueError(f'{name_a} should be less than {name_b}') diff --git a/mindspore/nn/distribution/bernoulli.py b/mindspore/nn/distribution/bernoulli.py index 9aa20d668f..f047326798 100644 --- a/mindspore/nn/distribution/bernoulli.py +++ b/mindspore/nn/distribution/bernoulli.py @@ -14,29 +14,75 @@ # ============================================================================ """Bernoulli Distribution""" from mindspore.ops import operations as P -from mindspore.ops import composite as C from .distribution import Distribution from ._utils.utils import cast_to_tensor, check_prob from ...common import dtype as mstype class Bernoulli(Distribution): """ - Example class: Bernoulli Distribution. + Bernoulli Distribution. Args: - probs (int, float, list, numpy.ndarray, Tensor, Parameter): probability of 1 as outcome. + probs (float, list, numpy.ndarray, Tensor, Parameter): probability of 1 as outcome. seed (int): seed to use in sampling. Default: 0. dtype (mindspore.dtype): type of the distribution. Default: mstype.int32. name (str): name of the distribution. Default: Bernoulli. Note: probs should be proper probabilities (0 <= p <= 1). + Dist_spec_args is probs. Examples: - >>> # To initialize a Bernoulli distribution which has equal probability of getting 1 and 0 - >>> b = nn.Bernoulli(0.5, dtype = mstype.int32) - >>> # The following create two independent Bernoulli distributions - >>> b = nn.Bernoulli([0.7, 0.2], dtype = mstype.int32) + >>> # To initialize a Bernoulli distribution of prob 0.5 + >>> n = nn.Bernoulli(0.5, dtype=mstype.int32) + >>> + >>> # The following creates two independent Bernoulli distributions + >>> n = nn.Bernoulli([0.5, 0.5], dtype=mstype.int32) + >>> + >>> # A Bernoulli distribution can be initilized without arguments + >>> # In this case, probs must be passed in through construct. + >>> n = nn.Bernoulli(dtype=mstype.int32) + >>> + >>> # To use Bernoulli distribution in a network + >>> class net(Cell): + >>> def __init__(self): + >>> super(net, self).__init__(): + >>> self.b1 = nn.Bernoulli(0.5, dtype=mstype.int32) + >>> self.b2 = nn.Bernoulli(dtype=mstype.int32) + >>> + >>> # All the following calls in construct are valid + >>> def construct(self, value, probs_b, probs_a): + >>> + >>> # Similar calls can be made to other probability functions + >>> # by replacing 'prob' with the name of the function + >>> ans = self.b1('prob', value) + >>> # Evaluate with the respect to distribution b + >>> ans = self.b1('prob', value, probs_b) + >>> + >>> # probs must be passed in through construct + >>> ans = self.b2('prob', value, probs_a) + >>> + >>> # Functions 'sd', 'var', 'entropy' have the same usage like 'mean' + >>> # Will return [0.0] + >>> ans = self.b1('mean') + >>> # Will return mean_b + >>> ans = self.b1('mean', probs_b) + >>> + >>> # probs must be passed in through construct + >>> ans = self.b2('mean', probs_a) + >>> + >>> # Usage of 'kl_loss' and 'cross_entropy' are similar + >>> ans = self.b1('kl_loss', 'Bernoulli', probs_b) + >>> ans = self.b1('kl_loss', 'Bernoulli', probs_b, probs_a) + >>> + >>> # Additional probs_a must be passed in through construct + >>> ans = self.b2('kl_loss', 'Bernoulli', probs_b, probs_a) + >>> + >>> # Sample Usage + >>> ans = self.b1('sample') + >>> ans = self.b1('sample', (2,3)) + >>> ans = self.b1('sample', (2,3), probs_b) + >>> ans = self.b2('sample', (2,3), probs_a) """ def __init__(self, @@ -50,29 +96,34 @@ class Bernoulli(Distribution): param = dict(locals()) super(Bernoulli, self).__init__(dtype, name, param) if probs is not None: - self._probs = cast_to_tensor(probs) - check_prob(self._probs) + self._probs = cast_to_tensor(probs, dtype=mstype.float32) + check_prob(self.probs) else: self._probs = probs self.seed = seed # ops needed for the class - self.log = P.Log() - self.add = P.TensorAdd() - self.mul = P.Mul() - self.sqrt = P.Sqrt() - self.realdiv = P.RealDiv() - self.shape = P.Shape() - self.const = P.ScalarToArray() - self.less = P.Less() self.cast = P.Cast() + self.const = P.ScalarToArray() + self.dtypeop = P.DType() self.erf = P.Erf() + self.fill = P.Fill() + self.log = P.Log() + self.less = P.Less() + self.shape = P.Shape() + self.select = P.Select() + self.sq = P.Square() self.sqrt = P.Sqrt() + self.uniform = P.UniformReal(seed=seed) def extend_repr(self): - str_info = f'probs = {self._probs}' + if self.is_scalar_batch: + str_info = f'probs = {self.probs}' + else: + str_info = f'batch_shape = {self._broadcast_shape}' return str_info + @property def probs(self): """ Returns the probability for the outcome is 1. @@ -85,7 +136,21 @@ class Bernoulli(Distribution): MEAN(B) = probs1 """ if name == 'mean': - return self._probs if probs1 is None else probs1 + return self.probs if probs1 is None else probs1 + return None + + def _mode(self, name='mode', probs1=None): + r""" + .. math:: + MODE(B) = 1 if probs1 > 0.5 else = 0 + """ + if name == 'mode': + probs1 = self.probs if probs1 is None else probs1 + prob_type = self.dtypeop(probs1) + zeros = self.fill(prob_type, self.shape(probs1), 0.0) + ones = self.fill(prob_type, self.shape(probs1), 1.0) + comp = self.less(0.5, probs1) + return self.select(comp, ones, zeros) return None def _var(self, name='var', probs1=None): @@ -93,10 +158,35 @@ class Bernoulli(Distribution): .. math:: VAR(B) = probs1 * probs0 """ - if name in ('sd', 'var'): - probs1 = self._probs if probs1 is None else probs1 - probs0 = self.add(1, -1 * probs1) - return self.mul(probs0, probs1) + if name in self._variance_functions: + probs1 = self.probs if probs1 is None else probs1 + probs0 = 1.0 - probs1 + return probs0 * probs1 + return None + + def _entropy(self, name='entropy', probs=None): + r""" + .. math:: + H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1) + """ + if name == 'entropy': + probs1 = self.probs if probs is None else probs + probs0 = 1 - probs1 + return -1 * (probs0 * self.log(probs0)) - (probs1 * self.log(probs1)) + return None + + def _cross_entropy(self, name, dist, probs1_b, probs1_a=None): + """ + Evaluate cross_entropy between Bernoulli distributions. + + Args: + name (str): name of the funtion. + dist (str): type of the distributions. Should be "Bernoulli" in this case. + probs1_b (Tensor): probs1 of distribution b. + probs1_a (Tensor): probs1 of distribution a. Default: self.probs. + """ + if name == 'cross_entropy' and dist == 'Bernoulli': + return self._entropy(probs=probs1_a) + self._kl_loss(name, dist, probs1_b, probs1_a) return None def _prob(self, name, value, probs=None): @@ -106,17 +196,43 @@ class Bernoulli(Distribution): Args: name (str): name of the function. Should be "prob" when passed in from construct. value (Tensor): a Tensor composed of only zeros and ones. - probs (Tensor): probability of outcome is 1. Default: self._probs. + probs (Tensor): probability of outcome is 1. Default: self.probs. .. math:: pmf(k) = probs1 if k = 1; pmf(k) = probs0 if k = 0; """ - if name in ('prob', 'log_prob'): - probs1 = self._probs if probs is None else probs - probs0 = self.add(1, -1 * probs1) - return self.add(self.mul(probs1, value), - self.mul(probs0, self.add(1, -1 * value))) + if name in self._prob_functions: + probs1 = self.probs if probs is None else probs + probs0 = 1.0 - probs1 + return (probs1 * value) + (probs0 * (1.0 - value)) + return None + + def _cdf(self, name, value, probs=None): + r""" + cdf of Bernoulli distribution. + + Args: + name (str): name of the function. + value (Tensor): value to be evaluated. + probs (Tensor): probability of outcome is 1. Default: self.probs. + + .. math:: + cdf(k) = 0 if k < 0; + cdf(k) = probs0 if 0 <= k <1; + cdf(k) = 1 if k >=1; + """ + if name in self._cdf_survival_functions: + probs1 = self.probs if probs is None else probs + prob_type = self.dtypeop(probs1) + value = value * self.fill(prob_type, self.shape(probs1), 1.0) + probs0 = 1.0 - probs1 * self.fill(prob_type, self.shape(value), 1.0) + comp_zero = self.less(value, 0.0) + comp_one = self.less(value, 1.0) + zeros = self.fill(prob_type, self.shape(value), 0.0) + ones = self.fill(prob_type, self.shape(value), 1.0) + less_than_zero = self.select(comp_zero, zeros, probs0) + return self.select(comp_one, less_than_zero, ones) return None def _kl_loss(self, name, dist, probs1_b, probs1_a=None): @@ -124,21 +240,20 @@ class Bernoulli(Distribution): Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b). Args: - name (str): name of the funtion. Should always be "kl_loss" when passed in from construct. + name (str): name of the funtion. dist (str): type of the distributions. Should be "Bernoulli" in this case. probs1_b (Tensor): probs1 of distribution b. - probs1_a (Tensor): probs1 of distribution a. Default: self._probs. + probs1_a (Tensor): probs1 of distribution a. Default: self.probs. .. math:: KL(a||b) = probs1_a * \log(\fract{probs1_a}{probs1_b}) + probs0_a * \log(\fract{probs0_a}{probs0_b}) """ - if name == 'kl_loss' and dist == 'Bernoulli': - probs1_a = self._probs if probs1_a is None else probs1_a - probs0_a = self.add(1, -1 * probs1_a) - probs0_b = self.add(1, -1 * probs1_b) - return self.add(probs1_a * self.log(self.realdiv(probs1_a, probs1_b)), - probs0_a * self.log(self.realdiv(probs0_a, probs0_b))) + if name in self._divergence_functions and dist == 'Bernoulli': + probs1_a = self.probs if probs1_a is None else probs1_a + probs0_a = 1.0 - probs1_a + probs0_b = 1.0 - probs1_b + return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b) return None def _sample(self, name, shape=(), probs=None): @@ -148,21 +263,17 @@ class Bernoulli(Distribution): Args: name (str): name of the function. Should always be 'sample' when passed in from construct. shape (tuple): shape of the sample. Default: (). - probs (Tensor): probs1 of the samples. Default: self._probs. + probs (Tensor): probs1 of the samples. Default: self.probs. Returns: Tensor, shape is shape + batch_shape. """ if name == 'sample': - probs1 = self._probs if probs is None else probs - batch_shape = self.shape(probs1) - sample_shape = shape + batch_shape - mean_zero = self.const(0.0) - sd_one = self.const(1.0) - sqrt_two = self.sqrt(self.const(2.0)) - sample_norm = C.normal(sample_shape, mean_zero, sd_one, self.seed) - sample_uniform = 0.5 * (1 + self.erf(self.realdiv(sample_norm, sqrt_two))) + probs1 = self.probs if probs is None else probs + l_zero = self.const(0.0) + h_one = self.const(1.0) + sample_uniform = self.uniform(shape + self.shape(probs1), l_zero, h_one) sample = self.less(sample_uniform, probs1) - sample = self.cast(sample, self._dtype) + sample = self.cast(sample, self.dtype) return sample return None diff --git a/mindspore/nn/distribution/distribution.py b/mindspore/nn/distribution/distribution.py index 1ed7906a9e..52e23f0e9a 100644 --- a/mindspore/nn/distribution/distribution.py +++ b/mindspore/nn/distribution/distribution.py @@ -14,8 +14,7 @@ # ============================================================================ """basic""" from ..cell import Cell -from ._utils.utils import calc_broadcast_shape_from_param - +from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param class Distribution(Cell): """ @@ -29,19 +28,18 @@ class Distribution(Cell): Note: Derived class should override operations such as ,_mean, _prob, and _log_prob. Functions should be called through construct when - used inside a network in the form of function name followed by - arguments. + used inside a network. Arguments should be passed in through *args + in the form of function name followed by additional arguments. + Functions such as cdf and prob, require a value to be passed in while + functions such as mean and sd do not require arguments other than name. - Examples: - >>> class MyNormalDistribution(Distribution): - >>> def __init__(self): - >>> super(MyDistribution, self).__init__() - >>> self._mean_value = Tensor([2.0,3.0]) - >>> self._sd_value = Tensor([2.0,3.0]) - >>> - >>> def _mean(self): - >>> return self._mean_value + Dist_spec_args are unique for each type of distribution. For example, mean and sd + are the dist_spec_args for a Normal distribution. + For all functions, passing in dist_spec_args, are optional. + Passing in the additional dist_spec_args will make the result to be evaluated with + new distribution specified by the dist_spec_args. But it won't change the + original distribuion. """ def __init__(self, dtype, @@ -61,12 +59,40 @@ class Distribution(Cell): self._parameters[k] = param[k] # some attributes self._broadcast_shape = calc_broadcast_shape_from_param( - self._parameters) + self.parameters) + self._is_scalar_batch = check_scalar_from_param(self.parameters) # set the function to call according to the derived class's attributes self._set_prob() self._set_log_prob() self._set_sd() + self._set_var() + self._set_cdf() + self._set_survival() + self._set_log_cdf() + self._set_log_survival() + self._set_cross_entropy() + + self._prob_functions = ('prob', 'log_prob') + self._cdf_survival_functions = ('cdf', 'log_cdf', 'survival_function', 'log_survival') + self._variance_functions = ('var', 'sd') + self._divergence_functions = ('kl_loss', 'cross_entropy') + + @property + def name(self): + return self._name + + @property + def dtype(self): + return self._dtype + + @property + def parameters(self): + return self._parameters + + @property + def is_scalar_batch(self): + return self._is_scalar_batch def _set_prob(self): """ @@ -74,8 +100,8 @@ class Distribution(Cell): """ if hasattr(self, '_prob'): self._call_prob = self._prob - elif hasattr(self, '_log_likelihood'): - self._call_prob = self._calc_prob_from_log_likelihood + elif hasattr(self, '_log_prob'): + self._call_prob = self._calc_prob_from_log_prob def _set_sd(self): """ @@ -86,45 +112,100 @@ class Distribution(Cell): elif hasattr(self, '_var'): self._call_sd = self._calc_sd_from_var + def _set_var(self): + """ + Set variance based on the availability of _sd and _var. + """ + if hasattr(self, '_var'): + self._call_var = self._var + elif hasattr(self, '_sd'): + self._call_var = self._calc_var_from_sd + def _set_log_prob(self): """ - Set log probability based on the availability of _prob and _log_likelihood. + Set log probability based on the availability of _prob and _log_prob. """ - if hasattr(self, '_log_likelihood'): - self._call_log_prob = self._log_likelihood - if hasattr(self, '_prob'): + if hasattr(self, '_log_prob'): + self._call_log_prob = self._log_prob + elif hasattr(self, '_prob'): self._call_log_prob = self._calc_log_prob_from_prob - def log_likelihood(self, *args): + def _set_cdf(self): """ - Evaluate the log probability at the given value. + Set cdf based on the availability of _cdf and _log_cdf and survival_functions. + """ + if hasattr(self, '_cdf'): + self._call_cdf = self._cdf + elif hasattr(self, '_log_cdf'): + self._call_cdf = self._calc_cdf_from_log_cdf + elif hasattr(self, '_survival_function'): + self._call_cdf = self._calc_cdf_from_survival + elif hasattr(self, '_log_survival'): + self._call_cdf = self._calc_cdf_from_log_survival + + def _set_survival(self): + """ + Set survival function based on the availability of _survival function and _log_survival + and _call_cdf. + """ + if hasattr(self, '_survival_function'): + self._call_survival = self._survival_function + elif hasattr(self, '_log_survival'): + self._call_survival = self._calc_survival_from_log_survival + elif hasattr(self, '_call_cdf'): + self._call_survival = self._calc_survival_from_call_cdf + + def _set_log_cdf(self): + """ + Set log cdf based on the availability of _log_cdf and _call_cdf. + """ + if hasattr(self, '_log_cdf'): + self._call_log_cdf = self._log_cdf + elif hasattr(self, '_call_cdf'): + self._call_log_cdf = self._calc_log_cdf_from_call_cdf + + def _set_log_survival(self): + """ + Set log survival based on the availability of _log_survival and _call_survival. + """ + if hasattr(self, '_log_survival'): + self._call_log_survival = self._log_survival + elif hasattr(self, '_call_survival'): + self._call_log_survival = self._calc_log_survival_from_call_survival + + def _set_cross_entropy(self): + """ + Set log survival based on the availability of _cross_entropy. + """ + if hasattr(self, '_cross_entropy'): + self._call_cross_entropy = self._cross_entropy + + def log_prob(self, *args): + """ + Evaluate the log probability(pdf or pmf) at the given value. Note: - value is casted to Tensor for further calculation. - - Returns: - Tensor, shape is the broadcast_shape of the distribution. + Args must include name of the function and value. + Dist_spec_args are optional. """ return self._call_log_prob(*args) - def _calc_prob_from_log_likelihood(self, *args): + def _calc_prob_from_log_prob(self, *args): r""" Evaluate prob from log probability. .. math:: probability(x) = \exp(log_likehood(x)) """ - return self.exp(self._log_likelihood(*args)) + return self.exp(self._log_prob(*args)) def prob(self, *args): """ - Evaluate the prob (pdf or pmf) at given value. + Evaluate the probability (pdf or pmf) at given value. Note: - value is casted to Tensor for further calculation. - - Returns: - Tensor, shape is the broadcast_shape of the distribution. + Args must include name of the function and value. + Dist_spec_args are optional. """ return self._call_prob(*args) @@ -137,33 +218,154 @@ class Distribution(Cell): """ return self.log(self._prob(*args)) - def kl_loss(self, **kwargs): + def cdf(self, *args): """ - Evaluate the KL divergence. Parameters of the second distribution should be - passed in through **kwargs. + Evaluate the cdf at given value. - Returns: - Tensor, shape is the broadcast_shape of the distribution and input distribution. + Note: + Args must include name of the function and value. + Dist_spec_args are optional. """ - return self._kl_loss(**kwargs) + return self._call_cdf(*args) - def mean(self, **kwargs): + def _calc_cdf_from_log_cdf(self, *args): + r""" + Evaluate cdf from log_cdf. + + .. math:: + cdf(x) = \exp(log_cdf(x)) + """ + return self.exp(self._log_cdf(*args)) + + def _calc_cdf_from_survival(self, *args): + r""" + Evaluate cdf from survival function. + + .. math:: + cdf(x) = 1 - (survival_function(x)) + """ + return 1.0 - self._survival_function(*args) + + def _calc_cdf_from_log_survival(self, *args): + r""" + Evaluate cdf from log survival function. + + .. math:: + cdf(x) = 1 - (\exp(log_survival(x))) + """ + return 1.0 - self.exp(self._log_survival(*args)) + + def log_cdf(self, *args): + """ + Evaluate the log cdf at given value. + + Note: + Args must include name of the function and value. + Dist_spec_args are optional. + """ + return self._call_log_cdf(*args) + + def _calc_log_cdf_from_call_cdf(self, *args): + r""" + Evaluate log cdf from cdf. + + .. math:: + log_cdf(x) = \log(cdf(x)) + """ + return self.log(self._call_cdf(*args)) + + def survival_function(self, *args): + """ + Evaluate the survival function at given value. + + Note: + Args must include name of the function and value. + Dist_spec_args are optional. + """ + return self._call_survival(*args) + + def _calc_survival_from_call_cdf(self, *args): + r""" + Evaluate survival function from cdf. + + .. math:: + survival_function(x) = 1 - (cdf(x)) + """ + return 1.0 - self._call_cdf(*args) + + def _calc_survival_from_log_survival(self, *args): + r""" + Evaluate survival function from log survival function. + + .. math:: + survival(x) = \exp(survival_function(x)) + """ + return self.exp(self._log_survival(*args)) + + def log_survival(self, *args): + """ + Evaluate the log survival function at given value. + + Note: + Args must include name of the function and value. + Dist_spec_args are optional. + """ + return self._call_log_survival(*args) + + def _calc_log_survival_from_call_survival(self, *args): + r""" + Evaluate log survival function from survival function. + + .. math:: + log_survival(x) = \log(survival_function(x)) + """ + return self.log(self._call_survival(*args)) + + def kl_loss(self, *args): + """ + Evaluate the KL divergence, i.e. KL(a||b). + + Note: + Args must include name of the function, type of the distribution, parameters of distribution b. + Parameters for distribution a are optional. + """ + return self._kl_loss(*args) + + def mean(self, *args): """ Evaluate the mean. - Returns: - Tensor, shape is the broadcast_shape of the distribution. + Note: + Args must include the name of function. Dist_spec_args are optional. """ - return self._mean(**kwargs) + return self._mean(*args) - def sd(self, **kwargs): + def mode(self, *args): + """ + Evaluate the mode. + + Note: + Args must include the name of function. Dist_spec_args are optional. + """ + return self._mode(*args) + + def sd(self, *args): """ Evaluate the standard deviation. - Returns: - Tensor, shape is the broadcast_shape of the distribution. + Note: + Args must include the name of function. Dist_spec_args are optional. """ - return self._call_sd(**kwargs) + return self._call_sd(*args) + + def var(self, *args): + """ + Evaluate the variance. + + Note: + Args must include the name of function. Dist_spec_args are optional. + """ + return self._call_var(*args) def _calc_sd_from_var(self, *args): r""" @@ -174,27 +376,96 @@ class Distribution(Cell): """ return self.sqrt(self._var(*args)) + def _calc_var_from_sd(self, *args): + r""" + Evaluate log probability from probability. + + .. math:: + VAR(x) = STD(x) ^ 2 + """ + return self.sq(self._sd(*args)) + + def entropy(self, *args): + """ + Evaluate the entropy. + + Note: + Args must include the name of function. Dist_spec_args are optional. + """ + return self._entropy(*args) + + def cross_entropy(self, *args): + """ + Evaluate the cross_entropy between distribution a and b. + + Note: + Args must include name of the function, type of the distribution, parameters of distribution b. + Parameters for distribution a are optional. + """ + return self._call_cross_entropy(*args) + + def _calc_cross_entropy(self, *args): + r""" + Evaluate cross_entropy from entropy and kl divergence. + + .. math:: + H(X, Y) = H(X) + KL(X||Y) + """ + return self._entropy(*args) + self._kl_loss(*args) + + def sample(self, *args): + """ + Sampling function. + + Args: + *args (list): arguments passed in through construct. + + Note: + Args must include name of the function. + Shape of the sample and dist_spec_args are optional. + """ + return self._sample(*args) + + def construct(self, *inputs): """ Override construct in Cell. - Args: - *inputs: inputs[0] is always the name of the function. + Note: + Names of supported functions: + 'prob', 'log_prob', 'cdf', 'log_cdf', 'survival_function', 'log_survival' + 'var', 'sd', 'entropy', 'kl_loss', 'cross_entropy', 'sample'. - Notes: - Always raise RuntimeError as Distribution should not be called directly. + Args: + *inputs (list): inputs[0] is always the name of the function. """ if inputs[0] == 'log_prob': return self._call_log_prob(*inputs) if inputs[0] == 'prob': return self._call_prob(*inputs) + if inputs[0] == 'cdf': + return self._call_cdf(*inputs) + if inputs[0] == 'log_cdf': + return self._call_log_cdf(*inputs) + if inputs[0] == 'survival_function': + return self._call_survival(*inputs) + if inputs[0] == 'log_survival': + return self._call_log_survival(*inputs) if inputs[0] == 'kl_loss': return self._kl_loss(*inputs) if inputs[0] == 'mean': return self._mean(*inputs) + if inputs[0] == 'mode': + return self._mode(*inputs) if inputs[0] == 'sd': return self._call_sd(*inputs) + if inputs[0] == 'var': + return self._call_var(*inputs) + if inputs[0] == 'entropy': + return self._entropy(*inputs) + if inputs[0] == 'cross_entropy': + return self._call_cross_entropy(*inputs) if inputs[0] == 'sample': return self._sample(*inputs) return None diff --git a/mindspore/nn/distribution/exponential.py b/mindspore/nn/distribution/exponential.py new file mode 100644 index 0000000000..9816369e0b --- /dev/null +++ b/mindspore/nn/distribution/exponential.py @@ -0,0 +1,268 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Exponential Distribution""" +import numpy as np +from mindspore.ops import operations as P +from .distribution import Distribution +from ...common import dtype as mstype +from ._utils.utils import cast_to_tensor, check_greater_zero + +class Exponential(Distribution): + """ + Example class: Exponential Distribution. + + Args: + rate (float, list, numpy.ndarray, Tensor, Parameter): inverse scale. + seed (int): seed to use in sampling. Default: 0. + dtype (mindspore.dtype): type of the distribution. Default: mstype.float32. + name (str): name of the distribution. Default: Exponential. + + Note: + rate should be strictly greater than 0. + Dist_spec_args is rate. + + Examples: + >>> # To initialize an Exponential distribution of rate 0.5 + >>> n = nn.Exponential(0.5, dtype=mstype.float32) + >>> + >>> # The following creates two independent Exponential distributions + >>> n = nn.Exponential([0.5, 0.5], dtype=mstype.float32) + >>> + >>> # A Exponential distribution can be initilized without arguments + >>> # In this case, rate must be passed in through construct. + >>> n = nn.Exponential(dtype=mstype.float32) + >>> + >>> # To use Exponential distribution in a network + >>> class net(Cell): + >>> def __init__(self): + >>> super(net, self).__init__(): + >>> self.e1 = nn.Exponential(0.5, dtype=mstype.float32) + >>> self.e2 = nn.Exponential(dtype=mstype.float32) + >>> + >>> # All the following calls in construct are valid + >>> def construct(self, value, rate_b, rate_a): + >>> + >>> # Similar calls can be made to other probability functions + >>> # by replacing 'prob' with the name of the function + >>> ans = self.e1('prob', value) + >>> # Evaluate with the respect to distribution b + >>> ans = self.e1('prob', value, rate_b) + >>> + >>> # Rate must be passed in through construct + >>> ans = self.e2('prob', value, rate_a) + >>> + >>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean' + >>> # Will return [0.0] + >>> ans = self.e1('mean') + >>> # Will return mean_b + >>> ans = self.e1('mean', rate_b) + >>> + >>> # Rate must be passed in through construct + >>> ans = self.e2('mean', rate_a) + >>> + >>> # Usage of 'kl_loss' and 'cross_entropy' are similar + >>> ans = self.e1('kl_loss', 'Exponential', rate_b) + >>> ans = self.e1('kl_loss', 'Exponential', rate_b, rate_a) + >>> + >>> # Additional rate must be passed in through construct + >>> ans = self.e2('kl_loss', 'Exponential', rate_b, rate_a) + >>> + >>> # Sample Usage + >>> ans = self.e1('sample') + >>> ans = self.e1('sample', (2,3)) + >>> ans = self.e1('sample', (2,3), rate_b) + >>> ans = self.e2('sample', (2,3), rate_a) + """ + + def __init__(self, + rate=None, + seed=0, + dtype=mstype.float32, + name="Exponential"): + """ + Constructor of Exponential distribution. + """ + param = dict(locals()) + super(Exponential, self).__init__(dtype, name, param) + if rate is not None: + self._rate = cast_to_tensor(rate, mstype.float32) + check_greater_zero(self._rate, "rate") + else: + self._rate = rate + + self.minval = np.finfo(np.float).tiny + + # ops needed for the class + self.const = P.ScalarToArray() + self.dtypeop = P.DType() + self.exp = P.Exp() + self.fill = P.Fill() + self.less = P.Less() + self.log = P.Log() + self.select = P.Select() + self.shape = P.Shape() + self.sqrt = P.Sqrt() + self.sq = P.Square() + self.uniform = P.UniformReal(seed=seed) + + def extend_repr(self): + if self.is_scalar_batch: + str_info = f'rate = {self.rate}' + else: + str_info = f'batch_shape = {self._broadcast_shape}' + return str_info + + @property + def rate(self): + """ + Return rate of the distribution. + """ + return self._rate + + def _mean(self, name='mean', rate=None): + r""" + .. math:: + MEAN(EXP) = \fract{1.0}{\lambda}. + """ + if name == 'mean': + rate = self.rate if rate is None else rate + return 1.0 / rate + return None + + def _mode(self, name='mode', rate=None): + r""" + .. math:: + MODE(EXP) = 0. + """ + if name == 'mode': + rate = self.rate if rate is None else rate + return self.fill(self.dtype, self.shape(rate), 0.) + return None + + def _sd(self, name='sd', rate=None): + r""" + .. math:: + sd(EXP) = \fract{1.0}{\lambda}. + """ + if name in self._variance_functions: + rate = self.rate if rate is None else rate + return 1.0 / rate + return None + + def _entropy(self, name='entropy', rate=None): + r""" + .. math:: + H(Exp) = 1 - \log(\lambda). + """ + rate = self.rate if rate is None else rate + if name == 'entropy': + return 1.0 - self.log(rate) + return None + + def _cross_entropy(self, name, dist, rate_b, rate_a=None): + """ + Evaluate cross_entropy between Exponential distributions. + + Args: + name (str): name of the funtion. Should always be "cross_entropy" when passed in from construct. + dist (str): type of the distributions. Should be "Exponential" in this case. + rate_b (Tensor): rate of distribution b. + rate_a (Tensor): rate of distribution a. Default: self.rate. + """ + if name == 'cross_entropy' and dist == 'Exponential': + return self._entropy(rate=rate_a) + self._kl_loss(name, dist, rate_b, rate_a) + return None + + def _prob(self, name, value, rate=None): + r""" + pdf of Exponential distribution. + + Args: + Args: + name (str): name of the function. + value (Tensor): value to be evaluated. + rate (Tensor): rate of the distribution. Default: self.rate. + + Note: + Value should be greater or equal to zero. + + .. math:: + pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0 + """ + if name in self._prob_functions: + rate = self.rate if rate is None else rate + prob = rate * self.exp(-1. * rate * value) + zeros = self.fill(self.dtypeop(prob), self.shape(prob), 0.0) + comp = self.less(value, zeros) + return self.select(comp, zeros, prob) + return None + + def _cdf(self, name, value, rate=None): + r""" + cdf of Exponential distribution. + + Args: + name (str): name of the function. + value (Tensor): value to be evaluated. + rate (Tensor): rate of the distribution. Default: self.rate. + + Note: + Value should be greater or equal to zero. + + .. math:: + cdf(x) = 1.0 - \exp(-1 * \lambda * x) if x >= 0 else 0 + """ + if name in self._cdf_survival_functions: + rate = self.rate if rate is None else rate + cdf = 1.0 - self.exp(-1. * rate * value) + zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0) + comp = self.less(value, zeros) + return self.select(comp, zeros, cdf) + return None + + def _kl_loss(self, name, dist, rate_b, rate_a=None): + """ + Evaluate exp-exp kl divergence, i.e. KL(a||b). + + Args: + name (str): name of the funtion. + dist (str): type of the distributions. Should be "Exponential" in this case. + rate_b (Tensor): rate of distribution b. + rate_a (Tensor): rate of distribution a. Default: self.rate. + """ + if name in self._divergence_functions and dist == 'Exponential': + rate_a = self.rate if rate_a is None else rate_a + return self.log(rate_a) - self.log(rate_b) + rate_b / rate_a - 1.0 + return None + + def _sample(self, name, shape=(), rate=None): + """ + Sampling. + + Args: + name (str): name of the function. + shape (tuple): shape of the sample. Default: (). + rate (Tensor): rate of the distribution. Default: self.rate. + + Returns: + Tensor, shape is shape + batch_shape. + """ + if name == 'sample': + rate = self.rate if rate is None else rate + minval = self.const(self.minval) + maxval = self.const(1.0) + sample = self.uniform(shape + self.shape(rate), minval, maxval) + return -self.log(sample) / rate + return None diff --git a/mindspore/nn/distribution/geometric.py b/mindspore/nn/distribution/geometric.py new file mode 100644 index 0000000000..0a9da3b244 --- /dev/null +++ b/mindspore/nn/distribution/geometric.py @@ -0,0 +1,288 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Geometric Distribution""" +import numpy as np +from mindspore.ops import operations as P +from .distribution import Distribution +from ._utils.utils import cast_to_tensor, check_prob +from ...common import dtype as mstype + +class Geometric(Distribution): + """ + Geometric Distribution. + It represents k+1 Bernoulli trials needed to get one success, k is the number of failures. + + Args: + probs (float, list, numpy.ndarray, Tensor, Parameter): probability of success. + seed (int): seed to use in sampling. Default: 0. + dtype (mindspore.dtype): type of the distribution. Default: mstype.int32. + name (str): name of the distribution. Default: Geometric. + + Note: + probs should be proper probabilities (0 <= p <= 1). + Dist_spec_args is probs. + + Examples: + >>> # To initialize a Geometric distribution of prob 0.5 + >>> n = nn.Geometric(0.5, dtype=mstype.int32) + >>> + >>> # The following creates two independent Geometric distributions + >>> n = nn.Geometric([0.5, 0.5], dtype=mstype.int32) + >>> + >>> # A Geometric distribution can be initilized without arguments + >>> # In this case, probs must be passed in through construct. + >>> n = nn.Geometric(dtype=mstype.int32) + >>> + >>> # To use Geometric distribution in a network + >>> class net(Cell): + >>> def __init__(self): + >>> super(net, self).__init__(): + >>> self.g1 = nn.Geometric(0.5, dtype=mstype.int32) + >>> self.g2 = nn.Geometric(dtype=mstype.int32) + >>> + >>> # Tthe following calls are valid in construct + >>> def construct(self, value, probs_b, probs_a): + >>> + >>> # Similar calls can be made to other probability functions + >>> # by replacing 'prob' with the name of the function + >>> ans = self.g1('prob', value) + >>> # Evaluate with the respect to distribution b + >>> ans = self.g1('prob', value, probs_b) + >>> + >>> # Probs must be passed in through construct + >>> ans = self.g2('prob', value, probs_a) + >>> + >>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean' + >>> # Will return [0.0] + >>> ans = self.g1('mean') + >>> # Will return mean_b + >>> ans = self.g1('mean', probs_b) + >>> + >>> # Probs must be passed in through construct + >>> ans = self.g2('mean', probs_a) + >>> + >>> # Usage of 'kl_loss' and 'cross_entropy' are similar + >>> ans = self.g1('kl_loss', 'Geometric', probs_b) + >>> ans = self.g1('kl_loss', 'Geometric', probs_b, probs_a) + >>> + >>> # Additional probs must be passed in through construct + >>> ans = self.g2('kl_loss', 'Geometric', probs_b, probs_a) + >>> + >>> # Sample Usage + >>> ans = self.g1('sample') + >>> ans = self.g1('sample', (2,3)) + >>> ans = self.g1('sample', (2,3), probs_b) + >>> ans = self.g2('sample', (2,3), probs_a) + """ + + def __init__(self, + probs=None, + seed=0, + dtype=mstype.int32, + name="Geometric"): + """ + Constructor of Geometric distribution. + """ + param = dict(locals()) + super(Geometric, self).__init__(dtype, name, param) + if probs is not None: + self._probs = cast_to_tensor(probs, dtype=mstype.float32) + check_prob(self._probs) + else: + self._probs = probs + + self.minval = np.finfo(np.float).tiny + + # ops needed for the class + self.const = P.ScalarToArray() + self.dtypeop = P.DType() + self.fill = P.Fill() + self.floor = P.Floor() + self.issubclass = P.IsSubClass() + self.less = P.Less() + self.log = P.Log() + self.pow = P.Pow() + self.select = P.Select() + self.shape = P.Shape() + self.sq = P.Square() + self.sqrt = P.Sqrt() + self.uniform = P.UniformReal(seed=seed) + + def extend_repr(self): + if self.is_scalar_batch: + str_info = f'probs = {self.probs}' + else: + str_info = f'batch_shape = {self._broadcast_shape}' + return str_info + + @property + def probs(self): + """ + Returns the probability for the outcome is 1. + """ + return self._probs + + def _mean(self, name='mean', probs1=None): + r""" + .. math:: + MEAN(Geo) = \fratc{1 - probs1}{probs1} + """ + if name == 'mean': + probs1 = self.probs if probs1 is None else probs1 + return (1. - probs1) / probs1 + return None + + def _mode(self, name='mode', probs1=None): + r""" + .. math:: + MODE(Geo) = 0 + """ + if name == 'mode': + probs1 = self.probs if probs1 is None else probs1 + return self.fill(self.dtypeop(probs1), self.shape(probs1), 0.) + return None + + def _var(self, name='var', probs1=None): + r""" + .. math:: + VAR(Geo) = \fract{1 - probs1}{probs1 ^ {2}} + """ + if name in self._variance_functions: + probs1 = self.probs if probs1 is None else probs1 + return (1.0 - probs1) / self.sq(probs1) + return None + + def _entropy(self, name='entropy', probs=None): + r""" + .. math:: + H(Geo) = \fract{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1} + """ + if name == 'entropy': + probs1 = self.probs if probs is None else probs + probs0 = 1.0 - probs1 + return (-probs0 * self.log(probs0) - probs1 * self.log(probs1)) / probs1 + return None + + def _cross_entropy(self, name, dist, probs1_b, probs1_a=None): + r""" + Evaluate cross_entropy between Geometric distributions. + + Args: + name (str): name of the funtion. Should always be "cross_entropy" when passed in from construct. + dist (str): type of the distributions. Should be "Geometric" in this case. + probs1_b (Tensor): probability of success of distribution b. + probs1_a (Tensor): probability of success of distribution a. Default: self.probs. + """ + if name == 'cross_entropy' and dist == 'Geometric': + return self._entropy(probs=probs1_a) + self._kl_loss(name, dist, probs1_b, probs1_a) + return None + + def _prob(self, name, value, probs=None): + r""" + pmf of Geometric distribution. + + Args: + name (str): name of the function. Should be "prob" when passed in from construct. + value (Tensor): a Tensor composed of only natural numbers. + probs (Tensor): probability of success. Default: self.probs. + + .. math:: + pmf(k) = probs0 ^k * probs1 if k >= 0; + pmf(k) = 0 if k < 0. + """ + if name in self._prob_functions: + probs1 = self.probs if probs is None else probs + dtype = self.dtypeop(value) + if self.issubclass(dtype, mstype.int_): + pass + elif self.issubclass(dtype, mstype.float_): + value = self.floor(value) + else: + return None + pmf = self.pow((1.0 - probs1), value) * probs1 + zeros = self.fill(self.dtypeop(probs1), self.shape(pmf), 0.0) + comp = self.less(value, zeros) + return self.select(comp, zeros, pmf) + return None + + def _cdf(self, name, value, probs=None): + r""" + cdf of Geometric distribution. + + Args: + name (str): name of the function. + value (Tensor): a Tensor composed of only natural numbers. + probs (Tensor): probability of success. Default: self.probs. + + .. math:: + cdf(k) = 1 - probs0 ^ (k+1) if k >= 0; + cdf(k) = 0 if k < 0. + + """ + if name in self._cdf_survival_functions: + probs1 = self.probs if probs is None else probs + probs0 = 1.0 - probs1 + dtype = self.dtypeop(value) + if self.issubclass(dtype, mstype.int_): + pass + elif self.issubclass(dtype, mstype.float_): + value = self.floor(value) + else: + return None + cdf = 1.0 - self.pow(probs0, value + 1.0) + zeros = self.fill(self.dtypeop(probs1), self.shape(cdf), 0.0) + comp = self.less(value, zeros) + return self.select(comp, zeros, cdf) + return None + + def _kl_loss(self, name, dist, probs1_b, probs1_a=None): + r""" + Evaluate Geometric-Geometric kl divergence, i.e. KL(a||b). + + Args: + name (str): name of the funtion. + dist (str): type of the distributions. Should be "Geometric" in this case. + probs1_b (Tensor): probability of success of distribution b. + probs1_a (Tensor): probability of success of distribution a. Default: self.probs. + + .. math:: + KL(a||b) = \log(\fract{probs1_a}{probs1_b}) + \fract{probs0_a}{probs1_a} * \log(\fract{probs0_a}{probs0_b}) + """ + if name in self._divergence_functions and dist == 'Geometric': + probs1_a = self.probs if probs1_a is None else probs1_a + probs0_a = 1.0 - probs1_a + probs0_b = 1.0 - probs1_b + return self.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * self.log(probs0_a / probs0_b) + return None + + def _sample(self, name, shape=(), probs=None): + """ + Sampling. + + Args: + name (str): name of the function. Should always be 'sample' when passed in from construct. + shape (tuple): shape of the sample. Default: (). + probs (Tensor): probability of success. Default: self.probs. + + Returns: + Tensor, shape is shape + batch_shape. + """ + if name == 'sample': + probs = self.probs if probs is None else probs + minval = self.const(self.minval) + maxval = self.const(1.0) + sample_uniform = self.uniform(shape + self.shape(probs), minval, maxval) + return self.floor(self.log(sample_uniform) / self.log(1.0 - probs)) + return None diff --git a/mindspore/nn/distribution/normal.py b/mindspore/nn/distribution/normal.py index 61cec6d810..7bfea6c7e9 100644 --- a/mindspore/nn/distribution/normal.py +++ b/mindspore/nn/distribution/normal.py @@ -23,24 +23,70 @@ from ...context import get_context class Normal(Distribution): """ - Example class: Normal distribution. + Normal distribution. Args: - mean (int, float, list, numpy.ndarray, Tensor, Parameter): mean of the Gaussian distribution. - sd (int, float, list, numpy.ndarray, Tensor, Parameter): stddev of the Gaussian distribution. + mean (int, float, list, numpy.ndarray, Tensor, Parameter): mean of the Normal distribution. + sd (int, float, list, numpy.ndarray, Tensor, Parameter): stddev of the Normal distribution. seed (int): seed to use in sampling. Default: 0. dtype (mindspore.dtype): type of the distribution. Default: mstype.float32. name (str): name of the distribution. Default: Normal. - Note: Standard deviation should be greater than zero. + Dist_spec_args are mean and sd. Examples: - >>> # To initialize a normal distribution of mean 3.0 and standard deviation 4.0 - >>> n = nn.Normal(3.0, 4.0, dtype=mstype.float32) - >>> # The following create two independent normal distributions - >>> n = nn.Normal([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32) + >>> # To initialize a Normal distribution of mean 3.0 and standard deviation 4.0 + >>> n = nn.Normal(3.0, 4.0, dtype=mstype.float32) + >>> + >>> # The following creates two independent Normal distributions + >>> n = nn.Normal([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32) + >>> + >>> # A normal distribution can be initilize without arguments + >>> # In this case, mean and sd must be passed in through construct. + >>> n = nn.Normal(dtype=mstype.float32) + >>> + >>> # To use normal in a network + >>> class net(Cell): + >>> def __init__(self): + >>> super(net, self).__init__(): + >>> self.n1 = nn.Normal(0.0, 1.0, dtype=mstype.float32) + >>> self.n2 = nn.Normal(dtype=mstype.float32) + >>> + >>> # The following calls are valid in construct + >>> def construct(self, value, mean_b, sd_b, mean_a, sd_a): + >>> + >>> # Similar calls can be made to other probability functions + >>> # by replacing 'prob' with the name of the function + >>> ans = self.n1('prob', value) + >>> # Evaluate with the respect to distribution b + >>> ans = self.n1('prob', value, mean_b, sd_b) + >>> + >>> # mean and sd must be passed in through construct + >>> ans = self.n2('prob', value, mean_a, sd_a) + >>> + >>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean' + >>> # Will return [0.0] + >>> ans = self.n1('mean') + >>> # Will return mean_b + >>> ans = self.n1('mean', mean_b, sd_b) + >>> + >>> # mean and sd must be passed in through construct + >>> ans = self.n2('mean', mean_a, sd_a) + >>> + >>> # Usage of 'kl_loss' and 'cross_entropy' are similar + >>> ans = self.n1('kl_loss', 'Normal', mean_b, sd_b) + >>> ans = self.n1('kl_loss', 'Normal', mean_b, sd_b, mean_a, sd_a) + >>> + >>> # Additional mean and sd must be passed in through construct + >>> ans = self.n2('kl_loss', 'Normal', mean_b, sd_b, mean_a, sd_a) + >>> + >>> # Sample Usage + >>> ans = self.n1('sample') + >>> ans = self.n1('sample', (2,3)) + >>> ans = self.n1('sample', (2,3), mean_b, sd_b) + >>> ans = self.n2('sample', (2,3), mean_a, sd_a) """ def __init__(self, @@ -64,27 +110,29 @@ class Normal(Distribution): self.seed = seed #ops needed for the class - self.exp = P.Exp() - self.add = P.TensorAdd() - self.mul = P.Mul() - self.sq = P.Square() - self.log = P.Log() - self.sqrt = P.Sqrt() - self.realdiv = P.RealDiv() - self.expm1 = P.Expm1() if get_context('device_target') == 'Ascend' else self._expm1_by_step - self.shape = P.Shape() - self.zeroslike = P.ZerosLike() self.const = P.ScalarToArray() + self.erf = P.Erf() + self.exp = P.Exp() + self.expm1 = P.Expm1() if get_context('device_target') == 'Ascend' else self._expm1_by_step + self.fill = P.Fill() + self.log = P.Log() + self.shape = P.Shape() + self.sq = P.Square() + self.sqrt = P.Sqrt() + self.zeroslike = P.ZerosLike() def extend_repr(self): - str_info = f'mean = {self._mean_value}, standard deviation = {self._sd_value}' + if self.is_scalar_batch: + str_info = f'mean = {self._mean_value}, standard deviation = {self._sd_value}' + else: + str_info = f'batch_shape = {self._broadcast_shape}' return str_info def _expm1_by_step(self, x): """ Expm1 ops under GPU context. """ - return self.add(self.exp(x), -1) + return self.exp(x) - 1.0 def _mean(self, name='mean', mean=None, sd=None): """ @@ -95,29 +143,92 @@ class Normal(Distribution): return mean return None + def _mode(self, name='mode', mean=None, sd=None): + """ + Mode of the distribution. + """ + if name == 'mode': + mean = self._mean_value if mean is None or sd is None else mean + return mean + return None + def _sd(self, name='sd', mean=None, sd=None): """ Standard deviation of the distribution. """ - if name in ('sd', 'var'): + if name in self._variance_functions: sd = self._sd_value if mean is None or sd is None else sd return sd return None - def _log_likelihood(self, name, value, mean=None, sd=None): + def _entropy(self, name='entropy', sd=None): + r""" + Evaluate entropy. + + .. math:: + H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma))) + """ + if name == 'entropy': + sd = self._sd_value if sd is None else sd + return self.log(self.sqrt(np.e * 2. * np.pi * self.sq(sd))) + return None + + def _cross_entropy(self, name, dist, mean_b, sd_b, mean_a=None, sd_a=None): + r""" + Evaluate cross_entropy between normal distributions. + + Args: + name (str): name of the funtion passed in from construct. Should always be "cross_entropy". + dist (str): type of the distributions. Should be "Normal" in this case. + mean_b (Tensor): mean of distribution b. + sd_b (Tensor): standard deviation distribution b. + mean_a (Tensor): mean of distribution a. Default: self._mean_value. + sd_a (Tensor): standard deviation distribution a. Default: self._sd_value. + """ + if name == 'cross_entropy' and dist == 'Normal': + return self._entropy(sd=sd_a) + self._kl_loss(name, dist, mean_b, sd_b, mean_a, sd_a) + return None + + def _log_prob(self, name, value, mean=None, sd=None): r""" Evaluate log probability. + Args: + name (str): name of the funtion passed in from construct. + value (Tensor): value to be evaluated. + mean (Tensor): mean of the distribution. Default: self._mean_value. + sd (Tensor): standard deviation the distribution. Default: self._sd_value. + .. math:: L(x) = -1* \fract{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2)) """ - if name in ('prob', 'log_prob'): + if name in self._prob_functions: mean = self._mean_value if mean is None else mean sd = self._sd_value if sd is None else sd - unnormalized_log_prob = -1. * self.realdiv(self.sq(self.add(value, -1. * mean)), - 2. * self.sq(sd)) + unnormalized_log_prob = -1. * (self.sq(value - mean)) / (2. * self.sq(sd)) neg_normalization = -1. * self.log(self.sqrt(2. * np.pi * self.sq(sd))) - return self.add(unnormalized_log_prob, neg_normalization) + return unnormalized_log_prob + neg_normalization + return None + + def _cdf(self, name, value, mean=None, sd=None): + r""" + Evaluate cdf of given value. + + Args: + name (str): name of the funtion passed in from construct. Should always be "cdf". + value (Tensor): value to be evaluated. + mean (Tensor): mean of the distribution. Default: self._mean_value. + sd (Tensor): standard deviation the distribution. Default: self._sd_value. + + .. math:: + cdf(x) = 0.5 * (1+ Erf((x - \mu) / ( \sigma * \sqrt(2)))) + """ + if name in self._cdf_survival_functions: + mean = self._mean_value if mean is None else mean + sd = self._sd_value if sd is None else sd + sqrt2 = self.sqrt(self.const(2.0)) + adjusted = (value - mean) / (sd * sqrt2) + return 0.5 * (1.0 + self.erf(adjusted)) return None def _kl_loss(self, name, dist, mean_b, sd_b, mean_a=None, sd_a=None): @@ -125,7 +236,7 @@ class Normal(Distribution): Evaluate Normal-Normal kl divergence, i.e. KL(a||b). Args: - name (str): name of the funtion passed in from construct. Should always be "kl_loss". + name (str): name of the funtion passed in from construct. dist (str): type of the distributions. Should be "Normal" in this case. mean_b (Tensor): mean of distribution b. sd_b (Tensor): standard deviation distribution b. @@ -136,12 +247,12 @@ class Normal(Distribution): KL(a||b) = 0.5 * (\fract{MEAN(a)}{STD(b)} - \fract{MEAN(b)}{STD(b)}) ^ 2 + 0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b))) """ - if name == 'kl_loss' and dist == 'Normal': + if name in self._divergence_functions and dist == 'Normal': mean_a = self._mean_value if mean_a is None else mean_a sd_a = self._sd_value if sd_a is None else sd_a - diff_log_scale = self.add(self.log(sd_a), - self.log(sd_b)) - squared_diff = self.sq(self.add(self.realdiv(mean_a, sd_b), - self.realdiv(mean_b, sd_b))) - return self.add(self.add(0.5 * squared_diff, 0.5 * self.expm1(2 * diff_log_scale)), - diff_log_scale) + diff_log_scale = self.log(sd_a) - self.log(sd_b) + squared_diff = self.sq(mean_a / sd_b - mean_b / sd_b) + return 0.5 * squared_diff + 0.5 * self.expm1(2 * diff_log_scale) - diff_log_scale return None def _sample(self, name, shape=(), mean=None, sd=None): @@ -160,11 +271,11 @@ class Normal(Distribution): if name == 'sample': mean = self._mean_value if mean is None else mean sd = self._sd_value if sd is None else sd - batch_shape = self.shape(self.add(self.zeroslike(mean), self.zeroslike(sd))) + batch_shape = self.shape(self.zeroslike(mean) + self.zeroslike(sd)) sample_shape = shape + batch_shape mean_zero = self.const(0.0) sd_one = self.const(1.0) sample_norm = C.normal(sample_shape, mean_zero, sd_one, self.seed) - sample = self.add(mean, self.mul(sample_norm, sd)) + sample = mean + sample_norm * sd return sample return None diff --git a/mindspore/nn/distribution/uniform.py b/mindspore/nn/distribution/uniform.py new file mode 100644 index 0000000000..3b90bbe736 --- /dev/null +++ b/mindspore/nn/distribution/uniform.py @@ -0,0 +1,304 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Uniform Distribution""" +from mindspore.ops import operations as P +from .distribution import Distribution +from ...common import dtype as mstype +from ._utils.utils import convert_to_batch, check_greater + +class Uniform(Distribution): + """ + Example class: Uniform Distribution. + + Args: + low (int, float, list, numpy.ndarray, Tensor, Parameter): lower bound of the distribution. + high (int, float, list, numpy.ndarray, Tensor, Parameter): upper bound of the distribution. + seed (int): seed to use in sampling. Default: 0. + dtype (mindspore.dtype): type of the distribution. Default: mstype.float32. + name (str): name of the distribution. Default: Uniform. + + Note: + low should be stricly less than high. + Dist_spec_args are high and low. + + Examples: + >>> # To initialize a Uniform distribution of mean 3.0 and standard deviation 4.0 + >>> n = nn.Uniform(0.0, 1.0, dtype=mstype.float32) + >>> + >>> # The following creates two independent Uniform distributions + >>> n = nn.Uniform([0.0, 0.0], [1.0, 2.0], dtype=mstype.float32) + >>> + >>> # A Uniform distribution can be initilized without arguments + >>> # In this case, high and low must be passed in through construct. + >>> n = nn.Uniform(dtype=mstype.float32) + >>> + >>> # To use Uniform in a network + >>> class net(Cell): + >>> def __init__(self) + >>> super(net, self).__init__(): + >>> self.u1 = nn.Uniform(0.0, 1.0, dtype=mstype.float32) + >>> self.u2 = nn.Uniform(dtype=mstype.float32) + >>> + >>> # All the following calls in construct are valid + >>> def construct(self, value, low_b, high_b, low_a, high_a): + >>> + >>> # Similar calls can be made to other probability functions + >>> # by replacing 'prob' with the name of the function + >>> ans = self.u1('prob', value) + >>> # Evaluate with the respect to distribution b + >>> ans = self.u1('prob', value, low_b, high_b) + >>> + >>> # High and low must be passed in through construct + >>> ans = self.u2('prob', value, low_a, high_a) + >>> + >>> # Functions 'sd', 'var', 'entropy' have the same usage with 'mean' + >>> # Will return [0.0] + >>> ans = self.u1('mean') + >>> # Will return low_b + >>> ans = self.u1('mean', low_b, high_b) + >>> + >>> # High and low must be passed in through construct + >>> ans = self.u2('mean', low_a, high_a) + >>> + >>> # Usage of 'kl_loss' and 'cross_entropy' are similar + >>> ans = self.u1('kl_loss', 'Uniform', low_b, high_b) + >>> ans = self.u1('kl_loss', 'Uniform', low_b, high_b, low_a, high_a) + >>> + >>> # Additional high and low must be passed in through construct + >>> ans = self.u2('kl_loss', 'Uniform', low_b, high_b, low_a, high_a) + >>> + >>> # Sample Usage + >>> ans = self.u1('sample') + >>> ans = self.u1('sample', (2,3)) + >>> ans = self.u1('sample', (2,3), low_b, high_b) + >>> ans = self.u2('sample', (2,3), low_a, high_a) + """ + + def __init__(self, + low=None, + high=None, + seed=0, + dtype=mstype.float32, + name="Uniform"): + """ + Constructor of Uniform distribution. + """ + param = dict(locals()) + super(Uniform, self).__init__(dtype, name, param) + if low is not None and high is not None: + self._low = convert_to_batch(low, self._broadcast_shape, dtype) + self._high = convert_to_batch(high, self._broadcast_shape, dtype) + check_greater(self.low, self.high, "low value", "high value") + else: + self._low = low + self._high = high + + # ops needed for the class + self.const = P.ScalarToArray() + self.dtypeop = P.DType() + self.exp = P.Exp() + self.fill = P.Fill() + self.less = P.Less() + self.lessequal = P.LessEqual() + self.log = P.Log() + self.logicaland = P.LogicalAnd() + self.select = P.Select() + self.shape = P.Shape() + self.sq = P.Square() + self.sqrt = P.Sqrt() + self.uniform = P.UniformReal(seed=seed) + self.zeroslike = P.ZerosLike() + + def extend_repr(self): + if self.is_scalar_batch: + str_info = f'low = {self.low}, high = {self.high}' + else: + str_info = f'batch_shape = {self._broadcast_shape}' + return str_info + + @property + def low(self): + """ + Return lower bound of the distribution. + """ + return self._low + + @property + def high(self): + """ + Return upper bound of the distribution. + """ + return self._high + + def _range(self, name='range', low=None, high=None): + r""" + Return the range of the distribution. + .. math:: + range(U) = high -low + """ + if name == 'range': + low = self.low if low is None else low + high = self.high if high is None else high + return high - low + return None + + def _mean(self, name='mean', low=None, high=None): + r""" + .. math:: + MEAN(U) = \fract{low + high}{2}. + """ + if name == 'mean': + low = self.low if low is None else low + high = self.high if high is None else high + return (low + high) / 2. + return None + + def _var(self, name='var', low=None, high=None): + r""" + .. math:: + VAR(U) = \fract{(high -low) ^ 2}{12}. + """ + if name in self._variance_functions: + low = self.low if low is None else low + high = self.high if high is None else high + return self.sq(high - low) / 12.0 + return None + + def _entropy(self, name='entropy', low=None, high=None): + r""" + .. math:: + H(U) = \log(high - low). + """ + if name == 'entropy': + low = self.low if low is None else low + high = self.high if high is None else high + return self.log(high - low) + return None + + def _cross_entropy(self, name, dist, low_b, high_b, low_a=None, high_a=None): + """ + Evaluate cross_entropy between Uniform distributoins. + + Args: + name (str): name of the funtion. + dist (str): type of the distributions. Should be "Uniform" in this case. + low_b (Tensor): lower bound of distribution b. + high_b (Tensor): upper bound of distribution b. + low_a (Tensor): lower bound of distribution a. Default: self.low. + high_a (Tensor): upper bound of distribution a. Default: self.high. + """ + if name == 'cross_entropy' and dist == 'Uniform': + return self._entropy(low=low_a, high=high_a) + self._kl_loss(name, dist, low_b, high_b, low_a, high_a) + return None + + def _prob(self, name, value, low=None, high=None): + r""" + pdf of Uniform distribution. + + Args: + name (str): name of the function. + value (Tensor): value to be evaluated. + low (Tensor): lower bound of the distribution. Default: self.low. + high (Tensor): upper bound of the distribution. Default: self.high. + + .. math:: + pdf(x) = 0 if x < low; + pdf(x) = \fract{1.0}{high -low} if low <= x <= high; + pdf(x) = 0 if x > high; + """ + if name in self._prob_functions: + low = self.low if low is None else low + high = self.high if high is None else high + ones = self.fill(self.dtype, self.shape(value), 1.0) + prob = ones / (high - low) + broadcast_shape = self.shape(prob) + zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0) + comp_lo = self.less(value, low) + comp_hi = self.lessequal(value, high) + less_than_low = self.select(comp_lo, zeros, prob) + return self.select(comp_hi, less_than_low, zeros) + return None + + def _kl_loss(self, name, dist, low_b, high_b, low_a=None, high_a=None): + """ + Evaluate uniform-uniform kl divergence, i.e. KL(a||b). + + Args: + name (str): name of the funtion. + dist (str): type of the distributions. Should be "Uniform" in this case. + low_b (Tensor): lower bound of distribution b. + high_b (Tensor): upper bound of distribution b. + low_a (Tensor): lower bound of distribution a. Default: self.low. + high_a (Tensor): upper bound of distribution a. Default: self.high. + """ + if name in self._divergence_functions and dist == 'Uniform': + low_a = self.low if low_a is None else low_a + high_a = self.high if high_a is None else high_a + kl = self.log(high_b - low_b) / self.log(high_a - low_a) + comp = self.logicaland(self.lessequal(low_b, low_a), self.lessequal(high_a, high_b)) + return self.select(comp, kl, self.log(self.zeroslike(kl))) + return None + + def _cdf(self, name, value, low=None, high=None): + r""" + cdf of Uniform distribution. + + Args: + name (str): name of the function. + value (Tensor): value to be evaluated. + low (Tensor): lower bound of the distribution. Default: self.low. + high (Tensor): upper bound of the distribution. Default: self.high. + + .. math:: + cdf(x) = 0 if x < low; + cdf(x) = \fract{x - low}{high -low} if low <= x <= high; + cdf(x) = 1 if x > high; + """ + if name in self._cdf_survival_functions: + low = self.low if low is None else low + high = self.high if high is None else high + prob = (value - low) / (high - low) + broadcast_shape = self.shape(prob) + zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0) + ones = self.fill(self.dtypeop(prob), broadcast_shape, 1.0) + comp_lo = self.less(value, low) + comp_hi = self.less(value, high) + less_than_low = self.select(comp_lo, zeros, prob) + return self.select(comp_hi, less_than_low, ones) + return None + + def _sample(self, name, shape=(), low=None, high=None): + """ + Sampling. + + Args: + name (str): name of the function. Should always be 'sample' when passed in from construct. + shape (tuple): shape of the sample. Default: (). + low (Tensor): lower bound of the distribution. Default: self.low. + high (Tensor): upper bound of the distribution. Default: self.high. + + Returns: + Tensor, shape is shape + batch_shape. + """ + if name == 'sample': + low = self.low if low is None else low + high = self.high if high is None else high + broadcast_shape = self.shape(low + high) + l_zero = self.const(0.0) + h_one = self.const(1.0) + sample_uniform = self.uniform(shape + broadcast_shape, l_zero, h_one) + sample = (high - low) * sample_uniform + low + return sample + return None diff --git a/tests/st/ops/ascend/test_distribution/test_bernoulli.py b/tests/st/ops/ascend/test_distribution/test_bernoulli.py index 5652d536c7..451530116b 100644 --- a/tests/st/ops/ascend/test_distribution/test_bernoulli.py +++ b/tests/st/ops/ascend/test_distribution/test_bernoulli.py @@ -23,91 +23,67 @@ from mindspore import dtype context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") -class Net(nn.Cell): +class Prob(nn.Cell): """ - Test class: probability of bernoulli distribution. + Test class: probability of Bernoulli distribution. """ def __init__(self): - super(Net, self).__init__() + super(Prob, self).__init__() self.b = nn.Bernoulli(0.7, dtype=dtype.int32) @ms_function def construct(self, x_): return self.b('prob', x_) -class Net1(nn.Cell): - """ - Test class: log probability of bernoulli distribution. - """ - def __init__(self): - super(Net1, self).__init__() - self.b = nn.Bernoulli(0.7, dtype=dtype.int32) - - @ms_function - def construct(self, x_): - return self.b('log_prob', x_) - -class Net2(nn.Cell): - """ - Test class: kl_loss between bernoulli distributions. - """ - def __init__(self): - super(Net2, self).__init__() - self.b = nn.Bernoulli(0.7, dtype=dtype.int32) - - @ms_function - def construct(self, x_): - return self.b('kl_loss', 'Bernoulli', x_) - -class Net3(nn.Cell): - """ - Test class: mean/sd of bernoulli distribution. - """ - def __init__(self): - super(Net3, self).__init__() - self.b = nn.Bernoulli([0.5, 0.5], dtype=dtype.int32) - - @ms_function - def construct(self): - return self.b('mean'), self.b('sd') - -class Net4(nn.Cell): - """ - Test class: log probability of bernoulli distribution. - """ - def __init__(self, shape, seed=0): - super(Net4, self).__init__() - self.b = nn.Bernoulli([0.7, 0.5], seed=seed, dtype=dtype.int32) - self.shape = shape - - @ms_function - def construct(self, probs=None): - return self.b('sample', self.shape, probs) - def test_pmf(): """ Test pmf. """ bernoulli_benchmark = stats.bernoulli(0.7) expect_pmf = bernoulli_benchmark.pmf([0, 1, 0, 1, 1]).astype(np.float32) - pdf = Net() + pmf = Prob() x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32) - output = pdf(x_) + output = pmf(x_) tol = 1e-6 assert (np.abs(output.asnumpy() - expect_pmf) < tol).all() + +class LogProb(nn.Cell): + """ + Test class: log probability of Bernoulli distribution. + """ + def __init__(self): + super(LogProb, self).__init__() + self.b = nn.Bernoulli(0.7, dtype=dtype.int32) + + @ms_function + def construct(self, x_): + return self.b('log_prob', x_) + def test_log_likelihood(): """ Test log_pmf. """ bernoulli_benchmark = stats.bernoulli(0.7) expect_logpmf = bernoulli_benchmark.logpmf([0, 1, 0, 1, 1]).astype(np.float32) - logprob = Net1() + logprob = LogProb() x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32) output = logprob(x_) tol = 1e-6 assert (np.abs(output.asnumpy() - expect_logpmf) < tol).all() +class KL(nn.Cell): + """ + Test class: kl_loss between Bernoulli distributions. + """ + def __init__(self): + super(KL, self).__init__() + self.b = nn.Bernoulli(0.7, dtype=dtype.int32) + + @ms_function + def construct(self, x_): + return self.b('kl_loss', 'Bernoulli', x_) + def test_kl_loss(): """ Test kl_loss. @@ -117,31 +93,203 @@ def test_kl_loss(): probs0_a = 1 - probs1_a probs0_b = 1 - probs1_b expect_kl_loss = probs1_a * np.log(probs1_a / probs1_b) + probs0_a * np.log(probs0_a / probs0_b) - kl_loss = Net2() + kl_loss = KL() output = kl_loss(Tensor([probs1_b], dtype=dtype.float32)) tol = 1e-6 assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() +class Basics(nn.Cell): + """ + Test class: mean/sd/mode of Bernoulli distribution. + """ + def __init__(self): + super(Basics, self).__init__() + self.b = nn.Bernoulli([0.3, 0.5, 0.7], dtype=dtype.int32) + + @ms_function + def construct(self): + return self.b('mean'), self.b('sd'), self.b('mode') + def test_basics(): """ - Test mean/standard deviation and probs. + Test mean/standard deviation/mode. """ - basics = Net3() - mean, sd = basics() - expect_mean = [0.5, 0.5] - assert (mean.asnumpy() == expect_mean).all() - assert (sd.asnumpy() == expect_mean).all() - b = nn.Bernoulli([0.7, 0.5], dtype=dtype.int32) - probs = b.probs() - expect_probs = [0.7, 0.5] + basics = Basics() + mean, sd, mode = basics() + expect_mean = [0.3, 0.5, 0.7] + expect_sd = np.sqrt(np.multiply([0.7, 0.5, 0.3], [0.3, 0.5, 0.7])) + expect_mode = [0.0, 0.0, 1.0] tol = 1e-6 - assert (np.abs(probs.asnumpy() - expect_probs) < tol).all() + assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() + assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() + assert (np.abs(mode.asnumpy() - expect_mode) < tol).all() + +class Sampling(nn.Cell): + """ + Test class: log probability of Bernoulli distribution. + """ + def __init__(self, shape, seed=0): + super(Sampling, self).__init__() + self.b = nn.Bernoulli([0.7, 0.5], seed=seed, dtype=dtype.int32) + self.shape = shape + + @ms_function + def construct(self, probs=None): + return self.b('sample', self.shape, probs) def test_sample(): """ Test sample. """ shape = (2, 3) - sample = Net4(shape) + sample = Sampling(shape) output = sample() assert output.shape == (2, 3, 2) + +class CDF(nn.Cell): + """ + Test class: cdf of bernoulli distributions. + """ + def __init__(self): + super(CDF, self).__init__() + self.b = nn.Bernoulli(0.7, dtype=dtype.int32) + + @ms_function + def construct(self, x_): + return self.b('cdf', x_) + +def test_cdf(): + """ + Test cdf. + """ + bernoulli_benchmark = stats.bernoulli(0.7) + expect_cdf = bernoulli_benchmark.cdf([0, 0, 1, 0, 1]).astype(np.float32) + x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype(np.int32), dtype=dtype.float32) + cdf = CDF() + output = cdf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() + + +class LogCDF(nn.Cell): + """ + Test class: log cdf of bernoulli distributions. + """ + def __init__(self): + super(LogCDF, self).__init__() + self.b = nn.Bernoulli(0.7, dtype=dtype.int32) + + @ms_function + def construct(self, x_): + return self.b('log_cdf', x_) + +def test_logcdf(): + """ + Test log_cdf. + """ + bernoulli_benchmark = stats.bernoulli(0.7) + expect_logcdf = bernoulli_benchmark.logcdf([0, 0, 1, 0, 1]).astype(np.float32) + x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype(np.int32), dtype=dtype.float32) + logcdf = LogCDF() + output = logcdf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all() + + +class SF(nn.Cell): + """ + Test class: survival function of Bernoulli distributions. + """ + def __init__(self): + super(SF, self).__init__() + self.b = nn.Bernoulli(0.7, dtype=dtype.int32) + + @ms_function + def construct(self, x_): + return self.b('survival_function', x_) + +def test_survival(): + """ + Test survival funciton. + """ + bernoulli_benchmark = stats.bernoulli(0.7) + expect_survival = bernoulli_benchmark.sf([0, 1, 1, 0, 0]).astype(np.float32) + x_ = Tensor(np.array([0, 1, 1, 0, 0]).astype(np.int32), dtype=dtype.float32) + sf = SF() + output = sf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_survival) < tol).all() + + +class LogSF(nn.Cell): + """ + Test class: log survival function of Bernoulli distributions. + """ + def __init__(self): + super(LogSF, self).__init__() + self.b = nn.Bernoulli(0.7, dtype=dtype.int32) + + @ms_function + def construct(self, x_): + return self.b('log_survival', x_) + +def test_log_survival(): + """ + Test log survival funciton. + """ + bernoulli_benchmark = stats.bernoulli(0.7) + expect_logsurvival = bernoulli_benchmark.logsf([-1, 0.9, 0, 0, 0]).astype(np.float32) + x_ = Tensor(np.array([-1, 0.9, 0, 0, 0]).astype(np.float32), dtype=dtype.float32) + log_sf = LogSF() + output = log_sf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all() + +class EntropyH(nn.Cell): + """ + Test class: entropy of Bernoulli distributions. + """ + def __init__(self): + super(EntropyH, self).__init__() + self.b = nn.Bernoulli(0.7, dtype=dtype.int32) + + @ms_function + def construct(self): + return self.b('entropy') + +def test_entropy(): + """ + Test entropy. + """ + bernoulli_benchmark = stats.bernoulli(0.7) + expect_entropy = bernoulli_benchmark.entropy().astype(np.float32) + entropy = EntropyH() + output = entropy() + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() + +class CrossEntropy(nn.Cell): + """ + Test class: cross entropy between bernoulli distributions. + """ + def __init__(self): + super(CrossEntropy, self).__init__() + self.b = nn.Bernoulli(0.7, dtype=dtype.int32) + + @ms_function + def construct(self, x_): + entropy = self.b('entropy') + kl_loss = self.b('kl_loss', 'Bernoulli', x_) + h_sum_kl = entropy + kl_loss + cross_entropy = self.b('cross_entropy', 'Bernoulli', x_) + return h_sum_kl - cross_entropy + +def test_cross_entropy(): + """ + Test cross_entropy. + """ + cross_entropy = CrossEntropy() + prob = Tensor([0.3], dtype=dtype.float32) + diff = cross_entropy(prob) + tol = 1e-6 + assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all() diff --git a/tests/st/ops/ascend/test_distribution/test_exponential.py b/tests/st/ops/ascend/test_distribution/test_exponential.py new file mode 100644 index 0000000000..823f9b0e1a --- /dev/null +++ b/tests/st/ops/ascend/test_distribution/test_exponential.py @@ -0,0 +1,291 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test cases for exponential distribution""" +import numpy as np +from scipy import stats +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.api import ms_function +from mindspore import dtype + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +class Prob(nn.Cell): + """ + Test class: probability of Exponential distribution. + """ + def __init__(self): + super(Prob, self).__init__() + self.e = nn.Exponential([[1.0], [0.5]], dtype=dtype.float32) + + @ms_function + def construct(self, x_): + return self.e('prob', x_) + +def test_pdf(): + """ + Test pdf. + """ + expon_benchmark = stats.expon(scale=[[1.0], [2.0]]) + expect_pdf = expon_benchmark.pdf([-1.0, 0.0, 1.0]).astype(np.float32) + pdf = Prob() + x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32) + output = pdf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_pdf) < tol).all() + +class LogProb(nn.Cell): + """ + Test class: log probability of Exponential distribution. + """ + def __init__(self): + super(LogProb, self).__init__() + self.e = nn.Exponential([[1.0], [0.5]], dtype=dtype.float32) + + @ms_function + def construct(self, x_): + return self.e('log_prob', x_) + +def test_log_likelihood(): + """ + Test log_pdf. + """ + expon_benchmark = stats.expon(scale=[[1.0], [2.0]]) + expect_logpdf = expon_benchmark.logpdf([0.5, 1.0, 2.0]).astype(np.float32) + logprob = LogProb() + x_ = Tensor(np.array([0.5, 1.0, 2.0]).astype(np.float32), dtype=dtype.float32) + output = logprob(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all() + +class KL(nn.Cell): + """ + Test class: kl_loss between Exponential distributions. + """ + def __init__(self): + super(KL, self).__init__() + self.e = nn.Exponential([1.5], dtype=dtype.float32) + + @ms_function + def construct(self, x_): + return self.e('kl_loss', 'Exponential', x_) + +def test_kl_loss(): + """ + Test kl_loss. + """ + rate_a = 1.5 + rate_b = np.array([0.5, 2.0]).astype(np.float32) + expect_kl_loss = np.log(rate_a) - np.log(rate_b) + rate_b / rate_a - 1.0 + kl = KL() + output = kl(Tensor(rate_b, dtype=dtype.float32)) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() + +class Basics(nn.Cell): + """ + Test class: mean/sd/mode of Exponential distribution. + """ + def __init__(self): + super(Basics, self).__init__() + self.e = nn.Exponential([0.5], dtype=dtype.float32) + + @ms_function + def construct(self): + return self.e('mean'), self.e('sd'), self.e('mode') + +def test_basics(): + """ + Test mean/standard/mode deviation. + """ + basics = Basics() + mean, sd, mode = basics() + expect_mean = 2. + expect_sd = 2. + expect_mode = 0. + tol = 1e-6 + assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() + assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() + assert (np.abs(mode.asnumpy() - expect_mode) < tol).all() + +class Sampling(nn.Cell): + """ + Test class: sample of Exponential distribution. + """ + def __init__(self, shape, seed=0): + super(Sampling, self).__init__() + self.e = nn.Exponential([[1.0], [0.5]], seed=seed, dtype=dtype.float32) + self.shape = shape + + @ms_function + def construct(self, rate=None): + return self.e('sample', self.shape, rate) + +def test_sample(): + """ + Test sample. + """ + shape = (2, 3) + seed = 10 + rate = Tensor([1.0, 2.0, 3.0], dtype=dtype.float32) + sample = Sampling(shape, seed=seed) + output = sample(rate) + assert output.shape == (2, 3, 3) + +class CDF(nn.Cell): + """ + Test class: cdf of Exponential distribution. + """ + def __init__(self): + super(CDF, self).__init__() + self.e = nn.Exponential([[1.0], [0.5]], dtype=dtype.float32) + + @ms_function + def construct(self, x_): + return self.e('cdf', x_) + +def test_cdf(): + """ + Test cdf. + """ + expon_benchmark = stats.expon(scale=[[1.0], [2.0]]) + expect_cdf = expon_benchmark.cdf([-1.0, 0.0, 1.0]).astype(np.float32) + cdf = CDF() + x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32) + output = cdf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() + +class LogCDF(nn.Cell): + """ + Test class: log_cdf of Exponential distribution. + """ + def __init__(self): + super(LogCDF, self).__init__() + self.e = nn.Exponential([[1.0], [0.5]], dtype=dtype.float32) + + @ms_function + def construct(self, x_): + return self.e('log_cdf', x_) + +def test_log_cdf(): + """ + Test log_cdf. + """ + expon_benchmark = stats.expon(scale=[[1.0], [2.0]]) + expect_logcdf = expon_benchmark.logcdf([0.5, 1.0, 2.5]).astype(np.float32) + logcdf = LogCDF() + x_ = Tensor(np.array([0.5, 1.0, 2.5]).astype(np.float32), dtype=dtype.float32) + output = logcdf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all() + +class SF(nn.Cell): + """ + Test class: survival function of Exponential distribution. + """ + def __init__(self): + super(SF, self).__init__() + self.e = nn.Exponential([[1.0], [0.5]], dtype=dtype.float32) + + @ms_function + def construct(self, x_): + return self.e('survival_function', x_) + +def test_survival(): + """ + Test survival function. + """ + expon_benchmark = stats.expon(scale=[[1.0], [2.0]]) + expect_survival = expon_benchmark.sf([-1.0, 0.0, 1.0]).astype(np.float32) + survival = SF() + x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32) + output = survival(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_survival) < tol).all() + +class LogSF(nn.Cell): + """ + Test class: log survival function of Exponential distribution. + """ + def __init__(self): + super(LogSF, self).__init__() + self.e = nn.Exponential([[1.0], [0.5]], dtype=dtype.float32) + + @ms_function + def construct(self, x_): + return self.e('log_survival', x_) + +def test_log_survival(): + """ + Test log survival function. + """ + expon_benchmark = stats.expon(scale=[[1.0], [2.0]]) + expect_logsurvival = expon_benchmark.logsf([-1.0, 0.0, 1.0]).astype(np.float32) + logsurvival = LogSF() + x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32) + output = logsurvival(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all() + +class EntropyH(nn.Cell): + """ + Test class: entropy of Exponential distribution. + """ + def __init__(self): + super(EntropyH, self).__init__() + self.e = nn.Exponential([[1.0], [0.5]], dtype=dtype.float32) + + @ms_function + def construct(self): + return self.e('entropy') + +def test_entropy(): + """ + Test entropy. + """ + expon_benchmark = stats.expon(scale=[[1.0], [2.0]]) + expect_entropy = expon_benchmark.entropy().astype(np.float32) + entropy = EntropyH() + output = entropy() + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() + +class CrossEntropy(nn.Cell): + """ + Test class: cross entropy between Exponential distribution. + """ + def __init__(self): + super(CrossEntropy, self).__init__() + self.e = nn.Exponential([1.0], dtype=dtype.float32) + + @ms_function + def construct(self, x_): + entropy = self.e('entropy') + kl_loss = self.e('kl_loss', 'Exponential', x_) + h_sum_kl = entropy + kl_loss + cross_entropy = self.e('cross_entropy', 'Exponential', x_) + return h_sum_kl - cross_entropy + +def test_cross_entropy(): + """ + Test cross_entropy. + """ + cross_entropy = CrossEntropy() + rate = Tensor([0.5], dtype=dtype.float32) + diff = cross_entropy(rate) + tol = 1e-6 + assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all() diff --git a/tests/st/ops/ascend/test_distribution/test_geometric.py b/tests/st/ops/ascend/test_distribution/test_geometric.py new file mode 100644 index 0000000000..b3b9995bcb --- /dev/null +++ b/tests/st/ops/ascend/test_distribution/test_geometric.py @@ -0,0 +1,291 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test cases for Geometric distribution""" +import numpy as np +from scipy import stats +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.api import ms_function +from mindspore import dtype + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +class Prob(nn.Cell): + """ + Test class: probability of Geometric distribution. + """ + def __init__(self): + super(Prob, self).__init__() + self.g = nn.Geometric(0.7, dtype=dtype.int32) + + @ms_function + def construct(self, x_): + return self.g('prob', x_) + +def test_pmf(): + """ + Test pmf. + """ + geom_benchmark = stats.geom(0.7) + expect_pmf = geom_benchmark.pmf([0, 1, 2, 3, 4]).astype(np.float32) + pdf = Prob() + x_ = Tensor(np.array([-1, 0, 1, 2, 3]).astype(np.float32), dtype=dtype.float32) + output = pdf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_pmf) < tol).all() + +class LogProb(nn.Cell): + """ + Test class: log probability of Geometric distribution. + """ + def __init__(self): + super(LogProb, self).__init__() + self.g = nn.Geometric(0.7, dtype=dtype.int32) + + @ms_function + def construct(self, x_): + return self.g('log_prob', x_) + +def test_log_likelihood(): + """ + Test log_pmf. + """ + geom_benchmark = stats.geom(0.7) + expect_logpmf = geom_benchmark.logpmf([1, 2, 3, 4, 5]).astype(np.float32) + logprob = LogProb() + x_ = Tensor(np.array([0, 1, 2, 3, 4]).astype(np.int32), dtype=dtype.float32) + output = logprob(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_logpmf) < tol).all() + +class KL(nn.Cell): + """ + Test class: kl_loss between Geometric distributions. + """ + def __init__(self): + super(KL, self).__init__() + self.g = nn.Geometric(0.7, dtype=dtype.int32) + + @ms_function + def construct(self, x_): + return self.g('kl_loss', 'Geometric', x_) + +def test_kl_loss(): + """ + Test kl_loss. + """ + probs1_a = 0.7 + probs1_b = 0.5 + probs0_a = 1 - probs1_a + probs0_b = 1 - probs1_b + expect_kl_loss = np.log(probs1_a / probs1_b) + (probs0_a / probs1_a) * np.log(probs0_a / probs0_b) + kl_loss = KL() + output = kl_loss(Tensor([probs1_b], dtype=dtype.float32)) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() + +class Basics(nn.Cell): + """ + Test class: mean/sd/mode of Geometric distribution. + """ + def __init__(self): + super(Basics, self).__init__() + self.g = nn.Geometric([0.5, 0.5], dtype=dtype.int32) + + @ms_function + def construct(self): + return self.g('mean'), self.g('sd'), self.g('mode') + +def test_basics(): + """ + Test mean/standard deviation/mode. + """ + basics = Basics() + mean, sd, mode = basics() + expect_mean = [1.0, 1.0] + expect_sd = np.sqrt(np.array([0.5, 0.5]) / np.square(np.array([0.5, 0.5]))) + expect_mode = [0.0, 0.0] + tol = 1e-6 + assert (np.abs(mean.asnumpy()- expect_mean) < tol).all() + assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() + assert (np.abs(mode.asnumpy() - expect_mode) < tol).all() + +class Sampling(nn.Cell): + """ + Test class: log probability of bernoulli distribution. + """ + def __init__(self, shape, seed=0): + super(Sampling, self).__init__() + self.g = nn.Geometric([0.7, 0.5], seed=seed, dtype=dtype.int32) + self.shape = shape + + @ms_function + def construct(self, probs=None): + return self.g('sample', self.shape, probs) + +def test_sample(): + """ + Test sample. + """ + shape = (2, 3) + sample = Sampling(shape) + output = sample() + assert output.shape == (2, 3, 2) + +class CDF(nn.Cell): + """ + Test class: cdf of Geometric distribution. + """ + def __init__(self): + super(CDF, self).__init__() + self.g = nn.Geometric(0.7, dtype=dtype.int32) + + @ms_function + def construct(self, x_): + return self.g('cdf', x_) + +def test_cdf(): + """ + Test cdf. + """ + geom_benchmark = stats.geom(0.7) + expect_cdf = geom_benchmark.cdf([0, 1, 2, 3, 4]).astype(np.float32) + x_ = Tensor(np.array([-1, 0, 1, 2, 3]).astype(np.int32), dtype=dtype.float32) + cdf = CDF() + output = cdf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() + +class LogCDF(nn.Cell): + """ + Test class: log cdf of Geometric distribution. + """ + def __init__(self): + super(LogCDF, self).__init__() + self.g = nn.Geometric(0.7, dtype=dtype.int32) + + @ms_function + def construct(self, x_): + return self.g('log_cdf', x_) + +def test_logcdf(): + """ + Test log_cdf. + """ + geom_benchmark = stats.geom(0.7) + expect_logcdf = geom_benchmark.logcdf([1, 2, 3, 4, 5]).astype(np.float32) + x_ = Tensor(np.array([0, 1, 2, 3, 4]).astype(np.int32), dtype=dtype.float32) + logcdf = LogCDF() + output = logcdf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all() + +class SF(nn.Cell): + """ + Test class: survial funciton of Geometric distribution. + """ + def __init__(self): + super(SF, self).__init__() + self.g = nn.Geometric(0.7, dtype=dtype.int32) + + @ms_function + def construct(self, x_): + return self.g('survival_function', x_) + +def test_survival(): + """ + Test survival function. + """ + geom_benchmark = stats.geom(0.7) + expect_survival = geom_benchmark.sf([0, 1, 2, 3, 4]).astype(np.float32) + x_ = Tensor(np.array([-1, 0, 1, 2, 3]).astype(np.int32), dtype=dtype.float32) + sf = SF() + output = sf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_survival) < tol).all() + +class LogSF(nn.Cell): + """ + Test class: log survial funciton of Geometric distribution. + """ + def __init__(self): + super(LogSF, self).__init__() + self.g = nn.Geometric(0.7, dtype=dtype.int32) + + @ms_function + def construct(self, x_): + return self.g('log_survival', x_) + +def test_log_survival(): + """ + Test log_survival function. + """ + geom_benchmark = stats.geom(0.7) + expect_logsurvival = geom_benchmark.logsf([0, 1, 2, 3, 4]).astype(np.float32) + x_ = Tensor(np.array([-1, 0, 1, 2, 3]).astype(np.float32), dtype=dtype.float32) + log_sf = LogSF() + output = log_sf(x_) + tol = 5e-6 + assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all() + +class EntropyH(nn.Cell): + """ + Test class: entropy of Geometric distribution. + """ + def __init__(self): + super(EntropyH, self).__init__() + self.g = nn.Geometric(0.7, dtype=dtype.int32) + + @ms_function + def construct(self): + return self.g('entropy') + +def test_entropy(): + """ + Test entropy. + """ + geom_benchmark = stats.geom(0.7) + expect_entropy = geom_benchmark.entropy().astype(np.float32) + entropy = EntropyH() + output = entropy() + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() + +class CrossEntropy(nn.Cell): + """ + Test class: cross entropy between Geometric distributions. + """ + def __init__(self): + super(CrossEntropy, self).__init__() + self.g = nn.Geometric(0.7, dtype=dtype.int32) + + @ms_function + def construct(self, x_): + entropy = self.g('entropy') + kl_loss = self.g('kl_loss', 'Geometric', x_) + h_sum_kl = entropy + kl_loss + ans = self.g('cross_entropy', 'Geometric', x_) + return h_sum_kl - ans + +def test_cross_entropy(): + """ + Test cross_entropy. + """ + cross_entropy = CrossEntropy() + prob = Tensor([0.5], dtype=dtype.float32) + diff = cross_entropy(prob) + tol = 1e-6 + assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all() diff --git a/tests/st/ops/ascend/test_distribution/test_normal.py b/tests/st/ops/ascend/test_distribution/test_normal.py index 52bb1173ee..d3a93c244c 100644 --- a/tests/st/ops/ascend/test_distribution/test_normal.py +++ b/tests/st/ops/ascend/test_distribution/test_normal.py @@ -23,89 +23,66 @@ from mindspore import dtype context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") -class Net(nn.Cell): +class Prob(nn.Cell): """ - Test class: probability of normal distribution. + Test class: probability of Normal distribution. """ def __init__(self): - super(Net, self).__init__() + super(Prob, self).__init__() self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) @ms_function def construct(self, x_): return self.n('prob', x_) -class Net1(nn.Cell): - """ - Test class: log probability of normal distribution. - """ - def __init__(self): - super(Net1, self).__init__() - self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) - - @ms_function - def construct(self, x_): - return self.n('log_prob', x_) - -class Net2(nn.Cell): - """ - Test class: kl_loss of normal distribution. - """ - def __init__(self): - super(Net2, self).__init__() - self.n = nn.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) - - @ms_function - def construct(self, x_, y_): - return self.n('kl_loss', 'Normal', x_, y_) - -class Net3(nn.Cell): - """ - Test class: mean/sd of normal distribution. - """ - def __init__(self): - super(Net3, self).__init__() - self.n = nn.Normal(np.array([3.0]), np.array([2.0, 4.0]), dtype=dtype.float32) - - @ms_function - def construct(self): - return self.n('mean'), self.n('sd') - -class Net4(nn.Cell): - """ - Test class: mean/sd of normal distribution. - """ - def __init__(self, shape, seed=0): - super(Net4, self).__init__() - self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), seed=seed, dtype=dtype.float32) - self.shape = shape - - @ms_function - def construct(self, mean=None, sd=None): - return self.n('sample', self.shape, mean, sd) - def test_pdf(): """ Test pdf. """ norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) expect_pdf = norm_benchmark.pdf([1.0, 2.0]).astype(np.float32) - pdf = Net() + pdf = Prob() output = pdf(Tensor([1.0, 2.0], dtype=dtype.float32)) tol = 1e-6 assert (np.abs(output.asnumpy() - expect_pdf) < tol).all() +class LogProb(nn.Cell): + """ + Test class: log probability of Normal distribution. + """ + def __init__(self): + super(LogProb, self).__init__() + self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) + + @ms_function + def construct(self, x_): + return self.n('log_prob', x_) + def test_log_likelihood(): """ Test log_pdf. """ norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) expect_logpdf = norm_benchmark.logpdf([1.0, 2.0]).astype(np.float32) - logprob = Net1() + logprob = LogProb() output = logprob(Tensor([1.0, 2.0], dtype=dtype.float32)) tol = 1e-6 assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all() + +class KL(nn.Cell): + """ + Test class: kl_loss of Normal distribution. + """ + def __init__(self): + super(KL, self).__init__() + self.n = nn.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) + + @ms_function + def construct(self, x_, y_): + return self.n('kl_loss', 'Normal', x_, y_) + + def test_kl_loss(): """ Test kl_loss. @@ -120,25 +97,51 @@ def test_kl_loss(): squared_diff = np.square(mean_a / sd_b - mean_b / sd_b) expect_kl_loss = 0.5 * squared_diff + 0.5 * np.expm1(2 * diff_log_scale) - diff_log_scale - kl_loss = Net2() + kl_loss = KL() mean = Tensor(mean_b, dtype=dtype.float32) sd = Tensor(sd_b, dtype=dtype.float32) output = kl_loss(mean, sd) tol = 1e-6 assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() +class Basics(nn.Cell): + """ + Test class: mean/sd/mode of Normal distribution. + """ + def __init__(self): + super(Basics, self).__init__() + self.n = nn.Normal(np.array([3.0]), np.array([2.0, 4.0]), dtype=dtype.float32) + + @ms_function + def construct(self): + return self.n('mean'), self.n('sd'), self.n('mode') + def test_basics(): """ - Test mean/standard deviation. + Test mean/standard deviation/mode. """ - basics = Net3() - mean, sd = basics() + basics = Basics() + mean, sd, mode = basics() expect_mean = [3.0, 3.0] expect_sd = [2.0, 4.0] tol = 1e-6 assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() + assert (np.abs(mode.asnumpy() - expect_mean) < tol).all() assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() +class Sampling(nn.Cell): + """ + Test class: sample of Normal distribution. + """ + def __init__(self, shape, seed=0): + super(Sampling, self).__init__() + self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), seed=seed, dtype=dtype.float32) + self.shape = shape + + @ms_function + def construct(self, mean=None, sd=None): + return self.n('sample', self.shape, mean, sd) + def test_sample(): """ Test sample. @@ -147,6 +150,149 @@ def test_sample(): seed = 10 mean = Tensor([2.0], dtype=dtype.float32) sd = Tensor([2.0, 2.0, 2.0], dtype=dtype.float32) - sample = Net4(shape, seed=seed) + sample = Sampling(shape, seed=seed) output = sample(mean, sd) assert output.shape == (2, 3, 3) + +class CDF(nn.Cell): + """ + Test class: cdf of Normal distribution. + """ + def __init__(self): + super(CDF, self).__init__() + self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) + + @ms_function + def construct(self, x_): + return self.n('cdf', x_) + + +def test_cdf(): + """ + Test cdf. + """ + norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) + expect_cdf = norm_benchmark.cdf([1.0, 2.0]).astype(np.float32) + cdf = CDF() + output = cdf(Tensor([1.0, 2.0], dtype=dtype.float32)) + tol = 2e-5 + assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() + +class LogCDF(nn.Cell): + """ + Test class: log_cdf of Mormal distribution. + """ + def __init__(self): + super(LogCDF, self).__init__() + self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) + + @ms_function + def construct(self, x_): + return self.n('log_cdf', x_) + +def test_log_cdf(): + """ + Test log cdf. + """ + norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) + expect_logcdf = norm_benchmark.logcdf([1.0, 2.0]).astype(np.float32) + logcdf = LogCDF() + output = logcdf(Tensor([1.0, 2.0], dtype=dtype.float32)) + tol = 5e-5 + assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all() + +class SF(nn.Cell): + """ + Test class: survival function of Normal distribution. + """ + def __init__(self): + super(SF, self).__init__() + self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) + + @ms_function + def construct(self, x_): + return self.n('survival_function', x_) + +def test_survival(): + """ + Test log_survival. + """ + norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) + expect_survival = norm_benchmark.sf([1.0, 2.0]).astype(np.float32) + survival_function = SF() + output = survival_function(Tensor([1.0, 2.0], dtype=dtype.float32)) + tol = 2e-5 + assert (np.abs(output.asnumpy() - expect_survival) < tol).all() + +class LogSF(nn.Cell): + """ + Test class: log survival function of Normal distribution. + """ + def __init__(self): + super(LogSF, self).__init__() + self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) + + @ms_function + def construct(self, x_): + return self.n('log_survival', x_) + +def test_log_survival(): + """ + Test log_survival. + """ + norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) + expect_log_survival = norm_benchmark.logsf([1.0, 2.0]).astype(np.float32) + log_survival = LogSF() + output = log_survival(Tensor([1.0, 2.0], dtype=dtype.float32)) + tol = 2e-5 + assert (np.abs(output.asnumpy() - expect_log_survival) < tol).all() + +class EntropyH(nn.Cell): + """ + Test class: entropy of Normal distribution. + """ + def __init__(self): + super(EntropyH, self).__init__() + self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) + + @ms_function + def construct(self): + return self.n('entropy') + +def test_entropy(): + """ + Test entropy. + """ + norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) + expect_entropy = norm_benchmark.entropy().astype(np.float32) + entropy = EntropyH() + output = entropy() + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() + +class CrossEntropy(nn.Cell): + """ + Test class: cross entropy between Normal distributions. + """ + def __init__(self): + super(CrossEntropy, self).__init__() + self.n = nn.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) + + @ms_function + def construct(self, x_, y_): + entropy = self.n('entropy') + kl_loss = self.n('kl_loss', 'Normal', x_, y_) + h_sum_kl = entropy + kl_loss + cross_entropy = self.n('cross_entropy', 'Normal', x_, y_) + return h_sum_kl - cross_entropy + +def test_cross_entropy(): + """ + Test cross_entropy. + """ + cross_entropy = CrossEntropy() + mean = Tensor([1.0], dtype=dtype.float32) + sd = Tensor([1.0], dtype=dtype.float32) + diff = cross_entropy(mean, sd) + tol = 1e-6 + assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all() diff --git a/tests/st/ops/ascend/test_distribution/test_uniform.py b/tests/st/ops/ascend/test_distribution/test_uniform.py new file mode 100644 index 0000000000..bfcf9b7235 --- /dev/null +++ b/tests/st/ops/ascend/test_distribution/test_uniform.py @@ -0,0 +1,293 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test cases for uniform distribution""" +import numpy as np +from scipy import stats +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.api import ms_function +from mindspore import dtype + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +class Prob(nn.Cell): + """ + Test class: probability of Uniform distribution. + """ + def __init__(self): + super(Prob, self).__init__() + self.u = nn.Uniform([0.0], [[1.0], [2.0]], dtype=dtype.float32) + + @ms_function + def construct(self, x_): + return self.u('prob', x_) + +def test_pdf(): + """ + Test pdf. + """ + uniform_benchmark = stats.uniform([0.0], [[1.0], [2.0]]) + expect_pdf = uniform_benchmark.pdf([-1.0, 0.0, 0.5, 1.0, 1.5, 3.0]).astype(np.float32) + pdf = Prob() + x_ = Tensor(np.array([-1.0, 0.0, 0.5, 1.0, 1.5, 3.0]).astype(np.float32), dtype=dtype.float32) + output = pdf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_pdf) < tol).all() + +class LogProb(nn.Cell): + """ + Test class: log probability of Uniform distribution. + """ + def __init__(self): + super(LogProb, self).__init__() + self.u = nn.Uniform([0.0], [[1.0], [2.0]], dtype=dtype.float32) + + @ms_function + def construct(self, x_): + return self.u('log_prob', x_) + +def test_log_likelihood(): + """ + Test log_pdf. + """ + uniform_benchmark = stats.uniform([0.0], [[1.0], [2.0]]) + expect_logpdf = uniform_benchmark.logpdf([0.5]).astype(np.float32) + logprob = LogProb() + x_ = Tensor(np.array([0.5]).astype(np.float32), dtype=dtype.float32) + output = logprob(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all() + +class KL(nn.Cell): + """ + Test class: kl_loss between Uniform distributions. + """ + def __init__(self): + super(KL, self).__init__() + self.u = nn.Uniform([0.0], [1.5], dtype=dtype.float32) + + @ms_function + def construct(self, x_, y_): + return self.u('kl_loss', 'Uniform', x_, y_) + +def test_kl_loss(): + """ + Test kl_loss. + """ + low_a = 0.0 + high_a = 1.5 + low_b = -1.0 + high_b = 2.0 + expect_kl_loss = np.log(high_b - low_b) / np.log(high_a - low_a) + kl = KL() + output = kl(Tensor(low_b, dtype=dtype.float32), Tensor(high_b, dtype=dtype.float32)) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() + +class Basics(nn.Cell): + """ + Test class: mean/sd of Uniform distribution. + """ + def __init__(self): + super(Basics, self).__init__() + self.u = nn.Uniform([0.0], [3.0], dtype=dtype.float32) + + @ms_function + def construct(self): + return self.u('mean'), self.u('sd') + +def test_basics(): + """ + Test mean/standard deviation. + """ + basics = Basics() + mean, sd = basics() + expect_mean = [1.5] + expect_sd = np.sqrt([0.75]) + tol = 1e-6 + assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() + assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() + +class Sampling(nn.Cell): + """ + Test class: sample of Uniform distribution. + """ + def __init__(self, shape, seed=0): + super(Sampling, self).__init__() + self.u = nn.Uniform([0.0], [[1.0], [2.0]], seed=seed, dtype=dtype.float32) + self.shape = shape + + @ms_function + def construct(self, low=None, high=None): + return self.u('sample', self.shape, low, high) + +def test_sample(): + """ + Test sample. + """ + shape = (2, 3) + seed = 10 + low = Tensor([1.0], dtype=dtype.float32) + high = Tensor([2.0, 3.0, 4.0], dtype=dtype.float32) + sample = Sampling(shape, seed=seed) + output = sample(low, high) + assert output.shape == (2, 3, 3) + +class CDF(nn.Cell): + """ + Test class: cdf of Uniform distribution. + """ + def __init__(self): + super(CDF, self).__init__() + self.u = nn.Uniform([0.0], [1.0], dtype=dtype.float32) + + @ms_function + def construct(self, x_): + return self.u('cdf', x_) + +def test_cdf(): + """ + Test cdf. + """ + uniform_benchmark = stats.uniform([0.0], [1.0]) + expect_cdf = uniform_benchmark.cdf([-1.0, 0.5, 1.0, 2.0]).astype(np.float32) + cdf = CDF() + x_ = Tensor(np.array([-1.0, 0.5, 1.0, 2.0]).astype(np.float32), dtype=dtype.float32) + output = cdf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() + +class LogCDF(nn.Cell): + """ + Test class: log_cdf of Uniform distribution. + """ + def __init__(self): + super(LogCDF, self).__init__() + self.u = nn.Uniform([0.0], [1.0], dtype=dtype.float32) + + @ms_function + def construct(self, x_): + return self.u('log_cdf', x_) + +class SF(nn.Cell): + """ + Test class: survival function of Uniform distribution. + """ + def __init__(self): + super(SF, self).__init__() + self.u = nn.Uniform([0.0], [1.0], dtype=dtype.float32) + + @ms_function + def construct(self, x_): + return self.u('survival_function', x_) + +class LogSF(nn.Cell): + """ + Test class: log survival function of Uniform distribution. + """ + def __init__(self): + super(LogSF, self).__init__() + self.u = nn.Uniform([0.0], [1.0], dtype=dtype.float32) + + @ms_function + def construct(self, x_): + return self.u('log_survival', x_) + +class EntropyH(nn.Cell): + """ + Test class: entropy of Uniform distribution. + """ + def __init__(self): + super(EntropyH, self).__init__() + self.u = nn.Uniform([0.0], [1.0, 2.0], dtype=dtype.float32) + + @ms_function + def construct(self): + return self.u('entropy') + +def test_entropy(): + """ + Test entropy. + """ + uniform_benchmark = stats.uniform([0.0], [1.0, 2.0]) + expect_entropy = uniform_benchmark.entropy().astype(np.float32) + entropy = EntropyH() + output = entropy() + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() + +class CrossEntropy(nn.Cell): + """ + Test class: cross_entropy between Uniform distributions. + """ + def __init__(self): + super(CrossEntropy, self).__init__() + self.u = nn.Uniform([0.0], [1.5], dtype=dtype.float32) + + @ms_function + def construct(self, x_, y_): + entropy = self.u('entropy') + kl_loss = self.u('kl_loss', 'Uniform', x_, y_) + h_sum_kl = entropy + kl_loss + cross_entropy = self.u('cross_entropy', 'Uniform', x_, y_) + return h_sum_kl - cross_entropy + +def test_log_cdf(): + """ + Test log_cdf. + """ + uniform_benchmark = stats.uniform([0.0], [1.0]) + expect_logcdf = uniform_benchmark.logcdf([0.5, 0.8, 2.0]).astype(np.float32) + logcdf = LogCDF() + x_ = Tensor(np.array([0.5, 0.8, 2.0]).astype(np.float32), dtype=dtype.float32) + output = logcdf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all() + +def test_survival(): + """ + Test survival function. + """ + uniform_benchmark = stats.uniform([0.0], [1.0]) + expect_survival = uniform_benchmark.sf([-1.0, 0.5, 1.0, 2.0]).astype(np.float32) + survival = SF() + x_ = Tensor(np.array([-1.0, 0.5, 1.0, 2.0]).astype(np.float32), dtype=dtype.float32) + output = survival(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_survival) < tol).all() + +def test_log_survival(): + """ + Test log survival function. + """ + uniform_benchmark = stats.uniform([0.0], [1.0]) + expect_logsurvival = uniform_benchmark.logsf([0.5, 0.8, -2.0]).astype(np.float32) + logsurvival = LogSF() + x_ = Tensor(np.array([0.5, 0.8, -2.0]).astype(np.float32), dtype=dtype.float32) + output = logsurvival(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all() + +def test_cross_entropy(): + """ + Test cross_entropy. + """ + cross_entropy = CrossEntropy() + low_b = -1.0 + high_b = 2.0 + diff = cross_entropy(Tensor(low_b, dtype=dtype.float32), Tensor(high_b, dtype=dtype.float32)) + tol = 1e-6 + assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all() diff --git a/tests/ut/python/nn/distribution/test_bernoulli.py b/tests/ut/python/nn/distribution/test_bernoulli.py new file mode 100644 index 0000000000..9233e2d395 --- /dev/null +++ b/tests/ut/python/nn/distribution/test_bernoulli.py @@ -0,0 +1,165 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Test nn.Distribution.Bernoulli. +""" +import pytest + +import mindspore.nn as nn +from mindspore import dtype +from mindspore import Tensor + +def test_arguments(): + """ + Args passing during initialization. + """ + b = nn.Bernoulli() + assert isinstance(b, nn.Distribution) + b = nn.Bernoulli([0.0, 0.3, 0.5, 1.0], dtype=dtype.int32) + assert isinstance(b, nn.Distribution) + +def test_prob(): + """ + Invalid probability. + """ + with pytest.raises(ValueError): + nn.Bernoulli([-0.1], dtype=dtype.int32) + with pytest.raises(ValueError): + nn.Bernoulli([1.1], dtype=dtype.int32) + +class BernoulliProb(nn.Cell): + """ + Bernoulli distribution: initialize with probs. + """ + def __init__(self): + super(BernoulliProb, self).__init__() + self.b = nn.Bernoulli(0.5, dtype=dtype.int32) + + def construct(self, value): + prob = self.b('prob', value) + log_prob = self.b('log_prob', value) + cdf = self.b('cdf', value) + log_cdf = self.b('log_cdf', value) + sf = self.b('survival_function', value) + log_sf = self.b('log_survival', value) + return prob + log_prob + cdf + log_cdf + sf + log_sf + +def test_bernoulli_prob(): + """ + Test probability functions: passing value through construct. + """ + net = BernoulliProb() + value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32) + ans = net(value) + assert isinstance(ans, Tensor) + +class BernoulliProb1(nn.Cell): + """ + Bernoulli distribution: initialize without probs. + """ + def __init__(self): + super(BernoulliProb1, self).__init__() + self.b = nn.Bernoulli(dtype=dtype.int32) + + def construct(self, value, probs): + prob = self.b('prob', value, probs) + log_prob = self.b('log_prob', value, probs) + cdf = self.b('cdf', value, probs) + log_cdf = self.b('log_cdf', value, probs) + sf = self.b('survival_function', value, probs) + log_sf = self.b('log_survival', value, probs) + return prob + log_prob + cdf + log_cdf + sf + log_sf + +def test_bernoulli_prob1(): + """ + Test probability functions: passing value/probs through construct. + """ + net = BernoulliProb1() + value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32) + probs = Tensor([0.5], dtype=dtype.float32) + ans = net(value, probs) + assert isinstance(ans, Tensor) + +class BernoulliKl(nn.Cell): + """ + Test class: kl_loss between Bernoulli distributions. + """ + def __init__(self): + super(BernoulliKl, self).__init__() + self.b1 = nn.Bernoulli(0.7, dtype=dtype.int32) + self.b2 = nn.Bernoulli(dtype=dtype.int32) + + def construct(self, probs_b, probs_a): + kl1 = self.b1('kl_loss', 'Bernoulli', probs_b) + kl2 = self.b2('kl_loss', 'Bernoulli', probs_b, probs_a) + return kl1 + kl2 + +def test_kl(): + """ + Test kl_loss function. + """ + ber_net = BernoulliKl() + probs_b = Tensor([0.3], dtype=dtype.float32) + probs_a = Tensor([0.7], dtype=dtype.float32) + ans = ber_net(probs_b, probs_a) + assert isinstance(ans, Tensor) + +class BernoulliCrossEntropy(nn.Cell): + """ + Test class: cross_entropy of Bernoulli distribution. + """ + def __init__(self): + super(BernoulliCrossEntropy, self).__init__() + self.b1 = nn.Bernoulli(0.7, dtype=dtype.int32) + self.b2 = nn.Bernoulli(dtype=dtype.int32) + + def construct(self, probs_b, probs_a): + h1 = self.b1('cross_entropy', 'Bernoulli', probs_b) + h2 = self.b2('cross_entropy', 'Bernoulli', probs_b, probs_a) + return h1 + h2 + +def test_cross_entropy(): + """ + Test cross_entropy between Bernoulli distributions. + """ + net = BernoulliCrossEntropy() + probs_b = Tensor([0.3], dtype=dtype.float32) + probs_a = Tensor([0.7], dtype=dtype.float32) + ans = net(probs_b, probs_a) + assert isinstance(ans, Tensor) + +class BernoulliBasics(nn.Cell): + """ + Test class: basic mean/sd/var/mode/entropy function. + """ + def __init__(self): + super(BernoulliBasics, self).__init__() + self.b = nn.Bernoulli([0.3, 0.5], dtype=dtype.int32) + + def construct(self): + mean = self.b('mean') + sd = self.b('sd') + var = self.b('var') + mode = self.b('mode') + entropy = self.b('entropy') + return mean + sd + var + mode + entropy + +def test_bascis(): + """ + Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. + """ + net = BernoulliBasics() + ans = net() + assert isinstance(ans, Tensor) diff --git a/tests/ut/python/nn/distribution/test_exponential.py b/tests/ut/python/nn/distribution/test_exponential.py new file mode 100644 index 0000000000..57c69a4aa8 --- /dev/null +++ b/tests/ut/python/nn/distribution/test_exponential.py @@ -0,0 +1,166 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Test nn.Distribution.Exponential. +""" +import pytest + +import mindspore.nn as nn +from mindspore import dtype +from mindspore import Tensor + + +def test_arguments(): + """ + Args passing during initialization. + """ + e = nn.Exponential() + assert isinstance(e, nn.Distribution) + e = nn.Exponential([0.1, 0.3, 0.5, 1.0], dtype=dtype.float32) + assert isinstance(e, nn.Distribution) + +def test_rate(): + """ + Invalid rate. + """ + with pytest.raises(ValueError): + nn.Exponential([-0.1], dtype=dtype.float32) + with pytest.raises(ValueError): + nn.Exponential([0.0], dtype=dtype.float32) + +class ExponentialProb(nn.Cell): + """ + Exponential distribution: initialize with rate. + """ + def __init__(self): + super(ExponentialProb, self).__init__() + self.e = nn.Exponential(0.5, dtype=dtype.float32) + + def construct(self, value): + prob = self.e('prob', value) + log_prob = self.e('log_prob', value) + cdf = self.e('cdf', value) + log_cdf = self.e('log_cdf', value) + sf = self.e('survival_function', value) + log_sf = self.e('log_survival', value) + return prob + log_prob + cdf + log_cdf + sf + log_sf + +def test_exponential_prob(): + """ + Test probability functions: passing value through construct. + """ + net = ExponentialProb() + value = Tensor([0.2, 0.3, 5.0, 2, 3.9], dtype=dtype.float32) + ans = net(value) + assert isinstance(ans, Tensor) + +class ExponentialProb1(nn.Cell): + """ + Exponential distribution: initialize without rate. + """ + def __init__(self): + super(ExponentialProb1, self).__init__() + self.e = nn.Exponential(dtype=dtype.float32) + + def construct(self, value, rate): + prob = self.e('prob', value, rate) + log_prob = self.e('log_prob', value, rate) + cdf = self.e('cdf', value, rate) + log_cdf = self.e('log_cdf', value, rate) + sf = self.e('survival_function', value, rate) + log_sf = self.e('log_survival', value, rate) + return prob + log_prob + cdf + log_cdf + sf + log_sf + +def test_exponential_prob1(): + """ + Test probability functions: passing value/rate through construct. + """ + net = ExponentialProb1() + value = Tensor([0.2, 0.9, 1, 2, 3], dtype=dtype.float32) + rate = Tensor([0.5], dtype=dtype.float32) + ans = net(value, rate) + assert isinstance(ans, Tensor) + +class ExponentialKl(nn.Cell): + """ + Test class: kl_loss between Exponential distributions. + """ + def __init__(self): + super(ExponentialKl, self).__init__() + self.e1 = nn.Exponential(0.7, dtype=dtype.float32) + self.e2 = nn.Exponential(dtype=dtype.float32) + + def construct(self, rate_b, rate_a): + kl1 = self.e1('kl_loss', 'Exponential', rate_b) + kl2 = self.e2('kl_loss', 'Exponential', rate_b, rate_a) + return kl1 + kl2 + +def test_kl(): + """ + Test kl_loss function. + """ + net = ExponentialKl() + rate_b = Tensor([0.3], dtype=dtype.float32) + rate_a = Tensor([0.7], dtype=dtype.float32) + ans = net(rate_b, rate_a) + assert isinstance(ans, Tensor) + +class ExponentialCrossEntropy(nn.Cell): + """ + Test class: cross_entropy of Exponential distribution. + """ + def __init__(self): + super(ExponentialCrossEntropy, self).__init__() + self.e1 = nn.Exponential(0.3, dtype=dtype.float32) + self.e2 = nn.Exponential(dtype=dtype.float32) + + def construct(self, rate_b, rate_a): + h1 = self.e1('cross_entropy', 'Exponential', rate_b) + h2 = self.e2('cross_entropy', 'Exponential', rate_b, rate_a) + return h1 + h2 + +def test_cross_entropy(): + """ + Test cross_entropy between Exponential distributions. + """ + net = ExponentialCrossEntropy() + rate_b = Tensor([0.3], dtype=dtype.float32) + rate_a = Tensor([0.7], dtype=dtype.float32) + ans = net(rate_b, rate_a) + assert isinstance(ans, Tensor) + +class ExponentialBasics(nn.Cell): + """ + Test class: basic mean/sd/mode/entropy function. + """ + def __init__(self): + super(ExponentialBasics, self).__init__() + self.e = nn.Exponential([0.3, 0.5], dtype=dtype.float32) + + def construct(self): + mean = self.e('mean') + sd = self.e('sd') + var = self.e('var') + mode = self.e('mode') + entropy = self.e('entropy') + return mean + sd + var + mode + entropy + +def test_bascis(): + """ + Test mean/sd/var/mode/entropy functionality of Exponential distribution. + """ + net = ExponentialBasics() + ans = net() + assert isinstance(ans, Tensor) diff --git a/tests/ut/python/nn/distribution/test_geometric.py b/tests/ut/python/nn/distribution/test_geometric.py new file mode 100644 index 0000000000..6e7c73cdc2 --- /dev/null +++ b/tests/ut/python/nn/distribution/test_geometric.py @@ -0,0 +1,167 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Test nn.Distribution.Geometric. +""" +import pytest + +import mindspore.nn as nn +from mindspore import dtype +from mindspore import Tensor + + +def test_arguments(): + """ + Args passing during initialization. + """ + g = nn.Geometric() + assert isinstance(g, nn.Distribution) + g = nn.Geometric([0.0, 0.3, 0.5, 1.0], dtype=dtype.int32) + assert isinstance(g, nn.Distribution) + +def test_prob(): + """ + Invalid probability. + """ + with pytest.raises(ValueError): + nn.Geometric([-0.1], dtype=dtype.int32) + with pytest.raises(ValueError): + nn.Geometric([1.1], dtype=dtype.int32) + +class GeometricProb(nn.Cell): + """ + Geometric distribution: initialize with probs. + """ + def __init__(self): + super(GeometricProb, self).__init__() + self.g = nn.Geometric(0.5, dtype=dtype.int32) + + def construct(self, value): + prob = self.g('prob', value) + log_prob = self.g('log_prob', value) + cdf = self.g('cdf', value) + log_cdf = self.g('log_cdf', value) + sf = self.g('survival_function', value) + log_sf = self.g('log_survival', value) + return prob + log_prob + cdf + log_cdf + sf + log_sf + +def test_geometric_prob(): + """ + Test probability functions: passing value through construct. + """ + net = GeometricProb() + value = Tensor([3, 4, 5, 6, 7], dtype=dtype.float32) + ans = net(value) + assert isinstance(ans, Tensor) + +class GeometricProb1(nn.Cell): + """ + Geometric distribution: initialize without probs. + """ + def __init__(self): + super(GeometricProb1, self).__init__() + self.g = nn.Geometric(dtype=dtype.int32) + + def construct(self, value, probs): + prob = self.g('prob', value, probs) + log_prob = self.g('log_prob', value, probs) + cdf = self.g('cdf', value, probs) + log_cdf = self.g('log_cdf', value, probs) + sf = self.g('survival_function', value, probs) + log_sf = self.g('log_survival', value, probs) + return prob + log_prob + cdf + log_cdf + sf + log_sf + +def test_geometric_prob1(): + """ + Test probability functions: passing value/probs through construct. + """ + net = GeometricProb1() + value = Tensor([3, 4, 5, 6, 7], dtype=dtype.float32) + probs = Tensor([0.5], dtype=dtype.float32) + ans = net(value, probs) + assert isinstance(ans, Tensor) + + +class GeometricKl(nn.Cell): + """ + Test class: kl_loss between Geometric distributions. + """ + def __init__(self): + super(GeometricKl, self).__init__() + self.g1 = nn.Geometric(0.7, dtype=dtype.int32) + self.g2 = nn.Geometric(dtype=dtype.int32) + + def construct(self, probs_b, probs_a): + kl1 = self.g1('kl_loss', 'Geometric', probs_b) + kl2 = self.g2('kl_loss', 'Geometric', probs_b, probs_a) + return kl1 + kl2 + +def test_kl(): + """ + Test kl_loss function. + """ + ber_net = GeometricKl() + probs_b = Tensor([0.3], dtype=dtype.float32) + probs_a = Tensor([0.7], dtype=dtype.float32) + ans = ber_net(probs_b, probs_a) + assert isinstance(ans, Tensor) + +class GeometricCrossEntropy(nn.Cell): + """ + Test class: cross_entropy of Geometric distribution. + """ + def __init__(self): + super(GeometricCrossEntropy, self).__init__() + self.g1 = nn.Geometric(0.3, dtype=dtype.int32) + self.g2 = nn.Geometric(dtype=dtype.int32) + + def construct(self, probs_b, probs_a): + h1 = self.g1('cross_entropy', 'Geometric', probs_b) + h2 = self.g2('cross_entropy', 'Geometric', probs_b, probs_a) + return h1 + h2 + +def test_cross_entropy(): + """ + Test cross_entropy between Geometric distributions. + """ + net = GeometricCrossEntropy() + probs_b = Tensor([0.3], dtype=dtype.float32) + probs_a = Tensor([0.7], dtype=dtype.float32) + ans = net(probs_b, probs_a) + assert isinstance(ans, Tensor) + +class GeometricBasics(nn.Cell): + """ + Test class: basic mean/sd/mode/entropy function. + """ + def __init__(self): + super(GeometricBasics, self).__init__() + self.g = nn.Geometric([0.3, 0.5], dtype=dtype.int32) + + def construct(self): + mean = self.g('mean') + sd = self.g('sd') + var = self.g('var') + mode = self.g('mode') + entropy = self.g('entropy') + return mean + sd + var + mode + entropy + +def test_bascis(): + """ + Test mean/sd/mode/entropy functionality of Geometric distribution. + """ + net = GeometricBasics() + ans = net() + assert isinstance(ans, Tensor) diff --git a/tests/ut/python/nn/distribution/test_normal.py b/tests/ut/python/nn/distribution/test_normal.py new file mode 100644 index 0000000000..87a92ad8da --- /dev/null +++ b/tests/ut/python/nn/distribution/test_normal.py @@ -0,0 +1,171 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Test nn.Distribution.Normal. +""" +import numpy as np +import pytest + +import mindspore.nn as nn +from mindspore import dtype +from mindspore import Tensor + +def test_normal_shape_errpr(): + """ + Invalid shapes. + """ + with pytest.raises(ValueError): + nn.Normal([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32) + + +def test_arguments(): + """ + args passing during initialization. + """ + n = nn.Normal() + assert isinstance(n, nn.Distribution) + n = nn.Normal([3.0], [4.0], dtype=dtype.float32) + assert isinstance(n, nn.Distribution) + + +class NormalProb(nn.Cell): + """ + Normal distribution: initialize with mean/sd. + """ + def __init__(self): + super(NormalProb, self).__init__() + self.normal = nn.Normal(3.0, 4.0, dtype=dtype.float32) + + def construct(self, value): + prob = self.normal('prob', value) + log_prob = self.normal('log_prob', value) + cdf = self.normal('cdf', value) + log_cdf = self.normal('log_cdf', value) + sf = self.normal('survival_function', value) + log_sf = self.normal('log_survival', value) + return prob + log_prob + cdf + log_cdf + sf + log_sf + +def test_normal_prob(): + """ + Test probability functions: passing value through construct. + """ + net = NormalProb() + value = Tensor([0.5, 1.0], dtype=dtype.float32) + ans = net(value) + assert isinstance(ans, Tensor) + + +class NormalProb1(nn.Cell): + """ + Normal distribution: initialize without mean/sd. + """ + def __init__(self): + super(NormalProb1, self).__init__() + self.normal = nn.Normal() + + def construct(self, value, mean, sd): + prob = self.normal('prob', value, mean, sd) + log_prob = self.normal('log_prob', value, mean, sd) + cdf = self.normal('cdf', value, mean, sd) + log_cdf = self.normal('log_cdf', value, mean, sd) + sf = self.normal('survival_function', value, mean, sd) + log_sf = self.normal('log_survival', value, mean, sd) + return prob + log_prob + cdf + log_cdf + sf + log_sf + +def test_normal_prob1(): + """ + Test probability functions: passing mean/sd, value through construct. + """ + net = NormalProb1() + value = Tensor([0.5, 1.0], dtype=dtype.float32) + mean = Tensor([0.0], dtype=dtype.float32) + sd = Tensor([1.0], dtype=dtype.float32) + ans = net(value, mean, sd) + assert isinstance(ans, Tensor) + +class NormalKl(nn.Cell): + """ + Test class: kl_loss of Normal distribution. + """ + def __init__(self): + super(NormalKl, self).__init__() + self.n1 = nn.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) + self.n2 = nn.Normal(dtype=dtype.float32) + + def construct(self, mean_b, sd_b, mean_a, sd_a): + kl1 = self.n1('kl_loss', 'Normal', mean_b, sd_b) + kl2 = self.n2('kl_loss', 'Normal', mean_b, sd_b, mean_a, sd_a) + return kl1 + kl2 + +def test_kl(): + """ + Test kl_loss. + """ + net = NormalKl() + mean_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) + sd_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) + mean_a = Tensor(np.array([2.0]).astype(np.float32), dtype=dtype.float32) + sd_a = Tensor(np.array([3.0]).astype(np.float32), dtype=dtype.float32) + ans = net(mean_b, sd_b, mean_a, sd_a) + assert isinstance(ans, Tensor) + +class NormalCrossEntropy(nn.Cell): + """ + Test class: cross_entropy of Normal distribution. + """ + def __init__(self): + super(NormalCrossEntropy, self).__init__() + self.n1 = nn.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) + self.n2 = nn.Normal(dtype=dtype.float32) + + def construct(self, mean_b, sd_b, mean_a, sd_a): + h1 = self.n1('cross_entropy', 'Normal', mean_b, sd_b) + h2 = self.n2('cross_entropy', 'Normal', mean_b, sd_b, mean_a, sd_a) + return h1 + h2 + +def test_cross_entropy(): + """ + Test cross entropy between Normal distributions. + """ + net = NormalCrossEntropy() + mean_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) + sd_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) + mean_a = Tensor(np.array([2.0]).astype(np.float32), dtype=dtype.float32) + sd_a = Tensor(np.array([3.0]).astype(np.float32), dtype=dtype.float32) + ans = net(mean_b, sd_b, mean_a, sd_a) + assert isinstance(ans, Tensor) + +class NormalBasics(nn.Cell): + """ + Test class: basic mean/sd function. + """ + def __init__(self): + super(NormalBasics, self).__init__() + self.n = nn.Normal(3.0, 4.0, dtype=dtype.float32) + + def construct(self): + mean = self.n('mean') + sd = self.n('sd') + mode = self.n('mode') + entropy = self.n('entropy') + return mean + sd + mode + entropy + +def test_bascis(): + """ + Test mean/sd/mode/entropy functionality of Normal. + """ + net = NormalBasics() + ans = net() + assert isinstance(ans, Tensor) diff --git a/tests/ut/python/nn/distribution/test_uniform.py b/tests/ut/python/nn/distribution/test_uniform.py new file mode 100644 index 0000000000..7f9b442816 --- /dev/null +++ b/tests/ut/python/nn/distribution/test_uniform.py @@ -0,0 +1,180 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Test nn.Distribution.Uniform. +""" +import numpy as np +import pytest + +import mindspore.nn as nn +from mindspore import dtype +from mindspore import Tensor + +def test_uniform_shape_errpr(): + """ + Invalid shapes. + """ + with pytest.raises(ValueError): + nn.Uniform([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32) + + +def test_arguments(): + """ + Args passing during initialization. + """ + u = nn.Uniform() + assert isinstance(u, nn.Distribution) + u = nn.Uniform([3.0], [4.0], dtype=dtype.float32) + assert isinstance(u, nn.Distribution) + + +def test_invalid_range(): + """ + Test range of uniform distribution. + """ + with pytest.raises(ValueError): + nn.Uniform(0.0, 0.0, dtype=dtype.float32) + with pytest.raises(ValueError): + nn.Uniform(1.0, 0.0, dtype=dtype.float32) + + +class UniformProb(nn.Cell): + """ + Uniform distribution: initialize with low/high. + """ + def __init__(self): + super(UniformProb, self).__init__() + self.u = nn.Uniform(3.0, 4.0, dtype=dtype.float32) + + def construct(self, value): + prob = self.u('prob', value) + log_prob = self.u('log_prob', value) + cdf = self.u('cdf', value) + log_cdf = self.u('log_cdf', value) + sf = self.u('survival_function', value) + log_sf = self.u('log_survival', value) + return prob + log_prob + cdf + log_cdf + sf + log_sf + +def test_uniform_prob(): + """ + Test probability functions: passing value through construct. + """ + net = UniformProb() + value = Tensor([3.1, 3.2, 3.3, 3.4], dtype=dtype.float32) + ans = net(value) + assert isinstance(ans, Tensor) + +class UniformProb1(nn.Cell): + """ + Uniform distribution: initialize without low/high. + """ + def __init__(self): + super(UniformProb1, self).__init__() + self.u = nn.Uniform(dtype=dtype.float32) + + def construct(self, value, low, high): + prob = self.u('prob', value, low, high) + log_prob = self.u('log_prob', value, low, high) + cdf = self.u('cdf', value, low, high) + log_cdf = self.u('log_cdf', value, low, high) + sf = self.u('survival_function', value, low, high) + log_sf = self.u('log_survival', value, low, high) + return prob + log_prob + cdf + log_cdf + sf + log_sf + +def test_uniform_prob1(): + """ + Test probability functions: passing low/high, value through construct. + """ + net = UniformProb1() + value = Tensor([0.1, 0.2, 0.3, 0.9], dtype=dtype.float32) + low = Tensor([0.0], dtype=dtype.float32) + high = Tensor([1.0], dtype=dtype.float32) + ans = net(value, low, high) + assert isinstance(ans, Tensor) + +class UniformKl(nn.Cell): + """ + Test class: kl_loss of Uniform distribution. + """ + def __init__(self): + super(UniformKl, self).__init__() + self.u1 = nn.Uniform(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) + self.u2 = nn.Uniform(dtype=dtype.float32) + + def construct(self, low_b, high_b, low_a, high_a): + kl1 = self.u1('kl_loss', 'Uniform', low_b, high_b) + kl2 = self.u2('kl_loss', 'Uniform', low_b, high_b, low_a, high_a) + return kl1 + kl2 + +def test_kl(): + """ + Test kl_loss. + """ + net = UniformKl() + low_b = Tensor(np.array([0.0]).astype(np.float32), dtype=dtype.float32) + high_b = Tensor(np.array([5.0]).astype(np.float32), dtype=dtype.float32) + low_a = Tensor(np.array([2.0]).astype(np.float32), dtype=dtype.float32) + high_a = Tensor(np.array([3.0]).astype(np.float32), dtype=dtype.float32) + ans = net(low_b, high_b, low_a, high_a) + assert isinstance(ans, Tensor) + +class UniformCrossEntropy(nn.Cell): + """ + Test class: cross_entropy of Uniform distribution. + """ + def __init__(self): + super(UniformCrossEntropy, self).__init__() + self.u1 = nn.Uniform(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) + self.u2 = nn.Uniform(dtype=dtype.float32) + + def construct(self, low_b, high_b, low_a, high_a): + h1 = self.u1('cross_entropy', 'Uniform', low_b, high_b) + h2 = self.u2('cross_entropy', 'Uniform', low_b, high_b, low_a, high_a) + return h1 + h2 + +def test_cross_entropy(): + """ + Test cross_entropy between Unifrom distributions. + """ + net = UniformCrossEntropy() + low_b = Tensor(np.array([0.0]).astype(np.float32), dtype=dtype.float32) + high_b = Tensor(np.array([5.0]).astype(np.float32), dtype=dtype.float32) + low_a = Tensor(np.array([2.0]).astype(np.float32), dtype=dtype.float32) + high_a = Tensor(np.array([3.0]).astype(np.float32), dtype=dtype.float32) + ans = net(low_b, high_b, low_a, high_a) + assert isinstance(ans, Tensor) + +class UniformBasics(nn.Cell): + """ + Test class: basic mean/sd/var/mode/entropy function. + """ + def __init__(self): + super(UniformBasics, self).__init__() + self.u = nn.Uniform(3.0, 4.0, dtype=dtype.float32) + + def construct(self): + mean = self.u('mean') + sd = self.u('sd') + var = self.u('var') + entropy = self.u('entropy') + return mean + sd + var + entropy + +def test_bascis(): + """ + Test mean/sd/var/mode/entropy functionality of Uniform. + """ + net = UniformBasics() + ans = net() + assert isinstance(ans, Tensor) diff --git a/tests/ut/python/nn/test_distribution.py b/tests/ut/python/nn/test_distribution.py deleted file mode 100644 index b779814fd5..0000000000 --- a/tests/ut/python/nn/test_distribution.py +++ /dev/null @@ -1,369 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -Test nn.Distribution. - -Including Normal Distribution and Bernoulli Distribution. -""" -import pytest -import numpy as np - -import mindspore.nn as nn -from mindspore import dtype -from mindspore import Tensor - -def test_normal_shape_errpr(): - """ - Invalid shapes. - """ - with pytest.raises(ValueError): - nn.Normal([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32) - -def test_no_arguments(): - """ - No args passed in during initialization. - """ - n = nn.Normal() - assert isinstance(n, nn.Distribution) - b = nn.Bernoulli() - assert isinstance(b, nn.Distribution) - -def test_with_arguments(): - """ - Args passed in during initialization. - """ - n = nn.Normal([3.0], [4.0], dtype=dtype.float32) - assert isinstance(n, nn.Distribution) - b = nn.Bernoulli([0.3, 0.5], dtype=dtype.int32) - assert isinstance(b, nn.Distribution) - -class NormalProb(nn.Cell): - """ - Normal distribution: initialize with mean/sd. - """ - def __init__(self): - super(NormalProb, self).__init__() - self.normal = nn.Normal(3.0, 4.0, dtype=dtype.float32) - - def construct(self, value): - x = self.normal('prob', value) - y = self.normal('log_prob', value) - return x, y - -def test_normal_prob(): - """ - Test pdf/log_pdf: passing value through construct. - """ - net = NormalProb() - value = Tensor([0.5, 1.0], dtype=dtype.float32) - pdf, log_pdf = net(value) - assert isinstance(pdf, Tensor) - assert isinstance(log_pdf, Tensor) - -class NormalProb1(nn.Cell): - """ - Normal distribution: initialize without mean/sd. - """ - def __init__(self): - super(NormalProb1, self).__init__() - self.normal = nn.Normal() - - def construct(self, value, mean, sd): - x = self.normal('prob', value, mean, sd) - y = self.normal('log_prob', value, mean, sd) - return x, y - -def test_normal_prob1(): - """ - Test pdf/logpdf: passing mean/sd, value through construct. - """ - net = NormalProb1() - value = Tensor([0.5, 1.0], dtype=dtype.float32) - mean = Tensor([0.0], dtype=dtype.float32) - sd = Tensor([1.0], dtype=dtype.float32) - pdf, log_pdf = net(value, mean, sd) - assert isinstance(pdf, Tensor) - assert isinstance(log_pdf, Tensor) - -class NormalProb2(nn.Cell): - """ - Normal distribution: initialize with mean/sd. - """ - def __init__(self): - super(NormalProb2, self).__init__() - self.normal = nn.Normal(3.0, 4.0, dtype=dtype.float32) - - def construct(self, value, mean, sd): - x = self.normal('prob', value, mean, sd) - y = self.normal('log_prob', value, mean, sd) - return x, y - -def test_normal_prob2(): - """ - Test pdf/log_pdf: passing mean/sd through construct. - Overwrite original mean/sd. - """ - net = NormalProb2() - value = Tensor([0.5, 1.0], dtype=dtype.float32) - mean = Tensor([0.0], dtype=dtype.float32) - sd = Tensor([1.0], dtype=dtype.float32) - pdf, log_pdf = net(value, mean, sd) - assert isinstance(pdf, Tensor) - assert isinstance(log_pdf, Tensor) - -class BernoulliProb(nn.Cell): - """ - Bernoulli distribution: initialize with probs. - """ - def __init__(self): - super(BernoulliProb, self).__init__() - self.bernoulli = nn.Bernoulli(0.5, dtype=dtype.int32) - - def construct(self, value): - return self.bernoulli('prob', value) - -class BernoulliLogProb(nn.Cell): - """ - Bernoulli distribution: initialize with probs. - """ - def __init__(self): - super(BernoulliLogProb, self).__init__() - self.bernoulli = nn.Bernoulli(0.5, dtype=dtype.int32) - - def construct(self, value): - return self.bernoulli('log_prob', value) - - -def test_bernoulli_prob(): - """ - Test pmf/log_pmf: passing value through construct. - """ - net = BernoulliProb() - value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) - pmf = net(value) - assert isinstance(pmf, Tensor) - -def test_bernoulli_log_prob(): - """ - Test pmf/log_pmf: passing value through construct. - """ - net = BernoulliLogProb() - value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) - log_pmf = net(value) - assert isinstance(log_pmf, Tensor) - -class BernoulliProb1(nn.Cell): - """ - Bernoulli distribution: initialize without probs. - """ - def __init__(self): - super(BernoulliProb1, self).__init__() - self.bernoulli = nn.Bernoulli() - - def construct(self, value, probs): - return self.bernoulli('prob', value, probs) - -class BernoulliLogProb1(nn.Cell): - """ - Bernoulli distribution: initialize without probs. - """ - def __init__(self): - super(BernoulliLogProb1, self).__init__() - self.bernoulli = nn.Bernoulli() - - def construct(self, value, probs): - return self.bernoulli('log_prob', value, probs) - - -def test_bernoulli_prob1(): - """ - Test pmf/log_pmf: passing probs through construct. - """ - net = BernoulliProb1() - value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) - probs = Tensor([0.3], dtype=dtype.float32) - pmf = net(value, probs) - assert isinstance(pmf, Tensor) - -def test_bernoulli_log_prob1(): - """ - Test pmf/log_pmf: passing probs through construct. - """ - net = BernoulliLogProb1() - value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) - probs = Tensor([0.3], dtype=dtype.float32) - log_pmf = net(value, probs) - assert isinstance(log_pmf, Tensor) - -class BernoulliProb2(nn.Cell): - """ - Bernoulli distribution: initialize with probs. - """ - def __init__(self): - super(BernoulliProb2, self).__init__() - self.bernoulli = nn.Bernoulli(0.5) - - def construct(self, value, probs): - return self.bernoulli('prob', value, probs) - -class BernoulliLogProb2(nn.Cell): - """ - Bernoulli distribution: initialize with probs. - """ - def __init__(self): - super(BernoulliLogProb2, self).__init__() - self.bernoulli = nn.Bernoulli(0.5) - - def construct(self, value, probs): - return self.bernoulli('log_prob', value, probs) - - -def test_bernoulli_prob2(): - """ - Test pmf/log_pmf: passing probs/value through construct. - Overwrite original probs. - """ - net = BernoulliProb2() - value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) - probs = Tensor([0.3], dtype=dtype.float32) - pmf = net(value, probs) - assert isinstance(pmf, Tensor) - -def test_bernoulli_log_prob2(): - """ - Test pmf/log_pmf: passing probs/value through construct. - Overwrite original probs. - """ - net = BernoulliLogProb2() - value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) - probs = Tensor([0.3], dtype=dtype.float32) - log_pmf = net(value, probs) - assert isinstance(log_pmf, Tensor) - - -class NormalKl(nn.Cell): - """ - Test class: kl_loss of Normal distribution. - """ - def __init__(self): - super(NormalKl, self).__init__() - self.n = nn.Normal(Tensor([3.0]), Tensor([4.0]), dtype=dtype.float32) - - def construct(self, x_, y_): - return self.n('kl_loss', 'Normal', x_, y_) - -class BernoulliKl(nn.Cell): - """ - Test class: kl_loss between Bernoulli distributions. - """ - def __init__(self): - super(BernoulliKl, self).__init__() - self.b = nn.Bernoulli(0.7, dtype=dtype.int32) - - def construct(self, x_): - return self.b('kl_loss', 'Bernoulli', x_) - -def test_kl(): - """ - Test kl_loss function. - """ - nor_net = NormalKl() - mean_b = np.array([1.0]).astype(np.float32) - sd_b = np.array([1.0]).astype(np.float32) - mean = Tensor(mean_b, dtype=dtype.float32) - sd = Tensor(sd_b, dtype=dtype.float32) - loss = nor_net(mean, sd) - assert isinstance(loss, Tensor) - - ber_net = BernoulliKl() - probs_b = Tensor([0.3], dtype=dtype.float32) - loss = ber_net(probs_b) - assert isinstance(loss, Tensor) - - -class NormalKlNoArgs(nn.Cell): - """ - Test class: kl_loss of Normal distribution. - No args during initialization. - """ - def __init__(self): - super(NormalKlNoArgs, self).__init__() - self.n = nn.Normal(dtype=dtype.float32) - - def construct(self, x_, y_, w_, v_): - return self.n('kl_loss', 'Normal', x_, y_, w_, v_) - -class BernoulliKlNoArgs(nn.Cell): - """ - Test class: kl_loss between Bernoulli distributions. - No args during initialization. - """ - def __init__(self): - super(BernoulliKlNoArgs, self).__init__() - self.b = nn.Bernoulli(dtype=dtype.int32) - - def construct(self, x_, y_): - return self.b('kl_loss', 'Bernoulli', x_, y_) - -def test_kl_no_args(): - """ - Test kl_loss function. - """ - nor_net = NormalKlNoArgs() - mean_b = np.array([1.0]).astype(np.float32) - sd_b = np.array([1.0]).astype(np.float32) - mean_a = np.array([2.0]).astype(np.float32) - sd_a = np.array([3.0]).astype(np.float32) - mean_b = Tensor(mean_b, dtype=dtype.float32) - sd_b = Tensor(sd_b, dtype=dtype.float32) - mean_a = Tensor(mean_a, dtype=dtype.float32) - sd_a = Tensor(sd_a, dtype=dtype.float32) - loss = nor_net(mean_b, sd_b, mean_a, sd_a) - assert isinstance(loss, Tensor) - - ber_net = BernoulliKlNoArgs() - probs_b = Tensor([0.3], dtype=dtype.float32) - probs_a = Tensor([0.7], dtype=dtype.float32) - loss = ber_net(probs_b, probs_a) - assert isinstance(loss, Tensor) - - - -class NormalBernoulli(nn.Cell): - """ - Test class: basic mean/sd function. - """ - def __init__(self): - super(NormalBernoulli, self).__init__() - self.n = nn.Normal(3.0, 4.0, dtype=dtype.float32) - self.b = nn.Bernoulli(0.5, dtype=dtype.int32) - - def construct(self): - normal_mean = self.n('mean') - normal_sd = self.n('sd') - bernoulli_mean = self.b('mean') - bernoulli_sd = self.b('sd') - return normal_mean, normal_sd, bernoulli_mean, bernoulli_sd - -def test_bascis(): - """ - Test mean/sd functionality of Normal and Bernoulli. - """ - net = NormalBernoulli() - normal_mean, normal_sd, bernoulli_mean, bernoulli_sd = net() - assert isinstance(normal_mean, Tensor) - assert isinstance(normal_sd, Tensor) - assert isinstance(bernoulli_mean, Tensor) - assert isinstance(bernoulli_sd, Tensor)