forked from mindspore-Ecosystem/mindspore
!4924 Modify API comments and fix error of st
Merge pull request !4924 from byweng/fix_param_check
This commit is contained in:
commit
ab45bec828
|
@ -67,8 +67,13 @@ class WithBNNLossCell:
|
||||||
def __init__(self, backbone, loss_fn, dnn_factor=1, bnn_factor=1):
|
def __init__(self, backbone, loss_fn, dnn_factor=1, bnn_factor=1):
|
||||||
if isinstance(dnn_factor, bool) or not isinstance(dnn_factor, (int, float)):
|
if isinstance(dnn_factor, bool) or not isinstance(dnn_factor, (int, float)):
|
||||||
raise TypeError('The type of `dnn_factor` should be `int` or `float`')
|
raise TypeError('The type of `dnn_factor` should be `int` or `float`')
|
||||||
|
if dnn_factor < 0:
|
||||||
|
raise ValueError('The value of `dnn_factor` should >= 0')
|
||||||
|
|
||||||
if isinstance(bnn_factor, bool) or not isinstance(bnn_factor, (int, float)):
|
if isinstance(bnn_factor, bool) or not isinstance(bnn_factor, (int, float)):
|
||||||
raise TypeError('The type of `bnn_factor` should be `int` or `float`')
|
raise TypeError('The type of `bnn_factor` should be `int` or `float`')
|
||||||
|
if bnn_factor < 0:
|
||||||
|
raise ValueError('The value of `bnn_factor` should >= 0')
|
||||||
|
|
||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
self.loss_fn = loss_fn
|
self.loss_fn = loss_fn
|
||||||
|
|
|
@ -61,12 +61,6 @@ class _ConvVariational(_Conv):
|
||||||
raise ValueError('Attr \'pad_mode\' of \'Conv2d\' Op passed '
|
raise ValueError('Attr \'pad_mode\' of \'Conv2d\' Op passed '
|
||||||
+ str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.')
|
+ str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.')
|
||||||
|
|
||||||
if isinstance(stride, bool) or not isinstance(stride, (int, tuple)):
|
|
||||||
raise TypeError('The type of `stride` should be `int` of `tuple`')
|
|
||||||
|
|
||||||
if isinstance(dilation, bool) or not isinstance(dilation, (int, tuple)):
|
|
||||||
raise TypeError('The type of `dilation` should be `int` of `tuple`')
|
|
||||||
|
|
||||||
# convolution args
|
# convolution args
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
|
|
|
@ -29,7 +29,7 @@ class NormalPrior(Cell):
|
||||||
To initialize a normal distribution of mean 0 and standard deviation 0.1.
|
To initialize a normal distribution of mean 0 and standard deviation 0.1.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dtype (class `mindspore.dtype`): The argument is used to define the data type of the output tensor.
|
dtype (:class:`mindspore.dtype`): The argument is used to define the data type of the output tensor.
|
||||||
Default: mindspore.float32.
|
Default: mindspore.float32.
|
||||||
mean (int, float): Mean of normal distribution.
|
mean (int, float): Mean of normal distribution.
|
||||||
std (int, float): Standard deviation of normal distribution.
|
std (int, float): Standard deviation of normal distribution.
|
||||||
|
@ -52,7 +52,7 @@ class NormalPosterior(Cell):
|
||||||
Args:
|
Args:
|
||||||
name (str): Name prepended to trainable parameter.
|
name (str): Name prepended to trainable parameter.
|
||||||
shape (list, tuple): Shape of the mean and standard deviation.
|
shape (list, tuple): Shape of the mean and standard deviation.
|
||||||
dtype (class `mindspore.dtype`): The argument is used to define the data type of the output tensor.
|
dtype (:class:`mindspore.dtype`): The argument is used to define the data type of the output tensor.
|
||||||
Default: mindspore.float32.
|
Default: mindspore.float32.
|
||||||
loc_mean (int, float): Mean of distribution to initialize trainable parameters. Default: 0.
|
loc_mean (int, float): Mean of distribution to initialize trainable parameters. Default: 0.
|
||||||
loc_std (int, float): Standard deviation of distribution to initialize trainable parameters. Default: 0.1.
|
loc_std (int, float): Standard deviation of distribution to initialize trainable parameters. Default: 0.1.
|
||||||
|
|
|
@ -63,8 +63,13 @@ class TransformToBNN:
|
||||||
def __init__(self, trainable_dnn, dnn_factor=1, bnn_factor=1):
|
def __init__(self, trainable_dnn, dnn_factor=1, bnn_factor=1):
|
||||||
if isinstance(dnn_factor, bool) or not isinstance(dnn_factor, (int, float)):
|
if isinstance(dnn_factor, bool) or not isinstance(dnn_factor, (int, float)):
|
||||||
raise TypeError('The type of `dnn_factor` should be `int` or `float`')
|
raise TypeError('The type of `dnn_factor` should be `int` or `float`')
|
||||||
|
if dnn_factor < 0:
|
||||||
|
raise ValueError('The value of `dnn_factor` should >= 0')
|
||||||
|
|
||||||
if isinstance(bnn_factor, bool) or not isinstance(bnn_factor, (int, float)):
|
if isinstance(bnn_factor, bool) or not isinstance(bnn_factor, (int, float)):
|
||||||
raise TypeError('The type of `bnn_factor` should be `int` or `float`')
|
raise TypeError('The type of `bnn_factor` should be `int` or `float`')
|
||||||
|
if bnn_factor < 0:
|
||||||
|
raise ValueError('The value of `bnn_factor` should >= 0')
|
||||||
|
|
||||||
net_with_loss = trainable_dnn.network
|
net_with_loss = trainable_dnn.network
|
||||||
self.optimizer = trainable_dnn.optimizer
|
self.optimizer = trainable_dnn.optimizer
|
||||||
|
@ -88,9 +93,9 @@ class TransformToBNN:
|
||||||
Transform the whole DNN model to BNN model, and wrap BNN model by TrainOneStepCell.
|
Transform the whole DNN model to BNN model, and wrap BNN model by TrainOneStepCell.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
get_dense_args (function): The arguments gotten from the DNN full connection layer. Default: lambda dp:
|
get_dense_args (:class:`function`): The arguments gotten from the DNN full connection layer. Default: lambda dp:
|
||||||
{"in_channels": dp.in_channels, "out_channels": dp.out_channels, "has_bias": dp.has_bias}.
|
{"in_channels": dp.in_channels, "out_channels": dp.out_channels, "has_bias": dp.has_bias}.
|
||||||
get_conv_args (function): The arguments gotten from the DNN convolutional layer. Default: lambda dp:
|
get_conv_args (:class:`function`): The arguments gotten from the DNN convolutional layer. Default: lambda dp:
|
||||||
{"in_channels": dp.in_channels, "out_channels": dp.out_channels, "pad_mode": dp.pad_mode,
|
{"in_channels": dp.in_channels, "out_channels": dp.out_channels, "pad_mode": dp.pad_mode,
|
||||||
"kernel_size": dp.kernel_size, "stride": dp.stride, "has_bias": dp.has_bias}.
|
"kernel_size": dp.kernel_size, "stride": dp.stride, "has_bias": dp.has_bias}.
|
||||||
add_dense_args (dict): The new arguments added to BNN full connection layer. Note that the arguments in
|
add_dense_args (dict): The new arguments added to BNN full connection layer. Note that the arguments in
|
||||||
|
@ -134,10 +139,10 @@ class TransformToBNN:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dnn_layer_type (Cell): The type of DNN layer to be transformed to BNN layer. The optional values are
|
dnn_layer_type (Cell): The type of DNN layer to be transformed to BNN layer. The optional values are
|
||||||
nn.Dense, nn.Conv2d.
|
nn.Dense, nn.Conv2d.
|
||||||
bnn_layer_type (Cell): The type of BNN layer to be transformed to. The optional values are
|
bnn_layer_type (Cell): The type of BNN layer to be transformed to. The optional values are
|
||||||
DenseReparameterization, ConvReparameterization.
|
DenseReparam, ConvReparam.
|
||||||
get_args (dict): The arguments gotten from the DNN layer. Default: None.
|
get_args (:class:`function`): The arguments gotten from the DNN layer. Default: None.
|
||||||
add_args (dict): The new arguments added to BNN layer. Note that the arguments in `add_args` should not
|
add_args (dict): The new arguments added to BNN layer. Note that the arguments in `add_args` should not
|
||||||
duplicate arguments in `get_args`. Default: None.
|
duplicate arguments in `get_args`. Default: None.
|
||||||
|
|
||||||
|
|
|
@ -108,22 +108,22 @@ class VaeGan(nn.Cell):
|
||||||
return ld_real, ld_fake, ld_p, recon_x, x, mu, std
|
return ld_real, ld_fake, ld_p, recon_x, x, mu, std
|
||||||
|
|
||||||
|
|
||||||
class VaeGanLoss(nn.Cell):
|
class VaeGanLoss(ELBO):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(VaeGanLoss, self).__init__()
|
super(VaeGanLoss, self).__init__()
|
||||||
self.zeros = P.ZerosLike()
|
self.zeros = P.ZerosLike()
|
||||||
self.mse = nn.MSELoss(reduction='sum')
|
self.mse = nn.MSELoss(reduction='sum')
|
||||||
self.elbo = ELBO(latent_prior='Normal', output_prior='Normal')
|
|
||||||
|
|
||||||
def construct(self, data, label):
|
def construct(self, data, label):
|
||||||
ld_real, ld_fake, ld_p, recon_x, x, mean, std = data
|
ld_real, ld_fake, ld_p, recon_x, x, mu, std = data
|
||||||
y_real = self.zeros(ld_real) + 1
|
y_real = self.zeros(ld_real) + 1
|
||||||
y_fake = self.zeros(ld_fake)
|
y_fake = self.zeros(ld_fake)
|
||||||
elbo_data = (recon_x, x, mean, std)
|
|
||||||
loss_D = self.mse(ld_real, y_real)
|
loss_D = self.mse(ld_real, y_real)
|
||||||
loss_GD = self.mse(ld_p, y_fake)
|
loss_GD = self.mse(ld_p, y_fake)
|
||||||
loss_G = self.mse(ld_fake, y_real)
|
loss_G = self.mse(ld_fake, y_real)
|
||||||
elbo_loss = self.elbo(elbo_data, label)
|
reconstruct_loss = self.recon_loss(x, recon_x)
|
||||||
|
kl_loss = self.posterior('kl_loss', 'Normal', self.zeros(mu), self.zeros(mu) + 1, mu, std)
|
||||||
|
elbo_loss = reconstruct_loss + self.sum(kl_loss)
|
||||||
return loss_D + loss_G + loss_GD + elbo_loss
|
return loss_D + loss_G + loss_GD + elbo_loss
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue