[ME] delete check_bool and replace with Validate.check_bool
This commit is contained in:
parent
6c9b6d491d
commit
d4e8e94981
|
@ -26,10 +26,6 @@ from mindspore import log as logger
|
|||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
# Named string regular expression
|
||||
_name_re = r"^\w+[0-9a-zA-Z\_\.]*$"
|
||||
|
||||
|
||||
class Rel(Enum):
|
||||
"""Numerical relationship between variables, logical relationship enumeration definition of range."""
|
||||
# scalar compare
|
||||
|
@ -114,7 +110,7 @@ class Validator:
|
|||
|
||||
@staticmethod
|
||||
def check_integer(arg_name, arg_value, value, rel, prim_name=None):
|
||||
"""Integer value judgment."""
|
||||
"""Check argument is integer"""
|
||||
rel_fn = Rel.get_fns(rel)
|
||||
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
|
||||
excp_cls = TypeError if type_mismatch else ValueError
|
||||
|
@ -125,6 +121,7 @@ class Validator:
|
|||
f' with type `{type(arg_value).__name__}`.')
|
||||
return arg_value
|
||||
|
||||
|
||||
@staticmethod
|
||||
def check_number(arg_name, arg_value, value, rel, prim_name):
|
||||
"""Number value judgment."""
|
||||
|
@ -142,10 +139,11 @@ class Validator:
|
|||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
def check_bool(arg_name, arg_value):
|
||||
"""Check arg isinstance of bool"""
|
||||
def check_bool(arg_value, arg_name=None):
|
||||
"""Check argument is instance of bool"""
|
||||
if not isinstance(arg_value, bool):
|
||||
raise ValueError(f'The `{arg_name}` should be isinstance of bool, but got {arg_value}.')
|
||||
arg_name = arg_name if arg_name else "Parameter"
|
||||
raise TypeError(f'`{arg_name}` should be isinstance of bool, but got `{arg_value}`.')
|
||||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
|
@ -170,15 +168,14 @@ class Validator:
|
|||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
def check_string(arg_name, arg_value, valid_values, prim_name):
|
||||
def check_string(arg_value, valid_values, arg_name=None, prim_name=None):
|
||||
"""Checks whether a string is in some value list"""
|
||||
if isinstance(arg_value, str) and arg_value in valid_values:
|
||||
return arg_value
|
||||
if len(valid_values) == 1:
|
||||
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be str and must be {valid_values[0]},'
|
||||
f' but got {arg_value}.')
|
||||
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be str and must be one of {valid_values},'
|
||||
f' but got {arg_value}.')
|
||||
arg_name = arg_name if arg_name else "Parameter"
|
||||
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
|
||||
raise ValueError(f'{msg_prefix} `{arg_name}` should be str and must be in `{valid_values}`,'
|
||||
f' but got `{arg_value}`.')
|
||||
|
||||
@staticmethod
|
||||
def check_pad_value_by_mode(pad_mode, padding, prim_name):
|
||||
|
@ -404,24 +401,6 @@ def check_int_zero_one(input_param):
|
|||
raise ValueError("The data must be 0 or 1.")
|
||||
|
||||
|
||||
def check_bool(input_param):
|
||||
"""Bool type judgment."""
|
||||
if isinstance(input_param, bool):
|
||||
return input_param
|
||||
raise TypeError("Input type must be bool!")
|
||||
|
||||
|
||||
def check_string(input_param, valid_values):
|
||||
"""String type judgment."""
|
||||
if isinstance(input_param, str) and input_param in valid_values:
|
||||
return input_param
|
||||
if len(valid_values) == 1:
|
||||
raise ValueError(f'Input should be str and must be {valid_values[0]},'
|
||||
f' but got {input_param}.')
|
||||
raise ValueError(f'Input should be str and must be one of {valid_values},'
|
||||
f' but got {input_param}.')
|
||||
|
||||
|
||||
def check_input_format(input_param):
|
||||
"""Judge input format."""
|
||||
if input_param == "NCHW":
|
||||
|
@ -587,7 +566,8 @@ def check_shape(arg_name, arg_value):
|
|||
|
||||
def _check_str_by_regular(target, reg=None, flag=re.ASCII):
|
||||
if reg is None:
|
||||
reg = _name_re
|
||||
# Named string regular expression
|
||||
reg = r"^\w+[0-9a-zA-Z\_\.]*$"
|
||||
if re.match(reg, target, flag) is None:
|
||||
raise ValueError("'{}' is illegal, it should be match regular'{}' by flags'{}'".format(target, reg, flag))
|
||||
return True
|
||||
|
|
|
@ -27,7 +27,7 @@ from mindspore.ops.operations import _inner_ops as inner
|
|||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore._extends import cell_attr_register
|
||||
from mindspore._checkparam import Rel, Validator as validator, check_int_positive, check_bool
|
||||
from mindspore._checkparam import Rel, Validator, check_int_positive
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore import context
|
||||
from ..cell import Cell
|
||||
|
@ -86,8 +86,8 @@ class Dropout(Cell):
|
|||
super(Dropout, self).__init__()
|
||||
if keep_prob <= 0 or keep_prob > 1:
|
||||
raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob))
|
||||
validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
|
||||
validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name)
|
||||
Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
|
||||
Validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name)
|
||||
self.keep_prob = keep_prob
|
||||
seed0 = get_seed()
|
||||
self.seed0 = seed0 if seed0 is not None else 0
|
||||
|
@ -205,7 +205,7 @@ class Dense(Cell):
|
|||
super(Dense, self).__init__()
|
||||
self.in_channels = check_int_positive(in_channels)
|
||||
self.out_channels = check_int_positive(out_channels)
|
||||
self.has_bias = check_bool(has_bias)
|
||||
self.has_bias = Validator.check_bool(has_bias)
|
||||
|
||||
if isinstance(weight_init, Tensor):
|
||||
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
|
||||
|
@ -348,7 +348,7 @@ class Norm(Cell):
|
|||
|
||||
def __init__(self, axis=(), keep_dims=False):
|
||||
super(Norm, self).__init__()
|
||||
validator.check_value_type("keep_dims", keep_dims, [bool], self.cls_name)
|
||||
Validator.check_value_type("keep_dims", keep_dims, [bool], self.cls_name)
|
||||
self.axis = axis
|
||||
self.keep_dims = keep_dims
|
||||
self.reduce_sum = P.ReduceSum(True)
|
||||
|
@ -472,7 +472,7 @@ class Pad(Cell):
|
|||
super(Pad, self).__init__()
|
||||
self.mode = mode
|
||||
self.paddings = paddings
|
||||
validator.check_string('mode', self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"], self.cls_name)
|
||||
Validator.check_string(self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"], 'mode', self.cls_name)
|
||||
if not isinstance(paddings, tuple):
|
||||
raise TypeError('Paddings must be tuple type.')
|
||||
for item in paddings:
|
||||
|
@ -549,7 +549,7 @@ class Unfold(Cell):
|
|||
|
||||
@constexpr
|
||||
def _get_matrix_diag_assist(x_shape, x_dtype):
|
||||
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, "_get_matrix_diag_assist")
|
||||
Validator.check_integer("x rank", len(x_shape), 1, Rel.GE, "_get_matrix_diag_assist")
|
||||
base_eye = np.eye(x_shape[-1], x_shape[-1]).reshape(-1)
|
||||
assist = np.tile(base_eye, x_shape[:-1]).reshape(x_shape + (x_shape[-1],))
|
||||
return Tensor(assist, x_dtype)
|
||||
|
@ -557,7 +557,7 @@ def _get_matrix_diag_assist(x_shape, x_dtype):
|
|||
|
||||
@constexpr
|
||||
def _get_matrix_diag_part_assist(x_shape, x_dtype):
|
||||
validator.check_integer("x rank", len(x_shape), 2, Rel.GE, "_get_matrix_diag_part_assist")
|
||||
Validator.check_integer("x rank", len(x_shape), 2, Rel.GE, "_get_matrix_diag_part_assist")
|
||||
base_eye = np.eye(x_shape[-2], x_shape[-1]).reshape(-1)
|
||||
assist = np.tile(base_eye, x_shape[:-2]).reshape(x_shape)
|
||||
return Tensor(assist, x_dtype)
|
||||
|
|
|
@ -21,7 +21,7 @@ from mindspore.ops.primitive import constexpr
|
|||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer, Initializer
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._checkparam import Validator, Rel, check_bool, twice, check_int_positive
|
||||
from mindspore._checkparam import Validator, Rel, twice, check_int_positive
|
||||
from mindspore._extends import cell_attr_register
|
||||
from ..cell import Cell
|
||||
|
||||
|
@ -92,7 +92,7 @@ class _Conv(Cell):
|
|||
shape = [out_channels, in_channels // group, *kernel_size]
|
||||
self.weight = Parameter(initializer(self.weight_init, shape), name='weight')
|
||||
|
||||
if check_bool(has_bias):
|
||||
if Validator.check_bool(has_bias):
|
||||
self.bias = Parameter(initializer(self.bias_init, [out_channels]), name='bias')
|
||||
else:
|
||||
if self.bias_init != 'zeros':
|
||||
|
@ -566,7 +566,7 @@ class Conv2dTranspose(_Conv):
|
|||
self.is_valid = self.pad_mode == 'valid'
|
||||
self.is_same = self.pad_mode == 'same'
|
||||
self.is_pad = self.pad_mode == 'pad'
|
||||
if check_bool(has_bias):
|
||||
if Validator.check_bool(has_bias):
|
||||
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
|
||||
|
||||
# cause Conv2DBackpropInput's out_channel refers to Conv2D's out_channel.
|
||||
|
@ -745,7 +745,7 @@ class Conv1dTranspose(_Conv):
|
|||
self.is_valid = self.pad_mode == 'valid'
|
||||
self.is_same = self.pad_mode == 'same'
|
||||
self.is_pad = self.pad_mode == 'pad'
|
||||
if check_bool(has_bias):
|
||||
if Validator.check_bool(has_bias):
|
||||
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
|
||||
|
||||
# cause Conv2DBackpropInput's out_channel refers to Conv2D's out_channel.
|
||||
|
|
|
@ -19,7 +19,7 @@ from mindspore.common.parameter import Parameter
|
|||
from mindspore.common.initializer import initializer
|
||||
from mindspore.ops.primitive import constexpr
|
||||
import mindspore.context as context
|
||||
from mindspore._checkparam import check_bool, check_typename, check_int_positive
|
||||
from mindspore._checkparam import Validator, check_typename, check_int_positive
|
||||
from mindspore._extends import cell_attr_register
|
||||
from mindspore.communication.management import get_group_size, get_rank
|
||||
from mindspore.communication import management
|
||||
|
@ -604,7 +604,7 @@ class GroupNorm(Cell):
|
|||
if num_channels % num_groups != 0:
|
||||
raise ValueError("num_channels should be divided by num_groups")
|
||||
self.eps = check_typename('eps', eps, (float,))
|
||||
self.affine = check_bool(affine)
|
||||
self.affine = Validator.check_bool(affine)
|
||||
|
||||
gamma = initializer(gamma_init, num_channels)
|
||||
beta = initializer(beta_init, num_channels)
|
||||
|
|
|
@ -27,7 +27,7 @@ class _PoolNd(Cell):
|
|||
|
||||
def __init__(self, kernel_size, stride, pad_mode):
|
||||
super(_PoolNd, self).__init__()
|
||||
self.pad_mode = validator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME'], self.cls_name)
|
||||
self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.cls_name)
|
||||
|
||||
def _check_int_or_tuple(arg_name, arg_value):
|
||||
validator.check_value_type(arg_name, arg_value, [int, tuple], self.cls_name)
|
||||
|
@ -270,7 +270,7 @@ class AvgPool1d(_PoolNd):
|
|||
super(AvgPool1d, self).__init__(kernel_size, stride, pad_mode)
|
||||
validator.check_value_type('kernel_size', kernel_size, [int], self.cls_name)
|
||||
validator.check_value_type('stride', stride, [int], self.cls_name)
|
||||
self.pad_mode = validator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME'], self.cls_name)
|
||||
self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.cls_name)
|
||||
validator.check_integer("kernel_size", kernel_size, 1, Rel.GE, self.cls_name)
|
||||
validator.check_integer("stride", stride, 1, Rel.GE, self.cls_name)
|
||||
self.kernel_size = (1, kernel_size)
|
||||
|
|
|
@ -23,7 +23,7 @@ from mindspore.ops import functional as F
|
|||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._checkparam import Rel, check_int_positive, check_bool, twice, Validator
|
||||
from mindspore._checkparam import Validator, Rel, check_int_positive, twice
|
||||
import mindspore.context as context
|
||||
from .normalization import BatchNorm2d, BatchNorm1d
|
||||
from .activation import get_activation, ReLU, LeakyReLU
|
||||
|
@ -133,7 +133,7 @@ class Conv2dBnAct(Cell):
|
|||
has_bias=has_bias,
|
||||
weight_init=weight_init,
|
||||
bias_init=bias_init)
|
||||
self.has_bn = Validator.check_bool("has_bn", has_bn)
|
||||
self.has_bn = Validator.check_bool(has_bn, "has_bn")
|
||||
self.has_act = activation is not None
|
||||
self.after_fake = after_fake
|
||||
if has_bn:
|
||||
|
@ -201,7 +201,7 @@ class DenseBnAct(Cell):
|
|||
weight_init,
|
||||
bias_init,
|
||||
has_bias)
|
||||
self.has_bn = Validator.check_bool("has_bn", has_bn)
|
||||
self.has_bn = Validator.check_bool(has_bn, "has_bn")
|
||||
self.has_act = activation is not None
|
||||
self.after_fake = after_fake
|
||||
if has_bn:
|
||||
|
@ -511,7 +511,7 @@ class Conv2dBnFoldQuant(Cell):
|
|||
channel_axis = 0
|
||||
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
|
||||
self.bias_add = P.BiasAdd()
|
||||
if check_bool(has_bias):
|
||||
if Validator.check_bool(has_bias):
|
||||
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
|
||||
else:
|
||||
self.bias = None
|
||||
|
@ -668,7 +668,7 @@ class Conv2dBnWithoutFoldQuant(Cell):
|
|||
self.quant_delay = quant_delay
|
||||
|
||||
self.bias_add = P.BiasAdd()
|
||||
if check_bool(has_bias):
|
||||
if Validator.check_bool(has_bias):
|
||||
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
|
||||
else:
|
||||
self.bias = None
|
||||
|
@ -799,7 +799,7 @@ class Conv2dQuant(Cell):
|
|||
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
|
||||
|
||||
self.bias_add = P.BiasAdd()
|
||||
if check_bool(has_bias):
|
||||
if Validator.check_bool(has_bias):
|
||||
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
|
||||
else:
|
||||
self.bias = None
|
||||
|
@ -888,7 +888,7 @@ class DenseQuant(Cell):
|
|||
super(DenseQuant, self).__init__()
|
||||
self.in_channels = check_int_positive(in_channels)
|
||||
self.out_channels = check_int_positive(out_channels)
|
||||
self.has_bias = check_bool(has_bias)
|
||||
self.has_bias = Validator.check_bool(has_bias)
|
||||
|
||||
if isinstance(weight_init, Tensor):
|
||||
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
|
||||
|
|
|
@ -18,8 +18,7 @@ from mindspore.ops import _selected_ops
|
|||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore._checkparam import check_bool
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Validator
|
||||
from .optimizer import Optimizer
|
||||
|
||||
_momentum_opt = C.MultitypeFuncGraph("momentum_opt")
|
||||
|
@ -126,12 +125,12 @@ class Momentum(Optimizer):
|
|||
"""
|
||||
def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0, use_nesterov=False):
|
||||
super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale)
|
||||
validator.check_value_type("momentum", momentum, [float], self.cls_name)
|
||||
Validator.check_value_type("momentum", momentum, [float], self.cls_name)
|
||||
if isinstance(momentum, float) and momentum < 0.0:
|
||||
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
|
||||
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
|
||||
self.params = self.parameters
|
||||
self.use_nesterov = check_bool(use_nesterov)
|
||||
self.use_nesterov = Validator.check_bool(use_nesterov)
|
||||
self.moments = self.params.clone(prefix="moments", init='zeros')
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.opt = _selected_ops.ApplyMomentum(use_nesterov=self.use_nesterov)
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
"""dense_variational"""
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._checkparam import check_int_positive, check_bool
|
||||
from mindspore._checkparam import check_int_positive, Validator
|
||||
from ...cell import Cell
|
||||
from ...layer.activation import get_activation
|
||||
from .layer_distribution import NormalPrior, NormalPosterior
|
||||
|
@ -41,7 +41,7 @@ class _DenseVariational(Cell):
|
|||
super(_DenseVariational, self).__init__()
|
||||
self.in_channels = check_int_positive(in_channels)
|
||||
self.out_channels = check_int_positive(out_channels)
|
||||
self.has_bias = check_bool(has_bias)
|
||||
self.has_bias = Validator.check_bool(has_bias)
|
||||
|
||||
if isinstance(weight_prior_fn, Cell):
|
||||
self.weight_prior = weight_prior_fn
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
from mindspore._checkparam import check_int_positive, check_bool
|
||||
from mindspore._checkparam import check_int_positive, Validator
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.train import Model
|
||||
|
@ -84,7 +84,7 @@ class UncertaintyEvaluation:
|
|||
self.epochs = check_int_positive(epochs)
|
||||
self.epi_uncer_model_path = epi_uncer_model_path
|
||||
self.ale_uncer_model_path = ale_uncer_model_path
|
||||
self.save_model = check_bool(save_model)
|
||||
self.save_model = Validator.check_bool(save_model)
|
||||
self.epi_uncer_model = None
|
||||
self.ale_uncer_model = None
|
||||
self.concat = P.Concat(axis=0)
|
||||
|
|
|
@ -216,7 +216,7 @@ class KLDivLossGrad(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, reduction='mean'):
|
||||
self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum'], self.name)
|
||||
self.reduction = validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', self.name)
|
||||
|
||||
def infer_shape(self, x_shape, y_shape, doutput_shape):
|
||||
validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name)
|
||||
|
@ -233,7 +233,7 @@ class BinaryCrossEntropyGrad(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, reduction='mean'):
|
||||
self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum'], self.name)
|
||||
self.reduction = validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', self.name)
|
||||
|
||||
def infer_shape(self, x_shape, y_shape, doutput_shape, weight_shape):
|
||||
validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name)
|
||||
|
@ -609,7 +609,7 @@ class _PoolGrad(PrimitiveWithInfer):
|
|||
|
||||
validator.check_value_type('ksize', ksize, [int, tuple], self.name)
|
||||
validator.check_value_type('strides', strides, [int, tuple], self.name)
|
||||
self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name)
|
||||
self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name)
|
||||
self.add_prim_attr("padding", self.padding)
|
||||
self.is_maxpoolgradwithargmax = (self.name == "MaxPoolGradWithArgmax")
|
||||
if not self.is_maxpoolgradwithargmax:
|
||||
|
@ -1457,7 +1457,7 @@ class MirrorPadGrad(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, mode="REFLECT"):
|
||||
"""Initialize MirrorPad"""
|
||||
validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC'], self.name)
|
||||
validator.check_string(mode, ['REFLECT', 'SYMMETRIC'], 'mode', self.name)
|
||||
self.mode = mode
|
||||
|
||||
def __infer__(self, dout, paddings):
|
||||
|
@ -1570,7 +1570,7 @@ class BasicLSTMCellCStateGrad(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, forget_bias, activation):
|
||||
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
|
||||
self.activation = validator.check_string("activation", activation, ['tanh'], self.name)
|
||||
self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)
|
||||
self.add_prim_attr("io_format", "ND")
|
||||
|
||||
def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape):
|
||||
|
|
|
@ -67,7 +67,7 @@ class ExtractImagePatches(PrimitiveWithInfer):
|
|||
_check_tuple_or_list("ksize", ksizes, self.name)
|
||||
_check_tuple_or_list("stride", strides, self.name)
|
||||
_check_tuple_or_list("rate", rates, self.name)
|
||||
self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name)
|
||||
self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name)
|
||||
self.add_prim_attr("padding", self.padding)
|
||||
self.add_prim_attr("io_format", "NHWC")
|
||||
self.is_ge = context.get_context("enable_ge")
|
||||
|
@ -206,8 +206,8 @@ class Quant(PrimitiveWithInfer):
|
|||
self.scale = validator.check_value_type("scale", scale, [float], self.name)
|
||||
self.offset = validator.check_value_type("offset", offset, [float], self.name)
|
||||
self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
|
||||
self.round_mode = validator.check_string("round_mode", round_mode,
|
||||
["Round", "Floor", "Ceil", "Trunc"], self.name)
|
||||
self.round_mode = validator.check_string(round_mode, ["Round", "Floor", "Ceil", "Trunc"],
|
||||
"round_mode", self.name)
|
||||
self.add_prim_attr("io_format", "ND")
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
|
|
|
@ -513,7 +513,7 @@ class Im2Col(PrimitiveWithInfer):
|
|||
self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True)
|
||||
self.add_prim_attr('dilation', self.dilation)
|
||||
validator.check_value_type('pad', pad, (int,), self.name)
|
||||
self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name)
|
||||
self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name)
|
||||
self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name)
|
||||
if self.pad_mode == 'pad':
|
||||
validator.check_integer('pad', self.pad, 0, Rel.GE, self.name)
|
||||
|
|
|
@ -82,7 +82,7 @@ class CropAndResize(PrimitiveWithInfer):
|
|||
"""Initialize CropAndResize"""
|
||||
self.init_prim_io_names(inputs=['x', 'boxes', 'box_index', 'crop_size'], outputs=['y'])
|
||||
validator.check_value_type("method", method, [str], self.name)
|
||||
validator.check_string("method", method, ["bilinear", "nearest", "bilinear_v2"], self.name)
|
||||
validator.check_string(method, ["bilinear", "nearest", "bilinear_v2"], "method", self.name)
|
||||
self.method = method
|
||||
validator.check_value_type("extrapolation_value", extrapolation_value, [float], self.name)
|
||||
self.extrapolation_value = extrapolation_value
|
||||
|
|
|
@ -1484,7 +1484,7 @@ class HistogramFixedWidth(PrimitiveWithInfer):
|
|||
self.nbins = validator.check_value_type("nbins", nbins, [int], self.name)
|
||||
validator.check_integer("nbins", nbins, 1, Rel.GE, self.name)
|
||||
valid_values = ['int32', 'int64']
|
||||
self.dtype = validator.check_string("dtype", dtype, valid_values, self.name)
|
||||
self.dtype = validator.check_string(dtype, valid_values, "dtype", self.name)
|
||||
self.init_prim_io_names(inputs=['x', 'range'], outputs=['y'])
|
||||
|
||||
def infer_shape(self, x_shape, range_shape):
|
||||
|
|
|
@ -995,7 +995,7 @@ class Conv2D(PrimitiveWithInfer):
|
|||
else:
|
||||
validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name)
|
||||
self.padding = pad
|
||||
self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name)
|
||||
self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name)
|
||||
|
||||
if pad_mode != 'pad' and pad != (0, 0, 0, 0):
|
||||
raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.")
|
||||
|
@ -1134,7 +1134,7 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
|
|||
else:
|
||||
validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name)
|
||||
self.padding = pad
|
||||
self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name)
|
||||
self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name)
|
||||
if pad_mode != 'pad' and pad != (0, 0, 0, 0):
|
||||
raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.")
|
||||
if self.pad_mode == 'pad':
|
||||
|
@ -1216,7 +1216,7 @@ class _Pool(PrimitiveWithInfer):
|
|||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
validator.check_value_type('ksize', ksize, [int, tuple], self.name)
|
||||
validator.check_value_type('strides', strides, [int, tuple], self.name)
|
||||
self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name)
|
||||
self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name)
|
||||
self.add_prim_attr("padding", self.padding)
|
||||
self.is_maxpoolwithargmax = (self.name == "MaxPoolWithArgmax")
|
||||
if not self.is_maxpoolwithargmax:
|
||||
|
@ -1521,7 +1521,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
|
|||
else:
|
||||
validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name)
|
||||
self.padding = pad
|
||||
self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name)
|
||||
self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name)
|
||||
if pad_mode != 'pad' and pad != (0, 0, 0, 0):
|
||||
raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.")
|
||||
if self.pad_mode == 'pad':
|
||||
|
@ -1942,8 +1942,8 @@ class DataFormatDimMap(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, src_format='NHWC', dst_format='NCHW'):
|
||||
valid_values = ['NHWC', 'NCHW']
|
||||
self.src_format = validator.check_string("src_format", src_format, valid_values, self.name)
|
||||
self.dst_format = validator.check_string("dst_format", dst_format, valid_values, self.name)
|
||||
self.src_format = validator.check_string(src_format, valid_values, "src_format", self.name)
|
||||
self.dst_format = validator.check_string(dst_format, valid_values, "dst_format", self.name)
|
||||
self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
|
@ -2961,7 +2961,7 @@ class MirrorPad(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, mode='REFLECT'):
|
||||
"""Initialize Pad"""
|
||||
validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC'], self.name)
|
||||
validator.check_string(mode, ['REFLECT', 'SYMMETRIC'], 'mode', self.name)
|
||||
self.mode = mode
|
||||
self.set_const_input_indexes([1])
|
||||
|
||||
|
@ -3651,7 +3651,7 @@ class KLDivLoss(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, reduction='mean'):
|
||||
self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum'], self.name)
|
||||
self.reduction = validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', self.name)
|
||||
|
||||
def infer_shape(self, x_shape, y_shape):
|
||||
validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name)
|
||||
|
@ -3727,7 +3727,7 @@ class BinaryCrossEntropy(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, reduction='mean'):
|
||||
self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum'], self.name)
|
||||
self.reduction = validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', self.name)
|
||||
|
||||
def infer_shape(self, x_shape, y_shape, weight_shape):
|
||||
validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name)
|
||||
|
@ -5487,7 +5487,7 @@ class BasicLSTMCell(PrimitiveWithInfer):
|
|||
self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0.0, 1.0, Rel.INC_BOTH, self.name)
|
||||
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
|
||||
self.state_is_tuple = validator.check_value_type("state_is_tuple", state_is_tuple, [bool], self.name)
|
||||
self.activation = validator.check_string("activation", activation, ['tanh'], self.name)
|
||||
self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)
|
||||
self.add_prim_attr("io_format", "ND")
|
||||
|
||||
def infer_shape(self, x_shape, h_shape, c_shape, w_shape, b_shape):
|
||||
|
@ -5605,9 +5605,9 @@ class DynamicRNN(PrimitiveWithInfer):
|
|||
self.use_peephole = validator.check_value_type("use_peephole", use_peephole, [bool], self.name)
|
||||
self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name)
|
||||
self.is_training = validator.check_value_type("is_training", is_training, [bool], self.name)
|
||||
self.cell_type = validator.check_string("cell_type", cell_type, ['LSTM'], self.name)
|
||||
self.direction = validator.check_string("direction", direction, ['UNIDIRECTIONAL'], self.name)
|
||||
self.activation = validator.check_string("activation", activation, ['tanh'], self.name)
|
||||
self.cell_type = validator.check_string(cell_type, ['LSTM'], "cell_type", self.name)
|
||||
self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name)
|
||||
self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)
|
||||
self.add_prim_attr("io_format", "ND")
|
||||
|
||||
def infer_shape(self, x_shape, w_shape, b_shape, seq_shape, h_shape, c_shape):
|
||||
|
@ -5720,7 +5720,7 @@ class LRN(PrimitiveWithInfer):
|
|||
validator.check_value_type("alpha", alpha, [float], self.name)
|
||||
validator.check_value_type("beta", beta, [float], self.name)
|
||||
validator.check_value_type("norm_region", norm_region, [str], self.name)
|
||||
validator.check_string('norm_region', norm_region, ['ACROSS_CHANNELS'], self.name)
|
||||
validator.check_string(norm_region, ['ACROSS_CHANNELS'], 'norm_region', self.name)
|
||||
validator.check_integer("depth_radius", depth_radius, 0, Rel.GE, self.name)
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
|
|
|
@ -21,7 +21,7 @@ import time
|
|||
import threading
|
||||
import mindspore.context as context
|
||||
from mindspore import log as logger
|
||||
from mindspore._checkparam import check_bool, check_int_non_negative
|
||||
from mindspore._checkparam import Validator, check_int_non_negative
|
||||
from mindspore.train._utils import _make_directory
|
||||
from mindspore.train.serialization import save_checkpoint, _save_graph
|
||||
from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank
|
||||
|
@ -132,8 +132,8 @@ class CheckpointConfig:
|
|||
if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0:
|
||||
self._keep_checkpoint_max = 1
|
||||
|
||||
self._integrated_save = check_bool(integrated_save)
|
||||
self._async_save = check_bool(async_save)
|
||||
self._integrated_save = Validator.check_bool(integrated_save)
|
||||
self._async_save = Validator.check_bool(async_save)
|
||||
|
||||
@property
|
||||
def save_checkpoint_steps(self):
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
import math
|
||||
import os
|
||||
|
||||
from mindspore._checkparam import check_bool, check_int
|
||||
from mindspore._checkparam import Validator, check_int
|
||||
from .. import context, nn
|
||||
from ._utils import _exec_datagraph, _get_types_and_shapes, _construct_tensor_list
|
||||
from ..nn.wrap import GetNextSingleOp
|
||||
|
@ -123,7 +123,7 @@ class DatasetHelper:
|
|||
"""
|
||||
|
||||
def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1):
|
||||
check_bool(dataset_sink_mode)
|
||||
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
|
||||
check_int(sink_size)
|
||||
if sink_size < -1 or sink_size == 0:
|
||||
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))
|
||||
|
|
|
@ -22,7 +22,7 @@ import numpy as np
|
|||
from mindspore import log as logger
|
||||
from ..common.tensor import Tensor
|
||||
from ..nn.metrics import get_metrics
|
||||
from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool, check_int
|
||||
from .._checkparam import check_input_data, check_output_data, check_int_positive, Validator, check_int
|
||||
from .callback import _InternalCallbackParam, RunContext, _CallbackManager
|
||||
from .. import context
|
||||
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
||||
|
@ -548,7 +548,7 @@ class Model:
|
|||
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
|
||||
>>> model.train(2, dataset)
|
||||
"""
|
||||
check_bool(dataset_sink_mode)
|
||||
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
|
||||
if sink_size == -1:
|
||||
sink_size = train_dataset.get_dataset_size()
|
||||
check_int(sink_size)
|
||||
|
@ -664,7 +664,7 @@ class Model:
|
|||
>>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
|
||||
>>> model.eval(dataset)
|
||||
"""
|
||||
check_bool(dataset_sink_mode)
|
||||
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
|
||||
_device_number_check(self._parallel_mode, self._device_number)
|
||||
if not self._metric_fns:
|
||||
raise ValueError("metric fn can not be None or empty.")
|
||||
|
|
|
@ -22,8 +22,7 @@ import mindspore.context as context
|
|||
|
||||
from ... import log as logger
|
||||
from ... import nn, ops
|
||||
from ..._checkparam import Validator
|
||||
from ..._checkparam import Rel
|
||||
from ..._checkparam import Validator, Rel
|
||||
from ...common import Tensor
|
||||
from ...common import dtype as mstype
|
||||
from ...common.api import _executor
|
||||
|
@ -92,16 +91,16 @@ class ConvertToQuantNetwork:
|
|||
self.network = Validator.check_isinstance('network', kwargs["network"], (nn.Cell,))
|
||||
self.weight_qdelay = Validator.check_integer("quant delay", kwargs["quant_delay"][0], 0, Rel.GE)
|
||||
self.act_qdelay = Validator.check_integer("quant delay", kwargs["quant_delay"][-1], 0, Rel.GE)
|
||||
self.bn_fold = Validator.check_bool("bn fold", kwargs["bn_fold"])
|
||||
self.bn_fold = Validator.check_bool(kwargs["bn_fold"], "bn fold")
|
||||
self.freeze_bn = Validator.check_integer("freeze bn", kwargs["freeze_bn"], 0, Rel.GE)
|
||||
self.weight_bits = Validator.check_integer("weights bit", kwargs["num_bits"][0], 0, Rel.GE)
|
||||
self.act_bits = Validator.check_integer("activations bit", kwargs["num_bits"][-1], 0, Rel.GE)
|
||||
self.weight_channel = Validator.check_bool("per channel", kwargs["per_channel"][0])
|
||||
self.act_channel = Validator.check_bool("per channel", kwargs["per_channel"][-1])
|
||||
self.weight_symmetric = Validator.check_bool("symmetric", kwargs["symmetric"][0])
|
||||
self.act_symmetric = Validator.check_bool("symmetric", kwargs["symmetric"][-1])
|
||||
self.weight_range = Validator.check_bool("narrow range", kwargs["narrow_range"][0])
|
||||
self.act_range = Validator.check_bool("narrow range", kwargs["narrow_range"][-1])
|
||||
self.weight_channel = Validator.check_bool(kwargs["per_channel"][0], "per channel")
|
||||
self.act_channel = Validator.check_bool(kwargs["per_channel"][-1], "per channel")
|
||||
self.weight_symmetric = Validator.check_bool(kwargs["symmetric"][0], "symmetric")
|
||||
self.act_symmetric = Validator.check_bool(kwargs["symmetric"][-1], "symmetric")
|
||||
self.weight_range = Validator.check_bool(kwargs["narrow_range"][0], "narrow range")
|
||||
self.act_range = Validator.check_bool(kwargs["narrow_range"][-1], "narrow range")
|
||||
self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv,
|
||||
quant.DenseBnAct: self._convert_dense}
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
"""Dataset help for minddata dataset"""
|
||||
import math
|
||||
import os
|
||||
from mindspore._checkparam import check_bool, check_int
|
||||
from mindspore._checkparam import Validator, check_int
|
||||
from mindspore import context
|
||||
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes
|
||||
from mindspore.nn.wrap import GetNextSingleOp
|
||||
|
@ -61,7 +61,7 @@ class DatasetHelper:
|
|||
"""
|
||||
|
||||
def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1, iter_first_order=1):
|
||||
check_bool(dataset_sink_mode)
|
||||
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
|
||||
check_int(sink_size)
|
||||
if sink_size < -1 or sink_size == 0:
|
||||
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))
|
||||
|
|
|
@ -18,8 +18,7 @@ from mindspore.common.initializer import initializer
|
|||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore._checkparam import check_bool
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore.nn.optim.optimizer import Optimizer
|
||||
from mindspore.parallel._utils import _get_device_num, _get_gradients_mean
|
||||
from src.grad_reducer_thor import DistributedGradReducerThor
|
||||
|
@ -53,12 +52,12 @@ class THOR_GPU(Optimizer):
|
|||
def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max,
|
||||
weight_decay=0.0, loss_scale=1.0, use_nesterov=False, decay_filter=lambda x: x.name not in []):
|
||||
super(THOR_GPU, self).__init__(learning_rate, params, weight_decay, loss_scale)
|
||||
validator.check_value_type("momentum", momentum, [float], self.cls_name)
|
||||
Validator.check_value_type("momentum", momentum, [float], self.cls_name)
|
||||
if isinstance(momentum, float) and momentum < 0.0:
|
||||
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
|
||||
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
|
||||
self.params = self.parameters
|
||||
self.use_nesterov = check_bool(use_nesterov)
|
||||
self.use_nesterov = Validator.check_bool(use_nesterov)
|
||||
self.moments = self.params.clone(prefix="moments", init='zeros')
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.opt = P.ApplyMomentum(use_nesterov=self.use_nesterov)
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore._checkparam import check_bool, twice, check_int_positive
|
||||
from mindspore._checkparam import Validator, twice, check_int_positive
|
||||
from mindspore._extends import cell_attr_register
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
@ -111,7 +111,7 @@ class _Conv(Cell):
|
|||
self.weight = Parameter(initializer(
|
||||
weight_init, [out_channels, in_channels // group, *kernel_size]), name='weight')
|
||||
|
||||
if check_bool(has_bias):
|
||||
if Validator.check_bool(has_bias):
|
||||
self.bias = Parameter(_initializer(
|
||||
bias_init, [out_channels]), name='bias')
|
||||
else:
|
||||
|
@ -294,7 +294,7 @@ class Dense_Thor_GPU(Cell):
|
|||
super(Dense_Thor_GPU, self).__init__()
|
||||
self.in_channels = check_int_positive(in_channels)
|
||||
self.out_channels = check_int_positive(out_channels)
|
||||
self.has_bias = check_bool(has_bias)
|
||||
self.has_bias = Validator.check_bool(has_bias)
|
||||
self.thor = True
|
||||
if isinstance(weight_init, Tensor):
|
||||
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
|
||||
|
@ -643,7 +643,7 @@ class Dense_Thor(Cell):
|
|||
super(Dense_Thor, self).__init__()
|
||||
self.in_channels = check_int_positive(in_channels)
|
||||
self.out_channels = check_int_positive(out_channels)
|
||||
self.has_bias = check_bool(has_bias)
|
||||
self.has_bias = Validator.check_bool(has_bias)
|
||||
self.thor = True
|
||||
self.batch_size = batch_size
|
||||
if isinstance(weight_init, Tensor):
|
||||
|
|
|
@ -19,7 +19,7 @@ from mindspore.ops import functional as F
|
|||
from mindspore._extends import cell_attr_register
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore._checkparam import check_int_positive, check_bool
|
||||
from mindspore._checkparam import check_int_positive, Validator
|
||||
from mindspore.nn.layer.activation import get_activation
|
||||
|
||||
|
||||
|
@ -74,7 +74,7 @@ class GNNFeatureTransform(nn.Cell):
|
|||
super(GNNFeatureTransform, self).__init__()
|
||||
self.in_channels = check_int_positive(in_channels)
|
||||
self.out_channels = check_int_positive(out_channels)
|
||||
self.has_bias = check_bool(has_bias)
|
||||
self.has_bias = Validator.check_bool(has_bias)
|
||||
|
||||
if isinstance(weight_init, Tensor):
|
||||
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
|
||||
|
@ -284,7 +284,7 @@ class AttentionHead(nn.Cell):
|
|||
self.matmul = P.MatMul()
|
||||
self.bias_add = P.BiasAdd()
|
||||
self.bias = Parameter(initializer('zeros', self.out_channel), name='bias')
|
||||
self.residual = check_bool(residual)
|
||||
self.residual = Validator.check_bool(residual)
|
||||
if self.residual:
|
||||
if in_channel != out_channel:
|
||||
self.residual_transform_flag = True
|
||||
|
@ -458,7 +458,7 @@ class GAT(nn.Cell):
|
|||
self.attn_drop = attn_drop
|
||||
self.ftr_drop = ftr_drop
|
||||
self.activation = activation
|
||||
self.residual = check_bool(residual)
|
||||
self.residual = Validator.check_bool(residual)
|
||||
self.layers = []
|
||||
# first layer
|
||||
self.layers.append(AttentionAggregator(
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
import os
|
||||
|
||||
from mindspore import context
|
||||
from mindspore._checkparam import check_bool, check_int
|
||||
from mindspore._checkparam import Validator, check_int
|
||||
from mindspore.parallel._utils import _get_device_num, _need_to_full, _to_full_shapes
|
||||
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes
|
||||
|
||||
|
@ -58,7 +58,7 @@ class DatasetHelper:
|
|||
"""
|
||||
|
||||
def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1, iter_first_order=0):
|
||||
check_bool(dataset_sink_mode)
|
||||
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
|
||||
check_int(sink_size)
|
||||
if sink_size < -1 or sink_size == 0:
|
||||
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))
|
||||
|
|
|
@ -22,7 +22,7 @@ from mindspore._c_expression import init_exec_dataset
|
|||
from mindspore import context
|
||||
from mindspore import log as logger
|
||||
from mindspore import nn
|
||||
from mindspore._checkparam import check_input_data, check_output_data, check_int_positive, check_bool, check_int
|
||||
from mindspore._checkparam import check_input_data, check_output_data, check_int_positive, Validator, check_int
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.dtype import pytype_to_dtype
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
@ -603,7 +603,7 @@ class Model:
|
|||
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
|
||||
>>> model.train(2, dataset)
|
||||
"""
|
||||
check_bool(dataset_sink_mode)
|
||||
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
|
||||
check_int(sink_size)
|
||||
if sink_size < -1 or sink_size == 0:
|
||||
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))
|
||||
|
@ -718,7 +718,7 @@ class Model:
|
|||
>>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
|
||||
>>> model.eval(dataset)
|
||||
"""
|
||||
check_bool(dataset_sink_mode)
|
||||
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
|
||||
_device_number_check(self._parallel_mode, self._device_number)
|
||||
if not self._metric_fns:
|
||||
raise ValueError("metric fn can not be None or empty.")
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
"""thor_layer"""
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore._checkparam import check_bool, check_int_positive
|
||||
from mindspore._checkparam import Validator, check_int_positive
|
||||
from mindspore.common.initializer import TruncatedNormal, initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
@ -162,7 +162,7 @@ class Dense_Thor(Cell):
|
|||
super(Dense_Thor, self).__init__()
|
||||
self.in_channels = check_int_positive(in_channels)
|
||||
self.out_channels = check_int_positive(out_channels)
|
||||
self.has_bias = check_bool(has_bias)
|
||||
self.has_bias = Validator.check_bool(has_bias)
|
||||
self.thor = True
|
||||
if isinstance(weight_init, Tensor):
|
||||
if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
"""Aggregator."""
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore._checkparam import check_int_positive, check_bool
|
||||
from mindspore._checkparam import check_int_positive, Validator
|
||||
from mindspore._extends import cell_attr_register
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.nn.layer.activation import get_activation
|
||||
|
@ -75,7 +75,7 @@ class GNNFeatureTransform(nn.Cell):
|
|||
super(GNNFeatureTransform, self).__init__()
|
||||
self.in_channels = check_int_positive(in_channels)
|
||||
self.out_channels = check_int_positive(out_channels)
|
||||
self.has_bias = check_bool(has_bias)
|
||||
self.has_bias = Validator.check_bool(has_bias)
|
||||
|
||||
if isinstance(weight_init, Tensor):
|
||||
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
|
||||
|
@ -284,7 +284,7 @@ class AttentionHead(nn.Cell):
|
|||
self.batch_matmul = P.BatchMatMul()
|
||||
self.bias_add = P.BiasAdd()
|
||||
self.bias = Parameter(initializer('zeros', self.out_channel), name='bias')
|
||||
self.residual = check_bool(residual)
|
||||
self.residual = Validator.check_bool(residual)
|
||||
if self.residual:
|
||||
if in_channel != out_channel:
|
||||
self.residual_transform_flag = True
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
"""Graph Attention Networks."""
|
||||
import mindspore.nn as nn
|
||||
from mindspore._checkparam import check_bool, check_int_positive
|
||||
from mindspore._checkparam import Validator, check_int_positive
|
||||
|
||||
from aggregator import AttentionAggregator
|
||||
|
||||
|
@ -79,7 +79,7 @@ class GAT(nn.Cell):
|
|||
self.attn_drop = attn_drop
|
||||
self.ftr_drop = ftr_drop
|
||||
self.activation = activation
|
||||
self.residual = check_bool(residual)
|
||||
self.residual = Validator.check_bool(residual)
|
||||
self.layers = []
|
||||
# first layer
|
||||
self.layers.append(AttentionAggregator(
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Dataset help for minddata dataset"""
|
||||
from mindspore._checkparam import check_bool
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _to_full_shapes
|
||||
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes
|
||||
from mindspore.context import ParallelMode
|
||||
|
@ -50,7 +50,7 @@ class DatasetHelper:
|
|||
"""
|
||||
|
||||
def __init__(self, dataset, dataset_sink_mode=True, iter_first_order=0):
|
||||
check_bool(dataset_sink_mode)
|
||||
Validator.check_bool(dataset_sink_mode)
|
||||
self.iter = _DatasetIterMSLoopSink(dataset, iter_first_order)
|
||||
|
||||
def __iter__(self):
|
||||
|
|
|
@ -19,7 +19,7 @@ from mindspore import context
|
|||
from mindspore import log as logger
|
||||
from mindspore import nn
|
||||
from mindspore._c_expression import init_exec_dataset
|
||||
from mindspore._checkparam import check_input_data, check_output_data, check_int_positive, check_bool
|
||||
from mindspore._checkparam import check_input_data, check_output_data, check_int_positive, Validator
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.dtype import pytype_to_dtype
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
@ -575,7 +575,7 @@ class Model:
|
|||
repeat_count = train_dataset.get_repeat_count()
|
||||
if epoch != repeat_count and dataset_sink_mode is True:
|
||||
logger.warning(f"The epoch_size {epoch} is not the same with dataset repeat_count {repeat_count}")
|
||||
check_bool(dataset_sink_mode)
|
||||
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
|
||||
_device_number_check(self._parallel_mode, self._device_number)
|
||||
_parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
|
||||
|
||||
|
@ -682,7 +682,7 @@ class Model:
|
|||
>>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
|
||||
>>> model.eval(dataset)
|
||||
"""
|
||||
check_bool(dataset_sink_mode)
|
||||
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
|
||||
_device_number_check(self._parallel_mode, self._device_number)
|
||||
if not self._metric_fns:
|
||||
raise ValueError("metric fn can not be None or empty.")
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
import numpy as np
|
||||
import mindspore as ms
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore._checkparam import check_bool, twice, check_int_positive
|
||||
from mindspore._checkparam import Validator, twice, check_int_positive
|
||||
from mindspore._extends import cell_attr_register
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
@ -87,7 +87,7 @@ class _Conv(Cell):
|
|||
self.weight = Parameter(initializer(
|
||||
weight_init, [out_channels, in_channels // group, *kernel_size]), name='weight')
|
||||
|
||||
if check_bool(has_bias):
|
||||
if Validator.check_bool(has_bias):
|
||||
self.bias = Parameter(_initializer(
|
||||
bias_init, [out_channels]), name='bias')
|
||||
else:
|
||||
|
@ -339,7 +339,7 @@ class Dense_Thor(Cell):
|
|||
super(Dense_Thor, self).__init__()
|
||||
self.in_channels = check_int_positive(in_channels)
|
||||
self.out_channels = check_int_positive(out_channels)
|
||||
self.has_bias = check_bool(has_bias)
|
||||
self.has_bias = Validator.check_bool(has_bias)
|
||||
self.thor = True
|
||||
if isinstance(weight_init, Tensor):
|
||||
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
import pytest
|
||||
|
||||
from mindspore._checkparam import check_int, check_int_positive, \
|
||||
check_input_format, check_bool, twice
|
||||
check_input_format, Validator, twice
|
||||
|
||||
kernel_size = 5
|
||||
kernel_size1 = twice(kernel_size)
|
||||
|
@ -66,26 +66,26 @@ def test_check_int_5():
|
|||
|
||||
|
||||
def test_check_bool_1():
|
||||
assert check_bool(True)
|
||||
assert Validator.check_bool(True)
|
||||
|
||||
|
||||
def test_check_bool_2():
|
||||
assert check_bool(False) is not True
|
||||
assert Validator.check_bool(False) is not True
|
||||
|
||||
|
||||
def test_check_bool_3():
|
||||
with pytest.raises(TypeError):
|
||||
check_bool("str")
|
||||
Validator.check_bool("str")
|
||||
|
||||
|
||||
def test_check_bool_4():
|
||||
with pytest.raises(TypeError):
|
||||
check_bool(1)
|
||||
Validator.check_bool(1)
|
||||
|
||||
|
||||
def test_check_bool_5():
|
||||
with pytest.raises(TypeError):
|
||||
check_bool(3.5)
|
||||
Validator.check_bool(3.5)
|
||||
|
||||
|
||||
def test_twice_1():
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
import pytest
|
||||
|
||||
from mindspore._checkparam import check_int, check_int_positive, \
|
||||
check_bool, check_input_format, _expand_tuple
|
||||
Validator, check_input_format, _expand_tuple
|
||||
|
||||
once = _expand_tuple(1)
|
||||
twice = _expand_tuple(2)
|
||||
|
@ -60,26 +60,26 @@ def test_check_int_4():
|
|||
|
||||
|
||||
def test_check_bool_1():
|
||||
assert check_bool(True)
|
||||
assert Validator.check_bool(True)
|
||||
|
||||
|
||||
def test_check_bool_2():
|
||||
assert check_bool(False) is not True
|
||||
assert Validator.check_bool(False) is not True
|
||||
|
||||
|
||||
def test_check_bool_3():
|
||||
with pytest.raises(TypeError):
|
||||
check_bool("str")
|
||||
Validator.check_bool("str")
|
||||
|
||||
|
||||
def test_check_bool_4():
|
||||
with pytest.raises(TypeError):
|
||||
check_bool(1)
|
||||
Validator.check_bool(1)
|
||||
|
||||
|
||||
def test_check_bool_5():
|
||||
with pytest.raises(TypeError):
|
||||
check_bool(3.5)
|
||||
Validator.check_bool(3.5)
|
||||
|
||||
|
||||
def test_twice_1():
|
||||
|
|
Loading…
Reference in New Issue