diff --git a/mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py b/mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py index 86d91c03d5f..887fe427e32 100644 --- a/mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +++ b/mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py @@ -67,8 +67,13 @@ class WithBNNLossCell: def __init__(self, backbone, loss_fn, dnn_factor=1, bnn_factor=1): if isinstance(dnn_factor, bool) or not isinstance(dnn_factor, (int, float)): raise TypeError('The type of `dnn_factor` should be `int` or `float`') + if dnn_factor < 0: + raise ValueError('The value of `dnn_factor` should >= 0') + if isinstance(bnn_factor, bool) or not isinstance(bnn_factor, (int, float)): raise TypeError('The type of `bnn_factor` should be `int` or `float`') + if bnn_factor < 0: + raise ValueError('The value of `bnn_factor` should >= 0') self.backbone = backbone self.loss_fn = loss_fn diff --git a/mindspore/nn/probability/bnn_layers/conv_variational.py b/mindspore/nn/probability/bnn_layers/conv_variational.py index f0087870d69..0dabc0d0538 100644 --- a/mindspore/nn/probability/bnn_layers/conv_variational.py +++ b/mindspore/nn/probability/bnn_layers/conv_variational.py @@ -61,12 +61,6 @@ class _ConvVariational(_Conv): raise ValueError('Attr \'pad_mode\' of \'Conv2d\' Op passed ' + str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.') - if isinstance(stride, bool) or not isinstance(stride, (int, tuple)): - raise TypeError('The type of `stride` should be `int` of `tuple`') - - if isinstance(dilation, bool) or not isinstance(dilation, (int, tuple)): - raise TypeError('The type of `dilation` should be `int` of `tuple`') - # convolution args self.in_channels = in_channels self.out_channels = out_channels diff --git a/mindspore/nn/probability/bnn_layers/layer_distribution.py b/mindspore/nn/probability/bnn_layers/layer_distribution.py index 778d4914f77..87c41682b55 100644 --- a/mindspore/nn/probability/bnn_layers/layer_distribution.py +++ b/mindspore/nn/probability/bnn_layers/layer_distribution.py @@ -29,7 +29,7 @@ class NormalPrior(Cell): To initialize a normal distribution of mean 0 and standard deviation 0.1. Args: - dtype (class `mindspore.dtype`): The argument is used to define the data type of the output tensor. + dtype (:class:`mindspore.dtype`): The argument is used to define the data type of the output tensor. Default: mindspore.float32. mean (int, float): Mean of normal distribution. std (int, float): Standard deviation of normal distribution. @@ -52,7 +52,7 @@ class NormalPosterior(Cell): Args: name (str): Name prepended to trainable parameter. shape (list, tuple): Shape of the mean and standard deviation. - dtype (class `mindspore.dtype`): The argument is used to define the data type of the output tensor. + dtype (:class:`mindspore.dtype`): The argument is used to define the data type of the output tensor. Default: mindspore.float32. loc_mean (int, float): Mean of distribution to initialize trainable parameters. Default: 0. loc_std (int, float): Standard deviation of distribution to initialize trainable parameters. Default: 0.1. diff --git a/mindspore/nn/probability/transforms/transform_bnn.py b/mindspore/nn/probability/transforms/transform_bnn.py index dd3fe68e725..3a728e73927 100644 --- a/mindspore/nn/probability/transforms/transform_bnn.py +++ b/mindspore/nn/probability/transforms/transform_bnn.py @@ -63,8 +63,13 @@ class TransformToBNN: def __init__(self, trainable_dnn, dnn_factor=1, bnn_factor=1): if isinstance(dnn_factor, bool) or not isinstance(dnn_factor, (int, float)): raise TypeError('The type of `dnn_factor` should be `int` or `float`') + if dnn_factor < 0: + raise ValueError('The value of `dnn_factor` should >= 0') + if isinstance(bnn_factor, bool) or not isinstance(bnn_factor, (int, float)): raise TypeError('The type of `bnn_factor` should be `int` or `float`') + if bnn_factor < 0: + raise ValueError('The value of `bnn_factor` should >= 0') net_with_loss = trainable_dnn.network self.optimizer = trainable_dnn.optimizer @@ -88,9 +93,9 @@ class TransformToBNN: Transform the whole DNN model to BNN model, and wrap BNN model by TrainOneStepCell. Args: - get_dense_args (function): The arguments gotten from the DNN full connection layer. Default: lambda dp: + get_dense_args (:class:`function`): The arguments gotten from the DNN full connection layer. Default: lambda dp: {"in_channels": dp.in_channels, "out_channels": dp.out_channels, "has_bias": dp.has_bias}. - get_conv_args (function): The arguments gotten from the DNN convolutional layer. Default: lambda dp: + get_conv_args (:class:`function`): The arguments gotten from the DNN convolutional layer. Default: lambda dp: {"in_channels": dp.in_channels, "out_channels": dp.out_channels, "pad_mode": dp.pad_mode, "kernel_size": dp.kernel_size, "stride": dp.stride, "has_bias": dp.has_bias}. add_dense_args (dict): The new arguments added to BNN full connection layer. Note that the arguments in @@ -134,10 +139,10 @@ class TransformToBNN: Args: dnn_layer_type (Cell): The type of DNN layer to be transformed to BNN layer. The optional values are - nn.Dense, nn.Conv2d. + nn.Dense, nn.Conv2d. bnn_layer_type (Cell): The type of BNN layer to be transformed to. The optional values are - DenseReparameterization, ConvReparameterization. - get_args (dict): The arguments gotten from the DNN layer. Default: None. + DenseReparam, ConvReparam. + get_args (:class:`function`): The arguments gotten from the DNN layer. Default: None. add_args (dict): The new arguments added to BNN layer. Note that the arguments in `add_args` should not duplicate arguments in `get_args`. Default: None. diff --git a/tests/st/probability/test_gpu_vae_gan.py b/tests/st/probability/test_gpu_vae_gan.py index adf09275695..7a1108c7154 100644 --- a/tests/st/probability/test_gpu_vae_gan.py +++ b/tests/st/probability/test_gpu_vae_gan.py @@ -108,22 +108,22 @@ class VaeGan(nn.Cell): return ld_real, ld_fake, ld_p, recon_x, x, mu, std -class VaeGanLoss(nn.Cell): +class VaeGanLoss(ELBO): def __init__(self): super(VaeGanLoss, self).__init__() self.zeros = P.ZerosLike() self.mse = nn.MSELoss(reduction='sum') - self.elbo = ELBO(latent_prior='Normal', output_prior='Normal') def construct(self, data, label): - ld_real, ld_fake, ld_p, recon_x, x, mean, std = data + ld_real, ld_fake, ld_p, recon_x, x, mu, std = data y_real = self.zeros(ld_real) + 1 y_fake = self.zeros(ld_fake) - elbo_data = (recon_x, x, mean, std) loss_D = self.mse(ld_real, y_real) loss_GD = self.mse(ld_p, y_fake) loss_G = self.mse(ld_fake, y_real) - elbo_loss = self.elbo(elbo_data, label) + reconstruct_loss = self.recon_loss(x, recon_x) + kl_loss = self.posterior('kl_loss', 'Normal', self.zeros(mu), self.zeros(mu) + 1, mu, std) + elbo_loss = reconstruct_loss + self.sum(kl_loss) return loss_D + loss_G + loss_GD + elbo_loss