forked from mindspore-Ecosystem/mindspore
!7254 [ME] reused `check_type` function
Merge pull request !7254 from chenzhongming/zomi_master
This commit is contained in:
commit
044a511726
|
@ -375,17 +375,14 @@ class Validator:
|
|||
"""Type checking."""
|
||||
def raise_error_msg():
|
||||
"""func for raising error message when check failed"""
|
||||
type_names = [t.__name__ for t in valid_types]
|
||||
num_types = len(valid_types)
|
||||
raise TypeError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
|
||||
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
|
||||
raise TypeError(f'The type of `{arg_name}` should be in {valid_types}, but got {type(arg_value).__name__}.')
|
||||
|
||||
if isinstance(arg_value, type(mstype.tensor)):
|
||||
arg_value = arg_value.element_type()
|
||||
# Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and
|
||||
# `check_type('x', True, [bool, int])` will check pass
|
||||
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
|
||||
raise_error_msg()
|
||||
if arg_value in valid_types:
|
||||
return arg_value
|
||||
if isinstance(arg_value, tuple(valid_types)):
|
||||
return arg_value
|
||||
raise_error_msg()
|
||||
|
|
|
@ -118,7 +118,7 @@ number_type = (int8,
|
|||
float64,)
|
||||
|
||||
int_type = (int8, int16, int32, int64,)
|
||||
uint_type = (uint8, uint16, uint32, uint64)
|
||||
uint_type = (uint8, uint16, uint32, uint64,)
|
||||
float_type = (float16, float32, float64,)
|
||||
|
||||
implicit_conversion_seq = {t: idx for idx, t in enumerate((
|
||||
|
|
|
@ -24,7 +24,6 @@ __all__ = [
|
|||
'check_greater_equal_zero',
|
||||
'check_greater_zero',
|
||||
'check_prob',
|
||||
'check_type',
|
||||
'exp_generic',
|
||||
'expm1_generic',
|
||||
'log_generic',
|
||||
|
|
|
@ -206,12 +206,6 @@ def probs_to_logits(probs, is_binary=False):
|
|||
return P.Log()(ps_clamped)
|
||||
|
||||
|
||||
def check_type(data_type, value_type, name):
|
||||
if not data_type in value_type:
|
||||
raise TypeError(
|
||||
f"For {name}, valid type include {value_type}, {data_type} is invalid")
|
||||
|
||||
|
||||
@constexpr
|
||||
def raise_none_error(name):
|
||||
raise TypeError(f"the type {name} should be subclass of Tensor."
|
||||
|
|
|
@ -16,8 +16,9 @@
|
|||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore._checkparam import Validator
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import check_prob, check_type, check_distribution_name
|
||||
from ._utils.utils import check_prob, check_distribution_name
|
||||
from ._utils.custom_ops import exp_generic, log_generic
|
||||
|
||||
|
||||
|
@ -118,7 +119,7 @@ class Bernoulli(Distribution):
|
|||
param = dict(locals())
|
||||
param['param_dict'] = {'probs': probs}
|
||||
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
|
||||
check_type(dtype, valid_dtype, type(self).__name__)
|
||||
Validator.check_type(type(self).__name__, dtype, valid_dtype)
|
||||
super(Bernoulli, self).__init__(seed, dtype, name, param)
|
||||
|
||||
self._probs = self._add_parameter(probs, 'probs')
|
||||
|
|
|
@ -16,10 +16,11 @@
|
|||
import numpy as np
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore._checkparam import Validator
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common import dtype as mstype
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import check_prob, check_sum_equal_one, check_type, check_rank,\
|
||||
from ._utils.utils import check_prob, check_sum_equal_one, check_rank,\
|
||||
check_distribution_name, raise_not_implemented_util
|
||||
from ._utils.custom_ops import exp_generic, log_generic, broadcast_to
|
||||
|
||||
|
@ -107,7 +108,7 @@ class Categorical(Distribution):
|
|||
param = dict(locals())
|
||||
param['param_dict'] = {'probs': probs}
|
||||
valid_dtype = mstype.int_type
|
||||
check_type(dtype, valid_dtype, "Categorical")
|
||||
Validator.check_type("Categorical", dtype, valid_dtype)
|
||||
super(Categorical, self).__init__(seed, dtype, name, param)
|
||||
|
||||
self._probs = self._add_parameter(probs, 'probs')
|
||||
|
|
|
@ -16,9 +16,10 @@
|
|||
import numpy as np
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore.common import dtype as mstype
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import check_greater_zero, check_type, check_distribution_name
|
||||
from ._utils.utils import check_greater_zero, check_distribution_name
|
||||
from ._utils.custom_ops import exp_generic, log_generic
|
||||
|
||||
|
||||
|
@ -120,7 +121,7 @@ class Exponential(Distribution):
|
|||
param = dict(locals())
|
||||
param['param_dict'] = {'rate': rate}
|
||||
valid_dtype = mstype.float_type
|
||||
check_type(dtype, valid_dtype, type(self).__name__)
|
||||
Validator.check_type(type(self).__name__, dtype, valid_dtype)
|
||||
super(Exponential, self).__init__(seed, dtype, name, param)
|
||||
|
||||
self._rate = self._add_parameter(rate, 'rate')
|
||||
|
|
|
@ -16,9 +16,10 @@
|
|||
import numpy as np
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore.common import dtype as mstype
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import check_prob, check_type, check_distribution_name
|
||||
from ._utils.utils import check_prob, check_distribution_name
|
||||
from ._utils.custom_ops import exp_generic, log_generic
|
||||
|
||||
|
||||
|
@ -121,7 +122,7 @@ class Geometric(Distribution):
|
|||
param = dict(locals())
|
||||
param['param_dict'] = {'probs': probs}
|
||||
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
|
||||
check_type(dtype, valid_dtype, type(self).__name__)
|
||||
Validator.check_type(type(self).__name__, dtype, valid_dtype)
|
||||
super(Geometric, self).__init__(seed, dtype, name, param)
|
||||
|
||||
self._probs = self._add_parameter(probs, 'probs')
|
||||
|
|
|
@ -16,9 +16,10 @@
|
|||
import numpy as np
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore.common import dtype as mstype
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import check_greater_zero, check_type
|
||||
from ._utils.utils import check_greater_zero
|
||||
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, log1p_generic
|
||||
|
||||
|
||||
|
@ -110,7 +111,7 @@ class Logistic(Distribution):
|
|||
param = dict(locals())
|
||||
param['param_dict'] = {'loc': loc, 'scale': scale}
|
||||
valid_dtype = mstype.float_type
|
||||
check_type(dtype, valid_dtype, type(self).__name__)
|
||||
Validator.check_type(type(self).__name__, dtype, valid_dtype)
|
||||
super(Logistic, self).__init__(seed, dtype, name, param)
|
||||
|
||||
self._loc = self._add_parameter(loc, 'loc')
|
||||
|
|
|
@ -16,9 +16,10 @@
|
|||
import numpy as np
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore.common import dtype as mstype
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import check_greater_zero, check_type, check_distribution_name
|
||||
from ._utils.utils import check_greater_zero, check_distribution_name
|
||||
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic
|
||||
|
||||
|
||||
|
@ -126,7 +127,7 @@ class Normal(Distribution):
|
|||
param = dict(locals())
|
||||
param['param_dict'] = {'mean': mean, 'sd': sd}
|
||||
valid_dtype = mstype.float_type
|
||||
check_type(dtype, valid_dtype, type(self).__name__)
|
||||
Validator.check_type(type(self).__name__, dtype, valid_dtype)
|
||||
super(Normal, self).__init__(seed, dtype, name, param)
|
||||
|
||||
self._mean_value = self._add_parameter(mean, 'mean')
|
||||
|
|
|
@ -15,9 +15,10 @@
|
|||
"""Uniform Distribution"""
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore.common import dtype as mstype
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import check_greater, check_type, check_distribution_name
|
||||
from ._utils.utils import check_greater, check_distribution_name
|
||||
from ._utils.custom_ops import exp_generic, log_generic
|
||||
|
||||
|
||||
|
@ -125,7 +126,7 @@ class Uniform(Distribution):
|
|||
param = dict(locals())
|
||||
param['param_dict'] = {'low': low, 'high': high}
|
||||
valid_dtype = mstype.float_type
|
||||
check_type(dtype, valid_dtype, type(self).__name__)
|
||||
Validator.check_type(type(self).__name__, dtype, valid_dtype)
|
||||
super(Uniform, self).__init__(seed, dtype, name, param)
|
||||
|
||||
self._low = self._add_parameter(low, 'low')
|
||||
|
|
Loading…
Reference in New Issue