forked from mindspore-Ecosystem/mindspore
fic param check
This commit is contained in:
parent
5ab41b1c26
commit
9cc415728b
|
@ -65,9 +65,9 @@ class WithBNNLossCell:
|
|||
"""
|
||||
|
||||
def __init__(self, backbone, loss_fn, dnn_factor=1, bnn_factor=1):
|
||||
if not isinstance(dnn_factor, (int, float)):
|
||||
if isinstance(dnn_factor, bool) or not isinstance(dnn_factor, (int, float)):
|
||||
raise TypeError('The type of `dnn_factor` should be `int` or `float`')
|
||||
if not isinstance(bnn_factor, (int, float)):
|
||||
if isinstance(bnn_factor, bool) or not isinstance(bnn_factor, (int, float)):
|
||||
raise TypeError('The type of `bnn_factor` should be `int` or `float`')
|
||||
|
||||
self.backbone = backbone
|
||||
|
|
|
@ -173,13 +173,12 @@ class ConvReparam(_ConvVariational):
|
|||
r"""
|
||||
Convolutional variational layers with Reparameterization.
|
||||
|
||||
See more details in paper `Auto-Encoding Variational Bayes
|
||||
<https://arxiv.org/abs/1312.6114>`
|
||||
See more details in paper `Auto-Encoding Variational Bayes <https://arxiv.org/abs/1312.6114>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channel :math:`C_{in}`.
|
||||
out_channels (int): The number of output channel :math:`C_{out}`.
|
||||
kernel_size (Union[int, tuple[int]]): The data type is int or
|
||||
kernel_size (Union[int, tuple[int]]): The data type is int or
|
||||
tuple with 2 integers. Specifies the height and width of the 2D
|
||||
convolution window. Single int means the value if for both
|
||||
height and width of the kernel. A tuple of 2 ints means the
|
||||
|
|
|
@ -132,8 +132,7 @@ class DenseReparam(_DenseVariational):
|
|||
r"""
|
||||
Dense variational layers with Reparameterization.
|
||||
|
||||
See more details in paper `Auto-Encoding Variational Bayes
|
||||
<https://arxiv.org/abs/1312.6114>`
|
||||
See more details in paper `Auto-Encoding Variational Bayes <https://arxiv.org/abs/1312.6114>`_.
|
||||
|
||||
Applies dense-connected layer for the input. This layer implements the operation as:
|
||||
|
||||
|
|
|
@ -78,16 +78,17 @@ class NormalPosterior(Cell):
|
|||
if not isinstance(shape, (tuple, list)):
|
||||
raise TypeError('The type of `shape` should be `tuple` or `list`')
|
||||
|
||||
if not isinstance(loc_mean, (int, float)):
|
||||
if isinstance(loc_mean, bool) or not isinstance(loc_mean, (int, float)):
|
||||
raise TypeError('The type of `loc_mean` should be `int` or `float`')
|
||||
|
||||
if not isinstance(untransformed_scale_mean, (int, float)):
|
||||
if isinstance(untransformed_scale_mean, bool) or not isinstance(untransformed_scale_mean, (int, float)):
|
||||
raise TypeError('The type of `untransformed_scale_mean` should be `int` or `float`')
|
||||
|
||||
if not (isinstance(loc_std, (int, float)) and loc_std >= 0):
|
||||
if isinstance(loc_std, bool) or not (isinstance(loc_std, (int, float)) and loc_std >= 0):
|
||||
raise TypeError('The type of `loc_std` should be `int` or `float` and its value should > 0')
|
||||
|
||||
if not (isinstance(untransformed_scale_std, (int, float)) and untransformed_scale_std >= 0):
|
||||
if isinstance(loc_std, bool) or not (isinstance(untransformed_scale_std, (int, float)) and
|
||||
untransformed_scale_std >= 0):
|
||||
raise TypeError('The type of `untransformed_scale_std` should be `int` or `float` and '
|
||||
'its value should > 0')
|
||||
|
||||
|
|
|
@ -61,9 +61,9 @@ class TransformToBNN:
|
|||
"""
|
||||
|
||||
def __init__(self, trainable_dnn, dnn_factor=1, bnn_factor=1):
|
||||
if not isinstance(dnn_factor, (int, float)):
|
||||
if isinstance(dnn_factor, bool) or not isinstance(dnn_factor, (int, float)):
|
||||
raise TypeError('The type of `dnn_factor` should be `int` or `float`')
|
||||
if not isinstance(bnn_factor, (int, float)):
|
||||
if isinstance(bnn_factor, bool) or not isinstance(bnn_factor, (int, float)):
|
||||
raise TypeError('The type of `bnn_factor` should be `int` or `float`')
|
||||
|
||||
net_with_loss = trainable_dnn.network
|
||||
|
|
Loading…
Reference in New Issue