!5932 Modify API comments

Merge pull request !5932 from byweng/master
This commit is contained in:
mindspore-ci-bot 2020-09-11 09:23:47 +08:00 committed by Gitee
commit 36f370b72f
29 changed files with 314 additions and 307 deletions

View File

@ -124,7 +124,7 @@ if __name__ == "__main__":
epoch = 100 epoch = 100
for i in range(epoch): for i in range(epoch):
train_loss, train_acc = train_model(train_bnn_network, test_set) train_loss, train_acc = train_model(train_bnn_network, network, train_set)
valid_acc = validate_model(network, test_set) valid_acc = validate_model(network, test_set)

View File

@ -13,9 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" """
Bijector. Bijectors are the high-level components used to construct the probabilistic network.
The high-level components(Bijectors) used to construct the probabilistic network.
""" """
from .bijector import Bijector from .bijector import Bijector

View File

@ -25,11 +25,11 @@ class Bijector(Cell):
Bijecotr class. Bijecotr class.
Args: Args:
is_constant_jacobian (bool): if the bijector has constant derivative. Default: False. is_constant_jacobian (bool): Whether the Bijector has constant derivative. Default: False.
is_injective (bool): if the bijector is an one-to-one mapping. Default: True. is_injective (bool): Whether the Bijector is a one-to-one mapping. Default: True.
name (str): name of the bijector. Default: None. name (str): The name of the Bijector. Default: None.
dtype (mindspore.dtype): type of the distribution the bijector can operate on. Default: None. dtype (mindspore.dtype): The type of the distribution the Bijector can operate on. Default: None.
param (dict): parameters used to initialize the bijector. Default: None. param (dict): The parameters used to initialize the Bijector. Default: None.
""" """
def __init__(self, def __init__(self,
is_constant_jacobian=False, is_constant_jacobian=False,
@ -82,7 +82,7 @@ class Bijector(Cell):
def _check_value(self, value, name): def _check_value(self, value, name):
""" """
Check availability fo value as a Tensor. Check availability of `value` as a Tensor.
""" """
if self.context_mode == 0: if self.context_mode == 0:
self.checktensor(value, name) self.checktensor(value, name)
@ -119,11 +119,11 @@ class Bijector(Cell):
This __call__ may go into two directions: This __call__ may go into two directions:
If args[0] is a distribution instance, the call will generate a new distribution derived from If args[0] is a distribution instance, the call will generate a new distribution derived from
the input distribution. the input distribution.
Otherwise, input[0] should be the name of a bijector function, e.g. "forward", then this call will Otherwise, input[0] should be the name of a Bijector function, e.g. "forward", then this call will
go in the construct and invoke the correstpoding bijector function. go in the construct and invoke the correstpoding Bijector function.
Args: Args:
*args: args[0] shall be either a distribution or the name of a bijector function. *args: args[0] shall be either a distribution or the name of a Bijector function.
""" """
if isinstance(args[0], Distribution): if isinstance(args[0], Distribution):
return TransformedDistribution(self, args[0], self.distribution.dtype) return TransformedDistribution(self, args[0], self.distribution.dtype)
@ -131,16 +131,16 @@ class Bijector(Cell):
def construct(self, name, *args, **kwargs): def construct(self, name, *args, **kwargs):
""" """
Override construct in Cell. Override `construct` in Cell.
Note: Note:
Names of supported functions include: Names of supported functions include:
'forward', 'inverse', 'forward_log_jacobian', 'inverse_log_jacobian'. 'forward', 'inverse', 'forward_log_jacobian', and 'inverse_log_jacobian'.
Args: Args:
name (str): name of the function. name (str): The name of the function.
*args (list): list of positional arguments needed for the function. *args (list): A list of positional arguments that the function needs.
**kwargs (dictionary): dictionary of keyword arguments needed for the function. **kwargs (dictionary): A dictionary of keyword arguments that the function needs.
""" """
if name == 'forward': if name == 'forward':
return self.forward(*args, **kwargs) return self.forward(*args, **kwargs)

View File

@ -18,17 +18,20 @@ from .power_transform import PowerTransform
class Exp(PowerTransform): class Exp(PowerTransform):
r""" r"""
Exponential Bijector. Exponential Bijector.
This Bijector performs the operation: Y = exp(x). This Bijector performs the operation:
.. math::
Y = exp(x).
Args: Args:
name (str): name of the bijector. Default: 'Exp'. name (str): The name of the Bijector. Default: 'Exp'.
Examples: Examples:
>>> # To initialize a Exp bijector >>> # To initialize an Exp bijector
>>> import mindspore.nn.probability.bijector as msb >>> import mindspore.nn.probability.bijector as msb
>>> n = msb.Exp() >>> n = msb.Exp()
>>> >>>
>>> # To use Exp distribution in a network >>> # To use Exp bijector in a network
>>> class net(Cell): >>> class net(Cell):
>>> def __init__(self): >>> def __init__(self):
>>> super(net, self).__init__(): >>> super(net, self).__init__():

View File

@ -23,20 +23,25 @@ from .bijector import Bijector
class PowerTransform(Bijector): class PowerTransform(Bijector):
r""" r"""
Power Bijector. Power Bijector.
This Bijector performs the operation: Y = g(X) = (1 + X * c)^(1 / c), X >= -1 / c, where c >= 0 is the power. This Bijector performs the operation:
.. math::
Y = g(X) = (1 + X * c)^{1 / c}, X >= -1 / c
where c >= 0 is the power.
The power transform maps inputs from `[-1/c, inf]` to `[0, inf]`. The power transform maps inputs from `[-1/c, inf]` to `[0, inf]`.
This bijector is equivalent to the `Exp` bijector when `c=0` This Bijector is equivalent to the `Exp` bijector when `c=0`
Raises: Raises:
ValueError: If the power is less than 0 or is not known statically. ValueError: If the power is less than 0 or is not known statically.
Args: Args:
power (int or float): scale factor. Default: 0. power (int or float): The scale factor. Default: 0.
name (str): name of the bijector. Default: 'PowerTransform'. name (str): The name of the bijector. Default: 'PowerTransform'.
param (dict): parameters used to initialize the bijector. This is only used when other bijectors that inherits param (dict): The parameters used to initialize the bijector. These parameters are only used when other
from powertransform passing in parameters. In this case the derived bijector may overwrite the param args. Bijectors inherit from powertransform to pass in parameters. In this case the derived Bijector may overwrite
the argument `param`.
Default: None. Default: None.
Examples: Examples:

View File

@ -23,13 +23,16 @@ from .bijector import Bijector
class ScalarAffine(Bijector): class ScalarAffine(Bijector):
""" """
Scalar Affine Bijector. Scalar Affine Bijector.
This Bijector performs the operation: Y = a * X + b, where a is the scale This Bijector performs the operation:
factor and b is the shift factor.
.. math::
Y = a * X + b
where a is the scale factor and b is the shift factor.
Args: Args:
scale (float): scale factor. Default: 1.0. scale (float): The scale factor. Default: 1.0.
shift (float): shift factor. Default: 0.0. shift (float): The shift factor. Default: 0.0.
name (str): name of the bijector. Default: 'ScalarAffine'. name (str): The name of the bijector. Default: 'ScalarAffine'.
Examples: Examples:
>>> # To initialize a ScalarAffine bijector of scale 1 and shift 2 >>> # To initialize a ScalarAffine bijector of scale 1 and shift 2
@ -55,7 +58,7 @@ class ScalarAffine(Bijector):
shift=0.0, shift=0.0,
name='ScalarAffine'): name='ScalarAffine'):
""" """
Constructor of scalar affine bijector. Constructor of scalar affine Bijector.
""" """
param = dict(locals()) param = dict(locals())
validator.check_value_type('scale', scale, [int, float], type(self).__name__) validator.check_value_type('scale', scale, [int, float], type(self).__name__)

View File

@ -26,14 +26,15 @@ from .bijector import Bijector
class Softplus(Bijector): class Softplus(Bijector):
r""" r"""
Softplus Bijector. Softplus Bijector.
This Bijector performs the operation, where k is the sharpness factor. This Bijector performs the operation:
.. math:: .. math::
Y = \frac{\log(1 + e ^ {kX})}{k} Y = \frac{\log(1 + e ^ {kX})}{k}
where k is the sharpness factor.
Args: Args:
sharpness (float): scale factor. Default: 1.0. sharpness (float): The scale factor. Default: 1.0.
name (str): name of the bijector. Default: 'Softplus'. name (str): The name of the Bijector. Default: 'Softplus'.
Examples: Examples:
>>> # To initialize a Softplus bijector of sharpness 2 >>> # To initialize a Softplus bijector of sharpness 2

View File

@ -13,9 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" """
Bayesian Layer. `bnn_layers` are the high-level components used to construct the bayesian neural network.
The high-level components(Cells) used to construct the bayesian neural network.
""" """
from . import conv_variational, dense_variational, layer_distribution, bnn_cell_wrapper from . import conv_variational, dense_variational, layer_distribution, bnn_cell_wrapper

View File

@ -39,13 +39,13 @@ class ClassWrap:
@ClassWrap @ClassWrap
class WithBNNLossCell: class WithBNNLossCell:
r""" r"""
Generate WithLossCell suitable for BNN. Generate a suitable WithLossCell for BNN to wrap the bayesian network with loss function.
Args: Args:
backbone (Cell): The target network. backbone (Cell): The target network.
loss_fn (Cell): The loss function used to compute loss. loss_fn (Cell): The loss function used to compute loss.
dnn_factor(int, float): The coefficient of backbone's loss, which is computed by loss functin. Default: 1. dnn_factor(int, float): The coefficient of backbone's loss, which is computed by the loss function. Default: 1.
bnn_factor(int, float): The coefficient of kl loss, which is kl divergence of Bayesian layer. Default: 1. bnn_factor(int, float): The coefficient of KL loss, which is the KL divergence of Bayesian layer. Default: 1.
Inputs: Inputs:
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`. - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.

View File

@ -178,34 +178,34 @@ class ConvReparam(_ConvVariational):
r""" r"""
Convolutional variational layers with Reparameterization. Convolutional variational layers with Reparameterization.
See more details in paper `Auto-Encoding Variational Bayes <https://arxiv.org/abs/1312.6114>`_. For more details, refer to the paper `Auto-Encoding Variational Bayes <https://arxiv.org/abs/1312.6114>`_.
Args: Args:
in_channels (int): The number of input channel :math:`C_{in}`. in_channels (int): The number of input channel :math:`C_{in}`.
out_channels (int): The number of output channel :math:`C_{out}`. out_channels (int): The number of output channel :math:`C_{out}`.
kernel_size (Union[int, tuple[int]]): The data type is int or kernel_size (Union[int, tuple[int]]): The data type is an integer or
tuple with 2 integers. Specifies the height and width of the 2D a tuple of 2 integers. The kernel size specifies the height and
convolution window. Single int means the value if for both width of the 2D convolution window. a single integer stands for the
height and width of the kernel. A tuple of 2 ints means the value is for both height and width of the kernel. With the `kernel_size`
first value is for the height and the other is for the width of being a tuple of 2 integers, the first value is for the height and the other
the kernel. is the width of the kernel.
stride(Union[int, tuple[int]]): The distance of kernel moving, stride(Union[int, tuple[int]]): The distance of kernel moving,
an int number that represents the height and width of movement an integer number represents that the height and width of movement
are both strides, or a tuple of two int numbers that represent are both strides, or a tuple of two integers numbers represents that
height and width of movement respectively. Default: 1. height and width of movement respectively. Default: 1.
pad_mode (str): Specifies padding mode. The optional values are pad_mode (str): Specifies the padding mode. The optional values are
"same", "valid", "pad". Default: "same". "same", "valid", and "pad". Default: "same".
- same: Adopts the way of completion. Output height and width - same: Adopts the way of completion. Output height and width
will be the same as the input. will be the same as the input.
Total number of padding will be calculated for horizontal and The total number of padding will be calculated for in horizontal and
vertical direction and evenly distributed to top and bottom, vertical directions and evenly distributed to top and bottom,
left and right if possible. Otherwise, the last extra padding left and right if possible. Otherwise, the last extra padding
will be done from the bottom and the right side. If this mode will be done from the bottom and the right side. If this mode
is set, `padding` must be 0. is set, `padding` must be 0.
- valid: Adopts the way of discarding. The possibly largest - valid: Adopts the way of discarding. The possible largest
height and width of output will be return without padding. height and width of the output will be returned without padding.
Extra pixels will be discarded. If this mode is set, `padding` Extra pixels will be discarded. If this mode is set, `padding`
must be 0. must be 0.
@ -215,9 +215,9 @@ class ConvReparam(_ConvVariational):
padding (Union[int, tuple[int]]): Implicit paddings on both sides of padding (Union[int, tuple[int]]): Implicit paddings on both sides of
the input. Default: 0. the input. Default: 0.
dilation (Union[int, tuple[int]]): The data type is int or tuple dilation (Union[int, tuple[int]]): The data type is an integer or a tuple
with 2 integers. Specifies the dilation rate to use for dilated of 2 integers. This parameter specifies the dilation rate of the
convolution. If set to be :math:`k > 1`, dilated convolution. If set to be :math:`k > 1`,
there will be :math:`k - 1` pixels skipped for each sampling there will be :math:`k - 1` pixels skipped for each sampling
location. Its value should be greater or equal to 1 and bounded location. Its value should be greater or equal to 1 and bounded
by the height and width of the input. Default: 1. by the height and width of the input. Default: 1.
@ -226,28 +226,28 @@ class ConvReparam(_ConvVariational):
Default: 1. Default: 1.
has_bias (bool): Specifies whether the layer uses a bias vector. has_bias (bool): Specifies whether the layer uses a bias vector.
Default: False. Default: False.
weight_prior_fn: prior distribution for weight. weight_prior_fn: The prior distribution for weight.
It should return a mindspore distribution instance. It should return a mindspore distribution instance.
Default: NormalPrior. (which creates an instance of standard Default: NormalPrior. (which creates an instance of standard
normal distribution). The current version only supports normal distribution. normal distribution). The current version only supports normal distribution.
weight_posterior_fn: posterior distribution for sampling weight. weight_posterior_fn: The posterior distribution for sampling weight.
It should be a function handle which returns a mindspore It should be a function handle which returns a mindspore
distribution instance. Default: lambda name, shape: NormalPosterior(name=name, shape=shape). distribution instance. Default: lambda name, shape: NormalPosterior(name=name, shape=shape).
The current version only supports normal distribution. The current version only supports normal distribution.
bias_prior_fn: prior distribution for bias vector. It should return bias_prior_fn: The prior distribution for bias vector. It should return
a mindspore distribution. Default: NormalPrior(which creates an a mindspore distribution. Default: NormalPrior(which creates an
instance of standard normal distribution). The current version instance of standard normal distribution). The current version
only supports normal distribution. only supports normal distribution.
bias_posterior_fn: posterior distribution for sampling bias vector. bias_posterior_fn: The posterior distribution for sampling bias vector.
It should be a function handle which returns a mindspore It should be a function handle which returns a mindspore
distribution instance. Default: lambda name, shape: NormalPosterior(name=name, shape=shape). distribution instance. Default: lambda name, shape: NormalPosterior(name=name, shape=shape).
The current version only supports normal distribution. The current version only supports normal distribution.
Inputs: Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. - **input** (Tensor) - The shape of the tensor is :math:`(N, C_{in}, H_{in}, W_{in})`.
Outputs: Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. Tensor, with the shape being :math:`(N, C_{out}, H_{out}, W_{out})`.
Examples: Examples:
>>> net = ConvReparam(120, 240, 4, has_bias=False) >>> net = ConvReparam(120, 240, 4, has_bias=False)

View File

@ -143,12 +143,12 @@ class DenseReparam(_DenseVariational):
r""" r"""
Dense variational layers with Reparameterization. Dense variational layers with Reparameterization.
See more details in paper `Auto-Encoding Variational Bayes <https://arxiv.org/abs/1312.6114>`_. For more details, refer to the paper `Auto-Encoding Variational Bayes <https://arxiv.org/abs/1312.6114>`_.
Applies dense-connected layer for the input. This layer implements the operation as: Applies dense-connected layer to the input. This layer implements the operation as:
.. math:: .. math::
\text{outputs} = \text{activation}(\text{inputs} * \text{kernel} + \text{bias}), \text{outputs} = \text{activation}(\text{inputs} * \text{weight} + \text{bias}),
where :math:`\text{activation}` is the activation function passed as the activation where :math:`\text{activation}` is the activation function passed as the activation
argument (if passed in), :math:`\text{activation}` is a weight matrix with the same argument (if passed in), :math:`\text{activation}` is a weight matrix with the same
@ -162,31 +162,31 @@ class DenseReparam(_DenseVariational):
in_channels (int): The number of input channel. in_channels (int): The number of input channel.
out_channels (int): The number of output channel . out_channels (int): The number of output channel .
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
activation (str, Cell): Regularizer function applied to the output of the layer. The type of activation can activation (str, Cell): A regularization function applied to the output of the layer. The type of `activation`
be str (eg. 'relu') or Cell (eg. nn.ReLU()). Note that if the type of activation is Cell, it must have been can be a string (eg. 'relu') or a Cell (eg. nn.ReLU()). Note that if the type of activation is Cell, it must
instantiated. Default: None. be instantiated beforehand. Default: None.
weight_prior_fn: prior distribution for weight. weight_prior_fn: The prior distribution for weight.
It should return a mindspore distribution instance. It should return a mindspore distribution instance.
Default: NormalPrior. (which creates an instance of standard Default: NormalPrior. (which creates an instance of standard
normal distribution). The current version only supports normal distribution. normal distribution). The current version only supports normal distribution.
weight_posterior_fn: posterior distribution for sampling weight. weight_posterior_fn: The posterior distribution for sampling weight.
It should be a function handle which returns a mindspore It should be a function handle which returns a mindspore
distribution instance. Default: lambda name, shape: NormalPosterior(name=name, shape=shape). distribution instance. Default: lambda name, shape: NormalPosterior(name=name, shape=shape).
The current version only supports normal distribution. The current version only supports normal distribution.
bias_prior_fn: prior distribution for bias vector. It should return bias_prior_fn: The prior distribution for bias vector. It should return
a mindspore distribution. Default: NormalPrior(which creates an a mindspore distribution. Default: NormalPrior(which creates an
instance of standard normal distribution). The current version instance of standard normal distribution). The current version
only supports normal distribution. only supports normal distribution.
bias_posterior_fn: posterior distribution for sampling bias vector. bias_posterior_fn: The posterior distribution for sampling bias vector.
It should be a function handle which returns a mindspore It should be a function handle which returns a mindspore
distribution instance. Default: lambda name, shape: NormalPosterior(name=name, shape=shape). distribution instance. Default: lambda name, shape: NormalPosterior(name=name, shape=shape).
The current version only supports normal distribution. The current version only supports normal distribution.
Inputs: Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`. - **input** (Tensor) - The shape of the tensor is :math:`(N, in\_channels)`.
Outputs: Outputs:
Tensor of shape :math:`(N, out\_channels)`. Tensor, the shape of the tensor is :math:`(N, out\_channels)`.
Examples: Examples:
>>> net = DenseReparam(3, 4) >>> net = DenseReparam(3, 4)

View File

@ -13,9 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" """
Distribution. Distributions are the high-level components used to construct the probabilistic network.
The high-level components(Distributions) used to construct the probabilistic network.
""" """
from .distribution import Distribution from .distribution import Distribution

View File

@ -26,16 +26,16 @@ class Bernoulli(Distribution):
Bernoulli Distribution. Bernoulli Distribution.
Args: Args:
probs (float, list, numpy.ndarray, Tensor, Parameter): probability of 1 as outcome. probs (float, list, numpy.ndarray, Tensor, Parameter): The probability of that the outcome is 1.
seed (int): seed to use in sampling. Global seed is used if it is None. Default: None. seed (int): The global seed is used in sampling. Global seed is used if it is None. Default: None.
dtype (mindspore.dtype): type of the distribution. Default: mstype.int32. dtype (mindspore.dtype): The type of the distribution. Default: mstype.int32.
name (str): name of the distribution. Default: Bernoulli. name (str): The name of the distribution. Default: Bernoulli.
Note: Note:
probs should be proper probabilities (0 < p < 1). `probs` should be a proper probability (0 < p < 1).
dist_spec_args is probs. dist_spec_args is `probs`.
Examples: Examples:
>>> # To initialize a Bernoulli distribution of prob 0.5 >>> # To initialize a Bernoulli distribution of prob 0.5
>>> import mindspore.nn.probability.distribution as msd >>> import mindspore.nn.probability.distribution as msd
>>> b = msd.Bernoulli(0.5, dtype=mstype.int32) >>> b = msd.Bernoulli(0.5, dtype=mstype.int32)
@ -153,13 +153,13 @@ class Bernoulli(Distribution):
@property @property
def probs(self): def probs(self):
""" """
Returns the probability for the outcome is 1. Return the probability of that the outcome is 1.
""" """
return self._probs return self._probs
def _check_param(self, probs1): def _check_param(self, probs1):
""" """
Check availablity of distribution specific args probs1. Check availablity of distribution specific args `probs1`.
""" """
if probs1 is not None: if probs1 is not None:
if self.context_mode == 0: if self.context_mode == 0:
@ -207,25 +207,25 @@ class Bernoulli(Distribution):
probs0 = 1.0 - probs1 probs0 = 1.0 - probs1
return -(probs0 * self.log(probs0)) - (probs1 * self.log(probs1)) return -(probs0 * self.log(probs0)) - (probs1 * self.log(probs1))
def _cross_entropy(self, dist, probs1_b, probs1=None): def _cross_entropy(self, dist, probs1_b, probs1_a=None):
""" """
Evaluate cross_entropy between Bernoulli distributions. Evaluate cross_entropy between Bernoulli distributions.
Args: Args:
dist (str): type of the distributions. Should be "Bernoulli" in this case. dist (str): The type of the distributions. Should be "Bernoulli" in this case.
probs1_b (Tensor): probs1 of distribution b. probs1_b (Tensor): `probs1` of distribution b.
probs1_a (Tensor): probs1 of distribution a. Default: self.probs. probs1_a (Tensor): `probs1` of distribution a. Default: self.probs.
""" """
check_distribution_name(dist, 'Bernoulli') check_distribution_name(dist, 'Bernoulli')
return self._entropy(probs1) + self._kl_loss(dist, probs1_b, probs1) return self._entropy(probs1_a) + self._kl_loss(dist, probs1_b, probs1_a)
def _log_prob(self, value, probs1=None): def _log_prob(self, value, probs1=None):
r""" r"""
pmf of Bernoulli distribution. pmf of Bernoulli distribution.
Args: Args:
value (Tensor): a Tensor composed of only zeros and ones. value (Tensor): A Tensor composed of only zeros and ones.
probs (Tensor): probability of outcome is 1. Default: self.probs. probs (Tensor): The probability of outcome is 1. Default: self.probs.
.. math:: .. math::
pmf(k) = probs1 if k = 1; pmf(k) = probs1 if k = 1;
@ -239,11 +239,11 @@ class Bernoulli(Distribution):
def _cdf(self, value, probs1=None): def _cdf(self, value, probs1=None):
r""" r"""
cdf of Bernoulli distribution. Cumulative distribution function (cdf) of Bernoulli distribution.
Args: Args:
value (Tensor): value to be evaluated. value (Tensor): The value to be evaluated.
probs (Tensor): probability of outcome is 1. Default: self.probs. probs (Tensor): The probability of that the outcome is 1. Default: self.probs.
.. math:: .. math::
cdf(k) = 0 if k < 0; cdf(k) = 0 if k < 0;
@ -264,14 +264,14 @@ class Bernoulli(Distribution):
less_than_zero = self.select(comp_zero, zeros, probs0) less_than_zero = self.select(comp_zero, zeros, probs0)
return self.select(comp_one, less_than_zero, ones) return self.select(comp_one, less_than_zero, ones)
def _kl_loss(self, dist, probs1_b, probs1=None): def _kl_loss(self, dist, probs1_b, probs1_a=None):
r""" r"""
Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b). Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b).
Args: Args:
dist (str): type of the distributions. Should be "Bernoulli" in this case. dist (str): The type of the distributions. Should be "Bernoulli" in this case.
probs1_b (Tensor, Number): probs1 of distribution b. probs1_b (Tensor, Number): `probs1` of distribution b.
probs1_a (Tensor, Number): probs1 of distribution a. Default: self.probs. probs1_a (Tensor, Number): `probs1` of distribution a. Default: self.probs.
.. math:: .. math::
KL(a||b) = probs1_a * \log(\frac{probs1_a}{probs1_b}) + KL(a||b) = probs1_a * \log(\frac{probs1_a}{probs1_b}) +
@ -280,7 +280,7 @@ class Bernoulli(Distribution):
check_distribution_name(dist, 'Bernoulli') check_distribution_name(dist, 'Bernoulli')
probs1_b = self._check_value(probs1_b, 'probs1_b') probs1_b = self._check_value(probs1_b, 'probs1_b')
probs1_b = self.cast(probs1_b, self.parameter_type) probs1_b = self.cast(probs1_b, self.parameter_type)
probs1_a = self._check_param(probs1) probs1_a = self._check_param(probs1_a)
probs0_a = 1.0 - probs1_a probs0_a = 1.0 - probs1_a
probs0_b = 1.0 - probs1_b probs0_b = 1.0 - probs1_b
return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b) return probs1_a * self.log(probs1_a / probs1_b) + probs0_a * self.log(probs0_a / probs0_b)
@ -290,8 +290,8 @@ class Bernoulli(Distribution):
Sampling. Sampling.
Args: Args:
shape (tuple): shape of the sample. Default: (). shape (tuple): The shape of the sample. Default: ().
probs (Tensor, Number): probs1 of the samples. Default: self.probs. probs1 (Tensor, Number): `probs1` of the samples. Default: self.probs.
Returns: Returns:
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.

View File

@ -23,17 +23,17 @@ from ._utils.utils import logits_to_probs, probs_to_logits, check_type, cast_to_
class Categorical(Distribution): class Categorical(Distribution):
""" """
Creates a categorical distribution parameterized by either probs or logits (but not both). Create a categorical distribution parameterized by either probabilities or logits (but not both).
Args: Args:
probs (Tensor, list, numpy.ndarray, Parameter): event probabilities. probs (Tensor, list, numpy.ndarray, Parameter): Event probabilities.
logits (Tensor, list, numpy.ndarray, Parameter, float): event log-odds. logits (Tensor, list, numpy.ndarray, Parameter, float): Event log-odds.
seed (int): seed to use in sampling. Global seed is used if it is None. Default: None. seed (int): The global seed is used in sampling. Global seed is used if it is None. Default: None.
dtype (mstype.int32): type of the distribution. Default: mstype.int32. dtype (mstype.int32): The type of the distribution. Default: mstype.int32.
name (str): name of the distribution. Default: Categorical. name (str): The name of the distribution. Default: Categorical.
Note: Note:
probs must be non-negative, finite and have a non-zero sum, and it will be normalized to sum to 1. `probs` must be non-negative, finite and have a non-zero sum, and it will be normalized to sum to 1.
Examples: Examples:
>>> # To initialize a Categorical distribution of prob is [0.5, 0.5] >>> # To initialize a Categorical distribution of prob is [0.5, 0.5]
@ -111,14 +111,14 @@ class Categorical(Distribution):
@property @property
def logits(self): def logits(self):
""" """
Returns the logits. Return the logits.
""" """
return self._logits return self._logits
@property @property
def probs(self): def probs(self):
""" """
Returns the probability. Return the probability.
""" """
return self._probs return self._probs
@ -127,7 +127,7 @@ class Categorical(Distribution):
Sampling. Sampling.
Args: Args:
sample_shape (tuple): shape of the sample. Default: (). sample_shape (tuple): The shape of the sample. Default: ().
Returns: Returns:
Tensor, shape is shape(probs)[:-1] + sample_shape Tensor, shape is shape(probs)[:-1] + sample_shape
@ -149,7 +149,7 @@ class Categorical(Distribution):
Evaluate log probability. Evaluate log probability.
Args: Args:
value (Tensor): value to be evaluated. value (Tensor): The value to be evaluated.
""" """
value = self._check_value(value, 'value') value = self._check_value(value, 'value')
value = self.expandim(self.cast(value, mstype.float32), -1) value = self.expandim(self.cast(value, mstype.float32), -1)
@ -176,8 +176,9 @@ class Categorical(Distribution):
def enumerate_support(self, expand=True): def enumerate_support(self, expand=True):
r""" r"""
Enumerate categories. Enumerate categories.
Args: Args:
expand (Bool): whether to expand. expand (Bool): Whether to expand.
""" """
num_events = self._num_events num_events = self._num_events
values = nn.Range(0., num_events, 1)() values = nn.Range(0., num_events, 1)()

View File

@ -27,25 +27,24 @@ class Distribution(Cell):
Base class for all mathematical distributions. Base class for all mathematical distributions.
Args: Args:
seed (int): random seed used in sampling. Global seed is used if it is None. Default: None. seed (int): The global seed is used in sampling. Global seed is used if it is None. Default: None.
dtype (mindspore.dtype): the type of the event samples. Default: subclass dtype. dtype (mindspore.dtype): The type of the event samples. Default: subclass dtype.
name (str): Python str name prefixed to Ops created by this class. Default: subclass name. name (str): Python string name prefixed to operations created by this class. Default: subclass name.
param (dict): parameters used to initialize the distribution. param (dict): The parameters used to initialize the distribution.
Note: Note:
Derived class should override operations such as ,_mean, _prob, Derived class should override operations such as `_mean`, `_prob`,
and _log_prob. Required arguments, such as value for _prob, and `_log_prob`. Required arguments, such as value for `_prob`,
should be passed in through args or kwargs. dist_spec_args which specify should be passed in through `args` or `kwargs`. dist_spec_args which specify
a new distribution are optional. a new distribution are optional.
dist_spec_args are unique for each type of distribution. For example, mean and sd dist_spec_args are unique for each type of distribution. For example, `mean` and `sd`
are the dist_spec_args for a Normal distribution, while rate is the dist_spec_args are the dist_spec_args for a Normal distribution, while `rate` is the dist_spec_args
for exponential distribution. for exponential distribution.
For all functions, passing in dist_spec_args, is optional. For all functions, passing in dist_spec_args, is optional.
Passing in the additional dist_spec_args will make the result to be evaluated with Passing in the additional dist_spec_args will evaluate the result to be evaluated with
new distribution specified by the dist_spec_args. But it won't change the new distribution specified by the dist_spec_args. But it will not change the original distribution.
original distribuion.
""" """
def __init__(self, def __init__(self,
@ -118,7 +117,7 @@ class Distribution(Cell):
def _check_value(self, value, name): def _check_value(self, value, name):
""" """
Check availability fo value as a Tensor. Check availability of `value` as a Tensor.
""" """
if self.context_mode == 0: if self.context_mode == 0:
self.checktensor(value, name) self.checktensor(value, name)
@ -127,7 +126,7 @@ class Distribution(Cell):
def _set_prob(self): def _set_prob(self):
""" """
Set probability funtion based on the availability of _prob and _log_likehood. Set probability funtion based on the availability of `_prob` and `_log_likehood`.
""" """
if hasattr(self, '_prob'): if hasattr(self, '_prob'):
self._call_prob = self._prob self._call_prob = self._prob
@ -136,7 +135,7 @@ class Distribution(Cell):
def _set_sd(self): def _set_sd(self):
""" """
Set standard deviation based on the availability of _sd and _var. Set standard deviation based on the availability of `_sd` and `_var`.
""" """
if hasattr(self, '_sd'): if hasattr(self, '_sd'):
self._call_sd = self._sd self._call_sd = self._sd
@ -145,7 +144,7 @@ class Distribution(Cell):
def _set_var(self): def _set_var(self):
""" """
Set variance based on the availability of _sd and _var. Set variance based on the availability of `_sd` and `_var`.
""" """
if hasattr(self, '_var'): if hasattr(self, '_var'):
self._call_var = self._var self._call_var = self._var
@ -154,7 +153,7 @@ class Distribution(Cell):
def _set_log_prob(self): def _set_log_prob(self):
""" """
Set log probability based on the availability of _prob and _log_prob. Set log probability based on the availability of `_prob` and `_log_prob`.
""" """
if hasattr(self, '_log_prob'): if hasattr(self, '_log_prob'):
self._call_log_prob = self._log_prob self._call_log_prob = self._log_prob
@ -163,7 +162,8 @@ class Distribution(Cell):
def _set_cdf(self): def _set_cdf(self):
""" """
Set cdf based on the availability of _cdf and _log_cdf and survival_functions. Set cumulative distribution function (cdf) based on the availability of `_cdf` and `_log_cdf` and
`survival_functions`.
""" """
if hasattr(self, '_cdf'): if hasattr(self, '_cdf'):
self._call_cdf = self._cdf self._call_cdf = self._cdf
@ -176,8 +176,8 @@ class Distribution(Cell):
def _set_survival(self): def _set_survival(self):
""" """
Set survival function based on the availability of _survival function and _log_survival Set survival function based on the availability of _survival function and `_log_survival`
and _call_cdf. and `_call_cdf`.
""" """
if hasattr(self, '_survival_function'): if hasattr(self, '_survival_function'):
self._call_survival = self._survival_function self._call_survival = self._survival_function
@ -188,7 +188,7 @@ class Distribution(Cell):
def _set_log_cdf(self): def _set_log_cdf(self):
""" """
Set log cdf based on the availability of _log_cdf and _call_cdf. Set log cdf based on the availability of `_log_cdf` and `_call_cdf`.
""" """
if hasattr(self, '_log_cdf'): if hasattr(self, '_log_cdf'):
self._call_log_cdf = self._log_cdf self._call_log_cdf = self._log_cdf
@ -197,7 +197,7 @@ class Distribution(Cell):
def _set_log_survival(self): def _set_log_survival(self):
""" """
Set log survival based on the availability of _log_survival and _call_survival. Set log survival based on the availability of `_log_survival` and `_call_survival`.
""" """
if hasattr(self, '_log_survival'): if hasattr(self, '_log_survival'):
self._call_log_survival = self._log_survival self._call_log_survival = self._log_survival
@ -206,7 +206,7 @@ class Distribution(Cell):
def _set_cross_entropy(self): def _set_cross_entropy(self):
""" """
Set log survival based on the availability of _cross_entropy. Set log survival based on the availability of `_cross_entropy`.
""" """
if hasattr(self, '_cross_entropy'): if hasattr(self, '_cross_entropy'):
self._call_cross_entropy = self._cross_entropy self._call_cross_entropy = self._cross_entropy
@ -216,7 +216,7 @@ class Distribution(Cell):
Evaluate the log probability(pdf or pmf) at the given value. Evaluate the log probability(pdf or pmf) at the given value.
Note: Note:
Args must include value. The argument `args` must include `value`.
dist_spec_args are optional. dist_spec_args are optional.
""" """
return self._call_log_prob(*args, **kwargs) return self._call_log_prob(*args, **kwargs)
@ -235,7 +235,7 @@ class Distribution(Cell):
Evaluate the probability (pdf or pmf) at given value. Evaluate the probability (pdf or pmf) at given value.
Note: Note:
Args must include value. The argument `args` must include `value`.
dist_spec_args are optional. dist_spec_args are optional.
""" """
return self._call_prob(*args, **kwargs) return self._call_prob(*args, **kwargs)
@ -254,7 +254,7 @@ class Distribution(Cell):
Evaluate the cdf at given value. Evaluate the cdf at given value.
Note: Note:
Args must include value. The argument `args` must include `value`.
dist_spec_args are optional. dist_spec_args are optional.
""" """
return self._call_cdf(*args, **kwargs) return self._call_cdf(*args, **kwargs)
@ -291,7 +291,7 @@ class Distribution(Cell):
Evaluate the log cdf at given value. Evaluate the log cdf at given value.
Note: Note:
Args must include value. The argument `args` must include `value`.
dist_spec_args are optional. dist_spec_args are optional.
""" """
return self._call_log_cdf(*args, **kwargs) return self._call_log_cdf(*args, **kwargs)
@ -310,7 +310,7 @@ class Distribution(Cell):
Evaluate the survival function at given value. Evaluate the survival function at given value.
Note: Note:
Args must include value. The argument `args` must include `value`.
dist_spec_args are optional. dist_spec_args are optional.
""" """
return self._call_survival(*args, **kwargs) return self._call_survival(*args, **kwargs)
@ -338,7 +338,7 @@ class Distribution(Cell):
Evaluate the log survival function at given value. Evaluate the log survival function at given value.
Note: Note:
Args must include value. The arguments `args` must include `value`.
dist_spec_args are optional. dist_spec_args are optional.
""" """
return self._call_log_survival(*args, **kwargs) return self._call_log_survival(*args, **kwargs)
@ -357,7 +357,7 @@ class Distribution(Cell):
Evaluate the KL divergence, i.e. KL(a||b). Evaluate the KL divergence, i.e. KL(a||b).
Note: Note:
Args must include type of the distribution, parameters of distribution b. The argument `args` must include the type of the distribution, parameters of distribution b.
Parameters for distribution a are optional. Parameters for distribution a are optional.
""" """
return self._kl_loss(*args, **kwargs) return self._kl_loss(*args, **kwargs)
@ -430,7 +430,7 @@ class Distribution(Cell):
Evaluate the cross_entropy between distribution a and b. Evaluate the cross_entropy between distribution a and b.
Note: Note:
Args must include type of the distribution, parameters of distribution b. The argument `args` must include the type of the distribution, parameters of distribution b.
Parameters for distribution a are optional. Parameters for distribution a are optional.
""" """
return self._call_cross_entropy(*args, **kwargs) return self._call_cross_entropy(*args, **kwargs)
@ -456,17 +456,17 @@ class Distribution(Cell):
def construct(self, name, *args, **kwargs): def construct(self, name, *args, **kwargs):
""" """
Override construct in Cell. Override `construct` in Cell.
Note: Note:
Names of supported functions include: Names of supported functions include:
'prob', 'log_prob', 'cdf', 'log_cdf', 'survival_function', 'log_survival' 'prob', 'log_prob', 'cdf', 'log_cdf', 'survival_function', 'log_survival'
'var', 'sd', 'entropy', 'kl_loss', 'cross_entropy', 'sample'. 'var', 'sd', 'entropy', 'kl_loss', 'cross_entropy', and 'sample'.
Args: Args:
name (str): name of the function. name (str): The name of the function.
*args (list): list of positional arguments needed for the function. *args (list): A list of positional arguments that the function needs.
**kwargs (dictionary): dictionary of keyword arguments needed for the function. **kwargs (dictionary): A dictionary of keyword arguments that the function needs.
""" """
if name == 'log_prob': if name == 'log_prob':

View File

@ -27,17 +27,17 @@ class Exponential(Distribution):
Example class: Exponential Distribution. Example class: Exponential Distribution.
Args: Args:
rate (float, list, numpy.ndarray, Tensor, Parameter): inverse scale. rate (float, list, numpy.ndarray, Tensor, Parameter): The inverse scale.
seed (int): seed to use in sampling. Global seed is used if it is None. Default: None. seed (int): The seed used in sampling. Global seed is used if it is None. Default: None.
dtype (mindspore.dtype): type of the distribution. Default: mstype.float32. dtype (mindspore.dtype): The type of the distribution. Default: mstype.float32.
name (str): name of the distribution. Default: Exponential. name (str): The name of the distribution. Default: Exponential.
Note: Note:
rate should be strictly greater than 0. `rate` should be strictly greater than 0.
dist_spec_args is rate. dist_spec_args is `rate`.
dtype should be float type because Exponential distributions are continuous. `dtype` should be float type because Exponential distributions are continuous.
Examples: Examples:
>>> # To initialize an Exponential distribution of rate 0.5 >>> # To initialize an Exponential distribution of rate 0.5
>>> import mindspore.nn.probability.distribution as msd >>> import mindspore.nn.probability.distribution as msd
>>> e = msd.Exponential(0.5, dtype=mstype.float32) >>> e = msd.Exponential(0.5, dtype=mstype.float32)
@ -162,7 +162,7 @@ class Exponential(Distribution):
def _check_param(self, rate): def _check_param(self, rate):
""" """
Check availablity of distribution specific args rate. Check availablity of distribution specific argument `rate`.
""" """
if rate is not None: if rate is not None:
if self.context_mode == 0: if self.context_mode == 0:
@ -209,9 +209,9 @@ class Exponential(Distribution):
Evaluate cross_entropy between Exponential distributions. Evaluate cross_entropy between Exponential distributions.
Args: Args:
dist (str): type of the distributions. Should be "Exponential" in this case. dist (str): The type of the distributions. Should be "Exponential" in this case.
rate_b (Tensor): rate of distribution b. rate_b (Tensor): The rate of distribution b.
rate_a (Tensor): rate of distribution a. Default: self.rate. rate_a (Tensor): The rate of distribution a. Default: self.rate.
""" """
check_distribution_name(dist, 'Exponential') check_distribution_name(dist, 'Exponential')
return self._entropy(rate) + self._kl_loss(dist, rate_b, rate) return self._entropy(rate) + self._kl_loss(dist, rate_b, rate)
@ -223,11 +223,11 @@ class Exponential(Distribution):
Args: Args:
Args: Args:
value (Tensor): value to be evaluated. value (Tensor): The value to be evaluated.
rate (Tensor): rate of the distribution. Default: self.rate. rate (Tensor): The rate of the distribution. Default: self.rate.
Note: Note:
Value should be greater or equal to zero. `value` should be greater or equal to zero.
.. math:: .. math::
log_pdf(x) = \log(rate) - rate * x if x >= 0 else 0 log_pdf(x) = \log(rate) - rate * x if x >= 0 else 0
@ -243,14 +243,14 @@ class Exponential(Distribution):
def _cdf(self, value, rate=None): def _cdf(self, value, rate=None):
r""" r"""
cdf of Exponential distribution. Cumulative distribution function (cdf) of Exponential distribution.
Args: Args:
value (Tensor): value to be evaluated. value (Tensor): The value to be evaluated.
rate (Tensor): rate of the distribution. Default: self.rate. rate (Tensor): The rate of the distribution. Default: self.rate.
Note: Note:
Value should be greater or equal to zero. `value` should be greater or equal to zero.
.. math:: .. math::
cdf(x) = 1.0 - \exp(-1 * \lambda * x) if x >= 0 else 0 cdf(x) = 1.0 - \exp(-1 * \lambda * x) if x >= 0 else 0
@ -268,11 +268,11 @@ class Exponential(Distribution):
log survival_function of Exponential distribution. log survival_function of Exponential distribution.
Args: Args:
value (Tensor): value to be evaluated. value (Tensor): The value to be evaluated.
rate (Tensor): rate of the distribution. Default: self.rate. rate (Tensor): The rate of the distribution. Default: self.rate.
Note: Note:
Value should be greater or equal to zero. `value` should be greater or equal to zero.
.. math:: .. math::
log_survival_function(x) = -1 * \lambda * x if x >= 0 else 0 log_survival_function(x) = -1 * \lambda * x if x >= 0 else 0
@ -290,9 +290,9 @@ class Exponential(Distribution):
Evaluate exp-exp kl divergence, i.e. KL(a||b). Evaluate exp-exp kl divergence, i.e. KL(a||b).
Args: Args:
dist (str): type of the distributions. Should be "Exponential" in this case. dist (str): The type of the distributions. Should be "Exponential" in this case.
rate_b (Tensor): rate of distribution b. rate_b (Tensor): The rate of distribution b.
rate_a (Tensor): rate of distribution a. Default: self.rate. rate_a (Tensor): The rate of distribution a. Default: self.rate.
""" """
check_distribution_name(dist, 'Exponential') check_distribution_name(dist, 'Exponential')
rate_b = self._check_value(rate_b, 'rate_b') rate_b = self._check_value(rate_b, 'rate_b')
@ -305,8 +305,8 @@ class Exponential(Distribution):
Sampling. Sampling.
Args: Args:
shape (tuple): shape of the sample. Default: (). shape (tuple): The shape of the sample. Default: ().
rate (Tensor): rate of the distribution. Default: self.rate. rate (Tensor): The rate of the distribution. Default: self.rate.
Returns: Returns:
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.

View File

@ -26,19 +26,20 @@ from ._utils.custom_ops import exp_generic, log_generic
class Geometric(Distribution): class Geometric(Distribution):
""" """
Geometric Distribution. Geometric Distribution.
It represents k+1 Bernoulli trials needed to get one success, k is the number of failures. It represents that there are k failures before the first sucess, namely taht there are in total k+1 Bernoulli trails
when the first success is achieved.
Args: Args:
probs (float, list, numpy.ndarray, Tensor, Parameter): probability of success. probs (float, list, numpy.ndarray, Tensor, Parameter): The probability of success.
seed (int): seed to use in sampling. Global seed is used if it is None. Default: None. seed (int): The seed used in sampling. Global seed is used if it is None. Default: None.
dtype (mindspore.dtype): type of the distribution. Default: mstype.int32. dtype (mindspore.dtype): The type of the distribution. Default: mstype.int32.
name (str): name of the distribution. Default: Geometric. name (str): The name of the distribution. Default: Geometric.
Note: Note:
probs should be proper probabilities (0 < p < 1). `probs` should be a proper probability (0 < p < 1).
dist_spec_args is probs. dist_spec_args is `probs`.
Examples: Examples:
>>> # To initialize a Geometric distribution of prob 0.5 >>> # To initialize a Geometric distribution of prob 0.5
>>> import mindspore.nn.probability.distribution as msd >>> import mindspore.nn.probability.distribution as msd
>>> n = msd.Geometric(0.5, dtype=mstype.int32) >>> n = msd.Geometric(0.5, dtype=mstype.int32)
@ -159,7 +160,7 @@ class Geometric(Distribution):
@property @property
def probs(self): def probs(self):
""" """
Returns the probability of success of the Bernoulli trail. Return the probability of success of the Bernoulli trail.
""" """
return self._probs return self._probs
@ -213,9 +214,9 @@ class Geometric(Distribution):
Evaluate cross_entropy between Geometric distributions. Evaluate cross_entropy between Geometric distributions.
Args: Args:
dist (str): type of the distributions. Should be "Geometric" in this case. dist (str): The type of the distributions. Should be "Geometric" in this case.
probs1_b (Tensor): probability of success of distribution b. probs1_b (Tensor): The probability of success of distribution b.
probs1_a (Tensor): probability of success of distribution a. Default: self.probs. probs1_a (Tensor): The probability of success of distribution a. Default: self.probs.
""" """
check_distribution_name(dist, 'Geometric') check_distribution_name(dist, 'Geometric')
return self._entropy(probs1) + self._kl_loss(dist, probs1_b, probs1) return self._entropy(probs1) + self._kl_loss(dist, probs1_b, probs1)
@ -225,8 +226,8 @@ class Geometric(Distribution):
pmf of Geometric distribution. pmf of Geometric distribution.
Args: Args:
value (Tensor): a Tensor composed of only natural numbers. value (Tensor): A Tensor composed of only natural numbers.
probs (Tensor): probability of success. Default: self.probs. probs (Tensor): The probability of success. Default: self.probs.
.. math:: .. math::
pmf(k) = probs0 ^k * probs1 if k >= 0; pmf(k) = probs0 ^k * probs1 if k >= 0;
@ -243,11 +244,11 @@ class Geometric(Distribution):
def _cdf(self, value, probs1=None): def _cdf(self, value, probs1=None):
r""" r"""
cdf of Geometric distribution. Cumulative distribution function (cdf) of Geometric distribution.
Args: Args:
value (Tensor): a Tensor composed of only natural numbers. value (Tensor): A Tensor composed of only natural numbers.
probs (Tensor): probability of success. Default: self.probs. probs (Tensor): The probability of success. Default: self.probs.
.. math:: .. math::
cdf(k) = 1 - probs0 ^ (k+1) if k >= 0; cdf(k) = 1 - probs0 ^ (k+1) if k >= 0;
@ -269,9 +270,9 @@ class Geometric(Distribution):
Evaluate Geometric-Geometric kl divergence, i.e. KL(a||b). Evaluate Geometric-Geometric kl divergence, i.e. KL(a||b).
Args: Args:
dist (str): type of the distributions. Should be "Geometric" in this case. dist (str): The type of the distributions. Should be "Geometric" in this case.
probs1_b (Tensor): probability of success of distribution b. probs1_b (Tensor): The probability of success of distribution b.
probs1_a (Tensor): probability of success of distribution a. Default: self.probs. probs1_a (Tensor): The probability of success of distribution a. Default: self.probs.
.. math:: .. math::
KL(a||b) = \log(\frac{probs1_a}{probs1_b}) + \frac{probs0_a}{probs1_a} * \log(\frac{probs0_a}{probs0_b}) KL(a||b) = \log(\frac{probs1_a}{probs1_b}) + \frac{probs0_a}{probs1_a} * \log(\frac{probs0_a}{probs0_b})
@ -289,8 +290,8 @@ class Geometric(Distribution):
Sampling. Sampling.
Args: Args:
shape (tuple): shape of the sample. Default: (). shape (tuple): The shape of the sample. Default: ().
probs (Tensor): probability of success. Default: self.probs. probs (Tensor): The probability of success. Default: self.probs.
Returns: Returns:
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.

View File

@ -27,18 +27,18 @@ class Normal(Distribution):
Normal distribution. Normal distribution.
Args: Args:
mean (int, float, list, numpy.ndarray, Tensor, Parameter): mean of the Normal distribution. mean (int, float, list, numpy.ndarray, Tensor, Parameter): The mean of the Normal distribution.
sd (int, float, list, numpy.ndarray, Tensor, Parameter): stddev of the Normal distribution. sd (int, float, list, numpy.ndarray, Tensor, Parameter): The standard deviation of the Normal distribution.
seed (int): seed to use in sampling. Global seed is used if it is None. Default: None. seed (int): The seed used in sampling. Global seed is used if it is None. Default: None.
dtype (mindspore.dtype): type of the distribution. Default: mstype.float32. dtype (mindspore.dtype): The type of the distribution. Default: mstype.float32.
name (str): name of the distribution. Default: Normal. name (str): The name of the distribution. Default: Normal.
Note: Note:
Standard deviation should be greater than zero. `sd` should be greater than zero.
dist_spec_args are mean and sd. dist_spec_args are `mean` and `sd`.
dtype should be float type because Normal distributions are continuous. `dtype` should be float type because Normal distributions are continuous.
Examples: Examples:
>>> # To initialize a Normal distribution of mean 3.0 and standard deviation 4.0 >>> # To initialize a Normal distribution of mean 3.0 and standard deviation 4.0
>>> import mindspore.nn.probability.distribution as msd >>> import mindspore.nn.probability.distribution as msd
>>> n = msd.Normal(3.0, 4.0, dtype=mstype.float32) >>> n = msd.Normal(3.0, 4.0, dtype=mstype.float32)
@ -161,7 +161,7 @@ class Normal(Distribution):
def _check_param(self, mean, sd): def _check_param(self, mean, sd):
""" """
Check availablity of distribution specific args mean and sd. Check availablity of distribution specific args `mean` and `sd`.
""" """
if mean is not None: if mean is not None:
if self.context_mode == 0: if self.context_mode == 0:
@ -187,21 +187,21 @@ class Normal(Distribution):
def _mean(self, mean=None, sd=None): def _mean(self, mean=None, sd=None):
""" """
Mean of the distribution. The mean of the distribution.
""" """
mean, sd = self._check_param(mean, sd) mean, sd = self._check_param(mean, sd)
return mean return mean
def _mode(self, mean=None, sd=None): def _mode(self, mean=None, sd=None):
""" """
Mode of the distribution. The mode of the distribution.
""" """
mean, sd = self._check_param(mean, sd) mean, sd = self._check_param(mean, sd)
return mean return mean
def _sd(self, mean=None, sd=None): def _sd(self, mean=None, sd=None):
""" """
Standard deviation of the distribution. The standard deviation of the distribution.
""" """
mean, sd = self._check_param(mean, sd) mean, sd = self._check_param(mean, sd)
return sd return sd
@ -221,11 +221,11 @@ class Normal(Distribution):
Evaluate cross_entropy between normal distributions. Evaluate cross_entropy between normal distributions.
Args: Args:
dist (str): type of the distributions. Should be "Normal" in this case. dist (str): Type of the distributions. Should be "Normal" in this case.
mean_b (Tensor): mean of distribution b. mean_b (Tensor): Mean of distribution b.
sd_b (Tensor): standard deviation distribution b. sd_b (Tensor): Standard deviation distribution b.
mean_a (Tensor): mean of distribution a. Default: self._mean_value. mean_a (Tensor): Mean of distribution a. Default: self._mean_value.
sd_a (Tensor): standard deviation distribution a. Default: self._sd_value. sd_a (Tensor): Standard deviation distribution a. Default: self._sd_value.
""" """
check_distribution_name(dist, 'Normal') check_distribution_name(dist, 'Normal')
return self._entropy(mean, sd) + self._kl_loss(dist, mean_b, sd_b, mean, sd) return self._entropy(mean, sd) + self._kl_loss(dist, mean_b, sd_b, mean, sd)
@ -235,9 +235,9 @@ class Normal(Distribution):
Evaluate log probability. Evaluate log probability.
Args: Args:
value (Tensor): value to be evaluated. value (Tensor): The value to be evaluated.
mean (Tensor): mean of the distribution. Default: self._mean_value. mean (Tensor): The mean of the distribution. Default: self._mean_value.
sd (Tensor): standard deviation the distribution. Default: self._sd_value. sd (Tensor): The standard deviation the distribution. Default: self._sd_value.
.. math:: .. math::
L(x) = -1* \frac{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2)) L(x) = -1* \frac{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2))
@ -254,9 +254,9 @@ class Normal(Distribution):
Evaluate cdf of given value. Evaluate cdf of given value.
Args: Args:
value (Tensor): value to be evaluated. value (Tensor): The value to be evaluated.
mean (Tensor): mean of the distribution. Default: self._mean_value. mean (Tensor): The mean of the distribution. Default: self._mean_value.
sd (Tensor): standard deviation the distribution. Default: self._sd_value. sd (Tensor): The standard deviation the distribution. Default: self._sd_value.
.. math:: .. math::
cdf(x) = 0.5 * (1+ Erf((x - \mu) / ( \sigma * \sqrt(2)))) cdf(x) = 0.5 * (1+ Erf((x - \mu) / ( \sigma * \sqrt(2))))
@ -270,14 +270,14 @@ class Normal(Distribution):
def _kl_loss(self, dist, mean_b, sd_b, mean=None, sd=None): def _kl_loss(self, dist, mean_b, sd_b, mean=None, sd=None):
r""" r"""
Evaluate Normal-Normal kl divergence, i.e. KL(a||b). Evaluate Normal-Normal KL divergence, i.e. KL(a||b).
Args: Args:
dist (str): type of the distributions. Should be "Normal" in this case. dist (str): The type of the distributions. Should be "Normal" in this case.
mean_b (Tensor): mean of distribution b. mean_b (Tensor): The mean of distribution b.
sd_b (Tensor): standard deviation distribution b. sd_b (Tensor): The standard deviation distribution b.
mean_a (Tensor): mean of distribution a. Default: self._mean_value. mean_a (Tensor): The mean of distribution a. Default: self._mean_value.
sd_a (Tensor): standard deviation distribution a. Default: self._sd_value. sd_a (Tensor): The standard deviation distribution a. Default: self._sd_value.
.. math:: .. math::
KL(a||b) = 0.5 * (\frac{MEAN(a)}{STD(b)} - \frac{MEAN(b)}{STD(b)}) ^ 2 + KL(a||b) = 0.5 * (\frac{MEAN(a)}{STD(b)} - \frac{MEAN(b)}{STD(b)}) ^ 2 +
@ -298,9 +298,9 @@ class Normal(Distribution):
Sampling. Sampling.
Args: Args:
shape (tuple): shape of the sample. Default: (). shape (tuple): The shape of the sample. Default: ().
mean (Tensor): mean of the samples. Default: self._mean_value. mean (Tensor): The mean of the samples. Default: self._mean_value.
sd (Tensor): standard deviation of the samples. Default: self._sd_value. sd (Tensor): The standard deviation of the samples. Default: self._sd_value.
Returns: Returns:
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.

View File

@ -27,14 +27,14 @@ class TransformedDistribution(Distribution):
to a new distribution through the operation defined by the bijector. to a new distribution through the operation defined by the bijector.
Args: Args:
bijector (Bijector): transformation to perform. bijector (Bijector): The transformation to perform.
distribution (Distribution): The original distribution. distribution (Distribution): The original distribution.
name (str): name of the transformed distribution. Default: transformed_distribution. name (str): The name of the transformed distribution. Default: transformed_distribution.
Note: Note:
The arguments used to initialize the original distribution cannot be None. The arguments used to initialize the original distribution cannot be None.
For example, mynormal = nn.Normal(dtype=dtyple.float32) cannot be used to initialized a For example, mynormal = nn.Normal(dtype=dtyple.float32) cannot be used to initialized a
TransformedDistribution since mean and sd are not specified. TransformedDistribution since `mean` and `sd` are not specified.
Examples: Examples:
>>> # To initialize a transformed distribution, e.g. lognormal distribution, >>> # To initialize a transformed distribution, e.g. lognormal distribution,

View File

@ -26,18 +26,18 @@ class Uniform(Distribution):
Example class: Uniform Distribution. Example class: Uniform Distribution.
Args: Args:
low (int, float, list, numpy.ndarray, Tensor, Parameter): lower bound of the distribution. low (int, float, list, numpy.ndarray, Tensor, Parameter): The lower bound of the distribution.
high (int, float, list, numpy.ndarray, Tensor, Parameter): upper bound of the distribution. high (int, float, list, numpy.ndarray, Tensor, Parameter): The upper bound of the distribution.
seed (int): seed to use in sampling. Global seed is used if it is None. Default: None. seed (int): The seed uses in sampling. Global seed is used if it is None. Default: None.
dtype (mindspore.dtype): type of the distribution. Default: mstype.float32. dtype (mindspore.dtype): The type of the distribution. Default: mstype.float32.
name (str): name of the distribution. Default: Uniform. name (str): The name of the distribution. Default: Uniform.
Note: Note:
low should be stricly less than high. `low` should be stricly less than `high`.
dist_spec_args are high and low. dist_spec_args are `high` and `low`.
dtype should be float type because Uniform distributions are continuous. `dtype` should be float type because Uniform distributions are continuous.
Examples: Examples:
>>> # To initialize a Uniform distribution of mean 3.0 and standard deviation 4.0 >>> # To initialize a Uniform distribution of mean 3.0 and standard deviation 4.0
>>> import mindspore.nn.probability.distribution as msd >>> import mindspore.nn.probability.distribution as msd
>>> u = msd.Uniform(0.0, 1.0, dtype=mstype.float32) >>> u = msd.Uniform(0.0, 1.0, dtype=mstype.float32)
@ -164,7 +164,7 @@ class Uniform(Distribution):
def _check_param(self, low, high): def _check_param(self, low, high):
""" """
Check availablity of distribution specific args low and high. Check availablity of distribution specific args `low` and `high`.
""" """
if low is not None: if low is not None:
if self.context_mode == 0: if self.context_mode == 0:
@ -205,6 +205,7 @@ class Uniform(Distribution):
def _range(self, low=None, high=None): def _range(self, low=None, high=None):
r""" r"""
Return the range of the distribution. Return the range of the distribution.
.. math:: .. math::
range(U) = high -low range(U) = high -low
""" """
@ -240,11 +241,11 @@ class Uniform(Distribution):
Evaluate cross_entropy between Uniform distributoins. Evaluate cross_entropy between Uniform distributoins.
Args: Args:
dist (str): type of the distributions. Should be "Uniform" in this case. dist (str): The type of the distributions. Should be "Uniform" in this case.
low_b (Tensor): lower bound of distribution b. low_b (Tensor): The lower bound of distribution b.
high_b (Tensor): upper bound of distribution b. high_b (Tensor): The upper bound of distribution b.
low_a (Tensor): lower bound of distribution a. Default: self.low. low_a (Tensor): The lower bound of distribution a. Default: self.low.
high_a (Tensor): upper bound of distribution a. Default: self.high. high_a (Tensor): The upper bound of distribution a. Default: self.high.
""" """
check_distribution_name(dist, 'Uniform') check_distribution_name(dist, 'Uniform')
return self._entropy(low, high) + self._kl_loss(dist, low_b, high_b, low, high) return self._entropy(low, high) + self._kl_loss(dist, low_b, high_b, low, high)
@ -254,9 +255,9 @@ class Uniform(Distribution):
pdf of Uniform distribution. pdf of Uniform distribution.
Args: Args:
value (Tensor): value to be evaluated. value (Tensor): The value to be evaluated.
low (Tensor): lower bound of the distribution. Default: self.low. low (Tensor): The lower bound of the distribution. Default: self.low.
high (Tensor): upper bound of the distribution. Default: self.high. high (Tensor): The upper bound of the distribution. Default: self.high.
.. math:: .. math::
pdf(x) = 0 if x < low; pdf(x) = 0 if x < low;
@ -277,14 +278,14 @@ class Uniform(Distribution):
def _kl_loss(self, dist, low_b, high_b, low=None, high=None): def _kl_loss(self, dist, low_b, high_b, low=None, high=None):
""" """
Evaluate uniform-uniform kl divergence, i.e. KL(a||b). Evaluate uniform-uniform KL divergence, i.e. KL(a||b).
Args: Args:
dist (str): type of the distributions. Should be "Uniform" in this case. dist (str): The type of the distributions. Should be "Uniform" in this case.
low_b (Tensor): lower bound of distribution b. low_b (Tensor): The lower bound of distribution b.
high_b (Tensor): upper bound of distribution b. high_b (Tensor): The upper bound of distribution b.
low_a (Tensor): lower bound of distribution a. Default: self.low. low_a (Tensor): The lower bound of distribution a. Default: self.low.
high_a (Tensor): upper bound of distribution a. Default: self.high. high_a (Tensor): The upper bound of distribution a. Default: self.high.
""" """
check_distribution_name(dist, 'Uniform') check_distribution_name(dist, 'Uniform')
low_b = self._check_value(low_b, 'low_b') low_b = self._check_value(low_b, 'low_b')
@ -301,9 +302,9 @@ class Uniform(Distribution):
cdf of Uniform distribution. cdf of Uniform distribution.
Args: Args:
value (Tensor): value to be evaluated. value (Tensor): The value to be evaluated.
low (Tensor): lower bound of the distribution. Default: self.low. low (Tensor): The lower bound of the distribution. Default: self.low.
high (Tensor): upper bound of the distribution. Default: self.high. high (Tensor): The upper bound of the distribution. Default: self.high.
.. math:: .. math::
cdf(x) = 0 if x < low; cdf(x) = 0 if x < low;
@ -327,9 +328,9 @@ class Uniform(Distribution):
Sampling. Sampling.
Args: Args:
shape (tuple): shape of the sample. Default: (). shape (tuple): The shape of the sample. Default: ().
low (Tensor): lower bound of the distribution. Default: self.low. low (Tensor): The lower bound of the distribution. Default: self.low.
high (Tensor): upper bound of the distribution. Default: self.high. high (Tensor): The upper bound of the distribution. Default: self.high.
Returns: Returns:
Tensor, shape is shape + batch_shape. Tensor, shape is shape + batch_shape.

View File

@ -13,8 +13,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" """
Deep Probability Network(dpn).
Deep probability network such as BNN and VAE network. Deep probability network such as BNN and VAE network.
""" """

View File

@ -25,25 +25,26 @@ class ConditionalVAE(Cell):
Conditional Variational Auto-Encoder (CVAE). Conditional Variational Auto-Encoder (CVAE).
The difference with VAE is that CVAE uses labels information. The difference with VAE is that CVAE uses labels information.
see more details in `Learning Structured Output Representation using Deep Conditional Generative Models For more details, refer to `Learning Structured Output Representation using Deep Conditional Generative Models
<http://papers.nips.cc/paper/5775-learning-structured-output-representation-using-deep-conditional- <http://papers.nips.cc/paper/5775-learning-structured-output-representation-using-deep-conditional-
generative-models>`_. generative-models>`_.
Note: Note:
When define the encoder and decoder, the shape of the encoder's output tensor and decoder's input tensor When encoder and decoder ard defined, the shape of the encoder's output tensor and decoder's input tensor
should be :math:`(N, hidden\_size)`. should be :math:`(N, hidden\_size)`.
The latent_size should be less than or equal to the hidden_size. The latent_size should be less than or equal to the hidden_size.
Args: Args:
encoder(Cell): The DNN model defined as encoder. encoder(Cell): The Deep Neural Network (DNN) model defined as encoder.
decoder(Cell): The DNN model defined as decoder. decoder(Cell): The DNN model defined as decoder.
hidden_size(int): The size of encoder's output tensor. hidden_size(int): The size of encoder's output tensor.
latent_size(int): The size of the latent space. latent_size(int): The size of the latent space.
num_classes(int): The number of classes. num_classes(int): The number of classes.
Inputs: Inputs:
- **input_x** (Tensor) - the same shape as the input of encoder, the shape is :math:`(N, C, H, W)`. - **input_x** (Tensor) - The input tensor is the same shape as the input of encoder, with the shape
- **input_y** (Tensor) - the tensor of the target data, the shape is :math:`(N,)`. being :math:`(N, C, H, W)`.
- **input_y** (Tensor) - The tensor of the target data, the shape is :math:`(N,)`.
Outputs: Outputs:
- **output** (tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)). - **output** (tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)).
@ -96,15 +97,15 @@ class ConditionalVAE(Cell):
def generate_sample(self, sample_y, generate_nums, shape): def generate_sample(self, sample_y, generate_nums, shape):
""" """
Randomly sample from latent space to generate sample. Randomly sample from the latent space to generate samples.
Args: Args:
sample_y (Tensor): Define the label of sample. Tensor of shape (generate_nums, ) and type mindspore.int32. sample_y (Tensor): Define the label of samples. Tensor of shape (generate_nums, ) and type mindspore.int32.
generate_nums (int): The number of samples to generate. generate_nums (int): The number of samples to generate.
shape(tuple): The shape of sample, it should be (generate_nums, C, H, W) or (-1, C, H, W). shape(tuple): The shape of sample, which should be the format of (generate_nums, C, H, W) or (-1, C, H, W).
Returns: Returns:
Tensor, the generated sample. Tensor, the generated samples.
""" """
generate_nums = check_int_positive(generate_nums) generate_nums = check_int_positive(generate_nums)
if not isinstance(shape, tuple) or len(shape) != 4 or (shape[0] != -1 and shape[0] != generate_nums): if not isinstance(shape, tuple) or len(shape) != 4 or (shape[0] != -1 and shape[0] != generate_nums):
@ -118,7 +119,7 @@ class ConditionalVAE(Cell):
def reconstruct_sample(self, x, y): def reconstruct_sample(self, x, y):
""" """
Reconstruct sample from original data. Reconstruct samples from original data.
Args: Args:
x (Tensor): The input tensor to be reconstructed, the shape is (N, C, H, W). x (Tensor): The input tensor to be reconstructed, the shape is (N, C, H, W).

View File

@ -25,21 +25,22 @@ class VAE(Cell):
Variational Auto-Encoder (VAE). Variational Auto-Encoder (VAE).
The VAE defines a generative model, `Z` is sampled from the prior, then used to reconstruct `X` by a decoder. The VAE defines a generative model, `Z` is sampled from the prior, then used to reconstruct `X` by a decoder.
see more details in `Auto-Encoding Variational Bayes <https://arxiv.org/abs/1312.6114>`_. For more details, refer to `Auto-Encoding Variational Bayes <https://arxiv.org/abs/1312.6114>`_.
Note: Note:
When define the encoder and decoder, the shape of the encoder's output tensor and decoder's input tensor When the encoder and decoder are defined, the shape of the encoder's output tensor and decoder's input tensor
should be :math:`(N, hidden\_size)`. should be :math:`(N, hidden\_size)`.
The latent_size should be less than or equal to the hidden_size. The latent_size should be less than or equal to the hidden_size.
Args: Args:
encoder(Cell): The DNN model defined as encoder. encoder(Cell): The Deep Neural Network (DNN) model defined as encoder.
decoder(Cell): The DNN model defined as decoder. decoder(Cell): The DNN model defined as decoder.
hidden_size(int): The size of encoder's output tensor. hidden_size(int): The size of encoder's output tensor.
latent_size(int): The size of the latent space. latent_size(int): The size of the latent space.
Inputs: Inputs:
- **input** (Tensor) - the same shape as the input of encoder, the shape is :math:`(N, C, H, W)`. - **input** (Tensor) - The input tensor is the same shape as the input of encoder, the shape
is :math:`(N, C, H, W)`.
Outputs: Outputs:
- **output** (Tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)). - **output** (Tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)).
@ -84,14 +85,14 @@ class VAE(Cell):
def generate_sample(self, generate_nums, shape): def generate_sample(self, generate_nums, shape):
""" """
Randomly sample from latent space to generate sample. Randomly sample from latent space to generate samples.
Args: Args:
generate_nums (int): The number of samples to generate. generate_nums (int): The number of samples to generate.
shape(tuple): The shape of sample, it should be (generate_nums, C, H, W) or (-1, C, H, W). shape(tuple): The shape of sample, it should be (generate_nums, C, H, W) or (-1, C, H, W).
Returns: Returns:
Tensor, the generated sample. Tensor, the generated samples.
""" """
generate_nums = check_int_positive(generate_nums) generate_nums = check_int_positive(generate_nums)
if not isinstance(shape, tuple) or len(shape) != 4 or (shape[0] != -1 and shape[0] != generate_nums): if not isinstance(shape, tuple) or len(shape) != 4 or (shape[0] != -1 and shape[0] != generate_nums):
@ -103,7 +104,7 @@ class VAE(Cell):
def reconstruct_sample(self, x): def reconstruct_sample(self, x):
""" """
Reconstruct sample from original data. Reconstruct samples from original data.
Args: Args:
x (Tensor): The input tensor to be reconstructed, the shape is (N, C, H, W). x (Tensor): The input tensor to be reconstructed, the shape is (N, C, H, W).

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" """
Infer algorithms in Probabilistic Programming. Inference algorithms in Probabilistic Programming.
""" """
from .variational import * from .variational import *

View File

@ -24,10 +24,10 @@ class ELBO(Cell):
The Evidence Lower Bound (ELBO). The Evidence Lower Bound (ELBO).
Variational inference minimizes the Kullback-Leibler (KL) divergence from the variational distribution to Variational inference minimizes the Kullback-Leibler (KL) divergence from the variational distribution to
the posterior distribution. It maximizes the evidence lower bound (ELBO), a lower bound on the logarithm of the posterior distribution. It maximizes the ELBO, a lower bound on the logarithm of
the marginal probability of the observations log p(x). The ELBO is equal to the negative KL divergence up to the marginal probability of the observations log p(x). The ELBO is equal to the negative KL divergence up to
an additive constant. an additive constant.
see more details in `Variational Inference: A Review for Statisticians <https://arxiv.org/abs/1601.00670>`_. For more details, refer to `Variational Inference: A Review for Statisticians <https://arxiv.org/abs/1601.00670>`_.
Args: Args:
latent_prior(str): The prior distribution of latent space. Default: Normal. latent_prior(str): The prior distribution of latent space. Default: Normal.

View File

@ -26,9 +26,9 @@ class SVI:
Stochastic Variational Inference(SVI). Stochastic Variational Inference(SVI).
Variational inference casts the inference problem as an optimization. Some distributions over the hidden Variational inference casts the inference problem as an optimization. Some distributions over the hidden
variables that is indexed by a set of free parameters, and then optimize the parameters to make it closest to variables are indexed by a set of free parameters, which are optimized to make distributions closest to
the posterior of interest. the posterior of interest.
see more details in `Variational Inference: A Review for Statisticians <https://arxiv.org/abs/1601.00670>`_. For more details, refer to `Variational Inference: A Review for Statisticians <https://arxiv.org/abs/1601.00670>`_.
Args: Args:
net_with_loss(Cell): Cell with loss function. net_with_loss(Cell): Cell with loss function.

View File

@ -43,15 +43,15 @@ class UncertaintyEvaluation:
- regression: A regression model. - regression: A regression model.
- classification: A classification model. - classification: A classification model.
num_classes (int): The number of labels of classification. num_classes (int): The number of labels of classification.
If the task type is classification, it must be set; if not classification, it need not to be set. If the task type is classification, it must be set; otherwise, it is not needed.
Default: None. Default: None.
epochs (int): Total number of iterations on the data. Default: 1. epochs (int): Total number of iterations on the data. Default: 1.
epi_uncer_model_path (str): The save or read path of the epistemic uncertainty model. Default: None. epi_uncer_model_path (str): The save or read path of the epistemic uncertainty model. Default: None.
ale_uncer_model_path (str): The save or read path of the aleatoric uncertainty model. Default: None. ale_uncer_model_path (str): The save or read path of the aleatoric uncertainty model. Default: None.
save_model (bool): Save the uncertainty model or not, if True, the epi_uncer_model_path save_model (bool): Whether to save the uncertainty model or not, if True, the epi_uncer_model_path
and ale_uncer_model_path should not be None. If False, give the path of and ale_uncer_model_path should not be None. If False, the model to evaluate will be loaded from
the uncertainty model, it will load the model to evaluate, if not given the the path of the uncertainty model; if the path is not given , it will not save or load the
the path, it will not save or load the uncertainty model. Default: False. uncertainty model. Default: False.
Examples: Examples:
>>> network = LeNet() >>> network = LeNet()

View File

@ -13,9 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" """
Transforms. The high-level components used to transform model between Deep Neural Network (DNN) and Bayesian Neural Network (BNN).
The high-level components used to transform model between DNN and BNN.
""" """
from . import transform_bnn from . import transform_bnn
from .transform_bnn import TransformToBNN from .transform_bnn import TransformToBNN

View File

@ -32,7 +32,7 @@ class TransformToBNN:
Args: Args:
trainable_dnn (Cell): A trainable DNN model (backbone) wrapped by TrainOneStepCell. trainable_dnn (Cell): A trainable DNN model (backbone) wrapped by TrainOneStepCell.
dnn_factor ((int, float): The coefficient of backbone's loss, which is computed by loss function. Default: 1. dnn_factor ((int, float): The coefficient of backbone's loss, which is computed by loss function. Default: 1.
bnn_factor (int, float): The coefficient of kl loss, which is kl divergence of Bayesian layer. Default: 1. bnn_factor (int, float): The coefficient of KL loss, which is KL divergence of Bayesian layer. Default: 1.
Examples: Examples:
>>> class Net(nn.Cell): >>> class Net(nn.Cell):
@ -139,15 +139,15 @@ class TransformToBNN:
Args: Args:
dnn_layer_type (Cell): The type of DNN layer to be transformed to BNN layer. The optional values are dnn_layer_type (Cell): The type of DNN layer to be transformed to BNN layer. The optional values are
nn.Dense, nn.Conv2d. nn.Dense and nn.Conv2d.
bnn_layer_type (Cell): The type of BNN layer to be transformed to. The optional values are bnn_layer_type (Cell): The type of BNN layer to be transformed to. The optional values are
DenseReparam, ConvReparam. DenseReparam and ConvReparam.
get_args: The arguments gotten from the DNN layer. Default: None. get_args: The arguments gotten from the DNN layer. Default: None.
add_args (dict): The new arguments added to BNN layer. Note that the arguments in `add_args` should not add_args (dict): The new arguments added to BNN layer. Note that the arguments in `add_args` should not
duplicate arguments in `get_args`. Default: None. duplicate arguments in `get_args`. Default: None.
Returns: Returns:
Cell, a trainable model wrapped by TrainOneStepCell, whose sprcific type of layer is transformed to the Cell, a trainable model wrapped by TrainOneStepCell, whose specific type of layer is transformed to the
corresponding bayesian layer. corresponding bayesian layer.
Examples: Examples: