[ME] delete check_bool and replace with Validate.check_bool

This commit is contained in:
chenzomi 2020-10-09 12:26:24 +08:00
parent 6c9b6d491d
commit d4e8e94981
33 changed files with 129 additions and 152 deletions

View File

@ -26,10 +26,6 @@ from mindspore import log as logger
from mindspore.common import dtype as mstype
# Named string regular expression
_name_re = r"^\w+[0-9a-zA-Z\_\.]*$"
class Rel(Enum):
"""Numerical relationship between variables, logical relationship enumeration definition of range."""
# scalar compare
@ -114,7 +110,7 @@ class Validator:
@staticmethod
def check_integer(arg_name, arg_value, value, rel, prim_name=None):
"""Integer value judgment."""
"""Check argument is integer"""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
excp_cls = TypeError if type_mismatch else ValueError
@ -125,6 +121,7 @@ class Validator:
f' with type `{type(arg_value).__name__}`.')
return arg_value
@staticmethod
def check_number(arg_name, arg_value, value, rel, prim_name):
"""Number value judgment."""
@ -142,10 +139,11 @@ class Validator:
return arg_value
@staticmethod
def check_bool(arg_name, arg_value):
"""Check arg isinstance of bool"""
def check_bool(arg_value, arg_name=None):
"""Check argument is instance of bool"""
if not isinstance(arg_value, bool):
raise ValueError(f'The `{arg_name}` should be isinstance of bool, but got {arg_value}.')
arg_name = arg_name if arg_name else "Parameter"
raise TypeError(f'`{arg_name}` should be isinstance of bool, but got `{arg_value}`.')
return arg_value
@staticmethod
@ -170,15 +168,14 @@ class Validator:
return arg_value
@staticmethod
def check_string(arg_name, arg_value, valid_values, prim_name):
def check_string(arg_value, valid_values, arg_name=None, prim_name=None):
"""Checks whether a string is in some value list"""
if isinstance(arg_value, str) and arg_value in valid_values:
return arg_value
if len(valid_values) == 1:
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be str and must be {valid_values[0]},'
f' but got {arg_value}.')
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be str and must be one of {valid_values},'
f' but got {arg_value}.')
arg_name = arg_name if arg_name else "Parameter"
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
raise ValueError(f'{msg_prefix} `{arg_name}` should be str and must be in `{valid_values}`,'
f' but got `{arg_value}`.')
@staticmethod
def check_pad_value_by_mode(pad_mode, padding, prim_name):
@ -404,24 +401,6 @@ def check_int_zero_one(input_param):
raise ValueError("The data must be 0 or 1.")
def check_bool(input_param):
"""Bool type judgment."""
if isinstance(input_param, bool):
return input_param
raise TypeError("Input type must be bool!")
def check_string(input_param, valid_values):
"""String type judgment."""
if isinstance(input_param, str) and input_param in valid_values:
return input_param
if len(valid_values) == 1:
raise ValueError(f'Input should be str and must be {valid_values[0]},'
f' but got {input_param}.')
raise ValueError(f'Input should be str and must be one of {valid_values},'
f' but got {input_param}.')
def check_input_format(input_param):
"""Judge input format."""
if input_param == "NCHW":
@ -587,7 +566,8 @@ def check_shape(arg_name, arg_value):
def _check_str_by_regular(target, reg=None, flag=re.ASCII):
if reg is None:
reg = _name_re
# Named string regular expression
reg = r"^\w+[0-9a-zA-Z\_\.]*$"
if re.match(reg, target, flag) is None:
raise ValueError("'{}' is illegal, it should be match regular'{}' by flags'{}'".format(target, reg, flag))
return True

View File

@ -27,7 +27,7 @@ from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops.primitive import constexpr
from mindspore.common.parameter import Parameter
from mindspore._extends import cell_attr_register
from mindspore._checkparam import Rel, Validator as validator, check_int_positive, check_bool
from mindspore._checkparam import Rel, Validator, check_int_positive
from mindspore.common.api import ms_function
from mindspore import context
from ..cell import Cell
@ -86,8 +86,8 @@ class Dropout(Cell):
super(Dropout, self).__init__()
if keep_prob <= 0 or keep_prob > 1:
raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob))
validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name)
Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
Validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name)
self.keep_prob = keep_prob
seed0 = get_seed()
self.seed0 = seed0 if seed0 is not None else 0
@ -205,7 +205,7 @@ class Dense(Cell):
super(Dense, self).__init__()
self.in_channels = check_int_positive(in_channels)
self.out_channels = check_int_positive(out_channels)
self.has_bias = check_bool(has_bias)
self.has_bias = Validator.check_bool(has_bias)
if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
@ -348,7 +348,7 @@ class Norm(Cell):
def __init__(self, axis=(), keep_dims=False):
super(Norm, self).__init__()
validator.check_value_type("keep_dims", keep_dims, [bool], self.cls_name)
Validator.check_value_type("keep_dims", keep_dims, [bool], self.cls_name)
self.axis = axis
self.keep_dims = keep_dims
self.reduce_sum = P.ReduceSum(True)
@ -472,7 +472,7 @@ class Pad(Cell):
super(Pad, self).__init__()
self.mode = mode
self.paddings = paddings
validator.check_string('mode', self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"], self.cls_name)
Validator.check_string(self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"], 'mode', self.cls_name)
if not isinstance(paddings, tuple):
raise TypeError('Paddings must be tuple type.')
for item in paddings:
@ -549,7 +549,7 @@ class Unfold(Cell):
@constexpr
def _get_matrix_diag_assist(x_shape, x_dtype):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, "_get_matrix_diag_assist")
Validator.check_integer("x rank", len(x_shape), 1, Rel.GE, "_get_matrix_diag_assist")
base_eye = np.eye(x_shape[-1], x_shape[-1]).reshape(-1)
assist = np.tile(base_eye, x_shape[:-1]).reshape(x_shape + (x_shape[-1],))
return Tensor(assist, x_dtype)
@ -557,7 +557,7 @@ def _get_matrix_diag_assist(x_shape, x_dtype):
@constexpr
def _get_matrix_diag_part_assist(x_shape, x_dtype):
validator.check_integer("x rank", len(x_shape), 2, Rel.GE, "_get_matrix_diag_part_assist")
Validator.check_integer("x rank", len(x_shape), 2, Rel.GE, "_get_matrix_diag_part_assist")
base_eye = np.eye(x_shape[-2], x_shape[-1]).reshape(-1)
assist = np.tile(base_eye, x_shape[:-2]).reshape(x_shape)
return Tensor(assist, x_dtype)

View File

@ -21,7 +21,7 @@ from mindspore.ops.primitive import constexpr
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer, Initializer
from mindspore.common.tensor import Tensor
from mindspore._checkparam import Validator, Rel, check_bool, twice, check_int_positive
from mindspore._checkparam import Validator, Rel, twice, check_int_positive
from mindspore._extends import cell_attr_register
from ..cell import Cell
@ -92,7 +92,7 @@ class _Conv(Cell):
shape = [out_channels, in_channels // group, *kernel_size]
self.weight = Parameter(initializer(self.weight_init, shape), name='weight')
if check_bool(has_bias):
if Validator.check_bool(has_bias):
self.bias = Parameter(initializer(self.bias_init, [out_channels]), name='bias')
else:
if self.bias_init != 'zeros':
@ -566,7 +566,7 @@ class Conv2dTranspose(_Conv):
self.is_valid = self.pad_mode == 'valid'
self.is_same = self.pad_mode == 'same'
self.is_pad = self.pad_mode == 'pad'
if check_bool(has_bias):
if Validator.check_bool(has_bias):
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
# cause Conv2DBackpropInput's out_channel refers to Conv2D's out_channel.
@ -745,7 +745,7 @@ class Conv1dTranspose(_Conv):
self.is_valid = self.pad_mode == 'valid'
self.is_same = self.pad_mode == 'same'
self.is_pad = self.pad_mode == 'pad'
if check_bool(has_bias):
if Validator.check_bool(has_bias):
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
# cause Conv2DBackpropInput's out_channel refers to Conv2D's out_channel.

View File

@ -19,7 +19,7 @@ from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.ops.primitive import constexpr
import mindspore.context as context
from mindspore._checkparam import check_bool, check_typename, check_int_positive
from mindspore._checkparam import Validator, check_typename, check_int_positive
from mindspore._extends import cell_attr_register
from mindspore.communication.management import get_group_size, get_rank
from mindspore.communication import management
@ -604,7 +604,7 @@ class GroupNorm(Cell):
if num_channels % num_groups != 0:
raise ValueError("num_channels should be divided by num_groups")
self.eps = check_typename('eps', eps, (float,))
self.affine = check_bool(affine)
self.affine = Validator.check_bool(affine)
gamma = initializer(gamma_init, num_channels)
beta = initializer(beta_init, num_channels)

View File

@ -27,7 +27,7 @@ class _PoolNd(Cell):
def __init__(self, kernel_size, stride, pad_mode):
super(_PoolNd, self).__init__()
self.pad_mode = validator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME'], self.cls_name)
self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.cls_name)
def _check_int_or_tuple(arg_name, arg_value):
validator.check_value_type(arg_name, arg_value, [int, tuple], self.cls_name)
@ -270,7 +270,7 @@ class AvgPool1d(_PoolNd):
super(AvgPool1d, self).__init__(kernel_size, stride, pad_mode)
validator.check_value_type('kernel_size', kernel_size, [int], self.cls_name)
validator.check_value_type('stride', stride, [int], self.cls_name)
self.pad_mode = validator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME'], self.cls_name)
self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.cls_name)
validator.check_integer("kernel_size", kernel_size, 1, Rel.GE, self.cls_name)
validator.check_integer("stride", stride, 1, Rel.GE, self.cls_name)
self.kernel_size = (1, kernel_size)

View File

@ -23,7 +23,7 @@ from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor
from mindspore._checkparam import Rel, check_int_positive, check_bool, twice, Validator
from mindspore._checkparam import Validator, Rel, check_int_positive, twice
import mindspore.context as context
from .normalization import BatchNorm2d, BatchNorm1d
from .activation import get_activation, ReLU, LeakyReLU
@ -133,7 +133,7 @@ class Conv2dBnAct(Cell):
has_bias=has_bias,
weight_init=weight_init,
bias_init=bias_init)
self.has_bn = Validator.check_bool("has_bn", has_bn)
self.has_bn = Validator.check_bool(has_bn, "has_bn")
self.has_act = activation is not None
self.after_fake = after_fake
if has_bn:
@ -201,7 +201,7 @@ class DenseBnAct(Cell):
weight_init,
bias_init,
has_bias)
self.has_bn = Validator.check_bool("has_bn", has_bn)
self.has_bn = Validator.check_bool(has_bn, "has_bn")
self.has_act = activation is not None
self.after_fake = after_fake
if has_bn:
@ -511,7 +511,7 @@ class Conv2dBnFoldQuant(Cell):
channel_axis = 0
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
self.bias_add = P.BiasAdd()
if check_bool(has_bias):
if Validator.check_bool(has_bias):
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
else:
self.bias = None
@ -668,7 +668,7 @@ class Conv2dBnWithoutFoldQuant(Cell):
self.quant_delay = quant_delay
self.bias_add = P.BiasAdd()
if check_bool(has_bias):
if Validator.check_bool(has_bias):
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
else:
self.bias = None
@ -799,7 +799,7 @@ class Conv2dQuant(Cell):
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
self.bias_add = P.BiasAdd()
if check_bool(has_bias):
if Validator.check_bool(has_bias):
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
else:
self.bias = None
@ -888,7 +888,7 @@ class DenseQuant(Cell):
super(DenseQuant, self).__init__()
self.in_channels = check_int_positive(in_channels)
self.out_channels = check_int_positive(out_channels)
self.has_bias = check_bool(has_bias)
self.has_bias = Validator.check_bool(has_bias)
if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \

View File

@ -18,8 +18,7 @@ from mindspore.ops import _selected_ops
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore._checkparam import check_bool
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Validator
from .optimizer import Optimizer
_momentum_opt = C.MultitypeFuncGraph("momentum_opt")
@ -126,12 +125,12 @@ class Momentum(Optimizer):
"""
def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0, use_nesterov=False):
super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale)
validator.check_value_type("momentum", momentum, [float], self.cls_name)
Validator.check_value_type("momentum", momentum, [float], self.cls_name)
if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
self.params = self.parameters
self.use_nesterov = check_bool(use_nesterov)
self.use_nesterov = Validator.check_bool(use_nesterov)
self.moments = self.params.clone(prefix="moments", init='zeros')
self.hyper_map = C.HyperMap()
self.opt = _selected_ops.ApplyMomentum(use_nesterov=self.use_nesterov)

View File

@ -15,7 +15,7 @@
"""dense_variational"""
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore._checkparam import check_int_positive, check_bool
from mindspore._checkparam import check_int_positive, Validator
from ...cell import Cell
from ...layer.activation import get_activation
from .layer_distribution import NormalPrior, NormalPosterior
@ -41,7 +41,7 @@ class _DenseVariational(Cell):
super(_DenseVariational, self).__init__()
self.in_channels = check_int_positive(in_channels)
self.out_channels = check_int_positive(out_channels)
self.has_bias = check_bool(has_bias)
self.has_bias = Validator.check_bool(has_bias)
if isinstance(weight_prior_fn, Cell):
self.weight_prior = weight_prior_fn

View File

@ -16,7 +16,7 @@
from copy import deepcopy
import numpy as np
from mindspore._checkparam import check_int_positive, check_bool
from mindspore._checkparam import check_int_positive, Validator
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.train import Model
@ -84,7 +84,7 @@ class UncertaintyEvaluation:
self.epochs = check_int_positive(epochs)
self.epi_uncer_model_path = epi_uncer_model_path
self.ale_uncer_model_path = ale_uncer_model_path
self.save_model = check_bool(save_model)
self.save_model = Validator.check_bool(save_model)
self.epi_uncer_model = None
self.ale_uncer_model = None
self.concat = P.Concat(axis=0)

View File

@ -216,7 +216,7 @@ class KLDivLossGrad(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, reduction='mean'):
self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum'], self.name)
self.reduction = validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', self.name)
def infer_shape(self, x_shape, y_shape, doutput_shape):
validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name)
@ -233,7 +233,7 @@ class BinaryCrossEntropyGrad(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, reduction='mean'):
self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum'], self.name)
self.reduction = validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', self.name)
def infer_shape(self, x_shape, y_shape, doutput_shape, weight_shape):
validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name)
@ -609,7 +609,7 @@ class _PoolGrad(PrimitiveWithInfer):
validator.check_value_type('ksize', ksize, [int, tuple], self.name)
validator.check_value_type('strides', strides, [int, tuple], self.name)
self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name)
self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name)
self.add_prim_attr("padding", self.padding)
self.is_maxpoolgradwithargmax = (self.name == "MaxPoolGradWithArgmax")
if not self.is_maxpoolgradwithargmax:
@ -1457,7 +1457,7 @@ class MirrorPadGrad(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, mode="REFLECT"):
"""Initialize MirrorPad"""
validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC'], self.name)
validator.check_string(mode, ['REFLECT', 'SYMMETRIC'], 'mode', self.name)
self.mode = mode
def __infer__(self, dout, paddings):
@ -1570,7 +1570,7 @@ class BasicLSTMCellCStateGrad(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, forget_bias, activation):
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
self.activation = validator.check_string("activation", activation, ['tanh'], self.name)
self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)
self.add_prim_attr("io_format", "ND")
def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape):

View File

@ -67,7 +67,7 @@ class ExtractImagePatches(PrimitiveWithInfer):
_check_tuple_or_list("ksize", ksizes, self.name)
_check_tuple_or_list("stride", strides, self.name)
_check_tuple_or_list("rate", rates, self.name)
self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name)
self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name)
self.add_prim_attr("padding", self.padding)
self.add_prim_attr("io_format", "NHWC")
self.is_ge = context.get_context("enable_ge")
@ -206,8 +206,8 @@ class Quant(PrimitiveWithInfer):
self.scale = validator.check_value_type("scale", scale, [float], self.name)
self.offset = validator.check_value_type("offset", offset, [float], self.name)
self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
self.round_mode = validator.check_string("round_mode", round_mode,
["Round", "Floor", "Ceil", "Trunc"], self.name)
self.round_mode = validator.check_string(round_mode, ["Round", "Floor", "Ceil", "Trunc"],
"round_mode", self.name)
self.add_prim_attr("io_format", "ND")
def infer_shape(self, x_shape):

View File

@ -513,7 +513,7 @@ class Im2Col(PrimitiveWithInfer):
self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True)
self.add_prim_attr('dilation', self.dilation)
validator.check_value_type('pad', pad, (int,), self.name)
self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name)
self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name)
self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name)
if self.pad_mode == 'pad':
validator.check_integer('pad', self.pad, 0, Rel.GE, self.name)

View File

@ -82,7 +82,7 @@ class CropAndResize(PrimitiveWithInfer):
"""Initialize CropAndResize"""
self.init_prim_io_names(inputs=['x', 'boxes', 'box_index', 'crop_size'], outputs=['y'])
validator.check_value_type("method", method, [str], self.name)
validator.check_string("method", method, ["bilinear", "nearest", "bilinear_v2"], self.name)
validator.check_string(method, ["bilinear", "nearest", "bilinear_v2"], "method", self.name)
self.method = method
validator.check_value_type("extrapolation_value", extrapolation_value, [float], self.name)
self.extrapolation_value = extrapolation_value

View File

@ -1484,7 +1484,7 @@ class HistogramFixedWidth(PrimitiveWithInfer):
self.nbins = validator.check_value_type("nbins", nbins, [int], self.name)
validator.check_integer("nbins", nbins, 1, Rel.GE, self.name)
valid_values = ['int32', 'int64']
self.dtype = validator.check_string("dtype", dtype, valid_values, self.name)
self.dtype = validator.check_string(dtype, valid_values, "dtype", self.name)
self.init_prim_io_names(inputs=['x', 'range'], outputs=['y'])
def infer_shape(self, x_shape, range_shape):

View File

@ -995,7 +995,7 @@ class Conv2D(PrimitiveWithInfer):
else:
validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name)
self.padding = pad
self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name)
self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name)
if pad_mode != 'pad' and pad != (0, 0, 0, 0):
raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.")
@ -1134,7 +1134,7 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
else:
validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name)
self.padding = pad
self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name)
self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name)
if pad_mode != 'pad' and pad != (0, 0, 0, 0):
raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.")
if self.pad_mode == 'pad':
@ -1216,7 +1216,7 @@ class _Pool(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['x'], outputs=['output'])
validator.check_value_type('ksize', ksize, [int, tuple], self.name)
validator.check_value_type('strides', strides, [int, tuple], self.name)
self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name)
self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name)
self.add_prim_attr("padding", self.padding)
self.is_maxpoolwithargmax = (self.name == "MaxPoolWithArgmax")
if not self.is_maxpoolwithargmax:
@ -1521,7 +1521,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
else:
validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name)
self.padding = pad
self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name)
self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name)
if pad_mode != 'pad' and pad != (0, 0, 0, 0):
raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.")
if self.pad_mode == 'pad':
@ -1942,8 +1942,8 @@ class DataFormatDimMap(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, src_format='NHWC', dst_format='NCHW'):
valid_values = ['NHWC', 'NCHW']
self.src_format = validator.check_string("src_format", src_format, valid_values, self.name)
self.dst_format = validator.check_string("dst_format", dst_format, valid_values, self.name)
self.src_format = validator.check_string(src_format, valid_values, "src_format", self.name)
self.dst_format = validator.check_string(dst_format, valid_values, "dst_format", self.name)
self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
def infer_shape(self, x_shape):
@ -2961,7 +2961,7 @@ class MirrorPad(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, mode='REFLECT'):
"""Initialize Pad"""
validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC'], self.name)
validator.check_string(mode, ['REFLECT', 'SYMMETRIC'], 'mode', self.name)
self.mode = mode
self.set_const_input_indexes([1])
@ -3651,7 +3651,7 @@ class KLDivLoss(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, reduction='mean'):
self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum'], self.name)
self.reduction = validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', self.name)
def infer_shape(self, x_shape, y_shape):
validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name)
@ -3727,7 +3727,7 @@ class BinaryCrossEntropy(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, reduction='mean'):
self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum'], self.name)
self.reduction = validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', self.name)
def infer_shape(self, x_shape, y_shape, weight_shape):
validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name)
@ -5487,7 +5487,7 @@ class BasicLSTMCell(PrimitiveWithInfer):
self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0.0, 1.0, Rel.INC_BOTH, self.name)
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
self.state_is_tuple = validator.check_value_type("state_is_tuple", state_is_tuple, [bool], self.name)
self.activation = validator.check_string("activation", activation, ['tanh'], self.name)
self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)
self.add_prim_attr("io_format", "ND")
def infer_shape(self, x_shape, h_shape, c_shape, w_shape, b_shape):
@ -5605,9 +5605,9 @@ class DynamicRNN(PrimitiveWithInfer):
self.use_peephole = validator.check_value_type("use_peephole", use_peephole, [bool], self.name)
self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name)
self.is_training = validator.check_value_type("is_training", is_training, [bool], self.name)
self.cell_type = validator.check_string("cell_type", cell_type, ['LSTM'], self.name)
self.direction = validator.check_string("direction", direction, ['UNIDIRECTIONAL'], self.name)
self.activation = validator.check_string("activation", activation, ['tanh'], self.name)
self.cell_type = validator.check_string(cell_type, ['LSTM'], "cell_type", self.name)
self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name)
self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)
self.add_prim_attr("io_format", "ND")
def infer_shape(self, x_shape, w_shape, b_shape, seq_shape, h_shape, c_shape):
@ -5720,7 +5720,7 @@ class LRN(PrimitiveWithInfer):
validator.check_value_type("alpha", alpha, [float], self.name)
validator.check_value_type("beta", beta, [float], self.name)
validator.check_value_type("norm_region", norm_region, [str], self.name)
validator.check_string('norm_region', norm_region, ['ACROSS_CHANNELS'], self.name)
validator.check_string(norm_region, ['ACROSS_CHANNELS'], 'norm_region', self.name)
validator.check_integer("depth_radius", depth_radius, 0, Rel.GE, self.name)
def infer_dtype(self, x_dtype):

View File

@ -21,7 +21,7 @@ import time
import threading
import mindspore.context as context
from mindspore import log as logger
from mindspore._checkparam import check_bool, check_int_non_negative
from mindspore._checkparam import Validator, check_int_non_negative
from mindspore.train._utils import _make_directory
from mindspore.train.serialization import save_checkpoint, _save_graph
from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank
@ -132,8 +132,8 @@ class CheckpointConfig:
if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0:
self._keep_checkpoint_max = 1
self._integrated_save = check_bool(integrated_save)
self._async_save = check_bool(async_save)
self._integrated_save = Validator.check_bool(integrated_save)
self._async_save = Validator.check_bool(async_save)
@property
def save_checkpoint_steps(self):

View File

@ -16,7 +16,7 @@
import math
import os
from mindspore._checkparam import check_bool, check_int
from mindspore._checkparam import Validator, check_int
from .. import context, nn
from ._utils import _exec_datagraph, _get_types_and_shapes, _construct_tensor_list
from ..nn.wrap import GetNextSingleOp
@ -123,7 +123,7 @@ class DatasetHelper:
"""
def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1):
check_bool(dataset_sink_mode)
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
check_int(sink_size)
if sink_size < -1 or sink_size == 0:
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))

View File

@ -22,7 +22,7 @@ import numpy as np
from mindspore import log as logger
from ..common.tensor import Tensor
from ..nn.metrics import get_metrics
from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool, check_int
from .._checkparam import check_input_data, check_output_data, check_int_positive, Validator, check_int
from .callback import _InternalCallbackParam, RunContext, _CallbackManager
from .. import context
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
@ -548,7 +548,7 @@ class Model:
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
>>> model.train(2, dataset)
"""
check_bool(dataset_sink_mode)
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
if sink_size == -1:
sink_size = train_dataset.get_dataset_size()
check_int(sink_size)
@ -664,7 +664,7 @@ class Model:
>>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
>>> model.eval(dataset)
"""
check_bool(dataset_sink_mode)
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
_device_number_check(self._parallel_mode, self._device_number)
if not self._metric_fns:
raise ValueError("metric fn can not be None or empty.")

View File

@ -22,8 +22,7 @@ import mindspore.context as context
from ... import log as logger
from ... import nn, ops
from ..._checkparam import Validator
from ..._checkparam import Rel
from ..._checkparam import Validator, Rel
from ...common import Tensor
from ...common import dtype as mstype
from ...common.api import _executor
@ -92,16 +91,16 @@ class ConvertToQuantNetwork:
self.network = Validator.check_isinstance('network', kwargs["network"], (nn.Cell,))
self.weight_qdelay = Validator.check_integer("quant delay", kwargs["quant_delay"][0], 0, Rel.GE)
self.act_qdelay = Validator.check_integer("quant delay", kwargs["quant_delay"][-1], 0, Rel.GE)
self.bn_fold = Validator.check_bool("bn fold", kwargs["bn_fold"])
self.bn_fold = Validator.check_bool(kwargs["bn_fold"], "bn fold")
self.freeze_bn = Validator.check_integer("freeze bn", kwargs["freeze_bn"], 0, Rel.GE)
self.weight_bits = Validator.check_integer("weights bit", kwargs["num_bits"][0], 0, Rel.GE)
self.act_bits = Validator.check_integer("activations bit", kwargs["num_bits"][-1], 0, Rel.GE)
self.weight_channel = Validator.check_bool("per channel", kwargs["per_channel"][0])
self.act_channel = Validator.check_bool("per channel", kwargs["per_channel"][-1])
self.weight_symmetric = Validator.check_bool("symmetric", kwargs["symmetric"][0])
self.act_symmetric = Validator.check_bool("symmetric", kwargs["symmetric"][-1])
self.weight_range = Validator.check_bool("narrow range", kwargs["narrow_range"][0])
self.act_range = Validator.check_bool("narrow range", kwargs["narrow_range"][-1])
self.weight_channel = Validator.check_bool(kwargs["per_channel"][0], "per channel")
self.act_channel = Validator.check_bool(kwargs["per_channel"][-1], "per channel")
self.weight_symmetric = Validator.check_bool(kwargs["symmetric"][0], "symmetric")
self.act_symmetric = Validator.check_bool(kwargs["symmetric"][-1], "symmetric")
self.weight_range = Validator.check_bool(kwargs["narrow_range"][0], "narrow range")
self.act_range = Validator.check_bool(kwargs["narrow_range"][-1], "narrow range")
self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv,
quant.DenseBnAct: self._convert_dense}

View File

@ -15,7 +15,7 @@
"""Dataset help for minddata dataset"""
import math
import os
from mindspore._checkparam import check_bool, check_int
from mindspore._checkparam import Validator, check_int
from mindspore import context
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes
from mindspore.nn.wrap import GetNextSingleOp
@ -61,7 +61,7 @@ class DatasetHelper:
"""
def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1, iter_first_order=1):
check_bool(dataset_sink_mode)
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
check_int(sink_size)
if sink_size < -1 or sink_size == 0:
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))

View File

@ -18,8 +18,7 @@ from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore._checkparam import check_bool
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Validator
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.parallel._utils import _get_device_num, _get_gradients_mean
from src.grad_reducer_thor import DistributedGradReducerThor
@ -53,12 +52,12 @@ class THOR_GPU(Optimizer):
def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max,
weight_decay=0.0, loss_scale=1.0, use_nesterov=False, decay_filter=lambda x: x.name not in []):
super(THOR_GPU, self).__init__(learning_rate, params, weight_decay, loss_scale)
validator.check_value_type("momentum", momentum, [float], self.cls_name)
Validator.check_value_type("momentum", momentum, [float], self.cls_name)
if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
self.params = self.parameters
self.use_nesterov = check_bool(use_nesterov)
self.use_nesterov = Validator.check_bool(use_nesterov)
self.moments = self.params.clone(prefix="moments", init='zeros')
self.hyper_map = C.HyperMap()
self.opt = P.ApplyMomentum(use_nesterov=self.use_nesterov)

View File

@ -16,7 +16,7 @@
import numpy as np
import mindspore.common.dtype as mstype
from mindspore._checkparam import check_bool, twice, check_int_positive
from mindspore._checkparam import Validator, twice, check_int_positive
from mindspore._extends import cell_attr_register
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
@ -111,7 +111,7 @@ class _Conv(Cell):
self.weight = Parameter(initializer(
weight_init, [out_channels, in_channels // group, *kernel_size]), name='weight')
if check_bool(has_bias):
if Validator.check_bool(has_bias):
self.bias = Parameter(_initializer(
bias_init, [out_channels]), name='bias')
else:
@ -294,7 +294,7 @@ class Dense_Thor_GPU(Cell):
super(Dense_Thor_GPU, self).__init__()
self.in_channels = check_int_positive(in_channels)
self.out_channels = check_int_positive(out_channels)
self.has_bias = check_bool(has_bias)
self.has_bias = Validator.check_bool(has_bias)
self.thor = True
if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
@ -643,7 +643,7 @@ class Dense_Thor(Cell):
super(Dense_Thor, self).__init__()
self.in_channels = check_int_positive(in_channels)
self.out_channels = check_int_positive(out_channels)
self.has_bias = check_bool(has_bias)
self.has_bias = Validator.check_bool(has_bias)
self.thor = True
self.batch_size = batch_size
if isinstance(weight_init, Tensor):

View File

@ -19,7 +19,7 @@ from mindspore.ops import functional as F
from mindspore._extends import cell_attr_register
from mindspore import Tensor, Parameter
from mindspore.common.initializer import initializer
from mindspore._checkparam import check_int_positive, check_bool
from mindspore._checkparam import check_int_positive, Validator
from mindspore.nn.layer.activation import get_activation
@ -74,7 +74,7 @@ class GNNFeatureTransform(nn.Cell):
super(GNNFeatureTransform, self).__init__()
self.in_channels = check_int_positive(in_channels)
self.out_channels = check_int_positive(out_channels)
self.has_bias = check_bool(has_bias)
self.has_bias = Validator.check_bool(has_bias)
if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
@ -284,7 +284,7 @@ class AttentionHead(nn.Cell):
self.matmul = P.MatMul()
self.bias_add = P.BiasAdd()
self.bias = Parameter(initializer('zeros', self.out_channel), name='bias')
self.residual = check_bool(residual)
self.residual = Validator.check_bool(residual)
if self.residual:
if in_channel != out_channel:
self.residual_transform_flag = True
@ -458,7 +458,7 @@ class GAT(nn.Cell):
self.attn_drop = attn_drop
self.ftr_drop = ftr_drop
self.activation = activation
self.residual = check_bool(residual)
self.residual = Validator.check_bool(residual)
self.layers = []
# first layer
self.layers.append(AttentionAggregator(

View File

@ -16,7 +16,7 @@
import os
from mindspore import context
from mindspore._checkparam import check_bool, check_int
from mindspore._checkparam import Validator, check_int
from mindspore.parallel._utils import _get_device_num, _need_to_full, _to_full_shapes
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes
@ -58,7 +58,7 @@ class DatasetHelper:
"""
def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1, iter_first_order=0):
check_bool(dataset_sink_mode)
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
check_int(sink_size)
if sink_size < -1 or sink_size == 0:
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))

View File

@ -22,7 +22,7 @@ from mindspore._c_expression import init_exec_dataset
from mindspore import context
from mindspore import log as logger
from mindspore import nn
from mindspore._checkparam import check_input_data, check_output_data, check_int_positive, check_bool, check_int
from mindspore._checkparam import check_input_data, check_output_data, check_int_positive, Validator, check_int
from mindspore.common import dtype as mstype
from mindspore.common.dtype import pytype_to_dtype
from mindspore.common.tensor import Tensor
@ -603,7 +603,7 @@ class Model:
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
>>> model.train(2, dataset)
"""
check_bool(dataset_sink_mode)
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
check_int(sink_size)
if sink_size < -1 or sink_size == 0:
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))
@ -718,7 +718,7 @@ class Model:
>>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
>>> model.eval(dataset)
"""
check_bool(dataset_sink_mode)
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
_device_number_check(self._parallel_mode, self._device_number)
if not self._metric_fns:
raise ValueError("metric fn can not be None or empty.")

View File

@ -15,7 +15,7 @@
"""thor_layer"""
import numpy as np
import mindspore.common.dtype as mstype
from mindspore._checkparam import check_bool, check_int_positive
from mindspore._checkparam import Validator, check_int_positive
from mindspore.common.initializer import TruncatedNormal, initializer
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
@ -162,7 +162,7 @@ class Dense_Thor(Cell):
super(Dense_Thor, self).__init__()
self.in_channels = check_int_positive(in_channels)
self.out_channels = check_int_positive(out_channels)
self.has_bias = check_bool(has_bias)
self.has_bias = Validator.check_bool(has_bias)
self.thor = True
if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \

View File

@ -15,7 +15,7 @@
"""Aggregator."""
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore._checkparam import check_int_positive, check_bool
from mindspore._checkparam import check_int_positive, Validator
from mindspore._extends import cell_attr_register
from mindspore.common.initializer import initializer
from mindspore.nn.layer.activation import get_activation
@ -75,7 +75,7 @@ class GNNFeatureTransform(nn.Cell):
super(GNNFeatureTransform, self).__init__()
self.in_channels = check_int_positive(in_channels)
self.out_channels = check_int_positive(out_channels)
self.has_bias = check_bool(has_bias)
self.has_bias = Validator.check_bool(has_bias)
if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
@ -284,7 +284,7 @@ class AttentionHead(nn.Cell):
self.batch_matmul = P.BatchMatMul()
self.bias_add = P.BiasAdd()
self.bias = Parameter(initializer('zeros', self.out_channel), name='bias')
self.residual = check_bool(residual)
self.residual = Validator.check_bool(residual)
if self.residual:
if in_channel != out_channel:
self.residual_transform_flag = True

View File

@ -14,7 +14,7 @@
# ============================================================================
"""Graph Attention Networks."""
import mindspore.nn as nn
from mindspore._checkparam import check_bool, check_int_positive
from mindspore._checkparam import Validator, check_int_positive
from aggregator import AttentionAggregator
@ -79,7 +79,7 @@ class GAT(nn.Cell):
self.attn_drop = attn_drop
self.ftr_drop = ftr_drop
self.activation = activation
self.residual = check_bool(residual)
self.residual = Validator.check_bool(residual)
self.layers = []
# first layer
self.layers.append(AttentionAggregator(

View File

@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""Dataset help for minddata dataset"""
from mindspore._checkparam import check_bool
from mindspore._checkparam import Validator
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _to_full_shapes
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes
from mindspore.context import ParallelMode
@ -50,7 +50,7 @@ class DatasetHelper:
"""
def __init__(self, dataset, dataset_sink_mode=True, iter_first_order=0):
check_bool(dataset_sink_mode)
Validator.check_bool(dataset_sink_mode)
self.iter = _DatasetIterMSLoopSink(dataset, iter_first_order)
def __iter__(self):

View File

@ -19,7 +19,7 @@ from mindspore import context
from mindspore import log as logger
from mindspore import nn
from mindspore._c_expression import init_exec_dataset
from mindspore._checkparam import check_input_data, check_output_data, check_int_positive, check_bool
from mindspore._checkparam import check_input_data, check_output_data, check_int_positive, Validator
from mindspore.common import dtype as mstype
from mindspore.common.dtype import pytype_to_dtype
from mindspore.common.tensor import Tensor
@ -575,7 +575,7 @@ class Model:
repeat_count = train_dataset.get_repeat_count()
if epoch != repeat_count and dataset_sink_mode is True:
logger.warning(f"The epoch_size {epoch} is not the same with dataset repeat_count {repeat_count}")
check_bool(dataset_sink_mode)
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
_device_number_check(self._parallel_mode, self._device_number)
_parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
@ -682,7 +682,7 @@ class Model:
>>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
>>> model.eval(dataset)
"""
check_bool(dataset_sink_mode)
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
_device_number_check(self._parallel_mode, self._device_number)
if not self._metric_fns:
raise ValueError("metric fn can not be None or empty.")

View File

@ -16,7 +16,7 @@
import numpy as np
import mindspore as ms
import mindspore.common.dtype as mstype
from mindspore._checkparam import check_bool, twice, check_int_positive
from mindspore._checkparam import Validator, twice, check_int_positive
from mindspore._extends import cell_attr_register
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
@ -87,7 +87,7 @@ class _Conv(Cell):
self.weight = Parameter(initializer(
weight_init, [out_channels, in_channels // group, *kernel_size]), name='weight')
if check_bool(has_bias):
if Validator.check_bool(has_bias):
self.bias = Parameter(_initializer(
bias_init, [out_channels]), name='bias')
else:
@ -339,7 +339,7 @@ class Dense_Thor(Cell):
super(Dense_Thor, self).__init__()
self.in_channels = check_int_positive(in_channels)
self.out_channels = check_int_positive(out_channels)
self.has_bias = check_bool(has_bias)
self.has_bias = Validator.check_bool(has_bias)
self.thor = True
if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \

View File

@ -16,7 +16,7 @@
import pytest
from mindspore._checkparam import check_int, check_int_positive, \
check_input_format, check_bool, twice
check_input_format, Validator, twice
kernel_size = 5
kernel_size1 = twice(kernel_size)
@ -66,26 +66,26 @@ def test_check_int_5():
def test_check_bool_1():
assert check_bool(True)
assert Validator.check_bool(True)
def test_check_bool_2():
assert check_bool(False) is not True
assert Validator.check_bool(False) is not True
def test_check_bool_3():
with pytest.raises(TypeError):
check_bool("str")
Validator.check_bool("str")
def test_check_bool_4():
with pytest.raises(TypeError):
check_bool(1)
Validator.check_bool(1)
def test_check_bool_5():
with pytest.raises(TypeError):
check_bool(3.5)
Validator.check_bool(3.5)
def test_twice_1():

View File

@ -16,7 +16,7 @@
import pytest
from mindspore._checkparam import check_int, check_int_positive, \
check_bool, check_input_format, _expand_tuple
Validator, check_input_format, _expand_tuple
once = _expand_tuple(1)
twice = _expand_tuple(2)
@ -60,26 +60,26 @@ def test_check_int_4():
def test_check_bool_1():
assert check_bool(True)
assert Validator.check_bool(True)
def test_check_bool_2():
assert check_bool(False) is not True
assert Validator.check_bool(False) is not True
def test_check_bool_3():
with pytest.raises(TypeError):
check_bool("str")
Validator.check_bool("str")
def test_check_bool_4():
with pytest.raises(TypeError):
check_bool(1)
Validator.check_bool(1)
def test_check_bool_5():
with pytest.raises(TypeError):
check_bool(3.5)
Validator.check_bool(3.5)
def test_twice_1():