diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index 6e8d6d8b4ec..40bb684eb10 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -92,7 +92,7 @@ rel_strs = { } -def _check_integer(arg_value, value, rel, arg_name=None, prim_name=None): +def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=None): """ Check argument integer. @@ -100,13 +100,13 @@ def _check_integer(arg_value, value, rel, arg_name=None, prim_name=None): - number = check_integer(number, 0, Rel.GE, "number", None) # number >= 0 """ rel_fn = Rel.get_fns(rel) - type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool) + type_mismatch = not isinstance(arg_value, arg_type) or isinstance(arg_value, bool) type_except = TypeError if type_mismatch else ValueError if type_mismatch or not rel_fn(arg_value, value): rel_str = Rel.get_strs(rel).format(value) arg_name = arg_name if arg_name else "parameter" msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" - raise type_except(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got `{arg_value}`' + raise type_except(f'{msg_prefix} `{arg_name}` should be an {arg_type} and must {rel_str}, but got `{arg_value}`' f' with type `{type(arg_value).__name__}`.') return arg_value @@ -149,7 +149,7 @@ class Validator: - number = check_positive_int(number) - number = check_positive_int(number, "bias") """ - return _check_integer(arg_value, 0, Rel.GT, arg_name, prim_name) + return check_number(arg_value, 0, Rel.GT, int, arg_name, prim_name) @staticmethod def check_negative_int(arg_value, arg_name=None, prim_name=None): @@ -160,7 +160,7 @@ class Validator: - number = check_negative_int(number) - number = check_negative_int(number, "bias") """ - return _check_integer(arg_value, 0, Rel.LT, arg_name, prim_name) + return check_number(arg_value, 0, Rel.LT, int, arg_name, prim_name) @staticmethod def check_non_positive_int(arg_value, arg_name=None, prim_name=None): @@ -171,7 +171,7 @@ class Validator: - number = check_non_positive_int(number) - number = check_non_positive_int(number, "bias") """ - return _check_integer(arg_value, 0, Rel.LE, arg_name, prim_name) + return check_number(arg_value, 0, Rel.LE, int, arg_name, prim_name) @staticmethod def check_non_negative_int(arg_value, arg_name=None, prim_name=None): @@ -182,7 +182,52 @@ class Validator: - number = check_non_negative_int(number) - number = check_non_negative_int(number, "bias") """ - return _check_integer(arg_value, 0, Rel.GE, arg_name, prim_name) + return check_number(arg_value, 0, Rel.GE, int, arg_name, prim_name) + + @staticmethod + def check_positive_float(arg_value, arg_name=None, prim_name=None): + """ + Check argument is positive float, which mean arg_value > 0. + + Usage: + - number = check_positive_float(number) + - number = check_positive_float(number, "bias") + - number = check_positive_float(number, "bias", "bias_class") + """ + return check_number(arg_value, 0, Rel.GT, float, arg_name, prim_name) + + @staticmethod + def check_negative_float(arg_value, arg_name=None, prim_name=None): + """ + Check argument is negative float, which mean arg_value < 0. + + Usage: + - number = check_negative_float(number) + - number = check_negative_float(number, "bias") + """ + return check_number(arg_value, 0, Rel.LT, float, arg_name, prim_name) + + @staticmethod + def check_non_positive_float(arg_value, arg_name=None, prim_name=None): + """ + Check argument is non-negative float, which mean arg_value <= 0. + + Usage: + - number = check_non_positive_float(number) + - number = check_non_positive_float(number, "bias") + """ + return check_number(arg_value, 0, Rel.LE, float, arg_name, prim_name) + + @staticmethod + def check_non_negative_float(arg_value, arg_name=None, prim_name=None): + """ + Check argument is non-negative float, which mean arg_value >= 0. + + Usage: + - number = check_non_negative_float(number) + - number = check_non_negative_float(number, "bias") + """ + return check_number(arg_value, 0, Rel.GE, float, arg_name, prim_name) @staticmethod def check_number(arg_name, arg_value, value, rel, prim_name): @@ -257,16 +302,6 @@ class Validator: raise ValueError(f"For '{prim_name}', padding must be zero when pad_mode is '{pad_mode}'.") return padding - @staticmethod - def check_float_positive(arg_name, arg_value, prim_name): - """Float type judgment.""" - msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" - if isinstance(arg_value, float): - if arg_value > 0: - return arg_value - raise ValueError(f"{msg_prefix} `{arg_name}` must be positive, but got {arg_value}.") - raise TypeError(f"{msg_prefix} `{arg_name}` must be float.") - @staticmethod def check_subclass(arg_name, type_, template_types, prim_name): """Checks whether some type is subclass of another type""" diff --git a/mindspore/dataset/core/validator_helpers.py b/mindspore/dataset/core/validator_helpers.py index 1bb87e1a668..7bfdccf4277 100644 --- a/mindspore/dataset/core/validator_helpers.py +++ b/mindspore/dataset/core/validator_helpers.py @@ -82,12 +82,6 @@ def check_positive(value, arg_name=""): raise ValueError("Input {0}must be greater than 0.".format(arg_name)) -def check_positive_float(value, arg_name=""): - arg_name = pad_arg_name(arg_name) - type_check(value, (float,), arg_name) - check_positive(value, arg_name) - - def check_2tuple(value, arg_name=""): if not (isinstance(value, tuple) and len(value) == 2): raise ValueError("Value {0}needs to be a 2-tuple.".format(arg_name)) diff --git a/mindspore/nn/dynamic_lr.py b/mindspore/nn/dynamic_lr.py index 3d36305e4ce..71102dadd3f 100644 --- a/mindspore/nn/dynamic_lr.py +++ b/mindspore/nn/dynamic_lr.py @@ -66,9 +66,9 @@ def _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_e validator.check_positive_int(total_step, 'total_step') validator.check_positive_int(step_per_epoch, 'step_per_epoch') validator.check_positive_int(decay_epoch, 'decay_epoch') - validator.check_float_positive('learning_rate', learning_rate, None) + validator.check_positive_float(learning_rate, 'learning_rate') validator.check_float_legal_value('learning_rate', learning_rate, None) - validator.check_float_positive('decay_rate', decay_rate, None) + validator.check_positive_float(decay_rate, 'decay_rate') validator.check_float_legal_value('decay_rate', decay_rate, None) validator.check_value_type('is_stair', is_stair, [bool], None) @@ -234,7 +234,7 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch): if not isinstance(min_lr, float): raise TypeError("min_lr must be float.") validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, None) - validator.check_float_positive('max_lr', max_lr, None) + validator.check_positive_float(max_lr, 'max_lr') validator.check_float_legal_value('max_lr', max_lr, None) validator.check_positive_int(total_step, 'total_step') validator.check_positive_int(step_per_epoch, 'step_per_epoch') @@ -299,12 +299,12 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e >>> polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power) [0.1, 0.1, 0.07363961030678928, 0.07363961030678928, 0.01, 0.01] """ - validator.check_float_positive('learning_rate', learning_rate, None) + validator.check_positive_float(learning_rate, 'learning_rate') validator.check_float_legal_value('learning_rate', learning_rate, None) if not isinstance(end_learning_rate, float): raise TypeError("end_learning_rate must be float.") validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None) - validator.check_float_positive('power', power, None) + validator.check_positive_float(power, 'power') validator.check_float_legal_value('power', power, None) validator.check_positive_int(total_step, 'total_step') validator.check_positive_int(step_per_epoch, 'step_per_epoch') diff --git a/mindspore/nn/layer/image.py b/mindspore/nn/layer/image.py index eaa8810d3ab..8a25869deff 100644 --- a/mindspore/nn/layer/image.py +++ b/mindspore/nn/layer/image.py @@ -221,7 +221,7 @@ class SSIM(Cell): validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name) self.max_val = max_val self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name) - self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma, self.cls_name) + self.filter_sigma = validator.check_positive_float(filter_sigma, 'filter_sigma', self.cls_name) self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name) self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name) window = _create_window(filter_size, filter_sigma) @@ -299,7 +299,7 @@ class MSSSIM(Cell): self.max_val = max_val validator.check_value_type('power_factors', power_factors, [tuple, list], self.cls_name) self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name) - self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma, self.cls_name) + self.filter_sigma = validator.check_positive_float(filter_sigma, 'filter_sigma', self.cls_name) self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name) self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name) window = _create_window(filter_size, filter_sigma) diff --git a/mindspore/nn/learning_rate_schedule.py b/mindspore/nn/learning_rate_schedule.py index 0cddd2c6c5e..6bcc7a8265e 100644 --- a/mindspore/nn/learning_rate_schedule.py +++ b/mindspore/nn/learning_rate_schedule.py @@ -45,9 +45,9 @@ class LearningRateSchedule(Cell): def _check_inputs(learning_rate, decay_rate, decay_steps, is_stair, cls_name): validator.check_positive_int(decay_steps, 'decay_steps', cls_name) - validator.check_float_positive('learning_rate', learning_rate, cls_name) + validator.check_positive_float(learning_rate, 'learning_rate', cls_name) validator.check_float_legal_value('learning_rate', learning_rate, cls_name) - validator.check_float_positive('decay_rate', decay_rate, cls_name) + validator.check_positive_float(decay_rate, 'decay_rate', cls_name) validator.check_float_legal_value('decay_rate', decay_rate, cls_name) validator.check_value_type('is_stair', is_stair, [bool], cls_name) @@ -255,7 +255,7 @@ class CosineDecayLR(LearningRateSchedule): if not isinstance(min_lr, float): raise TypeError("min_lr must be float.") validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) - validator.check_float_positive('max_lr', max_lr, self.cls_name) + validator.check_positive_float(max_lr, 'max_lr', self.cls_name) validator.check_float_legal_value('max_lr', max_lr, self.cls_name) validator.check_positive_int(decay_steps, "decay_steps", self.cls_name) if min_lr >= max_lr: @@ -318,7 +318,7 @@ class PolynomialDecayLR(LearningRateSchedule): """ def __init__(self, learning_rate, end_learning_rate, decay_steps, power, update_decay_steps=False): super(PolynomialDecayLR, self).__init__() - validator.check_float_positive('learning_rate', learning_rate, None) + validator.check_positive_float(learning_rate, 'learning_rate') validator.check_float_legal_value('learning_rate', learning_rate, None) if not isinstance(end_learning_rate, float): raise TypeError("end_learning_rate must be float.") @@ -326,7 +326,7 @@ class PolynomialDecayLR(LearningRateSchedule): self.cls_name) validator.check_positive_int(decay_steps, 'decay_steps', self.cls_name) validator.check_value_type('update_decay_steps', update_decay_steps, [bool], self.cls_name) - validator.check_float_positive('power', power, self.cls_name) + validator.check_positive_float(power, 'power', self.cls_name) validator.check_float_legal_value('power', power, self.cls_name) self.decay_steps = decay_steps diff --git a/mindspore/ops/operations/_quant_ops.py b/mindspore/ops/operations/_quant_ops.py index 5b07ce460a8..feb7d066b56 100644 --- a/mindspore/ops/operations/_quant_ops.py +++ b/mindspore/ops/operations/_quant_ops.py @@ -503,7 +503,7 @@ class BatchNormFold(PrimitiveWithInfer): def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0): """Initialize batch norm fold layer""" self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) - self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name) + self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name) self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) @@ -546,7 +546,7 @@ class BatchNormFoldGrad(PrimitiveWithInfer): """Initialize BatchNormGrad layer""" self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) - self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name) + self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name) self.init_prim_io_names(inputs=['d_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'global_step'], outputs=['dx']) @@ -814,7 +814,7 @@ class BatchNormFoldD(PrimitiveWithInfer): """Initialize _BatchNormFold layer""" from mindspore.ops._op_impl._custom_op import batchnorm_fold self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) - self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name) + self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name) self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) self.data_format = "NCHW" @@ -842,7 +842,7 @@ class BatchNormFoldGradD(PrimitiveWithInfer): def __init__(self, epsilon=1e-5, is_training=True, freeze_bn=0): """Initialize _BatchNormFoldGrad layer""" from mindspore.ops._op_impl._custom_op import batchnorm_fold_grad - self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name) + self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name) self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) self.init_prim_io_names(inputs=['d_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std'], diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 3b26079fdb9..e963e28cb78 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -3560,7 +3560,7 @@ class IFMR(PrimitiveWithInfer): validator.check_value_type("max_percentile", max_percentile, [float], self.name) validator.check_value_type("search_range", search_range, [list, tuple], self.name) for item in search_range: - validator.check_float_positive("item of search_range", item, self.name) + validator.check_positive_float(item, "item of search_range", self.name) validator.check('search_range[1]', search_range[1], 'search_range[0]', search_range[0], Rel.GE, self.name) validator.check_value_type("search_step", search_step, [float], self.name) validator.check_value_type("offset_flag", with_offset, [bool], self.name)