From cf9ce2b8fb5c0efa3e9f2b22479f9631251df795 Mon Sep 17 00:00:00 2001 From: bingyaweng Date: Tue, 1 Dec 2020 19:57:52 +0800 Subject: [PATCH] fix bug of docs and add DenseLocalReparam --- .../nn/probability/bnn_layers/__init__.py | 2 +- .../bnn_layers/bnn_cell_wrapper.py | 3 + .../bnn_layers/conv_variational.py | 17 +-- .../bnn_layers/dense_variational.py | 112 ++++++++++++++++-- .../bnn_layers/layer_distribution.py | 6 + mindspore/nn/probability/dpn/vae/cvae.py | 3 + mindspore/nn/probability/dpn/vae/vae.py | 3 + .../nn/probability/infer/variational/elbo.py | 3 + .../nn/probability/infer/variational/svi.py | 3 + .../probability/toolbox/anomaly_detection.py | 3 + .../toolbox/uncertainty_evaluation.py | 3 + .../probability/transforms/transform_bnn.py | 11 +- 12 files changed, 151 insertions(+), 18 deletions(-) diff --git a/mindspore/nn/probability/bnn_layers/__init__.py b/mindspore/nn/probability/bnn_layers/__init__.py index abc719eb5a3..905a5e1bebe 100644 --- a/mindspore/nn/probability/bnn_layers/__init__.py +++ b/mindspore/nn/probability/bnn_layers/__init__.py @@ -18,7 +18,7 @@ """ from . import conv_variational, dense_variational, layer_distribution, bnn_cell_wrapper from .conv_variational import ConvReparam -from .dense_variational import DenseReparam +from .dense_variational import DenseReparam, DenseLocalReparam from .layer_distribution import NormalPrior, NormalPosterior from .bnn_cell_wrapper import WithBNNLossCell diff --git a/mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py b/mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py index 72772a81518..06f6a013709 100644 --- a/mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +++ b/mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py @@ -37,6 +37,9 @@ class WithBNNLossCell(Cell): Outputs: Tensor, a scalar tensor with shape :math:`()`. + Supported Platforms: + ``Ascend`` ``GPU`` + Examples: >>> net = Net() >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False) diff --git a/mindspore/nn/probability/bnn_layers/conv_variational.py b/mindspore/nn/probability/bnn_layers/conv_variational.py index b4e55255ee4..fc637bbf0b4 100644 --- a/mindspore/nn/probability/bnn_layers/conv_variational.py +++ b/mindspore/nn/probability/bnn_layers/conv_variational.py @@ -157,18 +157,16 @@ class _ConvVariational(_Conv): def compute_kl_loss(self): """Compute kl loss""" - weight_post_mean = self.weight_posterior("mean") - weight_post_sd = self.weight_posterior("sd") + weight_args_list = self.weight_posterior("get_dist_args") + weight_type = self.weight_posterior("get_dist_type") - kl = self.weight_prior("kl_loss", "Normal", - weight_post_mean, weight_post_sd) + kl = self.weight_prior("kl_loss", weight_type, *weight_args_list) kl_loss = self.sum(kl) if self.has_bias: - bias_post_mean = self.bias_posterior("mean") - bias_post_sd = self.bias_posterior("sd") + bias_args_list = self.bias_posterior("get_dist_args") + bias_type = self.bias_posterior("get_dist_type") - kl = self.bias_prior("kl_loss", "Normal", - bias_post_mean, bias_post_sd) + kl = self.bias_prior("kl_loss", bias_type, *bias_args_list) kl = self.sum(kl) kl_loss += kl return kl_loss @@ -249,6 +247,9 @@ class ConvReparam(_ConvVariational): Outputs: Tensor, with the shape being :math:`(N, C_{out}, H_{out}, W_{out})`. + Supported Platforms: + ``Ascend`` ``GPU`` + Examples: >>> net = ConvReparam(120, 240, 4, has_bias=False) >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32) diff --git a/mindspore/nn/probability/bnn_layers/dense_variational.py b/mindspore/nn/probability/bnn_layers/dense_variational.py index 7400d7c5641..95989f5cd93 100644 --- a/mindspore/nn/probability/bnn_layers/dense_variational.py +++ b/mindspore/nn/probability/bnn_layers/dense_variational.py @@ -18,9 +18,10 @@ from mindspore.common.tensor import Tensor from mindspore._checkparam import Validator from ...cell import Cell from ...layer.activation import get_activation +from ..distribution.normal import Normal from .layer_distribution import NormalPrior, NormalPosterior -__all__ = ['DenseReparam'] +__all__ = ['DenseReparam', 'DenseLocalReparam'] class _DenseVariational(Cell): @@ -122,17 +123,17 @@ class _DenseVariational(Cell): return self.bias_add(inputs, bias_posterior_tensor) def compute_kl_loss(self): - """Compute kl loss.""" - weight_post_mean = self.weight_posterior("mean") - weight_post_sd = self.weight_posterior("sd") + """Compute kl loss""" + weight_args_list = self.weight_posterior("get_dist_args") + weight_type = self.weight_posterior("get_dist_type") - kl = self.weight_prior("kl_loss", "Normal", weight_post_mean, weight_post_sd) + kl = self.weight_prior("kl_loss", weight_type, *weight_args_list) kl_loss = self.sum(kl) if self.has_bias: - bias_post_mean = self.bias_posterior("mean") - bias_post_sd = self.bias_posterior("sd") + bias_args_list = self.bias_posterior("get_dist_args") + bias_type = self.bias_posterior("get_dist_type") - kl = self.bias_prior("kl_loss", "Normal", bias_post_mean, bias_post_sd) + kl = self.bias_prior("kl_loss", bias_type, *bias_args_list) kl = self.sum(kl) kl_loss += kl return kl_loss @@ -187,6 +188,9 @@ class DenseReparam(_DenseVariational): Outputs: Tensor, the shape of the tensor is :math:`(N, out\_channels)`. + Supported Platforms: + ``Ascend`` ``GPU`` + Examples: >>> net = DenseReparam(3, 4) >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) @@ -220,3 +224,95 @@ class DenseReparam(_DenseVariational): weight_posterior_tensor = self.weight_posterior("sample") outputs = self.matmul(inputs, weight_posterior_tensor) return outputs + + +class DenseLocalReparam(_DenseVariational): + r""" + Dense variational layers with Local Reparameterization. + + For more details, refer to the paper `Variational Dropout and the Local Reparameterization + Trick `_. + + Applies dense-connected layer to the input. This layer implements the operation as: + + .. math:: + \text{outputs} = \text{activation}(\text{inputs} * \text{weight} + \text{bias}), + + where :math:`\text{activation}` is the activation function passed as the activation + argument (if passed in), :math:`\text{activation}` is a weight matrix with the same + data type as the inputs created by the layer, :math:`\text{weight}` is a weight + matrix sampling from posterior distribution of weight, and :math:`\text{bias}` is a + bias vector with the same data type as the inputs created by the layer (only if + has_bias is True). The bias vector is sampling from posterior distribution of + :math:`\text{bias}`. + + Args: + in_channels (int): The number of input channel. + out_channels (int): The number of output channel . + has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. + activation (str, Cell): A regularization function applied to the output of the layer. The type of `activation` + can be a string (eg. 'relu') or a Cell (eg. nn.ReLU()). Note that if the type of activation is Cell, it must + be instantiated beforehand. Default: None. + weight_prior_fn: The prior distribution for weight. + It must return a mindspore distribution instance. + Default: NormalPrior. (which creates an instance of standard + normal distribution). The current version only supports normal distribution. + weight_posterior_fn: The posterior distribution for sampling weight. + It must be a function handle which returns a mindspore + distribution instance. Default: lambda name, shape: NormalPosterior(name=name, shape=shape). + The current version only supports normal distribution. + bias_prior_fn: The prior distribution for bias vector. It must return + a mindspore distribution. Default: NormalPrior(which creates an + instance of standard normal distribution). The current version + only supports normal distribution. + bias_posterior_fn: The posterior distribution for sampling bias vector. + It must be a function handle which returns a mindspore + distribution instance. Default: lambda name, shape: NormalPosterior(name=name, shape=shape). + The current version only supports normal distribution. + + Inputs: + - **input** (Tensor) - The shape of the tensor is :math:`(N, in\_channels)`. + + Outputs: + Tensor, the shape of the tensor is :math:`(N, out\_channels)`. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> net = DenseLocalReparam(3, 4) + >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) + >>> output = net(input).shape + >>> print(output) + (2, 4) + """ + + def __init__( + self, + in_channels, + out_channels, + activation=None, + has_bias=True, + weight_prior_fn=NormalPrior, + weight_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape), + bias_prior_fn=NormalPrior, + bias_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape)): + super(DenseLocalReparam, self).__init__( + in_channels, + out_channels, + activation=activation, + has_bias=has_bias, + weight_prior_fn=weight_prior_fn, + weight_posterior_fn=weight_posterior_fn, + bias_prior_fn=bias_prior_fn, + bias_posterior_fn=bias_posterior_fn + ) + self.sqrt = P.Sqrt() + self.square = P.Square() + self.normal = Normal() + + def _apply_variational_weight(self, inputs): + mean = self.matmul(inputs, self.weight_posterior("mean")) + std = self.sqrt(self.matmul(self.square(inputs), self.square(self.weight_posterior("sd")))) + weight_posterior_affine_tensor = self.normal("sample", mean=mean, sd=std) + return weight_posterior_affine_tensor diff --git a/mindspore/nn/probability/bnn_layers/layer_distribution.py b/mindspore/nn/probability/bnn_layers/layer_distribution.py index 9bfe7ec2646..ea221d737b4 100644 --- a/mindspore/nn/probability/bnn_layers/layer_distribution.py +++ b/mindspore/nn/probability/bnn_layers/layer_distribution.py @@ -36,6 +36,9 @@ class NormalPrior(Cell): Returns: Cell, a normal distribution. + + Supported Platforms: + ``Ascend`` ``GPU`` """ def __init__(self, dtype=mstype.float32, mean=0, std=0.1): super(NormalPrior, self).__init__() @@ -62,6 +65,9 @@ class NormalPosterior(Cell): Returns: Cell, a normal distribution. + + Supported Platforms: + ``Ascend`` ``GPU`` """ def __init__(self, name, diff --git a/mindspore/nn/probability/dpn/vae/cvae.py b/mindspore/nn/probability/dpn/vae/cvae.py index 2028059411a..ccda4444f0c 100644 --- a/mindspore/nn/probability/dpn/vae/cvae.py +++ b/mindspore/nn/probability/dpn/vae/cvae.py @@ -49,6 +49,9 @@ class ConditionalVAE(Cell): Outputs: - **output** (tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)). + + Supported Platforms: + ``Ascend`` ``GPU`` """ def __init__(self, encoder, decoder, hidden_size, latent_size, num_classes): diff --git a/mindspore/nn/probability/dpn/vae/vae.py b/mindspore/nn/probability/dpn/vae/vae.py index e743cab28d7..37dc61b6942 100644 --- a/mindspore/nn/probability/dpn/vae/vae.py +++ b/mindspore/nn/probability/dpn/vae/vae.py @@ -44,6 +44,9 @@ class VAE(Cell): Outputs: - **output** (Tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)). + + Supported Platforms: + ``Ascend`` ``GPU`` """ def __init__(self, encoder, decoder, hidden_size, latent_size): diff --git a/mindspore/nn/probability/infer/variational/elbo.py b/mindspore/nn/probability/infer/variational/elbo.py index f76faccb1cf..35051c8639b 100644 --- a/mindspore/nn/probability/infer/variational/elbo.py +++ b/mindspore/nn/probability/infer/variational/elbo.py @@ -41,6 +41,9 @@ class ELBO(Cell): Outputs: Tensor, loss float tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` """ def __init__(self, latent_prior='Normal', output_prior='Normal'): diff --git a/mindspore/nn/probability/infer/variational/svi.py b/mindspore/nn/probability/infer/variational/svi.py index f40ade88cbe..bed8aa99478 100644 --- a/mindspore/nn/probability/infer/variational/svi.py +++ b/mindspore/nn/probability/infer/variational/svi.py @@ -34,6 +34,9 @@ class SVI: net_with_loss(Cell): Cell with loss function. optimizer (Cell): Optimizer for updating the weights. + Supported Platforms: + ``Ascend`` ``GPU`` + """ def __init__(self, net_with_loss, optimizer): diff --git a/mindspore/nn/probability/toolbox/anomaly_detection.py b/mindspore/nn/probability/toolbox/anomaly_detection.py index 4673bace70a..87a635dfc8a 100644 --- a/mindspore/nn/probability/toolbox/anomaly_detection.py +++ b/mindspore/nn/probability/toolbox/anomaly_detection.py @@ -34,6 +34,9 @@ class VAEAnomalyDetection: hidden_size(int): The size of encoder's output tensor. latent_size(int): The size of the latent space. + Supported Platforms: + ``Ascend`` ``GPU`` + """ def __init__(self, encoder, decoder, hidden_size=400, latent_size=20): diff --git a/mindspore/nn/probability/toolbox/uncertainty_evaluation.py b/mindspore/nn/probability/toolbox/uncertainty_evaluation.py index 113c081eee0..b624e079c8b 100644 --- a/mindspore/nn/probability/toolbox/uncertainty_evaluation.py +++ b/mindspore/nn/probability/toolbox/uncertainty_evaluation.py @@ -53,6 +53,9 @@ class UncertaintyEvaluation: the the path of the uncertainty model; if the path is not given , it will not save or load the uncertainty model. Default: False. + Supported Platforms: + ``Ascend`` ``GPU`` + Examples: >>> network = LeNet() >>> param_dict = load_checkpoint('checkpoint_lenet.ckpt') diff --git a/mindspore/nn/probability/transforms/transform_bnn.py b/mindspore/nn/probability/transforms/transform_bnn.py index 7bb15088c66..0d1a08de082 100644 --- a/mindspore/nn/probability/transforms/transform_bnn.py +++ b/mindspore/nn/probability/transforms/transform_bnn.py @@ -34,6 +34,9 @@ class TransformToBNN: dnn_factor ((int, float): The coefficient of backbone's loss, which is computed by loss function. Default: 1. bnn_factor (int, float): The coefficient of KL loss, which is KL divergence of Bayesian layer. Default: 1. + Supported Platforms: + ``Ascend`` ``GPU`` + Examples: >>> class Net(nn.Cell): ... def __init__(self): @@ -57,7 +60,7 @@ class TransformToBNN: >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) >>> net_with_loss = WithLossCell(network, criterion) >>> train_network = TrainOneStepCell(net_with_loss, optim) - >>> bnn_transformer = TransformToBNN(train_network, 60000, 0.1) + >>> bnn_transformer = TransformToBNN(train_network, 60000, 0.0001) """ def __init__(self, trainable_dnn, dnn_factor=1, bnn_factor=1): @@ -105,6 +108,9 @@ class TransformToBNN: Returns: Cell, a trainable BNN model wrapped by TrainOneStepCell. + Supported Platforms: + ``Ascend`` ``GPU`` + Examples: >>> net = Net() >>> criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True) @@ -147,6 +153,9 @@ class TransformToBNN: Cell, a trainable model wrapped by TrainOneStepCell, whose specific type of layer is transformed to the corresponding bayesian layer. + Supported Platforms: + ``Ascend`` ``GPU`` + Examples: >>> net = Net() >>> criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True)