forked from mindspore-Ecosystem/mindspore
!7311 [ME] change _check_parameter format
Merge pull request !7311 from chenzhongming/zomi_master
This commit is contained in:
commit
c68f36c81e
|
@ -94,10 +94,10 @@ rel_strs = {
|
|||
|
||||
def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=None):
|
||||
"""
|
||||
Check argument integer.
|
||||
Check argument integer.
|
||||
|
||||
Usage:
|
||||
- number = check_integer(number, 0, Rel.GE, "number", None) # number >= 0
|
||||
Usage:
|
||||
- number = check_integer(number, 0, Rel.GE, "number", None) # number >= 0
|
||||
"""
|
||||
rel_fn = Rel.get_fns(rel)
|
||||
type_mismatch = not isinstance(arg_value, arg_type) or isinstance(arg_value, bool)
|
||||
|
@ -122,13 +122,33 @@ def check_is_number(arg_value, arg_type, arg_name=None, prim_name=None):
|
|||
"""
|
||||
prim_name = f'in \'{prim_name}\'' if prim_name else ''
|
||||
arg_name = f'\'{prim_name}\'' if arg_name else 'Input value'
|
||||
if isinstance(arg_value, arg_type):
|
||||
if isinstance(arg_value, arg_type) and not isinstance(arg_value, bool):
|
||||
if math.isinf(arg_value) or math.isnan(arg_value):
|
||||
raise ValueError(f'{arg_name} {prim_name} must be legal float, but got `{arg_value}`.')
|
||||
return arg_value
|
||||
raise TypeError(f'{arg_name} {prim_name} must be float, but got `{type(arg_value).__name__}`')
|
||||
|
||||
|
||||
def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg_name=None, prim_name=None):
|
||||
"""
|
||||
Method for checking whether an int value is in some range.
|
||||
|
||||
Usage:
|
||||
- number = check_number_range(number, 0.0, 1.0, Rel.INC_NEITHER, "number", float) # number in [0.0, 1.0]
|
||||
- number = check_number_range(number, 0, 1, Rel.INC_NEITHER, "number", int) # number in [0, 1]
|
||||
"""
|
||||
prim_name = f'in `{prim_name}`' if prim_name else ''
|
||||
arg_name = f'`{arg_name}`' if arg_name else ''
|
||||
rel_fn = Rel.get_fns(rel)
|
||||
type_mismatch = not isinstance(arg_value, (np.ndarray, np.generic, value_type)) or isinstance(arg_value, bool)
|
||||
excp_cls = TypeError if type_mismatch else ValueError
|
||||
if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit):
|
||||
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
|
||||
raise excp_cls("{} {} should be in range of {}, but got {:.3f} with type {}.".format(
|
||||
arg_name, prim_name, rel_str, arg_value, type(arg_value).__name__))
|
||||
return arg_value
|
||||
|
||||
|
||||
class Validator:
|
||||
"""validator for checking input parameters"""
|
||||
|
||||
|
@ -147,16 +167,13 @@ class Validator:
|
|||
|
||||
@staticmethod
|
||||
def check_integer(arg_name, arg_value, value, rel, prim_name=None):
|
||||
"""Check argument is integer"""
|
||||
rel_fn = Rel.get_fns(rel)
|
||||
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
|
||||
excp_cls = TypeError if type_mismatch else ValueError
|
||||
if type_mismatch or not rel_fn(arg_value, value):
|
||||
rel_str = Rel.get_strs(rel).format(value)
|
||||
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
|
||||
raise excp_cls(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got `{arg_value}`'
|
||||
f' with type `{type(arg_value).__name__}`.')
|
||||
return arg_value
|
||||
"""
|
||||
Checks input integer value `arg_value` compare to `value`.
|
||||
|
||||
Usage:
|
||||
- number = check_integer(number, 0, Rel.GE, "number", None) # number >= 0
|
||||
"""
|
||||
return check_number(arg_value, value, rel, int, arg_name, prim_name)
|
||||
|
||||
@staticmethod
|
||||
def check_is_int(arg_value, arg_name=None, prim_name=None):
|
||||
|
@ -168,7 +185,7 @@ class Validator:
|
|||
- number = check_is_int(number, int, "bias")
|
||||
- number = check_is_int(number, int, "bias", "bias_class")
|
||||
"""
|
||||
check_is_number(arg_value, int, arg_name, prim_name)
|
||||
return check_is_number(arg_value, int, arg_name, prim_name)
|
||||
|
||||
@staticmethod
|
||||
def check_positive_int(arg_value, arg_name=None, prim_name=None):
|
||||
|
@ -214,6 +231,16 @@ class Validator:
|
|||
"""
|
||||
return check_number(arg_value, 0, Rel.GE, int, arg_name, prim_name)
|
||||
|
||||
@staticmethod
|
||||
def check_float(arg_value, value, rel, arg_name=None, prim_name=None):
|
||||
"""
|
||||
Checks input float value `arg_value` compare to `value`.
|
||||
|
||||
Usage:
|
||||
- number = check_float(number, 0.0, Rel.GE, "number", None) # number >= 0
|
||||
"""
|
||||
return check_number(arg_value, value, rel, float, arg_name, prim_name)
|
||||
|
||||
@staticmethod
|
||||
def check_is_float(arg_value, arg_name=None, prim_name=None):
|
||||
"""
|
||||
|
@ -224,7 +251,7 @@ class Validator:
|
|||
- number = check_is_float(number, int, "bias")
|
||||
- number = check_is_float(number, int, "bias", "bias_class")
|
||||
"""
|
||||
check_is_number(arg_value, float, arg_name, prim_name)
|
||||
return check_is_number(arg_value, float, arg_name, prim_name)
|
||||
|
||||
@staticmethod
|
||||
def check_positive_float(arg_value, arg_name=None, prim_name=None):
|
||||
|
@ -302,25 +329,26 @@ class Validator:
|
|||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name):
|
||||
"""Method for checking whether an int value is in some range."""
|
||||
rel_fn = Rel.get_fns(rel)
|
||||
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
|
||||
excp_cls = TypeError if type_mismatch else ValueError
|
||||
if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit):
|
||||
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
|
||||
raise excp_cls(f'For \'{prim_name}\' the `{arg_name}` should be an int in range {rel_str},'
|
||||
f' but got `{arg_value}` with type `{type(arg_value).__name__}`.')
|
||||
return arg_value
|
||||
def check_int_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None):
|
||||
"""
|
||||
Method for checking whether input value is in int range.
|
||||
|
||||
Usage:
|
||||
- number = check_int_range(number, 0, 1, Rel.INC_NEITHER) # number in [0, 1]
|
||||
- number = check_int_range(number, 0, 1, Rel.INC_NEITHER, "number") # number in [0, 1]
|
||||
"""
|
||||
return check_number_range(arg_value, lower_limit, upper_limit, rel, int, arg_name, prim_name)
|
||||
|
||||
@staticmethod
|
||||
def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name):
|
||||
"""Method for checking whether a numeric value is in some range."""
|
||||
rel_fn = Rel.get_fns(rel)
|
||||
if not rel_fn(arg_value, lower_limit, upper_limit):
|
||||
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
|
||||
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be in range {rel_str}, but got {arg_value}.')
|
||||
return arg_value
|
||||
def check_float_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None):
|
||||
"""
|
||||
Method for checking whether input value is in float range.
|
||||
|
||||
Usage:
|
||||
- number = check_float_range(number, 0.0, 1.0, Rel.INC_NEITHER) # number in [0.0, 1.0]
|
||||
- number = check_float_range(number, 0.0, 1.0, Rel.INC_NEITHER, "number") # number in [0.0, 1.0]
|
||||
"""
|
||||
return check_number_range(arg_value, lower_limit, upper_limit, rel, float, arg_name, prim_name)
|
||||
|
||||
@staticmethod
|
||||
def check_string(arg_value, valid_values, arg_name=None, prim_name=None):
|
||||
|
@ -502,13 +530,6 @@ class Validator:
|
|||
f'{tuple(exp_shape)}, but got {shape}.')
|
||||
|
||||
|
||||
def check_int(input_param):
|
||||
"""Int type judgment."""
|
||||
if isinstance(input_param, int) and not isinstance(input_param, bool):
|
||||
return input_param
|
||||
raise TypeError("Input type must be int!")
|
||||
|
||||
|
||||
def check_int_zero_one(input_param):
|
||||
"""Judge whether it is 0 or 1."""
|
||||
if input_param in (0, 1):
|
||||
|
|
|
@ -233,7 +233,7 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch):
|
|||
"""
|
||||
if not isinstance(min_lr, float):
|
||||
raise TypeError("min_lr must be float.")
|
||||
validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, None)
|
||||
validator.check_non_negative_float(min_lr, "min_lr", None)
|
||||
validator.check_positive_float(max_lr, 'max_lr')
|
||||
validator.check_is_float(max_lr, 'max_lr')
|
||||
validator.check_positive_int(total_step, 'total_step')
|
||||
|
@ -303,7 +303,7 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e
|
|||
validator.check_is_float(learning_rate, 'learning_rate')
|
||||
if not isinstance(end_learning_rate, float):
|
||||
raise TypeError("end_learning_rate must be float.")
|
||||
validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None)
|
||||
validator.check_non_negative_float(end_learning_rate, "end_learning_rate", None)
|
||||
validator.check_positive_float(power, 'power')
|
||||
validator.check_is_float(power, 'power')
|
||||
validator.check_positive_int(total_step, 'total_step')
|
||||
|
@ -356,7 +356,7 @@ def warmup_lr(learning_rate, total_step, step_per_epoch, warmup_epoch):
|
|||
"""
|
||||
if not isinstance(learning_rate, float):
|
||||
raise TypeError("learning_rate must be float.")
|
||||
validator.check_number_range("learning_rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None)
|
||||
validator.check_non_negative_float(learning_rate, "learning_rate", None)
|
||||
validator.check_positive_int(warmup_epoch, 'warmup_epoch')
|
||||
validator.check_positive_int(total_step, 'total_step')
|
||||
validator.check_positive_int(step_per_epoch, 'step_per_epoch')
|
||||
|
|
|
@ -451,8 +451,7 @@ class CentralCrop(Cell):
|
|||
def __init__(self, central_fraction):
|
||||
super(CentralCrop, self).__init__()
|
||||
validator.check_value_type("central_fraction", central_fraction, [float], self.cls_name)
|
||||
self.central_fraction = validator.check_number_range('central_fraction', central_fraction,
|
||||
0.0, 1.0, Rel.INC_RIGHT, self.cls_name)
|
||||
self.central_fraction = validator.check_float_range(0.0, 1.0, Rel.INC_RIGHT, 'central_fraction', central_fraction, self.cls_name)
|
||||
self.slice = P.Slice()
|
||||
|
||||
def construct(self, image):
|
||||
|
|
|
@ -254,7 +254,7 @@ class CosineDecayLR(LearningRateSchedule):
|
|||
super(CosineDecayLR, self).__init__()
|
||||
if not isinstance(min_lr, float):
|
||||
raise TypeError("min_lr must be float.")
|
||||
validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
|
||||
validator.check_non_negative_float(min_lr, "min_lr", self.cls_name)
|
||||
validator.check_positive_float(max_lr, 'max_lr', self.cls_name)
|
||||
validator.check_is_float(max_lr, 'max_lr', self.cls_name)
|
||||
validator.check_positive_int(decay_steps, "decay_steps", self.cls_name)
|
||||
|
@ -322,8 +322,7 @@ class PolynomialDecayLR(LearningRateSchedule):
|
|||
validator.check_is_float(learning_rate, 'learning_rate')
|
||||
if not isinstance(end_learning_rate, float):
|
||||
raise TypeError("end_learning_rate must be float.")
|
||||
validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT,
|
||||
self.cls_name)
|
||||
validator.check_non_negative_float(end_learning_rate, "end_learning_rate", self.cls_name)
|
||||
validator.check_positive_int(decay_steps, 'decay_steps', self.cls_name)
|
||||
validator.check_value_type('update_decay_steps', update_decay_steps, [bool], self.cls_name)
|
||||
validator.check_positive_float(power, 'power', self.cls_name)
|
||||
|
@ -387,7 +386,7 @@ class WarmUpLR(LearningRateSchedule):
|
|||
super(WarmUpLR, self).__init__()
|
||||
if not isinstance(learning_rate, float):
|
||||
raise TypeError("learning_rate must be float.")
|
||||
validator.check_number_range("learning_rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
|
||||
validator.check_non_negative_float(learning_rate, "learning_rate", self.cls_name)
|
||||
validator.check_positive_int(warmup_steps, 'warmup_steps', self.cls_name)
|
||||
self.warmup_steps = warmup_steps
|
||||
self.learning_rate = learning_rate
|
||||
|
|
|
@ -368,7 +368,7 @@ class CosineEmbeddingLoss(_Loss):
|
|||
self.reduce_sum = P.ReduceSum()
|
||||
self.maximum = P.Maximum()
|
||||
validator.check_value_type("margin", margin, [float], self.cls_name)
|
||||
self.margin = validator.check_number_range("margin", margin, -1.0, 1.0, Rel.INC_BOTH, self.cls_name)
|
||||
self.margin = validator.check_float_range(margin, -1.0, 1.0, Rel.INC_BOTH, "margin", self.cls_name)
|
||||
|
||||
def construct(self, x1, x2, y):
|
||||
F.same_type_shape(x1, x2)
|
||||
|
|
|
@ -126,9 +126,9 @@ def _check_param_value(beta1, beta2, eps, prim_name):
|
|||
validator.check_value_type("beta1", beta1, [float], prim_name)
|
||||
validator.check_value_type("beta2", beta2, [float], prim_name)
|
||||
validator.check_value_type("eps", eps, [float], prim_name)
|
||||
validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
|
||||
validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
|
||||
validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name)
|
||||
validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
|
||||
validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
|
||||
validator.check_positive_float(eps, "eps", prim_name)
|
||||
|
||||
|
||||
class Adam(Optimizer):
|
||||
|
|
|
@ -177,9 +177,9 @@ def _check_param_value(beta1, beta2, eps, prim_name):
|
|||
validator.check_value_type("beta1", beta1, [float], prim_name)
|
||||
validator.check_value_type("beta2", beta2, [float], prim_name)
|
||||
validator.check_value_type("eps", eps, [float], prim_name)
|
||||
validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
|
||||
validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
|
||||
validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name)
|
||||
validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
|
||||
validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
|
||||
validator.check_positive_float(eps, "eps", prim_name)
|
||||
|
||||
|
||||
class Lamb(Optimizer):
|
||||
|
|
|
@ -70,10 +70,10 @@ def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
|
|||
validator.check_value_type("beta2", beta2, [float], prim_name)
|
||||
validator.check_value_type("eps", eps, [float], prim_name)
|
||||
validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
|
||||
validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
|
||||
validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
|
||||
validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name)
|
||||
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
|
||||
validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
|
||||
validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
|
||||
validator.check_positive_float(eps, "eps", prim_name)
|
||||
validator.check_non_negative_float(weight_decay, "weight_decay", prim_name)
|
||||
|
||||
|
||||
class LazyAdam(Optimizer):
|
||||
|
|
|
@ -100,7 +100,7 @@ class Optimizer(Cell):
|
|||
if isinstance(loss_scale, int):
|
||||
loss_scale = float(loss_scale)
|
||||
validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name)
|
||||
validator.check_number_range("loss_scale", loss_scale, 0.0, float("inf"), Rel.INC_NEITHER, self.cls_name)
|
||||
validator.check_positive_float(loss_scale, "loss_scale", self.cls_name)
|
||||
self.loss_scale = loss_scale
|
||||
|
||||
weight_decay = self._preprocess_weight_decay(weight_decay)
|
||||
|
@ -221,7 +221,7 @@ class Optimizer(Cell):
|
|||
"""Check weight decay, and convert int to float."""
|
||||
if isinstance(weight_decay, (float, int)):
|
||||
weight_decay = float(weight_decay)
|
||||
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
|
||||
validator.check_non_negative_float(weight_decay, "weight_decay", self.cls_name)
|
||||
return weight_decay
|
||||
raise TypeError("Weight decay should be int or float.")
|
||||
|
||||
|
@ -229,7 +229,7 @@ class Optimizer(Cell):
|
|||
"""Check lr value, and convert lr to a float, a Tensor or a LearningRateSchedule."""
|
||||
if isinstance(learning_rate, (float, int)):
|
||||
learning_rate = float(learning_rate)
|
||||
validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
|
||||
validator.check_non_negative_float(learning_rate, "learning rate", self.cls_name)
|
||||
return learning_rate
|
||||
if isinstance(learning_rate, Tensor) and learning_rate.dim() == 0:
|
||||
return learning_rate
|
||||
|
|
|
@ -45,9 +45,9 @@ def _check_param_value(accum, l1, l2, use_locking, prim_name=None):
|
|||
validator.check_value_type("l1", l1, [float], prim_name)
|
||||
validator.check_value_type("l2", l2, [float], prim_name)
|
||||
validator.check_value_type("use_locking", use_locking, [bool], prim_name)
|
||||
validator.check_number_range("accum", accum, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
|
||||
validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
|
||||
validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
|
||||
validator.check_non_negative_float(accum, "accum", prim_name)
|
||||
validator.check_non_negative_float(l1, "l1", prim_name)
|
||||
validator.check_non_negative_float(l2, "l2", prim_name)
|
||||
|
||||
|
||||
class ProximalAdagrad(Optimizer):
|
||||
|
|
|
@ -154,11 +154,11 @@ class RMSProp(Optimizer):
|
|||
use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0):
|
||||
super(RMSProp, self).__init__(learning_rate, params, weight_decay, loss_scale)
|
||||
validator.check_value_type("decay", decay, [float], self.cls_name)
|
||||
validator.check_number_range("decay", decay, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
|
||||
validator.check_non_negative_float(decay, "decay", self.cls_name)
|
||||
validator.check_value_type("momentum", momentum, [float], self.cls_name)
|
||||
validator.check_number_range("momentum", momentum, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
|
||||
validator.check_non_negative_float(momentum, "momentum", self.cls_name)
|
||||
validator.check_value_type("epsilon", epsilon, [float], self.cls_name)
|
||||
validator.check_number_range("epsilon", epsilon, 0.0, float("inf"), Rel.INC_NEITHER, self.cls_name)
|
||||
validator.check_positive_float(epsilon, "epsilon", self.cls_name)
|
||||
validator.check_value_type("use_locking", use_locking, [bool], self.cls_name)
|
||||
validator.check_value_type("centered", centered, [bool], self.cls_name)
|
||||
|
||||
|
|
|
@ -69,7 +69,7 @@ def get_concat_offset(x_shp, x_type, axis, prim_name):
|
|||
validator.check_subclass("shape0", x_type[0], mstype.tensor, prim_name)
|
||||
validator.check_positive_int(len(x_shp[0]), "len of x_shp[0]", prim_name)
|
||||
rank_base = len(x_shp[0])
|
||||
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH, prim_name)
|
||||
validator.check_int_range(axis, -rank_base - 1, rank_base, Rel.INC_BOTH, 'axis', prim_name)
|
||||
if axis < 0:
|
||||
axis = axis + rank_base
|
||||
all_shp = x_shp[0][axis]
|
||||
|
|
|
@ -188,7 +188,7 @@ class BatchNormGrad(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, is_training=False, epsilon=1e-5):
|
||||
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
|
||||
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name)
|
||||
self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
|
||||
self.add_prim_attr('data_format', "NCHW")
|
||||
|
||||
def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape):
|
||||
|
@ -485,7 +485,7 @@ class DropoutGrad(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, keep_prob=0.5):
|
||||
self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0, 1, Rel.INC_RIGHT, self.name)
|
||||
self.keep_prob = validator.check_float_range(keep_prob, 0, 1, Rel.INC_RIGHT, "keep_prob", self.name)
|
||||
|
||||
def infer_shape(self, dy_shape, mask_shape):
|
||||
return dy_shape
|
||||
|
@ -902,7 +902,7 @@ class LogSoftmaxGrad(PrimitiveWithInfer):
|
|||
|
||||
def infer_shape(self, dout, logits):
|
||||
rank = len(logits)
|
||||
validator.check_int_range('axis', self.axis, -rank - 1, rank, Rel.INC_BOTH, self.name)
|
||||
validator.check_int_range(self.axis, -rank - 1, rank, Rel.INC_BOTH, 'axis', self.name)
|
||||
return logits
|
||||
|
||||
def infer_dtype(self, dout, logits):
|
||||
|
@ -921,7 +921,7 @@ class LSTMGradData(PrimitiveWithInfer):
|
|||
self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
|
||||
self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
|
||||
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
|
||||
self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name)
|
||||
self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name)
|
||||
|
||||
if bidirectional:
|
||||
self.num_directions = 2
|
||||
|
@ -970,7 +970,7 @@ class LSTMGradWeight(PrimitiveWithInfer):
|
|||
self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
|
||||
self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
|
||||
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
|
||||
self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name)
|
||||
self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name)
|
||||
|
||||
if bidirectional:
|
||||
self.num_directions = 2
|
||||
|
@ -1005,7 +1005,7 @@ class LSTMGrad(PrimitiveWithInfer):
|
|||
self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
|
||||
self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
|
||||
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
|
||||
self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name)
|
||||
self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name)
|
||||
|
||||
if bidirectional:
|
||||
self.num_directions = 2
|
||||
|
@ -1652,7 +1652,7 @@ class BasicLSTMCellInputGrad(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, keep_prob):
|
||||
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
|
||||
self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0.0, 1.0, Rel.INC_BOTH, self.name)
|
||||
self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name)
|
||||
self.add_prim_attr("io_format", "ND")
|
||||
|
||||
def infer_shape(self, dgate_shape, w_shape):
|
||||
|
|
|
@ -76,8 +76,7 @@ class MinMaxUpdatePerLayer(PrimitiveWithInfer):
|
|||
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
|
||||
|
||||
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
|
||||
self.ema_decay = validator.check_number_range(
|
||||
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
|
||||
self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
|
||||
self.init_prim_io_names(inputs=['x', 'min', 'max'],
|
||||
outputs=['min_up', 'max_up'])
|
||||
|
||||
|
@ -136,10 +135,9 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
|
|||
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
|
||||
|
||||
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
|
||||
self.ema_decay = validator.check_number_range(
|
||||
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
|
||||
self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
|
||||
if self.is_ascend:
|
||||
self.channel_axis = validator.check_int_range('channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name)
|
||||
self.channel_axis = validator.check_int_range(channel_axis, 0, 1, Rel.INC_BOTH, 'channel_axis', self.name)
|
||||
else:
|
||||
self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
|
||||
self.init_prim_io_names(
|
||||
|
@ -222,10 +220,8 @@ class FakeQuantPerLayer(PrimitiveWithInfer):
|
|||
'symmetric', symmetric, (bool,), self.name)
|
||||
self.narrow_range = validator.check_value_type(
|
||||
'narrow_range', narrow_range, (bool,), self.name)
|
||||
self.training = validator.check_value_type(
|
||||
'training', training, (bool,), self.name)
|
||||
self.ema_decay = validator.check_number_range(
|
||||
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
|
||||
self.training = validator.check_value_type('training', training, (bool,), self.name)
|
||||
self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
|
||||
self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
|
||||
self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name)
|
||||
self.init_prim_io_names(inputs=['x', 'min', 'max'],
|
||||
|
@ -366,12 +362,11 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
|
|||
'narrow_range', narrow_range, (bool,), self.name)
|
||||
self.training = validator.check_value_type(
|
||||
'training', training, (bool,), self.name)
|
||||
self.ema_decay = validator.check_number_range(
|
||||
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
|
||||
self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
|
||||
self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
|
||||
self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name)
|
||||
if self.is_ascend:
|
||||
self.channel_axis = validator.check_int_range('channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name)
|
||||
self.channel_axis = validator.check_int_range(channel_axis, 0, 1, Rel.INC_BOTH, 'channel_axis', self.name)
|
||||
else:
|
||||
self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
|
||||
self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out'])
|
||||
|
@ -495,7 +490,7 @@ class BatchNormFold(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0):
|
||||
"""Initialize batch norm fold layer"""
|
||||
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
|
||||
self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
|
||||
self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
|
||||
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
|
||||
self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
|
||||
|
@ -806,7 +801,7 @@ class BatchNormFoldD(PrimitiveWithInfer):
|
|||
def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0):
|
||||
"""Initialize _BatchNormFold layer"""
|
||||
from mindspore.ops._op_impl._custom_op import batchnorm_fold
|
||||
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
|
||||
self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
|
||||
self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
|
||||
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
|
||||
self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
|
||||
|
|
|
@ -129,7 +129,7 @@ class ExpandDims(PrimitiveWithInfer):
|
|||
x_shape = list(x['shape'])
|
||||
axis_v = axis['value']
|
||||
rank = len(x_shape)
|
||||
validator.check_int_range('axis', axis_v, -rank - 1, rank, Rel.INC_BOTH, self.name)
|
||||
validator.check_int_range(axis_v, -rank - 1, rank, Rel.INC_BOTH, 'axis', self.name)
|
||||
value = None
|
||||
if x['value'] is not None:
|
||||
value = x['value'].asnumpy()
|
||||
|
@ -534,7 +534,7 @@ class Squeeze(PrimitiveWithInfer):
|
|||
ret = [d for d in x_shape if d != 1]
|
||||
else:
|
||||
for a in axis:
|
||||
validator.check_int_range('axis or its elements', a, -ndim, ndim - 1, Rel.INC_BOTH, self.name)
|
||||
validator.check_int_range(a, -ndim, ndim - 1, Rel.INC_BOTH, 'axis or its elements', self.name)
|
||||
if x_shape[a] != 1:
|
||||
raise ValueError('Cannot select an axis to squeeze out which has size not equal to one.')
|
||||
ret = [x_shape[i] for i in range(ndim) if not (i in axis or (i - ndim) in axis)]
|
||||
|
@ -658,7 +658,7 @@ class GatherV2(PrimitiveWithCheck):
|
|||
axis_v = axis['value']
|
||||
params_shp = params['shape']
|
||||
rank = len(params_shp)
|
||||
validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name)
|
||||
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
|
||||
|
||||
if axis_v < 0:
|
||||
axis_v += rank
|
||||
|
@ -777,7 +777,7 @@ class Split(PrimitiveWithInfer):
|
|||
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
|
||||
x_shape = list(x['shape'])
|
||||
dim = len(x_shape)
|
||||
validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name)
|
||||
validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name)
|
||||
validator.check_positive_int(self.output_num, "output_num", self.name)
|
||||
output_valid_check = x_shape[self.axis] % self.output_num
|
||||
if output_valid_check != 0:
|
||||
|
@ -1224,7 +1224,7 @@ class Argmax(PrimitiveWithInfer):
|
|||
if axis is None:
|
||||
axis = 0
|
||||
x_rank = len(x_shape)
|
||||
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name)
|
||||
validator.check_int_range(axis, -x_rank, x_rank, Rel.INC_LEFT, "axis", self.name)
|
||||
axis = axis + x_rank if axis < 0 else axis
|
||||
ouput_shape = [x_shape[i] for i in range(x_rank) if i != axis]
|
||||
return ouput_shape
|
||||
|
@ -1272,7 +1272,7 @@ class Argmin(PrimitiveWithInfer):
|
|||
if axis is None:
|
||||
axis = 0
|
||||
x_rank = len(x_shape)
|
||||
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name)
|
||||
validator.check_int_range(axis, -x_rank, x_rank, Rel.INC_LEFT, "axis", self.name)
|
||||
axis = axis + x_rank if axis < 0 else axis
|
||||
ouput_shape = [x_shape[i] for i in range(x_rank) if i != axis]
|
||||
return ouput_shape
|
||||
|
@ -1325,7 +1325,7 @@ class ArgMaxWithValue(PrimitiveWithInfer):
|
|||
def infer_shape(self, x_shape):
|
||||
axis = self.axis
|
||||
x_rank = len(x_shape)
|
||||
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name)
|
||||
validator.check_int_range(axis, -x_rank, x_rank, Rel.INC_LEFT, "axis", self.name)
|
||||
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name)
|
||||
return ouput_shape, ouput_shape
|
||||
|
||||
|
@ -1377,7 +1377,7 @@ class ArgMinWithValue(PrimitiveWithInfer):
|
|||
def infer_shape(self, x_shape):
|
||||
axis = self.axis
|
||||
x_rank = len(x_shape)
|
||||
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name)
|
||||
validator.check_int_range(axis, -x_rank, x_rank, Rel.INC_LEFT, "axis", self.name)
|
||||
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name)
|
||||
return ouput_shape, ouput_shape
|
||||
|
||||
|
@ -1760,7 +1760,7 @@ def _get_pack_shape(x_shape, x_type, axis, prim_name):
|
|||
rank_base = len(x_shape[0])
|
||||
N = len(x_shape)
|
||||
out_shape = x_shape[0]
|
||||
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH, prim_name)
|
||||
validator.check_int_range(axis, -rank_base - 1, rank_base, Rel.INC_BOTH, 'axis', prim_name)
|
||||
if axis < 0:
|
||||
axis = axis + rank_base + 1
|
||||
for i in range(1, N):
|
||||
|
@ -1863,7 +1863,7 @@ class Unpack(PrimitiveWithInfer):
|
|||
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
|
||||
x_shape = list(x['shape'])
|
||||
dim = len(x_shape)
|
||||
validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name)
|
||||
validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name)
|
||||
if self.axis < 0:
|
||||
self.axis = self.axis + dim
|
||||
output_num = x_shape[self.axis]
|
||||
|
@ -1965,7 +1965,7 @@ class ReverseV2(PrimitiveWithInfer):
|
|||
def infer_shape(self, x_shape):
|
||||
dim = len(x_shape)
|
||||
for i, each in enumerate(self.axis):
|
||||
validator.check_int_range(f'axis[{i}]', each, -dim, dim, Rel.INC_LEFT, self.name)
|
||||
validator.check_int_range(each, -dim, dim, Rel.INC_LEFT, f'axis[{i}]', self.name)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
|
|
|
@ -206,7 +206,7 @@ class _HostAllGather(PrimitiveWithInfer):
|
|||
validator.check_value_type('group', group, (tuple, list), self.name)
|
||||
validator.check_integer("group size", len(group), 2, Rel.GE, self.name)
|
||||
for r in group:
|
||||
validator.check_int_range("rank_id", r, 0, 7, Rel.INC_BOTH, self.name)
|
||||
validator.check_int_range(r, 0, 7, Rel.INC_BOTH, "rank_id", self.name)
|
||||
validator.check_value_type("rank_id", r, (int,), self.name)
|
||||
self.group_size = len(group)
|
||||
self.add_prim_attr('group', group)
|
||||
|
@ -315,7 +315,7 @@ class _HostReduceScatter(PrimitiveWithInfer):
|
|||
validator.check_value_type('group', group, (tuple, list), self.name)
|
||||
validator.check_integer("group size", len(group), 2, Rel.GE, self.name)
|
||||
for r in group:
|
||||
validator.check_int_range("rank_id", r, 0, 7, Rel.INC_BOTH, self.name)
|
||||
validator.check_int_range(r, 0, 7, Rel.INC_BOTH, "rank_id", self.name)
|
||||
validator.check_value_type("rank_id", r, (int,), self.name)
|
||||
self.op = op
|
||||
self.group_size = len(group)
|
||||
|
|
|
@ -70,8 +70,7 @@ class ControlDepend(Primitive):
|
|||
@prim_attr_register
|
||||
def __init__(self, depend_mode=0):
|
||||
"""init"""
|
||||
validator.check_int_range(
|
||||
"depend_mode", depend_mode, 0, 1, Rel.INC_BOTH, self.name)
|
||||
validator.check_int_range(depend_mode, 0, 1, Rel.INC_BOTH, "depend_mode", self.name)
|
||||
|
||||
def __call__(self, src, dst):
|
||||
return src
|
||||
|
|
|
@ -31,7 +31,7 @@ def _infer_shape_reduce(x, axis, keep_dims, prim_name):
|
|||
"""Common infer for reduce operator"""
|
||||
|
||||
def reduce_one_axis(one_axis):
|
||||
validator.check_int_range('axis', one_axis, -dim, dim, Rel.INC_LEFT, prim_name)
|
||||
validator.check_int_range(one_axis, -dim, dim, Rel.INC_LEFT, 'axis', prim_name)
|
||||
if one_axis < 0:
|
||||
one_axis += dim
|
||||
axis_reduce.add(one_axis)
|
||||
|
|
|
@ -149,7 +149,7 @@ class Softmax(PrimitiveWithInfer):
|
|||
validator.check_integer("length of axis", len(self.axis), 1, Rel.GE, self.name)
|
||||
rank = len(logits)
|
||||
for axis_v in self.axis:
|
||||
validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name)
|
||||
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
|
||||
return logits
|
||||
|
||||
def infer_dtype(self, logits):
|
||||
|
@ -193,7 +193,7 @@ class LogSoftmax(PrimitiveWithInfer):
|
|||
|
||||
def infer_shape(self, logits):
|
||||
rank = len(logits)
|
||||
validator.check_int_range('axis', self.axis, -rank, rank, Rel.INC_LEFT, self.name)
|
||||
validator.check_int_range(self.axis, -rank, rank, Rel.INC_LEFT, 'axis', self.name)
|
||||
return logits
|
||||
|
||||
def infer_dtype(self, logits):
|
||||
|
@ -637,8 +637,8 @@ class FusedBatchNorm(Primitive):
|
|||
self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'],
|
||||
outputs=['y', 'running_mean', 'running_variance', 'save_mean', 'save_inv_variance'])
|
||||
self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name)
|
||||
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name)
|
||||
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
|
||||
self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
|
||||
self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
|
||||
self._update_parameter = True
|
||||
|
||||
|
||||
|
@ -710,8 +710,8 @@ class FusedBatchNormEx(PrimitiveWithInfer):
|
|||
self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'],
|
||||
outputs=['y', 'save_scale', 'save_bias', 'save_mean', 'save_inv_variance', 'reserve'])
|
||||
self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name)
|
||||
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name)
|
||||
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
|
||||
self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
|
||||
self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
|
||||
self._update_parameter = True
|
||||
self.add_prim_attr('data_format', "NCHW")
|
||||
|
||||
|
@ -818,8 +818,8 @@ class BNTrainingUpdate(PrimitiveWithInfer):
|
|||
validator.check_value_type("isRef", isRef, [bool], self.name)
|
||||
validator.check_value_type("epsilon", epsilon, [float], self.name)
|
||||
validator.check_value_type("factor", factor, [float], self.name)
|
||||
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, 'BNTrainingUpdate')
|
||||
self.factor = validator.check_number_range('factor', factor, 0, 1, Rel.INC_BOTH, 'BNTrainingUpdate')
|
||||
self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', 'BNTrainingUpdate')
|
||||
self.factor = validator.check_float_range(factor, 0, 1, Rel.INC_BOTH, 'factor', 'BNTrainingUpdate')
|
||||
|
||||
def infer_shape(self, x, sum, square_sum, scale, b, mean, variance):
|
||||
validator.check_integer("x rank", len(x), 4, Rel.EQ, self.name)
|
||||
|
@ -898,7 +898,7 @@ class BatchNorm(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, is_training=False, epsilon=1e-5):
|
||||
validator.check_value_type('is_training', is_training, (bool,), self.name)
|
||||
validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name)
|
||||
validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
|
||||
self.add_prim_attr('data_format', "NCHW")
|
||||
self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'],
|
||||
outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2'])
|
||||
|
@ -2383,7 +2383,7 @@ class L2Normalize(PrimitiveWithInfer):
|
|||
|
||||
def infer_shape(self, input_x):
|
||||
dim = len(input_x)
|
||||
validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name)
|
||||
validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name)
|
||||
return input_x
|
||||
|
||||
def infer_dtype(self, input_x):
|
||||
|
@ -2481,10 +2481,10 @@ class DropoutDoMask(PrimitiveWithInfer):
|
|||
keep_prob_v = keep_prob['value']
|
||||
if keep_prob_v is not None:
|
||||
if isinstance(keep_prob['dtype'], type(mstype.tensor)):
|
||||
validator.check_number_range('keep_prob', keep_prob_v.asnumpy(), 0, 1, Rel.INC_BOTH, self.name)
|
||||
validator.check_float_range(keep_prob_v.asnumpy(), 0, 1, Rel.INC_BOTH, 'keep_prob', self.name)
|
||||
else:
|
||||
validator.check_value_type("keep_prob", keep_prob_v, [float], self.name)
|
||||
validator.check_number_range('keep_prob', keep_prob_v, 0, 1, Rel.INC_BOTH, self.name)
|
||||
validator.check_float_range(keep_prob_v, 0, 1, Rel.INC_BOTH, 'keep_prob', self.name)
|
||||
|
||||
out = {'shape': input_x_shape,
|
||||
'dtype': input_x['dtype'],
|
||||
|
@ -2584,7 +2584,7 @@ class OneHot(PrimitiveWithInfer):
|
|||
|
||||
# check shape
|
||||
indices_shp = indices['shape']
|
||||
validator.check_int_range("axis", self.axis, -1, len(indices_shp), Rel.INC_BOTH, self.name)
|
||||
validator.check_int_range(self.axis, -1, len(indices_shp), Rel.INC_BOTH, "axis", self.name)
|
||||
depth_val = depth['value']
|
||||
validator.check_non_negative_int(depth_val, "depth", self.name)
|
||||
# create new dimension at end if self.axis is -1
|
||||
|
@ -2771,7 +2771,7 @@ class LSTM(PrimitiveWithInfer):
|
|||
self.has_bias = validator.check_value_type("has_bias", has_bias, (bool,), self.name)
|
||||
self.bidirectional = validator.check_value_type("bidirectional", bidirectional, (bool,), self.name)
|
||||
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
|
||||
self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name)
|
||||
self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name)
|
||||
|
||||
if bidirectional:
|
||||
self.num_directions = 2
|
||||
|
@ -3054,7 +3054,7 @@ class ROIAlign(PrimitiveWithInfer):
|
|||
validator.check_value_type("spatial_scale", spatial_scale, [float], self.name)
|
||||
validator.check_value_type("sample_num", sample_num, [int], self.name)
|
||||
validator.check_value_type("roi_end_mode", roi_end_mode, [int], self.name)
|
||||
validator.check_int_range("roi_end_mode", roi_end_mode, 0, 1, Rel.INC_BOTH, self.name)
|
||||
validator.check_int_range(roi_end_mode, 0, 1, Rel.INC_BOTH, "roi_end_mode", self.name)
|
||||
self.pooled_height = pooled_height
|
||||
self.pooled_width = pooled_width
|
||||
self.spatial_scale = spatial_scale
|
||||
|
@ -3502,9 +3502,9 @@ class FusedSparseFtrl(PrimitiveWithInfer):
|
|||
validator.check_value_type("l1", l1, [float], self.name)
|
||||
validator.check_value_type("l2", l2, [float], self.name)
|
||||
validator.check_value_type("lr_power", lr_power, [float], self.name)
|
||||
self.lr = validator.check_number_range("lr", lr, 0.0, float("inf"), Rel.INC_NEITHER, self.name)
|
||||
self.l1 = validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, self.name)
|
||||
self.l2 = validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, self.name)
|
||||
self.lr = validator.check_positive_float(lr, "lr", self.name)
|
||||
self.l1 = validator.check_non_negative_float(l1, "l1", self.name)
|
||||
self.l2 = validator.check_non_negative_float(l2, "l2", self.name)
|
||||
self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name)
|
||||
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||
|
||||
|
@ -4240,7 +4240,7 @@ class SparseApplyAdagrad(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, lr, update_slots=True, use_locking=False):
|
||||
validator.check_value_type("lr", lr, [float], self.name)
|
||||
validator.check_number_range("lr", lr, float("-inf"), float("inf"), Rel.INC_NEITHER, self.name)
|
||||
validator.check_is_float(lr, "lr", self.name)
|
||||
validator.check_value_type("update_slots", update_slots, [bool], self.name)
|
||||
validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||
|
||||
|
@ -5142,9 +5142,9 @@ class SparseApplyFtrl(PrimitiveWithCheck):
|
|||
validator.check_value_type("l1", l1, [float], self.name)
|
||||
validator.check_value_type("l2", l2, [float], self.name)
|
||||
validator.check_value_type("lr_power", lr_power, [float], self.name)
|
||||
self.lr = validator.check_number_range("lr", lr, 0.0, float("inf"), Rel.INC_NEITHER, self.name)
|
||||
self.l1 = validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, self.name)
|
||||
self.l2 = validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, self.name)
|
||||
self.lr = validator.check_positive_float(lr, "lr", self.name)
|
||||
self.l1 = validator.check_non_negative_float(l1, "l1", self.name)
|
||||
self.l2 = validator.check_non_negative_float(l2, "l2", self.name)
|
||||
self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name)
|
||||
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||
self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'indices'],
|
||||
|
@ -5239,9 +5239,9 @@ class SparseApplyFtrlV2(PrimitiveWithInfer):
|
|||
validator.check_value_type("l1", l1, [float], self.name)
|
||||
validator.check_value_type("l2", l2, [float], self.name)
|
||||
validator.check_value_type("lr_power", lr_power, [float], self.name)
|
||||
self.lr = validator.check_number_range("lr", lr, 0.0, float("inf"), Rel.INC_NEITHER, self.name)
|
||||
self.l1 = validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, self.name)
|
||||
self.l2 = validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, self.name)
|
||||
self.lr = validator.check_positive_float(lr, "lr", self.name)
|
||||
self.l1 = validator.check_non_negative_float(l1, "l1", self.name)
|
||||
self.l2 = validator.check_non_negative_float(l2, "l2", self.name)
|
||||
self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name)
|
||||
self.l2_shrinkage = validator.check_value_type("l2_shrinkage", l2_shrinkage, [float], self.name)
|
||||
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||
|
@ -5285,7 +5285,7 @@ class Dropout(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, keep_prob=0.5):
|
||||
self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0, 1, Rel.INC_RIGHT, self.name)
|
||||
self.keep_prob = validator.check_float_range(keep_prob, 0, 1, Rel.INC_RIGHT, "keep_prob", self.name)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
validator.check_integer("x_shape", len(x_shape), 1, Rel.GE, self.name)
|
||||
|
@ -5510,7 +5510,7 @@ class BasicLSTMCell(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'):
|
||||
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
|
||||
self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0.0, 1.0, Rel.INC_BOTH, self.name)
|
||||
self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name)
|
||||
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
|
||||
self.state_is_tuple = validator.check_value_type("state_is_tuple", state_is_tuple, [bool], self.name)
|
||||
self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)
|
||||
|
|
|
@ -100,10 +100,8 @@ def _generate_cosine_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps):
|
|||
lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
lr = float(lr_init) + lr_inc * (i + 1)
|
||||
else:
|
||||
linear_decay = (total_steps - i) / decay_steps
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))
|
||||
decayed = linear_decay * cosine_decay + 0.00001
|
||||
lr = lr_max * decayed
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * (i - warmup_steps) / decay_steps))
|
||||
lr = (lr_max - lr_end) * cosine_decay + lr_end
|
||||
lr_each_step.append(lr)
|
||||
return lr_each_step
|
||||
|
||||
|
|
|
@ -122,7 +122,7 @@ class MySparseGatherV2(PrimitiveWithInfer):
|
|||
axis_v = axis['value']
|
||||
params_shp = params['shape']
|
||||
rank = len(params_shp)
|
||||
validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name)
|
||||
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
|
||||
if axis_v < 0:
|
||||
axis_v += rank
|
||||
out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:]
|
||||
|
@ -208,10 +208,10 @@ def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
|
|||
validator.check_value_type("beta2", beta2, [float], prim_name)
|
||||
validator.check_value_type("eps", eps, [float], prim_name)
|
||||
validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
|
||||
validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
|
||||
validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
|
||||
validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name)
|
||||
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
|
||||
validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
|
||||
validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
|
||||
validator.check_positive_float(eps, "eps", prim_name)
|
||||
validator.check_non_negative_float(weight_decay, "weight_decay", prim_name)
|
||||
|
||||
|
||||
class AdamWeightDecaySparse(Optimizer):
|
||||
|
|
|
@ -14,55 +14,97 @@
|
|||
# ============================================================================
|
||||
""" test checkparameter """
|
||||
import pytest
|
||||
|
||||
from mindspore._checkparam import check_int, check_input_format, Validator, twice
|
||||
import numpy as np
|
||||
from mindspore._checkparam import check_input_format, Validator, twice, Rel
|
||||
|
||||
kernel_size = 5
|
||||
kernel_size1 = twice(kernel_size)
|
||||
assert kernel_size1 == (5, 5)
|
||||
|
||||
def test_check_integer1():
|
||||
with pytest.raises(TypeError):
|
||||
Validator.check_integer("input", 0, Rel.GE, "number")
|
||||
|
||||
def test_check_int_1():
|
||||
assert check_int(3) == 3
|
||||
|
||||
|
||||
def check_int_positive_1():
|
||||
def test_check_integer2():
|
||||
with pytest.raises(ValueError):
|
||||
Validator.check_positive_int(-1)
|
||||
Validator.check_integer(-1, 0, Rel.GE, "number")
|
||||
|
||||
def test_check_integer3():
|
||||
input = np.random.randint(0, 100)
|
||||
assert Validator.check_integer(input, 0, Rel.GE, "number") == input
|
||||
|
||||
def test_NCHW1():
|
||||
assert check_input_format("NCHW") == "NCHW"
|
||||
def test_check_int1():
|
||||
input = np.random.randint(-100, 100)
|
||||
assert Validator.check_is_int(input) == input
|
||||
|
||||
def test_check_int2():
|
||||
with pytest.raises(TypeError):
|
||||
Validator.check_is_int(3.3)
|
||||
|
||||
def test_NCHW3():
|
||||
def test_check_int3():
|
||||
with pytest.raises(TypeError):
|
||||
Validator.check_is_int("str")
|
||||
|
||||
def test_check_int4():
|
||||
with pytest.raises(TypeError):
|
||||
Validator.check_is_int(True)
|
||||
|
||||
def test_check_is_int5():
|
||||
with pytest.raises(TypeError):
|
||||
Validator.check_is_int(True)
|
||||
with pytest.raises(TypeError):
|
||||
Validator.check_is_int(False)
|
||||
|
||||
def test_check_positive_int1():
|
||||
input = np.random.randint(0, 100)
|
||||
assert Validator.check_positive_int(input) == input
|
||||
|
||||
def test_check_positive_int2():
|
||||
input = np.random.randint(-100, 0)
|
||||
with pytest.raises(ValueError):
|
||||
check_input_format("rt")
|
||||
Validator.check_positive_int(input)
|
||||
|
||||
def test_check_positive_int3():
|
||||
with pytest.raises(ValueError):
|
||||
Validator.check_positive_int(3.3)
|
||||
|
||||
def test_check_int_2():
|
||||
def test_check_positive_int4():
|
||||
with pytest.raises(TypeError):
|
||||
check_int(3.3)
|
||||
Validator.check_positive_int("str")
|
||||
|
||||
def test_check_negative_int1():
|
||||
input = np.random.randint(-100, -1)
|
||||
assert Validator.check_negative_int(input) == input
|
||||
|
||||
def test_check_int_3():
|
||||
def test_check_negative_int2():
|
||||
input = np.random.randint(0, 100)
|
||||
with pytest.raises(ValueError):
|
||||
Validator.check_negative_int(input)
|
||||
|
||||
def test_check_negative_int3():
|
||||
with pytest.raises(ValueError):
|
||||
Validator.check_negative_int(3.3)
|
||||
|
||||
def test_check_negative_int4():
|
||||
with pytest.raises(TypeError):
|
||||
check_int("str")
|
||||
Validator.check_negative_int("str")
|
||||
|
||||
def test_check_non_positive_int1():
|
||||
input = np.random.randint(-100, 0)
|
||||
assert Validator.check_non_positive_int(input) == input
|
||||
|
||||
def test_check_int_4():
|
||||
def test_check_non_positive_int2():
|
||||
input = np.random.randint(1, 100)
|
||||
with pytest.raises(ValueError):
|
||||
Validator.check_non_positive_int(input)
|
||||
|
||||
def test_check_non_positive_int3():
|
||||
with pytest.raises(ValueError):
|
||||
Validator.check_non_positive_int(3.3)
|
||||
|
||||
def test_check_non_positive_int4():
|
||||
with pytest.raises(TypeError):
|
||||
check_int(True)
|
||||
|
||||
|
||||
def test_check_int_5():
|
||||
check_int(0)
|
||||
check_int(1)
|
||||
with pytest.raises(TypeError):
|
||||
check_int(True)
|
||||
with pytest.raises(TypeError):
|
||||
check_int(False)
|
||||
|
||||
Validator.check_non_positive_int("str")
|
||||
|
||||
def test_check_bool_1():
|
||||
assert Validator.check_bool(True)
|
||||
|
|
Loading…
Reference in New Issue