forked from mindspore-Ecosystem/mindspore
Add cell name to error message
This commit is contained in:
parent
2d31ae97e8
commit
8cbbbd950e
|
@ -17,7 +17,7 @@ import re
|
|||
from enum import Enum
|
||||
from functools import reduce
|
||||
from itertools import repeat
|
||||
from collections import Iterable
|
||||
from collections.abc import Iterable
|
||||
|
||||
import numpy as np
|
||||
from mindspore import log as logger
|
||||
|
@ -98,7 +98,7 @@ class Validator:
|
|||
"""validator for checking input parameters"""
|
||||
|
||||
@staticmethod
|
||||
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None):
|
||||
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None, excp_cls=ValueError):
|
||||
"""
|
||||
Method for judging relation between two int values or list/tuple made up of ints.
|
||||
|
||||
|
@ -108,8 +108,8 @@ class Validator:
|
|||
rel_fn = Rel.get_fns(rel)
|
||||
if not rel_fn(arg_value, value):
|
||||
rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
|
||||
msg_prefix = f'For {prim_name} the' if prim_name else "The"
|
||||
raise ValueError(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.')
|
||||
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
|
||||
raise excp_cls(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.')
|
||||
|
||||
@staticmethod
|
||||
def check_integer(arg_name, arg_value, value, rel, prim_name):
|
||||
|
@ -118,8 +118,17 @@ class Validator:
|
|||
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
|
||||
if type_mismatch or not rel_fn(arg_value, value):
|
||||
rel_str = Rel.get_strs(rel).format(value)
|
||||
raise ValueError(f'For {prim_name} the `{arg_name}` should be an int and must {rel_str},'
|
||||
f' but got {arg_value}.')
|
||||
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
|
||||
raise ValueError(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
|
||||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
def check_number(arg_name, arg_value, value, rel, prim_name):
|
||||
"""Integer value judgment."""
|
||||
rel_fn = Rel.get_fns(rel)
|
||||
if not rel_fn(arg_value, value):
|
||||
rel_str = Rel.get_strs(rel).format(value)
|
||||
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must {rel_str}, but got {arg_value}.')
|
||||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
|
@ -133,9 +142,46 @@ class Validator:
|
|||
f' but got {arg_value}.')
|
||||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name):
|
||||
"""Method for checking whether a numeric value is in some range."""
|
||||
rel_fn = Rel.get_fns(rel)
|
||||
if not rel_fn(arg_value, lower_limit, upper_limit):
|
||||
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
|
||||
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be in range {rel_str}, but got {arg_value}.')
|
||||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
def check_string(arg_name, arg_value, valid_values, prim_name):
|
||||
"""Checks whether a string is in some value list"""
|
||||
if isinstance(arg_value, str) and arg_value in valid_values:
|
||||
return arg_value
|
||||
if len(valid_values) == 1:
|
||||
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be str and must be {valid_values[0]},'
|
||||
f' but got {arg_value}.')
|
||||
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be str and must be one of {valid_values},'
|
||||
f' but got {arg_value}.')
|
||||
|
||||
@staticmethod
|
||||
def check_pad_value_by_mode(pad_mode, padding, prim_name):
|
||||
"""Validates value of padding according to pad_mode"""
|
||||
if pad_mode != 'pad' and padding != 0:
|
||||
raise ValueError(f"For '{prim_name}', padding must be zero when pad_mode is '{pad_mode}'.")
|
||||
return padding
|
||||
|
||||
@staticmethod
|
||||
def check_float_positive(arg_name, arg_value, prim_name):
|
||||
"""Float type judgment."""
|
||||
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
|
||||
if isinstance(arg_value, float):
|
||||
if arg_value > 0:
|
||||
return arg_value
|
||||
raise ValueError(f"{msg_prefix} `{arg_name}` must be positive, but got {arg_value}.")
|
||||
raise TypeError(f"{msg_prefix} `{arg_name}` must be float.")
|
||||
|
||||
@staticmethod
|
||||
def check_subclass(arg_name, type_, template_type, prim_name):
|
||||
"""Check whether some type is sublcass of another type"""
|
||||
"""Checks whether some type is sublcass of another type"""
|
||||
if not isinstance(template_type, Iterable):
|
||||
template_type = (template_type,)
|
||||
if not any([mstype.issubclass_(type_, x) for x in template_type]):
|
||||
|
@ -143,16 +189,44 @@ class Validator:
|
|||
raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass'
|
||||
f' of {",".join((str(x) for x in template_type))}, but got {type_str}.')
|
||||
|
||||
@staticmethod
|
||||
def check_const_input(arg_name, arg_value, prim_name):
|
||||
"""Check valid value."""
|
||||
if arg_value is None:
|
||||
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must be a const input, but got {arg_value}.')
|
||||
|
||||
@staticmethod
|
||||
def check_scalar_type_same(args, valid_values, prim_name):
|
||||
"""check whether the types of inputs are the same."""
|
||||
def _check_tensor_type(arg):
|
||||
arg_key, arg_val = arg
|
||||
elem_type = arg_val
|
||||
if not elem_type in valid_values:
|
||||
raise TypeError(f'For \'{prim_name}\' type of `{arg_key}` should be in {valid_values},'
|
||||
f' but `{arg_key}` is {elem_type}.')
|
||||
return (arg_key, elem_type)
|
||||
|
||||
def _check_types_same(arg1, arg2):
|
||||
arg1_name, arg1_type = arg1
|
||||
arg2_name, arg2_type = arg2
|
||||
if arg1_type != arg2_type:
|
||||
raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,'
|
||||
f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
|
||||
return arg1
|
||||
|
||||
elem_types = map(_check_tensor_type, args.items())
|
||||
reduce(_check_types_same, elem_types)
|
||||
|
||||
@staticmethod
|
||||
def check_tensor_type_same(args, valid_values, prim_name):
|
||||
"""check whether the element types of input tensors are the same."""
|
||||
"""Checks whether the element types of input tensors are the same."""
|
||||
def _check_tensor_type(arg):
|
||||
arg_key, arg_val = arg
|
||||
Validator.check_subclass(arg_key, arg_val, mstype.tensor, prim_name)
|
||||
elem_type = arg_val.element_type()
|
||||
if not elem_type in valid_values:
|
||||
raise TypeError(f'For \'{prim_name}\' element type of `{arg_key}` should be in {valid_values},'
|
||||
f' but `{arg_key}` is {elem_type}.')
|
||||
f' but element type of `{arg_key}` is {elem_type}.')
|
||||
return (arg_key, elem_type)
|
||||
|
||||
def _check_types_same(arg1, arg2):
|
||||
|
@ -168,8 +242,13 @@ class Validator:
|
|||
|
||||
|
||||
@staticmethod
|
||||
def check_scalar_or_tensor_type_same(args, valid_values, prim_name):
|
||||
"""check whether the types of inputs are the same. if the input args are tensors, check their element types"""
|
||||
def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False):
|
||||
"""
|
||||
Checks whether the types of inputs are the same. If the input args are tensors, checks their element types.
|
||||
|
||||
If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised.
|
||||
"""
|
||||
|
||||
def _check_argument_type(arg):
|
||||
arg_key, arg_val = arg
|
||||
if isinstance(arg_val, type(mstype.tensor)):
|
||||
|
@ -188,6 +267,9 @@ class Validator:
|
|||
arg2_type = arg2_type.element_type()
|
||||
elif not (isinstance(arg1_type, type(mstype.tensor)) or isinstance(arg2_type, type(mstype.tensor))):
|
||||
pass
|
||||
elif allow_mix:
|
||||
arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor)) else arg1_type
|
||||
arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor)) else arg2_type
|
||||
else:
|
||||
excp_flag = True
|
||||
|
||||
|
@ -199,13 +281,14 @@ class Validator:
|
|||
|
||||
@staticmethod
|
||||
def check_value_type(arg_name, arg_value, valid_types, prim_name):
|
||||
"""Check whether a values is instance of some types."""
|
||||
"""Checks whether a value is instance of some types."""
|
||||
valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
|
||||
def raise_error_msg():
|
||||
"""func for raising error message when check failed"""
|
||||
type_names = [t.__name__ for t in valid_types]
|
||||
num_types = len(valid_types)
|
||||
raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be '
|
||||
f'{"one of " if num_types > 1 else ""}'
|
||||
msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The'
|
||||
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
|
||||
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
|
||||
|
||||
# Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
|
||||
|
@ -216,6 +299,23 @@ class Validator:
|
|||
return arg_value
|
||||
raise_error_msg()
|
||||
|
||||
@staticmethod
|
||||
def check_type_name(arg_name, arg_type, valid_types, prim_name):
|
||||
"""Checks whether a type in some specified types"""
|
||||
valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
|
||||
def get_typename(t):
|
||||
return t.__name__ if hasattr(t, '__name__') else str(t)
|
||||
|
||||
if arg_type in valid_types:
|
||||
return arg_type
|
||||
type_names = [get_typename(t) for t in valid_types]
|
||||
msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The'
|
||||
if len(valid_types) == 1:
|
||||
raise ValueError(f'{msg_prefix} type of `{arg_name}` should be {type_names[0]},'
|
||||
f' but got {get_typename(arg_type)}.')
|
||||
raise ValueError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},'
|
||||
f' but got {get_typename(arg_type)}.')
|
||||
|
||||
|
||||
class ParamValidator:
|
||||
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`"""
|
||||
|
|
|
@ -103,6 +103,10 @@ class Cell:
|
|||
def parameter_layout_dict(self):
|
||||
return self._parameter_layout_dict
|
||||
|
||||
@property
|
||||
def cls_name(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
@parameter_layout_dict.setter
|
||||
def parameter_layout_dict(self, value):
|
||||
if not isinstance(value, dict):
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
"""dynamic learning rate"""
|
||||
import math
|
||||
|
||||
from mindspore._checkparam import ParamValidator as validator
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
|
||||
|
||||
|
@ -43,16 +43,16 @@ def piecewise_constant_lr(milestone, learning_rates):
|
|||
>>> lr = piecewise_constant_lr(milestone, learning_rates)
|
||||
[0.1, 0.1, 0.05, 0.05, 0.05, 0.01, 0.01, 0.01, 0.01, 0.01]
|
||||
"""
|
||||
validator.check_type('milestone', milestone, (tuple, list))
|
||||
validator.check_type('learning_rates', learning_rates, (tuple, list))
|
||||
validator.check_value_type('milestone', milestone, (tuple, list), None)
|
||||
validator.check_value_type('learning_rates', learning_rates, (tuple, list), None)
|
||||
if len(milestone) != len(learning_rates):
|
||||
raise ValueError('The size of `milestone` must be same with the size of `learning_rates`.')
|
||||
|
||||
lr = []
|
||||
last_item = 0
|
||||
for i, item in enumerate(milestone):
|
||||
validator.check_integer(f'milestone[{i}]', item, 0, Rel.GT)
|
||||
validator.check_type(f'learning_rates[{i}]', learning_rates[i], [float])
|
||||
validator.check_integer(f'milestone[{i}]', item, 0, Rel.GT, None)
|
||||
validator.check_value_type(f'learning_rates[{i}]', learning_rates[i], [float], None)
|
||||
if item < last_item:
|
||||
raise ValueError(f'The value of milestone[{i}] must be greater than milestone[{i - 1}]')
|
||||
lr += [learning_rates[i]] * (item - last_item)
|
||||
|
@ -62,12 +62,12 @@ def piecewise_constant_lr(milestone, learning_rates):
|
|||
|
||||
|
||||
def _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair):
|
||||
validator.check_integer('total_step', total_step, 0, Rel.GT)
|
||||
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT)
|
||||
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT)
|
||||
validator.check_float_positive('learning_rate', learning_rate)
|
||||
validator.check_float_positive('decay_rate', decay_rate)
|
||||
validator.check_type('is_stair', is_stair, [bool])
|
||||
validator.check_integer('total_step', total_step, 0, Rel.GT, None)
|
||||
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None)
|
||||
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None)
|
||||
validator.check_float_positive('learning_rate', learning_rate, None)
|
||||
validator.check_float_positive('decay_rate', decay_rate, None)
|
||||
validator.check_value_type('is_stair', is_stair, [bool], None)
|
||||
|
||||
|
||||
def exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair=False):
|
||||
|
@ -228,11 +228,11 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch):
|
|||
>>> lr = cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch)
|
||||
[0.1, 0.1, 0.05500000000000001, 0.05500000000000001, 0.01, 0.01]
|
||||
"""
|
||||
validator.check_float_positive('min_lr', min_lr)
|
||||
validator.check_float_positive('max_lr', max_lr)
|
||||
validator.check_integer('total_step', total_step, 0, Rel.GT)
|
||||
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT)
|
||||
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT)
|
||||
validator.check_float_positive('min_lr', min_lr, None)
|
||||
validator.check_float_positive('max_lr', max_lr, None)
|
||||
validator.check_integer('total_step', total_step, 0, Rel.GT, None)
|
||||
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None)
|
||||
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None)
|
||||
|
||||
delta = 0.5 * (max_lr - min_lr)
|
||||
lr = []
|
||||
|
@ -279,13 +279,13 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e
|
|||
>>> lr = polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power)
|
||||
[0.1, 0.1, 0.07363961030678928, 0.07363961030678928, 0.01, 0.01]
|
||||
"""
|
||||
validator.check_float_positive('learning_rate', learning_rate)
|
||||
validator.check_float_positive('end_learning_rate', end_learning_rate)
|
||||
validator.check_integer('total_step', total_step, 0, Rel.GT)
|
||||
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT)
|
||||
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT)
|
||||
validator.check_type('power', power, [float])
|
||||
validator.check_type('update_decay_epoch', update_decay_epoch, [bool])
|
||||
validator.check_float_positive('learning_rate', learning_rate, None)
|
||||
validator.check_float_positive('end_learning_rate', end_learning_rate, None)
|
||||
validator.check_integer('total_step', total_step, 0, Rel.GT, None)
|
||||
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None)
|
||||
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None)
|
||||
validator.check_value_type('power', power, [float], None)
|
||||
validator.check_value_type('update_decay_epoch', update_decay_epoch, [bool], None)
|
||||
|
||||
function = lambda x, y: (x, min(x, y))
|
||||
if update_decay_epoch:
|
||||
|
|
|
@ -25,7 +25,7 @@ from mindspore.common.parameter import Parameter
|
|||
from mindspore._extends import cell_attr_register
|
||||
from ..cell import Cell
|
||||
from .activation import get_activation
|
||||
from ..._checkparam import ParamValidator as validator
|
||||
from ..._checkparam import Validator as validator
|
||||
|
||||
|
||||
class Dropout(Cell):
|
||||
|
@ -73,7 +73,7 @@ class Dropout(Cell):
|
|||
super(Dropout, self).__init__()
|
||||
if keep_prob <= 0 or keep_prob > 1:
|
||||
raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob))
|
||||
validator.check_subclass("dtype", dtype, mstype.number_type)
|
||||
validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
|
||||
self.keep_prob = Tensor(keep_prob)
|
||||
self.seed0 = seed0
|
||||
self.seed1 = seed1
|
||||
|
@ -421,7 +421,7 @@ class Pad(Cell):
|
|||
super(Pad, self).__init__()
|
||||
self.mode = mode
|
||||
self.paddings = paddings
|
||||
validator.check_string('mode', self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"])
|
||||
validator.check_string('mode', self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"], self.cls_name)
|
||||
if not isinstance(paddings, tuple):
|
||||
raise TypeError('Paddings must be tuple type.')
|
||||
for item in paddings:
|
||||
|
|
|
@ -19,7 +19,7 @@ from mindspore.ops import operations as P
|
|||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from ..cell import Cell
|
||||
from ..._checkparam import ParamValidator as validator
|
||||
from ..._checkparam import Validator as validator
|
||||
|
||||
|
||||
class Embedding(Cell):
|
||||
|
@ -59,7 +59,7 @@ class Embedding(Cell):
|
|||
"""
|
||||
def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal', dtype=mstype.float32):
|
||||
super(Embedding, self).__init__()
|
||||
validator.check_subclass("dtype", dtype, mstype.number_type)
|
||||
validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
|
||||
self.vocab_size = vocab_size
|
||||
self.embedding_size = embedding_size
|
||||
self.use_one_hot = use_one_hot
|
||||
|
|
|
@ -19,7 +19,7 @@ from mindspore.common.tensor import Tensor
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore._checkparam import ParamValidator as validator
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from ..cell import Cell
|
||||
|
||||
|
@ -134,15 +134,15 @@ class SSIM(Cell):
|
|||
"""
|
||||
def __init__(self, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03):
|
||||
super(SSIM, self).__init__()
|
||||
validator.check_type('max_val', max_val, [int, float])
|
||||
validator.check('max_val', max_val, '', 0.0, Rel.GT)
|
||||
validator.check_value_type('max_val', max_val, [int, float], self.cls_name)
|
||||
validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name)
|
||||
self.max_val = max_val
|
||||
self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE)
|
||||
self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma)
|
||||
validator.check_type('k1', k1, [float])
|
||||
self.k1 = validator.check_number_range('k1', k1, 0.0, 1.0, Rel.INC_NEITHER)
|
||||
validator.check_type('k2', k2, [float])
|
||||
self.k2 = validator.check_number_range('k2', k2, 0.0, 1.0, Rel.INC_NEITHER)
|
||||
self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name)
|
||||
self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma, self.cls_name)
|
||||
validator.check_value_type('k1', k1, [float], self.cls_name)
|
||||
self.k1 = validator.check_number_range('k1', k1, 0.0, 1.0, Rel.INC_NEITHER, self.cls_name)
|
||||
validator.check_value_type('k2', k2, [float], self.cls_name)
|
||||
self.k2 = validator.check_number_range('k2', k2, 0.0, 1.0, Rel.INC_NEITHER, self.cls_name)
|
||||
self.mean = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=filter_size)
|
||||
|
||||
def construct(self, img1, img2):
|
||||
|
@ -231,8 +231,8 @@ class PSNR(Cell):
|
|||
"""
|
||||
def __init__(self, max_val=1.0):
|
||||
super(PSNR, self).__init__()
|
||||
validator.check_type('max_val', max_val, [int, float])
|
||||
validator.check('max_val', max_val, '', 0.0, Rel.GT)
|
||||
validator.check_value_type('max_val', max_val, [int, float], self.cls_name)
|
||||
validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name)
|
||||
self.max_val = max_val
|
||||
|
||||
def construct(self, img1, img2):
|
||||
|
|
|
@ -17,7 +17,7 @@ from mindspore.ops import operations as P
|
|||
from mindspore.nn.cell import Cell
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore._checkparam import ParamValidator as validator
|
||||
from mindspore._checkparam import Validator as validator
|
||||
|
||||
|
||||
class LSTM(Cell):
|
||||
|
@ -114,7 +114,7 @@ class LSTM(Cell):
|
|||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.has_bias = has_bias
|
||||
self.batch_first = validator.check_type("batch_first", batch_first, [bool])
|
||||
self.batch_first = validator.check_value_type("batch_first", batch_first, [bool], self.cls_name)
|
||||
self.dropout = float(dropout)
|
||||
self.bidirectional = bidirectional
|
||||
|
||||
|
|
|
@ -14,8 +14,7 @@
|
|||
# ============================================================================
|
||||
"""pooling"""
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore._checkparam import ParamValidator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from ... import context
|
||||
from ..cell import Cell
|
||||
|
||||
|
@ -24,35 +23,27 @@ class _PoolNd(Cell):
|
|||
"""N-D AvgPool"""
|
||||
|
||||
def __init__(self, kernel_size, stride, pad_mode):
|
||||
name = self.__class__.__name__
|
||||
super(_PoolNd, self).__init__()
|
||||
validator.check_type('kernel_size', kernel_size, [int, tuple])
|
||||
validator.check_type('stride', stride, [int, tuple])
|
||||
self.pad_mode = validator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME'])
|
||||
self.pad_mode = validator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME'], self.cls_name)
|
||||
|
||||
if isinstance(kernel_size, int):
|
||||
validator.check_integer("kernel_size", kernel_size, 1, Rel.GE)
|
||||
else:
|
||||
if (len(kernel_size) != 2 or
|
||||
(not isinstance(kernel_size[0], int)) or
|
||||
(not isinstance(kernel_size[1], int)) or
|
||||
kernel_size[0] <= 0 or
|
||||
kernel_size[1] <= 0):
|
||||
raise ValueError(f'The kernel_size passed to cell {name} should be an positive int number or'
|
||||
f'a tuple of two positive int numbers, but got {kernel_size}')
|
||||
self.kernel_size = kernel_size
|
||||
def _check_int_or_tuple(arg_name, arg_value):
|
||||
validator.check_value_type(arg_name, arg_value, [int, tuple], self.cls_name)
|
||||
error_msg = f'For \'{self.cls_name}\' the {arg_name} should be an positive int number or ' \
|
||||
f'a tuple of two positive int numbers, but got {arg_value}'
|
||||
if isinstance(arg_value, int):
|
||||
if arg_value <= 0:
|
||||
raise ValueError(error_msg)
|
||||
elif len(arg_value) == 2:
|
||||
for item in arg_value:
|
||||
if isinstance(item, int) and item > 0:
|
||||
continue
|
||||
raise ValueError(error_msg)
|
||||
else:
|
||||
raise ValueError(error_msg)
|
||||
return arg_value
|
||||
|
||||
if isinstance(stride, int):
|
||||
validator.check_integer("stride", stride, 1, Rel.GE)
|
||||
else:
|
||||
if (len(stride) != 2 or
|
||||
(not isinstance(stride[0], int)) or
|
||||
(not isinstance(stride[1], int)) or
|
||||
stride[0] <= 0 or
|
||||
stride[1] <= 0):
|
||||
raise ValueError(f'The stride passed to cell {name} should be an positive int number or'
|
||||
f'a tuple of two positive int numbers, but got {stride}')
|
||||
self.stride = stride
|
||||
self.kernel_size = _check_int_or_tuple('kernel_size', kernel_size)
|
||||
self.stride = _check_int_or_tuple('stride', stride)
|
||||
|
||||
def construct(self, *inputs):
|
||||
pass
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
"""Fbeta."""
|
||||
import sys
|
||||
import numpy as np
|
||||
from mindspore._checkparam import ParamValidator as validator
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from .metric import Metric
|
||||
|
||||
|
||||
|
@ -104,7 +104,7 @@ class Fbeta(Metric):
|
|||
Returns:
|
||||
Float, computed result.
|
||||
"""
|
||||
validator.check_type("average", average, [bool])
|
||||
validator.check_value_type("average", average, [bool], self.__class__.__name__)
|
||||
if self._class_num == 0:
|
||||
raise RuntimeError('Input number of samples can not be 0.')
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ import sys
|
|||
|
||||
import numpy as np
|
||||
|
||||
from mindspore._checkparam import ParamValidator as validator
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from .evaluation import EvaluationBase
|
||||
|
||||
|
||||
|
@ -136,7 +136,7 @@ class Precision(EvaluationBase):
|
|||
if self._class_num == 0:
|
||||
raise RuntimeError('Input number of samples can not be 0.')
|
||||
|
||||
validator.check_type("average", average, [bool])
|
||||
validator.check_value_type("average", average, [bool], self.__class__.__name__)
|
||||
result = self._true_positives / (self._positives + self.eps)
|
||||
|
||||
if average:
|
||||
|
|
|
@ -17,7 +17,7 @@ import sys
|
|||
|
||||
import numpy as np
|
||||
|
||||
from mindspore._checkparam import ParamValidator as validator
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from .evaluation import EvaluationBase
|
||||
|
||||
|
||||
|
@ -136,7 +136,7 @@ class Recall(EvaluationBase):
|
|||
if self._class_num == 0:
|
||||
raise RuntimeError('Input number of samples can not be 0.')
|
||||
|
||||
validator.check_type("average", average, [bool])
|
||||
validator.check_value_type("average", average, [bool], self.__class__.__name__)
|
||||
result = self._true_positives / (self._actual_positives + self.eps)
|
||||
|
||||
if average:
|
||||
|
|
|
@ -22,7 +22,7 @@ from mindspore.ops import composite as C
|
|||
from mindspore.ops import functional as F
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._checkparam import ParamValidator as validator
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from .optimizer import Optimizer
|
||||
|
||||
|
@ -78,16 +78,16 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad
|
|||
return next_v
|
||||
|
||||
|
||||
def _check_param_value(beta1, beta2, eps, weight_decay):
|
||||
def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
|
||||
"""Check the type of inputs."""
|
||||
validator.check_type("beta1", beta1, [float])
|
||||
validator.check_type("beta2", beta2, [float])
|
||||
validator.check_type("eps", eps, [float])
|
||||
validator.check_type("weight_dacay", weight_decay, [float])
|
||||
validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER)
|
||||
validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER)
|
||||
validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER)
|
||||
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT)
|
||||
validator.check_value_type("beta1", beta1, [float], prim_name)
|
||||
validator.check_value_type("beta2", beta2, [float], prim_name)
|
||||
validator.check_value_type("eps", eps, [float], prim_name)
|
||||
validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
|
||||
validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
|
||||
validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
|
||||
validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name)
|
||||
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
|
||||
|
||||
|
||||
@adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
|
||||
|
@ -168,11 +168,11 @@ class Adam(Optimizer):
|
|||
use_nesterov=False, weight_decay=0.0, loss_scale=1.0,
|
||||
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
|
||||
super(Adam, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter)
|
||||
_check_param_value(beta1, beta2, eps, weight_decay)
|
||||
validator.check_type("use_locking", use_locking, [bool])
|
||||
validator.check_type("use_nesterov", use_nesterov, [bool])
|
||||
validator.check_type("loss_scale", loss_scale, [float])
|
||||
validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT)
|
||||
_check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
|
||||
validator.check_value_type("use_locking", use_locking, [bool], self.cls_name)
|
||||
validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name)
|
||||
validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name)
|
||||
validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT, self.cls_name)
|
||||
|
||||
self.beta1 = Tensor(beta1, mstype.float32)
|
||||
self.beta2 = Tensor(beta2, mstype.float32)
|
||||
|
@ -241,7 +241,7 @@ class AdamWeightDecay(Optimizer):
|
|||
"""
|
||||
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
|
||||
super(AdamWeightDecay, self).__init__(learning_rate, params)
|
||||
_check_param_value(beta1, beta2, eps, weight_decay)
|
||||
_check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
|
||||
self.lr = Tensor(np.array([learning_rate]).astype(np.float32))
|
||||
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
|
||||
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
|
||||
|
@ -304,7 +304,7 @@ class AdamWeightDecayDynamicLR(Optimizer):
|
|||
eps=1e-6,
|
||||
weight_decay=0.0):
|
||||
super(AdamWeightDecayDynamicLR, self).__init__(learning_rate, params)
|
||||
_check_param_value(beta1, beta2, eps, weight_decay)
|
||||
_check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
|
||||
|
||||
# turn them to scalar when me support scalar/tensor mix operations
|
||||
self.global_step = Parameter(initializer(0, [1]), name="global_step")
|
||||
|
|
|
@ -18,7 +18,7 @@ from mindspore.common.initializer import initializer
|
|||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore._checkparam import ParamValidator as validator
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from .optimizer import Optimizer, apply_decay, grad_scale
|
||||
|
||||
|
@ -30,29 +30,30 @@ def _tensor_run_opt(opt, learning_rate, l1, l2, lr_power, linear, gradient, weig
|
|||
success = F.depend(success, opt(weight, moment, linear, gradient, learning_rate, l1, l2, lr_power))
|
||||
return success
|
||||
|
||||
def _check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale=1.0, weight_decay=0.0):
|
||||
validator.check_type("initial_accum", initial_accum, [float])
|
||||
validator.check("initial_accum", initial_accum, "", 0.0, Rel.GE)
|
||||
def _check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale=1.0, weight_decay=0.0,
|
||||
prim_name=None):
|
||||
validator.check_value_type("initial_accum", initial_accum, [float], prim_name)
|
||||
validator.check_number("initial_accum", initial_accum, 0.0, Rel.GE, prim_name)
|
||||
|
||||
validator.check_type("learning_rate", learning_rate, [float])
|
||||
validator.check("learning_rate", learning_rate, "", 0.0, Rel.GT)
|
||||
validator.check_value_type("learning_rate", learning_rate, [float], prim_name)
|
||||
validator.check_number("learning_rate", learning_rate, 0.0, Rel.GT, prim_name)
|
||||
|
||||
validator.check_type("lr_power", lr_power, [float])
|
||||
validator.check("lr_power", lr_power, "", 0.0, Rel.LE)
|
||||
validator.check_value_type("lr_power", lr_power, [float], prim_name)
|
||||
validator.check_number("lr_power", lr_power, 0.0, Rel.LE, prim_name)
|
||||
|
||||
validator.check_type("l1", l1, [float])
|
||||
validator.check("l1", l1, "", 0.0, Rel.GE)
|
||||
validator.check_value_type("l1", l1, [float], prim_name)
|
||||
validator.check_number("l1", l1, 0.0, Rel.GE, prim_name)
|
||||
|
||||
validator.check_type("l2", l2, [float])
|
||||
validator.check("l2", l2, "", 0.0, Rel.GE)
|
||||
validator.check_value_type("l2", l2, [float], prim_name)
|
||||
validator.check_number("l2", l2, 0.0, Rel.GE, prim_name)
|
||||
|
||||
validator.check_type("use_locking", use_locking, [bool])
|
||||
validator.check_value_type("use_locking", use_locking, [bool], prim_name)
|
||||
|
||||
validator.check_type("loss_scale", loss_scale, [float])
|
||||
validator.check("loss_scale", loss_scale, "", 1.0, Rel.GE)
|
||||
validator.check_value_type("loss_scale", loss_scale, [float], prim_name)
|
||||
validator.check_number("loss_scale", loss_scale, 1.0, Rel.GE, prim_name)
|
||||
|
||||
validator.check_type("weight_decay", weight_decay, [float])
|
||||
validator.check("weight_decay", weight_decay, "", 0.0, Rel.GE)
|
||||
validator.check_value_type("weight_decay", weight_decay, [float], prim_name)
|
||||
validator.check_number("weight_decay", weight_decay, 0.0, Rel.GE, prim_name)
|
||||
|
||||
|
||||
class FTRL(Optimizer):
|
||||
|
@ -94,7 +95,8 @@ class FTRL(Optimizer):
|
|||
use_locking=False, loss_scale=1.0, weight_decay=0.0):
|
||||
super(FTRL, self).__init__(learning_rate, params)
|
||||
|
||||
_check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale, weight_decay)
|
||||
_check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale, weight_decay,
|
||||
self.cls_name)
|
||||
self.moments = self.parameters.clone(prefix="moments", init=initial_accum)
|
||||
self.linear = self.parameters.clone(prefix="linear", init='zeros')
|
||||
self.l1 = l1
|
||||
|
|
|
@ -21,7 +21,7 @@ from mindspore.ops import composite as C
|
|||
from mindspore.ops import functional as F
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._checkparam import ParamValidator as validator
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from .optimizer import Optimizer
|
||||
from .. import layer
|
||||
|
@ -109,23 +109,23 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
|
|||
|
||||
|
||||
def _check_param_value(decay_steps, warmup_steps, start_learning_rate,
|
||||
end_learning_rate, power, beta1, beta2, eps, weight_decay):
|
||||
end_learning_rate, power, beta1, beta2, eps, weight_decay, prim_name):
|
||||
|
||||
"""Check the type of inputs."""
|
||||
validator.check_type("decay_steps", decay_steps, [int])
|
||||
validator.check_type("warmup_steps", warmup_steps, [int])
|
||||
validator.check_type("start_learning_rate", start_learning_rate, [float])
|
||||
validator.check_type("end_learning_rate", end_learning_rate, [float])
|
||||
validator.check_type("power", power, [float])
|
||||
validator.check_type("beta1", beta1, [float])
|
||||
validator.check_type("beta2", beta2, [float])
|
||||
validator.check_type("eps", eps, [float])
|
||||
validator.check_type("weight_dacay", weight_decay, [float])
|
||||
validator.check_number_range("decay_steps", decay_steps, 1, float("inf"), Rel.INC_LEFT)
|
||||
validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER)
|
||||
validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER)
|
||||
validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER)
|
||||
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT)
|
||||
validator.check_value_type("decay_steps", decay_steps, [int], prim_name)
|
||||
validator.check_value_type("warmup_steps", warmup_steps, [int], prim_name)
|
||||
validator.check_value_type("start_learning_rate", start_learning_rate, [float], prim_name)
|
||||
validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name)
|
||||
validator.check_value_type("power", power, [float], prim_name)
|
||||
validator.check_value_type("beta1", beta1, [float], prim_name)
|
||||
validator.check_value_type("beta2", beta2, [float], prim_name)
|
||||
validator.check_value_type("eps", eps, [float], prim_name)
|
||||
validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
|
||||
validator.check_number_range("decay_steps", decay_steps, 1, float("inf"), Rel.INC_LEFT, prim_name)
|
||||
validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
|
||||
validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
|
||||
validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name)
|
||||
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
|
||||
|
||||
|
||||
class Lamb(Optimizer):
|
||||
|
@ -182,7 +182,7 @@ class Lamb(Optimizer):
|
|||
|
||||
super(Lamb, self).__init__(start_learning_rate, params)
|
||||
_check_param_value(decay_steps, warmup_steps, start_learning_rate, end_learning_rate,
|
||||
power, beta1, beta2, eps, weight_decay)
|
||||
power, beta1, beta2, eps, weight_decay, self.cls_name)
|
||||
|
||||
# turn them to scalar when me support scalar/tensor mix operations
|
||||
self.global_step = Parameter(initializer(0, [1]), name="global_step")
|
||||
|
|
|
@ -22,7 +22,7 @@ from mindspore.ops import functional as F, composite as C, operations as P
|
|||
from mindspore.nn.cell import Cell
|
||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore._checkparam import ParamValidator as validator
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore import log as logger
|
||||
|
@ -63,7 +63,7 @@ class Optimizer(Cell):
|
|||
self.gather = None
|
||||
self.assignadd = None
|
||||
self.global_step = None
|
||||
validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT)
|
||||
validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
|
||||
else:
|
||||
self.dynamic_lr = True
|
||||
self.gather = P.GatherV2()
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
"""rmsprop"""
|
||||
from mindspore.ops import functional as F, composite as C, operations as P
|
||||
from mindspore._checkparam import ParamValidator as validator
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from .optimizer import Optimizer
|
||||
|
||||
rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
|
||||
|
@ -144,8 +144,8 @@ class RMSProp(Optimizer):
|
|||
self.decay = decay
|
||||
self.epsilon = epsilon
|
||||
|
||||
validator.check_type("use_locking", use_locking, [bool])
|
||||
validator.check_type("centered", centered, [bool])
|
||||
validator.check_value_type("use_locking", use_locking, [bool], self.cls_name)
|
||||
validator.check_value_type("centered", centered, [bool], self.cls_name)
|
||||
self.centered = centered
|
||||
if centered:
|
||||
self.opt = P.ApplyCenteredRMSProp(use_locking)
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
"""sgd"""
|
||||
from mindspore.ops import functional as F, composite as C, operations as P
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore._checkparam import ParamValidator as validator
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from .optimizer import Optimizer
|
||||
|
||||
sgd_opt = C.MultitypeFuncGraph("sgd_opt")
|
||||
|
@ -100,7 +100,7 @@ class SGD(Optimizer):
|
|||
raise ValueError("dampening should be at least 0.0, but got dampening {}".format(dampening))
|
||||
self.dampening = dampening
|
||||
|
||||
validator.check_type("nesterov", nesterov, [bool])
|
||||
validator.check_value_type("nesterov", nesterov, [bool], self.cls_name)
|
||||
self.nesterov = nesterov
|
||||
|
||||
self.opt = P.SGD(dampening, weight_decay, nesterov)
|
||||
|
|
|
@ -19,7 +19,7 @@ import os
|
|||
import json
|
||||
import inspect
|
||||
from mindspore._c_expression import Oplib
|
||||
from mindspore._checkparam import ParamValidator as validator
|
||||
from mindspore._checkparam import Validator as validator
|
||||
|
||||
# path of built-in op info register.
|
||||
BUILT_IN_OPS_REGISTER_PATH = "mindspore/ops/_op_impl"
|
||||
|
@ -43,7 +43,7 @@ def op_info_register(op_info):
|
|||
op_info_real = json.dumps(op_info)
|
||||
else:
|
||||
op_info_real = op_info
|
||||
validator.check_type("op_info", op_info_real, [str])
|
||||
validator.check_value_type("op_info", op_info_real, [str], None)
|
||||
op_lib = Oplib()
|
||||
file_path = os.path.realpath(inspect.getfile(func))
|
||||
# keep the path custom ops implementation.
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
from easydict import EasyDict as edict
|
||||
|
||||
from .. import nn
|
||||
from .._checkparam import ParamValidator as validator
|
||||
from .._checkparam import Validator as validator
|
||||
from .._checkparam import Rel
|
||||
from ..common import dtype as mstype
|
||||
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
|
||||
|
@ -73,14 +73,14 @@ def _check_kwargs(key_words):
|
|||
raise ValueError(f"Unsupported arg '{arg}'")
|
||||
|
||||
if 'cast_model_type' in key_words:
|
||||
validator.check('cast_model_type', key_words['cast_model_type'],
|
||||
[mstype.float16, mstype.float32], Rel.IN)
|
||||
validator.check_type_name('cast_model_type', key_words['cast_model_type'],
|
||||
[mstype.float16, mstype.float32], None)
|
||||
if 'keep_batchnorm_fp32' in key_words:
|
||||
validator.check_isinstance('keep_batchnorm_fp32', key_words['keep_batchnorm_fp32'], bool)
|
||||
validator.check_value_type('keep_batchnorm_fp32', key_words['keep_batchnorm_fp32'], bool, None)
|
||||
if 'loss_scale_manager' in key_words:
|
||||
loss_scale_manager = key_words['loss_scale_manager']
|
||||
if loss_scale_manager:
|
||||
validator.check_isinstance('loss_scale_manager', loss_scale_manager, LossScaleManager)
|
||||
validator.check_value_type('loss_scale_manager', loss_scale_manager, LossScaleManager, None)
|
||||
|
||||
|
||||
def _add_loss_network(network, loss_fn, cast_model_type):
|
||||
|
@ -97,7 +97,7 @@ def _add_loss_network(network, loss_fn, cast_model_type):
|
|||
label = _mp_cast_helper(mstype.float32, label)
|
||||
return self._loss_fn(F.cast(out, mstype.float32), label)
|
||||
|
||||
validator.check_isinstance('loss_fn', loss_fn, nn.Cell)
|
||||
validator.check_value_type('loss_fn', loss_fn, nn.Cell, None)
|
||||
if cast_model_type == mstype.float16:
|
||||
network = WithLossCell(network, loss_fn)
|
||||
else:
|
||||
|
@ -126,9 +126,9 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
|
|||
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
|
||||
scale the loss by LossScaleManager. If set, overwrite the level setting.
|
||||
"""
|
||||
validator.check_isinstance('network', network, nn.Cell)
|
||||
validator.check_isinstance('optimizer', optimizer, nn.Optimizer)
|
||||
validator.check('level', level, "", ['O0', 'O2'], Rel.IN)
|
||||
validator.check_value_type('network', network, nn.Cell, None)
|
||||
validator.check_value_type('optimizer', optimizer, nn.Optimizer, None)
|
||||
validator.check('level', level, "", ['O0', 'O2'], Rel.IN, None)
|
||||
_check_kwargs(kwargs)
|
||||
config = dict(_config_level[level], **kwargs)
|
||||
config = edict(config)
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Loss scale manager abstract class."""
|
||||
from .._checkparam import ParamValidator as validator
|
||||
from .._checkparam import Validator as validator
|
||||
from .._checkparam import Rel
|
||||
from .. import nn
|
||||
|
||||
|
@ -97,7 +97,7 @@ class DynamicLossScaleManager(LossScaleManager):
|
|||
if init_loss_scale < 1.0:
|
||||
raise ValueError("Loss scale value should be > 1")
|
||||
self.loss_scale = init_loss_scale
|
||||
validator.check_integer("scale_window", scale_window, 0, Rel.GT)
|
||||
validator.check_integer("scale_window", scale_window, 0, Rel.GT, self.__class__.__name__)
|
||||
self.scale_window = scale_window
|
||||
if scale_factor <= 0:
|
||||
raise ValueError("Scale factor should be > 1")
|
||||
|
|
|
@ -32,7 +32,7 @@ power = 0.5
|
|||
class TestInputs:
|
||||
def test_milestone1(self):
|
||||
milestone1 = 1
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(TypeError):
|
||||
dr.piecewise_constant_lr(milestone1, learning_rates)
|
||||
|
||||
def test_milestone2(self):
|
||||
|
@ -46,12 +46,12 @@ class TestInputs:
|
|||
|
||||
def test_learning_rates1(self):
|
||||
lr = True
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(TypeError):
|
||||
dr.piecewise_constant_lr(milestone, lr)
|
||||
|
||||
def test_learning_rates2(self):
|
||||
lr = [1, 2, 1]
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(TypeError):
|
||||
dr.piecewise_constant_lr(milestone, lr)
|
||||
|
||||
def test_learning_rate_type(self):
|
||||
|
@ -158,7 +158,7 @@ class TestInputs:
|
|||
|
||||
def test_is_stair(self):
|
||||
is_stair = 1
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(TypeError):
|
||||
dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair)
|
||||
|
||||
def test_min_lr_type(self):
|
||||
|
@ -183,12 +183,12 @@ class TestInputs:
|
|||
|
||||
def test_power(self):
|
||||
power1 = True
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(TypeError):
|
||||
dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power1)
|
||||
|
||||
def test_update_decay_epoch(self):
|
||||
update_decay_epoch = 1
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(TypeError):
|
||||
dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch,
|
||||
power, update_decay_epoch)
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ def test_psnr_max_val_negative():
|
|||
|
||||
def test_psnr_max_val_bool():
|
||||
max_val = True
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(TypeError):
|
||||
net = PSNRNet(max_val)
|
||||
|
||||
def test_psnr_max_val_zero():
|
||||
|
|
|
@ -51,7 +51,7 @@ def test_ssim_max_val_negative():
|
|||
|
||||
def test_ssim_max_val_bool():
|
||||
max_val = True
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(TypeError):
|
||||
net = SSIMNet(max_val)
|
||||
|
||||
def test_ssim_max_val_zero():
|
||||
|
@ -92,4 +92,4 @@ def test_ssim_k1_k2_wrong_value():
|
|||
with pytest.raises(ValueError):
|
||||
net = SSIMNet(k2=0.0)
|
||||
with pytest.raises(ValueError):
|
||||
net = SSIMNet(k2=-1.0)
|
||||
net = SSIMNet(k2=-1.0)
|
||||
|
|
|
@ -577,14 +577,14 @@ test_cases_for_verify_exception = [
|
|||
('MaxPool2d_ValueError_2', {
|
||||
'block': (
|
||||
lambda _: nn.MaxPool2d(kernel_size=120, stride=True, pad_mode="valid"),
|
||||
{'exception': ValueError},
|
||||
{'exception': TypeError},
|
||||
),
|
||||
'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))],
|
||||
}),
|
||||
('MaxPool2d_ValueError_3', {
|
||||
'block': (
|
||||
lambda _: nn.MaxPool2d(kernel_size=3, stride=True, pad_mode="valid"),
|
||||
{'exception': ValueError},
|
||||
{'exception': TypeError},
|
||||
),
|
||||
'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))],
|
||||
}),
|
||||
|
|
|
@ -38,7 +38,7 @@ def test_avgpool2d_error_input():
|
|||
""" test_avgpool2d_error_input """
|
||||
kernel_size = 5
|
||||
stride = 2.3
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(TypeError):
|
||||
nn.AvgPool2d(kernel_size, stride)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue