!7311 [ME] change _check_parameter format

Merge pull request !7311 from chenzhongming/zomi_master
This commit is contained in:
mindspore-ci-bot 2020-10-15 20:52:29 +08:00 committed by Gitee
commit c68f36c81e
22 changed files with 224 additions and 171 deletions

View File

@ -94,10 +94,10 @@ rel_strs = {
def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=None):
"""
Check argument integer.
Check argument integer.
Usage:
- number = check_integer(number, 0, Rel.GE, "number", None) # number >= 0
Usage:
- number = check_integer(number, 0, Rel.GE, "number", None) # number >= 0
"""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, arg_type) or isinstance(arg_value, bool)
@ -122,13 +122,33 @@ def check_is_number(arg_value, arg_type, arg_name=None, prim_name=None):
"""
prim_name = f'in \'{prim_name}\'' if prim_name else ''
arg_name = f'\'{prim_name}\'' if arg_name else 'Input value'
if isinstance(arg_value, arg_type):
if isinstance(arg_value, arg_type) and not isinstance(arg_value, bool):
if math.isinf(arg_value) or math.isnan(arg_value):
raise ValueError(f'{arg_name} {prim_name} must be legal float, but got `{arg_value}`.')
return arg_value
raise TypeError(f'{arg_name} {prim_name} must be float, but got `{type(arg_value).__name__}`')
def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg_name=None, prim_name=None):
"""
Method for checking whether an int value is in some range.
Usage:
- number = check_number_range(number, 0.0, 1.0, Rel.INC_NEITHER, "number", float) # number in [0.0, 1.0]
- number = check_number_range(number, 0, 1, Rel.INC_NEITHER, "number", int) # number in [0, 1]
"""
prim_name = f'in `{prim_name}`' if prim_name else ''
arg_name = f'`{arg_name}`' if arg_name else ''
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, (np.ndarray, np.generic, value_type)) or isinstance(arg_value, bool)
excp_cls = TypeError if type_mismatch else ValueError
if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit):
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
raise excp_cls("{} {} should be in range of {}, but got {:.3f} with type {}.".format(
arg_name, prim_name, rel_str, arg_value, type(arg_value).__name__))
return arg_value
class Validator:
"""validator for checking input parameters"""
@ -147,16 +167,13 @@ class Validator:
@staticmethod
def check_integer(arg_name, arg_value, value, rel, prim_name=None):
"""Check argument is integer"""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
excp_cls = TypeError if type_mismatch else ValueError
if type_mismatch or not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(value)
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
raise excp_cls(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got `{arg_value}`'
f' with type `{type(arg_value).__name__}`.')
return arg_value
"""
Checks input integer value `arg_value` compare to `value`.
Usage:
- number = check_integer(number, 0, Rel.GE, "number", None) # number >= 0
"""
return check_number(arg_value, value, rel, int, arg_name, prim_name)
@staticmethod
def check_is_int(arg_value, arg_name=None, prim_name=None):
@ -168,7 +185,7 @@ class Validator:
- number = check_is_int(number, int, "bias")
- number = check_is_int(number, int, "bias", "bias_class")
"""
check_is_number(arg_value, int, arg_name, prim_name)
return check_is_number(arg_value, int, arg_name, prim_name)
@staticmethod
def check_positive_int(arg_value, arg_name=None, prim_name=None):
@ -214,6 +231,16 @@ class Validator:
"""
return check_number(arg_value, 0, Rel.GE, int, arg_name, prim_name)
@staticmethod
def check_float(arg_value, value, rel, arg_name=None, prim_name=None):
"""
Checks input float value `arg_value` compare to `value`.
Usage:
- number = check_float(number, 0.0, Rel.GE, "number", None) # number >= 0
"""
return check_number(arg_value, value, rel, float, arg_name, prim_name)
@staticmethod
def check_is_float(arg_value, arg_name=None, prim_name=None):
"""
@ -224,7 +251,7 @@ class Validator:
- number = check_is_float(number, int, "bias")
- number = check_is_float(number, int, "bias", "bias_class")
"""
check_is_number(arg_value, float, arg_name, prim_name)
return check_is_number(arg_value, float, arg_name, prim_name)
@staticmethod
def check_positive_float(arg_value, arg_name=None, prim_name=None):
@ -302,25 +329,26 @@ class Validator:
return arg_value
@staticmethod
def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name):
"""Method for checking whether an int value is in some range."""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
excp_cls = TypeError if type_mismatch else ValueError
if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit):
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
raise excp_cls(f'For \'{prim_name}\' the `{arg_name}` should be an int in range {rel_str},'
f' but got `{arg_value}` with type `{type(arg_value).__name__}`.')
return arg_value
def check_int_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None):
"""
Method for checking whether input value is in int range.
Usage:
- number = check_int_range(number, 0, 1, Rel.INC_NEITHER) # number in [0, 1]
- number = check_int_range(number, 0, 1, Rel.INC_NEITHER, "number") # number in [0, 1]
"""
return check_number_range(arg_value, lower_limit, upper_limit, rel, int, arg_name, prim_name)
@staticmethod
def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name):
"""Method for checking whether a numeric value is in some range."""
rel_fn = Rel.get_fns(rel)
if not rel_fn(arg_value, lower_limit, upper_limit):
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be in range {rel_str}, but got {arg_value}.')
return arg_value
def check_float_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None):
"""
Method for checking whether input value is in float range.
Usage:
- number = check_float_range(number, 0.0, 1.0, Rel.INC_NEITHER) # number in [0.0, 1.0]
- number = check_float_range(number, 0.0, 1.0, Rel.INC_NEITHER, "number") # number in [0.0, 1.0]
"""
return check_number_range(arg_value, lower_limit, upper_limit, rel, float, arg_name, prim_name)
@staticmethod
def check_string(arg_value, valid_values, arg_name=None, prim_name=None):
@ -502,13 +530,6 @@ class Validator:
f'{tuple(exp_shape)}, but got {shape}.')
def check_int(input_param):
"""Int type judgment."""
if isinstance(input_param, int) and not isinstance(input_param, bool):
return input_param
raise TypeError("Input type must be int!")
def check_int_zero_one(input_param):
"""Judge whether it is 0 or 1."""
if input_param in (0, 1):

View File

@ -233,7 +233,7 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch):
"""
if not isinstance(min_lr, float):
raise TypeError("min_lr must be float.")
validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, None)
validator.check_non_negative_float(min_lr, "min_lr", None)
validator.check_positive_float(max_lr, 'max_lr')
validator.check_is_float(max_lr, 'max_lr')
validator.check_positive_int(total_step, 'total_step')
@ -303,7 +303,7 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e
validator.check_is_float(learning_rate, 'learning_rate')
if not isinstance(end_learning_rate, float):
raise TypeError("end_learning_rate must be float.")
validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None)
validator.check_non_negative_float(end_learning_rate, "end_learning_rate", None)
validator.check_positive_float(power, 'power')
validator.check_is_float(power, 'power')
validator.check_positive_int(total_step, 'total_step')
@ -356,7 +356,7 @@ def warmup_lr(learning_rate, total_step, step_per_epoch, warmup_epoch):
"""
if not isinstance(learning_rate, float):
raise TypeError("learning_rate must be float.")
validator.check_number_range("learning_rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None)
validator.check_non_negative_float(learning_rate, "learning_rate", None)
validator.check_positive_int(warmup_epoch, 'warmup_epoch')
validator.check_positive_int(total_step, 'total_step')
validator.check_positive_int(step_per_epoch, 'step_per_epoch')

View File

@ -451,8 +451,7 @@ class CentralCrop(Cell):
def __init__(self, central_fraction):
super(CentralCrop, self).__init__()
validator.check_value_type("central_fraction", central_fraction, [float], self.cls_name)
self.central_fraction = validator.check_number_range('central_fraction', central_fraction,
0.0, 1.0, Rel.INC_RIGHT, self.cls_name)
self.central_fraction = validator.check_float_range(0.0, 1.0, Rel.INC_RIGHT, 'central_fraction', central_fraction, self.cls_name)
self.slice = P.Slice()
def construct(self, image):

View File

@ -254,7 +254,7 @@ class CosineDecayLR(LearningRateSchedule):
super(CosineDecayLR, self).__init__()
if not isinstance(min_lr, float):
raise TypeError("min_lr must be float.")
validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
validator.check_non_negative_float(min_lr, "min_lr", self.cls_name)
validator.check_positive_float(max_lr, 'max_lr', self.cls_name)
validator.check_is_float(max_lr, 'max_lr', self.cls_name)
validator.check_positive_int(decay_steps, "decay_steps", self.cls_name)
@ -322,8 +322,7 @@ class PolynomialDecayLR(LearningRateSchedule):
validator.check_is_float(learning_rate, 'learning_rate')
if not isinstance(end_learning_rate, float):
raise TypeError("end_learning_rate must be float.")
validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT,
self.cls_name)
validator.check_non_negative_float(end_learning_rate, "end_learning_rate", self.cls_name)
validator.check_positive_int(decay_steps, 'decay_steps', self.cls_name)
validator.check_value_type('update_decay_steps', update_decay_steps, [bool], self.cls_name)
validator.check_positive_float(power, 'power', self.cls_name)
@ -387,7 +386,7 @@ class WarmUpLR(LearningRateSchedule):
super(WarmUpLR, self).__init__()
if not isinstance(learning_rate, float):
raise TypeError("learning_rate must be float.")
validator.check_number_range("learning_rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
validator.check_non_negative_float(learning_rate, "learning_rate", self.cls_name)
validator.check_positive_int(warmup_steps, 'warmup_steps', self.cls_name)
self.warmup_steps = warmup_steps
self.learning_rate = learning_rate

View File

@ -368,7 +368,7 @@ class CosineEmbeddingLoss(_Loss):
self.reduce_sum = P.ReduceSum()
self.maximum = P.Maximum()
validator.check_value_type("margin", margin, [float], self.cls_name)
self.margin = validator.check_number_range("margin", margin, -1.0, 1.0, Rel.INC_BOTH, self.cls_name)
self.margin = validator.check_float_range(margin, -1.0, 1.0, Rel.INC_BOTH, "margin", self.cls_name)
def construct(self, x1, x2, y):
F.same_type_shape(x1, x2)

View File

@ -126,9 +126,9 @@ def _check_param_value(beta1, beta2, eps, prim_name):
validator.check_value_type("beta1", beta1, [float], prim_name)
validator.check_value_type("beta2", beta2, [float], prim_name)
validator.check_value_type("eps", eps, [float], prim_name)
validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name)
validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
validator.check_positive_float(eps, "eps", prim_name)
class Adam(Optimizer):

View File

@ -177,9 +177,9 @@ def _check_param_value(beta1, beta2, eps, prim_name):
validator.check_value_type("beta1", beta1, [float], prim_name)
validator.check_value_type("beta2", beta2, [float], prim_name)
validator.check_value_type("eps", eps, [float], prim_name)
validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name)
validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
validator.check_positive_float(eps, "eps", prim_name)
class Lamb(Optimizer):

View File

@ -70,10 +70,10 @@ def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
validator.check_value_type("beta2", beta2, [float], prim_name)
validator.check_value_type("eps", eps, [float], prim_name)
validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name)
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
validator.check_positive_float(eps, "eps", prim_name)
validator.check_non_negative_float(weight_decay, "weight_decay", prim_name)
class LazyAdam(Optimizer):

View File

@ -100,7 +100,7 @@ class Optimizer(Cell):
if isinstance(loss_scale, int):
loss_scale = float(loss_scale)
validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name)
validator.check_number_range("loss_scale", loss_scale, 0.0, float("inf"), Rel.INC_NEITHER, self.cls_name)
validator.check_positive_float(loss_scale, "loss_scale", self.cls_name)
self.loss_scale = loss_scale
weight_decay = self._preprocess_weight_decay(weight_decay)
@ -221,7 +221,7 @@ class Optimizer(Cell):
"""Check weight decay, and convert int to float."""
if isinstance(weight_decay, (float, int)):
weight_decay = float(weight_decay)
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
validator.check_non_negative_float(weight_decay, "weight_decay", self.cls_name)
return weight_decay
raise TypeError("Weight decay should be int or float.")
@ -229,7 +229,7 @@ class Optimizer(Cell):
"""Check lr value, and convert lr to a float, a Tensor or a LearningRateSchedule."""
if isinstance(learning_rate, (float, int)):
learning_rate = float(learning_rate)
validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
validator.check_non_negative_float(learning_rate, "learning rate", self.cls_name)
return learning_rate
if isinstance(learning_rate, Tensor) and learning_rate.dim() == 0:
return learning_rate

View File

@ -45,9 +45,9 @@ def _check_param_value(accum, l1, l2, use_locking, prim_name=None):
validator.check_value_type("l1", l1, [float], prim_name)
validator.check_value_type("l2", l2, [float], prim_name)
validator.check_value_type("use_locking", use_locking, [bool], prim_name)
validator.check_number_range("accum", accum, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
validator.check_non_negative_float(accum, "accum", prim_name)
validator.check_non_negative_float(l1, "l1", prim_name)
validator.check_non_negative_float(l2, "l2", prim_name)
class ProximalAdagrad(Optimizer):

View File

@ -154,11 +154,11 @@ class RMSProp(Optimizer):
use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0):
super(RMSProp, self).__init__(learning_rate, params, weight_decay, loss_scale)
validator.check_value_type("decay", decay, [float], self.cls_name)
validator.check_number_range("decay", decay, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
validator.check_non_negative_float(decay, "decay", self.cls_name)
validator.check_value_type("momentum", momentum, [float], self.cls_name)
validator.check_number_range("momentum", momentum, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
validator.check_non_negative_float(momentum, "momentum", self.cls_name)
validator.check_value_type("epsilon", epsilon, [float], self.cls_name)
validator.check_number_range("epsilon", epsilon, 0.0, float("inf"), Rel.INC_NEITHER, self.cls_name)
validator.check_positive_float(epsilon, "epsilon", self.cls_name)
validator.check_value_type("use_locking", use_locking, [bool], self.cls_name)
validator.check_value_type("centered", centered, [bool], self.cls_name)

View File

@ -69,7 +69,7 @@ def get_concat_offset(x_shp, x_type, axis, prim_name):
validator.check_subclass("shape0", x_type[0], mstype.tensor, prim_name)
validator.check_positive_int(len(x_shp[0]), "len of x_shp[0]", prim_name)
rank_base = len(x_shp[0])
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH, prim_name)
validator.check_int_range(axis, -rank_base - 1, rank_base, Rel.INC_BOTH, 'axis', prim_name)
if axis < 0:
axis = axis + rank_base
all_shp = x_shp[0][axis]

View File

@ -188,7 +188,7 @@ class BatchNormGrad(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, is_training=False, epsilon=1e-5):
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name)
self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
self.add_prim_attr('data_format', "NCHW")
def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape):
@ -485,7 +485,7 @@ class DropoutGrad(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, keep_prob=0.5):
self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0, 1, Rel.INC_RIGHT, self.name)
self.keep_prob = validator.check_float_range(keep_prob, 0, 1, Rel.INC_RIGHT, "keep_prob", self.name)
def infer_shape(self, dy_shape, mask_shape):
return dy_shape
@ -902,7 +902,7 @@ class LogSoftmaxGrad(PrimitiveWithInfer):
def infer_shape(self, dout, logits):
rank = len(logits)
validator.check_int_range('axis', self.axis, -rank - 1, rank, Rel.INC_BOTH, self.name)
validator.check_int_range(self.axis, -rank - 1, rank, Rel.INC_BOTH, 'axis', self.name)
return logits
def infer_dtype(self, dout, logits):
@ -921,7 +921,7 @@ class LSTMGradData(PrimitiveWithInfer):
self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name)
self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name)
if bidirectional:
self.num_directions = 2
@ -970,7 +970,7 @@ class LSTMGradWeight(PrimitiveWithInfer):
self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name)
self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name)
if bidirectional:
self.num_directions = 2
@ -1005,7 +1005,7 @@ class LSTMGrad(PrimitiveWithInfer):
self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name)
self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name)
if bidirectional:
self.num_directions = 2
@ -1652,7 +1652,7 @@ class BasicLSTMCellInputGrad(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, keep_prob):
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0.0, 1.0, Rel.INC_BOTH, self.name)
self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name)
self.add_prim_attr("io_format", "ND")
def infer_shape(self, dgate_shape, w_shape):

View File

@ -76,8 +76,7 @@ class MinMaxUpdatePerLayer(PrimitiveWithInfer):
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'],
outputs=['min_up', 'max_up'])
@ -136,10 +135,9 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
if self.is_ascend:
self.channel_axis = validator.check_int_range('channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name)
self.channel_axis = validator.check_int_range(channel_axis, 0, 1, Rel.INC_BOTH, 'channel_axis', self.name)
else:
self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
self.init_prim_io_names(
@ -222,10 +220,8 @@ class FakeQuantPerLayer(PrimitiveWithInfer):
'symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type(
'narrow_range', narrow_range, (bool,), self.name)
self.training = validator.check_value_type(
'training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.training = validator.check_value_type('training', training, (bool,), self.name)
self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'],
@ -366,12 +362,11 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
'narrow_range', narrow_range, (bool,), self.name)
self.training = validator.check_value_type(
'training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name)
if self.is_ascend:
self.channel_axis = validator.check_int_range('channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name)
self.channel_axis = validator.check_int_range(channel_axis, 0, 1, Rel.INC_BOTH, 'channel_axis', self.name)
else:
self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out'])
@ -495,7 +490,7 @@ class BatchNormFold(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0):
"""Initialize batch norm fold layer"""
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
@ -806,7 +801,7 @@ class BatchNormFoldD(PrimitiveWithInfer):
def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0):
"""Initialize _BatchNormFold layer"""
from mindspore.ops._op_impl._custom_op import batchnorm_fold
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)

View File

@ -129,7 +129,7 @@ class ExpandDims(PrimitiveWithInfer):
x_shape = list(x['shape'])
axis_v = axis['value']
rank = len(x_shape)
validator.check_int_range('axis', axis_v, -rank - 1, rank, Rel.INC_BOTH, self.name)
validator.check_int_range(axis_v, -rank - 1, rank, Rel.INC_BOTH, 'axis', self.name)
value = None
if x['value'] is not None:
value = x['value'].asnumpy()
@ -534,7 +534,7 @@ class Squeeze(PrimitiveWithInfer):
ret = [d for d in x_shape if d != 1]
else:
for a in axis:
validator.check_int_range('axis or its elements', a, -ndim, ndim - 1, Rel.INC_BOTH, self.name)
validator.check_int_range(a, -ndim, ndim - 1, Rel.INC_BOTH, 'axis or its elements', self.name)
if x_shape[a] != 1:
raise ValueError('Cannot select an axis to squeeze out which has size not equal to one.')
ret = [x_shape[i] for i in range(ndim) if not (i in axis or (i - ndim) in axis)]
@ -658,7 +658,7 @@ class GatherV2(PrimitiveWithCheck):
axis_v = axis['value']
params_shp = params['shape']
rank = len(params_shp)
validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name)
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
if axis_v < 0:
axis_v += rank
@ -777,7 +777,7 @@ class Split(PrimitiveWithInfer):
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
x_shape = list(x['shape'])
dim = len(x_shape)
validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name)
validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name)
validator.check_positive_int(self.output_num, "output_num", self.name)
output_valid_check = x_shape[self.axis] % self.output_num
if output_valid_check != 0:
@ -1224,7 +1224,7 @@ class Argmax(PrimitiveWithInfer):
if axis is None:
axis = 0
x_rank = len(x_shape)
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name)
validator.check_int_range(axis, -x_rank, x_rank, Rel.INC_LEFT, "axis", self.name)
axis = axis + x_rank if axis < 0 else axis
ouput_shape = [x_shape[i] for i in range(x_rank) if i != axis]
return ouput_shape
@ -1272,7 +1272,7 @@ class Argmin(PrimitiveWithInfer):
if axis is None:
axis = 0
x_rank = len(x_shape)
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name)
validator.check_int_range(axis, -x_rank, x_rank, Rel.INC_LEFT, "axis", self.name)
axis = axis + x_rank if axis < 0 else axis
ouput_shape = [x_shape[i] for i in range(x_rank) if i != axis]
return ouput_shape
@ -1325,7 +1325,7 @@ class ArgMaxWithValue(PrimitiveWithInfer):
def infer_shape(self, x_shape):
axis = self.axis
x_rank = len(x_shape)
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name)
validator.check_int_range(axis, -x_rank, x_rank, Rel.INC_LEFT, "axis", self.name)
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name)
return ouput_shape, ouput_shape
@ -1377,7 +1377,7 @@ class ArgMinWithValue(PrimitiveWithInfer):
def infer_shape(self, x_shape):
axis = self.axis
x_rank = len(x_shape)
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name)
validator.check_int_range(axis, -x_rank, x_rank, Rel.INC_LEFT, "axis", self.name)
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name)
return ouput_shape, ouput_shape
@ -1760,7 +1760,7 @@ def _get_pack_shape(x_shape, x_type, axis, prim_name):
rank_base = len(x_shape[0])
N = len(x_shape)
out_shape = x_shape[0]
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH, prim_name)
validator.check_int_range(axis, -rank_base - 1, rank_base, Rel.INC_BOTH, 'axis', prim_name)
if axis < 0:
axis = axis + rank_base + 1
for i in range(1, N):
@ -1863,7 +1863,7 @@ class Unpack(PrimitiveWithInfer):
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
x_shape = list(x['shape'])
dim = len(x_shape)
validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name)
validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name)
if self.axis < 0:
self.axis = self.axis + dim
output_num = x_shape[self.axis]
@ -1965,7 +1965,7 @@ class ReverseV2(PrimitiveWithInfer):
def infer_shape(self, x_shape):
dim = len(x_shape)
for i, each in enumerate(self.axis):
validator.check_int_range(f'axis[{i}]', each, -dim, dim, Rel.INC_LEFT, self.name)
validator.check_int_range(each, -dim, dim, Rel.INC_LEFT, f'axis[{i}]', self.name)
return x_shape
def infer_dtype(self, x_dtype):

View File

@ -206,7 +206,7 @@ class _HostAllGather(PrimitiveWithInfer):
validator.check_value_type('group', group, (tuple, list), self.name)
validator.check_integer("group size", len(group), 2, Rel.GE, self.name)
for r in group:
validator.check_int_range("rank_id", r, 0, 7, Rel.INC_BOTH, self.name)
validator.check_int_range(r, 0, 7, Rel.INC_BOTH, "rank_id", self.name)
validator.check_value_type("rank_id", r, (int,), self.name)
self.group_size = len(group)
self.add_prim_attr('group', group)
@ -315,7 +315,7 @@ class _HostReduceScatter(PrimitiveWithInfer):
validator.check_value_type('group', group, (tuple, list), self.name)
validator.check_integer("group size", len(group), 2, Rel.GE, self.name)
for r in group:
validator.check_int_range("rank_id", r, 0, 7, Rel.INC_BOTH, self.name)
validator.check_int_range(r, 0, 7, Rel.INC_BOTH, "rank_id", self.name)
validator.check_value_type("rank_id", r, (int,), self.name)
self.op = op
self.group_size = len(group)

View File

@ -70,8 +70,7 @@ class ControlDepend(Primitive):
@prim_attr_register
def __init__(self, depend_mode=0):
"""init"""
validator.check_int_range(
"depend_mode", depend_mode, 0, 1, Rel.INC_BOTH, self.name)
validator.check_int_range(depend_mode, 0, 1, Rel.INC_BOTH, "depend_mode", self.name)
def __call__(self, src, dst):
return src

View File

@ -31,7 +31,7 @@ def _infer_shape_reduce(x, axis, keep_dims, prim_name):
"""Common infer for reduce operator"""
def reduce_one_axis(one_axis):
validator.check_int_range('axis', one_axis, -dim, dim, Rel.INC_LEFT, prim_name)
validator.check_int_range(one_axis, -dim, dim, Rel.INC_LEFT, 'axis', prim_name)
if one_axis < 0:
one_axis += dim
axis_reduce.add(one_axis)

View File

@ -149,7 +149,7 @@ class Softmax(PrimitiveWithInfer):
validator.check_integer("length of axis", len(self.axis), 1, Rel.GE, self.name)
rank = len(logits)
for axis_v in self.axis:
validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name)
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
return logits
def infer_dtype(self, logits):
@ -193,7 +193,7 @@ class LogSoftmax(PrimitiveWithInfer):
def infer_shape(self, logits):
rank = len(logits)
validator.check_int_range('axis', self.axis, -rank, rank, Rel.INC_LEFT, self.name)
validator.check_int_range(self.axis, -rank, rank, Rel.INC_LEFT, 'axis', self.name)
return logits
def infer_dtype(self, logits):
@ -637,8 +637,8 @@ class FusedBatchNorm(Primitive):
self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'],
outputs=['y', 'running_mean', 'running_variance', 'save_mean', 'save_inv_variance'])
self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name)
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name)
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
self._update_parameter = True
@ -710,8 +710,8 @@ class FusedBatchNormEx(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'],
outputs=['y', 'save_scale', 'save_bias', 'save_mean', 'save_inv_variance', 'reserve'])
self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name)
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name)
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
self._update_parameter = True
self.add_prim_attr('data_format', "NCHW")
@ -818,8 +818,8 @@ class BNTrainingUpdate(PrimitiveWithInfer):
validator.check_value_type("isRef", isRef, [bool], self.name)
validator.check_value_type("epsilon", epsilon, [float], self.name)
validator.check_value_type("factor", factor, [float], self.name)
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, 'BNTrainingUpdate')
self.factor = validator.check_number_range('factor', factor, 0, 1, Rel.INC_BOTH, 'BNTrainingUpdate')
self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', 'BNTrainingUpdate')
self.factor = validator.check_float_range(factor, 0, 1, Rel.INC_BOTH, 'factor', 'BNTrainingUpdate')
def infer_shape(self, x, sum, square_sum, scale, b, mean, variance):
validator.check_integer("x rank", len(x), 4, Rel.EQ, self.name)
@ -898,7 +898,7 @@ class BatchNorm(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, is_training=False, epsilon=1e-5):
validator.check_value_type('is_training', is_training, (bool,), self.name)
validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name)
validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
self.add_prim_attr('data_format', "NCHW")
self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'],
outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2'])
@ -2383,7 +2383,7 @@ class L2Normalize(PrimitiveWithInfer):
def infer_shape(self, input_x):
dim = len(input_x)
validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name)
validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name)
return input_x
def infer_dtype(self, input_x):
@ -2481,10 +2481,10 @@ class DropoutDoMask(PrimitiveWithInfer):
keep_prob_v = keep_prob['value']
if keep_prob_v is not None:
if isinstance(keep_prob['dtype'], type(mstype.tensor)):
validator.check_number_range('keep_prob', keep_prob_v.asnumpy(), 0, 1, Rel.INC_BOTH, self.name)
validator.check_float_range(keep_prob_v.asnumpy(), 0, 1, Rel.INC_BOTH, 'keep_prob', self.name)
else:
validator.check_value_type("keep_prob", keep_prob_v, [float], self.name)
validator.check_number_range('keep_prob', keep_prob_v, 0, 1, Rel.INC_BOTH, self.name)
validator.check_float_range(keep_prob_v, 0, 1, Rel.INC_BOTH, 'keep_prob', self.name)
out = {'shape': input_x_shape,
'dtype': input_x['dtype'],
@ -2584,7 +2584,7 @@ class OneHot(PrimitiveWithInfer):
# check shape
indices_shp = indices['shape']
validator.check_int_range("axis", self.axis, -1, len(indices_shp), Rel.INC_BOTH, self.name)
validator.check_int_range(self.axis, -1, len(indices_shp), Rel.INC_BOTH, "axis", self.name)
depth_val = depth['value']
validator.check_non_negative_int(depth_val, "depth", self.name)
# create new dimension at end if self.axis is -1
@ -2771,7 +2771,7 @@ class LSTM(PrimitiveWithInfer):
self.has_bias = validator.check_value_type("has_bias", has_bias, (bool,), self.name)
self.bidirectional = validator.check_value_type("bidirectional", bidirectional, (bool,), self.name)
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name)
self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name)
if bidirectional:
self.num_directions = 2
@ -3054,7 +3054,7 @@ class ROIAlign(PrimitiveWithInfer):
validator.check_value_type("spatial_scale", spatial_scale, [float], self.name)
validator.check_value_type("sample_num", sample_num, [int], self.name)
validator.check_value_type("roi_end_mode", roi_end_mode, [int], self.name)
validator.check_int_range("roi_end_mode", roi_end_mode, 0, 1, Rel.INC_BOTH, self.name)
validator.check_int_range(roi_end_mode, 0, 1, Rel.INC_BOTH, "roi_end_mode", self.name)
self.pooled_height = pooled_height
self.pooled_width = pooled_width
self.spatial_scale = spatial_scale
@ -3502,9 +3502,9 @@ class FusedSparseFtrl(PrimitiveWithInfer):
validator.check_value_type("l1", l1, [float], self.name)
validator.check_value_type("l2", l2, [float], self.name)
validator.check_value_type("lr_power", lr_power, [float], self.name)
self.lr = validator.check_number_range("lr", lr, 0.0, float("inf"), Rel.INC_NEITHER, self.name)
self.l1 = validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, self.name)
self.l2 = validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, self.name)
self.lr = validator.check_positive_float(lr, "lr", self.name)
self.l1 = validator.check_non_negative_float(l1, "l1", self.name)
self.l2 = validator.check_non_negative_float(l2, "l2", self.name)
self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name)
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
@ -4240,7 +4240,7 @@ class SparseApplyAdagrad(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, lr, update_slots=True, use_locking=False):
validator.check_value_type("lr", lr, [float], self.name)
validator.check_number_range("lr", lr, float("-inf"), float("inf"), Rel.INC_NEITHER, self.name)
validator.check_is_float(lr, "lr", self.name)
validator.check_value_type("update_slots", update_slots, [bool], self.name)
validator.check_value_type("use_locking", use_locking, [bool], self.name)
@ -5142,9 +5142,9 @@ class SparseApplyFtrl(PrimitiveWithCheck):
validator.check_value_type("l1", l1, [float], self.name)
validator.check_value_type("l2", l2, [float], self.name)
validator.check_value_type("lr_power", lr_power, [float], self.name)
self.lr = validator.check_number_range("lr", lr, 0.0, float("inf"), Rel.INC_NEITHER, self.name)
self.l1 = validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, self.name)
self.l2 = validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, self.name)
self.lr = validator.check_positive_float(lr, "lr", self.name)
self.l1 = validator.check_non_negative_float(l1, "l1", self.name)
self.l2 = validator.check_non_negative_float(l2, "l2", self.name)
self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name)
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'indices'],
@ -5239,9 +5239,9 @@ class SparseApplyFtrlV2(PrimitiveWithInfer):
validator.check_value_type("l1", l1, [float], self.name)
validator.check_value_type("l2", l2, [float], self.name)
validator.check_value_type("lr_power", lr_power, [float], self.name)
self.lr = validator.check_number_range("lr", lr, 0.0, float("inf"), Rel.INC_NEITHER, self.name)
self.l1 = validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, self.name)
self.l2 = validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, self.name)
self.lr = validator.check_positive_float(lr, "lr", self.name)
self.l1 = validator.check_non_negative_float(l1, "l1", self.name)
self.l2 = validator.check_non_negative_float(l2, "l2", self.name)
self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name)
self.l2_shrinkage = validator.check_value_type("l2_shrinkage", l2_shrinkage, [float], self.name)
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
@ -5285,7 +5285,7 @@ class Dropout(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, keep_prob=0.5):
self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0, 1, Rel.INC_RIGHT, self.name)
self.keep_prob = validator.check_float_range(keep_prob, 0, 1, Rel.INC_RIGHT, "keep_prob", self.name)
def infer_shape(self, x_shape):
validator.check_integer("x_shape", len(x_shape), 1, Rel.GE, self.name)
@ -5510,7 +5510,7 @@ class BasicLSTMCell(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'):
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0.0, 1.0, Rel.INC_BOTH, self.name)
self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name)
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
self.state_is_tuple = validator.check_value_type("state_is_tuple", state_is_tuple, [bool], self.name)
self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)

View File

@ -100,10 +100,8 @@ def _generate_cosine_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps):
lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
lr = float(lr_init) + lr_inc * (i + 1)
else:
linear_decay = (total_steps - i) / decay_steps
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))
decayed = linear_decay * cosine_decay + 0.00001
lr = lr_max * decayed
cosine_decay = 0.5 * (1 + math.cos(math.pi * (i - warmup_steps) / decay_steps))
lr = (lr_max - lr_end) * cosine_decay + lr_end
lr_each_step.append(lr)
return lr_each_step

View File

@ -122,7 +122,7 @@ class MySparseGatherV2(PrimitiveWithInfer):
axis_v = axis['value']
params_shp = params['shape']
rank = len(params_shp)
validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name)
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
if axis_v < 0:
axis_v += rank
out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:]
@ -208,10 +208,10 @@ def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
validator.check_value_type("beta2", beta2, [float], prim_name)
validator.check_value_type("eps", eps, [float], prim_name)
validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name)
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
validator.check_positive_float(eps, "eps", prim_name)
validator.check_non_negative_float(weight_decay, "weight_decay", prim_name)
class AdamWeightDecaySparse(Optimizer):

View File

@ -14,55 +14,97 @@
# ============================================================================
""" test checkparameter """
import pytest
from mindspore._checkparam import check_int, check_input_format, Validator, twice
import numpy as np
from mindspore._checkparam import check_input_format, Validator, twice, Rel
kernel_size = 5
kernel_size1 = twice(kernel_size)
assert kernel_size1 == (5, 5)
def test_check_integer1():
with pytest.raises(TypeError):
Validator.check_integer("input", 0, Rel.GE, "number")
def test_check_int_1():
assert check_int(3) == 3
def check_int_positive_1():
def test_check_integer2():
with pytest.raises(ValueError):
Validator.check_positive_int(-1)
Validator.check_integer(-1, 0, Rel.GE, "number")
def test_check_integer3():
input = np.random.randint(0, 100)
assert Validator.check_integer(input, 0, Rel.GE, "number") == input
def test_NCHW1():
assert check_input_format("NCHW") == "NCHW"
def test_check_int1():
input = np.random.randint(-100, 100)
assert Validator.check_is_int(input) == input
def test_check_int2():
with pytest.raises(TypeError):
Validator.check_is_int(3.3)
def test_NCHW3():
def test_check_int3():
with pytest.raises(TypeError):
Validator.check_is_int("str")
def test_check_int4():
with pytest.raises(TypeError):
Validator.check_is_int(True)
def test_check_is_int5():
with pytest.raises(TypeError):
Validator.check_is_int(True)
with pytest.raises(TypeError):
Validator.check_is_int(False)
def test_check_positive_int1():
input = np.random.randint(0, 100)
assert Validator.check_positive_int(input) == input
def test_check_positive_int2():
input = np.random.randint(-100, 0)
with pytest.raises(ValueError):
check_input_format("rt")
Validator.check_positive_int(input)
def test_check_positive_int3():
with pytest.raises(ValueError):
Validator.check_positive_int(3.3)
def test_check_int_2():
def test_check_positive_int4():
with pytest.raises(TypeError):
check_int(3.3)
Validator.check_positive_int("str")
def test_check_negative_int1():
input = np.random.randint(-100, -1)
assert Validator.check_negative_int(input) == input
def test_check_int_3():
def test_check_negative_int2():
input = np.random.randint(0, 100)
with pytest.raises(ValueError):
Validator.check_negative_int(input)
def test_check_negative_int3():
with pytest.raises(ValueError):
Validator.check_negative_int(3.3)
def test_check_negative_int4():
with pytest.raises(TypeError):
check_int("str")
Validator.check_negative_int("str")
def test_check_non_positive_int1():
input = np.random.randint(-100, 0)
assert Validator.check_non_positive_int(input) == input
def test_check_int_4():
def test_check_non_positive_int2():
input = np.random.randint(1, 100)
with pytest.raises(ValueError):
Validator.check_non_positive_int(input)
def test_check_non_positive_int3():
with pytest.raises(ValueError):
Validator.check_non_positive_int(3.3)
def test_check_non_positive_int4():
with pytest.raises(TypeError):
check_int(True)
def test_check_int_5():
check_int(0)
check_int(1)
with pytest.raises(TypeError):
check_int(True)
with pytest.raises(TypeError):
check_int(False)
Validator.check_non_positive_int("str")
def test_check_bool_1():
assert Validator.check_bool(True)