!7254 [ME] reused `check_type` function

Merge pull request !7254 from chenzhongming/zomi_master
This commit is contained in:
mindspore-ci-bot 2020-10-14 19:02:11 +08:00 committed by Gitee
commit 044a511726
11 changed files with 25 additions and 28 deletions

View File

@ -375,17 +375,14 @@ class Validator:
"""Type checking."""
def raise_error_msg():
"""func for raising error message when check failed"""
type_names = [t.__name__ for t in valid_types]
num_types = len(valid_types)
raise TypeError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
raise TypeError(f'The type of `{arg_name}` should be in {valid_types}, but got {type(arg_value).__name__}.')
if isinstance(arg_value, type(mstype.tensor)):
arg_value = arg_value.element_type()
# Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and
# `check_type('x', True, [bool, int])` will check pass
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
raise_error_msg()
if arg_value in valid_types:
return arg_value
if isinstance(arg_value, tuple(valid_types)):
return arg_value
raise_error_msg()

View File

@ -118,7 +118,7 @@ number_type = (int8,
float64,)
int_type = (int8, int16, int32, int64,)
uint_type = (uint8, uint16, uint32, uint64)
uint_type = (uint8, uint16, uint32, uint64,)
float_type = (float16, float32, float64,)
implicit_conversion_seq = {t: idx for idx, t in enumerate((

View File

@ -24,7 +24,6 @@ __all__ = [
'check_greater_equal_zero',
'check_greater_zero',
'check_prob',
'check_type',
'exp_generic',
'expm1_generic',
'log_generic',

View File

@ -206,12 +206,6 @@ def probs_to_logits(probs, is_binary=False):
return P.Log()(ps_clamped)
def check_type(data_type, value_type, name):
if not data_type in value_type:
raise TypeError(
f"For {name}, valid type include {value_type}, {data_type} is invalid")
@constexpr
def raise_none_error(name):
raise TypeError(f"the type {name} should be subclass of Tensor."

View File

@ -16,8 +16,9 @@
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore._checkparam import Validator
from .distribution import Distribution
from ._utils.utils import check_prob, check_type, check_distribution_name
from ._utils.utils import check_prob, check_distribution_name
from ._utils.custom_ops import exp_generic, log_generic
@ -118,7 +119,7 @@ class Bernoulli(Distribution):
param = dict(locals())
param['param_dict'] = {'probs': probs}
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
Validator.check_type(type(self).__name__, dtype, valid_dtype)
super(Bernoulli, self).__init__(seed, dtype, name, param)
self._probs = self._add_parameter(probs, 'probs')

View File

@ -16,10 +16,11 @@
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore._checkparam import Validator
import mindspore.nn as nn
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import check_prob, check_sum_equal_one, check_type, check_rank,\
from ._utils.utils import check_prob, check_sum_equal_one, check_rank,\
check_distribution_name, raise_not_implemented_util
from ._utils.custom_ops import exp_generic, log_generic, broadcast_to
@ -107,7 +108,7 @@ class Categorical(Distribution):
param = dict(locals())
param['param_dict'] = {'probs': probs}
valid_dtype = mstype.int_type
check_type(dtype, valid_dtype, "Categorical")
Validator.check_type("Categorical", dtype, valid_dtype)
super(Categorical, self).__init__(seed, dtype, name, param)
self._probs = self._add_parameter(probs, 'probs')

View File

@ -16,9 +16,10 @@
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore._checkparam import Validator
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import check_greater_zero, check_type, check_distribution_name
from ._utils.utils import check_greater_zero, check_distribution_name
from ._utils.custom_ops import exp_generic, log_generic
@ -120,7 +121,7 @@ class Exponential(Distribution):
param = dict(locals())
param['param_dict'] = {'rate': rate}
valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
Validator.check_type(type(self).__name__, dtype, valid_dtype)
super(Exponential, self).__init__(seed, dtype, name, param)
self._rate = self._add_parameter(rate, 'rate')

View File

@ -16,9 +16,10 @@
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore._checkparam import Validator
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import check_prob, check_type, check_distribution_name
from ._utils.utils import check_prob, check_distribution_name
from ._utils.custom_ops import exp_generic, log_generic
@ -121,7 +122,7 @@ class Geometric(Distribution):
param = dict(locals())
param['param_dict'] = {'probs': probs}
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
Validator.check_type(type(self).__name__, dtype, valid_dtype)
super(Geometric, self).__init__(seed, dtype, name, param)
self._probs = self._add_parameter(probs, 'probs')

View File

@ -16,9 +16,10 @@
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore._checkparam import Validator
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import check_greater_zero, check_type
from ._utils.utils import check_greater_zero
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, log1p_generic
@ -110,7 +111,7 @@ class Logistic(Distribution):
param = dict(locals())
param['param_dict'] = {'loc': loc, 'scale': scale}
valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
Validator.check_type(type(self).__name__, dtype, valid_dtype)
super(Logistic, self).__init__(seed, dtype, name, param)
self._loc = self._add_parameter(loc, 'loc')

View File

@ -16,9 +16,10 @@
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore._checkparam import Validator
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import check_greater_zero, check_type, check_distribution_name
from ._utils.utils import check_greater_zero, check_distribution_name
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic
@ -126,7 +127,7 @@ class Normal(Distribution):
param = dict(locals())
param['param_dict'] = {'mean': mean, 'sd': sd}
valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
Validator.check_type(type(self).__name__, dtype, valid_dtype)
super(Normal, self).__init__(seed, dtype, name, param)
self._mean_value = self._add_parameter(mean, 'mean')

View File

@ -15,9 +15,10 @@
"""Uniform Distribution"""
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore._checkparam import Validator
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import check_greater, check_type, check_distribution_name
from ._utils.utils import check_greater, check_distribution_name
from ._utils.custom_ops import exp_generic, log_generic
@ -125,7 +126,7 @@ class Uniform(Distribution):
param = dict(locals())
param['param_dict'] = {'low': low, 'high': high}
valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
Validator.check_type(type(self).__name__, dtype, valid_dtype)
super(Uniform, self).__init__(seed, dtype, name, param)
self._low = self._add_parameter(low, 'low')