!7311 [ME] change _check_parameter format

Merge pull request !7311 from chenzhongming/zomi_master
This commit is contained in:
mindspore-ci-bot 2020-10-15 20:52:29 +08:00 committed by Gitee
commit c68f36c81e
22 changed files with 224 additions and 171 deletions

View File

@ -94,10 +94,10 @@ rel_strs = {
def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=None): def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=None):
""" """
Check argument integer. Check argument integer.
Usage: Usage:
- number = check_integer(number, 0, Rel.GE, "number", None) # number >= 0 - number = check_integer(number, 0, Rel.GE, "number", None) # number >= 0
""" """
rel_fn = Rel.get_fns(rel) rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, arg_type) or isinstance(arg_value, bool) type_mismatch = not isinstance(arg_value, arg_type) or isinstance(arg_value, bool)
@ -122,13 +122,33 @@ def check_is_number(arg_value, arg_type, arg_name=None, prim_name=None):
""" """
prim_name = f'in \'{prim_name}\'' if prim_name else '' prim_name = f'in \'{prim_name}\'' if prim_name else ''
arg_name = f'\'{prim_name}\'' if arg_name else 'Input value' arg_name = f'\'{prim_name}\'' if arg_name else 'Input value'
if isinstance(arg_value, arg_type): if isinstance(arg_value, arg_type) and not isinstance(arg_value, bool):
if math.isinf(arg_value) or math.isnan(arg_value): if math.isinf(arg_value) or math.isnan(arg_value):
raise ValueError(f'{arg_name} {prim_name} must be legal float, but got `{arg_value}`.') raise ValueError(f'{arg_name} {prim_name} must be legal float, but got `{arg_value}`.')
return arg_value return arg_value
raise TypeError(f'{arg_name} {prim_name} must be float, but got `{type(arg_value).__name__}`') raise TypeError(f'{arg_name} {prim_name} must be float, but got `{type(arg_value).__name__}`')
def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg_name=None, prim_name=None):
"""
Method for checking whether an int value is in some range.
Usage:
- number = check_number_range(number, 0.0, 1.0, Rel.INC_NEITHER, "number", float) # number in [0.0, 1.0]
- number = check_number_range(number, 0, 1, Rel.INC_NEITHER, "number", int) # number in [0, 1]
"""
prim_name = f'in `{prim_name}`' if prim_name else ''
arg_name = f'`{arg_name}`' if arg_name else ''
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, (np.ndarray, np.generic, value_type)) or isinstance(arg_value, bool)
excp_cls = TypeError if type_mismatch else ValueError
if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit):
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
raise excp_cls("{} {} should be in range of {}, but got {:.3f} with type {}.".format(
arg_name, prim_name, rel_str, arg_value, type(arg_value).__name__))
return arg_value
class Validator: class Validator:
"""validator for checking input parameters""" """validator for checking input parameters"""
@ -147,16 +167,13 @@ class Validator:
@staticmethod @staticmethod
def check_integer(arg_name, arg_value, value, rel, prim_name=None): def check_integer(arg_name, arg_value, value, rel, prim_name=None):
"""Check argument is integer""" """
rel_fn = Rel.get_fns(rel) Checks input integer value `arg_value` compare to `value`.
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
excp_cls = TypeError if type_mismatch else ValueError Usage:
if type_mismatch or not rel_fn(arg_value, value): - number = check_integer(number, 0, Rel.GE, "number", None) # number >= 0
rel_str = Rel.get_strs(rel).format(value) """
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" return check_number(arg_value, value, rel, int, arg_name, prim_name)
raise excp_cls(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got `{arg_value}`'
f' with type `{type(arg_value).__name__}`.')
return arg_value
@staticmethod @staticmethod
def check_is_int(arg_value, arg_name=None, prim_name=None): def check_is_int(arg_value, arg_name=None, prim_name=None):
@ -168,7 +185,7 @@ class Validator:
- number = check_is_int(number, int, "bias") - number = check_is_int(number, int, "bias")
- number = check_is_int(number, int, "bias", "bias_class") - number = check_is_int(number, int, "bias", "bias_class")
""" """
check_is_number(arg_value, int, arg_name, prim_name) return check_is_number(arg_value, int, arg_name, prim_name)
@staticmethod @staticmethod
def check_positive_int(arg_value, arg_name=None, prim_name=None): def check_positive_int(arg_value, arg_name=None, prim_name=None):
@ -214,6 +231,16 @@ class Validator:
""" """
return check_number(arg_value, 0, Rel.GE, int, arg_name, prim_name) return check_number(arg_value, 0, Rel.GE, int, arg_name, prim_name)
@staticmethod
def check_float(arg_value, value, rel, arg_name=None, prim_name=None):
"""
Checks input float value `arg_value` compare to `value`.
Usage:
- number = check_float(number, 0.0, Rel.GE, "number", None) # number >= 0
"""
return check_number(arg_value, value, rel, float, arg_name, prim_name)
@staticmethod @staticmethod
def check_is_float(arg_value, arg_name=None, prim_name=None): def check_is_float(arg_value, arg_name=None, prim_name=None):
""" """
@ -224,7 +251,7 @@ class Validator:
- number = check_is_float(number, int, "bias") - number = check_is_float(number, int, "bias")
- number = check_is_float(number, int, "bias", "bias_class") - number = check_is_float(number, int, "bias", "bias_class")
""" """
check_is_number(arg_value, float, arg_name, prim_name) return check_is_number(arg_value, float, arg_name, prim_name)
@staticmethod @staticmethod
def check_positive_float(arg_value, arg_name=None, prim_name=None): def check_positive_float(arg_value, arg_name=None, prim_name=None):
@ -302,25 +329,26 @@ class Validator:
return arg_value return arg_value
@staticmethod @staticmethod
def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name): def check_int_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None):
"""Method for checking whether an int value is in some range.""" """
rel_fn = Rel.get_fns(rel) Method for checking whether input value is in int range.
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
excp_cls = TypeError if type_mismatch else ValueError Usage:
if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit): - number = check_int_range(number, 0, 1, Rel.INC_NEITHER) # number in [0, 1]
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) - number = check_int_range(number, 0, 1, Rel.INC_NEITHER, "number") # number in [0, 1]
raise excp_cls(f'For \'{prim_name}\' the `{arg_name}` should be an int in range {rel_str},' """
f' but got `{arg_value}` with type `{type(arg_value).__name__}`.') return check_number_range(arg_value, lower_limit, upper_limit, rel, int, arg_name, prim_name)
return arg_value
@staticmethod @staticmethod
def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name): def check_float_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None):
"""Method for checking whether a numeric value is in some range.""" """
rel_fn = Rel.get_fns(rel) Method for checking whether input value is in float range.
if not rel_fn(arg_value, lower_limit, upper_limit):
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) Usage:
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be in range {rel_str}, but got {arg_value}.') - number = check_float_range(number, 0.0, 1.0, Rel.INC_NEITHER) # number in [0.0, 1.0]
return arg_value - number = check_float_range(number, 0.0, 1.0, Rel.INC_NEITHER, "number") # number in [0.0, 1.0]
"""
return check_number_range(arg_value, lower_limit, upper_limit, rel, float, arg_name, prim_name)
@staticmethod @staticmethod
def check_string(arg_value, valid_values, arg_name=None, prim_name=None): def check_string(arg_value, valid_values, arg_name=None, prim_name=None):
@ -502,13 +530,6 @@ class Validator:
f'{tuple(exp_shape)}, but got {shape}.') f'{tuple(exp_shape)}, but got {shape}.')
def check_int(input_param):
"""Int type judgment."""
if isinstance(input_param, int) and not isinstance(input_param, bool):
return input_param
raise TypeError("Input type must be int!")
def check_int_zero_one(input_param): def check_int_zero_one(input_param):
"""Judge whether it is 0 or 1.""" """Judge whether it is 0 or 1."""
if input_param in (0, 1): if input_param in (0, 1):

View File

@ -233,7 +233,7 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch):
""" """
if not isinstance(min_lr, float): if not isinstance(min_lr, float):
raise TypeError("min_lr must be float.") raise TypeError("min_lr must be float.")
validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, None) validator.check_non_negative_float(min_lr, "min_lr", None)
validator.check_positive_float(max_lr, 'max_lr') validator.check_positive_float(max_lr, 'max_lr')
validator.check_is_float(max_lr, 'max_lr') validator.check_is_float(max_lr, 'max_lr')
validator.check_positive_int(total_step, 'total_step') validator.check_positive_int(total_step, 'total_step')
@ -303,7 +303,7 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e
validator.check_is_float(learning_rate, 'learning_rate') validator.check_is_float(learning_rate, 'learning_rate')
if not isinstance(end_learning_rate, float): if not isinstance(end_learning_rate, float):
raise TypeError("end_learning_rate must be float.") raise TypeError("end_learning_rate must be float.")
validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None) validator.check_non_negative_float(end_learning_rate, "end_learning_rate", None)
validator.check_positive_float(power, 'power') validator.check_positive_float(power, 'power')
validator.check_is_float(power, 'power') validator.check_is_float(power, 'power')
validator.check_positive_int(total_step, 'total_step') validator.check_positive_int(total_step, 'total_step')
@ -356,7 +356,7 @@ def warmup_lr(learning_rate, total_step, step_per_epoch, warmup_epoch):
""" """
if not isinstance(learning_rate, float): if not isinstance(learning_rate, float):
raise TypeError("learning_rate must be float.") raise TypeError("learning_rate must be float.")
validator.check_number_range("learning_rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None) validator.check_non_negative_float(learning_rate, "learning_rate", None)
validator.check_positive_int(warmup_epoch, 'warmup_epoch') validator.check_positive_int(warmup_epoch, 'warmup_epoch')
validator.check_positive_int(total_step, 'total_step') validator.check_positive_int(total_step, 'total_step')
validator.check_positive_int(step_per_epoch, 'step_per_epoch') validator.check_positive_int(step_per_epoch, 'step_per_epoch')

View File

@ -451,8 +451,7 @@ class CentralCrop(Cell):
def __init__(self, central_fraction): def __init__(self, central_fraction):
super(CentralCrop, self).__init__() super(CentralCrop, self).__init__()
validator.check_value_type("central_fraction", central_fraction, [float], self.cls_name) validator.check_value_type("central_fraction", central_fraction, [float], self.cls_name)
self.central_fraction = validator.check_number_range('central_fraction', central_fraction, self.central_fraction = validator.check_float_range(0.0, 1.0, Rel.INC_RIGHT, 'central_fraction', central_fraction, self.cls_name)
0.0, 1.0, Rel.INC_RIGHT, self.cls_name)
self.slice = P.Slice() self.slice = P.Slice()
def construct(self, image): def construct(self, image):

View File

@ -254,7 +254,7 @@ class CosineDecayLR(LearningRateSchedule):
super(CosineDecayLR, self).__init__() super(CosineDecayLR, self).__init__()
if not isinstance(min_lr, float): if not isinstance(min_lr, float):
raise TypeError("min_lr must be float.") raise TypeError("min_lr must be float.")
validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) validator.check_non_negative_float(min_lr, "min_lr", self.cls_name)
validator.check_positive_float(max_lr, 'max_lr', self.cls_name) validator.check_positive_float(max_lr, 'max_lr', self.cls_name)
validator.check_is_float(max_lr, 'max_lr', self.cls_name) validator.check_is_float(max_lr, 'max_lr', self.cls_name)
validator.check_positive_int(decay_steps, "decay_steps", self.cls_name) validator.check_positive_int(decay_steps, "decay_steps", self.cls_name)
@ -322,8 +322,7 @@ class PolynomialDecayLR(LearningRateSchedule):
validator.check_is_float(learning_rate, 'learning_rate') validator.check_is_float(learning_rate, 'learning_rate')
if not isinstance(end_learning_rate, float): if not isinstance(end_learning_rate, float):
raise TypeError("end_learning_rate must be float.") raise TypeError("end_learning_rate must be float.")
validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, validator.check_non_negative_float(end_learning_rate, "end_learning_rate", self.cls_name)
self.cls_name)
validator.check_positive_int(decay_steps, 'decay_steps', self.cls_name) validator.check_positive_int(decay_steps, 'decay_steps', self.cls_name)
validator.check_value_type('update_decay_steps', update_decay_steps, [bool], self.cls_name) validator.check_value_type('update_decay_steps', update_decay_steps, [bool], self.cls_name)
validator.check_positive_float(power, 'power', self.cls_name) validator.check_positive_float(power, 'power', self.cls_name)
@ -387,7 +386,7 @@ class WarmUpLR(LearningRateSchedule):
super(WarmUpLR, self).__init__() super(WarmUpLR, self).__init__()
if not isinstance(learning_rate, float): if not isinstance(learning_rate, float):
raise TypeError("learning_rate must be float.") raise TypeError("learning_rate must be float.")
validator.check_number_range("learning_rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) validator.check_non_negative_float(learning_rate, "learning_rate", self.cls_name)
validator.check_positive_int(warmup_steps, 'warmup_steps', self.cls_name) validator.check_positive_int(warmup_steps, 'warmup_steps', self.cls_name)
self.warmup_steps = warmup_steps self.warmup_steps = warmup_steps
self.learning_rate = learning_rate self.learning_rate = learning_rate

View File

@ -368,7 +368,7 @@ class CosineEmbeddingLoss(_Loss):
self.reduce_sum = P.ReduceSum() self.reduce_sum = P.ReduceSum()
self.maximum = P.Maximum() self.maximum = P.Maximum()
validator.check_value_type("margin", margin, [float], self.cls_name) validator.check_value_type("margin", margin, [float], self.cls_name)
self.margin = validator.check_number_range("margin", margin, -1.0, 1.0, Rel.INC_BOTH, self.cls_name) self.margin = validator.check_float_range(margin, -1.0, 1.0, Rel.INC_BOTH, "margin", self.cls_name)
def construct(self, x1, x2, y): def construct(self, x1, x2, y):
F.same_type_shape(x1, x2) F.same_type_shape(x1, x2)

View File

@ -126,9 +126,9 @@ def _check_param_value(beta1, beta2, eps, prim_name):
validator.check_value_type("beta1", beta1, [float], prim_name) validator.check_value_type("beta1", beta1, [float], prim_name)
validator.check_value_type("beta2", beta2, [float], prim_name) validator.check_value_type("beta2", beta2, [float], prim_name)
validator.check_value_type("eps", eps, [float], prim_name) validator.check_value_type("eps", eps, [float], prim_name)
validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) validator.check_positive_float(eps, "eps", prim_name)
class Adam(Optimizer): class Adam(Optimizer):

View File

@ -177,9 +177,9 @@ def _check_param_value(beta1, beta2, eps, prim_name):
validator.check_value_type("beta1", beta1, [float], prim_name) validator.check_value_type("beta1", beta1, [float], prim_name)
validator.check_value_type("beta2", beta2, [float], prim_name) validator.check_value_type("beta2", beta2, [float], prim_name)
validator.check_value_type("eps", eps, [float], prim_name) validator.check_value_type("eps", eps, [float], prim_name)
validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) validator.check_positive_float(eps, "eps", prim_name)
class Lamb(Optimizer): class Lamb(Optimizer):

View File

@ -70,10 +70,10 @@ def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
validator.check_value_type("beta2", beta2, [float], prim_name) validator.check_value_type("beta2", beta2, [float], prim_name)
validator.check_value_type("eps", eps, [float], prim_name) validator.check_value_type("eps", eps, [float], prim_name)
validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) validator.check_positive_float(eps, "eps", prim_name)
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name) validator.check_non_negative_float(weight_decay, "weight_decay", prim_name)
class LazyAdam(Optimizer): class LazyAdam(Optimizer):

View File

@ -100,7 +100,7 @@ class Optimizer(Cell):
if isinstance(loss_scale, int): if isinstance(loss_scale, int):
loss_scale = float(loss_scale) loss_scale = float(loss_scale)
validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name) validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name)
validator.check_number_range("loss_scale", loss_scale, 0.0, float("inf"), Rel.INC_NEITHER, self.cls_name) validator.check_positive_float(loss_scale, "loss_scale", self.cls_name)
self.loss_scale = loss_scale self.loss_scale = loss_scale
weight_decay = self._preprocess_weight_decay(weight_decay) weight_decay = self._preprocess_weight_decay(weight_decay)
@ -221,7 +221,7 @@ class Optimizer(Cell):
"""Check weight decay, and convert int to float.""" """Check weight decay, and convert int to float."""
if isinstance(weight_decay, (float, int)): if isinstance(weight_decay, (float, int)):
weight_decay = float(weight_decay) weight_decay = float(weight_decay)
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) validator.check_non_negative_float(weight_decay, "weight_decay", self.cls_name)
return weight_decay return weight_decay
raise TypeError("Weight decay should be int or float.") raise TypeError("Weight decay should be int or float.")
@ -229,7 +229,7 @@ class Optimizer(Cell):
"""Check lr value, and convert lr to a float, a Tensor or a LearningRateSchedule.""" """Check lr value, and convert lr to a float, a Tensor or a LearningRateSchedule."""
if isinstance(learning_rate, (float, int)): if isinstance(learning_rate, (float, int)):
learning_rate = float(learning_rate) learning_rate = float(learning_rate)
validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) validator.check_non_negative_float(learning_rate, "learning rate", self.cls_name)
return learning_rate return learning_rate
if isinstance(learning_rate, Tensor) and learning_rate.dim() == 0: if isinstance(learning_rate, Tensor) and learning_rate.dim() == 0:
return learning_rate return learning_rate

View File

@ -45,9 +45,9 @@ def _check_param_value(accum, l1, l2, use_locking, prim_name=None):
validator.check_value_type("l1", l1, [float], prim_name) validator.check_value_type("l1", l1, [float], prim_name)
validator.check_value_type("l2", l2, [float], prim_name) validator.check_value_type("l2", l2, [float], prim_name)
validator.check_value_type("use_locking", use_locking, [bool], prim_name) validator.check_value_type("use_locking", use_locking, [bool], prim_name)
validator.check_number_range("accum", accum, 0.0, float("inf"), Rel.INC_LEFT, prim_name) validator.check_non_negative_float(accum, "accum", prim_name)
validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, prim_name) validator.check_non_negative_float(l1, "l1", prim_name)
validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, prim_name) validator.check_non_negative_float(l2, "l2", prim_name)
class ProximalAdagrad(Optimizer): class ProximalAdagrad(Optimizer):

View File

@ -154,11 +154,11 @@ class RMSProp(Optimizer):
use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0): use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0):
super(RMSProp, self).__init__(learning_rate, params, weight_decay, loss_scale) super(RMSProp, self).__init__(learning_rate, params, weight_decay, loss_scale)
validator.check_value_type("decay", decay, [float], self.cls_name) validator.check_value_type("decay", decay, [float], self.cls_name)
validator.check_number_range("decay", decay, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) validator.check_non_negative_float(decay, "decay", self.cls_name)
validator.check_value_type("momentum", momentum, [float], self.cls_name) validator.check_value_type("momentum", momentum, [float], self.cls_name)
validator.check_number_range("momentum", momentum, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) validator.check_non_negative_float(momentum, "momentum", self.cls_name)
validator.check_value_type("epsilon", epsilon, [float], self.cls_name) validator.check_value_type("epsilon", epsilon, [float], self.cls_name)
validator.check_number_range("epsilon", epsilon, 0.0, float("inf"), Rel.INC_NEITHER, self.cls_name) validator.check_positive_float(epsilon, "epsilon", self.cls_name)
validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) validator.check_value_type("use_locking", use_locking, [bool], self.cls_name)
validator.check_value_type("centered", centered, [bool], self.cls_name) validator.check_value_type("centered", centered, [bool], self.cls_name)

View File

@ -69,7 +69,7 @@ def get_concat_offset(x_shp, x_type, axis, prim_name):
validator.check_subclass("shape0", x_type[0], mstype.tensor, prim_name) validator.check_subclass("shape0", x_type[0], mstype.tensor, prim_name)
validator.check_positive_int(len(x_shp[0]), "len of x_shp[0]", prim_name) validator.check_positive_int(len(x_shp[0]), "len of x_shp[0]", prim_name)
rank_base = len(x_shp[0]) rank_base = len(x_shp[0])
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH, prim_name) validator.check_int_range(axis, -rank_base - 1, rank_base, Rel.INC_BOTH, 'axis', prim_name)
if axis < 0: if axis < 0:
axis = axis + rank_base axis = axis + rank_base
all_shp = x_shp[0][axis] all_shp = x_shp[0][axis]

View File

@ -188,7 +188,7 @@ class BatchNormGrad(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, is_training=False, epsilon=1e-5): def __init__(self, is_training=False, epsilon=1e-5):
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
self.add_prim_attr('data_format', "NCHW") self.add_prim_attr('data_format', "NCHW")
def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape): def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape):
@ -485,7 +485,7 @@ class DropoutGrad(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, keep_prob=0.5): def __init__(self, keep_prob=0.5):
self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0, 1, Rel.INC_RIGHT, self.name) self.keep_prob = validator.check_float_range(keep_prob, 0, 1, Rel.INC_RIGHT, "keep_prob", self.name)
def infer_shape(self, dy_shape, mask_shape): def infer_shape(self, dy_shape, mask_shape):
return dy_shape return dy_shape
@ -902,7 +902,7 @@ class LogSoftmaxGrad(PrimitiveWithInfer):
def infer_shape(self, dout, logits): def infer_shape(self, dout, logits):
rank = len(logits) rank = len(logits)
validator.check_int_range('axis', self.axis, -rank - 1, rank, Rel.INC_BOTH, self.name) validator.check_int_range(self.axis, -rank - 1, rank, Rel.INC_BOTH, 'axis', self.name)
return logits return logits
def infer_dtype(self, dout, logits): def infer_dtype(self, dout, logits):
@ -921,7 +921,7 @@ class LSTMGradData(PrimitiveWithInfer):
self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name) self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name) self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name) self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name)
if bidirectional: if bidirectional:
self.num_directions = 2 self.num_directions = 2
@ -970,7 +970,7 @@ class LSTMGradWeight(PrimitiveWithInfer):
self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name) self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name) self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name) self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name)
if bidirectional: if bidirectional:
self.num_directions = 2 self.num_directions = 2
@ -1005,7 +1005,7 @@ class LSTMGrad(PrimitiveWithInfer):
self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name) self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name) self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name) self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name)
if bidirectional: if bidirectional:
self.num_directions = 2 self.num_directions = 2
@ -1652,7 +1652,7 @@ class BasicLSTMCellInputGrad(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, keep_prob): def __init__(self, keep_prob):
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0.0, 1.0, Rel.INC_BOTH, self.name) self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name)
self.add_prim_attr("io_format", "ND") self.add_prim_attr("io_format", "ND")
def infer_shape(self, dgate_shape, w_shape): def infer_shape(self, dgate_shape, w_shape):

View File

@ -76,8 +76,7 @@ class MinMaxUpdatePerLayer(PrimitiveWithInfer):
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name) self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.ema_decay = validator.check_number_range( self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'], self.init_prim_io_names(inputs=['x', 'min', 'max'],
outputs=['min_up', 'max_up']) outputs=['min_up', 'max_up'])
@ -136,10 +135,9 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name) self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.ema_decay = validator.check_number_range( self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
if self.is_ascend: if self.is_ascend:
self.channel_axis = validator.check_int_range('channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name) self.channel_axis = validator.check_int_range(channel_axis, 0, 1, Rel.INC_BOTH, 'channel_axis', self.name)
else: else:
self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name) self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
self.init_prim_io_names( self.init_prim_io_names(
@ -222,10 +220,8 @@ class FakeQuantPerLayer(PrimitiveWithInfer):
'symmetric', symmetric, (bool,), self.name) 'symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type( self.narrow_range = validator.check_value_type(
'narrow_range', narrow_range, (bool,), self.name) 'narrow_range', narrow_range, (bool,), self.name)
self.training = validator.check_value_type( self.training = validator.check_value_type('training', training, (bool,), self.name)
'training', training, (bool,), self.name) self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name) self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name) self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'], self.init_prim_io_names(inputs=['x', 'min', 'max'],
@ -366,12 +362,11 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
'narrow_range', narrow_range, (bool,), self.name) 'narrow_range', narrow_range, (bool,), self.name)
self.training = validator.check_value_type( self.training = validator.check_value_type(
'training', training, (bool,), self.name) 'training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range( self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name) self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name) self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name)
if self.is_ascend: if self.is_ascend:
self.channel_axis = validator.check_int_range('channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name) self.channel_axis = validator.check_int_range(channel_axis, 0, 1, Rel.INC_BOTH, 'channel_axis', self.name)
else: else:
self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name) self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out']) self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out'])
@ -495,7 +490,7 @@ class BatchNormFold(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0): def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0):
"""Initialize batch norm fold layer""" """Initialize batch norm fold layer"""
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name) self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
@ -806,7 +801,7 @@ class BatchNormFoldD(PrimitiveWithInfer):
def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0): def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0):
"""Initialize _BatchNormFold layer""" """Initialize _BatchNormFold layer"""
from mindspore.ops._op_impl._custom_op import batchnorm_fold from mindspore.ops._op_impl._custom_op import batchnorm_fold
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name) self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)

View File

@ -129,7 +129,7 @@ class ExpandDims(PrimitiveWithInfer):
x_shape = list(x['shape']) x_shape = list(x['shape'])
axis_v = axis['value'] axis_v = axis['value']
rank = len(x_shape) rank = len(x_shape)
validator.check_int_range('axis', axis_v, -rank - 1, rank, Rel.INC_BOTH, self.name) validator.check_int_range(axis_v, -rank - 1, rank, Rel.INC_BOTH, 'axis', self.name)
value = None value = None
if x['value'] is not None: if x['value'] is not None:
value = x['value'].asnumpy() value = x['value'].asnumpy()
@ -534,7 +534,7 @@ class Squeeze(PrimitiveWithInfer):
ret = [d for d in x_shape if d != 1] ret = [d for d in x_shape if d != 1]
else: else:
for a in axis: for a in axis:
validator.check_int_range('axis or its elements', a, -ndim, ndim - 1, Rel.INC_BOTH, self.name) validator.check_int_range(a, -ndim, ndim - 1, Rel.INC_BOTH, 'axis or its elements', self.name)
if x_shape[a] != 1: if x_shape[a] != 1:
raise ValueError('Cannot select an axis to squeeze out which has size not equal to one.') raise ValueError('Cannot select an axis to squeeze out which has size not equal to one.')
ret = [x_shape[i] for i in range(ndim) if not (i in axis or (i - ndim) in axis)] ret = [x_shape[i] for i in range(ndim) if not (i in axis or (i - ndim) in axis)]
@ -658,7 +658,7 @@ class GatherV2(PrimitiveWithCheck):
axis_v = axis['value'] axis_v = axis['value']
params_shp = params['shape'] params_shp = params['shape']
rank = len(params_shp) rank = len(params_shp)
validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name) validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
if axis_v < 0: if axis_v < 0:
axis_v += rank axis_v += rank
@ -777,7 +777,7 @@ class Split(PrimitiveWithInfer):
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
x_shape = list(x['shape']) x_shape = list(x['shape'])
dim = len(x_shape) dim = len(x_shape)
validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name) validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name)
validator.check_positive_int(self.output_num, "output_num", self.name) validator.check_positive_int(self.output_num, "output_num", self.name)
output_valid_check = x_shape[self.axis] % self.output_num output_valid_check = x_shape[self.axis] % self.output_num
if output_valid_check != 0: if output_valid_check != 0:
@ -1224,7 +1224,7 @@ class Argmax(PrimitiveWithInfer):
if axis is None: if axis is None:
axis = 0 axis = 0
x_rank = len(x_shape) x_rank = len(x_shape)
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name) validator.check_int_range(axis, -x_rank, x_rank, Rel.INC_LEFT, "axis", self.name)
axis = axis + x_rank if axis < 0 else axis axis = axis + x_rank if axis < 0 else axis
ouput_shape = [x_shape[i] for i in range(x_rank) if i != axis] ouput_shape = [x_shape[i] for i in range(x_rank) if i != axis]
return ouput_shape return ouput_shape
@ -1272,7 +1272,7 @@ class Argmin(PrimitiveWithInfer):
if axis is None: if axis is None:
axis = 0 axis = 0
x_rank = len(x_shape) x_rank = len(x_shape)
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name) validator.check_int_range(axis, -x_rank, x_rank, Rel.INC_LEFT, "axis", self.name)
axis = axis + x_rank if axis < 0 else axis axis = axis + x_rank if axis < 0 else axis
ouput_shape = [x_shape[i] for i in range(x_rank) if i != axis] ouput_shape = [x_shape[i] for i in range(x_rank) if i != axis]
return ouput_shape return ouput_shape
@ -1325,7 +1325,7 @@ class ArgMaxWithValue(PrimitiveWithInfer):
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
axis = self.axis axis = self.axis
x_rank = len(x_shape) x_rank = len(x_shape)
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name) validator.check_int_range(axis, -x_rank, x_rank, Rel.INC_LEFT, "axis", self.name)
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name) ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name)
return ouput_shape, ouput_shape return ouput_shape, ouput_shape
@ -1377,7 +1377,7 @@ class ArgMinWithValue(PrimitiveWithInfer):
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
axis = self.axis axis = self.axis
x_rank = len(x_shape) x_rank = len(x_shape)
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name) validator.check_int_range(axis, -x_rank, x_rank, Rel.INC_LEFT, "axis", self.name)
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name) ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name)
return ouput_shape, ouput_shape return ouput_shape, ouput_shape
@ -1760,7 +1760,7 @@ def _get_pack_shape(x_shape, x_type, axis, prim_name):
rank_base = len(x_shape[0]) rank_base = len(x_shape[0])
N = len(x_shape) N = len(x_shape)
out_shape = x_shape[0] out_shape = x_shape[0]
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH, prim_name) validator.check_int_range(axis, -rank_base - 1, rank_base, Rel.INC_BOTH, 'axis', prim_name)
if axis < 0: if axis < 0:
axis = axis + rank_base + 1 axis = axis + rank_base + 1
for i in range(1, N): for i in range(1, N):
@ -1863,7 +1863,7 @@ class Unpack(PrimitiveWithInfer):
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
x_shape = list(x['shape']) x_shape = list(x['shape'])
dim = len(x_shape) dim = len(x_shape)
validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name) validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name)
if self.axis < 0: if self.axis < 0:
self.axis = self.axis + dim self.axis = self.axis + dim
output_num = x_shape[self.axis] output_num = x_shape[self.axis]
@ -1965,7 +1965,7 @@ class ReverseV2(PrimitiveWithInfer):
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
dim = len(x_shape) dim = len(x_shape)
for i, each in enumerate(self.axis): for i, each in enumerate(self.axis):
validator.check_int_range(f'axis[{i}]', each, -dim, dim, Rel.INC_LEFT, self.name) validator.check_int_range(each, -dim, dim, Rel.INC_LEFT, f'axis[{i}]', self.name)
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):

View File

@ -206,7 +206,7 @@ class _HostAllGather(PrimitiveWithInfer):
validator.check_value_type('group', group, (tuple, list), self.name) validator.check_value_type('group', group, (tuple, list), self.name)
validator.check_integer("group size", len(group), 2, Rel.GE, self.name) validator.check_integer("group size", len(group), 2, Rel.GE, self.name)
for r in group: for r in group:
validator.check_int_range("rank_id", r, 0, 7, Rel.INC_BOTH, self.name) validator.check_int_range(r, 0, 7, Rel.INC_BOTH, "rank_id", self.name)
validator.check_value_type("rank_id", r, (int,), self.name) validator.check_value_type("rank_id", r, (int,), self.name)
self.group_size = len(group) self.group_size = len(group)
self.add_prim_attr('group', group) self.add_prim_attr('group', group)
@ -315,7 +315,7 @@ class _HostReduceScatter(PrimitiveWithInfer):
validator.check_value_type('group', group, (tuple, list), self.name) validator.check_value_type('group', group, (tuple, list), self.name)
validator.check_integer("group size", len(group), 2, Rel.GE, self.name) validator.check_integer("group size", len(group), 2, Rel.GE, self.name)
for r in group: for r in group:
validator.check_int_range("rank_id", r, 0, 7, Rel.INC_BOTH, self.name) validator.check_int_range(r, 0, 7, Rel.INC_BOTH, "rank_id", self.name)
validator.check_value_type("rank_id", r, (int,), self.name) validator.check_value_type("rank_id", r, (int,), self.name)
self.op = op self.op = op
self.group_size = len(group) self.group_size = len(group)

View File

@ -70,8 +70,7 @@ class ControlDepend(Primitive):
@prim_attr_register @prim_attr_register
def __init__(self, depend_mode=0): def __init__(self, depend_mode=0):
"""init""" """init"""
validator.check_int_range( validator.check_int_range(depend_mode, 0, 1, Rel.INC_BOTH, "depend_mode", self.name)
"depend_mode", depend_mode, 0, 1, Rel.INC_BOTH, self.name)
def __call__(self, src, dst): def __call__(self, src, dst):
return src return src

View File

@ -31,7 +31,7 @@ def _infer_shape_reduce(x, axis, keep_dims, prim_name):
"""Common infer for reduce operator""" """Common infer for reduce operator"""
def reduce_one_axis(one_axis): def reduce_one_axis(one_axis):
validator.check_int_range('axis', one_axis, -dim, dim, Rel.INC_LEFT, prim_name) validator.check_int_range(one_axis, -dim, dim, Rel.INC_LEFT, 'axis', prim_name)
if one_axis < 0: if one_axis < 0:
one_axis += dim one_axis += dim
axis_reduce.add(one_axis) axis_reduce.add(one_axis)

View File

@ -149,7 +149,7 @@ class Softmax(PrimitiveWithInfer):
validator.check_integer("length of axis", len(self.axis), 1, Rel.GE, self.name) validator.check_integer("length of axis", len(self.axis), 1, Rel.GE, self.name)
rank = len(logits) rank = len(logits)
for axis_v in self.axis: for axis_v in self.axis:
validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name) validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
return logits return logits
def infer_dtype(self, logits): def infer_dtype(self, logits):
@ -193,7 +193,7 @@ class LogSoftmax(PrimitiveWithInfer):
def infer_shape(self, logits): def infer_shape(self, logits):
rank = len(logits) rank = len(logits)
validator.check_int_range('axis', self.axis, -rank, rank, Rel.INC_LEFT, self.name) validator.check_int_range(self.axis, -rank, rank, Rel.INC_LEFT, 'axis', self.name)
return logits return logits
def infer_dtype(self, logits): def infer_dtype(self, logits):
@ -637,8 +637,8 @@ class FusedBatchNorm(Primitive):
self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'], self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'],
outputs=['y', 'running_mean', 'running_variance', 'save_mean', 'save_inv_variance']) outputs=['y', 'running_mean', 'running_variance', 'save_mean', 'save_inv_variance'])
self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name) self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name)
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
self._update_parameter = True self._update_parameter = True
@ -710,8 +710,8 @@ class FusedBatchNormEx(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'], self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'],
outputs=['y', 'save_scale', 'save_bias', 'save_mean', 'save_inv_variance', 'reserve']) outputs=['y', 'save_scale', 'save_bias', 'save_mean', 'save_inv_variance', 'reserve'])
self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name) self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name)
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
self._update_parameter = True self._update_parameter = True
self.add_prim_attr('data_format', "NCHW") self.add_prim_attr('data_format', "NCHW")
@ -818,8 +818,8 @@ class BNTrainingUpdate(PrimitiveWithInfer):
validator.check_value_type("isRef", isRef, [bool], self.name) validator.check_value_type("isRef", isRef, [bool], self.name)
validator.check_value_type("epsilon", epsilon, [float], self.name) validator.check_value_type("epsilon", epsilon, [float], self.name)
validator.check_value_type("factor", factor, [float], self.name) validator.check_value_type("factor", factor, [float], self.name)
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, 'BNTrainingUpdate') self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', 'BNTrainingUpdate')
self.factor = validator.check_number_range('factor', factor, 0, 1, Rel.INC_BOTH, 'BNTrainingUpdate') self.factor = validator.check_float_range(factor, 0, 1, Rel.INC_BOTH, 'factor', 'BNTrainingUpdate')
def infer_shape(self, x, sum, square_sum, scale, b, mean, variance): def infer_shape(self, x, sum, square_sum, scale, b, mean, variance):
validator.check_integer("x rank", len(x), 4, Rel.EQ, self.name) validator.check_integer("x rank", len(x), 4, Rel.EQ, self.name)
@ -898,7 +898,7 @@ class BatchNorm(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, is_training=False, epsilon=1e-5): def __init__(self, is_training=False, epsilon=1e-5):
validator.check_value_type('is_training', is_training, (bool,), self.name) validator.check_value_type('is_training', is_training, (bool,), self.name)
validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
self.add_prim_attr('data_format', "NCHW") self.add_prim_attr('data_format', "NCHW")
self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'], self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'],
outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2']) outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2'])
@ -2383,7 +2383,7 @@ class L2Normalize(PrimitiveWithInfer):
def infer_shape(self, input_x): def infer_shape(self, input_x):
dim = len(input_x) dim = len(input_x)
validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name) validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name)
return input_x return input_x
def infer_dtype(self, input_x): def infer_dtype(self, input_x):
@ -2481,10 +2481,10 @@ class DropoutDoMask(PrimitiveWithInfer):
keep_prob_v = keep_prob['value'] keep_prob_v = keep_prob['value']
if keep_prob_v is not None: if keep_prob_v is not None:
if isinstance(keep_prob['dtype'], type(mstype.tensor)): if isinstance(keep_prob['dtype'], type(mstype.tensor)):
validator.check_number_range('keep_prob', keep_prob_v.asnumpy(), 0, 1, Rel.INC_BOTH, self.name) validator.check_float_range(keep_prob_v.asnumpy(), 0, 1, Rel.INC_BOTH, 'keep_prob', self.name)
else: else:
validator.check_value_type("keep_prob", keep_prob_v, [float], self.name) validator.check_value_type("keep_prob", keep_prob_v, [float], self.name)
validator.check_number_range('keep_prob', keep_prob_v, 0, 1, Rel.INC_BOTH, self.name) validator.check_float_range(keep_prob_v, 0, 1, Rel.INC_BOTH, 'keep_prob', self.name)
out = {'shape': input_x_shape, out = {'shape': input_x_shape,
'dtype': input_x['dtype'], 'dtype': input_x['dtype'],
@ -2584,7 +2584,7 @@ class OneHot(PrimitiveWithInfer):
# check shape # check shape
indices_shp = indices['shape'] indices_shp = indices['shape']
validator.check_int_range("axis", self.axis, -1, len(indices_shp), Rel.INC_BOTH, self.name) validator.check_int_range(self.axis, -1, len(indices_shp), Rel.INC_BOTH, "axis", self.name)
depth_val = depth['value'] depth_val = depth['value']
validator.check_non_negative_int(depth_val, "depth", self.name) validator.check_non_negative_int(depth_val, "depth", self.name)
# create new dimension at end if self.axis is -1 # create new dimension at end if self.axis is -1
@ -2771,7 +2771,7 @@ class LSTM(PrimitiveWithInfer):
self.has_bias = validator.check_value_type("has_bias", has_bias, (bool,), self.name) self.has_bias = validator.check_value_type("has_bias", has_bias, (bool,), self.name)
self.bidirectional = validator.check_value_type("bidirectional", bidirectional, (bool,), self.name) self.bidirectional = validator.check_value_type("bidirectional", bidirectional, (bool,), self.name)
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name) self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name)
if bidirectional: if bidirectional:
self.num_directions = 2 self.num_directions = 2
@ -3054,7 +3054,7 @@ class ROIAlign(PrimitiveWithInfer):
validator.check_value_type("spatial_scale", spatial_scale, [float], self.name) validator.check_value_type("spatial_scale", spatial_scale, [float], self.name)
validator.check_value_type("sample_num", sample_num, [int], self.name) validator.check_value_type("sample_num", sample_num, [int], self.name)
validator.check_value_type("roi_end_mode", roi_end_mode, [int], self.name) validator.check_value_type("roi_end_mode", roi_end_mode, [int], self.name)
validator.check_int_range("roi_end_mode", roi_end_mode, 0, 1, Rel.INC_BOTH, self.name) validator.check_int_range(roi_end_mode, 0, 1, Rel.INC_BOTH, "roi_end_mode", self.name)
self.pooled_height = pooled_height self.pooled_height = pooled_height
self.pooled_width = pooled_width self.pooled_width = pooled_width
self.spatial_scale = spatial_scale self.spatial_scale = spatial_scale
@ -3502,9 +3502,9 @@ class FusedSparseFtrl(PrimitiveWithInfer):
validator.check_value_type("l1", l1, [float], self.name) validator.check_value_type("l1", l1, [float], self.name)
validator.check_value_type("l2", l2, [float], self.name) validator.check_value_type("l2", l2, [float], self.name)
validator.check_value_type("lr_power", lr_power, [float], self.name) validator.check_value_type("lr_power", lr_power, [float], self.name)
self.lr = validator.check_number_range("lr", lr, 0.0, float("inf"), Rel.INC_NEITHER, self.name) self.lr = validator.check_positive_float(lr, "lr", self.name)
self.l1 = validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, self.name) self.l1 = validator.check_non_negative_float(l1, "l1", self.name)
self.l2 = validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, self.name) self.l2 = validator.check_non_negative_float(l2, "l2", self.name)
self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name) self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name)
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
@ -4240,7 +4240,7 @@ class SparseApplyAdagrad(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, lr, update_slots=True, use_locking=False): def __init__(self, lr, update_slots=True, use_locking=False):
validator.check_value_type("lr", lr, [float], self.name) validator.check_value_type("lr", lr, [float], self.name)
validator.check_number_range("lr", lr, float("-inf"), float("inf"), Rel.INC_NEITHER, self.name) validator.check_is_float(lr, "lr", self.name)
validator.check_value_type("update_slots", update_slots, [bool], self.name) validator.check_value_type("update_slots", update_slots, [bool], self.name)
validator.check_value_type("use_locking", use_locking, [bool], self.name) validator.check_value_type("use_locking", use_locking, [bool], self.name)
@ -5142,9 +5142,9 @@ class SparseApplyFtrl(PrimitiveWithCheck):
validator.check_value_type("l1", l1, [float], self.name) validator.check_value_type("l1", l1, [float], self.name)
validator.check_value_type("l2", l2, [float], self.name) validator.check_value_type("l2", l2, [float], self.name)
validator.check_value_type("lr_power", lr_power, [float], self.name) validator.check_value_type("lr_power", lr_power, [float], self.name)
self.lr = validator.check_number_range("lr", lr, 0.0, float("inf"), Rel.INC_NEITHER, self.name) self.lr = validator.check_positive_float(lr, "lr", self.name)
self.l1 = validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, self.name) self.l1 = validator.check_non_negative_float(l1, "l1", self.name)
self.l2 = validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, self.name) self.l2 = validator.check_non_negative_float(l2, "l2", self.name)
self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name) self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name)
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'indices'], self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'indices'],
@ -5239,9 +5239,9 @@ class SparseApplyFtrlV2(PrimitiveWithInfer):
validator.check_value_type("l1", l1, [float], self.name) validator.check_value_type("l1", l1, [float], self.name)
validator.check_value_type("l2", l2, [float], self.name) validator.check_value_type("l2", l2, [float], self.name)
validator.check_value_type("lr_power", lr_power, [float], self.name) validator.check_value_type("lr_power", lr_power, [float], self.name)
self.lr = validator.check_number_range("lr", lr, 0.0, float("inf"), Rel.INC_NEITHER, self.name) self.lr = validator.check_positive_float(lr, "lr", self.name)
self.l1 = validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, self.name) self.l1 = validator.check_non_negative_float(l1, "l1", self.name)
self.l2 = validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, self.name) self.l2 = validator.check_non_negative_float(l2, "l2", self.name)
self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name) self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name)
self.l2_shrinkage = validator.check_value_type("l2_shrinkage", l2_shrinkage, [float], self.name) self.l2_shrinkage = validator.check_value_type("l2_shrinkage", l2_shrinkage, [float], self.name)
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
@ -5285,7 +5285,7 @@ class Dropout(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, keep_prob=0.5): def __init__(self, keep_prob=0.5):
self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0, 1, Rel.INC_RIGHT, self.name) self.keep_prob = validator.check_float_range(keep_prob, 0, 1, Rel.INC_RIGHT, "keep_prob", self.name)
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
validator.check_integer("x_shape", len(x_shape), 1, Rel.GE, self.name) validator.check_integer("x_shape", len(x_shape), 1, Rel.GE, self.name)
@ -5510,7 +5510,7 @@ class BasicLSTMCell(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'): def __init__(self, keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'):
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0.0, 1.0, Rel.INC_BOTH, self.name) self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name)
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
self.state_is_tuple = validator.check_value_type("state_is_tuple", state_is_tuple, [bool], self.name) self.state_is_tuple = validator.check_value_type("state_is_tuple", state_is_tuple, [bool], self.name)
self.activation = validator.check_string(activation, ['tanh'], "activation", self.name) self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)

View File

@ -100,10 +100,8 @@ def _generate_cosine_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps):
lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps) lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
lr = float(lr_init) + lr_inc * (i + 1) lr = float(lr_init) + lr_inc * (i + 1)
else: else:
linear_decay = (total_steps - i) / decay_steps cosine_decay = 0.5 * (1 + math.cos(math.pi * (i - warmup_steps) / decay_steps))
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps)) lr = (lr_max - lr_end) * cosine_decay + lr_end
decayed = linear_decay * cosine_decay + 0.00001
lr = lr_max * decayed
lr_each_step.append(lr) lr_each_step.append(lr)
return lr_each_step return lr_each_step

View File

@ -122,7 +122,7 @@ class MySparseGatherV2(PrimitiveWithInfer):
axis_v = axis['value'] axis_v = axis['value']
params_shp = params['shape'] params_shp = params['shape']
rank = len(params_shp) rank = len(params_shp)
validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name) validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
if axis_v < 0: if axis_v < 0:
axis_v += rank axis_v += rank
out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:] out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:]
@ -208,10 +208,10 @@ def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
validator.check_value_type("beta2", beta2, [float], prim_name) validator.check_value_type("beta2", beta2, [float], prim_name)
validator.check_value_type("eps", eps, [float], prim_name) validator.check_value_type("eps", eps, [float], prim_name)
validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) validator.check_positive_float(eps, "eps", prim_name)
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name) validator.check_non_negative_float(weight_decay, "weight_decay", prim_name)
class AdamWeightDecaySparse(Optimizer): class AdamWeightDecaySparse(Optimizer):

View File

@ -14,55 +14,97 @@
# ============================================================================ # ============================================================================
""" test checkparameter """ """ test checkparameter """
import pytest import pytest
import numpy as np
from mindspore._checkparam import check_int, check_input_format, Validator, twice from mindspore._checkparam import check_input_format, Validator, twice, Rel
kernel_size = 5 kernel_size = 5
kernel_size1 = twice(kernel_size) kernel_size1 = twice(kernel_size)
assert kernel_size1 == (5, 5) assert kernel_size1 == (5, 5)
def test_check_integer1():
with pytest.raises(TypeError):
Validator.check_integer("input", 0, Rel.GE, "number")
def test_check_int_1(): def test_check_integer2():
assert check_int(3) == 3
def check_int_positive_1():
with pytest.raises(ValueError): with pytest.raises(ValueError):
Validator.check_positive_int(-1) Validator.check_integer(-1, 0, Rel.GE, "number")
def test_check_integer3():
input = np.random.randint(0, 100)
assert Validator.check_integer(input, 0, Rel.GE, "number") == input
def test_NCHW1(): def test_check_int1():
assert check_input_format("NCHW") == "NCHW" input = np.random.randint(-100, 100)
assert Validator.check_is_int(input) == input
def test_check_int2():
with pytest.raises(TypeError):
Validator.check_is_int(3.3)
def test_NCHW3(): def test_check_int3():
with pytest.raises(TypeError):
Validator.check_is_int("str")
def test_check_int4():
with pytest.raises(TypeError):
Validator.check_is_int(True)
def test_check_is_int5():
with pytest.raises(TypeError):
Validator.check_is_int(True)
with pytest.raises(TypeError):
Validator.check_is_int(False)
def test_check_positive_int1():
input = np.random.randint(0, 100)
assert Validator.check_positive_int(input) == input
def test_check_positive_int2():
input = np.random.randint(-100, 0)
with pytest.raises(ValueError): with pytest.raises(ValueError):
check_input_format("rt") Validator.check_positive_int(input)
def test_check_positive_int3():
with pytest.raises(ValueError):
Validator.check_positive_int(3.3)
def test_check_int_2(): def test_check_positive_int4():
with pytest.raises(TypeError): with pytest.raises(TypeError):
check_int(3.3) Validator.check_positive_int("str")
def test_check_negative_int1():
input = np.random.randint(-100, -1)
assert Validator.check_negative_int(input) == input
def test_check_int_3(): def test_check_negative_int2():
input = np.random.randint(0, 100)
with pytest.raises(ValueError):
Validator.check_negative_int(input)
def test_check_negative_int3():
with pytest.raises(ValueError):
Validator.check_negative_int(3.3)
def test_check_negative_int4():
with pytest.raises(TypeError): with pytest.raises(TypeError):
check_int("str") Validator.check_negative_int("str")
def test_check_non_positive_int1():
input = np.random.randint(-100, 0)
assert Validator.check_non_positive_int(input) == input
def test_check_int_4(): def test_check_non_positive_int2():
input = np.random.randint(1, 100)
with pytest.raises(ValueError):
Validator.check_non_positive_int(input)
def test_check_non_positive_int3():
with pytest.raises(ValueError):
Validator.check_non_positive_int(3.3)
def test_check_non_positive_int4():
with pytest.raises(TypeError): with pytest.raises(TypeError):
check_int(True) Validator.check_non_positive_int("str")
def test_check_int_5():
check_int(0)
check_int(1)
with pytest.raises(TypeError):
check_int(True)
with pytest.raises(TypeError):
check_int(False)
def test_check_bool_1(): def test_check_bool_1():
assert Validator.check_bool(True) assert Validator.check_bool(True)