From 77769f2a41df90ac203eb267345a84ffdabcfa43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E7=A8=8B=E6=B5=A9?= Date: Wed, 23 Mar 2022 08:21:23 +0000 Subject: [PATCH] clean the warning code --- .../python/mindspore/nn/probability/bijector/bijector.py | 6 +++--- .../mindspore/nn/probability/distribution/distribution.py | 2 +- .../python/mindspore/nn/probability/distribution/gumbel.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mindspore/python/mindspore/nn/probability/bijector/bijector.py b/mindspore/python/mindspore/nn/probability/bijector/bijector.py index 8f10d54e720..07b0c930476 100644 --- a/mindspore/python/mindspore/nn/probability/bijector/bijector.py +++ b/mindspore/python/mindspore/nn/probability/bijector/bijector.py @@ -190,7 +190,7 @@ class Bijector(Cell): """ if 'param_dict' not in self.parameters.keys(): return None - param_dict = self.parameters['param_dict'] + param_dict = self.parameters.get('param_dict') broadcast_shape_tensor = None for value in param_dict.values(): if value is None: @@ -208,7 +208,7 @@ class Bijector(Cell): """ if 'param_dict' not in self.parameters.keys(): return False - param_dict = self.parameters['param_dict'] + param_dict = self.parameters.get('param_dict') for value in param_dict.values(): if value is None: continue @@ -327,4 +327,4 @@ class Bijector(Cell): return self.forward_log_jacobian(*args, **kwargs) if name == 'inverse_log_jacobian': return self.inverse_log_jacobian(*args, **kwargs) - return None + raise Exception('Invalid name') diff --git a/mindspore/python/mindspore/nn/probability/distribution/distribution.py b/mindspore/python/mindspore/nn/probability/distribution/distribution.py index 208103f181b..bbdf1568daa 100644 --- a/mindspore/python/mindspore/nn/probability/distribution/distribution.py +++ b/mindspore/python/mindspore/nn/probability/distribution/distribution.py @@ -218,7 +218,7 @@ class Distribution(Cell): """ Check if the parameters used during initialization are scalars. """ - param_dict = self.parameters['param_dict'] + param_dict = self.parameters.get('param_dict') for value in param_dict.values(): if value is None: continue diff --git a/mindspore/python/mindspore/nn/probability/distribution/gumbel.py b/mindspore/python/mindspore/nn/probability/distribution/gumbel.py index 67d6562b214..8f6a43eba06 100644 --- a/mindspore/python/mindspore/nn/probability/distribution/gumbel.py +++ b/mindspore/python/mindspore/nn/probability/distribution/gumbel.py @@ -67,10 +67,10 @@ class Gumbel(TransformedDistribution): TypeError: When the input `dtype` is not a subclass of float. Examples: - >>> import numpy as np >>> import mindspore - >>> import mindspore.nn as nn + >>> import numpy as np >>> import mindspore.nn.probability.distribution as msd + >>> import mindspore.nn as nn >>> from mindspore import Tensor >>> class Prob(nn.Cell): ... def __init__(self):