!2605 High level abstraction of mathematical distributions

Merge pull request !2605 from XunDeng/pp_poc_v3
This commit is contained in:
mindspore-ci-bot 2020-07-08 12:44:52 +08:00 committed by Gitee
commit dffb76a0a9
10 changed files with 1459 additions and 1 deletions

View File

@ -17,13 +17,15 @@ Neural Networks Cells.
Pre-defined building blocks or computing units to construct Neural Networks.
"""
from . import layer, loss, optim, metrics, wrap
from . import layer, loss, optim, metrics, wrap, distribution
from .cell import Cell, GraphKernel
from .layer import *
from .loss import *
from .optim import *
from .metrics import *
from .wrap import *
from .distribution import *
__all__ = ["Cell", "GraphKernel"]
__all__.extend(layer.__all__)
@ -31,5 +33,7 @@ __all__.extend(loss.__all__)
__all__.extend(optim.__all__)
__all__.extend(metrics.__all__)
__all__.extend(wrap.__all__)
__all__.extend(distribution.__all__)
__all__.sort()

View File

@ -0,0 +1,27 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Distribution.
The high-level components(Distributions) used to construct the probabilistic network.
"""
from .distribution import Distribution
from .normal import Normal
from .bernoulli import Bernoulli
__all__ = ['Distribution',
'Normal',
'Bernoulli',]

View File

@ -0,0 +1,24 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Distribution operation utility functions.
"""
from .utils import *
__all__ = ['check_scalar', 'convert_to_batch', 'cast_to_tensor',
'calc_batch_size', 'check_greater',
'check_greater_equal_zero',
'calc_broadcast_shape_from_param',
'check_scalar_from_param', 'check_prob']

View File

@ -0,0 +1,199 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utitly functions to help distribution class."""
import numpy as np
from mindspore.ops import _utils as utils
from ....common.tensor import Tensor, MetaTensor
from ....common.parameter import Parameter
from ....common import dtype as mstype
def check_scalar(value):
"""
Check if input value is a scalar.
"""
return np.isscalar(value)
def cast_to_tensor(t, dtype=mstype.float32):
"""
Cast an user input value into a Tensor of dtype.
Args:
t (int, float, list, numpy.ndarray, Tensor, Parameter): object to be cast to Tensor.
dtype (mindspore.dtype): dtype of the Tensor. Default: mstype.float32.
Raises:
RuntimeError: if t cannot be cast to Tensor.
Returns:
Tensor.
"""
if isinstance(t, Parameter):
return t
if isinstance(t, Tensor):
#check if the Tensor in shape of Tensor(4)
if t.dim() == 0:
value = t.asnumpy()
return Tensor([t], dtype=dtype)
#convert the type of tensor to dtype
t.set_dtype(dtype)
return t
if isinstance(t, (list, np.ndarray)):
return Tensor(t, dtype=dtype)
if check_scalar(t):
return Tensor([t], dtype=dtype)
raise RuntimeError("Input type is not supported.")
def calc_batch_size(batch_shape):
"""
Calculate the size of a given batch_shape.
Args:
batch_shape (tuple): batch shape to be calculated.
Returns:
int.
"""
return int(np.prod(batch_shape))
def convert_to_batch(t, batch_shape, dtype):
"""
Convert a Tensor to a given batch shape.
Args:
t (Tensor, Parameter): Tensor to be converted.
batch_shape (tuple): desired batch shape.
dtype (mindspore.dtype): desired dtype.
Raises:
RuntimeError: if the converison cannot be done.
Returns:
Tensor, with shape of batch_shape.
"""
if isinstance(t, Parameter):
return t
t = cast_to_tensor(t, dtype)
if t.shape != batch_shape:
mul = calc_batch_size(batch_shape) // t.size()
if (calc_batch_size(batch_shape) % t.size()) != 0:
raise RuntimeError("Cannot cast the tensor to the given batch shape.")
temp = list(t.asnumpy()) * mul
temp = np.reshape(temp, batch_shape)
return Tensor(temp, dtype)
return t
def check_scalar_from_param(params):
"""
Check if params are all scalars.
Args:
params (dict): parameters used to initialize distribution.
Notes: String parameters are excluded.
"""
for value in params.values():
if isinstance(value, (str, type(params['dtype']))):
continue
elif check_scalar(value):
continue
else:
return False
return True
def calc_broadcast_shape_from_param(params):
"""
Calculate the broadcast shape from params.
Args:
params (dict): parameters used to initialize distribution.
Returns:
tuple.
"""
broadcast_shape = []
for value in params.values():
if isinstance(value, (str, type(params['dtype']))):
continue
if value is None:
return None
if isinstance(value, Parameter):
value_t = value.default_input
else:
value_t = cast_to_tensor(value, params['dtype'])
broadcast_shape = utils.get_broadcast_shape(broadcast_shape, list(value_t.shape), params['name'])
return tuple(broadcast_shape)
def check_greater_equal_zero(value, name):
"""
Check if the given Tensor is greater zero.
Args:
value (Tensor, Parameter): value to be checked.
name (str) : name of the value.
Raises:
ValueError: if the input value is less than zero.
"""
if isinstance(value, Parameter):
if isinstance(value.default_input, MetaTensor):
return
value = value.default_input
comp = np.less(value.asnumpy(), np.zeros(value.shape))
if comp.any():
raise ValueError(f'{name} should be greater than zero.')
def check_greater(a, b, name_a, name_b):
"""
Check if Tensor b is strictly greater than Tensor a.
Args:
a (Tensor): input tensor a.
b (Tensor): input tensor b.
name_a (str): name of Tensor_a.
name_b (str): name of Tensor_b.
Raises:
ValueError: if b is less than or equal to a
"""
comp = np.less(a.asnumpy(), b.asnumpy())
if not comp.all():
raise ValueError(f'{name_a} should be less than {name_b}')
def check_prob(p):
"""
Check if p is a proper probability, i.e. 0 <= p <=1.
Args:
p (Tensor, Parameter): value to be checked.
Raises:
ValueError: if p is not a proper probability.
"""
if isinstance(p, Parameter):
if isinstance(p.default_input, MetaTensor):
return
p = p.default_input
comp = np.less(p.asnumpy(), np.zeros(p.shape))
if comp.any():
raise ValueError('Probabilities should be greater than or equal to zero')
comp = np.greater(p.asnumpy(), np.ones(p.shape))
if comp.any():
raise ValueError('Probabilities should be less than or equal to one')

View File

@ -0,0 +1,167 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Bernoulli Distribution"""
from mindspore.ops import operations as P
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_prob
from ...common import dtype as mstype
class Bernoulli(Distribution):
"""
Example class: Bernoulli Distribution.
Args:
probs (int, float, list, numpy.ndarray, Tensor, Parameter): probability of 1 as outcome.
seed (int): seed to use in sampling. Default: 0.
dtype (mindspore.dtype): type of the distribution. Default: mstype.int32.
name (str): name of the distribution. Default: Bernoulli.
Note:
probs should be proper probabilities (0 <= p <= 1).
Examples:
>>> # To initialize a Bernoulli distribution which has equal probability of getting 1 and 0
>>> b = nn.Bernoulli(0.5, dtype = mstype.int32)
>>> # The following create two independent Bernoulli distributions
>>> b = nn.Bernoulli([0.7, 0.2], dtype = mstype.int32)
"""
def __init__(self,
probs=None,
seed=0,
dtype=mstype.int32,
name="Bernoulli"):
"""
Constructor of Bernoulli distribution.
"""
param = dict(locals())
super(Bernoulli, self).__init__(dtype, name, param)
if probs is not None:
self._probs = cast_to_tensor(probs)
check_prob(self._probs)
else:
self._probs = probs
# ops needed for the class
self.log = P.Log()
self.add = P.TensorAdd()
self.mul = P.Mul()
self.sqrt = P.Sqrt()
self.realdiv = P.RealDiv()
self.shape = P.Shape()
self.const = P.ScalarToArray()
self.less = P.Less()
self.cast = P.Cast()
self.normal = P.Normal(seed=seed)
self.erf = P.Erf()
self.sqrt = P.Sqrt()
def extend_repr(self):
str_info = f'probs = {self._probs}'
return str_info
def probs(self):
"""
Returns the probability for the outcome is 1.
"""
return self._probs
def _mean(self, name='mean', probs1=None):
r"""
.. math::
MEAN(B) = probs1
"""
if name == 'mean':
return self._probs if probs1 is None else probs1
return None
def _var(self, name='var', probs1=None):
r"""
.. math::
VAR(B) = probs1 * probs0
"""
if name in ('sd', 'var'):
probs1 = self._probs if probs1 is None else probs1
probs0 = self.add(1, -1 * probs1)
return self.mul(probs0, probs1)
return None
def _prob(self, name, value, probs=None):
r"""
pmf of Bernoulli distribution.
Args:
name (str): name of the function. Should be "prob" when passed in from construct.
value (Tensor): a Tensor composed of only zeros and ones.
probs (Tensor): probability of outcome is 1. Default: self._probs.
.. math::
pmf(k) = probs1 if k = 1;
pmf(k) = probs0 if k = 0;
"""
if name in ('prob', 'log_prob'):
probs1 = self._probs if probs is None else probs
probs0 = self.add(1, -1 * probs1)
return self.add(self.mul(probs1, value),
self.mul(probs0, self.add(1, -1 * value)))
return None
def _kl_loss(self, name, dist, probs1_b, probs1_a=None):
r"""
Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b).
Args:
name (str): name of the funtion. Should always be "kl_loss" when passed in from construct.
dist (str): type of the distributions. Should be "Bernoulli" in this case.
probs1_b (Tensor): probs1 of distribution b.
probs1_a (Tensor): probs1 of distribution a. Default: self._probs.
.. math::
KL(a||b) = probs1_a * \log(\fract{probs1_a}{probs1_b}) +
probs0_a * \log(\fract{probs0_a}{probs0_b})
"""
if name == 'kl_loss' and dist == 'Bernoulli':
probs1_a = self._probs if probs1_a is None else probs1_a
probs0_a = self.add(1, -1 * probs1_a)
probs0_b = self.add(1, -1 * probs1_b)
return self.add(probs1_a * self.log(self.realdiv(probs1_a, probs1_b)),
probs0_a * self.log(self.realdiv(probs0_a, probs0_b)))
return None
def _sample(self, name, shape=(), probs=None):
"""
Sampling.
Args:
name (str): name of the function. Should always be 'sample' when passed in from construct.
shape (tuple): shape of the sample. Default: ().
probs (Tensor): probs1 of the samples. Default: self._probs.
Returns:
Tensor, shape is shape + batch_shape.
"""
if name == 'sample':
probs1 = self._probs if probs is None else probs
batch_shape = self.shape(probs1)
sample_shape = shape + batch_shape
mean_zero = self.const(0.0)
sd_one = self.const(1.0)
sqrt_two = self.sqrt(self.const(2.0))
sample_norm = self.normal(sample_shape, mean_zero, sd_one)
sample_uniform = 0.5 * (1 + self.erf(self.realdiv(sample_norm, sqrt_two)))
sample = self.less(sample_uniform, probs1)
sample = self.cast(sample, self._dtype)
return sample
return None

View File

@ -0,0 +1,200 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""basic"""
from ..cell import Cell
from ._utils.utils import calc_broadcast_shape_from_param
class Distribution(Cell):
"""
Base class for all mathematical distributions.
Args:
dtype (mindspore.dtype): type of the distribution.
name (str): name of the distribution.
param (dict): parameters used to initialize the distribution.
Note:
Derived class should override operations such as ,_mean, _prob,
and _log_prob. Functions should be called through construct when
used inside a network in the form of function name followed by
arguments.
Examples:
>>> class MyNormalDistribution(Distribution):
>>> def __init__(self):
>>> super(MyDistribution, self).__init__()
>>> self._mean_value = Tensor([2.0,3.0])
>>> self._sd_value = Tensor([2.0,3.0])
>>>
>>> def _mean(self):
>>> return self._mean_value
"""
def __init__(self,
dtype,
name,
param):
"""
Constructor of distribution class.
"""
super(Distribution, self).__init__()
self._name = name
self._dtype = dtype
self._parameters = {}
# parsing parameters
for k in param.keys():
if not(k == 'self' or k.startswith('_')):
self._parameters[k] = param[k]
# some attributes
self._broadcast_shape = calc_broadcast_shape_from_param(
self._parameters)
# set the function to call according to the derived class's attributes
self._set_prob()
self._set_log_prob()
self._set_sd()
def _set_prob(self):
"""
Set probability funtion based on the availability of _prob and _log_likehood.
"""
if hasattr(self, '_prob'):
self._call_prob = self._prob
elif hasattr(self, '_log_likelihood'):
self._call_prob = self._calc_prob_from_log_likelihood
def _set_sd(self):
"""
Set standard deviation based on the availability of _sd and _var.
"""
if hasattr(self, '_sd'):
self._call_sd = self._sd
elif hasattr(self, '_var'):
self._call_sd = self._calc_sd_from_var
def _set_log_prob(self):
"""
Set log probability based on the availability of _prob and _log_likelihood.
"""
if hasattr(self, '_log_likelihood'):
self._call_log_prob = self._log_likelihood
if hasattr(self, '_prob'):
self._call_log_prob = self._calc_log_prob_from_prob
def log_likelihood(self, *args):
"""
Evaluate the log probability at the given value.
Note:
value is casted to Tensor for further calculation.
Returns:
Tensor, shape is the broadcast_shape of the distribution.
"""
return self._call_log_prob(*args)
def _calc_prob_from_log_likelihood(self, *args):
r"""
Evaluate prob from log probability.
.. math::
probability(x) = \exp(log_likehood(x))
"""
return self.exp(self._log_likelihood(*args))
def prob(self, *args):
"""
Evaluate the prob (pdf or pmf) at given value.
Note:
value is casted to Tensor for further calculation.
Returns:
Tensor, shape is the broadcast_shape of the distribution.
"""
return self._call_prob(*args)
def _calc_log_prob_from_prob(self, *args):
r"""
Evaluate log probability from probability.
.. math::
log_prob(x) = \log(prob(x))
"""
return self.log(self._prob(*args))
def kl_loss(self, **kwargs):
"""
Evaluate the KL divergence. Parameters of the second distribution should be
passed in through **kwargs.
Returns:
Tensor, shape is the broadcast_shape of the distribution and input distribution.
"""
return self._kl_loss(**kwargs)
def mean(self, **kwargs):
"""
Evaluate the mean.
Returns:
Tensor, shape is the broadcast_shape of the distribution.
"""
return self._mean(**kwargs)
def sd(self, **kwargs):
"""
Evaluate the standard deviation.
Returns:
Tensor, shape is the broadcast_shape of the distribution.
"""
return self._call_sd(**kwargs)
def _calc_sd_from_var(self, *args):
r"""
Evaluate log probability from probability.
.. math::
STD(x) = \sqrt(VAR(x))
"""
return self.sqrt(self._var(*args))
def construct(self, *inputs):
"""
Override construct in Cell.
Args:
*inputs: inputs[0] is always the name of the function.
Notes:
Always raise RuntimeError as Distribution should not be called directly.
"""
if inputs[0] == 'log_prob':
return self._call_log_prob(*inputs)
if inputs[0] == 'prob':
return self._call_prob(*inputs)
if inputs[0] == 'kl_loss':
return self._kl_loss(*inputs)
if inputs[0] == 'mean':
return self._mean(*inputs)
if inputs[0] == 'sd':
return self._call_sd(*inputs)
if inputs[0] == 'sample':
return self._sample(*inputs)
return None

View File

@ -0,0 +1,169 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Normal Distribution"""
import numpy as np
from mindspore.ops import operations as P
from .distribution import Distribution
from ._utils.utils import convert_to_batch, check_greater_equal_zero
from ...common import dtype as mstype
from ...context import get_context
class Normal(Distribution):
"""
Example class: Normal distribution.
Args:
mean (int, float, list, numpy.ndarray, Tensor, Parameter): mean of the Gaussian distribution.
sd (int, float, list, numpy.ndarray, Tensor, Parameter): stddev of the Gaussian distribution.
seed (int): seed to use in sampling. Default: 0.
dtype (mindspore.dtype): type of the distribution. Default: mstype.float32.
name (str): name of the distribution. Default: Normal.
Note:
Standard deviation should be greater than zero.
Examples:
>>> # To initialize a normal distribution of mean 3.0 and standard deviation 4.0
>>> n = nn.Normal(3.0, 4.0, dtype=mstype.float32)
>>> # The following create two independent normal distributions
>>> n = nn.Normal([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32)
"""
def __init__(self,
mean=None,
sd=None,
seed=0,
dtype=mstype.float32,
name="Normal"):
"""
Constructor of normal distribution.
"""
param = dict(locals())
super(Normal, self).__init__(dtype, name, param)
if mean is not None and sd is not None:
self._mean_value = convert_to_batch(mean, self._broadcast_shape, dtype)
self._sd_value = convert_to_batch(sd, self._broadcast_shape, dtype)
check_greater_equal_zero(self._sd_value, "Standard deviation")
else:
self._mean_value = mean
self._sd_value = sd
#ops needed for the class
self.exp = P.Exp()
self.add = P.TensorAdd()
self.mul = P.Mul()
self.sq = P.Square()
self.log = P.Log()
self.sqrt = P.Sqrt()
self.realdiv = P.RealDiv()
self.expm1 = P.Expm1() if get_context('device_target') == 'Ascend' else self._expm1_by_step
self.normal = P.Normal(seed=seed)
self.shape = P.Shape()
self.zeroslike = P.ZerosLike()
self.const = P.ScalarToArray()
def extend_repr(self):
str_info = f'mean = {self._mean_value}, standard deviation = {self._sd_value}'
return str_info
def _expm1_by_step(self, x):
"""
Expm1 ops under GPU context.
"""
return self.add(self.exp(x), -1)
def _mean(self, name='mean', mean=None, sd=None):
"""
Mean of the distribution.
"""
if name == 'mean':
mean = self._mean_value if mean is None or sd is None else mean
return mean
return None
def _sd(self, name='sd', mean=None, sd=None):
"""
Standard deviation of the distribution.
"""
if name in ('sd', 'var'):
sd = self._sd_value if mean is None or sd is None else sd
return sd
return None
def _log_likelihood(self, name, value, mean=None, sd=None):
r"""
Evaluate log probability.
.. math::
L(x) = -1* \fract{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2))
"""
if name in ('prob', 'log_prob'):
mean = self._mean_value if mean is None else mean
sd = self._sd_value if sd is None else sd
unnormalized_log_prob = -1. * self.realdiv(self.sq(self.add(value, -1. * mean)),
2. * self.sq(sd))
neg_normalization = -1. * self.log(self.sqrt(2. * np.pi * self.sq(sd)))
return self.add(unnormalized_log_prob, neg_normalization)
return None
def _kl_loss(self, name, dist, mean_b, sd_b, mean_a=None, sd_a=None):
r"""
Evaluate Normal-Normal kl divergence, i.e. KL(a||b).
Args:
name (str): name of the funtion passed in from construct. Should always be "kl_loss".
dist (str): type of the distributions. Should be "Normal" in this case.
mean_b (Tensor): mean of distribution b.
sd_b (Tensor): standard deviation distribution b.
mean_a (Tensor): mean of distribution a. Default: self._mean_value.
sd_a (Tensor): standard deviation distribution a. Default: self._sd_value.
.. math::
KL(a||b) = 0.5 * (\fract{MEAN(a)}{STD(b)} - \fract{MEAN(b)}{STD(b)}) ^ 2 +
0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b)))
"""
if name == 'kl_loss' and dist == 'Normal':
mean_a = self._mean_value if mean_a is None else mean_a
sd_a = self._sd_value if sd_a is None else sd_a
diff_log_scale = self.add(self.log(sd_a), - self.log(sd_b))
squared_diff = self.sq(self.add(self.realdiv(mean_a, sd_b), - self.realdiv(mean_b, sd_b)))
return self.add(self.add(0.5 * squared_diff, 0.5 * self.expm1(2 * diff_log_scale)), - diff_log_scale)
return None
def _sample(self, name, shape=(), mean=None, sd=None):
"""
Sampling.
Args:
name (str): name of the function. Should always be 'sample' when passed in from construct.
shape (tuple): shape of the sample. Default: ().
mean (Tensor): mean of the samples. Default: self._mean_value.
sd (Tensor): standard deviation of the samples. Default: self._sd_value.
Returns:
Tensor, shape is shape + batch_shape.
"""
if name == 'sample':
mean = self._mean_value if mean is None else mean
sd = self._sd_value if sd is None else sd
batch_shape = self.shape(self.add(self.zeroslike(mean), self.zeroslike(sd)))
sample_shape = shape + batch_shape
mean_zero = self.const(0.0)
sd_one = self.const(1.0)
sample_norm = self.normal(sample_shape, mean_zero, sd_one)
sample = self.add(mean, self.mul(sample_norm, sd))
return sample
return None

View File

@ -0,0 +1,147 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test cases for bernoulli distribution"""
import numpy as np
from scipy import stats
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore import dtype
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
"""
Test class: probability of bernoulli distribution.
"""
def __init__(self):
super(Net, self).__init__()
self.b = nn.Bernoulli(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
return self.b('prob', x_)
class Net1(nn.Cell):
"""
Test class: log probability of bernoulli distribution.
"""
def __init__(self):
super(Net1, self).__init__()
self.b = nn.Bernoulli(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
return self.b('log_prob', x_)
class Net2(nn.Cell):
"""
Test class: kl_loss between bernoulli distributions.
"""
def __init__(self):
super(Net2, self).__init__()
self.b = nn.Bernoulli(0.7, dtype=dtype.int32)
@ms_function
def construct(self, x_):
return self.b('kl_loss', 'Bernoulli', x_)
class Net3(nn.Cell):
"""
Test class: mean/sd of bernoulli distribution.
"""
def __init__(self):
super(Net3, self).__init__()
self.b = nn.Bernoulli([0.5, 0.5], dtype=dtype.int32)
@ms_function
def construct(self):
return self.b('mean'), self.b('sd')
class Net4(nn.Cell):
"""
Test class: log probability of bernoulli distribution.
"""
def __init__(self, shape, seed=0):
super(Net4, self).__init__()
self.b = nn.Bernoulli([0.7, 0.5], seed=seed, dtype=dtype.int32)
self.shape = shape
@ms_function
def construct(self, probs=None):
return self.b('sample', self.shape, probs)
def test_pmf():
"""
Test pmf.
"""
bernoulli_benchmark = stats.bernoulli(0.7)
expect_pmf = bernoulli_benchmark.pmf([0, 1, 0, 1, 1]).astype(np.float32)
pdf = Net()
x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32)
output = pdf(x_)
tol = 1e-6
assert (np.abs(output.asnumpy() - expect_pmf) < tol).all()
def test_log_likelihood():
"""
Test log_pmf.
"""
bernoulli_benchmark = stats.bernoulli(0.7)
expect_logpmf = bernoulli_benchmark.logpmf([0, 1, 0, 1, 1]).astype(np.float32)
logprob = Net1()
x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32)
output = logprob(x_)
tol = 1e-6
assert (np.abs(output.asnumpy() - expect_logpmf) < tol).all()
def test_kl_loss():
"""
Test kl_loss.
"""
probs1_a = 0.7
probs1_b = 0.5
probs0_a = 1 - probs1_a
probs0_b = 1 - probs1_b
expect_kl_loss = probs1_a * np.log(probs1_a / probs1_b) + probs0_a * np.log(probs0_a / probs0_b)
kl_loss = Net2()
output = kl_loss(Tensor([probs1_b], dtype=dtype.float32))
tol = 1e-6
assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all()
def test_basics():
"""
Test mean/standard deviation and probs.
"""
basics = Net3()
mean, sd = basics()
expect_mean = [0.5, 0.5]
assert (mean.asnumpy() == expect_mean).all()
assert (sd.asnumpy() == expect_mean).all()
b = nn.Bernoulli([0.7, 0.5], dtype=dtype.int32)
probs = b.probs()
expect_probs = [0.7, 0.5]
tol = 1e-6
assert (np.abs(probs.asnumpy() - expect_probs) < tol).all()
def test_sample():
"""
Test sample.
"""
shape = (2, 3)
sample = Net4(shape)
output = sample()
assert output.shape == (2, 3, 2)

View File

@ -0,0 +1,152 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test cases for normal distribution"""
import numpy as np
from scipy import stats
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore import dtype
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
"""
Test class: probability of normal distribution.
"""
def __init__(self):
super(Net, self).__init__()
self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.n('prob', x_)
class Net1(nn.Cell):
"""
Test class: log probability of normal distribution.
"""
def __init__(self):
super(Net1, self).__init__()
self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
@ms_function
def construct(self, x_):
return self.n('log_prob', x_)
class Net2(nn.Cell):
"""
Test class: kl_loss of normal distribution.
"""
def __init__(self):
super(Net2, self).__init__()
self.n = nn.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
@ms_function
def construct(self, x_, y_):
return self.n('kl_loss', 'Normal', x_, y_)
class Net3(nn.Cell):
"""
Test class: mean/sd of normal distribution.
"""
def __init__(self):
super(Net3, self).__init__()
self.n = nn.Normal(np.array([3.0]), np.array([2.0, 4.0]), dtype=dtype.float32)
@ms_function
def construct(self):
return self.n('mean'), self.n('sd')
class Net4(nn.Cell):
"""
Test class: mean/sd of normal distribution.
"""
def __init__(self, shape, seed=0):
super(Net4, self).__init__()
self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), seed=seed, dtype=dtype.float32)
self.shape = shape
@ms_function
def construct(self, mean=None, sd=None):
return self.n('sample', self.shape, mean, sd)
def test_pdf():
"""
Test pdf.
"""
norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]]))
expect_pdf = norm_benchmark.pdf([1.0, 2.0]).astype(np.float32)
pdf = Net()
output = pdf(Tensor([1.0, 2.0], dtype=dtype.float32))
tol = 1e-6
assert (np.abs(output.asnumpy() - expect_pdf) < tol).all()
def test_log_likelihood():
"""
Test log_pdf.
"""
norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]]))
expect_logpdf = norm_benchmark.logpdf([1.0, 2.0]).astype(np.float32)
logprob = Net1()
output = logprob(Tensor([1.0, 2.0], dtype=dtype.float32))
tol = 1e-6
assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all()
def test_kl_loss():
"""
Test kl_loss.
"""
mean_a = np.array([3.0]).astype(np.float32)
sd_a = np.array([4.0]).astype(np.float32)
mean_b = np.array([1.0]).astype(np.float32)
sd_b = np.array([1.0]).astype(np.float32)
diff_log_scale = np.log(sd_a) - np.log(sd_b)
squared_diff = np.square(mean_a / sd_b - mean_b / sd_b)
expect_kl_loss = 0.5 * squared_diff + 0.5 * np.expm1(2 * diff_log_scale) - diff_log_scale
kl_loss = Net2()
mean = Tensor(mean_b, dtype=dtype.float32)
sd = Tensor(sd_b, dtype=dtype.float32)
output = kl_loss(mean, sd)
tol = 1e-6
assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all()
def test_basics():
"""
Test mean/standard deviation.
"""
basics = Net3()
mean, sd = basics()
expect_mean = [3.0, 3.0]
expect_sd = [2.0, 4.0]
tol = 1e-6
assert (np.abs(mean.asnumpy() - expect_mean) < tol).all()
assert (np.abs(sd.asnumpy() - expect_sd) < tol).all()
def test_sample():
"""
Test sample.
"""
shape = (2, 3)
seed = 10
mean = Tensor([2.0], dtype=dtype.float32)
sd = Tensor([2.0, 2.0, 2.0], dtype=dtype.float32)
sample = Net4(shape, seed=seed)
output = sample(mean, sd)
assert output.shape == (2, 3, 3)

View File

@ -0,0 +1,369 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Test nn.Distribution.
Including Normal Distribution and Bernoulli Distribution.
"""
import pytest
import numpy as np
import mindspore.nn as nn
from mindspore import dtype
from mindspore import Tensor
def test_normal_shape_errpr():
"""
Invalid shapes.
"""
with pytest.raises(ValueError):
nn.Normal([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32)
def test_no_arguments():
"""
No args passed in during initialization.
"""
n = nn.Normal()
assert isinstance(n, nn.Distribution)
b = nn.Bernoulli()
assert isinstance(b, nn.Distribution)
def test_with_arguments():
"""
Args passed in during initialization.
"""
n = nn.Normal([3.0], [4.0], dtype=dtype.float32)
assert isinstance(n, nn.Distribution)
b = nn.Bernoulli([0.3, 0.5], dtype=dtype.int32)
assert isinstance(b, nn.Distribution)
class NormalProb(nn.Cell):
"""
Normal distribution: initialize with mean/sd.
"""
def __init__(self):
super(NormalProb, self).__init__()
self.normal = nn.Normal(3.0, 4.0, dtype=dtype.float32)
def construct(self, value):
x = self.normal('prob', value)
y = self.normal('log_prob', value)
return x, y
def test_normal_prob():
"""
Test pdf/log_pdf: passing value through construct.
"""
net = NormalProb()
value = Tensor([0.5, 1.0], dtype=dtype.float32)
pdf, log_pdf = net(value)
assert isinstance(pdf, Tensor)
assert isinstance(log_pdf, Tensor)
class NormalProb1(nn.Cell):
"""
Normal distribution: initialize without mean/sd.
"""
def __init__(self):
super(NormalProb1, self).__init__()
self.normal = nn.Normal()
def construct(self, value, mean, sd):
x = self.normal('prob', value, mean, sd)
y = self.normal('log_prob', value, mean, sd)
return x, y
def test_normal_prob1():
"""
Test pdf/logpdf: passing mean/sd, value through construct.
"""
net = NormalProb1()
value = Tensor([0.5, 1.0], dtype=dtype.float32)
mean = Tensor([0.0], dtype=dtype.float32)
sd = Tensor([1.0], dtype=dtype.float32)
pdf, log_pdf = net(value, mean, sd)
assert isinstance(pdf, Tensor)
assert isinstance(log_pdf, Tensor)
class NormalProb2(nn.Cell):
"""
Normal distribution: initialize with mean/sd.
"""
def __init__(self):
super(NormalProb2, self).__init__()
self.normal = nn.Normal(3.0, 4.0, dtype=dtype.float32)
def construct(self, value, mean, sd):
x = self.normal('prob', value, mean, sd)
y = self.normal('log_prob', value, mean, sd)
return x, y
def test_normal_prob2():
"""
Test pdf/log_pdf: passing mean/sd through construct.
Overwrite original mean/sd.
"""
net = NormalProb2()
value = Tensor([0.5, 1.0], dtype=dtype.float32)
mean = Tensor([0.0], dtype=dtype.float32)
sd = Tensor([1.0], dtype=dtype.float32)
pdf, log_pdf = net(value, mean, sd)
assert isinstance(pdf, Tensor)
assert isinstance(log_pdf, Tensor)
class BernoulliProb(nn.Cell):
"""
Bernoulli distribution: initialize with probs.
"""
def __init__(self):
super(BernoulliProb, self).__init__()
self.bernoulli = nn.Bernoulli(0.5, dtype=dtype.int32)
def construct(self, value):
return self.bernoulli('prob', value)
class BernoulliLogProb(nn.Cell):
"""
Bernoulli distribution: initialize with probs.
"""
def __init__(self):
super(BernoulliLogProb, self).__init__()
self.bernoulli = nn.Bernoulli(0.5, dtype=dtype.int32)
def construct(self, value):
return self.bernoulli('log_prob', value)
def test_bernoulli_prob():
"""
Test pmf/log_pmf: passing value through construct.
"""
net = BernoulliProb()
value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32)
pmf = net(value)
assert isinstance(pmf, Tensor)
def test_bernoulli_log_prob():
"""
Test pmf/log_pmf: passing value through construct.
"""
net = BernoulliLogProb()
value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32)
log_pmf = net(value)
assert isinstance(log_pmf, Tensor)
class BernoulliProb1(nn.Cell):
"""
Bernoulli distribution: initialize without probs.
"""
def __init__(self):
super(BernoulliProb1, self).__init__()
self.bernoulli = nn.Bernoulli()
def construct(self, value, probs):
return self.bernoulli('prob', value, probs)
class BernoulliLogProb1(nn.Cell):
"""
Bernoulli distribution: initialize without probs.
"""
def __init__(self):
super(BernoulliLogProb1, self).__init__()
self.bernoulli = nn.Bernoulli()
def construct(self, value, probs):
return self.bernoulli('log_prob', value, probs)
def test_bernoulli_prob1():
"""
Test pmf/log_pmf: passing probs through construct.
"""
net = BernoulliProb1()
value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32)
probs = Tensor([0.3], dtype=dtype.float32)
pmf = net(value, probs)
assert isinstance(pmf, Tensor)
def test_bernoulli_log_prob1():
"""
Test pmf/log_pmf: passing probs through construct.
"""
net = BernoulliLogProb1()
value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32)
probs = Tensor([0.3], dtype=dtype.float32)
log_pmf = net(value, probs)
assert isinstance(log_pmf, Tensor)
class BernoulliProb2(nn.Cell):
"""
Bernoulli distribution: initialize with probs.
"""
def __init__(self):
super(BernoulliProb2, self).__init__()
self.bernoulli = nn.Bernoulli(0.5)
def construct(self, value, probs):
return self.bernoulli('prob', value, probs)
class BernoulliLogProb2(nn.Cell):
"""
Bernoulli distribution: initialize with probs.
"""
def __init__(self):
super(BernoulliLogProb2, self).__init__()
self.bernoulli = nn.Bernoulli(0.5)
def construct(self, value, probs):
return self.bernoulli('log_prob', value, probs)
def test_bernoulli_prob2():
"""
Test pmf/log_pmf: passing probs/value through construct.
Overwrite original probs.
"""
net = BernoulliProb2()
value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32)
probs = Tensor([0.3], dtype=dtype.float32)
pmf = net(value, probs)
assert isinstance(pmf, Tensor)
def test_bernoulli_log_prob2():
"""
Test pmf/log_pmf: passing probs/value through construct.
Overwrite original probs.
"""
net = BernoulliLogProb2()
value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32)
probs = Tensor([0.3], dtype=dtype.float32)
log_pmf = net(value, probs)
assert isinstance(log_pmf, Tensor)
class NormalKl(nn.Cell):
"""
Test class: kl_loss of Normal distribution.
"""
def __init__(self):
super(NormalKl, self).__init__()
self.n = nn.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
def construct(self, x_, y_):
return self.n('kl_loss', 'Normal', x_, y_)
class BernoulliKl(nn.Cell):
"""
Test class: kl_loss between Bernoulli distributions.
"""
def __init__(self):
super(BernoulliKl, self).__init__()
self.b = nn.Bernoulli(0.7, dtype=dtype.int32)
def construct(self, x_):
return self.b('kl_loss', 'Bernoulli', x_)
def test_kl():
"""
Test kl_loss function.
"""
nor_net = NormalKl()
mean_b = np.array([1.0]).astype(np.float32)
sd_b = np.array([1.0]).astype(np.float32)
mean = Tensor(mean_b, dtype=dtype.float32)
sd = Tensor(sd_b, dtype=dtype.float32)
loss = nor_net(mean, sd)
assert isinstance(loss, Tensor)
ber_net = BernoulliKl()
probs_b = Tensor([0.3], dtype=dtype.float32)
loss = ber_net(probs_b)
assert isinstance(loss, Tensor)
class NormalKlNoArgs(nn.Cell):
"""
Test class: kl_loss of Normal distribution.
No args during initialization.
"""
def __init__(self):
super(NormalKlNoArgs, self).__init__()
self.n = nn.Normal(dtype=dtype.float32)
def construct(self, x_, y_, w_, v_):
return self.n('kl_loss', 'Normal', x_, y_, w_, v_)
class BernoulliKlNoArgs(nn.Cell):
"""
Test class: kl_loss between Bernoulli distributions.
No args during initialization.
"""
def __init__(self):
super(BernoulliKlNoArgs, self).__init__()
self.b = nn.Bernoulli(dtype=dtype.int32)
def construct(self, x_, y_):
return self.b('kl_loss', 'Bernoulli', x_, y_)
def test_kl_no_args():
"""
Test kl_loss function.
"""
nor_net = NormalKlNoArgs()
mean_b = np.array([1.0]).astype(np.float32)
sd_b = np.array([1.0]).astype(np.float32)
mean_a = np.array([2.0]).astype(np.float32)
sd_a = np.array([3.0]).astype(np.float32)
mean_b = Tensor(mean_b, dtype=dtype.float32)
sd_b = Tensor(sd_b, dtype=dtype.float32)
mean_a = Tensor(mean_a, dtype=dtype.float32)
sd_a = Tensor(sd_a, dtype=dtype.float32)
loss = nor_net(mean_b, sd_b, mean_a, sd_a)
assert isinstance(loss, Tensor)
ber_net = BernoulliKlNoArgs()
probs_b = Tensor([0.3], dtype=dtype.float32)
probs_a = Tensor([0.7], dtype=dtype.float32)
loss = ber_net(probs_b, probs_a)
assert isinstance(loss, Tensor)
class NormalBernoulli(nn.Cell):
"""
Test class: basic mean/sd function.
"""
def __init__(self):
super(NormalBernoulli, self).__init__()
self.n = nn.Normal(3.0, 4.0, dtype=dtype.float32)
self.b = nn.Bernoulli(0.5, dtype=dtype.int32)
def construct(self):
normal_mean = self.n('mean')
normal_sd = self.n('sd')
bernoulli_mean = self.b('mean')
bernoulli_sd = self.b('sd')
return normal_mean, normal_sd, bernoulli_mean, bernoulli_sd
def test_bascis():
"""
Test mean/sd functionality of Normal and Bernoulli.
"""
net = NormalBernoulli()
normal_mean, normal_sd, bernoulli_mean, bernoulli_sd = net()
assert isinstance(normal_mean, Tensor)
assert isinstance(normal_sd, Tensor)
assert isinstance(bernoulli_mean, Tensor)
assert isinstance(bernoulli_sd, Tensor)