!7169 [ME] format `check_type` and `check_type_name`

Merge pull request !7169 from chenzhongming/zomi_master
This commit is contained in:
mindspore-ci-bot 2020-10-13 16:17:07 +08:00 committed by Gitee
commit 0d9fece038
7 changed files with 69 additions and 40 deletions

View File

@ -92,7 +92,7 @@ rel_strs = {
}
def _check_integer(arg_value, value, rel, arg_name=None, prim_name=None):
def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=None):
"""
Check argument integer.
@ -100,13 +100,13 @@ def _check_integer(arg_value, value, rel, arg_name=None, prim_name=None):
- number = check_integer(number, 0, Rel.GE, "number", None) # number >= 0
"""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
type_mismatch = not isinstance(arg_value, arg_type) or isinstance(arg_value, bool)
type_except = TypeError if type_mismatch else ValueError
if type_mismatch or not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(value)
arg_name = arg_name if arg_name else "parameter"
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
raise type_except(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got `{arg_value}`'
raise type_except(f'{msg_prefix} `{arg_name}` should be an {arg_type} and must {rel_str}, but got `{arg_value}`'
f' with type `{type(arg_value).__name__}`.')
return arg_value
@ -149,7 +149,7 @@ class Validator:
- number = check_positive_int(number)
- number = check_positive_int(number, "bias")
"""
return _check_integer(arg_value, 0, Rel.GT, arg_name, prim_name)
return check_number(arg_value, 0, Rel.GT, int, arg_name, prim_name)
@staticmethod
def check_negative_int(arg_value, arg_name=None, prim_name=None):
@ -160,7 +160,7 @@ class Validator:
- number = check_negative_int(number)
- number = check_negative_int(number, "bias")
"""
return _check_integer(arg_value, 0, Rel.LT, arg_name, prim_name)
return check_number(arg_value, 0, Rel.LT, int, arg_name, prim_name)
@staticmethod
def check_non_positive_int(arg_value, arg_name=None, prim_name=None):
@ -171,7 +171,7 @@ class Validator:
- number = check_non_positive_int(number)
- number = check_non_positive_int(number, "bias")
"""
return _check_integer(arg_value, 0, Rel.LE, arg_name, prim_name)
return check_number(arg_value, 0, Rel.LE, int, arg_name, prim_name)
@staticmethod
def check_non_negative_int(arg_value, arg_name=None, prim_name=None):
@ -182,7 +182,52 @@ class Validator:
- number = check_non_negative_int(number)
- number = check_non_negative_int(number, "bias")
"""
return _check_integer(arg_value, 0, Rel.GE, arg_name, prim_name)
return check_number(arg_value, 0, Rel.GE, int, arg_name, prim_name)
@staticmethod
def check_positive_float(arg_value, arg_name=None, prim_name=None):
"""
Check argument is positive float, which mean arg_value > 0.
Usage:
- number = check_positive_float(number)
- number = check_positive_float(number, "bias")
- number = check_positive_float(number, "bias", "bias_class")
"""
return check_number(arg_value, 0, Rel.GT, float, arg_name, prim_name)
@staticmethod
def check_negative_float(arg_value, arg_name=None, prim_name=None):
"""
Check argument is negative float, which mean arg_value < 0.
Usage:
- number = check_negative_float(number)
- number = check_negative_float(number, "bias")
"""
return check_number(arg_value, 0, Rel.LT, float, arg_name, prim_name)
@staticmethod
def check_non_positive_float(arg_value, arg_name=None, prim_name=None):
"""
Check argument is non-negative float, which mean arg_value <= 0.
Usage:
- number = check_non_positive_float(number)
- number = check_non_positive_float(number, "bias")
"""
return check_number(arg_value, 0, Rel.LE, float, arg_name, prim_name)
@staticmethod
def check_non_negative_float(arg_value, arg_name=None, prim_name=None):
"""
Check argument is non-negative float, which mean arg_value >= 0.
Usage:
- number = check_non_negative_float(number)
- number = check_non_negative_float(number, "bias")
"""
return check_number(arg_value, 0, Rel.GE, float, arg_name, prim_name)
@staticmethod
def check_number(arg_name, arg_value, value, rel, prim_name):
@ -257,16 +302,6 @@ class Validator:
raise ValueError(f"For '{prim_name}', padding must be zero when pad_mode is '{pad_mode}'.")
return padding
@staticmethod
def check_float_positive(arg_name, arg_value, prim_name):
"""Float type judgment."""
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
if isinstance(arg_value, float):
if arg_value > 0:
return arg_value
raise ValueError(f"{msg_prefix} `{arg_name}` must be positive, but got {arg_value}.")
raise TypeError(f"{msg_prefix} `{arg_name}` must be float.")
@staticmethod
def check_subclass(arg_name, type_, template_types, prim_name):
"""Checks whether some type is subclass of another type"""

View File

@ -82,12 +82,6 @@ def check_positive(value, arg_name=""):
raise ValueError("Input {0}must be greater than 0.".format(arg_name))
def check_positive_float(value, arg_name=""):
arg_name = pad_arg_name(arg_name)
type_check(value, (float,), arg_name)
check_positive(value, arg_name)
def check_2tuple(value, arg_name=""):
if not (isinstance(value, tuple) and len(value) == 2):
raise ValueError("Value {0}needs to be a 2-tuple.".format(arg_name))

View File

@ -66,9 +66,9 @@ def _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_e
validator.check_positive_int(total_step, 'total_step')
validator.check_positive_int(step_per_epoch, 'step_per_epoch')
validator.check_positive_int(decay_epoch, 'decay_epoch')
validator.check_float_positive('learning_rate', learning_rate, None)
validator.check_positive_float(learning_rate, 'learning_rate')
validator.check_float_legal_value('learning_rate', learning_rate, None)
validator.check_float_positive('decay_rate', decay_rate, None)
validator.check_positive_float(decay_rate, 'decay_rate')
validator.check_float_legal_value('decay_rate', decay_rate, None)
validator.check_value_type('is_stair', is_stair, [bool], None)
@ -234,7 +234,7 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch):
if not isinstance(min_lr, float):
raise TypeError("min_lr must be float.")
validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, None)
validator.check_float_positive('max_lr', max_lr, None)
validator.check_positive_float(max_lr, 'max_lr')
validator.check_float_legal_value('max_lr', max_lr, None)
validator.check_positive_int(total_step, 'total_step')
validator.check_positive_int(step_per_epoch, 'step_per_epoch')
@ -299,12 +299,12 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e
>>> polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power)
[0.1, 0.1, 0.07363961030678928, 0.07363961030678928, 0.01, 0.01]
"""
validator.check_float_positive('learning_rate', learning_rate, None)
validator.check_positive_float(learning_rate, 'learning_rate')
validator.check_float_legal_value('learning_rate', learning_rate, None)
if not isinstance(end_learning_rate, float):
raise TypeError("end_learning_rate must be float.")
validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None)
validator.check_float_positive('power', power, None)
validator.check_positive_float(power, 'power')
validator.check_float_legal_value('power', power, None)
validator.check_positive_int(total_step, 'total_step')
validator.check_positive_int(step_per_epoch, 'step_per_epoch')

View File

@ -221,7 +221,7 @@ class SSIM(Cell):
validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name)
self.max_val = max_val
self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name)
self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma, self.cls_name)
self.filter_sigma = validator.check_positive_float(filter_sigma, 'filter_sigma', self.cls_name)
self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name)
self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name)
window = _create_window(filter_size, filter_sigma)
@ -299,7 +299,7 @@ class MSSSIM(Cell):
self.max_val = max_val
validator.check_value_type('power_factors', power_factors, [tuple, list], self.cls_name)
self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name)
self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma, self.cls_name)
self.filter_sigma = validator.check_positive_float(filter_sigma, 'filter_sigma', self.cls_name)
self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name)
self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name)
window = _create_window(filter_size, filter_sigma)

View File

@ -45,9 +45,9 @@ class LearningRateSchedule(Cell):
def _check_inputs(learning_rate, decay_rate, decay_steps, is_stair, cls_name):
validator.check_positive_int(decay_steps, 'decay_steps', cls_name)
validator.check_float_positive('learning_rate', learning_rate, cls_name)
validator.check_positive_float(learning_rate, 'learning_rate', cls_name)
validator.check_float_legal_value('learning_rate', learning_rate, cls_name)
validator.check_float_positive('decay_rate', decay_rate, cls_name)
validator.check_positive_float(decay_rate, 'decay_rate', cls_name)
validator.check_float_legal_value('decay_rate', decay_rate, cls_name)
validator.check_value_type('is_stair', is_stair, [bool], cls_name)
@ -255,7 +255,7 @@ class CosineDecayLR(LearningRateSchedule):
if not isinstance(min_lr, float):
raise TypeError("min_lr must be float.")
validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
validator.check_float_positive('max_lr', max_lr, self.cls_name)
validator.check_positive_float(max_lr, 'max_lr', self.cls_name)
validator.check_float_legal_value('max_lr', max_lr, self.cls_name)
validator.check_positive_int(decay_steps, "decay_steps", self.cls_name)
if min_lr >= max_lr:
@ -318,7 +318,7 @@ class PolynomialDecayLR(LearningRateSchedule):
"""
def __init__(self, learning_rate, end_learning_rate, decay_steps, power, update_decay_steps=False):
super(PolynomialDecayLR, self).__init__()
validator.check_float_positive('learning_rate', learning_rate, None)
validator.check_positive_float(learning_rate, 'learning_rate')
validator.check_float_legal_value('learning_rate', learning_rate, None)
if not isinstance(end_learning_rate, float):
raise TypeError("end_learning_rate must be float.")
@ -326,7 +326,7 @@ class PolynomialDecayLR(LearningRateSchedule):
self.cls_name)
validator.check_positive_int(decay_steps, 'decay_steps', self.cls_name)
validator.check_value_type('update_decay_steps', update_decay_steps, [bool], self.cls_name)
validator.check_float_positive('power', power, self.cls_name)
validator.check_positive_float(power, 'power', self.cls_name)
validator.check_float_legal_value('power', power, self.cls_name)
self.decay_steps = decay_steps

View File

@ -503,7 +503,7 @@ class BatchNormFold(PrimitiveWithInfer):
def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0):
"""Initialize batch norm fold layer"""
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name)
self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
@ -546,7 +546,7 @@ class BatchNormFoldGrad(PrimitiveWithInfer):
"""Initialize BatchNormGrad layer"""
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name)
self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
self.init_prim_io_names(inputs=['d_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'global_step'],
outputs=['dx'])
@ -814,7 +814,7 @@ class BatchNormFoldD(PrimitiveWithInfer):
"""Initialize _BatchNormFold layer"""
from mindspore.ops._op_impl._custom_op import batchnorm_fold
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name)
self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
self.data_format = "NCHW"
@ -842,7 +842,7 @@ class BatchNormFoldGradD(PrimitiveWithInfer):
def __init__(self, epsilon=1e-5, is_training=True, freeze_bn=0):
"""Initialize _BatchNormFoldGrad layer"""
from mindspore.ops._op_impl._custom_op import batchnorm_fold_grad
self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name)
self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
self.init_prim_io_names(inputs=['d_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std'],

View File

@ -3560,7 +3560,7 @@ class IFMR(PrimitiveWithInfer):
validator.check_value_type("max_percentile", max_percentile, [float], self.name)
validator.check_value_type("search_range", search_range, [list, tuple], self.name)
for item in search_range:
validator.check_float_positive("item of search_range", item, self.name)
validator.check_positive_float(item, "item of search_range", self.name)
validator.check('search_range[1]', search_range[1], 'search_range[0]', search_range[0], Rel.GE, self.name)
validator.check_value_type("search_step", search_step, [float], self.name)
validator.check_value_type("offset_flag", with_offset, [bool], self.name)