forked from mindspore-Ecosystem/mindspore
!4924 Modify API comments and fix error of st
Merge pull request !4924 from byweng/fix_param_check
This commit is contained in:
commit
ab45bec828
|
@ -67,8 +67,13 @@ class WithBNNLossCell:
|
|||
def __init__(self, backbone, loss_fn, dnn_factor=1, bnn_factor=1):
|
||||
if isinstance(dnn_factor, bool) or not isinstance(dnn_factor, (int, float)):
|
||||
raise TypeError('The type of `dnn_factor` should be `int` or `float`')
|
||||
if dnn_factor < 0:
|
||||
raise ValueError('The value of `dnn_factor` should >= 0')
|
||||
|
||||
if isinstance(bnn_factor, bool) or not isinstance(bnn_factor, (int, float)):
|
||||
raise TypeError('The type of `bnn_factor` should be `int` or `float`')
|
||||
if bnn_factor < 0:
|
||||
raise ValueError('The value of `bnn_factor` should >= 0')
|
||||
|
||||
self.backbone = backbone
|
||||
self.loss_fn = loss_fn
|
||||
|
|
|
@ -61,12 +61,6 @@ class _ConvVariational(_Conv):
|
|||
raise ValueError('Attr \'pad_mode\' of \'Conv2d\' Op passed '
|
||||
+ str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.')
|
||||
|
||||
if isinstance(stride, bool) or not isinstance(stride, (int, tuple)):
|
||||
raise TypeError('The type of `stride` should be `int` of `tuple`')
|
||||
|
||||
if isinstance(dilation, bool) or not isinstance(dilation, (int, tuple)):
|
||||
raise TypeError('The type of `dilation` should be `int` of `tuple`')
|
||||
|
||||
# convolution args
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
|
|
@ -29,7 +29,7 @@ class NormalPrior(Cell):
|
|||
To initialize a normal distribution of mean 0 and standard deviation 0.1.
|
||||
|
||||
Args:
|
||||
dtype (class `mindspore.dtype`): The argument is used to define the data type of the output tensor.
|
||||
dtype (:class:`mindspore.dtype`): The argument is used to define the data type of the output tensor.
|
||||
Default: mindspore.float32.
|
||||
mean (int, float): Mean of normal distribution.
|
||||
std (int, float): Standard deviation of normal distribution.
|
||||
|
@ -52,7 +52,7 @@ class NormalPosterior(Cell):
|
|||
Args:
|
||||
name (str): Name prepended to trainable parameter.
|
||||
shape (list, tuple): Shape of the mean and standard deviation.
|
||||
dtype (class `mindspore.dtype`): The argument is used to define the data type of the output tensor.
|
||||
dtype (:class:`mindspore.dtype`): The argument is used to define the data type of the output tensor.
|
||||
Default: mindspore.float32.
|
||||
loc_mean (int, float): Mean of distribution to initialize trainable parameters. Default: 0.
|
||||
loc_std (int, float): Standard deviation of distribution to initialize trainable parameters. Default: 0.1.
|
||||
|
|
|
@ -63,8 +63,13 @@ class TransformToBNN:
|
|||
def __init__(self, trainable_dnn, dnn_factor=1, bnn_factor=1):
|
||||
if isinstance(dnn_factor, bool) or not isinstance(dnn_factor, (int, float)):
|
||||
raise TypeError('The type of `dnn_factor` should be `int` or `float`')
|
||||
if dnn_factor < 0:
|
||||
raise ValueError('The value of `dnn_factor` should >= 0')
|
||||
|
||||
if isinstance(bnn_factor, bool) or not isinstance(bnn_factor, (int, float)):
|
||||
raise TypeError('The type of `bnn_factor` should be `int` or `float`')
|
||||
if bnn_factor < 0:
|
||||
raise ValueError('The value of `bnn_factor` should >= 0')
|
||||
|
||||
net_with_loss = trainable_dnn.network
|
||||
self.optimizer = trainable_dnn.optimizer
|
||||
|
@ -88,9 +93,9 @@ class TransformToBNN:
|
|||
Transform the whole DNN model to BNN model, and wrap BNN model by TrainOneStepCell.
|
||||
|
||||
Args:
|
||||
get_dense_args (function): The arguments gotten from the DNN full connection layer. Default: lambda dp:
|
||||
get_dense_args (:class:`function`): The arguments gotten from the DNN full connection layer. Default: lambda dp:
|
||||
{"in_channels": dp.in_channels, "out_channels": dp.out_channels, "has_bias": dp.has_bias}.
|
||||
get_conv_args (function): The arguments gotten from the DNN convolutional layer. Default: lambda dp:
|
||||
get_conv_args (:class:`function`): The arguments gotten from the DNN convolutional layer. Default: lambda dp:
|
||||
{"in_channels": dp.in_channels, "out_channels": dp.out_channels, "pad_mode": dp.pad_mode,
|
||||
"kernel_size": dp.kernel_size, "stride": dp.stride, "has_bias": dp.has_bias}.
|
||||
add_dense_args (dict): The new arguments added to BNN full connection layer. Note that the arguments in
|
||||
|
@ -134,10 +139,10 @@ class TransformToBNN:
|
|||
|
||||
Args:
|
||||
dnn_layer_type (Cell): The type of DNN layer to be transformed to BNN layer. The optional values are
|
||||
nn.Dense, nn.Conv2d.
|
||||
nn.Dense, nn.Conv2d.
|
||||
bnn_layer_type (Cell): The type of BNN layer to be transformed to. The optional values are
|
||||
DenseReparameterization, ConvReparameterization.
|
||||
get_args (dict): The arguments gotten from the DNN layer. Default: None.
|
||||
DenseReparam, ConvReparam.
|
||||
get_args (:class:`function`): The arguments gotten from the DNN layer. Default: None.
|
||||
add_args (dict): The new arguments added to BNN layer. Note that the arguments in `add_args` should not
|
||||
duplicate arguments in `get_args`. Default: None.
|
||||
|
||||
|
|
|
@ -108,22 +108,22 @@ class VaeGan(nn.Cell):
|
|||
return ld_real, ld_fake, ld_p, recon_x, x, mu, std
|
||||
|
||||
|
||||
class VaeGanLoss(nn.Cell):
|
||||
class VaeGanLoss(ELBO):
|
||||
def __init__(self):
|
||||
super(VaeGanLoss, self).__init__()
|
||||
self.zeros = P.ZerosLike()
|
||||
self.mse = nn.MSELoss(reduction='sum')
|
||||
self.elbo = ELBO(latent_prior='Normal', output_prior='Normal')
|
||||
|
||||
def construct(self, data, label):
|
||||
ld_real, ld_fake, ld_p, recon_x, x, mean, std = data
|
||||
ld_real, ld_fake, ld_p, recon_x, x, mu, std = data
|
||||
y_real = self.zeros(ld_real) + 1
|
||||
y_fake = self.zeros(ld_fake)
|
||||
elbo_data = (recon_x, x, mean, std)
|
||||
loss_D = self.mse(ld_real, y_real)
|
||||
loss_GD = self.mse(ld_p, y_fake)
|
||||
loss_G = self.mse(ld_fake, y_real)
|
||||
elbo_loss = self.elbo(elbo_data, label)
|
||||
reconstruct_loss = self.recon_loss(x, recon_x)
|
||||
kl_loss = self.posterior('kl_loss', 'Normal', self.zeros(mu), self.zeros(mu) + 1, mu, std)
|
||||
elbo_loss = reconstruct_loss + self.sum(kl_loss)
|
||||
return loss_D + loss_G + loss_GD + elbo_loss
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue