fix bug of docs and add DenseLocalReparam
This commit is contained in:
parent
55cc959ac7
commit
cf9ce2b8fb
|
@ -18,7 +18,7 @@
|
|||
"""
|
||||
from . import conv_variational, dense_variational, layer_distribution, bnn_cell_wrapper
|
||||
from .conv_variational import ConvReparam
|
||||
from .dense_variational import DenseReparam
|
||||
from .dense_variational import DenseReparam, DenseLocalReparam
|
||||
from .layer_distribution import NormalPrior, NormalPosterior
|
||||
from .bnn_cell_wrapper import WithBNNLossCell
|
||||
|
||||
|
|
|
@ -37,6 +37,9 @@ class WithBNNLossCell(Cell):
|
|||
Outputs:
|
||||
Tensor, a scalar tensor with shape :math:`()`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
|
||||
|
|
|
@ -157,18 +157,16 @@ class _ConvVariational(_Conv):
|
|||
|
||||
def compute_kl_loss(self):
|
||||
"""Compute kl loss"""
|
||||
weight_post_mean = self.weight_posterior("mean")
|
||||
weight_post_sd = self.weight_posterior("sd")
|
||||
weight_args_list = self.weight_posterior("get_dist_args")
|
||||
weight_type = self.weight_posterior("get_dist_type")
|
||||
|
||||
kl = self.weight_prior("kl_loss", "Normal",
|
||||
weight_post_mean, weight_post_sd)
|
||||
kl = self.weight_prior("kl_loss", weight_type, *weight_args_list)
|
||||
kl_loss = self.sum(kl)
|
||||
if self.has_bias:
|
||||
bias_post_mean = self.bias_posterior("mean")
|
||||
bias_post_sd = self.bias_posterior("sd")
|
||||
bias_args_list = self.bias_posterior("get_dist_args")
|
||||
bias_type = self.bias_posterior("get_dist_type")
|
||||
|
||||
kl = self.bias_prior("kl_loss", "Normal",
|
||||
bias_post_mean, bias_post_sd)
|
||||
kl = self.bias_prior("kl_loss", bias_type, *bias_args_list)
|
||||
kl = self.sum(kl)
|
||||
kl_loss += kl
|
||||
return kl_loss
|
||||
|
@ -249,6 +247,9 @@ class ConvReparam(_ConvVariational):
|
|||
Outputs:
|
||||
Tensor, with the shape being :math:`(N, C_{out}, H_{out}, W_{out})`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> net = ConvReparam(120, 240, 4, has_bias=False)
|
||||
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
|
||||
|
|
|
@ -18,9 +18,10 @@ from mindspore.common.tensor import Tensor
|
|||
from mindspore._checkparam import Validator
|
||||
from ...cell import Cell
|
||||
from ...layer.activation import get_activation
|
||||
from ..distribution.normal import Normal
|
||||
from .layer_distribution import NormalPrior, NormalPosterior
|
||||
|
||||
__all__ = ['DenseReparam']
|
||||
__all__ = ['DenseReparam', 'DenseLocalReparam']
|
||||
|
||||
|
||||
class _DenseVariational(Cell):
|
||||
|
@ -122,17 +123,17 @@ class _DenseVariational(Cell):
|
|||
return self.bias_add(inputs, bias_posterior_tensor)
|
||||
|
||||
def compute_kl_loss(self):
|
||||
"""Compute kl loss."""
|
||||
weight_post_mean = self.weight_posterior("mean")
|
||||
weight_post_sd = self.weight_posterior("sd")
|
||||
"""Compute kl loss"""
|
||||
weight_args_list = self.weight_posterior("get_dist_args")
|
||||
weight_type = self.weight_posterior("get_dist_type")
|
||||
|
||||
kl = self.weight_prior("kl_loss", "Normal", weight_post_mean, weight_post_sd)
|
||||
kl = self.weight_prior("kl_loss", weight_type, *weight_args_list)
|
||||
kl_loss = self.sum(kl)
|
||||
if self.has_bias:
|
||||
bias_post_mean = self.bias_posterior("mean")
|
||||
bias_post_sd = self.bias_posterior("sd")
|
||||
bias_args_list = self.bias_posterior("get_dist_args")
|
||||
bias_type = self.bias_posterior("get_dist_type")
|
||||
|
||||
kl = self.bias_prior("kl_loss", "Normal", bias_post_mean, bias_post_sd)
|
||||
kl = self.bias_prior("kl_loss", bias_type, *bias_args_list)
|
||||
kl = self.sum(kl)
|
||||
kl_loss += kl
|
||||
return kl_loss
|
||||
|
@ -187,6 +188,9 @@ class DenseReparam(_DenseVariational):
|
|||
Outputs:
|
||||
Tensor, the shape of the tensor is :math:`(N, out\_channels)`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> net = DenseReparam(3, 4)
|
||||
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
|
||||
|
@ -220,3 +224,95 @@ class DenseReparam(_DenseVariational):
|
|||
weight_posterior_tensor = self.weight_posterior("sample")
|
||||
outputs = self.matmul(inputs, weight_posterior_tensor)
|
||||
return outputs
|
||||
|
||||
|
||||
class DenseLocalReparam(_DenseVariational):
|
||||
r"""
|
||||
Dense variational layers with Local Reparameterization.
|
||||
|
||||
For more details, refer to the paper `Variational Dropout and the Local Reparameterization
|
||||
Trick <https://arxiv.org/abs/1506.02557>`_.
|
||||
|
||||
Applies dense-connected layer to the input. This layer implements the operation as:
|
||||
|
||||
.. math::
|
||||
\text{outputs} = \text{activation}(\text{inputs} * \text{weight} + \text{bias}),
|
||||
|
||||
where :math:`\text{activation}` is the activation function passed as the activation
|
||||
argument (if passed in), :math:`\text{activation}` is a weight matrix with the same
|
||||
data type as the inputs created by the layer, :math:`\text{weight}` is a weight
|
||||
matrix sampling from posterior distribution of weight, and :math:`\text{bias}` is a
|
||||
bias vector with the same data type as the inputs created by the layer (only if
|
||||
has_bias is True). The bias vector is sampling from posterior distribution of
|
||||
:math:`\text{bias}`.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channel.
|
||||
out_channels (int): The number of output channel .
|
||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
|
||||
activation (str, Cell): A regularization function applied to the output of the layer. The type of `activation`
|
||||
can be a string (eg. 'relu') or a Cell (eg. nn.ReLU()). Note that if the type of activation is Cell, it must
|
||||
be instantiated beforehand. Default: None.
|
||||
weight_prior_fn: The prior distribution for weight.
|
||||
It must return a mindspore distribution instance.
|
||||
Default: NormalPrior. (which creates an instance of standard
|
||||
normal distribution). The current version only supports normal distribution.
|
||||
weight_posterior_fn: The posterior distribution for sampling weight.
|
||||
It must be a function handle which returns a mindspore
|
||||
distribution instance. Default: lambda name, shape: NormalPosterior(name=name, shape=shape).
|
||||
The current version only supports normal distribution.
|
||||
bias_prior_fn: The prior distribution for bias vector. It must return
|
||||
a mindspore distribution. Default: NormalPrior(which creates an
|
||||
instance of standard normal distribution). The current version
|
||||
only supports normal distribution.
|
||||
bias_posterior_fn: The posterior distribution for sampling bias vector.
|
||||
It must be a function handle which returns a mindspore
|
||||
distribution instance. Default: lambda name, shape: NormalPosterior(name=name, shape=shape).
|
||||
The current version only supports normal distribution.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - The shape of the tensor is :math:`(N, in\_channels)`.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape of the tensor is :math:`(N, out\_channels)`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> net = DenseLocalReparam(3, 4)
|
||||
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
|
||||
>>> output = net(input).shape
|
||||
>>> print(output)
|
||||
(2, 4)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
activation=None,
|
||||
has_bias=True,
|
||||
weight_prior_fn=NormalPrior,
|
||||
weight_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape),
|
||||
bias_prior_fn=NormalPrior,
|
||||
bias_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape)):
|
||||
super(DenseLocalReparam, self).__init__(
|
||||
in_channels,
|
||||
out_channels,
|
||||
activation=activation,
|
||||
has_bias=has_bias,
|
||||
weight_prior_fn=weight_prior_fn,
|
||||
weight_posterior_fn=weight_posterior_fn,
|
||||
bias_prior_fn=bias_prior_fn,
|
||||
bias_posterior_fn=bias_posterior_fn
|
||||
)
|
||||
self.sqrt = P.Sqrt()
|
||||
self.square = P.Square()
|
||||
self.normal = Normal()
|
||||
|
||||
def _apply_variational_weight(self, inputs):
|
||||
mean = self.matmul(inputs, self.weight_posterior("mean"))
|
||||
std = self.sqrt(self.matmul(self.square(inputs), self.square(self.weight_posterior("sd"))))
|
||||
weight_posterior_affine_tensor = self.normal("sample", mean=mean, sd=std)
|
||||
return weight_posterior_affine_tensor
|
||||
|
|
|
@ -36,6 +36,9 @@ class NormalPrior(Cell):
|
|||
|
||||
Returns:
|
||||
Cell, a normal distribution.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
"""
|
||||
def __init__(self, dtype=mstype.float32, mean=0, std=0.1):
|
||||
super(NormalPrior, self).__init__()
|
||||
|
@ -62,6 +65,9 @@ class NormalPosterior(Cell):
|
|||
|
||||
Returns:
|
||||
Cell, a normal distribution.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
"""
|
||||
def __init__(self,
|
||||
name,
|
||||
|
|
|
@ -49,6 +49,9 @@ class ConditionalVAE(Cell):
|
|||
|
||||
Outputs:
|
||||
- **output** (tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)).
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
"""
|
||||
|
||||
def __init__(self, encoder, decoder, hidden_size, latent_size, num_classes):
|
||||
|
|
|
@ -44,6 +44,9 @@ class VAE(Cell):
|
|||
|
||||
Outputs:
|
||||
- **output** (Tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)).
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
"""
|
||||
|
||||
def __init__(self, encoder, decoder, hidden_size, latent_size):
|
||||
|
|
|
@ -41,6 +41,9 @@ class ELBO(Cell):
|
|||
|
||||
Outputs:
|
||||
Tensor, loss float tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
"""
|
||||
|
||||
def __init__(self, latent_prior='Normal', output_prior='Normal'):
|
||||
|
|
|
@ -34,6 +34,9 @@ class SVI:
|
|||
net_with_loss(Cell): Cell with loss function.
|
||||
optimizer (Cell): Optimizer for updating the weights.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, net_with_loss, optimizer):
|
||||
|
|
|
@ -34,6 +34,9 @@ class VAEAnomalyDetection:
|
|||
hidden_size(int): The size of encoder's output tensor.
|
||||
latent_size(int): The size of the latent space.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, encoder, decoder, hidden_size=400, latent_size=20):
|
||||
|
|
|
@ -53,6 +53,9 @@ class UncertaintyEvaluation:
|
|||
the the path of the uncertainty model; if the path is not given , it will not save or load the
|
||||
uncertainty model. Default: False.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> network = LeNet()
|
||||
>>> param_dict = load_checkpoint('checkpoint_lenet.ckpt')
|
||||
|
|
|
@ -34,6 +34,9 @@ class TransformToBNN:
|
|||
dnn_factor ((int, float): The coefficient of backbone's loss, which is computed by loss function. Default: 1.
|
||||
bnn_factor (int, float): The coefficient of KL loss, which is KL divergence of Bayesian layer. Default: 1.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> class Net(nn.Cell):
|
||||
... def __init__(self):
|
||||
|
@ -57,7 +60,7 @@ class TransformToBNN:
|
|||
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
>>> net_with_loss = WithLossCell(network, criterion)
|
||||
>>> train_network = TrainOneStepCell(net_with_loss, optim)
|
||||
>>> bnn_transformer = TransformToBNN(train_network, 60000, 0.1)
|
||||
>>> bnn_transformer = TransformToBNN(train_network, 60000, 0.0001)
|
||||
"""
|
||||
|
||||
def __init__(self, trainable_dnn, dnn_factor=1, bnn_factor=1):
|
||||
|
@ -105,6 +108,9 @@ class TransformToBNN:
|
|||
Returns:
|
||||
Cell, a trainable BNN model wrapped by TrainOneStepCell.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
>>> criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
||||
|
@ -147,6 +153,9 @@ class TransformToBNN:
|
|||
Cell, a trainable model wrapped by TrainOneStepCell, whose specific type of layer is transformed to the
|
||||
corresponding bayesian layer.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
>>> criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
||||
|
|
Loading…
Reference in New Issue