From 90de7be0fa08dcb16c81e58a9b3cd2375180ea97 Mon Sep 17 00:00:00 2001 From: chenzomi Date: Tue, 13 Oct 2020 17:53:21 +0800 Subject: [PATCH] [ME] reused `check_type` function --- mindspore/_checkparam.py | 9 +++------ mindspore/common/dtype.py | 2 +- mindspore/nn/probability/distribution/_utils/__init__.py | 1 - mindspore/nn/probability/distribution/_utils/utils.py | 6 ------ mindspore/nn/probability/distribution/bernoulli.py | 5 +++-- mindspore/nn/probability/distribution/categorical.py | 5 +++-- mindspore/nn/probability/distribution/exponential.py | 5 +++-- mindspore/nn/probability/distribution/geometric.py | 5 +++-- mindspore/nn/probability/distribution/logistic.py | 5 +++-- mindspore/nn/probability/distribution/normal.py | 5 +++-- mindspore/nn/probability/distribution/uniform.py | 5 +++-- 11 files changed, 25 insertions(+), 28 deletions(-) diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index 8b020797edd..75b4993bd71 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -375,17 +375,14 @@ class Validator: """Type checking.""" def raise_error_msg(): """func for raising error message when check failed""" - type_names = [t.__name__ for t in valid_types] - num_types = len(valid_types) - raise TypeError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}' - f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') + raise TypeError(f'The type of `{arg_name}` should be in {valid_types}, but got {type(arg_value).__name__}.') if isinstance(arg_value, type(mstype.tensor)): arg_value = arg_value.element_type() - # Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and - # `check_type('x', True, [bool, int])` will check pass if isinstance(arg_value, bool) and bool not in tuple(valid_types): raise_error_msg() + if arg_value in valid_types: + return arg_value if isinstance(arg_value, tuple(valid_types)): return arg_value raise_error_msg() diff --git a/mindspore/common/dtype.py b/mindspore/common/dtype.py index 8351bfb51c0..3289e87d2b6 100644 --- a/mindspore/common/dtype.py +++ b/mindspore/common/dtype.py @@ -118,7 +118,7 @@ number_type = (int8, float64,) int_type = (int8, int16, int32, int64,) -uint_type = (uint8, uint16, uint32, uint64) +uint_type = (uint8, uint16, uint32, uint64,) float_type = (float16, float32, float64,) implicit_conversion_seq = {t: idx for idx, t in enumerate(( diff --git a/mindspore/nn/probability/distribution/_utils/__init__.py b/mindspore/nn/probability/distribution/_utils/__init__.py index 49150d99a65..f4d66d83364 100644 --- a/mindspore/nn/probability/distribution/_utils/__init__.py +++ b/mindspore/nn/probability/distribution/_utils/__init__.py @@ -24,7 +24,6 @@ __all__ = [ 'check_greater_equal_zero', 'check_greater_zero', 'check_prob', - 'check_type', 'exp_generic', 'expm1_generic', 'log_generic', diff --git a/mindspore/nn/probability/distribution/_utils/utils.py b/mindspore/nn/probability/distribution/_utils/utils.py index 30f127b7615..c344e9c06f4 100644 --- a/mindspore/nn/probability/distribution/_utils/utils.py +++ b/mindspore/nn/probability/distribution/_utils/utils.py @@ -206,12 +206,6 @@ def probs_to_logits(probs, is_binary=False): return P.Log()(ps_clamped) -def check_type(data_type, value_type, name): - if not data_type in value_type: - raise TypeError( - f"For {name}, valid type include {value_type}, {data_type} is invalid") - - @constexpr def raise_none_error(name): raise TypeError(f"the type {name} should be subclass of Tensor." diff --git a/mindspore/nn/probability/distribution/bernoulli.py b/mindspore/nn/probability/distribution/bernoulli.py index 120e9b6359d..4181947e728 100644 --- a/mindspore/nn/probability/distribution/bernoulli.py +++ b/mindspore/nn/probability/distribution/bernoulli.py @@ -16,8 +16,9 @@ from mindspore.common import dtype as mstype from mindspore.ops import operations as P from mindspore.ops import composite as C +from mindspore._checkparam import Validator from .distribution import Distribution -from ._utils.utils import check_prob, check_type, check_distribution_name +from ._utils.utils import check_prob, check_distribution_name from ._utils.custom_ops import exp_generic, log_generic @@ -118,7 +119,7 @@ class Bernoulli(Distribution): param = dict(locals()) param['param_dict'] = {'probs': probs} valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type - check_type(dtype, valid_dtype, type(self).__name__) + Validator.check_type(type(self).__name__, dtype, valid_dtype) super(Bernoulli, self).__init__(seed, dtype, name, param) self._probs = self._add_parameter(probs, 'probs') diff --git a/mindspore/nn/probability/distribution/categorical.py b/mindspore/nn/probability/distribution/categorical.py index 0fce3655577..7cb8705c659 100644 --- a/mindspore/nn/probability/distribution/categorical.py +++ b/mindspore/nn/probability/distribution/categorical.py @@ -16,10 +16,11 @@ import numpy as np from mindspore.ops import operations as P from mindspore.ops import composite as C +from mindspore._checkparam import Validator import mindspore.nn as nn from mindspore.common import dtype as mstype from .distribution import Distribution -from ._utils.utils import check_prob, check_sum_equal_one, check_type, check_rank,\ +from ._utils.utils import check_prob, check_sum_equal_one, check_rank,\ check_distribution_name, raise_not_implemented_util from ._utils.custom_ops import exp_generic, log_generic, broadcast_to @@ -107,7 +108,7 @@ class Categorical(Distribution): param = dict(locals()) param['param_dict'] = {'probs': probs} valid_dtype = mstype.int_type - check_type(dtype, valid_dtype, "Categorical") + Validator.check_type("Categorical", dtype, valid_dtype) super(Categorical, self).__init__(seed, dtype, name, param) self._probs = self._add_parameter(probs, 'probs') diff --git a/mindspore/nn/probability/distribution/exponential.py b/mindspore/nn/probability/distribution/exponential.py index 3edc040f318..c21b20b6122 100644 --- a/mindspore/nn/probability/distribution/exponential.py +++ b/mindspore/nn/probability/distribution/exponential.py @@ -16,9 +16,10 @@ import numpy as np from mindspore.ops import operations as P from mindspore.ops import composite as C +from mindspore._checkparam import Validator from mindspore.common import dtype as mstype from .distribution import Distribution -from ._utils.utils import check_greater_zero, check_type, check_distribution_name +from ._utils.utils import check_greater_zero, check_distribution_name from ._utils.custom_ops import exp_generic, log_generic @@ -120,7 +121,7 @@ class Exponential(Distribution): param = dict(locals()) param['param_dict'] = {'rate': rate} valid_dtype = mstype.float_type - check_type(dtype, valid_dtype, type(self).__name__) + Validator.check_type(type(self).__name__, dtype, valid_dtype) super(Exponential, self).__init__(seed, dtype, name, param) self._rate = self._add_parameter(rate, 'rate') diff --git a/mindspore/nn/probability/distribution/geometric.py b/mindspore/nn/probability/distribution/geometric.py index 9a37d308c92..86c0ca5f852 100644 --- a/mindspore/nn/probability/distribution/geometric.py +++ b/mindspore/nn/probability/distribution/geometric.py @@ -16,9 +16,10 @@ import numpy as np from mindspore.ops import operations as P from mindspore.ops import composite as C +from mindspore._checkparam import Validator from mindspore.common import dtype as mstype from .distribution import Distribution -from ._utils.utils import check_prob, check_type, check_distribution_name +from ._utils.utils import check_prob, check_distribution_name from ._utils.custom_ops import exp_generic, log_generic @@ -121,7 +122,7 @@ class Geometric(Distribution): param = dict(locals()) param['param_dict'] = {'probs': probs} valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type - check_type(dtype, valid_dtype, type(self).__name__) + Validator.check_type(type(self).__name__, dtype, valid_dtype) super(Geometric, self).__init__(seed, dtype, name, param) self._probs = self._add_parameter(probs, 'probs') diff --git a/mindspore/nn/probability/distribution/logistic.py b/mindspore/nn/probability/distribution/logistic.py index 0a8ff8fdb54..7ee6cb484db 100644 --- a/mindspore/nn/probability/distribution/logistic.py +++ b/mindspore/nn/probability/distribution/logistic.py @@ -16,9 +16,10 @@ import numpy as np from mindspore.ops import operations as P from mindspore.ops import composite as C +from mindspore._checkparam import Validator from mindspore.common import dtype as mstype from .distribution import Distribution -from ._utils.utils import check_greater_zero, check_type +from ._utils.utils import check_greater_zero from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, log1p_generic @@ -110,7 +111,7 @@ class Logistic(Distribution): param = dict(locals()) param['param_dict'] = {'loc': loc, 'scale': scale} valid_dtype = mstype.float_type - check_type(dtype, valid_dtype, type(self).__name__) + Validator.check_type(type(self).__name__, dtype, valid_dtype) super(Logistic, self).__init__(seed, dtype, name, param) self._loc = self._add_parameter(loc, 'loc') diff --git a/mindspore/nn/probability/distribution/normal.py b/mindspore/nn/probability/distribution/normal.py index 0df0d2b8e40..4226cccd161 100644 --- a/mindspore/nn/probability/distribution/normal.py +++ b/mindspore/nn/probability/distribution/normal.py @@ -16,9 +16,10 @@ import numpy as np from mindspore.ops import operations as P from mindspore.ops import composite as C +from mindspore._checkparam import Validator from mindspore.common import dtype as mstype from .distribution import Distribution -from ._utils.utils import check_greater_zero, check_type, check_distribution_name +from ._utils.utils import check_greater_zero, check_distribution_name from ._utils.custom_ops import exp_generic, expm1_generic, log_generic @@ -126,7 +127,7 @@ class Normal(Distribution): param = dict(locals()) param['param_dict'] = {'mean': mean, 'sd': sd} valid_dtype = mstype.float_type - check_type(dtype, valid_dtype, type(self).__name__) + Validator.check_type(type(self).__name__, dtype, valid_dtype) super(Normal, self).__init__(seed, dtype, name, param) self._mean_value = self._add_parameter(mean, 'mean') diff --git a/mindspore/nn/probability/distribution/uniform.py b/mindspore/nn/probability/distribution/uniform.py index 21a1754bac5..feafe87b147 100644 --- a/mindspore/nn/probability/distribution/uniform.py +++ b/mindspore/nn/probability/distribution/uniform.py @@ -15,9 +15,10 @@ """Uniform Distribution""" from mindspore.ops import operations as P from mindspore.ops import composite as C +from mindspore._checkparam import Validator from mindspore.common import dtype as mstype from .distribution import Distribution -from ._utils.utils import check_greater, check_type, check_distribution_name +from ._utils.utils import check_greater, check_distribution_name from ._utils.custom_ops import exp_generic, log_generic @@ -125,7 +126,7 @@ class Uniform(Distribution): param = dict(locals()) param['param_dict'] = {'low': low, 'high': high} valid_dtype = mstype.float_type - check_type(dtype, valid_dtype, type(self).__name__) + Validator.check_type(type(self).__name__, dtype, valid_dtype) super(Uniform, self).__init__(seed, dtype, name, param) self._low = self._add_parameter(low, 'low')