forked from mindspore-Ecosystem/mindspore
!7169 [ME] format `check_type` and `check_type_name`
Merge pull request !7169 from chenzhongming/zomi_master
This commit is contained in:
commit
0d9fece038
|
@ -92,7 +92,7 @@ rel_strs = {
|
|||
}
|
||||
|
||||
|
||||
def _check_integer(arg_value, value, rel, arg_name=None, prim_name=None):
|
||||
def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=None):
|
||||
"""
|
||||
Check argument integer.
|
||||
|
||||
|
@ -100,13 +100,13 @@ def _check_integer(arg_value, value, rel, arg_name=None, prim_name=None):
|
|||
- number = check_integer(number, 0, Rel.GE, "number", None) # number >= 0
|
||||
"""
|
||||
rel_fn = Rel.get_fns(rel)
|
||||
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
|
||||
type_mismatch = not isinstance(arg_value, arg_type) or isinstance(arg_value, bool)
|
||||
type_except = TypeError if type_mismatch else ValueError
|
||||
if type_mismatch or not rel_fn(arg_value, value):
|
||||
rel_str = Rel.get_strs(rel).format(value)
|
||||
arg_name = arg_name if arg_name else "parameter"
|
||||
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
|
||||
raise type_except(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got `{arg_value}`'
|
||||
raise type_except(f'{msg_prefix} `{arg_name}` should be an {arg_type} and must {rel_str}, but got `{arg_value}`'
|
||||
f' with type `{type(arg_value).__name__}`.')
|
||||
return arg_value
|
||||
|
||||
|
@ -149,7 +149,7 @@ class Validator:
|
|||
- number = check_positive_int(number)
|
||||
- number = check_positive_int(number, "bias")
|
||||
"""
|
||||
return _check_integer(arg_value, 0, Rel.GT, arg_name, prim_name)
|
||||
return check_number(arg_value, 0, Rel.GT, int, arg_name, prim_name)
|
||||
|
||||
@staticmethod
|
||||
def check_negative_int(arg_value, arg_name=None, prim_name=None):
|
||||
|
@ -160,7 +160,7 @@ class Validator:
|
|||
- number = check_negative_int(number)
|
||||
- number = check_negative_int(number, "bias")
|
||||
"""
|
||||
return _check_integer(arg_value, 0, Rel.LT, arg_name, prim_name)
|
||||
return check_number(arg_value, 0, Rel.LT, int, arg_name, prim_name)
|
||||
|
||||
@staticmethod
|
||||
def check_non_positive_int(arg_value, arg_name=None, prim_name=None):
|
||||
|
@ -171,7 +171,7 @@ class Validator:
|
|||
- number = check_non_positive_int(number)
|
||||
- number = check_non_positive_int(number, "bias")
|
||||
"""
|
||||
return _check_integer(arg_value, 0, Rel.LE, arg_name, prim_name)
|
||||
return check_number(arg_value, 0, Rel.LE, int, arg_name, prim_name)
|
||||
|
||||
@staticmethod
|
||||
def check_non_negative_int(arg_value, arg_name=None, prim_name=None):
|
||||
|
@ -182,7 +182,52 @@ class Validator:
|
|||
- number = check_non_negative_int(number)
|
||||
- number = check_non_negative_int(number, "bias")
|
||||
"""
|
||||
return _check_integer(arg_value, 0, Rel.GE, arg_name, prim_name)
|
||||
return check_number(arg_value, 0, Rel.GE, int, arg_name, prim_name)
|
||||
|
||||
@staticmethod
|
||||
def check_positive_float(arg_value, arg_name=None, prim_name=None):
|
||||
"""
|
||||
Check argument is positive float, which mean arg_value > 0.
|
||||
|
||||
Usage:
|
||||
- number = check_positive_float(number)
|
||||
- number = check_positive_float(number, "bias")
|
||||
- number = check_positive_float(number, "bias", "bias_class")
|
||||
"""
|
||||
return check_number(arg_value, 0, Rel.GT, float, arg_name, prim_name)
|
||||
|
||||
@staticmethod
|
||||
def check_negative_float(arg_value, arg_name=None, prim_name=None):
|
||||
"""
|
||||
Check argument is negative float, which mean arg_value < 0.
|
||||
|
||||
Usage:
|
||||
- number = check_negative_float(number)
|
||||
- number = check_negative_float(number, "bias")
|
||||
"""
|
||||
return check_number(arg_value, 0, Rel.LT, float, arg_name, prim_name)
|
||||
|
||||
@staticmethod
|
||||
def check_non_positive_float(arg_value, arg_name=None, prim_name=None):
|
||||
"""
|
||||
Check argument is non-negative float, which mean arg_value <= 0.
|
||||
|
||||
Usage:
|
||||
- number = check_non_positive_float(number)
|
||||
- number = check_non_positive_float(number, "bias")
|
||||
"""
|
||||
return check_number(arg_value, 0, Rel.LE, float, arg_name, prim_name)
|
||||
|
||||
@staticmethod
|
||||
def check_non_negative_float(arg_value, arg_name=None, prim_name=None):
|
||||
"""
|
||||
Check argument is non-negative float, which mean arg_value >= 0.
|
||||
|
||||
Usage:
|
||||
- number = check_non_negative_float(number)
|
||||
- number = check_non_negative_float(number, "bias")
|
||||
"""
|
||||
return check_number(arg_value, 0, Rel.GE, float, arg_name, prim_name)
|
||||
|
||||
@staticmethod
|
||||
def check_number(arg_name, arg_value, value, rel, prim_name):
|
||||
|
@ -257,16 +302,6 @@ class Validator:
|
|||
raise ValueError(f"For '{prim_name}', padding must be zero when pad_mode is '{pad_mode}'.")
|
||||
return padding
|
||||
|
||||
@staticmethod
|
||||
def check_float_positive(arg_name, arg_value, prim_name):
|
||||
"""Float type judgment."""
|
||||
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
|
||||
if isinstance(arg_value, float):
|
||||
if arg_value > 0:
|
||||
return arg_value
|
||||
raise ValueError(f"{msg_prefix} `{arg_name}` must be positive, but got {arg_value}.")
|
||||
raise TypeError(f"{msg_prefix} `{arg_name}` must be float.")
|
||||
|
||||
@staticmethod
|
||||
def check_subclass(arg_name, type_, template_types, prim_name):
|
||||
"""Checks whether some type is subclass of another type"""
|
||||
|
|
|
@ -82,12 +82,6 @@ def check_positive(value, arg_name=""):
|
|||
raise ValueError("Input {0}must be greater than 0.".format(arg_name))
|
||||
|
||||
|
||||
def check_positive_float(value, arg_name=""):
|
||||
arg_name = pad_arg_name(arg_name)
|
||||
type_check(value, (float,), arg_name)
|
||||
check_positive(value, arg_name)
|
||||
|
||||
|
||||
def check_2tuple(value, arg_name=""):
|
||||
if not (isinstance(value, tuple) and len(value) == 2):
|
||||
raise ValueError("Value {0}needs to be a 2-tuple.".format(arg_name))
|
||||
|
|
|
@ -66,9 +66,9 @@ def _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_e
|
|||
validator.check_positive_int(total_step, 'total_step')
|
||||
validator.check_positive_int(step_per_epoch, 'step_per_epoch')
|
||||
validator.check_positive_int(decay_epoch, 'decay_epoch')
|
||||
validator.check_float_positive('learning_rate', learning_rate, None)
|
||||
validator.check_positive_float(learning_rate, 'learning_rate')
|
||||
validator.check_float_legal_value('learning_rate', learning_rate, None)
|
||||
validator.check_float_positive('decay_rate', decay_rate, None)
|
||||
validator.check_positive_float(decay_rate, 'decay_rate')
|
||||
validator.check_float_legal_value('decay_rate', decay_rate, None)
|
||||
validator.check_value_type('is_stair', is_stair, [bool], None)
|
||||
|
||||
|
@ -234,7 +234,7 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch):
|
|||
if not isinstance(min_lr, float):
|
||||
raise TypeError("min_lr must be float.")
|
||||
validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, None)
|
||||
validator.check_float_positive('max_lr', max_lr, None)
|
||||
validator.check_positive_float(max_lr, 'max_lr')
|
||||
validator.check_float_legal_value('max_lr', max_lr, None)
|
||||
validator.check_positive_int(total_step, 'total_step')
|
||||
validator.check_positive_int(step_per_epoch, 'step_per_epoch')
|
||||
|
@ -299,12 +299,12 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e
|
|||
>>> polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power)
|
||||
[0.1, 0.1, 0.07363961030678928, 0.07363961030678928, 0.01, 0.01]
|
||||
"""
|
||||
validator.check_float_positive('learning_rate', learning_rate, None)
|
||||
validator.check_positive_float(learning_rate, 'learning_rate')
|
||||
validator.check_float_legal_value('learning_rate', learning_rate, None)
|
||||
if not isinstance(end_learning_rate, float):
|
||||
raise TypeError("end_learning_rate must be float.")
|
||||
validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None)
|
||||
validator.check_float_positive('power', power, None)
|
||||
validator.check_positive_float(power, 'power')
|
||||
validator.check_float_legal_value('power', power, None)
|
||||
validator.check_positive_int(total_step, 'total_step')
|
||||
validator.check_positive_int(step_per_epoch, 'step_per_epoch')
|
||||
|
|
|
@ -221,7 +221,7 @@ class SSIM(Cell):
|
|||
validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name)
|
||||
self.max_val = max_val
|
||||
self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name)
|
||||
self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma, self.cls_name)
|
||||
self.filter_sigma = validator.check_positive_float(filter_sigma, 'filter_sigma', self.cls_name)
|
||||
self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name)
|
||||
self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name)
|
||||
window = _create_window(filter_size, filter_sigma)
|
||||
|
@ -299,7 +299,7 @@ class MSSSIM(Cell):
|
|||
self.max_val = max_val
|
||||
validator.check_value_type('power_factors', power_factors, [tuple, list], self.cls_name)
|
||||
self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name)
|
||||
self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma, self.cls_name)
|
||||
self.filter_sigma = validator.check_positive_float(filter_sigma, 'filter_sigma', self.cls_name)
|
||||
self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name)
|
||||
self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name)
|
||||
window = _create_window(filter_size, filter_sigma)
|
||||
|
|
|
@ -45,9 +45,9 @@ class LearningRateSchedule(Cell):
|
|||
|
||||
def _check_inputs(learning_rate, decay_rate, decay_steps, is_stair, cls_name):
|
||||
validator.check_positive_int(decay_steps, 'decay_steps', cls_name)
|
||||
validator.check_float_positive('learning_rate', learning_rate, cls_name)
|
||||
validator.check_positive_float(learning_rate, 'learning_rate', cls_name)
|
||||
validator.check_float_legal_value('learning_rate', learning_rate, cls_name)
|
||||
validator.check_float_positive('decay_rate', decay_rate, cls_name)
|
||||
validator.check_positive_float(decay_rate, 'decay_rate', cls_name)
|
||||
validator.check_float_legal_value('decay_rate', decay_rate, cls_name)
|
||||
validator.check_value_type('is_stair', is_stair, [bool], cls_name)
|
||||
|
||||
|
@ -255,7 +255,7 @@ class CosineDecayLR(LearningRateSchedule):
|
|||
if not isinstance(min_lr, float):
|
||||
raise TypeError("min_lr must be float.")
|
||||
validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
|
||||
validator.check_float_positive('max_lr', max_lr, self.cls_name)
|
||||
validator.check_positive_float(max_lr, 'max_lr', self.cls_name)
|
||||
validator.check_float_legal_value('max_lr', max_lr, self.cls_name)
|
||||
validator.check_positive_int(decay_steps, "decay_steps", self.cls_name)
|
||||
if min_lr >= max_lr:
|
||||
|
@ -318,7 +318,7 @@ class PolynomialDecayLR(LearningRateSchedule):
|
|||
"""
|
||||
def __init__(self, learning_rate, end_learning_rate, decay_steps, power, update_decay_steps=False):
|
||||
super(PolynomialDecayLR, self).__init__()
|
||||
validator.check_float_positive('learning_rate', learning_rate, None)
|
||||
validator.check_positive_float(learning_rate, 'learning_rate')
|
||||
validator.check_float_legal_value('learning_rate', learning_rate, None)
|
||||
if not isinstance(end_learning_rate, float):
|
||||
raise TypeError("end_learning_rate must be float.")
|
||||
|
@ -326,7 +326,7 @@ class PolynomialDecayLR(LearningRateSchedule):
|
|||
self.cls_name)
|
||||
validator.check_positive_int(decay_steps, 'decay_steps', self.cls_name)
|
||||
validator.check_value_type('update_decay_steps', update_decay_steps, [bool], self.cls_name)
|
||||
validator.check_float_positive('power', power, self.cls_name)
|
||||
validator.check_positive_float(power, 'power', self.cls_name)
|
||||
validator.check_float_legal_value('power', power, self.cls_name)
|
||||
|
||||
self.decay_steps = decay_steps
|
||||
|
|
|
@ -503,7 +503,7 @@ class BatchNormFold(PrimitiveWithInfer):
|
|||
def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0):
|
||||
"""Initialize batch norm fold layer"""
|
||||
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
|
||||
self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name)
|
||||
self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
|
||||
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
|
||||
self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
|
||||
|
||||
|
@ -546,7 +546,7 @@ class BatchNormFoldGrad(PrimitiveWithInfer):
|
|||
"""Initialize BatchNormGrad layer"""
|
||||
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
|
||||
self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
|
||||
self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name)
|
||||
self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
|
||||
self.init_prim_io_names(inputs=['d_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'global_step'],
|
||||
outputs=['dx'])
|
||||
|
||||
|
@ -814,7 +814,7 @@ class BatchNormFoldD(PrimitiveWithInfer):
|
|||
"""Initialize _BatchNormFold layer"""
|
||||
from mindspore.ops._op_impl._custom_op import batchnorm_fold
|
||||
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
|
||||
self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name)
|
||||
self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
|
||||
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
|
||||
self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
|
||||
self.data_format = "NCHW"
|
||||
|
@ -842,7 +842,7 @@ class BatchNormFoldGradD(PrimitiveWithInfer):
|
|||
def __init__(self, epsilon=1e-5, is_training=True, freeze_bn=0):
|
||||
"""Initialize _BatchNormFoldGrad layer"""
|
||||
from mindspore.ops._op_impl._custom_op import batchnorm_fold_grad
|
||||
self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name)
|
||||
self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
|
||||
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
|
||||
self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
|
||||
self.init_prim_io_names(inputs=['d_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std'],
|
||||
|
|
|
@ -3560,7 +3560,7 @@ class IFMR(PrimitiveWithInfer):
|
|||
validator.check_value_type("max_percentile", max_percentile, [float], self.name)
|
||||
validator.check_value_type("search_range", search_range, [list, tuple], self.name)
|
||||
for item in search_range:
|
||||
validator.check_float_positive("item of search_range", item, self.name)
|
||||
validator.check_positive_float(item, "item of search_range", self.name)
|
||||
validator.check('search_range[1]', search_range[1], 'search_range[0]', search_range[0], Rel.GE, self.name)
|
||||
validator.check_value_type("search_step", search_step, [float], self.name)
|
||||
validator.check_value_type("offset_flag", with_offset, [bool], self.name)
|
||||
|
|
Loading…
Reference in New Issue