Add cell name to error message

This commit is contained in:
fary86 2020-04-17 01:12:34 +08:00
parent 2d31ae97e8
commit 8cbbbd950e
25 changed files with 270 additions and 173 deletions

View File

@ -17,7 +17,7 @@ import re
from enum import Enum from enum import Enum
from functools import reduce from functools import reduce
from itertools import repeat from itertools import repeat
from collections import Iterable from collections.abc import Iterable
import numpy as np import numpy as np
from mindspore import log as logger from mindspore import log as logger
@ -98,7 +98,7 @@ class Validator:
"""validator for checking input parameters""" """validator for checking input parameters"""
@staticmethod @staticmethod
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None): def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None, excp_cls=ValueError):
""" """
Method for judging relation between two int values or list/tuple made up of ints. Method for judging relation between two int values or list/tuple made up of ints.
@ -108,8 +108,8 @@ class Validator:
rel_fn = Rel.get_fns(rel) rel_fn = Rel.get_fns(rel)
if not rel_fn(arg_value, value): if not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}') rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
msg_prefix = f'For {prim_name} the' if prim_name else "The" msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
raise ValueError(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.') raise excp_cls(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.')
@staticmethod @staticmethod
def check_integer(arg_name, arg_value, value, rel, prim_name): def check_integer(arg_name, arg_value, value, rel, prim_name):
@ -118,8 +118,17 @@ class Validator:
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool) type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
if type_mismatch or not rel_fn(arg_value, value): if type_mismatch or not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(value) rel_str = Rel.get_strs(rel).format(value)
raise ValueError(f'For {prim_name} the `{arg_name}` should be an int and must {rel_str},' msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
f' but got {arg_value}.') raise ValueError(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_number(arg_name, arg_value, value, rel, prim_name):
"""Integer value judgment."""
rel_fn = Rel.get_fns(rel)
if not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(value)
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must {rel_str}, but got {arg_value}.')
return arg_value return arg_value
@staticmethod @staticmethod
@ -133,9 +142,46 @@ class Validator:
f' but got {arg_value}.') f' but got {arg_value}.')
return arg_value return arg_value
@staticmethod
def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name):
"""Method for checking whether a numeric value is in some range."""
rel_fn = Rel.get_fns(rel)
if not rel_fn(arg_value, lower_limit, upper_limit):
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be in range {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_string(arg_name, arg_value, valid_values, prim_name):
"""Checks whether a string is in some value list"""
if isinstance(arg_value, str) and arg_value in valid_values:
return arg_value
if len(valid_values) == 1:
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be str and must be {valid_values[0]},'
f' but got {arg_value}.')
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be str and must be one of {valid_values},'
f' but got {arg_value}.')
@staticmethod
def check_pad_value_by_mode(pad_mode, padding, prim_name):
"""Validates value of padding according to pad_mode"""
if pad_mode != 'pad' and padding != 0:
raise ValueError(f"For '{prim_name}', padding must be zero when pad_mode is '{pad_mode}'.")
return padding
@staticmethod
def check_float_positive(arg_name, arg_value, prim_name):
"""Float type judgment."""
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
if isinstance(arg_value, float):
if arg_value > 0:
return arg_value
raise ValueError(f"{msg_prefix} `{arg_name}` must be positive, but got {arg_value}.")
raise TypeError(f"{msg_prefix} `{arg_name}` must be float.")
@staticmethod @staticmethod
def check_subclass(arg_name, type_, template_type, prim_name): def check_subclass(arg_name, type_, template_type, prim_name):
"""Check whether some type is sublcass of another type""" """Checks whether some type is sublcass of another type"""
if not isinstance(template_type, Iterable): if not isinstance(template_type, Iterable):
template_type = (template_type,) template_type = (template_type,)
if not any([mstype.issubclass_(type_, x) for x in template_type]): if not any([mstype.issubclass_(type_, x) for x in template_type]):
@ -143,16 +189,44 @@ class Validator:
raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass' raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass'
f' of {",".join((str(x) for x in template_type))}, but got {type_str}.') f' of {",".join((str(x) for x in template_type))}, but got {type_str}.')
@staticmethod
def check_const_input(arg_name, arg_value, prim_name):
"""Check valid value."""
if arg_value is None:
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must be a const input, but got {arg_value}.')
@staticmethod
def check_scalar_type_same(args, valid_values, prim_name):
"""check whether the types of inputs are the same."""
def _check_tensor_type(arg):
arg_key, arg_val = arg
elem_type = arg_val
if not elem_type in valid_values:
raise TypeError(f'For \'{prim_name}\' type of `{arg_key}` should be in {valid_values},'
f' but `{arg_key}` is {elem_type}.')
return (arg_key, elem_type)
def _check_types_same(arg1, arg2):
arg1_name, arg1_type = arg1
arg2_name, arg2_type = arg2
if arg1_type != arg2_type:
raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,'
f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
return arg1
elem_types = map(_check_tensor_type, args.items())
reduce(_check_types_same, elem_types)
@staticmethod @staticmethod
def check_tensor_type_same(args, valid_values, prim_name): def check_tensor_type_same(args, valid_values, prim_name):
"""check whether the element types of input tensors are the same.""" """Checks whether the element types of input tensors are the same."""
def _check_tensor_type(arg): def _check_tensor_type(arg):
arg_key, arg_val = arg arg_key, arg_val = arg
Validator.check_subclass(arg_key, arg_val, mstype.tensor, prim_name) Validator.check_subclass(arg_key, arg_val, mstype.tensor, prim_name)
elem_type = arg_val.element_type() elem_type = arg_val.element_type()
if not elem_type in valid_values: if not elem_type in valid_values:
raise TypeError(f'For \'{prim_name}\' element type of `{arg_key}` should be in {valid_values},' raise TypeError(f'For \'{prim_name}\' element type of `{arg_key}` should be in {valid_values},'
f' but `{arg_key}` is {elem_type}.') f' but element type of `{arg_key}` is {elem_type}.')
return (arg_key, elem_type) return (arg_key, elem_type)
def _check_types_same(arg1, arg2): def _check_types_same(arg1, arg2):
@ -168,8 +242,13 @@ class Validator:
@staticmethod @staticmethod
def check_scalar_or_tensor_type_same(args, valid_values, prim_name): def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False):
"""check whether the types of inputs are the same. if the input args are tensors, check their element types""" """
Checks whether the types of inputs are the same. If the input args are tensors, checks their element types.
If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised.
"""
def _check_argument_type(arg): def _check_argument_type(arg):
arg_key, arg_val = arg arg_key, arg_val = arg
if isinstance(arg_val, type(mstype.tensor)): if isinstance(arg_val, type(mstype.tensor)):
@ -188,6 +267,9 @@ class Validator:
arg2_type = arg2_type.element_type() arg2_type = arg2_type.element_type()
elif not (isinstance(arg1_type, type(mstype.tensor)) or isinstance(arg2_type, type(mstype.tensor))): elif not (isinstance(arg1_type, type(mstype.tensor)) or isinstance(arg2_type, type(mstype.tensor))):
pass pass
elif allow_mix:
arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor)) else arg1_type
arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor)) else arg2_type
else: else:
excp_flag = True excp_flag = True
@ -199,13 +281,14 @@ class Validator:
@staticmethod @staticmethod
def check_value_type(arg_name, arg_value, valid_types, prim_name): def check_value_type(arg_name, arg_value, valid_types, prim_name):
"""Check whether a values is instance of some types.""" """Checks whether a value is instance of some types."""
valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
def raise_error_msg(): def raise_error_msg():
"""func for raising error message when check failed""" """func for raising error message when check failed"""
type_names = [t.__name__ for t in valid_types] type_names = [t.__name__ for t in valid_types]
num_types = len(valid_types) num_types = len(valid_types)
raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be ' msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The'
f'{"one of " if num_types > 1 else ""}' raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
# Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and # Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
@ -216,6 +299,23 @@ class Validator:
return arg_value return arg_value
raise_error_msg() raise_error_msg()
@staticmethod
def check_type_name(arg_name, arg_type, valid_types, prim_name):
"""Checks whether a type in some specified types"""
valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
def get_typename(t):
return t.__name__ if hasattr(t, '__name__') else str(t)
if arg_type in valid_types:
return arg_type
type_names = [get_typename(t) for t in valid_types]
msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The'
if len(valid_types) == 1:
raise ValueError(f'{msg_prefix} type of `{arg_name}` should be {type_names[0]},'
f' but got {get_typename(arg_type)}.')
raise ValueError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},'
f' but got {get_typename(arg_type)}.')
class ParamValidator: class ParamValidator:
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`""" """Parameter validator. NOTICE: this class will be replaced by `class Validator`"""

View File

@ -103,6 +103,10 @@ class Cell:
def parameter_layout_dict(self): def parameter_layout_dict(self):
return self._parameter_layout_dict return self._parameter_layout_dict
@property
def cls_name(self):
return self.__class__.__name__
@parameter_layout_dict.setter @parameter_layout_dict.setter
def parameter_layout_dict(self, value): def parameter_layout_dict(self, value):
if not isinstance(value, dict): if not isinstance(value, dict):

View File

@ -15,7 +15,7 @@
"""dynamic learning rate""" """dynamic learning rate"""
import math import math
from mindspore._checkparam import ParamValidator as validator from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
@ -43,16 +43,16 @@ def piecewise_constant_lr(milestone, learning_rates):
>>> lr = piecewise_constant_lr(milestone, learning_rates) >>> lr = piecewise_constant_lr(milestone, learning_rates)
[0.1, 0.1, 0.05, 0.05, 0.05, 0.01, 0.01, 0.01, 0.01, 0.01] [0.1, 0.1, 0.05, 0.05, 0.05, 0.01, 0.01, 0.01, 0.01, 0.01]
""" """
validator.check_type('milestone', milestone, (tuple, list)) validator.check_value_type('milestone', milestone, (tuple, list), None)
validator.check_type('learning_rates', learning_rates, (tuple, list)) validator.check_value_type('learning_rates', learning_rates, (tuple, list), None)
if len(milestone) != len(learning_rates): if len(milestone) != len(learning_rates):
raise ValueError('The size of `milestone` must be same with the size of `learning_rates`.') raise ValueError('The size of `milestone` must be same with the size of `learning_rates`.')
lr = [] lr = []
last_item = 0 last_item = 0
for i, item in enumerate(milestone): for i, item in enumerate(milestone):
validator.check_integer(f'milestone[{i}]', item, 0, Rel.GT) validator.check_integer(f'milestone[{i}]', item, 0, Rel.GT, None)
validator.check_type(f'learning_rates[{i}]', learning_rates[i], [float]) validator.check_value_type(f'learning_rates[{i}]', learning_rates[i], [float], None)
if item < last_item: if item < last_item:
raise ValueError(f'The value of milestone[{i}] must be greater than milestone[{i - 1}]') raise ValueError(f'The value of milestone[{i}] must be greater than milestone[{i - 1}]')
lr += [learning_rates[i]] * (item - last_item) lr += [learning_rates[i]] * (item - last_item)
@ -62,12 +62,12 @@ def piecewise_constant_lr(milestone, learning_rates):
def _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair): def _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair):
validator.check_integer('total_step', total_step, 0, Rel.GT) validator.check_integer('total_step', total_step, 0, Rel.GT, None)
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT) validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None)
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT) validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None)
validator.check_float_positive('learning_rate', learning_rate) validator.check_float_positive('learning_rate', learning_rate, None)
validator.check_float_positive('decay_rate', decay_rate) validator.check_float_positive('decay_rate', decay_rate, None)
validator.check_type('is_stair', is_stair, [bool]) validator.check_value_type('is_stair', is_stair, [bool], None)
def exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair=False): def exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair=False):
@ -228,11 +228,11 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch):
>>> lr = cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch) >>> lr = cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch)
[0.1, 0.1, 0.05500000000000001, 0.05500000000000001, 0.01, 0.01] [0.1, 0.1, 0.05500000000000001, 0.05500000000000001, 0.01, 0.01]
""" """
validator.check_float_positive('min_lr', min_lr) validator.check_float_positive('min_lr', min_lr, None)
validator.check_float_positive('max_lr', max_lr) validator.check_float_positive('max_lr', max_lr, None)
validator.check_integer('total_step', total_step, 0, Rel.GT) validator.check_integer('total_step', total_step, 0, Rel.GT, None)
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT) validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None)
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT) validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None)
delta = 0.5 * (max_lr - min_lr) delta = 0.5 * (max_lr - min_lr)
lr = [] lr = []
@ -279,13 +279,13 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e
>>> lr = polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power) >>> lr = polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power)
[0.1, 0.1, 0.07363961030678928, 0.07363961030678928, 0.01, 0.01] [0.1, 0.1, 0.07363961030678928, 0.07363961030678928, 0.01, 0.01]
""" """
validator.check_float_positive('learning_rate', learning_rate) validator.check_float_positive('learning_rate', learning_rate, None)
validator.check_float_positive('end_learning_rate', end_learning_rate) validator.check_float_positive('end_learning_rate', end_learning_rate, None)
validator.check_integer('total_step', total_step, 0, Rel.GT) validator.check_integer('total_step', total_step, 0, Rel.GT, None)
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT) validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None)
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT) validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None)
validator.check_type('power', power, [float]) validator.check_value_type('power', power, [float], None)
validator.check_type('update_decay_epoch', update_decay_epoch, [bool]) validator.check_value_type('update_decay_epoch', update_decay_epoch, [bool], None)
function = lambda x, y: (x, min(x, y)) function = lambda x, y: (x, min(x, y))
if update_decay_epoch: if update_decay_epoch:

View File

@ -25,7 +25,7 @@ from mindspore.common.parameter import Parameter
from mindspore._extends import cell_attr_register from mindspore._extends import cell_attr_register
from ..cell import Cell from ..cell import Cell
from .activation import get_activation from .activation import get_activation
from ..._checkparam import ParamValidator as validator from ..._checkparam import Validator as validator
class Dropout(Cell): class Dropout(Cell):
@ -73,7 +73,7 @@ class Dropout(Cell):
super(Dropout, self).__init__() super(Dropout, self).__init__()
if keep_prob <= 0 or keep_prob > 1: if keep_prob <= 0 or keep_prob > 1:
raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob)) raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob))
validator.check_subclass("dtype", dtype, mstype.number_type) validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
self.keep_prob = Tensor(keep_prob) self.keep_prob = Tensor(keep_prob)
self.seed0 = seed0 self.seed0 = seed0
self.seed1 = seed1 self.seed1 = seed1
@ -421,7 +421,7 @@ class Pad(Cell):
super(Pad, self).__init__() super(Pad, self).__init__()
self.mode = mode self.mode = mode
self.paddings = paddings self.paddings = paddings
validator.check_string('mode', self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"]) validator.check_string('mode', self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"], self.cls_name)
if not isinstance(paddings, tuple): if not isinstance(paddings, tuple):
raise TypeError('Paddings must be tuple type.') raise TypeError('Paddings must be tuple type.')
for item in paddings: for item in paddings:

View File

@ -19,7 +19,7 @@ from mindspore.ops import operations as P
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from ..cell import Cell from ..cell import Cell
from ..._checkparam import ParamValidator as validator from ..._checkparam import Validator as validator
class Embedding(Cell): class Embedding(Cell):
@ -59,7 +59,7 @@ class Embedding(Cell):
""" """
def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal', dtype=mstype.float32): def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal', dtype=mstype.float32):
super(Embedding, self).__init__() super(Embedding, self).__init__()
validator.check_subclass("dtype", dtype, mstype.number_type) validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.embedding_size = embedding_size self.embedding_size = embedding_size
self.use_one_hot = use_one_hot self.use_one_hot = use_one_hot

View File

@ -19,7 +19,7 @@ from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops.primitive import constexpr from mindspore.ops.primitive import constexpr
from mindspore._checkparam import ParamValidator as validator from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from ..cell import Cell from ..cell import Cell
@ -134,15 +134,15 @@ class SSIM(Cell):
""" """
def __init__(self, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03): def __init__(self, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03):
super(SSIM, self).__init__() super(SSIM, self).__init__()
validator.check_type('max_val', max_val, [int, float]) validator.check_value_type('max_val', max_val, [int, float], self.cls_name)
validator.check('max_val', max_val, '', 0.0, Rel.GT) validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name)
self.max_val = max_val self.max_val = max_val
self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE) self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name)
self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma) self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma, self.cls_name)
validator.check_type('k1', k1, [float]) validator.check_value_type('k1', k1, [float], self.cls_name)
self.k1 = validator.check_number_range('k1', k1, 0.0, 1.0, Rel.INC_NEITHER) self.k1 = validator.check_number_range('k1', k1, 0.0, 1.0, Rel.INC_NEITHER, self.cls_name)
validator.check_type('k2', k2, [float]) validator.check_value_type('k2', k2, [float], self.cls_name)
self.k2 = validator.check_number_range('k2', k2, 0.0, 1.0, Rel.INC_NEITHER) self.k2 = validator.check_number_range('k2', k2, 0.0, 1.0, Rel.INC_NEITHER, self.cls_name)
self.mean = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=filter_size) self.mean = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=filter_size)
def construct(self, img1, img2): def construct(self, img1, img2):
@ -231,8 +231,8 @@ class PSNR(Cell):
""" """
def __init__(self, max_val=1.0): def __init__(self, max_val=1.0):
super(PSNR, self).__init__() super(PSNR, self).__init__()
validator.check_type('max_val', max_val, [int, float]) validator.check_value_type('max_val', max_val, [int, float], self.cls_name)
validator.check('max_val', max_val, '', 0.0, Rel.GT) validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name)
self.max_val = max_val self.max_val = max_val
def construct(self, img1, img2): def construct(self, img1, img2):

View File

@ -17,7 +17,7 @@ from mindspore.ops import operations as P
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore._checkparam import ParamValidator as validator from mindspore._checkparam import Validator as validator
class LSTM(Cell): class LSTM(Cell):
@ -114,7 +114,7 @@ class LSTM(Cell):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.num_layers = num_layers self.num_layers = num_layers
self.has_bias = has_bias self.has_bias = has_bias
self.batch_first = validator.check_type("batch_first", batch_first, [bool]) self.batch_first = validator.check_value_type("batch_first", batch_first, [bool], self.cls_name)
self.dropout = float(dropout) self.dropout = float(dropout)
self.bidirectional = bidirectional self.bidirectional = bidirectional

View File

@ -14,8 +14,7 @@
# ============================================================================ # ============================================================================
"""pooling""" """pooling"""
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore._checkparam import ParamValidator as validator from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from ... import context from ... import context
from ..cell import Cell from ..cell import Cell
@ -24,35 +23,27 @@ class _PoolNd(Cell):
"""N-D AvgPool""" """N-D AvgPool"""
def __init__(self, kernel_size, stride, pad_mode): def __init__(self, kernel_size, stride, pad_mode):
name = self.__class__.__name__
super(_PoolNd, self).__init__() super(_PoolNd, self).__init__()
validator.check_type('kernel_size', kernel_size, [int, tuple]) self.pad_mode = validator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME'], self.cls_name)
validator.check_type('stride', stride, [int, tuple])
self.pad_mode = validator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME'])
if isinstance(kernel_size, int): def _check_int_or_tuple(arg_name, arg_value):
validator.check_integer("kernel_size", kernel_size, 1, Rel.GE) validator.check_value_type(arg_name, arg_value, [int, tuple], self.cls_name)
else: error_msg = f'For \'{self.cls_name}\' the {arg_name} should be an positive int number or ' \
if (len(kernel_size) != 2 or f'a tuple of two positive int numbers, but got {arg_value}'
(not isinstance(kernel_size[0], int)) or if isinstance(arg_value, int):
(not isinstance(kernel_size[1], int)) or if arg_value <= 0:
kernel_size[0] <= 0 or raise ValueError(error_msg)
kernel_size[1] <= 0): elif len(arg_value) == 2:
raise ValueError(f'The kernel_size passed to cell {name} should be an positive int number or' for item in arg_value:
f'a tuple of two positive int numbers, but got {kernel_size}') if isinstance(item, int) and item > 0:
self.kernel_size = kernel_size continue
raise ValueError(error_msg)
else:
raise ValueError(error_msg)
return arg_value
if isinstance(stride, int): self.kernel_size = _check_int_or_tuple('kernel_size', kernel_size)
validator.check_integer("stride", stride, 1, Rel.GE) self.stride = _check_int_or_tuple('stride', stride)
else:
if (len(stride) != 2 or
(not isinstance(stride[0], int)) or
(not isinstance(stride[1], int)) or
stride[0] <= 0 or
stride[1] <= 0):
raise ValueError(f'The stride passed to cell {name} should be an positive int number or'
f'a tuple of two positive int numbers, but got {stride}')
self.stride = stride
def construct(self, *inputs): def construct(self, *inputs):
pass pass

View File

@ -15,7 +15,7 @@
"""Fbeta.""" """Fbeta."""
import sys import sys
import numpy as np import numpy as np
from mindspore._checkparam import ParamValidator as validator from mindspore._checkparam import Validator as validator
from .metric import Metric from .metric import Metric
@ -104,7 +104,7 @@ class Fbeta(Metric):
Returns: Returns:
Float, computed result. Float, computed result.
""" """
validator.check_type("average", average, [bool]) validator.check_value_type("average", average, [bool], self.__class__.__name__)
if self._class_num == 0: if self._class_num == 0:
raise RuntimeError('Input number of samples can not be 0.') raise RuntimeError('Input number of samples can not be 0.')

View File

@ -17,7 +17,7 @@ import sys
import numpy as np import numpy as np
from mindspore._checkparam import ParamValidator as validator from mindspore._checkparam import Validator as validator
from .evaluation import EvaluationBase from .evaluation import EvaluationBase
@ -136,7 +136,7 @@ class Precision(EvaluationBase):
if self._class_num == 0: if self._class_num == 0:
raise RuntimeError('Input number of samples can not be 0.') raise RuntimeError('Input number of samples can not be 0.')
validator.check_type("average", average, [bool]) validator.check_value_type("average", average, [bool], self.__class__.__name__)
result = self._true_positives / (self._positives + self.eps) result = self._true_positives / (self._positives + self.eps)
if average: if average:

View File

@ -17,7 +17,7 @@ import sys
import numpy as np import numpy as np
from mindspore._checkparam import ParamValidator as validator from mindspore._checkparam import Validator as validator
from .evaluation import EvaluationBase from .evaluation import EvaluationBase
@ -136,7 +136,7 @@ class Recall(EvaluationBase):
if self._class_num == 0: if self._class_num == 0:
raise RuntimeError('Input number of samples can not be 0.') raise RuntimeError('Input number of samples can not be 0.')
validator.check_type("average", average, [bool]) validator.check_value_type("average", average, [bool], self.__class__.__name__)
result = self._true_positives / (self._actual_positives + self.eps) result = self._true_positives / (self._actual_positives + self.eps)
if average: if average:

View File

@ -22,7 +22,7 @@ from mindspore.ops import composite as C
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore._checkparam import ParamValidator as validator from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from .optimizer import Optimizer from .optimizer import Optimizer
@ -78,16 +78,16 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad
return next_v return next_v
def _check_param_value(beta1, beta2, eps, weight_decay): def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
"""Check the type of inputs.""" """Check the type of inputs."""
validator.check_type("beta1", beta1, [float]) validator.check_value_type("beta1", beta1, [float], prim_name)
validator.check_type("beta2", beta2, [float]) validator.check_value_type("beta2", beta2, [float], prim_name)
validator.check_type("eps", eps, [float]) validator.check_value_type("eps", eps, [float], prim_name)
validator.check_type("weight_dacay", weight_decay, [float]) validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER) validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER) validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER) validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name)
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT) validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
@adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", @adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
@ -168,11 +168,11 @@ class Adam(Optimizer):
use_nesterov=False, weight_decay=0.0, loss_scale=1.0, use_nesterov=False, weight_decay=0.0, loss_scale=1.0,
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
super(Adam, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter) super(Adam, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter)
_check_param_value(beta1, beta2, eps, weight_decay) _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
validator.check_type("use_locking", use_locking, [bool]) validator.check_value_type("use_locking", use_locking, [bool], self.cls_name)
validator.check_type("use_nesterov", use_nesterov, [bool]) validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name)
validator.check_type("loss_scale", loss_scale, [float]) validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name)
validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT) validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT, self.cls_name)
self.beta1 = Tensor(beta1, mstype.float32) self.beta1 = Tensor(beta1, mstype.float32)
self.beta2 = Tensor(beta2, mstype.float32) self.beta2 = Tensor(beta2, mstype.float32)
@ -241,7 +241,7 @@ class AdamWeightDecay(Optimizer):
""" """
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0): def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
super(AdamWeightDecay, self).__init__(learning_rate, params) super(AdamWeightDecay, self).__init__(learning_rate, params)
_check_param_value(beta1, beta2, eps, weight_decay) _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
self.lr = Tensor(np.array([learning_rate]).astype(np.float32)) self.lr = Tensor(np.array([learning_rate]).astype(np.float32))
self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
@ -304,7 +304,7 @@ class AdamWeightDecayDynamicLR(Optimizer):
eps=1e-6, eps=1e-6,
weight_decay=0.0): weight_decay=0.0):
super(AdamWeightDecayDynamicLR, self).__init__(learning_rate, params) super(AdamWeightDecayDynamicLR, self).__init__(learning_rate, params)
_check_param_value(beta1, beta2, eps, weight_decay) _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
# turn them to scalar when me support scalar/tensor mix operations # turn them to scalar when me support scalar/tensor mix operations
self.global_step = Parameter(initializer(0, [1]), name="global_step") self.global_step = Parameter(initializer(0, [1]), name="global_step")

View File

@ -18,7 +18,7 @@ from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common import Tensor from mindspore.common import Tensor
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore._checkparam import ParamValidator as validator from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from .optimizer import Optimizer, apply_decay, grad_scale from .optimizer import Optimizer, apply_decay, grad_scale
@ -30,29 +30,30 @@ def _tensor_run_opt(opt, learning_rate, l1, l2, lr_power, linear, gradient, weig
success = F.depend(success, opt(weight, moment, linear, gradient, learning_rate, l1, l2, lr_power)) success = F.depend(success, opt(weight, moment, linear, gradient, learning_rate, l1, l2, lr_power))
return success return success
def _check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale=1.0, weight_decay=0.0): def _check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale=1.0, weight_decay=0.0,
validator.check_type("initial_accum", initial_accum, [float]) prim_name=None):
validator.check("initial_accum", initial_accum, "", 0.0, Rel.GE) validator.check_value_type("initial_accum", initial_accum, [float], prim_name)
validator.check_number("initial_accum", initial_accum, 0.0, Rel.GE, prim_name)
validator.check_type("learning_rate", learning_rate, [float]) validator.check_value_type("learning_rate", learning_rate, [float], prim_name)
validator.check("learning_rate", learning_rate, "", 0.0, Rel.GT) validator.check_number("learning_rate", learning_rate, 0.0, Rel.GT, prim_name)
validator.check_type("lr_power", lr_power, [float]) validator.check_value_type("lr_power", lr_power, [float], prim_name)
validator.check("lr_power", lr_power, "", 0.0, Rel.LE) validator.check_number("lr_power", lr_power, 0.0, Rel.LE, prim_name)
validator.check_type("l1", l1, [float]) validator.check_value_type("l1", l1, [float], prim_name)
validator.check("l1", l1, "", 0.0, Rel.GE) validator.check_number("l1", l1, 0.0, Rel.GE, prim_name)
validator.check_type("l2", l2, [float]) validator.check_value_type("l2", l2, [float], prim_name)
validator.check("l2", l2, "", 0.0, Rel.GE) validator.check_number("l2", l2, 0.0, Rel.GE, prim_name)
validator.check_type("use_locking", use_locking, [bool]) validator.check_value_type("use_locking", use_locking, [bool], prim_name)
validator.check_type("loss_scale", loss_scale, [float]) validator.check_value_type("loss_scale", loss_scale, [float], prim_name)
validator.check("loss_scale", loss_scale, "", 1.0, Rel.GE) validator.check_number("loss_scale", loss_scale, 1.0, Rel.GE, prim_name)
validator.check_type("weight_decay", weight_decay, [float]) validator.check_value_type("weight_decay", weight_decay, [float], prim_name)
validator.check("weight_decay", weight_decay, "", 0.0, Rel.GE) validator.check_number("weight_decay", weight_decay, 0.0, Rel.GE, prim_name)
class FTRL(Optimizer): class FTRL(Optimizer):
@ -94,7 +95,8 @@ class FTRL(Optimizer):
use_locking=False, loss_scale=1.0, weight_decay=0.0): use_locking=False, loss_scale=1.0, weight_decay=0.0):
super(FTRL, self).__init__(learning_rate, params) super(FTRL, self).__init__(learning_rate, params)
_check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale, weight_decay) _check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale, weight_decay,
self.cls_name)
self.moments = self.parameters.clone(prefix="moments", init=initial_accum) self.moments = self.parameters.clone(prefix="moments", init=initial_accum)
self.linear = self.parameters.clone(prefix="linear", init='zeros') self.linear = self.parameters.clone(prefix="linear", init='zeros')
self.l1 = l1 self.l1 = l1

View File

@ -21,7 +21,7 @@ from mindspore.ops import composite as C
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore._checkparam import ParamValidator as validator from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from .optimizer import Optimizer from .optimizer import Optimizer
from .. import layer from .. import layer
@ -109,23 +109,23 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
def _check_param_value(decay_steps, warmup_steps, start_learning_rate, def _check_param_value(decay_steps, warmup_steps, start_learning_rate,
end_learning_rate, power, beta1, beta2, eps, weight_decay): end_learning_rate, power, beta1, beta2, eps, weight_decay, prim_name):
"""Check the type of inputs.""" """Check the type of inputs."""
validator.check_type("decay_steps", decay_steps, [int]) validator.check_value_type("decay_steps", decay_steps, [int], prim_name)
validator.check_type("warmup_steps", warmup_steps, [int]) validator.check_value_type("warmup_steps", warmup_steps, [int], prim_name)
validator.check_type("start_learning_rate", start_learning_rate, [float]) validator.check_value_type("start_learning_rate", start_learning_rate, [float], prim_name)
validator.check_type("end_learning_rate", end_learning_rate, [float]) validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name)
validator.check_type("power", power, [float]) validator.check_value_type("power", power, [float], prim_name)
validator.check_type("beta1", beta1, [float]) validator.check_value_type("beta1", beta1, [float], prim_name)
validator.check_type("beta2", beta2, [float]) validator.check_value_type("beta2", beta2, [float], prim_name)
validator.check_type("eps", eps, [float]) validator.check_value_type("eps", eps, [float], prim_name)
validator.check_type("weight_dacay", weight_decay, [float]) validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
validator.check_number_range("decay_steps", decay_steps, 1, float("inf"), Rel.INC_LEFT) validator.check_number_range("decay_steps", decay_steps, 1, float("inf"), Rel.INC_LEFT, prim_name)
validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER) validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER) validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER) validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name)
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT) validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
class Lamb(Optimizer): class Lamb(Optimizer):
@ -182,7 +182,7 @@ class Lamb(Optimizer):
super(Lamb, self).__init__(start_learning_rate, params) super(Lamb, self).__init__(start_learning_rate, params)
_check_param_value(decay_steps, warmup_steps, start_learning_rate, end_learning_rate, _check_param_value(decay_steps, warmup_steps, start_learning_rate, end_learning_rate,
power, beta1, beta2, eps, weight_decay) power, beta1, beta2, eps, weight_decay, self.cls_name)
# turn them to scalar when me support scalar/tensor mix operations # turn them to scalar when me support scalar/tensor mix operations
self.global_step = Parameter(initializer(0, [1]), name="global_step") self.global_step = Parameter(initializer(0, [1]), name="global_step")

View File

@ -22,7 +22,7 @@ from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore._checkparam import ParamValidator as validator from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore import log as logger from mindspore import log as logger
@ -63,7 +63,7 @@ class Optimizer(Cell):
self.gather = None self.gather = None
self.assignadd = None self.assignadd = None
self.global_step = None self.global_step = None
validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT) validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
else: else:
self.dynamic_lr = True self.dynamic_lr = True
self.gather = P.GatherV2() self.gather = P.GatherV2()

View File

@ -14,7 +14,7 @@
# ============================================================================ # ============================================================================
"""rmsprop""" """rmsprop"""
from mindspore.ops import functional as F, composite as C, operations as P from mindspore.ops import functional as F, composite as C, operations as P
from mindspore._checkparam import ParamValidator as validator from mindspore._checkparam import Validator as validator
from .optimizer import Optimizer from .optimizer import Optimizer
rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
@ -144,8 +144,8 @@ class RMSProp(Optimizer):
self.decay = decay self.decay = decay
self.epsilon = epsilon self.epsilon = epsilon
validator.check_type("use_locking", use_locking, [bool]) validator.check_value_type("use_locking", use_locking, [bool], self.cls_name)
validator.check_type("centered", centered, [bool]) validator.check_value_type("centered", centered, [bool], self.cls_name)
self.centered = centered self.centered = centered
if centered: if centered:
self.opt = P.ApplyCenteredRMSProp(use_locking) self.opt = P.ApplyCenteredRMSProp(use_locking)

View File

@ -15,7 +15,7 @@
"""sgd""" """sgd"""
from mindspore.ops import functional as F, composite as C, operations as P from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore._checkparam import ParamValidator as validator from mindspore._checkparam import Validator as validator
from .optimizer import Optimizer from .optimizer import Optimizer
sgd_opt = C.MultitypeFuncGraph("sgd_opt") sgd_opt = C.MultitypeFuncGraph("sgd_opt")
@ -100,7 +100,7 @@ class SGD(Optimizer):
raise ValueError("dampening should be at least 0.0, but got dampening {}".format(dampening)) raise ValueError("dampening should be at least 0.0, but got dampening {}".format(dampening))
self.dampening = dampening self.dampening = dampening
validator.check_type("nesterov", nesterov, [bool]) validator.check_value_type("nesterov", nesterov, [bool], self.cls_name)
self.nesterov = nesterov self.nesterov = nesterov
self.opt = P.SGD(dampening, weight_decay, nesterov) self.opt = P.SGD(dampening, weight_decay, nesterov)

View File

@ -19,7 +19,7 @@ import os
import json import json
import inspect import inspect
from mindspore._c_expression import Oplib from mindspore._c_expression import Oplib
from mindspore._checkparam import ParamValidator as validator from mindspore._checkparam import Validator as validator
# path of built-in op info register. # path of built-in op info register.
BUILT_IN_OPS_REGISTER_PATH = "mindspore/ops/_op_impl" BUILT_IN_OPS_REGISTER_PATH = "mindspore/ops/_op_impl"
@ -43,7 +43,7 @@ def op_info_register(op_info):
op_info_real = json.dumps(op_info) op_info_real = json.dumps(op_info)
else: else:
op_info_real = op_info op_info_real = op_info
validator.check_type("op_info", op_info_real, [str]) validator.check_value_type("op_info", op_info_real, [str], None)
op_lib = Oplib() op_lib = Oplib()
file_path = os.path.realpath(inspect.getfile(func)) file_path = os.path.realpath(inspect.getfile(func))
# keep the path custom ops implementation. # keep the path custom ops implementation.

View File

@ -16,7 +16,7 @@
from easydict import EasyDict as edict from easydict import EasyDict as edict
from .. import nn from .. import nn
from .._checkparam import ParamValidator as validator from .._checkparam import Validator as validator
from .._checkparam import Rel from .._checkparam import Rel
from ..common import dtype as mstype from ..common import dtype as mstype
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
@ -73,14 +73,14 @@ def _check_kwargs(key_words):
raise ValueError(f"Unsupported arg '{arg}'") raise ValueError(f"Unsupported arg '{arg}'")
if 'cast_model_type' in key_words: if 'cast_model_type' in key_words:
validator.check('cast_model_type', key_words['cast_model_type'], validator.check_type_name('cast_model_type', key_words['cast_model_type'],
[mstype.float16, mstype.float32], Rel.IN) [mstype.float16, mstype.float32], None)
if 'keep_batchnorm_fp32' in key_words: if 'keep_batchnorm_fp32' in key_words:
validator.check_isinstance('keep_batchnorm_fp32', key_words['keep_batchnorm_fp32'], bool) validator.check_value_type('keep_batchnorm_fp32', key_words['keep_batchnorm_fp32'], bool, None)
if 'loss_scale_manager' in key_words: if 'loss_scale_manager' in key_words:
loss_scale_manager = key_words['loss_scale_manager'] loss_scale_manager = key_words['loss_scale_manager']
if loss_scale_manager: if loss_scale_manager:
validator.check_isinstance('loss_scale_manager', loss_scale_manager, LossScaleManager) validator.check_value_type('loss_scale_manager', loss_scale_manager, LossScaleManager, None)
def _add_loss_network(network, loss_fn, cast_model_type): def _add_loss_network(network, loss_fn, cast_model_type):
@ -97,7 +97,7 @@ def _add_loss_network(network, loss_fn, cast_model_type):
label = _mp_cast_helper(mstype.float32, label) label = _mp_cast_helper(mstype.float32, label)
return self._loss_fn(F.cast(out, mstype.float32), label) return self._loss_fn(F.cast(out, mstype.float32), label)
validator.check_isinstance('loss_fn', loss_fn, nn.Cell) validator.check_value_type('loss_fn', loss_fn, nn.Cell, None)
if cast_model_type == mstype.float16: if cast_model_type == mstype.float16:
network = WithLossCell(network, loss_fn) network = WithLossCell(network, loss_fn)
else: else:
@ -126,9 +126,9 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
scale the loss by LossScaleManager. If set, overwrite the level setting. scale the loss by LossScaleManager. If set, overwrite the level setting.
""" """
validator.check_isinstance('network', network, nn.Cell) validator.check_value_type('network', network, nn.Cell, None)
validator.check_isinstance('optimizer', optimizer, nn.Optimizer) validator.check_value_type('optimizer', optimizer, nn.Optimizer, None)
validator.check('level', level, "", ['O0', 'O2'], Rel.IN) validator.check('level', level, "", ['O0', 'O2'], Rel.IN, None)
_check_kwargs(kwargs) _check_kwargs(kwargs)
config = dict(_config_level[level], **kwargs) config = dict(_config_level[level], **kwargs)
config = edict(config) config = edict(config)

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Loss scale manager abstract class.""" """Loss scale manager abstract class."""
from .._checkparam import ParamValidator as validator from .._checkparam import Validator as validator
from .._checkparam import Rel from .._checkparam import Rel
from .. import nn from .. import nn
@ -97,7 +97,7 @@ class DynamicLossScaleManager(LossScaleManager):
if init_loss_scale < 1.0: if init_loss_scale < 1.0:
raise ValueError("Loss scale value should be > 1") raise ValueError("Loss scale value should be > 1")
self.loss_scale = init_loss_scale self.loss_scale = init_loss_scale
validator.check_integer("scale_window", scale_window, 0, Rel.GT) validator.check_integer("scale_window", scale_window, 0, Rel.GT, self.__class__.__name__)
self.scale_window = scale_window self.scale_window = scale_window
if scale_factor <= 0: if scale_factor <= 0:
raise ValueError("Scale factor should be > 1") raise ValueError("Scale factor should be > 1")

View File

@ -32,7 +32,7 @@ power = 0.5
class TestInputs: class TestInputs:
def test_milestone1(self): def test_milestone1(self):
milestone1 = 1 milestone1 = 1
with pytest.raises(ValueError): with pytest.raises(TypeError):
dr.piecewise_constant_lr(milestone1, learning_rates) dr.piecewise_constant_lr(milestone1, learning_rates)
def test_milestone2(self): def test_milestone2(self):
@ -46,12 +46,12 @@ class TestInputs:
def test_learning_rates1(self): def test_learning_rates1(self):
lr = True lr = True
with pytest.raises(ValueError): with pytest.raises(TypeError):
dr.piecewise_constant_lr(milestone, lr) dr.piecewise_constant_lr(milestone, lr)
def test_learning_rates2(self): def test_learning_rates2(self):
lr = [1, 2, 1] lr = [1, 2, 1]
with pytest.raises(ValueError): with pytest.raises(TypeError):
dr.piecewise_constant_lr(milestone, lr) dr.piecewise_constant_lr(milestone, lr)
def test_learning_rate_type(self): def test_learning_rate_type(self):
@ -158,7 +158,7 @@ class TestInputs:
def test_is_stair(self): def test_is_stair(self):
is_stair = 1 is_stair = 1
with pytest.raises(ValueError): with pytest.raises(TypeError):
dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair) dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair)
def test_min_lr_type(self): def test_min_lr_type(self):
@ -183,12 +183,12 @@ class TestInputs:
def test_power(self): def test_power(self):
power1 = True power1 = True
with pytest.raises(ValueError): with pytest.raises(TypeError):
dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power1) dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power1)
def test_update_decay_epoch(self): def test_update_decay_epoch(self):
update_decay_epoch = 1 update_decay_epoch = 1
with pytest.raises(ValueError): with pytest.raises(TypeError):
dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch,
power, update_decay_epoch) power, update_decay_epoch)

View File

@ -52,7 +52,7 @@ def test_psnr_max_val_negative():
def test_psnr_max_val_bool(): def test_psnr_max_val_bool():
max_val = True max_val = True
with pytest.raises(ValueError): with pytest.raises(TypeError):
net = PSNRNet(max_val) net = PSNRNet(max_val)
def test_psnr_max_val_zero(): def test_psnr_max_val_zero():

View File

@ -51,7 +51,7 @@ def test_ssim_max_val_negative():
def test_ssim_max_val_bool(): def test_ssim_max_val_bool():
max_val = True max_val = True
with pytest.raises(ValueError): with pytest.raises(TypeError):
net = SSIMNet(max_val) net = SSIMNet(max_val)
def test_ssim_max_val_zero(): def test_ssim_max_val_zero():

View File

@ -577,14 +577,14 @@ test_cases_for_verify_exception = [
('MaxPool2d_ValueError_2', { ('MaxPool2d_ValueError_2', {
'block': ( 'block': (
lambda _: nn.MaxPool2d(kernel_size=120, stride=True, pad_mode="valid"), lambda _: nn.MaxPool2d(kernel_size=120, stride=True, pad_mode="valid"),
{'exception': ValueError}, {'exception': TypeError},
), ),
'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))], 'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))],
}), }),
('MaxPool2d_ValueError_3', { ('MaxPool2d_ValueError_3', {
'block': ( 'block': (
lambda _: nn.MaxPool2d(kernel_size=3, stride=True, pad_mode="valid"), lambda _: nn.MaxPool2d(kernel_size=3, stride=True, pad_mode="valid"),
{'exception': ValueError}, {'exception': TypeError},
), ),
'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))], 'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))],
}), }),

View File

@ -38,7 +38,7 @@ def test_avgpool2d_error_input():
""" test_avgpool2d_error_input """ """ test_avgpool2d_error_input """
kernel_size = 5 kernel_size = 5
stride = 2.3 stride = 2.3
with pytest.raises(ValueError): with pytest.raises(TypeError):
nn.AvgPool2d(kernel_size, stride) nn.AvgPool2d(kernel_size, stride)