From d471d32e873fd2695fe8aa1321a0a2f69e3dff1c Mon Sep 17 00:00:00 2001 From: chenzomi Date: Sat, 10 Oct 2020 11:34:16 +0800 Subject: [PATCH] [ME] change `check_integer` to format `check_positive_int` and `check_integeter` --- mindspore/_checkparam.py | 99 ++++++++++++----- mindspore/nn/dynamic_lr.py | 26 ++--- mindspore/nn/layer/basic.py | 6 +- mindspore/nn/layer/conv.py | 8 +- mindspore/nn/layer/embedding.py | 4 +- mindspore/nn/layer/lstm.py | 6 +- mindspore/nn/layer/math.py | 4 +- mindspore/nn/layer/normalization.py | 10 +- mindspore/nn/layer/quant.py | 14 +-- mindspore/nn/learning_rate_schedule.py | 8 +- .../bnn_layers/dense_variational.py | 6 +- mindspore/nn/probability/dpn/vae/cvae.py | 10 +- mindspore/nn/probability/dpn/vae/vae.py | 8 +- .../nn/probability/infer/variational/svi.py | 4 +- .../toolbox/uncertainty_evaluation.py | 6 +- mindspore/ops/_utils/utils.py | 4 +- .../multitype_ops/_constexpr_utils.py | 26 +---- mindspore/ops/composite/random_ops.py | 4 +- mindspore/ops/operations/_grad_ops.py | 18 ++-- mindspore/ops/operations/array_ops.py | 26 ++--- mindspore/ops/operations/comm_ops.py | 4 +- mindspore/ops/operations/nn_ops.py | 23 ++-- mindspore/ops/operations/random_ops.py | 100 +++++++++--------- mindspore/train/callback/_checkpoint.py | 10 +- mindspore/train/loss_scale_manager.py | 4 +- mindspore/train/model.py | 4 +- .../official/cv/resnet_thor/src/thor_layer.py | 10 +- model_zoo/official/gnn/gat/src/gat.py | 16 +-- .../official/nlp/bert_thor/src/model_thor.py | 4 +- .../official/nlp/bert_thor/src/thor_layer.py | 6 +- tests/st/gnn/aggregator.py | 10 +- tests/st/gnn/gat.py | 8 +- .../models/resnet50/src_thor/model_thor.py | 4 +- .../models/resnet50/src_thor/thor_layer.py | 6 +- tests/ut/python/nn/test_checkparameter.py | 5 +- .../pynative_mode/nn/test_checkparameter.py | 5 +- tests/vm_impl/vm_me.py | 8 +- 37 files changed, 272 insertions(+), 252 deletions(-) diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index a3e27a4449b..6e8d6d8b4ec 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -92,6 +92,25 @@ rel_strs = { } +def _check_integer(arg_value, value, rel, arg_name=None, prim_name=None): + """ + Check argument integer. + + Usage: + - number = check_integer(number, 0, Rel.GE, "number", None) # number >= 0 + """ + rel_fn = Rel.get_fns(rel) + type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool) + type_except = TypeError if type_mismatch else ValueError + if type_mismatch or not rel_fn(arg_value, value): + rel_str = Rel.get_strs(rel).format(value) + arg_name = arg_name if arg_name else "parameter" + msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" + raise type_except(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got `{arg_value}`' + f' with type `{type(arg_value).__name__}`.') + return arg_value + + class Validator: """validator for checking input parameters""" @@ -121,6 +140,49 @@ class Validator: f' with type `{type(arg_value).__name__}`.') return arg_value + @staticmethod + def check_positive_int(arg_value, arg_name=None, prim_name=None): + """ + Check argument is positive integer, which mean arg_value > 0. + + Usage: + - number = check_positive_int(number) + - number = check_positive_int(number, "bias") + """ + return _check_integer(arg_value, 0, Rel.GT, arg_name, prim_name) + + @staticmethod + def check_negative_int(arg_value, arg_name=None, prim_name=None): + """ + Check argument is negative integer, which mean arg_value < 0. + + Usage: + - number = check_negative_int(number) + - number = check_negative_int(number, "bias") + """ + return _check_integer(arg_value, 0, Rel.LT, arg_name, prim_name) + + @staticmethod + def check_non_positive_int(arg_value, arg_name=None, prim_name=None): + """ + Check argument is non-negative integer, which mean arg_value <= 0. + + Usage: + - number = check_non_positive_int(number) + - number = check_non_positive_int(number, "bias") + """ + return _check_integer(arg_value, 0, Rel.LE, arg_name, prim_name) + + @staticmethod + def check_non_negative_int(arg_value, arg_name=None, prim_name=None): + """ + Check argument is non-negative integer, which mean arg_value >= 0. + + Usage: + - number = check_non_negative_int(number) + - number = check_non_negative_int(number, "bias") + """ + return _check_integer(arg_value, 0, Rel.GE, arg_name, prim_name) @staticmethod def check_number(arg_name, arg_value, value, rel, prim_name): @@ -140,7 +202,13 @@ class Validator: @staticmethod def check_bool(arg_value, arg_name=None): - """Check argument is instance of bool""" + """ + Check argument is instance of bool. + + Usage: + - has_bias = check_bool(has_bias) + - has_bias = check_bool(has_bias, "has_bias") + """ if not isinstance(arg_value, bool): arg_name = arg_name if arg_name else "Parameter" raise TypeError(f'`{arg_name}` should be isinstance of bool, but got `{arg_value}`.') @@ -169,7 +237,12 @@ class Validator: @staticmethod def check_string(arg_value, valid_values, arg_name=None, prim_name=None): - """Checks whether a string is in some value list""" + """ + Check whether string is in some value list. + + Usage: + - method = check_string(method, ["string1", "string2", "string3"], "method") + """ if isinstance(arg_value, str) and arg_value in valid_values: return arg_value arg_name = arg_name if arg_name else "Parameter" @@ -372,28 +445,6 @@ def check_int(input_param): raise TypeError("Input type must be int!") -def check_int_positive(input_param): - """Int type judgment.""" - if isinstance(input_param, bool): - raise TypeError("Input type must be int cannot be bool!") - if isinstance(input_param, int): - if input_param > 0: - return input_param - raise ValueError("The input_param must be positive, but got input_param {}.".format(input_param)) - raise TypeError("Input type must be int cannot be {}!".format(type(input_param))) - - -def check_int_non_negative(input_param): - """Non_negative type judgment.""" - if isinstance(input_param, bool): - raise TypeError("Input type must be int cannot be bool!") - if isinstance(input_param, int): - if input_param >= 0: - return input_param - raise ValueError("The input_param must be non_negative, but got input_param {}.".format(input_param)) - raise TypeError("Input type must be int cannot be {}!".format(type(input_param))) - - def check_int_zero_one(input_param): """Judge whether it is 0 or 1.""" if input_param in (0, 1): diff --git a/mindspore/nn/dynamic_lr.py b/mindspore/nn/dynamic_lr.py index 98a72c444ce..3d36305e4ce 100644 --- a/mindspore/nn/dynamic_lr.py +++ b/mindspore/nn/dynamic_lr.py @@ -52,7 +52,7 @@ def piecewise_constant_lr(milestone, learning_rates): lr = [] last_item = 0 for i, item in enumerate(milestone): - validator.check_integer(f'milestone[{i}]', item, 0, Rel.GT, None) + validator.check_positive_int(item, f'milestone[{i}]') validator.check_float_legal_value(f'learning_rates[{i}]', learning_rates[i], None) if item < last_item: raise ValueError(f'The value of milestone[{i}] must be greater than milestone[{i - 1}]') @@ -63,9 +63,9 @@ def piecewise_constant_lr(milestone, learning_rates): def _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair): - validator.check_integer('total_step', total_step, 0, Rel.GT, None) - validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) - validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None) + validator.check_positive_int(total_step, 'total_step') + validator.check_positive_int(step_per_epoch, 'step_per_epoch') + validator.check_positive_int(decay_epoch, 'decay_epoch') validator.check_float_positive('learning_rate', learning_rate, None) validator.check_float_legal_value('learning_rate', learning_rate, None) validator.check_float_positive('decay_rate', decay_rate, None) @@ -236,9 +236,9 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch): validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, None) validator.check_float_positive('max_lr', max_lr, None) validator.check_float_legal_value('max_lr', max_lr, None) - validator.check_integer('total_step', total_step, 0, Rel.GT, None) - validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) - validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None) + validator.check_positive_int(total_step, 'total_step') + validator.check_positive_int(step_per_epoch, 'step_per_epoch') + validator.check_positive_int(decay_epoch, 'decay_epoch') if min_lr >= max_lr: raise ValueError('`max_lr` should be greater than `min_lr`.') @@ -306,9 +306,9 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None) validator.check_float_positive('power', power, None) validator.check_float_legal_value('power', power, None) - validator.check_integer('total_step', total_step, 0, Rel.GT, None) - validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) - validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None) + validator.check_positive_int(total_step, 'total_step') + validator.check_positive_int(step_per_epoch, 'step_per_epoch') + validator.check_positive_int(decay_epoch, 'decay_epoch') validator.check_value_type('update_decay_epoch', update_decay_epoch, [bool], None) origin_decay_epoch = decay_epoch @@ -357,9 +357,9 @@ def warmup_lr(learning_rate, total_step, step_per_epoch, warmup_epoch): if not isinstance(learning_rate, float): raise TypeError("learning_rate must be float.") validator.check_number_range("learning_rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None) - validator.check_integer('warmup_epoch', warmup_epoch, 0, Rel.GT, None) - validator.check_integer('total_step', total_step, 0, Rel.GT, None) - validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) + validator.check_positive_int(warmup_epoch, 'warmup_epoch') + validator.check_positive_int(total_step, 'total_step') + validator.check_positive_int(step_per_epoch, 'step_per_epoch') function = lambda x, y: (x, min(x, y)) diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index 4acca2e9c26..8cee69aab77 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -27,7 +27,7 @@ from mindspore.ops.operations import _inner_ops as inner from mindspore.ops.primitive import constexpr from mindspore.common.parameter import Parameter from mindspore._extends import cell_attr_register -from mindspore._checkparam import Rel, Validator, check_int_positive +from mindspore._checkparam import Rel, Validator from mindspore.common.api import ms_function from mindspore import context from ..cell import Cell @@ -203,8 +203,8 @@ class Dense(Cell): has_bias=True, activation=None): super(Dense, self).__init__() - self.in_channels = check_int_positive(in_channels) - self.out_channels = check_int_positive(out_channels) + self.in_channels = Validator.check_positive_int(in_channels) + self.out_channels = Validator.check_positive_int(out_channels) self.has_bias = Validator.check_bool(has_bias) if isinstance(weight_init, Tensor): diff --git a/mindspore/nn/layer/conv.py b/mindspore/nn/layer/conv.py index 20fe0582a81..c5ba7999190 100644 --- a/mindspore/nn/layer/conv.py +++ b/mindspore/nn/layer/conv.py @@ -21,7 +21,7 @@ from mindspore.ops.primitive import constexpr from mindspore.common.parameter import Parameter from mindspore.common.initializer import initializer, Initializer from mindspore.common.tensor import Tensor -from mindspore._checkparam import Validator, Rel, twice, check_int_positive +from mindspore._checkparam import Validator, Rel, twice from mindspore._extends import cell_attr_register from ..cell import Cell @@ -47,8 +47,8 @@ class _Conv(Cell): bias_init, transposed=False): super(_Conv, self).__init__() - self.in_channels = check_int_positive(in_channels) - self.out_channels = check_int_positive(out_channels) + self.in_channels = Validator.check_positive_int(in_channels) + self.out_channels = Validator.check_positive_int(out_channels) self.kernel_size = kernel_size self.stride = stride self.pad_mode = pad_mode @@ -65,7 +65,7 @@ class _Conv(Cell): raise TypeError("padding type must be int/tuple(int) cannot be {}!".format(type(padding))) self.dilation = dilation - self.group = check_int_positive(group) + self.group = Validator.check_positive_int(group) self.has_bias = has_bias if (not isinstance(kernel_size[0], int)) or (not isinstance(kernel_size[1], int)) or \ isinstance(kernel_size[0], bool) or isinstance(kernel_size[1], bool) or \ diff --git a/mindspore/nn/layer/embedding.py b/mindspore/nn/layer/embedding.py index d5b7346071c..479ed4dea80 100755 --- a/mindspore/nn/layer/embedding.py +++ b/mindspore/nn/layer/embedding.py @@ -21,7 +21,7 @@ from mindspore.common.initializer import initializer from mindspore.communication.management import get_group_size from mindspore.context import ParallelMode from mindspore.parallel._utils import _get_parallel_mode -from mindspore._checkparam import Rel, Validator as validator +from mindspore._checkparam import Validator as validator from ..cell import Cell __all__ = ['Embedding', 'EmbeddingLookup'] @@ -170,7 +170,7 @@ class EmbeddingLookup(Cell): if not isinstance(manual_shapes, tuple): raise TypeError("manual_shapes type must be tuple(int) cannot be {}!".format(type(manual_shapes))) for dim in manual_shapes: - validator.check_integer('manul shape dim', dim, 0, Rel.GT, self.cls_name) + validator.check_positive_int(dim, 'manual shape dim', self.cls_name) self.gatherv2.add_prim_attr("manual_split", manual_shapes) self.embeddinglookup.add_prim_attr("manual_split", manual_shapes) self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size()))) diff --git a/mindspore/nn/layer/lstm.py b/mindspore/nn/layer/lstm.py index 9b42f55cb44..a3ae9ca67d7 100755 --- a/mindspore/nn/layer/lstm.py +++ b/mindspore/nn/layer/lstm.py @@ -15,7 +15,7 @@ """lstm""" import math import numpy as np -from mindspore._checkparam import Rel, Validator as validator +from mindspore._checkparam import Validator as validator from mindspore.common.initializer import initializer from mindspore.common.parameter import Parameter from mindspore.common.tensor import Tensor @@ -103,8 +103,8 @@ class LSTM(Cell): bidirectional=False): super(LSTM, self).__init__() validator.check_value_type("batch_first", batch_first, [bool], self.cls_name) - validator.check_integer("hidden_size", hidden_size, 0, Rel.GT, self.cls_name) - validator.check_integer("num_layers", num_layers, 0, Rel.GT, self.cls_name) + validator.check_positive_int(hidden_size, "hidden_size", self.cls_name) + validator.check_positive_int(num_layers, "num_layers", self.cls_name) self.batch_first = batch_first self.transpose = P.Transpose() diff --git a/mindspore/nn/layer/math.py b/mindspore/nn/layer/math.py index bfd32f43a8b..6929727156f 100644 --- a/mindspore/nn/layer/math.py +++ b/mindspore/nn/layer/math.py @@ -21,7 +21,7 @@ from mindspore.common.tensor import Tensor from mindspore.ops.primitive import constexpr from ..cell import Cell from ...common import dtype as mstype -from ..._checkparam import Rel, Validator as validator +from ..._checkparam import Validator as validator __all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'MatMul'] @@ -156,7 +156,7 @@ class LinSpace(Cell): validator.check_value_type("start", start, [int, float], self.cls_name) validator.check_value_type("stop", stop, [int, float], self.cls_name) validator.check_value_type("num", num, [int], self.cls_name) - validator.check_integer("num", num, 0, Rel.GT, self.cls_name) + validator.check_positive_int(num, "num", self.cls_name) self.is_single = bool(num == 1) self.lin_space = inner.LinSpace() diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 3145514e934..f340c959ecb 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -19,7 +19,7 @@ from mindspore.common.parameter import Parameter from mindspore.common.initializer import initializer from mindspore.ops.primitive import constexpr import mindspore.context as context -from mindspore._checkparam import Validator, check_typename, check_int_positive +from mindspore._checkparam import Validator, check_typename from mindspore._extends import cell_attr_register from mindspore.communication.management import get_group_size, get_rank from mindspore.communication import management @@ -64,7 +64,7 @@ class _BatchNorm(Cell): gamma_init, num_features), name="gamma", requires_grad=affine) self.beta = Parameter(initializer( beta_init, num_features), name="beta", requires_grad=affine) - self.group = check_int_positive(device_num_each_group) + self.group = Validator.check_positive_int(device_num_each_group) self.is_global = False if self.group != 1: self.rank_id = get_rank() @@ -464,7 +464,7 @@ class GlobalBatchNorm(_BatchNorm): use_batch_statistics, device_num_each_group, input_dims='both') - self.group = check_int_positive(device_num_each_group) + self.group = Validator.check_positive_int(device_num_each_group) if self.group <= 1: raise ValueError("the number of group must be greater than 1.") @@ -599,8 +599,8 @@ class GroupNorm(Cell): def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, gamma_init='ones', beta_init='zeros'): super(GroupNorm, self).__init__() - self.num_groups = check_int_positive(num_groups) - self.num_channels = check_int_positive(num_channels) + self.num_groups = Validator.check_positive_int(num_groups) + self.num_channels = Validator.check_positive_int(num_channels) if num_channels % num_groups != 0: raise ValueError("num_channels should be divided by num_groups") self.eps = check_typename('eps', eps, (float,)) diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 793e9612b52..16104aba2b0 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -23,7 +23,7 @@ from mindspore.ops import functional as F from mindspore.common.parameter import Parameter from mindspore.common.initializer import initializer from mindspore.common.tensor import Tensor -from mindspore._checkparam import Validator, Rel, check_int_positive, twice +from mindspore._checkparam import Validator, Rel, twice import mindspore.context as context from .normalization import BatchNorm2d, BatchNorm1d from .activation import get_activation, ReLU, LeakyReLU @@ -657,8 +657,8 @@ class Conv2dBnWithoutFoldQuant(Cell): self.kernel_size = (kernel_size, kernel_size) else: self.kernel_size = kernel_size - self.in_channels = check_int_positive(in_channels) - self.out_channels = check_int_positive(out_channels) + self.in_channels = Validator.check_positive_int(in_channels) + self.out_channels = Validator.check_positive_int(out_channels) self.has_bias = has_bias self.stride = twice(stride) self.dilation = twice(dilation) @@ -785,8 +785,8 @@ class Conv2dQuant(Cell): self.kernel_size = (kernel_size, kernel_size) else: self.kernel_size = kernel_size - self.in_channels = check_int_positive(in_channels) - self.out_channels = check_int_positive(out_channels) + self.in_channels = Validator.check_positive_int(in_channels) + self.out_channels = Validator.check_positive_int(out_channels) self.has_bias = has_bias self.stride = twice(stride) self.dilation = twice(dilation) @@ -886,8 +886,8 @@ class DenseQuant(Cell): narrow_range=False, quant_delay=0): super(DenseQuant, self).__init__() - self.in_channels = check_int_positive(in_channels) - self.out_channels = check_int_positive(out_channels) + self.in_channels = Validator.check_positive_int(in_channels) + self.out_channels = Validator.check_positive_int(out_channels) self.has_bias = Validator.check_bool(has_bias) if isinstance(weight_init, Tensor): diff --git a/mindspore/nn/learning_rate_schedule.py b/mindspore/nn/learning_rate_schedule.py index 8aea68599f1..0cddd2c6c5e 100644 --- a/mindspore/nn/learning_rate_schedule.py +++ b/mindspore/nn/learning_rate_schedule.py @@ -44,7 +44,7 @@ class LearningRateSchedule(Cell): def _check_inputs(learning_rate, decay_rate, decay_steps, is_stair, cls_name): - validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, cls_name) + validator.check_positive_int(decay_steps, 'decay_steps', cls_name) validator.check_float_positive('learning_rate', learning_rate, cls_name) validator.check_float_legal_value('learning_rate', learning_rate, cls_name) validator.check_float_positive('decay_rate', decay_rate, cls_name) @@ -257,7 +257,7 @@ class CosineDecayLR(LearningRateSchedule): validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) validator.check_float_positive('max_lr', max_lr, self.cls_name) validator.check_float_legal_value('max_lr', max_lr, self.cls_name) - validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, self.cls_name) + validator.check_positive_int(decay_steps, "decay_steps", self.cls_name) if min_lr >= max_lr: raise ValueError('`max_lr` should be greater than `min_lr`.') self.min_lr = min_lr @@ -324,7 +324,7 @@ class PolynomialDecayLR(LearningRateSchedule): raise TypeError("end_learning_rate must be float.") validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) - validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, self.cls_name) + validator.check_positive_int(decay_steps, 'decay_steps', self.cls_name) validator.check_value_type('update_decay_steps', update_decay_steps, [bool], self.cls_name) validator.check_float_positive('power', power, self.cls_name) validator.check_float_legal_value('power', power, self.cls_name) @@ -388,7 +388,7 @@ class WarmUpLR(LearningRateSchedule): if not isinstance(learning_rate, float): raise TypeError("learning_rate must be float.") validator.check_number_range("learning_rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) - validator.check_integer('warmup_steps', warmup_steps, 0, Rel.GT, self.cls_name) + validator.check_positive_int(warmup_steps, 'warmup_steps', self.cls_name) self.warmup_steps = warmup_steps self.learning_rate = learning_rate self.min = P.Minimum() diff --git a/mindspore/nn/probability/bnn_layers/dense_variational.py b/mindspore/nn/probability/bnn_layers/dense_variational.py index c4bc16c602d..5dcfe5d80b4 100644 --- a/mindspore/nn/probability/bnn_layers/dense_variational.py +++ b/mindspore/nn/probability/bnn_layers/dense_variational.py @@ -15,7 +15,7 @@ """dense_variational""" from mindspore.ops import operations as P from mindspore.common.tensor import Tensor -from mindspore._checkparam import check_int_positive, Validator +from mindspore._checkparam import Validator from ...cell import Cell from ...layer.activation import get_activation from .layer_distribution import NormalPrior, NormalPosterior @@ -39,8 +39,8 @@ class _DenseVariational(Cell): bias_prior_fn=NormalPrior, bias_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape)): super(_DenseVariational, self).__init__() - self.in_channels = check_int_positive(in_channels) - self.out_channels = check_int_positive(out_channels) + self.in_channels = Validator.check_positive_int(in_channels) + self.out_channels = Validator.check_positive_int(out_channels) self.has_bias = Validator.check_bool(has_bias) if isinstance(weight_prior_fn, Cell): diff --git a/mindspore/nn/probability/dpn/vae/cvae.py b/mindspore/nn/probability/dpn/vae/cvae.py index 01577a46774..2028059411a 100644 --- a/mindspore/nn/probability/dpn/vae/cvae.py +++ b/mindspore/nn/probability/dpn/vae/cvae.py @@ -15,7 +15,7 @@ """Conditional Variational auto-encoder (CVAE).""" from mindspore.ops import composite as C from mindspore.ops import operations as P -from mindspore._checkparam import check_int_positive +from mindspore._checkparam import Validator from ....cell import Cell from ....layer.basic import Dense, OneHot @@ -57,11 +57,11 @@ class ConditionalVAE(Cell): self.decoder = decoder if (not isinstance(encoder, Cell)) or (not isinstance(decoder, Cell)): raise TypeError('The encoder and decoder should be Cell type.') - self.hidden_size = check_int_positive(hidden_size) - self.latent_size = check_int_positive(latent_size) + self.hidden_size = Validator.check_positive_int(hidden_size) + self.latent_size = Validator.check_positive_int(latent_size) if hidden_size < latent_size: raise ValueError('The latent_size should be less than or equal to the hidden_size.') - self.num_classes = check_int_positive(num_classes) + self.num_classes = Validator.check_positive_int(num_classes) self.normal = C.normal self.exp = P.Exp() self.reshape = P.Reshape() @@ -108,7 +108,7 @@ class ConditionalVAE(Cell): Returns: Tensor, the generated samples. """ - generate_nums = check_int_positive(generate_nums) + generate_nums = Validator.check_positive_int(generate_nums) if not isinstance(shape, tuple) or len(shape) != 4 or (shape[0] != -1 and shape[0] != generate_nums): raise ValueError('The shape should be (generate_nums, C, H, W) or (-1, C, H, W).') sample_z = self.normal((generate_nums, self.latent_size), self.to_tensor(0.0), self.to_tensor(1.0), seed=0) diff --git a/mindspore/nn/probability/dpn/vae/vae.py b/mindspore/nn/probability/dpn/vae/vae.py index c0aff3b567f..e743cab28d7 100644 --- a/mindspore/nn/probability/dpn/vae/vae.py +++ b/mindspore/nn/probability/dpn/vae/vae.py @@ -15,7 +15,7 @@ """Variational auto-encoder (VAE)""" from mindspore.ops import composite as C from mindspore.ops import operations as P -from mindspore._checkparam import check_int_positive +from mindspore._checkparam import Validator from ....cell import Cell from ....layer.basic import Dense @@ -52,8 +52,8 @@ class VAE(Cell): self.decoder = decoder if (not isinstance(encoder, Cell)) or (not isinstance(decoder, Cell)): raise TypeError('The encoder and decoder should be Cell type.') - self.hidden_size = check_int_positive(hidden_size) - self.latent_size = check_int_positive(latent_size) + self.hidden_size = Validator.check_positive_int(hidden_size) + self.latent_size = Validator.check_positive_int(latent_size) if hidden_size < latent_size: raise ValueError('The latent_size should be less than or equal to the hidden_size.') self.normal = C.normal @@ -94,7 +94,7 @@ class VAE(Cell): Returns: Tensor, the generated samples. """ - generate_nums = check_int_positive(generate_nums) + generate_nums = Validator.check_positive_int(generate_nums) if not isinstance(shape, tuple) or len(shape) != 4 or (shape[0] != -1 and shape[0] != generate_nums): raise ValueError('The shape should be (generate_nums, C, H, W) or (-1, C, H, W).') sample_z = self.normal((generate_nums, self.latent_size), self.to_tensor(0.0), self.to_tensor(1.0), seed=0) diff --git a/mindspore/nn/probability/infer/variational/svi.py b/mindspore/nn/probability/infer/variational/svi.py index 023d56188e6..f40ade88cbe 100644 --- a/mindspore/nn/probability/infer/variational/svi.py +++ b/mindspore/nn/probability/infer/variational/svi.py @@ -15,7 +15,7 @@ """Stochastic Variational Inference(SVI).""" import mindspore.common.dtype as mstype from mindspore.common.tensor import Tensor -from mindspore._checkparam import check_int_positive +from mindspore._checkparam import Validator from ....cell import Cell from ....wrap.cell_wrapper import TrainOneStepCell from .elbo import ELBO @@ -57,7 +57,7 @@ class SVI: Outputs: Cell, the trained probability network. """ - epochs = check_int_positive(epochs) + epochs = Validator.check_positive_int(epochs) train_net = TrainOneStepCell(self.net_with_loss, self.optimizer) train_net.set_train() for _ in range(1, epochs+1): diff --git a/mindspore/nn/probability/toolbox/uncertainty_evaluation.py b/mindspore/nn/probability/toolbox/uncertainty_evaluation.py index f2de2cb4bdb..96ada060899 100644 --- a/mindspore/nn/probability/toolbox/uncertainty_evaluation.py +++ b/mindspore/nn/probability/toolbox/uncertainty_evaluation.py @@ -16,7 +16,7 @@ from copy import deepcopy import numpy as np -from mindspore._checkparam import check_int_positive, Validator +from mindspore._checkparam import Validator from mindspore.ops import composite as C from mindspore.ops import operations as P from mindspore.train import Model @@ -81,7 +81,7 @@ class UncertaintyEvaluation: self.epi_train_dataset = train_dataset self.ale_train_dataset = deepcopy(train_dataset) self.task_type = task_type - self.epochs = check_int_positive(epochs) + self.epochs = Validator.check_positive_int(epochs) self.epi_uncer_model_path = epi_uncer_model_path self.ale_uncer_model_path = ale_uncer_model_path self.save_model = Validator.check_bool(save_model) @@ -95,7 +95,7 @@ class UncertaintyEvaluation: if task_type not in ('regression', 'classification'): raise ValueError('The task should be regression or classification.') if task_type == 'classification': - self.num_classes = check_int_positive(num_classes) + self.num_classes = Validator.check_positive_int(num_classes) else: self.num_classes = num_classes if save_model: diff --git a/mindspore/ops/_utils/utils.py b/mindspore/ops/_utils/utils.py index 9ee599e6efa..25ba930ef96 100644 --- a/mindspore/ops/_utils/utils.py +++ b/mindspore/ops/_utils/utils.py @@ -65,9 +65,9 @@ def get_broadcast_shape(x_shape, y_shape, prim_name): def get_concat_offset(x_shp, x_type, axis, prim_name): """for concat and concatoffset check args and compute offset""" validator.check_value_type("shape", x_shp, [tuple], prim_name) - validator.check_integer("input_x rank", len(x_shp), 0, Rel.GT, prim_name) + validator.check_positive_int(len(x_shp), "input_x rank", prim_name) validator.check_subclass("shape0", x_type[0], mstype.tensor, prim_name) - validator.check_integer("len of x_shp[0]", len(x_shp[0]), 0, Rel.GT, prim_name) + validator.check_positive_int(len(x_shp[0]), "len of x_shp[0]", prim_name) rank_base = len(x_shp[0]) validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH, prim_name) if axis < 0: diff --git a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py index 8c34fc8cd7e..9489991abf7 100644 --- a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py +++ b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ - """constexpr util""" + from functools import reduce import numpy as np @@ -60,30 +60,6 @@ def check_equal(param1, param2, msg="{},{}"): return param1 -@constexpr -def check_int_positive(arg_name, arg_value, op_name): - """Int type judgment.""" - if isinstance(arg_value, bool): - raise TypeError("For \'{}\' the `{}` must be int, cannot be bool.".format(op_name, arg_name)) - if isinstance(arg_value, int): - if arg_value > 0: - return arg_value - raise ValueError("For \'{}\' the `{}` must be positive, but got {}.".format(op_name, arg_name, arg_value)) - raise TypeError("For \'{}\' the `{}` must be int, cannot be {}.".format(op_name, arg_name, type(arg_value))) - - -@constexpr -def check_int_non_negative(arg_name, arg_value, op_name): - """Int type judgment.""" - if isinstance(arg_value, bool): - raise TypeError("For \'{}\' the `{}` must be int, cannot be bool.".format(op_name, arg_name)) - if isinstance(arg_value, int): - if arg_value >= 0: - return arg_value - raise ValueError("For \'{}\' the `{}` must be non_negative, but got {}.".format(op_name, arg_name, arg_value)) - raise TypeError("For \'{}\' the `{}` must be int, cannot be {}.".format(op_name, arg_name, type(arg_value))) - - @constexpr def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size): """Checks the shape and size of the sensor and value.""" diff --git a/mindspore/ops/composite/random_ops.py b/mindspore/ops/composite/random_ops.py index 4871e52806f..e05071823a0 100644 --- a/mindspore/ops/composite/random_ops.py +++ b/mindspore/ops/composite/random_ops.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ - """Operations for random number generators.""" +from mindspore._checkparam import Validator from .. import operations as P from .. import functional as F from ..primitive import constexpr @@ -54,7 +54,7 @@ def get_seed(op_seed, kernel_name): if op_seed is None: temp_seed = _get_op_seed(0, kernel_name) else: - const_utils.check_int_non_negative("seed", op_seed, kernel_name) + Validator.check_non_negative_int(op_seed, "seed", kernel_name) temp_seed = _get_op_seed(op_seed, kernel_name) seeds = _truncate_seed(global_seed), _truncate_seed(temp_seed) _update_seeds(op_seed, kernel_name) diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 2e9b21d3e20..3e15f4bcd58 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -915,9 +915,9 @@ class LSTMGradData(PrimitiveWithInfer): @prim_attr_register def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): - self.input_size = validator.check_integer('input_size', input_size, 0, Rel.GT, self.name) - self.hidden_size = validator.check_integer('hidden_size', hidden_size, 0, Rel.GT, self.name) - self.num_layers = validator.check_integer('num_layers', num_layers, 0, Rel.GT, self.name) + self.input_size = validator.check_positive_int(input_size, 'input_size', self.name) + self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name) + self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name) self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name) self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name) self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) @@ -964,9 +964,9 @@ class LSTMGradWeight(PrimitiveWithInfer): @prim_attr_register def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): - self.input_size = validator.check_integer('input_size', input_size, 0, Rel.GT, self.name) - self.hidden_size = validator.check_integer('hidden_size', hidden_size, 0, Rel.GT, self.name) - self.num_layers = validator.check_integer('num_layers', num_layers, 0, Rel.GT, self.name) + self.input_size = validator.check_positive_int(input_size, 'input_size', self.name) + self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name) + self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name) self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name) self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name) self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) @@ -999,9 +999,9 @@ class LSTMGrad(PrimitiveWithInfer): @prim_attr_register def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): - self.input_size = validator.check_integer('input_size', input_size, 0, Rel.GT, self.name) - self.hidden_size = validator.check_integer('hidden_size', hidden_size, 0, Rel.GT, self.name) - self.num_layers = validator.check_integer('num_layers', num_layers, 0, Rel.GT, self.name) + self.input_size = validator.check_positive_int(input_size, 'input_size', self.name) + self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name) + self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name) self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name) self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name) self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 338ae9d44f2..7fe34490080 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -701,7 +701,7 @@ class Padding(PrimitiveWithInfer): def __init__(self, pad_dim_size=8): """Initialize padding""" validator.check_value_type("pad_dim_size", pad_dim_size, [int], self.name) - validator.check_integer("pad_dim_size", pad_dim_size, 0, Rel.GT, self.name) + validator.check_positive_int(pad_dim_size, "pad_dim_size", self.name) self.pad_dim_size = pad_dim_size def __infer__(self, x): @@ -911,8 +911,8 @@ class Fill(PrimitiveWithInfer): def __infer__(self, dtype, dims, x): validator.check_value_type("shape", dims['value'], [tuple], self.name) validator.check_value_type("value", x['value'], [numbers.Number, bool], self.name) - for idx, item in enumerate(dims['value']): - validator.check_integer("dims[%d]" % idx, item, 0, Rel.GT, self.name) + for i, item in enumerate(dims['value']): + validator.check_positive_int(item, f'dims[{i}]', self.name) valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint32, mstype.uint64, mstype.float16, mstype.float32, mstype.float64] @@ -1482,20 +1482,20 @@ class UnsortedSegmentSum(PrimitiveWithInfer): validator.check_subclass("input_x", x_type, mstype.tensor, self.name) validator.check_value_type("x_shape", x_shp, [list], self.name) x_shp_len = len(x_shp) - validator.check_integer("rank of input_x", x_shp_len, 0, Rel.GT, self.name) + validator.check_positive_int(x_shp_len, "rank of input_x", self.name) segment_ids_shp = segment_ids['shape'] segment_ids_type = segment_ids['dtype'] validator.check_subclass("segment_ids", segment_ids_type, mstype.tensor, self.name) validator.check_value_type("segment_ids", segment_ids_shp, [list], self.name) segment_ids_shp_len = len(segment_ids_shp) - validator.check_integer("rank of segment_ids", segment_ids_shp_len, 0, Rel.GT, self.name) + validator.check_positive_int(segment_ids_shp_len, "rank of segment_ids", self.name) validator.check(f'rank of input_x', len(x_shp), 'rank of segments_id', len(segment_ids_shp), Rel.GE, self.name) for i, value in enumerate(segment_ids_shp): validator.check("ids[%d]" % i, value, 'input[%d]' % i, x_shp[i], Rel.EQ, self.name) num_segments_v = num_segments['value'] validator.check_value_type('num_segments', num_segments_v, [int], self.name) - validator.check_integer("num_segments", num_segments_v, 0, Rel.GT, self.name) + validator.check_positive_int(num_segments_v, "num_segments", self.name) shp = [num_segments_v] shp += x_shp[segment_ids_shp_len:] out = {'shape': shp, @@ -1544,7 +1544,7 @@ class UnsortedSegmentMin(PrimitiveWithInfer): 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) num_segments_v = num_segments['value'] validator.check_value_type('num_segments', num_segments_v, [int], self.name) - validator.check_integer("num_segments", num_segments_v, 0, Rel.GT, self.name) + validator.check_positive_int(num_segments_v, "num_segments", self.name) segment_ids_shape_len = len(segment_ids_shape) out_shape = [num_segments_v] out_shape += x_shape[segment_ids_shape_len:] @@ -1597,7 +1597,7 @@ class UnsortedSegmentProd(PrimitiveWithInfer): 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) num_segments_v = num_segments['value'] validator.check_value_type('num_segments', num_segments_v, [int], self.name) - validator.check_integer("num_segments", num_segments_v, 0, Rel.GT, self.name) + validator.check_positive_int(num_segments_v, "num_segments", self.name) segment_ids_shape_len = len(segment_ids_shape) out_shape = [num_segments_v] out_shape += x_shape[segment_ids_shape_len:] @@ -1832,7 +1832,7 @@ class Unpack(PrimitiveWithInfer): self.axis = self.axis + dim output_num = x_shape[self.axis] validator.check_value_type("num", output_num, [int], self.name) - validator.check_integer("output_num", output_num, 0, Rel.GT, self.name) + validator.check_positive_int(output_num, "output_num", self.name) self.add_prim_attr('num', output_num) output_valid_check = x_shape[self.axis] - output_num validator.check_integer("The dimension which to unpack divides output_num", output_valid_check, 0, Rel.EQ, @@ -2401,8 +2401,8 @@ class Eye(PrimitiveWithInfer): """Initialize Eye""" def infer_value(self, n, m, t): - validator.check_integer("n", n, 0, Rel.GT, self.name) - validator.check_integer("m", m, 0, Rel.GT, self.name) + validator.check_positive_int(n, "n", self.name) + validator.check_positive_int(m, "m", self.name) args = {"dtype": t} validator.check_type_same(args, mstype.number_type + (mstype.bool_,), self.name) np_type = mstype.dtype_to_nptype(t) @@ -2443,7 +2443,7 @@ class ScatterNd(PrimitiveWithInfer): validator.check_tensor_type_same({"indices": indices['dtype']}, [mstype.int32], self.name) validator.check_value_type("shape", shp, [tuple], self.name) for i, x in enumerate(shp): - validator.check_integer("shape[%d]" % i, x, 0, Rel.GT, self.name) + validator.check_positive_int(x, f'shape[{i}]', self.name) indices_shape, update_shape = indices["shape"], update["shape"] if indices_shape[0] != update_shape[0]: @@ -3469,7 +3469,7 @@ class BroadcastTo(PrimitiveWithInfer): validator.check_value_type("shape", shape, (tuple), self.name) validator.check("shape length", len(shape), "", 0, Rel.GT, self.name) for i in shape: - validator.check_integer("shape element", i, 0, Rel.GT, self.name) + validator.check_positive_int(i, "shape element", self.name) self.shape = shape def infer_shape(self, x_shape): diff --git a/mindspore/ops/operations/comm_ops.py b/mindspore/ops/operations/comm_ops.py index 42cdee7e265..7148a1b9bac 100644 --- a/mindspore/ops/operations/comm_ops.py +++ b/mindspore/ops/operations/comm_ops.py @@ -160,7 +160,7 @@ class AllGather(PrimitiveWithInfer): self.add_prim_attr('group', _get_group(group)) def infer_shape(self, x_shape): - validator.check_integer("x shape", len(x_shape), 0, Rel.GT, self.name) + validator.check_positive_int(len(x_shape), "x shape", self.name) x_shape[0] = x_shape[0] * self.rank_size return x_shape @@ -210,7 +210,7 @@ class _HostAllGather(PrimitiveWithInfer): self.add_prim_attr('group', group) def infer_shape(self, x_shape): - validator.check_integer("x shape", len(x_shape), 0, Rel.GT, self.name) + validator.check_positive_int(len(x_shape), "x shape", self.name) x_shape[0] = x_shape[0] * self.group_size return x_shape diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index e5b06ed191f..934f72d2092 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1005,8 +1005,8 @@ class Conv2D(PrimitiveWithInfer): self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name) self.add_prim_attr('data_format', "NCHW") - self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT, self.name) - self.group = validator.check_integer('group', group, 0, Rel.GT, self.name) + self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name) + self.group = validator.check_positive_int(group, 'group', self.name) self.add_prim_attr('offset_a', 0) def infer_shape(self, x_shape, w_shape, b_shape=None): @@ -1142,9 +1142,8 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): validator.check_integer('pad item', item, 0, Rel.GE, self.name) self.mode = validator.check_integer("mode", mode, 3, Rel.EQ, self.name) self.add_prim_attr('data_format', "NCHW") - self.channel_multiplier = validator.check_integer("channel_multiplier", channel_multiplier, 0, Rel.GT, - self.name) - self.group = validator.check_integer("group", group, 0, Rel.GT, self.name) + self.channel_multiplier = validator.check_positive_int(channel_multiplier, "channel_multiplier", self.name) + self.group = validator.check_positive_int(group, "group", self.name) self.add_prim_attr('offset_a', 0) def infer_shape(self, x_shape, w_shape, b_shape=None): @@ -1508,7 +1507,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer): group=1): """Initialize Conv2DBackpropInput""" self.init_prim_io_names(inputs=['out_backprop', 'filter', 'input_sizes'], outputs=['output']) - self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT, self.name) + self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name) self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=False) self.add_prim_attr('stride', self.stride) @@ -1531,7 +1530,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer): pad_mode = pad_mode.upper() self.add_prim_attr('pad_mode', pad_mode) self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name) - self.group = validator.check_integer('group', group, 0, Rel.GT, self.name) + self.group = validator.check_positive_int(group, 'group', self.name) self.add_prim_attr('data_format', "NCHW") if pad_list: for x in pad_list: @@ -2062,10 +2061,10 @@ class SGD(PrimitiveWithInfer): def infer_shape(self, parameters_shape, gradient_shape, learning_rate_shape, accum_shape, momentum_shape, stat_shape): - validator.check_integer(f'parameters rank', len(parameters_shape), 0, Rel.GT, self.name) + validator.check_positive_int(len(parameters_shape), "parameters rank", self.name) validator.check_integer(f'gradient rank', len(gradient_shape), 0, Rel.GE, self.name) validator.check_integer(f'learning rate rank', len(learning_rate_shape), 0, Rel.GE, self.name) - validator.check_integer(f'accumulation rank', len(accum_shape), 0, Rel.GT, self.name) + validator.check_positive_int(len(accum_shape), "accumulation rank", self.name) validator.check_integer(f'momentum rank', len(momentum_shape), 0, Rel.GE, self.name) validator.check_integer(f'stat rank', len(stat_shape), 0, Rel.GE, self.name) validator.check("gradient shape", gradient_shape, "stat shape", stat_shape, Rel.EQ, self.name) @@ -2748,9 +2747,9 @@ class LSTM(PrimitiveWithInfer): @prim_attr_register def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): - self.input_size = validator.check_integer("input_size", input_size, 0, Rel.GT, self.name) - self.hidden_size = validator.check_integer("hidden_size", hidden_size, 0, Rel.GT, self.name) - self.num_layers = validator.check_integer("num_layers", num_layers, 0, Rel.GT, self.name) + self.input_size = validator.check_positive_int(input_size, "input_size", self.name) + self.hidden_size = validator.check_positive_int(hidden_size, "hidden_size", self.name) + self.num_layers = validator.check_positive_int(num_layers, "num_layers", self.name) self.has_bias = validator.check_value_type("has_bias", has_bias, (bool,), self.name) self.bidirectional = validator.check_value_type("bidirectional", bidirectional, (bool,), self.name) self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) diff --git a/mindspore/ops/operations/random_ops.py b/mindspore/ops/operations/random_ops.py index 20eb74f908b..e20d54dc648 100644 --- a/mindspore/ops/operations/random_ops.py +++ b/mindspore/ops/operations/random_ops.py @@ -12,11 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ - """Operators for random.""" -from ..._checkparam import Validator as validator -from ..._checkparam import Rel +from ..._checkparam import Validator, Rel from ...common import dtype as mstype from ..primitive import PrimitiveWithInfer, prim_attr_register from .._utils import get_broadcast_shape @@ -46,16 +44,16 @@ class StandardNormal(PrimitiveWithInfer): def __init__(self, seed=0, seed2=0): """Initialize StandardNormal""" self.init_prim_io_names(inputs=['shape'], outputs=['output']) - validator.check_integer("seed", seed, 0, Rel.GE, self.name) - validator.check_integer("seed2", seed2, 0, Rel.GE, self.name) + Validator.check_integer("seed", seed, 0, Rel.GE, self.name) + Validator.check_integer("seed2", seed2, 0, Rel.GE, self.name) def __infer__(self, shape): shape_v = shape["value"] if shape_v is None: raise ValueError(f"For {self.name}, shape must be const.") - validator.check_value_type("shape", shape_v, [tuple], self.name) + Validator.check_value_type("shape", shape_v, [tuple], self.name) for i, shape_i in enumerate(shape_v): - validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name) + Validator.check_positive_int(shape_i, f'shape[{i}]', self.name) out = { 'shape': shape_v, 'dtype': mstype.float32, @@ -91,16 +89,16 @@ class StandardLaplace(PrimitiveWithInfer): def __init__(self, seed=0, seed2=0): """Initialize StandardLaplace""" self.init_prim_io_names(inputs=['shape'], outputs=['output']) - validator.check_value_type('seed', seed, [int], self.name) - validator.check_value_type('seed2', seed2, [int], self.name) + Validator.check_value_type('seed', seed, [int], self.name) + Validator.check_value_type('seed2', seed2, [int], self.name) def __infer__(self, shape): shape_v = shape["value"] if shape_v is None: raise ValueError(f"For {self.name}, shape must be const.") - validator.check_value_type("shape", shape_v, [tuple], self.name) + Validator.check_value_type("shape", shape_v, [tuple], self.name) for i, shape_i in enumerate(shape_v): - validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name) + Validator.check_positive_int(shape_i, f'shape[{i}]', self.name) out = { 'shape': shape_v, 'dtype': mstype.float32, @@ -143,18 +141,18 @@ class Gamma(PrimitiveWithInfer): def __init__(self, seed=0, seed2=0): """Initialize Gamma""" self.init_prim_io_names(inputs=['shape', 'alpha', 'beta'], outputs=['output']) - validator.check_integer("seed", seed, 0, Rel.GE, self.name) - validator.check_integer("seed2", seed2, 0, Rel.GE, self.name) + Validator.check_integer("seed", seed, 0, Rel.GE, self.name) + Validator.check_integer("seed2", seed2, 0, Rel.GE, self.name) def __infer__(self, shape, alpha, beta): shape_v = shape["value"] if shape_v is None: raise ValueError(f"For {self.name}, shape must be const.") - validator.check_value_type("shape", shape_v, [tuple], self.name) + Validator.check_value_type("shape", shape_v, [tuple], self.name) for i, shape_i in enumerate(shape_v): - validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name) - validator.check_tensor_type_same({"alpha": alpha["dtype"]}, [mstype.float32], self.name) - validator.check_tensor_type_same({"beta": beta["dtype"]}, [mstype.float32], self.name) + Validator.check_positive_int(shape_i, f'shape[{i}]', self.name) + Validator.check_tensor_type_same({"alpha": alpha["dtype"]}, [mstype.float32], self.name) + Validator.check_tensor_type_same({"beta": beta["dtype"]}, [mstype.float32], self.name) broadcast_shape = get_broadcast_shape(alpha['shape'], beta['shape'], self.name) broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name) out = { @@ -195,17 +193,17 @@ class Poisson(PrimitiveWithInfer): def __init__(self, seed=0, seed2=0): """Initialize Poisson""" self.init_prim_io_names(inputs=['shape', 'mean'], outputs=['output']) - validator.check_integer("seed", seed, 0, Rel.GE, self.name) - validator.check_integer("seed2", seed2, 0, Rel.GE, self.name) + Validator.check_integer("seed", seed, 0, Rel.GE, self.name) + Validator.check_integer("seed2", seed2, 0, Rel.GE, self.name) def __infer__(self, shape, mean): shape_v = shape["value"] if shape_v is None: raise ValueError(f"For {self.name}, shape must be const.") - validator.check_value_type("shape", shape_v, [tuple], self.name) + Validator.check_value_type("shape", shape_v, [tuple], self.name) for i, shape_i in enumerate(shape_v): - validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name) - validator.check_tensor_type_same({"mean": mean["dtype"]}, [mstype.float32], self.name) + Validator.check_positive_int(shape_i, f'shape[{i}]', self.name) + Validator.check_tensor_type_same({"mean": mean["dtype"]}, [mstype.float32], self.name) broadcast_shape = get_broadcast_shape(mean['shape'], shape_v, self.name) out = { 'shape': broadcast_shape, @@ -251,22 +249,22 @@ class UniformInt(PrimitiveWithInfer): def __init__(self, seed=0, seed2=0): """Initialize UniformInt""" self.init_prim_io_names(inputs=['shape', 'minval', 'maxval'], outputs=['output']) - validator.check_integer("seed", seed, 0, Rel.GE, self.name) - validator.check_integer("seed2", seed2, 0, Rel.GE, self.name) + Validator.check_integer("seed", seed, 0, Rel.GE, self.name) + Validator.check_integer("seed2", seed2, 0, Rel.GE, self.name) def __infer__(self, shape, minval, maxval): shape_v = shape["value"] if shape_v is None: raise ValueError(f"For {self.name}, shape must be const.") - validator.check_value_type("shape", shape_v, [tuple], self.name) + Validator.check_value_type("shape", shape_v, [tuple], self.name) for i, shape_i in enumerate(shape_v): - validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name) - validator.check_tensor_type_same({"minval": minval["dtype"]}, [mstype.int32], self.name) - validator.check_tensor_type_same({"maxval": maxval["dtype"]}, [mstype.int32], self.name) + Validator.check_positive_int(shape_i, f'shape[{i}]', self.name) + Validator.check_tensor_type_same({"minval": minval["dtype"]}, [mstype.int32], self.name) + Validator.check_tensor_type_same({"maxval": maxval["dtype"]}, [mstype.int32], self.name) minval_shape = minval['shape'] maxval_shape = maxval['shape'] - validator.check("dim of minval", len(minval_shape), '0(scalar)', 0, Rel.EQ, self.name) - validator.check("dim of maxval", len(maxval_shape), '0(scalar)', 0, Rel.EQ, self.name) + Validator.check("dim of minval", len(minval_shape), '0(scalar)', 0, Rel.EQ, self.name) + Validator.check("dim of maxval", len(maxval_shape), '0(scalar)', 0, Rel.EQ, self.name) out = { 'shape': shape_v, 'dtype': mstype.int32, @@ -298,16 +296,16 @@ class UniformReal(PrimitiveWithInfer): def __init__(self, seed=0, seed2=0): """Initialize UniformReal""" self.init_prim_io_names(inputs=['shape'], outputs=['output']) - validator.check_integer("seed", seed, 0, Rel.GE, self.name) - validator.check_integer("seed2", seed2, 0, Rel.GE, self.name) + Validator.check_integer("seed", seed, 0, Rel.GE, self.name) + Validator.check_integer("seed2", seed2, 0, Rel.GE, self.name) def __infer__(self, shape): shape_v = shape["value"] if shape_v is None: raise ValueError(f"For {self.name}, shape must be const.") - validator.check_value_type("shape", shape_v, [tuple], self.name) + Validator.check_value_type("shape", shape_v, [tuple], self.name) for i, shape_i in enumerate(shape_v): - validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name) + Validator.check_positive_int(shape_i, f'shape[{i}]', self.name) out = { 'shape': shape_v, 'dtype': mstype.float32, @@ -348,18 +346,18 @@ class RandomChoiceWithMask(PrimitiveWithInfer): @prim_attr_register def __init__(self, count=256, seed=0, seed2=0): """Initialize RandomChoiceWithMask""" - validator.check_value_type("count", count, [int], self.name) - validator.check_integer("count", count, 0, Rel.GT, self.name) - validator.check_value_type('seed', seed, [int], self.name) - validator.check_value_type('seed2', seed2, [int], self.name) + Validator.check_value_type("count", count, [int], self.name) + Validator.check_positive_int(count, "count", self.name) + Validator.check_value_type('seed', seed, [int], self.name) + Validator.check_value_type('seed2', seed2, [int], self.name) def infer_shape(self, x_shape): - validator.check_integer("input_x rank", len(x_shape), 1, Rel.GE, self.name) - validator.check_integer("input_x rank", len(x_shape), 5, Rel.LE, self.name) + Validator.check_integer("input_x rank", len(x_shape), 1, Rel.GE, self.name) + Validator.check_integer("input_x rank", len(x_shape), 5, Rel.LE, self.name) return ([self.count, len(x_shape)], [self.count]) def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'x': x_dtype}, [mstype.bool_], self.name) + Validator.check_tensor_type_same({'x': x_dtype}, [mstype.bool_], self.name) return (mstype.int32, mstype.bool_) @@ -399,19 +397,19 @@ class RandomCategorical(PrimitiveWithInfer): self.dtype = dtype valid_values = (mstype.int32, mstype.int16, mstype.int64) - validator.check_type_name("dtype", dtype, valid_values, self.name) + Validator.check_type_name("dtype", dtype, valid_values, self.name) self.init_prim_io_names(inputs=['logits', 'num_samples', 'seed'], outputs=['output']) def __infer__(self, logits, num_samples, seed): logits_dtype = logits['dtype'] valid_types = (mstype.float32, mstype.float16, mstype.float64) - validator.check_tensor_type_same({'logits': logits_dtype}, valid_types, self.name) + Validator.check_tensor_type_same({'logits': logits_dtype}, valid_types, self.name) num_samples_v = num_samples['value'] seed_v = seed['value'] - validator.check_value_type('num_samples', num_samples_v, (int,), self.name) - validator.check_value_type('seed', seed_v, (int,), self.name) - validator.check_integer("num_samples", num_samples_v, 0, Rel.GT, self.name) + Validator.check_value_type('num_samples', num_samples_v, (int,), self.name) + Validator.check_value_type('seed', seed_v, (int,), self.name) + Validator.check_positive_int(num_samples_v, "num_samples", self.name) x_shape = list(logits['shape']) if len(x_shape) != 2: raise ValueError("RandomCategorical shape should be 2-dimension.") @@ -450,20 +448,20 @@ class Multinomial(PrimitiveWithInfer): @prim_attr_register def __init__(self, seed=0): """init""" - validator.check_value_type("seed", seed, [int], self.name) - validator.check_integer("seed", seed, 0, Rel.GE, self.name) + Validator.check_value_type("seed", seed, [int], self.name) + Validator.check_integer("seed", seed, 0, Rel.GE, self.name) self.init_prim_io_names(inputs=['input', 'num_sample'], outputs=['output']) def __infer__(self, inputs, num_samples): input_shape = inputs["shape"] if len(input_shape) != 1 and len(input_shape) != 2: raise ValueError("input dim must be 1 or 2") - validator.check_tensor_type_same({'inputs': inputs['dtype']}, [mstype.float32], self.name) + Validator.check_tensor_type_same({'inputs': inputs['dtype']}, [mstype.float32], self.name) num_samples_value = num_samples["value"] if num_samples_value is None: raise ValueError(f"For {self.name}, shape nust be const") - validator.check_value_type("num_samples", num_samples_value, (int,), self.name) - validator.check_integer("num_samples", num_samples_value, 0, Rel.GT, None) + Validator.check_value_type("num_samples", num_samples_value, (int,), self.name) + Validator.check_positive_int(num_samples_value, "num_samples") y_shape = (num_samples_value,) if len(input_shape) == 2: y_shape = (input_shape[0], num_samples_value) diff --git a/mindspore/train/callback/_checkpoint.py b/mindspore/train/callback/_checkpoint.py index 36a09b46e81..c59440a51d1 100644 --- a/mindspore/train/callback/_checkpoint.py +++ b/mindspore/train/callback/_checkpoint.py @@ -21,7 +21,7 @@ import time import threading import mindspore.context as context from mindspore import log as logger -from mindspore._checkparam import Validator, check_int_non_negative +from mindspore._checkparam import Validator from mindspore.train._utils import _make_directory from mindspore.train.serialization import save_checkpoint, _save_graph from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank @@ -107,13 +107,13 @@ class CheckpointConfig: async_save=False): if save_checkpoint_steps is not None: - save_checkpoint_steps = check_int_non_negative(save_checkpoint_steps) + save_checkpoint_steps = Validator.check_non_negative_int(save_checkpoint_steps) if save_checkpoint_seconds is not None: - save_checkpoint_seconds = check_int_non_negative(save_checkpoint_seconds) + save_checkpoint_seconds = Validator.check_non_negative_int(save_checkpoint_seconds) if keep_checkpoint_max is not None: - keep_checkpoint_max = check_int_non_negative(keep_checkpoint_max) + keep_checkpoint_max = Validator.check_non_negative_int(keep_checkpoint_max) if keep_checkpoint_per_n_minutes is not None: - keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes) + keep_checkpoint_per_n_minutes = Validator.check_non_negative_int(keep_checkpoint_per_n_minutes) if not save_checkpoint_steps and not save_checkpoint_seconds and \ not keep_checkpoint_max and not keep_checkpoint_per_n_minutes: diff --git a/mindspore/train/loss_scale_manager.py b/mindspore/train/loss_scale_manager.py index 823994377f7..1a0f81a1b7f 100644 --- a/mindspore/train/loss_scale_manager.py +++ b/mindspore/train/loss_scale_manager.py @@ -13,8 +13,8 @@ # limitations under the License. # ============================================================================ """Loss scale manager abstract class.""" + from .._checkparam import Validator as validator -from .._checkparam import Rel from .. import nn __all__ = ["LossScaleManager", "FixedLossScaleManager", "DynamicLossScaleManager"] @@ -97,7 +97,7 @@ class DynamicLossScaleManager(LossScaleManager): if init_loss_scale < 1.0: raise ValueError("Loss scale value should be > 1") self.loss_scale = init_loss_scale - validator.check_integer("scale_window", scale_window, 0, Rel.GT, self.__class__.__name__) + validator.check_positive_int(scale_window, "scale_window", self.__class__.__name__) self.scale_window = scale_window if scale_factor <= 0: raise ValueError("Scale factor should be > 1") diff --git a/mindspore/train/model.py b/mindspore/train/model.py index aba9a3a52dd..3d8c3c73192 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -22,7 +22,7 @@ import numpy as np from mindspore import log as logger from ..common.tensor import Tensor from ..nn.metrics import get_metrics -from .._checkparam import check_input_data, check_output_data, check_int_positive, Validator, check_int +from .._checkparam import check_input_data, check_output_data, Validator, check_int from .callback import _InternalCallbackParam, RunContext, _CallbackManager from .. import context from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ @@ -339,7 +339,7 @@ class Model: dataset not sink. sink_size (int): Control the amount of data in each sink. Default: -1. """ - epoch = check_int_positive(epoch) + epoch = Validator.check_positive_int(epoch) if self._parameter_broadcast: self._train_network.set_broadcast_flag() diff --git a/model_zoo/official/cv/resnet_thor/src/thor_layer.py b/model_zoo/official/cv/resnet_thor/src/thor_layer.py index 7cc548b7110..7a5ee3bd869 100644 --- a/model_zoo/official/cv/resnet_thor/src/thor_layer.py +++ b/model_zoo/official/cv/resnet_thor/src/thor_layer.py @@ -16,7 +16,7 @@ import numpy as np import mindspore.common.dtype as mstype -from mindspore._checkparam import Validator, twice, check_int_positive +from mindspore._checkparam import Validator, twice from mindspore._extends import cell_attr_register from mindspore.common.initializer import initializer from mindspore.common.parameter import Parameter @@ -292,8 +292,8 @@ class Dense_Thor_GPU(Cell): has_bias=True, activation=None): super(Dense_Thor_GPU, self).__init__() - self.in_channels = check_int_positive(in_channels) - self.out_channels = check_int_positive(out_channels) + self.in_channels = Validator.check_positive_int(in_channels) + self.out_channels = Validator.check_positive_int(out_channels) self.has_bias = Validator.check_bool(has_bias) self.thor = True if isinstance(weight_init, Tensor): @@ -641,8 +641,8 @@ class Dense_Thor(Cell): has_bias=True, activation=None): super(Dense_Thor, self).__init__() - self.in_channels = check_int_positive(in_channels) - self.out_channels = check_int_positive(out_channels) + self.in_channels = Validator.check_positive_int(in_channels) + self.out_channels = Validator.check_positive_int(out_channels) self.has_bias = Validator.check_bool(has_bias) self.thor = True self.batch_size = batch_size diff --git a/model_zoo/official/gnn/gat/src/gat.py b/model_zoo/official/gnn/gat/src/gat.py index 08bac4d9deb..b600590ca2e 100644 --- a/model_zoo/official/gnn/gat/src/gat.py +++ b/model_zoo/official/gnn/gat/src/gat.py @@ -19,7 +19,7 @@ from mindspore.ops import functional as F from mindspore._extends import cell_attr_register from mindspore import Tensor, Parameter from mindspore.common.initializer import initializer -from mindspore._checkparam import check_int_positive, Validator +from mindspore._checkparam import Validator from mindspore.nn.layer.activation import get_activation @@ -72,8 +72,8 @@ class GNNFeatureTransform(nn.Cell): bias_init='zeros', has_bias=True): super(GNNFeatureTransform, self).__init__() - self.in_channels = check_int_positive(in_channels) - self.out_channels = check_int_positive(out_channels) + self.in_channels = Validator.check_positive_int(in_channels) + self.out_channels = Validator.check_positive_int(out_channels) self.has_bias = Validator.check_bool(has_bias) if isinstance(weight_init, Tensor): @@ -259,8 +259,8 @@ class AttentionHead(nn.Cell): coef_activation=nn.LeakyReLU(), activation=nn.ELU()): super(AttentionHead, self).__init__() - self.in_channel = check_int_positive(in_channel) - self.out_channel = check_int_positive(out_channel) + self.in_channel = Validator.check_positive_int(in_channel) + self.out_channel = Validator.check_positive_int(out_channel) self.in_drop_ratio = in_drop_ratio self.in_drop = nn.Dropout(keep_prob=1 - in_drop_ratio) self.in_drop_2 = nn.Dropout(keep_prob=1 - in_drop_ratio) @@ -450,9 +450,9 @@ class GAT(nn.Cell): super(GAT, self).__init__() self.features = Tensor(features) self.biases = Tensor(biases) - self.ftr_dims = check_int_positive(ftr_dims) - self.num_class = check_int_positive(num_class) - self.num_nodes = check_int_positive(num_nodes) + self.ftr_dims = Validator.check_positive_int(ftr_dims) + self.num_class = Validator.check_positive_int(num_class) + self.num_nodes = Validator.check_positive_int(num_nodes) self.hidden_units = hidden_units self.num_heads = num_heads self.attn_drop = attn_drop diff --git a/model_zoo/official/nlp/bert_thor/src/model_thor.py b/model_zoo/official/nlp/bert_thor/src/model_thor.py index 67cdff3c474..0b2f24020d5 100644 --- a/model_zoo/official/nlp/bert_thor/src/model_thor.py +++ b/model_zoo/official/nlp/bert_thor/src/model_thor.py @@ -22,7 +22,7 @@ from mindspore._c_expression import init_exec_dataset from mindspore import context from mindspore import log as logger from mindspore import nn -from mindspore._checkparam import check_input_data, check_output_data, check_int_positive, Validator, check_int +from mindspore._checkparam import check_input_data, check_output_data, Validator, check_int from mindspore.common import dtype as mstype from mindspore.common.dtype import pytype_to_dtype from mindspore.common.tensor import Tensor @@ -374,7 +374,7 @@ class Model: dataset not sink. sink_size (int): Control the amount of data each sink. Default: -1. """ - epoch = check_int_positive(epoch) + epoch = Validator.check_positive_int(epoch) self._train_network.set_train() if self._parameter_broadcast: diff --git a/model_zoo/official/nlp/bert_thor/src/thor_layer.py b/model_zoo/official/nlp/bert_thor/src/thor_layer.py index 6f791ef6f1f..0fb646fe259 100644 --- a/model_zoo/official/nlp/bert_thor/src/thor_layer.py +++ b/model_zoo/official/nlp/bert_thor/src/thor_layer.py @@ -15,7 +15,7 @@ """thor_layer""" import numpy as np import mindspore.common.dtype as mstype -from mindspore._checkparam import Validator, check_int_positive +from mindspore._checkparam import Validator from mindspore.common.initializer import TruncatedNormal, initializer from mindspore.common.parameter import Parameter from mindspore.common.tensor import Tensor @@ -160,8 +160,8 @@ class Dense_Thor(Cell): activation=None, batch_size=12): super(Dense_Thor, self).__init__() - self.in_channels = check_int_positive(in_channels) - self.out_channels = check_int_positive(out_channels) + self.in_channels = Validator.check_positive_int(in_channels) + self.out_channels = Validator.check_positive_int(out_channels) self.has_bias = Validator.check_bool(has_bias) self.thor = True if isinstance(weight_init, Tensor): diff --git a/tests/st/gnn/aggregator.py b/tests/st/gnn/aggregator.py index bffacef2dc4..b1c69a23a6e 100644 --- a/tests/st/gnn/aggregator.py +++ b/tests/st/gnn/aggregator.py @@ -15,7 +15,7 @@ """Aggregator.""" import mindspore.nn as nn from mindspore import Tensor, Parameter -from mindspore._checkparam import check_int_positive, Validator +from mindspore._checkparam import Validator from mindspore._extends import cell_attr_register from mindspore.common.initializer import initializer from mindspore.nn.layer.activation import get_activation @@ -73,8 +73,8 @@ class GNNFeatureTransform(nn.Cell): bias_init='zeros', has_bias=True): super(GNNFeatureTransform, self).__init__() - self.in_channels = check_int_positive(in_channels) - self.out_channels = check_int_positive(out_channels) + self.in_channels = Validator.check_positive_int(in_channels) + self.out_channels = Validator.check_positive_int(out_channels) self.has_bias = Validator.check_bool(has_bias) if isinstance(weight_init, Tensor): @@ -262,8 +262,8 @@ class AttentionHead(nn.Cell): coef_activation=nn.LeakyReLU(), activation=nn.ELU()): super(AttentionHead, self).__init__() - self.in_channel = check_int_positive(in_channel) - self.out_channel = check_int_positive(out_channel) + self.in_channel = Validator.check_positive_int(in_channel) + self.out_channel = Validator.check_positive_int(out_channel) self.in_drop_ratio = in_drop_ratio self.in_drop = nn.Dropout(keep_prob=1 - in_drop_ratio) self.in_drop_2 = nn.Dropout(keep_prob=1 - in_drop_ratio) diff --git a/tests/st/gnn/gat.py b/tests/st/gnn/gat.py index 3d12c48977b..2ea4c909ebf 100644 --- a/tests/st/gnn/gat.py +++ b/tests/st/gnn/gat.py @@ -14,7 +14,7 @@ # ============================================================================ """Graph Attention Networks.""" import mindspore.nn as nn -from mindspore._checkparam import Validator, check_int_positive +from mindspore._checkparam import Validator from aggregator import AttentionAggregator @@ -71,9 +71,9 @@ class GAT(nn.Cell): activation=nn.ELU(), residual=False): super(GAT, self).__init__() - self.ftr_dims = check_int_positive(ftr_dims) - self.num_class = check_int_positive(num_class) - self.num_nodes = check_int_positive(num_nodes) + self.ftr_dims = Validator.check_positive_int(ftr_dims) + self.num_class = Validator.check_positive_int(num_class) + self.num_nodes = Validator.check_positive_int(num_nodes) self.hidden_units = hidden_units self.num_heads = num_heads self.attn_drop = attn_drop diff --git a/tests/st/networks/models/resnet50/src_thor/model_thor.py b/tests/st/networks/models/resnet50/src_thor/model_thor.py index 7e31d23daa2..18346812cae 100644 --- a/tests/st/networks/models/resnet50/src_thor/model_thor.py +++ b/tests/st/networks/models/resnet50/src_thor/model_thor.py @@ -19,7 +19,7 @@ from mindspore import context from mindspore import log as logger from mindspore import nn from mindspore._c_expression import init_exec_dataset -from mindspore._checkparam import check_input_data, check_output_data, check_int_positive, Validator +from mindspore._checkparam import check_input_data, check_output_data, Validator from mindspore.common import dtype as mstype from mindspore.common.dtype import pytype_to_dtype from mindspore.common.tensor import Tensor @@ -377,7 +377,7 @@ class Model: Configure pynative mode, the training process will be performed with dataset not sink. """ - epoch = check_int_positive(epoch) + epoch = Validator.check_positive_int(epoch) self._train_network.set_train() if self._parameter_broadcast: diff --git a/tests/st/networks/models/resnet50/src_thor/thor_layer.py b/tests/st/networks/models/resnet50/src_thor/thor_layer.py index f4d76d07543..6b56461c45a 100644 --- a/tests/st/networks/models/resnet50/src_thor/thor_layer.py +++ b/tests/st/networks/models/resnet50/src_thor/thor_layer.py @@ -16,7 +16,7 @@ import numpy as np import mindspore as ms import mindspore.common.dtype as mstype -from mindspore._checkparam import Validator, twice, check_int_positive +from mindspore._checkparam import Validator, twice from mindspore._extends import cell_attr_register from mindspore.common.initializer import initializer from mindspore.common.parameter import Parameter @@ -337,8 +337,8 @@ class Dense_Thor(Cell): has_bias=True, activation=None): super(Dense_Thor, self).__init__() - self.in_channels = check_int_positive(in_channels) - self.out_channels = check_int_positive(out_channels) + self.in_channels = Validator.check_positive_int(in_channels) + self.out_channels = Validator.check_positive_int(out_channels) self.has_bias = Validator.check_bool(has_bias) self.thor = True if isinstance(weight_init, Tensor): diff --git a/tests/ut/python/nn/test_checkparameter.py b/tests/ut/python/nn/test_checkparameter.py index 7a68fd1ae06..7878b5b9e63 100644 --- a/tests/ut/python/nn/test_checkparameter.py +++ b/tests/ut/python/nn/test_checkparameter.py @@ -15,8 +15,7 @@ """ test checkparameter """ import pytest -from mindspore._checkparam import check_int, check_int_positive, \ - check_input_format, Validator, twice +from mindspore._checkparam import check_int, check_input_format, Validator, twice kernel_size = 5 kernel_size1 = twice(kernel_size) @@ -29,7 +28,7 @@ def test_check_int_1(): def check_int_positive_1(): with pytest.raises(ValueError): - check_int_positive(-1) + Validator.check_positive_int(-1) def test_NCHW1(): diff --git a/tests/ut/python/pynative_mode/nn/test_checkparameter.py b/tests/ut/python/pynative_mode/nn/test_checkparameter.py index 30a80df8ff2..048c4a71189 100644 --- a/tests/ut/python/pynative_mode/nn/test_checkparameter.py +++ b/tests/ut/python/pynative_mode/nn/test_checkparameter.py @@ -15,8 +15,7 @@ """ test_checkparameter """ import pytest -from mindspore._checkparam import check_int, check_int_positive, \ - Validator, check_input_format, _expand_tuple +from mindspore._checkparam import check_int, Validator, check_input_format, _expand_tuple once = _expand_tuple(1) twice = _expand_tuple(2) @@ -32,7 +31,7 @@ def test_check_int_1(): def check_int_positive_1(): with pytest.raises(ValueError): - check_int_positive(-1) + Validator.check_positive_int(-1) def test_NCHW1(): diff --git a/tests/vm_impl/vm_me.py b/tests/vm_impl/vm_me.py index 58558ffa0f1..528863ec21b 100644 --- a/tests/vm_impl/vm_me.py +++ b/tests/vm_impl/vm_me.py @@ -15,8 +15,6 @@ """VM implementations based on numpy.""" import numpy as np - -from mindspore._checkparam import Rel from mindspore._checkparam import Validator as validator @@ -33,7 +31,7 @@ def avg_pooling(x, pool_h, pool_w, stride): Returns: numpy.ndarray, an output array after applying average pooling on input array. """ - validator.check_integer("stride", stride, 0, Rel.GT, None) + validator.check_positive_int(stride, "stride") num, channel, height, width = x.shape out_h = (height - pool_h) // stride + 1 out_w = (width - pool_w) // stride + 1 @@ -423,7 +421,7 @@ def matmul(x, w, b=None): def max_pooling(x, pool_h, pool_w, stride): """Max pooling.""" - validator.check_integer("stride", stride, 0, Rel.GT, None) + validator.check_positive_int(stride, "stride") num, channel, height, width = x.shape out_h = (height - pool_h) // stride + 1 out_w = (width - pool_w) // stride + 1 @@ -466,7 +464,7 @@ def max_pool_grad_with_argmax(x, dout, arg_max, pool_h, pool_w, stride): def max_pool_with_argmax(x, pool_h, pool_w, stride): """Max pooling with argmax.""" - validator.check_integer("stride", stride, 0, Rel.GT, None) + validator.check_positive_int(stride, "stride") num, channel, height, width = x.shape out_h = (height - pool_h) // stride + 1 out_w = (width - pool_w) // stride + 1