clean the warning code

This commit is contained in:
王程浩 2022-03-23 08:21:23 +00:00 committed by cheng-hao-wang
parent 5c72a3c2bc
commit 77769f2a41
3 changed files with 6 additions and 6 deletions

View File

@ -190,7 +190,7 @@ class Bijector(Cell):
""" """
if 'param_dict' not in self.parameters.keys(): if 'param_dict' not in self.parameters.keys():
return None return None
param_dict = self.parameters['param_dict'] param_dict = self.parameters.get('param_dict')
broadcast_shape_tensor = None broadcast_shape_tensor = None
for value in param_dict.values(): for value in param_dict.values():
if value is None: if value is None:
@ -208,7 +208,7 @@ class Bijector(Cell):
""" """
if 'param_dict' not in self.parameters.keys(): if 'param_dict' not in self.parameters.keys():
return False return False
param_dict = self.parameters['param_dict'] param_dict = self.parameters.get('param_dict')
for value in param_dict.values(): for value in param_dict.values():
if value is None: if value is None:
continue continue
@ -327,4 +327,4 @@ class Bijector(Cell):
return self.forward_log_jacobian(*args, **kwargs) return self.forward_log_jacobian(*args, **kwargs)
if name == 'inverse_log_jacobian': if name == 'inverse_log_jacobian':
return self.inverse_log_jacobian(*args, **kwargs) return self.inverse_log_jacobian(*args, **kwargs)
return None raise Exception('Invalid name')

View File

@ -218,7 +218,7 @@ class Distribution(Cell):
""" """
Check if the parameters used during initialization are scalars. Check if the parameters used during initialization are scalars.
""" """
param_dict = self.parameters['param_dict'] param_dict = self.parameters.get('param_dict')
for value in param_dict.values(): for value in param_dict.values():
if value is None: if value is None:
continue continue

View File

@ -67,10 +67,10 @@ class Gumbel(TransformedDistribution):
TypeError: When the input `dtype` is not a subclass of float. TypeError: When the input `dtype` is not a subclass of float.
Examples: Examples:
>>> import numpy as np
>>> import mindspore >>> import mindspore
>>> import mindspore.nn as nn >>> import numpy as np
>>> import mindspore.nn.probability.distribution as msd >>> import mindspore.nn.probability.distribution as msd
>>> import mindspore.nn as nn
>>> from mindspore import Tensor >>> from mindspore import Tensor
>>> class Prob(nn.Cell): >>> class Prob(nn.Cell):
... def __init__(self): ... def __init__(self):