[ME] change `check_integer` to format `check_positive_int` and `check_integeter`

This commit is contained in:
chenzomi 2020-10-10 11:34:16 +08:00
parent d4e8e94981
commit d471d32e87
37 changed files with 272 additions and 252 deletions

View File

@ -92,6 +92,25 @@ rel_strs = {
} }
def _check_integer(arg_value, value, rel, arg_name=None, prim_name=None):
"""
Check argument integer.
Usage:
- number = check_integer(number, 0, Rel.GE, "number", None) # number >= 0
"""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
type_except = TypeError if type_mismatch else ValueError
if type_mismatch or not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(value)
arg_name = arg_name if arg_name else "parameter"
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
raise type_except(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got `{arg_value}`'
f' with type `{type(arg_value).__name__}`.')
return arg_value
class Validator: class Validator:
"""validator for checking input parameters""" """validator for checking input parameters"""
@ -121,6 +140,49 @@ class Validator:
f' with type `{type(arg_value).__name__}`.') f' with type `{type(arg_value).__name__}`.')
return arg_value return arg_value
@staticmethod
def check_positive_int(arg_value, arg_name=None, prim_name=None):
"""
Check argument is positive integer, which mean arg_value > 0.
Usage:
- number = check_positive_int(number)
- number = check_positive_int(number, "bias")
"""
return _check_integer(arg_value, 0, Rel.GT, arg_name, prim_name)
@staticmethod
def check_negative_int(arg_value, arg_name=None, prim_name=None):
"""
Check argument is negative integer, which mean arg_value < 0.
Usage:
- number = check_negative_int(number)
- number = check_negative_int(number, "bias")
"""
return _check_integer(arg_value, 0, Rel.LT, arg_name, prim_name)
@staticmethod
def check_non_positive_int(arg_value, arg_name=None, prim_name=None):
"""
Check argument is non-negative integer, which mean arg_value <= 0.
Usage:
- number = check_non_positive_int(number)
- number = check_non_positive_int(number, "bias")
"""
return _check_integer(arg_value, 0, Rel.LE, arg_name, prim_name)
@staticmethod
def check_non_negative_int(arg_value, arg_name=None, prim_name=None):
"""
Check argument is non-negative integer, which mean arg_value >= 0.
Usage:
- number = check_non_negative_int(number)
- number = check_non_negative_int(number, "bias")
"""
return _check_integer(arg_value, 0, Rel.GE, arg_name, prim_name)
@staticmethod @staticmethod
def check_number(arg_name, arg_value, value, rel, prim_name): def check_number(arg_name, arg_value, value, rel, prim_name):
@ -140,7 +202,13 @@ class Validator:
@staticmethod @staticmethod
def check_bool(arg_value, arg_name=None): def check_bool(arg_value, arg_name=None):
"""Check argument is instance of bool""" """
Check argument is instance of bool.
Usage:
- has_bias = check_bool(has_bias)
- has_bias = check_bool(has_bias, "has_bias")
"""
if not isinstance(arg_value, bool): if not isinstance(arg_value, bool):
arg_name = arg_name if arg_name else "Parameter" arg_name = arg_name if arg_name else "Parameter"
raise TypeError(f'`{arg_name}` should be isinstance of bool, but got `{arg_value}`.') raise TypeError(f'`{arg_name}` should be isinstance of bool, but got `{arg_value}`.')
@ -169,7 +237,12 @@ class Validator:
@staticmethod @staticmethod
def check_string(arg_value, valid_values, arg_name=None, prim_name=None): def check_string(arg_value, valid_values, arg_name=None, prim_name=None):
"""Checks whether a string is in some value list""" """
Check whether string is in some value list.
Usage:
- method = check_string(method, ["string1", "string2", "string3"], "method")
"""
if isinstance(arg_value, str) and arg_value in valid_values: if isinstance(arg_value, str) and arg_value in valid_values:
return arg_value return arg_value
arg_name = arg_name if arg_name else "Parameter" arg_name = arg_name if arg_name else "Parameter"
@ -372,28 +445,6 @@ def check_int(input_param):
raise TypeError("Input type must be int!") raise TypeError("Input type must be int!")
def check_int_positive(input_param):
"""Int type judgment."""
if isinstance(input_param, bool):
raise TypeError("Input type must be int cannot be bool!")
if isinstance(input_param, int):
if input_param > 0:
return input_param
raise ValueError("The input_param must be positive, but got input_param {}.".format(input_param))
raise TypeError("Input type must be int cannot be {}!".format(type(input_param)))
def check_int_non_negative(input_param):
"""Non_negative type judgment."""
if isinstance(input_param, bool):
raise TypeError("Input type must be int cannot be bool!")
if isinstance(input_param, int):
if input_param >= 0:
return input_param
raise ValueError("The input_param must be non_negative, but got input_param {}.".format(input_param))
raise TypeError("Input type must be int cannot be {}!".format(type(input_param)))
def check_int_zero_one(input_param): def check_int_zero_one(input_param):
"""Judge whether it is 0 or 1.""" """Judge whether it is 0 or 1."""
if input_param in (0, 1): if input_param in (0, 1):

View File

@ -52,7 +52,7 @@ def piecewise_constant_lr(milestone, learning_rates):
lr = [] lr = []
last_item = 0 last_item = 0
for i, item in enumerate(milestone): for i, item in enumerate(milestone):
validator.check_integer(f'milestone[{i}]', item, 0, Rel.GT, None) validator.check_positive_int(item, f'milestone[{i}]')
validator.check_float_legal_value(f'learning_rates[{i}]', learning_rates[i], None) validator.check_float_legal_value(f'learning_rates[{i}]', learning_rates[i], None)
if item < last_item: if item < last_item:
raise ValueError(f'The value of milestone[{i}] must be greater than milestone[{i - 1}]') raise ValueError(f'The value of milestone[{i}] must be greater than milestone[{i - 1}]')
@ -63,9 +63,9 @@ def piecewise_constant_lr(milestone, learning_rates):
def _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair): def _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair):
validator.check_integer('total_step', total_step, 0, Rel.GT, None) validator.check_positive_int(total_step, 'total_step')
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) validator.check_positive_int(step_per_epoch, 'step_per_epoch')
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None) validator.check_positive_int(decay_epoch, 'decay_epoch')
validator.check_float_positive('learning_rate', learning_rate, None) validator.check_float_positive('learning_rate', learning_rate, None)
validator.check_float_legal_value('learning_rate', learning_rate, None) validator.check_float_legal_value('learning_rate', learning_rate, None)
validator.check_float_positive('decay_rate', decay_rate, None) validator.check_float_positive('decay_rate', decay_rate, None)
@ -236,9 +236,9 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch):
validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, None) validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, None)
validator.check_float_positive('max_lr', max_lr, None) validator.check_float_positive('max_lr', max_lr, None)
validator.check_float_legal_value('max_lr', max_lr, None) validator.check_float_legal_value('max_lr', max_lr, None)
validator.check_integer('total_step', total_step, 0, Rel.GT, None) validator.check_positive_int(total_step, 'total_step')
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) validator.check_positive_int(step_per_epoch, 'step_per_epoch')
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None) validator.check_positive_int(decay_epoch, 'decay_epoch')
if min_lr >= max_lr: if min_lr >= max_lr:
raise ValueError('`max_lr` should be greater than `min_lr`.') raise ValueError('`max_lr` should be greater than `min_lr`.')
@ -306,9 +306,9 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e
validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None) validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None)
validator.check_float_positive('power', power, None) validator.check_float_positive('power', power, None)
validator.check_float_legal_value('power', power, None) validator.check_float_legal_value('power', power, None)
validator.check_integer('total_step', total_step, 0, Rel.GT, None) validator.check_positive_int(total_step, 'total_step')
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) validator.check_positive_int(step_per_epoch, 'step_per_epoch')
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None) validator.check_positive_int(decay_epoch, 'decay_epoch')
validator.check_value_type('update_decay_epoch', update_decay_epoch, [bool], None) validator.check_value_type('update_decay_epoch', update_decay_epoch, [bool], None)
origin_decay_epoch = decay_epoch origin_decay_epoch = decay_epoch
@ -357,9 +357,9 @@ def warmup_lr(learning_rate, total_step, step_per_epoch, warmup_epoch):
if not isinstance(learning_rate, float): if not isinstance(learning_rate, float):
raise TypeError("learning_rate must be float.") raise TypeError("learning_rate must be float.")
validator.check_number_range("learning_rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None) validator.check_number_range("learning_rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None)
validator.check_integer('warmup_epoch', warmup_epoch, 0, Rel.GT, None) validator.check_positive_int(warmup_epoch, 'warmup_epoch')
validator.check_integer('total_step', total_step, 0, Rel.GT, None) validator.check_positive_int(total_step, 'total_step')
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) validator.check_positive_int(step_per_epoch, 'step_per_epoch')
function = lambda x, y: (x, min(x, y)) function = lambda x, y: (x, min(x, y))

View File

@ -27,7 +27,7 @@ from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops.primitive import constexpr from mindspore.ops.primitive import constexpr
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore._extends import cell_attr_register from mindspore._extends import cell_attr_register
from mindspore._checkparam import Rel, Validator, check_int_positive from mindspore._checkparam import Rel, Validator
from mindspore.common.api import ms_function from mindspore.common.api import ms_function
from mindspore import context from mindspore import context
from ..cell import Cell from ..cell import Cell
@ -203,8 +203,8 @@ class Dense(Cell):
has_bias=True, has_bias=True,
activation=None): activation=None):
super(Dense, self).__init__() super(Dense, self).__init__()
self.in_channels = check_int_positive(in_channels) self.in_channels = Validator.check_positive_int(in_channels)
self.out_channels = check_int_positive(out_channels) self.out_channels = Validator.check_positive_int(out_channels)
self.has_bias = Validator.check_bool(has_bias) self.has_bias = Validator.check_bool(has_bias)
if isinstance(weight_init, Tensor): if isinstance(weight_init, Tensor):

View File

@ -21,7 +21,7 @@ from mindspore.ops.primitive import constexpr
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer, Initializer from mindspore.common.initializer import initializer, Initializer
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore._checkparam import Validator, Rel, twice, check_int_positive from mindspore._checkparam import Validator, Rel, twice
from mindspore._extends import cell_attr_register from mindspore._extends import cell_attr_register
from ..cell import Cell from ..cell import Cell
@ -47,8 +47,8 @@ class _Conv(Cell):
bias_init, bias_init,
transposed=False): transposed=False):
super(_Conv, self).__init__() super(_Conv, self).__init__()
self.in_channels = check_int_positive(in_channels) self.in_channels = Validator.check_positive_int(in_channels)
self.out_channels = check_int_positive(out_channels) self.out_channels = Validator.check_positive_int(out_channels)
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.stride = stride self.stride = stride
self.pad_mode = pad_mode self.pad_mode = pad_mode
@ -65,7 +65,7 @@ class _Conv(Cell):
raise TypeError("padding type must be int/tuple(int) cannot be {}!".format(type(padding))) raise TypeError("padding type must be int/tuple(int) cannot be {}!".format(type(padding)))
self.dilation = dilation self.dilation = dilation
self.group = check_int_positive(group) self.group = Validator.check_positive_int(group)
self.has_bias = has_bias self.has_bias = has_bias
if (not isinstance(kernel_size[0], int)) or (not isinstance(kernel_size[1], int)) or \ if (not isinstance(kernel_size[0], int)) or (not isinstance(kernel_size[1], int)) or \
isinstance(kernel_size[0], bool) or isinstance(kernel_size[1], bool) or \ isinstance(kernel_size[0], bool) or isinstance(kernel_size[1], bool) or \

View File

@ -21,7 +21,7 @@ from mindspore.common.initializer import initializer
from mindspore.communication.management import get_group_size from mindspore.communication.management import get_group_size
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_parallel_mode from mindspore.parallel._utils import _get_parallel_mode
from mindspore._checkparam import Rel, Validator as validator from mindspore._checkparam import Validator as validator
from ..cell import Cell from ..cell import Cell
__all__ = ['Embedding', 'EmbeddingLookup'] __all__ = ['Embedding', 'EmbeddingLookup']
@ -170,7 +170,7 @@ class EmbeddingLookup(Cell):
if not isinstance(manual_shapes, tuple): if not isinstance(manual_shapes, tuple):
raise TypeError("manual_shapes type must be tuple(int) cannot be {}!".format(type(manual_shapes))) raise TypeError("manual_shapes type must be tuple(int) cannot be {}!".format(type(manual_shapes)))
for dim in manual_shapes: for dim in manual_shapes:
validator.check_integer('manul shape dim', dim, 0, Rel.GT, self.cls_name) validator.check_positive_int(dim, 'manual shape dim', self.cls_name)
self.gatherv2.add_prim_attr("manual_split", manual_shapes) self.gatherv2.add_prim_attr("manual_split", manual_shapes)
self.embeddinglookup.add_prim_attr("manual_split", manual_shapes) self.embeddinglookup.add_prim_attr("manual_split", manual_shapes)
self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size()))) self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size())))

View File

@ -15,7 +15,7 @@
"""lstm""" """lstm"""
import math import math
import numpy as np import numpy as np
from mindspore._checkparam import Rel, Validator as validator from mindspore._checkparam import Validator as validator
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
@ -103,8 +103,8 @@ class LSTM(Cell):
bidirectional=False): bidirectional=False):
super(LSTM, self).__init__() super(LSTM, self).__init__()
validator.check_value_type("batch_first", batch_first, [bool], self.cls_name) validator.check_value_type("batch_first", batch_first, [bool], self.cls_name)
validator.check_integer("hidden_size", hidden_size, 0, Rel.GT, self.cls_name) validator.check_positive_int(hidden_size, "hidden_size", self.cls_name)
validator.check_integer("num_layers", num_layers, 0, Rel.GT, self.cls_name) validator.check_positive_int(num_layers, "num_layers", self.cls_name)
self.batch_first = batch_first self.batch_first = batch_first
self.transpose = P.Transpose() self.transpose = P.Transpose()

View File

@ -21,7 +21,7 @@ from mindspore.common.tensor import Tensor
from mindspore.ops.primitive import constexpr from mindspore.ops.primitive import constexpr
from ..cell import Cell from ..cell import Cell
from ...common import dtype as mstype from ...common import dtype as mstype
from ..._checkparam import Rel, Validator as validator from ..._checkparam import Validator as validator
__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'MatMul'] __all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'MatMul']
@ -156,7 +156,7 @@ class LinSpace(Cell):
validator.check_value_type("start", start, [int, float], self.cls_name) validator.check_value_type("start", start, [int, float], self.cls_name)
validator.check_value_type("stop", stop, [int, float], self.cls_name) validator.check_value_type("stop", stop, [int, float], self.cls_name)
validator.check_value_type("num", num, [int], self.cls_name) validator.check_value_type("num", num, [int], self.cls_name)
validator.check_integer("num", num, 0, Rel.GT, self.cls_name) validator.check_positive_int(num, "num", self.cls_name)
self.is_single = bool(num == 1) self.is_single = bool(num == 1)
self.lin_space = inner.LinSpace() self.lin_space = inner.LinSpace()

View File

@ -19,7 +19,7 @@ from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.ops.primitive import constexpr from mindspore.ops.primitive import constexpr
import mindspore.context as context import mindspore.context as context
from mindspore._checkparam import Validator, check_typename, check_int_positive from mindspore._checkparam import Validator, check_typename
from mindspore._extends import cell_attr_register from mindspore._extends import cell_attr_register
from mindspore.communication.management import get_group_size, get_rank from mindspore.communication.management import get_group_size, get_rank
from mindspore.communication import management from mindspore.communication import management
@ -64,7 +64,7 @@ class _BatchNorm(Cell):
gamma_init, num_features), name="gamma", requires_grad=affine) gamma_init, num_features), name="gamma", requires_grad=affine)
self.beta = Parameter(initializer( self.beta = Parameter(initializer(
beta_init, num_features), name="beta", requires_grad=affine) beta_init, num_features), name="beta", requires_grad=affine)
self.group = check_int_positive(device_num_each_group) self.group = Validator.check_positive_int(device_num_each_group)
self.is_global = False self.is_global = False
if self.group != 1: if self.group != 1:
self.rank_id = get_rank() self.rank_id = get_rank()
@ -464,7 +464,7 @@ class GlobalBatchNorm(_BatchNorm):
use_batch_statistics, use_batch_statistics,
device_num_each_group, device_num_each_group,
input_dims='both') input_dims='both')
self.group = check_int_positive(device_num_each_group) self.group = Validator.check_positive_int(device_num_each_group)
if self.group <= 1: if self.group <= 1:
raise ValueError("the number of group must be greater than 1.") raise ValueError("the number of group must be greater than 1.")
@ -599,8 +599,8 @@ class GroupNorm(Cell):
def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, gamma_init='ones', beta_init='zeros'): def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, gamma_init='ones', beta_init='zeros'):
super(GroupNorm, self).__init__() super(GroupNorm, self).__init__()
self.num_groups = check_int_positive(num_groups) self.num_groups = Validator.check_positive_int(num_groups)
self.num_channels = check_int_positive(num_channels) self.num_channels = Validator.check_positive_int(num_channels)
if num_channels % num_groups != 0: if num_channels % num_groups != 0:
raise ValueError("num_channels should be divided by num_groups") raise ValueError("num_channels should be divided by num_groups")
self.eps = check_typename('eps', eps, (float,)) self.eps = check_typename('eps', eps, (float,))

View File

@ -23,7 +23,7 @@ from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore._checkparam import Validator, Rel, check_int_positive, twice from mindspore._checkparam import Validator, Rel, twice
import mindspore.context as context import mindspore.context as context
from .normalization import BatchNorm2d, BatchNorm1d from .normalization import BatchNorm2d, BatchNorm1d
from .activation import get_activation, ReLU, LeakyReLU from .activation import get_activation, ReLU, LeakyReLU
@ -657,8 +657,8 @@ class Conv2dBnWithoutFoldQuant(Cell):
self.kernel_size = (kernel_size, kernel_size) self.kernel_size = (kernel_size, kernel_size)
else: else:
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.in_channels = check_int_positive(in_channels) self.in_channels = Validator.check_positive_int(in_channels)
self.out_channels = check_int_positive(out_channels) self.out_channels = Validator.check_positive_int(out_channels)
self.has_bias = has_bias self.has_bias = has_bias
self.stride = twice(stride) self.stride = twice(stride)
self.dilation = twice(dilation) self.dilation = twice(dilation)
@ -785,8 +785,8 @@ class Conv2dQuant(Cell):
self.kernel_size = (kernel_size, kernel_size) self.kernel_size = (kernel_size, kernel_size)
else: else:
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.in_channels = check_int_positive(in_channels) self.in_channels = Validator.check_positive_int(in_channels)
self.out_channels = check_int_positive(out_channels) self.out_channels = Validator.check_positive_int(out_channels)
self.has_bias = has_bias self.has_bias = has_bias
self.stride = twice(stride) self.stride = twice(stride)
self.dilation = twice(dilation) self.dilation = twice(dilation)
@ -886,8 +886,8 @@ class DenseQuant(Cell):
narrow_range=False, narrow_range=False,
quant_delay=0): quant_delay=0):
super(DenseQuant, self).__init__() super(DenseQuant, self).__init__()
self.in_channels = check_int_positive(in_channels) self.in_channels = Validator.check_positive_int(in_channels)
self.out_channels = check_int_positive(out_channels) self.out_channels = Validator.check_positive_int(out_channels)
self.has_bias = Validator.check_bool(has_bias) self.has_bias = Validator.check_bool(has_bias)
if isinstance(weight_init, Tensor): if isinstance(weight_init, Tensor):

View File

@ -44,7 +44,7 @@ class LearningRateSchedule(Cell):
def _check_inputs(learning_rate, decay_rate, decay_steps, is_stair, cls_name): def _check_inputs(learning_rate, decay_rate, decay_steps, is_stair, cls_name):
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, cls_name) validator.check_positive_int(decay_steps, 'decay_steps', cls_name)
validator.check_float_positive('learning_rate', learning_rate, cls_name) validator.check_float_positive('learning_rate', learning_rate, cls_name)
validator.check_float_legal_value('learning_rate', learning_rate, cls_name) validator.check_float_legal_value('learning_rate', learning_rate, cls_name)
validator.check_float_positive('decay_rate', decay_rate, cls_name) validator.check_float_positive('decay_rate', decay_rate, cls_name)
@ -257,7 +257,7 @@ class CosineDecayLR(LearningRateSchedule):
validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
validator.check_float_positive('max_lr', max_lr, self.cls_name) validator.check_float_positive('max_lr', max_lr, self.cls_name)
validator.check_float_legal_value('max_lr', max_lr, self.cls_name) validator.check_float_legal_value('max_lr', max_lr, self.cls_name)
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, self.cls_name) validator.check_positive_int(decay_steps, "decay_steps", self.cls_name)
if min_lr >= max_lr: if min_lr >= max_lr:
raise ValueError('`max_lr` should be greater than `min_lr`.') raise ValueError('`max_lr` should be greater than `min_lr`.')
self.min_lr = min_lr self.min_lr = min_lr
@ -324,7 +324,7 @@ class PolynomialDecayLR(LearningRateSchedule):
raise TypeError("end_learning_rate must be float.") raise TypeError("end_learning_rate must be float.")
validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT,
self.cls_name) self.cls_name)
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, self.cls_name) validator.check_positive_int(decay_steps, 'decay_steps', self.cls_name)
validator.check_value_type('update_decay_steps', update_decay_steps, [bool], self.cls_name) validator.check_value_type('update_decay_steps', update_decay_steps, [bool], self.cls_name)
validator.check_float_positive('power', power, self.cls_name) validator.check_float_positive('power', power, self.cls_name)
validator.check_float_legal_value('power', power, self.cls_name) validator.check_float_legal_value('power', power, self.cls_name)
@ -388,7 +388,7 @@ class WarmUpLR(LearningRateSchedule):
if not isinstance(learning_rate, float): if not isinstance(learning_rate, float):
raise TypeError("learning_rate must be float.") raise TypeError("learning_rate must be float.")
validator.check_number_range("learning_rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) validator.check_number_range("learning_rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
validator.check_integer('warmup_steps', warmup_steps, 0, Rel.GT, self.cls_name) validator.check_positive_int(warmup_steps, 'warmup_steps', self.cls_name)
self.warmup_steps = warmup_steps self.warmup_steps = warmup_steps
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.min = P.Minimum() self.min = P.Minimum()

View File

@ -15,7 +15,7 @@
"""dense_variational""" """dense_variational"""
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore._checkparam import check_int_positive, Validator from mindspore._checkparam import Validator
from ...cell import Cell from ...cell import Cell
from ...layer.activation import get_activation from ...layer.activation import get_activation
from .layer_distribution import NormalPrior, NormalPosterior from .layer_distribution import NormalPrior, NormalPosterior
@ -39,8 +39,8 @@ class _DenseVariational(Cell):
bias_prior_fn=NormalPrior, bias_prior_fn=NormalPrior,
bias_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape)): bias_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape)):
super(_DenseVariational, self).__init__() super(_DenseVariational, self).__init__()
self.in_channels = check_int_positive(in_channels) self.in_channels = Validator.check_positive_int(in_channels)
self.out_channels = check_int_positive(out_channels) self.out_channels = Validator.check_positive_int(out_channels)
self.has_bias = Validator.check_bool(has_bias) self.has_bias = Validator.check_bool(has_bias)
if isinstance(weight_prior_fn, Cell): if isinstance(weight_prior_fn, Cell):

View File

@ -15,7 +15,7 @@
"""Conditional Variational auto-encoder (CVAE).""" """Conditional Variational auto-encoder (CVAE)."""
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore._checkparam import check_int_positive from mindspore._checkparam import Validator
from ....cell import Cell from ....cell import Cell
from ....layer.basic import Dense, OneHot from ....layer.basic import Dense, OneHot
@ -57,11 +57,11 @@ class ConditionalVAE(Cell):
self.decoder = decoder self.decoder = decoder
if (not isinstance(encoder, Cell)) or (not isinstance(decoder, Cell)): if (not isinstance(encoder, Cell)) or (not isinstance(decoder, Cell)):
raise TypeError('The encoder and decoder should be Cell type.') raise TypeError('The encoder and decoder should be Cell type.')
self.hidden_size = check_int_positive(hidden_size) self.hidden_size = Validator.check_positive_int(hidden_size)
self.latent_size = check_int_positive(latent_size) self.latent_size = Validator.check_positive_int(latent_size)
if hidden_size < latent_size: if hidden_size < latent_size:
raise ValueError('The latent_size should be less than or equal to the hidden_size.') raise ValueError('The latent_size should be less than or equal to the hidden_size.')
self.num_classes = check_int_positive(num_classes) self.num_classes = Validator.check_positive_int(num_classes)
self.normal = C.normal self.normal = C.normal
self.exp = P.Exp() self.exp = P.Exp()
self.reshape = P.Reshape() self.reshape = P.Reshape()
@ -108,7 +108,7 @@ class ConditionalVAE(Cell):
Returns: Returns:
Tensor, the generated samples. Tensor, the generated samples.
""" """
generate_nums = check_int_positive(generate_nums) generate_nums = Validator.check_positive_int(generate_nums)
if not isinstance(shape, tuple) or len(shape) != 4 or (shape[0] != -1 and shape[0] != generate_nums): if not isinstance(shape, tuple) or len(shape) != 4 or (shape[0] != -1 and shape[0] != generate_nums):
raise ValueError('The shape should be (generate_nums, C, H, W) or (-1, C, H, W).') raise ValueError('The shape should be (generate_nums, C, H, W) or (-1, C, H, W).')
sample_z = self.normal((generate_nums, self.latent_size), self.to_tensor(0.0), self.to_tensor(1.0), seed=0) sample_z = self.normal((generate_nums, self.latent_size), self.to_tensor(0.0), self.to_tensor(1.0), seed=0)

View File

@ -15,7 +15,7 @@
"""Variational auto-encoder (VAE)""" """Variational auto-encoder (VAE)"""
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore._checkparam import check_int_positive from mindspore._checkparam import Validator
from ....cell import Cell from ....cell import Cell
from ....layer.basic import Dense from ....layer.basic import Dense
@ -52,8 +52,8 @@ class VAE(Cell):
self.decoder = decoder self.decoder = decoder
if (not isinstance(encoder, Cell)) or (not isinstance(decoder, Cell)): if (not isinstance(encoder, Cell)) or (not isinstance(decoder, Cell)):
raise TypeError('The encoder and decoder should be Cell type.') raise TypeError('The encoder and decoder should be Cell type.')
self.hidden_size = check_int_positive(hidden_size) self.hidden_size = Validator.check_positive_int(hidden_size)
self.latent_size = check_int_positive(latent_size) self.latent_size = Validator.check_positive_int(latent_size)
if hidden_size < latent_size: if hidden_size < latent_size:
raise ValueError('The latent_size should be less than or equal to the hidden_size.') raise ValueError('The latent_size should be less than or equal to the hidden_size.')
self.normal = C.normal self.normal = C.normal
@ -94,7 +94,7 @@ class VAE(Cell):
Returns: Returns:
Tensor, the generated samples. Tensor, the generated samples.
""" """
generate_nums = check_int_positive(generate_nums) generate_nums = Validator.check_positive_int(generate_nums)
if not isinstance(shape, tuple) or len(shape) != 4 or (shape[0] != -1 and shape[0] != generate_nums): if not isinstance(shape, tuple) or len(shape) != 4 or (shape[0] != -1 and shape[0] != generate_nums):
raise ValueError('The shape should be (generate_nums, C, H, W) or (-1, C, H, W).') raise ValueError('The shape should be (generate_nums, C, H, W) or (-1, C, H, W).')
sample_z = self.normal((generate_nums, self.latent_size), self.to_tensor(0.0), self.to_tensor(1.0), seed=0) sample_z = self.normal((generate_nums, self.latent_size), self.to_tensor(0.0), self.to_tensor(1.0), seed=0)

View File

@ -15,7 +15,7 @@
"""Stochastic Variational Inference(SVI).""" """Stochastic Variational Inference(SVI)."""
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore._checkparam import check_int_positive from mindspore._checkparam import Validator
from ....cell import Cell from ....cell import Cell
from ....wrap.cell_wrapper import TrainOneStepCell from ....wrap.cell_wrapper import TrainOneStepCell
from .elbo import ELBO from .elbo import ELBO
@ -57,7 +57,7 @@ class SVI:
Outputs: Outputs:
Cell, the trained probability network. Cell, the trained probability network.
""" """
epochs = check_int_positive(epochs) epochs = Validator.check_positive_int(epochs)
train_net = TrainOneStepCell(self.net_with_loss, self.optimizer) train_net = TrainOneStepCell(self.net_with_loss, self.optimizer)
train_net.set_train() train_net.set_train()
for _ in range(1, epochs+1): for _ in range(1, epochs+1):

View File

@ -16,7 +16,7 @@
from copy import deepcopy from copy import deepcopy
import numpy as np import numpy as np
from mindspore._checkparam import check_int_positive, Validator from mindspore._checkparam import Validator
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.train import Model from mindspore.train import Model
@ -81,7 +81,7 @@ class UncertaintyEvaluation:
self.epi_train_dataset = train_dataset self.epi_train_dataset = train_dataset
self.ale_train_dataset = deepcopy(train_dataset) self.ale_train_dataset = deepcopy(train_dataset)
self.task_type = task_type self.task_type = task_type
self.epochs = check_int_positive(epochs) self.epochs = Validator.check_positive_int(epochs)
self.epi_uncer_model_path = epi_uncer_model_path self.epi_uncer_model_path = epi_uncer_model_path
self.ale_uncer_model_path = ale_uncer_model_path self.ale_uncer_model_path = ale_uncer_model_path
self.save_model = Validator.check_bool(save_model) self.save_model = Validator.check_bool(save_model)
@ -95,7 +95,7 @@ class UncertaintyEvaluation:
if task_type not in ('regression', 'classification'): if task_type not in ('regression', 'classification'):
raise ValueError('The task should be regression or classification.') raise ValueError('The task should be regression or classification.')
if task_type == 'classification': if task_type == 'classification':
self.num_classes = check_int_positive(num_classes) self.num_classes = Validator.check_positive_int(num_classes)
else: else:
self.num_classes = num_classes self.num_classes = num_classes
if save_model: if save_model:

View File

@ -65,9 +65,9 @@ def get_broadcast_shape(x_shape, y_shape, prim_name):
def get_concat_offset(x_shp, x_type, axis, prim_name): def get_concat_offset(x_shp, x_type, axis, prim_name):
"""for concat and concatoffset check args and compute offset""" """for concat and concatoffset check args and compute offset"""
validator.check_value_type("shape", x_shp, [tuple], prim_name) validator.check_value_type("shape", x_shp, [tuple], prim_name)
validator.check_integer("input_x rank", len(x_shp), 0, Rel.GT, prim_name) validator.check_positive_int(len(x_shp), "input_x rank", prim_name)
validator.check_subclass("shape0", x_type[0], mstype.tensor, prim_name) validator.check_subclass("shape0", x_type[0], mstype.tensor, prim_name)
validator.check_integer("len of x_shp[0]", len(x_shp[0]), 0, Rel.GT, prim_name) validator.check_positive_int(len(x_shp[0]), "len of x_shp[0]", prim_name)
rank_base = len(x_shp[0]) rank_base = len(x_shp[0])
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH, prim_name) validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH, prim_name)
if axis < 0: if axis < 0:

View File

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""constexpr util""" """constexpr util"""
from functools import reduce from functools import reduce
import numpy as np import numpy as np
@ -60,30 +60,6 @@ def check_equal(param1, param2, msg="{},{}"):
return param1 return param1
@constexpr
def check_int_positive(arg_name, arg_value, op_name):
"""Int type judgment."""
if isinstance(arg_value, bool):
raise TypeError("For \'{}\' the `{}` must be int, cannot be bool.".format(op_name, arg_name))
if isinstance(arg_value, int):
if arg_value > 0:
return arg_value
raise ValueError("For \'{}\' the `{}` must be positive, but got {}.".format(op_name, arg_name, arg_value))
raise TypeError("For \'{}\' the `{}` must be int, cannot be {}.".format(op_name, arg_name, type(arg_value)))
@constexpr
def check_int_non_negative(arg_name, arg_value, op_name):
"""Int type judgment."""
if isinstance(arg_value, bool):
raise TypeError("For \'{}\' the `{}` must be int, cannot be bool.".format(op_name, arg_name))
if isinstance(arg_value, int):
if arg_value >= 0:
return arg_value
raise ValueError("For \'{}\' the `{}` must be non_negative, but got {}.".format(op_name, arg_name, arg_value))
raise TypeError("For \'{}\' the `{}` must be int, cannot be {}.".format(op_name, arg_name, type(arg_value)))
@constexpr @constexpr
def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size): def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size):
"""Checks the shape and size of the sensor and value.""" """Checks the shape and size of the sensor and value."""

View File

@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Operations for random number generators.""" """Operations for random number generators."""
from mindspore._checkparam import Validator
from .. import operations as P from .. import operations as P
from .. import functional as F from .. import functional as F
from ..primitive import constexpr from ..primitive import constexpr
@ -54,7 +54,7 @@ def get_seed(op_seed, kernel_name):
if op_seed is None: if op_seed is None:
temp_seed = _get_op_seed(0, kernel_name) temp_seed = _get_op_seed(0, kernel_name)
else: else:
const_utils.check_int_non_negative("seed", op_seed, kernel_name) Validator.check_non_negative_int(op_seed, "seed", kernel_name)
temp_seed = _get_op_seed(op_seed, kernel_name) temp_seed = _get_op_seed(op_seed, kernel_name)
seeds = _truncate_seed(global_seed), _truncate_seed(temp_seed) seeds = _truncate_seed(global_seed), _truncate_seed(temp_seed)
_update_seeds(op_seed, kernel_name) _update_seeds(op_seed, kernel_name)

View File

@ -915,9 +915,9 @@ class LSTMGradData(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
self.input_size = validator.check_integer('input_size', input_size, 0, Rel.GT, self.name) self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
self.hidden_size = validator.check_integer('hidden_size', hidden_size, 0, Rel.GT, self.name) self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
self.num_layers = validator.check_integer('num_layers', num_layers, 0, Rel.GT, self.name) self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name) self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name) self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
@ -964,9 +964,9 @@ class LSTMGradWeight(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
self.input_size = validator.check_integer('input_size', input_size, 0, Rel.GT, self.name) self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
self.hidden_size = validator.check_integer('hidden_size', hidden_size, 0, Rel.GT, self.name) self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
self.num_layers = validator.check_integer('num_layers', num_layers, 0, Rel.GT, self.name) self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name) self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name) self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
@ -999,9 +999,9 @@ class LSTMGrad(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
self.input_size = validator.check_integer('input_size', input_size, 0, Rel.GT, self.name) self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
self.hidden_size = validator.check_integer('hidden_size', hidden_size, 0, Rel.GT, self.name) self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
self.num_layers = validator.check_integer('num_layers', num_layers, 0, Rel.GT, self.name) self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name) self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name) self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)

View File

@ -701,7 +701,7 @@ class Padding(PrimitiveWithInfer):
def __init__(self, pad_dim_size=8): def __init__(self, pad_dim_size=8):
"""Initialize padding""" """Initialize padding"""
validator.check_value_type("pad_dim_size", pad_dim_size, [int], self.name) validator.check_value_type("pad_dim_size", pad_dim_size, [int], self.name)
validator.check_integer("pad_dim_size", pad_dim_size, 0, Rel.GT, self.name) validator.check_positive_int(pad_dim_size, "pad_dim_size", self.name)
self.pad_dim_size = pad_dim_size self.pad_dim_size = pad_dim_size
def __infer__(self, x): def __infer__(self, x):
@ -911,8 +911,8 @@ class Fill(PrimitiveWithInfer):
def __infer__(self, dtype, dims, x): def __infer__(self, dtype, dims, x):
validator.check_value_type("shape", dims['value'], [tuple], self.name) validator.check_value_type("shape", dims['value'], [tuple], self.name)
validator.check_value_type("value", x['value'], [numbers.Number, bool], self.name) validator.check_value_type("value", x['value'], [numbers.Number, bool], self.name)
for idx, item in enumerate(dims['value']): for i, item in enumerate(dims['value']):
validator.check_integer("dims[%d]" % idx, item, 0, Rel.GT, self.name) validator.check_positive_int(item, f'dims[{i}]', self.name)
valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64, valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64,
mstype.uint8, mstype.uint32, mstype.uint64, mstype.uint8, mstype.uint32, mstype.uint64,
mstype.float16, mstype.float32, mstype.float64] mstype.float16, mstype.float32, mstype.float64]
@ -1482,20 +1482,20 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
validator.check_subclass("input_x", x_type, mstype.tensor, self.name) validator.check_subclass("input_x", x_type, mstype.tensor, self.name)
validator.check_value_type("x_shape", x_shp, [list], self.name) validator.check_value_type("x_shape", x_shp, [list], self.name)
x_shp_len = len(x_shp) x_shp_len = len(x_shp)
validator.check_integer("rank of input_x", x_shp_len, 0, Rel.GT, self.name) validator.check_positive_int(x_shp_len, "rank of input_x", self.name)
segment_ids_shp = segment_ids['shape'] segment_ids_shp = segment_ids['shape']
segment_ids_type = segment_ids['dtype'] segment_ids_type = segment_ids['dtype']
validator.check_subclass("segment_ids", segment_ids_type, mstype.tensor, self.name) validator.check_subclass("segment_ids", segment_ids_type, mstype.tensor, self.name)
validator.check_value_type("segment_ids", segment_ids_shp, [list], self.name) validator.check_value_type("segment_ids", segment_ids_shp, [list], self.name)
segment_ids_shp_len = len(segment_ids_shp) segment_ids_shp_len = len(segment_ids_shp)
validator.check_integer("rank of segment_ids", segment_ids_shp_len, 0, Rel.GT, self.name) validator.check_positive_int(segment_ids_shp_len, "rank of segment_ids", self.name)
validator.check(f'rank of input_x', len(x_shp), validator.check(f'rank of input_x', len(x_shp),
'rank of segments_id', len(segment_ids_shp), Rel.GE, self.name) 'rank of segments_id', len(segment_ids_shp), Rel.GE, self.name)
for i, value in enumerate(segment_ids_shp): for i, value in enumerate(segment_ids_shp):
validator.check("ids[%d]" % i, value, 'input[%d]' % i, x_shp[i], Rel.EQ, self.name) validator.check("ids[%d]" % i, value, 'input[%d]' % i, x_shp[i], Rel.EQ, self.name)
num_segments_v = num_segments['value'] num_segments_v = num_segments['value']
validator.check_value_type('num_segments', num_segments_v, [int], self.name) validator.check_value_type('num_segments', num_segments_v, [int], self.name)
validator.check_integer("num_segments", num_segments_v, 0, Rel.GT, self.name) validator.check_positive_int(num_segments_v, "num_segments", self.name)
shp = [num_segments_v] shp = [num_segments_v]
shp += x_shp[segment_ids_shp_len:] shp += x_shp[segment_ids_shp_len:]
out = {'shape': shp, out = {'shape': shp,
@ -1544,7 +1544,7 @@ class UnsortedSegmentMin(PrimitiveWithInfer):
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
num_segments_v = num_segments['value'] num_segments_v = num_segments['value']
validator.check_value_type('num_segments', num_segments_v, [int], self.name) validator.check_value_type('num_segments', num_segments_v, [int], self.name)
validator.check_integer("num_segments", num_segments_v, 0, Rel.GT, self.name) validator.check_positive_int(num_segments_v, "num_segments", self.name)
segment_ids_shape_len = len(segment_ids_shape) segment_ids_shape_len = len(segment_ids_shape)
out_shape = [num_segments_v] out_shape = [num_segments_v]
out_shape += x_shape[segment_ids_shape_len:] out_shape += x_shape[segment_ids_shape_len:]
@ -1597,7 +1597,7 @@ class UnsortedSegmentProd(PrimitiveWithInfer):
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
num_segments_v = num_segments['value'] num_segments_v = num_segments['value']
validator.check_value_type('num_segments', num_segments_v, [int], self.name) validator.check_value_type('num_segments', num_segments_v, [int], self.name)
validator.check_integer("num_segments", num_segments_v, 0, Rel.GT, self.name) validator.check_positive_int(num_segments_v, "num_segments", self.name)
segment_ids_shape_len = len(segment_ids_shape) segment_ids_shape_len = len(segment_ids_shape)
out_shape = [num_segments_v] out_shape = [num_segments_v]
out_shape += x_shape[segment_ids_shape_len:] out_shape += x_shape[segment_ids_shape_len:]
@ -1832,7 +1832,7 @@ class Unpack(PrimitiveWithInfer):
self.axis = self.axis + dim self.axis = self.axis + dim
output_num = x_shape[self.axis] output_num = x_shape[self.axis]
validator.check_value_type("num", output_num, [int], self.name) validator.check_value_type("num", output_num, [int], self.name)
validator.check_integer("output_num", output_num, 0, Rel.GT, self.name) validator.check_positive_int(output_num, "output_num", self.name)
self.add_prim_attr('num', output_num) self.add_prim_attr('num', output_num)
output_valid_check = x_shape[self.axis] - output_num output_valid_check = x_shape[self.axis] - output_num
validator.check_integer("The dimension which to unpack divides output_num", output_valid_check, 0, Rel.EQ, validator.check_integer("The dimension which to unpack divides output_num", output_valid_check, 0, Rel.EQ,
@ -2401,8 +2401,8 @@ class Eye(PrimitiveWithInfer):
"""Initialize Eye""" """Initialize Eye"""
def infer_value(self, n, m, t): def infer_value(self, n, m, t):
validator.check_integer("n", n, 0, Rel.GT, self.name) validator.check_positive_int(n, "n", self.name)
validator.check_integer("m", m, 0, Rel.GT, self.name) validator.check_positive_int(m, "m", self.name)
args = {"dtype": t} args = {"dtype": t}
validator.check_type_same(args, mstype.number_type + (mstype.bool_,), self.name) validator.check_type_same(args, mstype.number_type + (mstype.bool_,), self.name)
np_type = mstype.dtype_to_nptype(t) np_type = mstype.dtype_to_nptype(t)
@ -2443,7 +2443,7 @@ class ScatterNd(PrimitiveWithInfer):
validator.check_tensor_type_same({"indices": indices['dtype']}, [mstype.int32], self.name) validator.check_tensor_type_same({"indices": indices['dtype']}, [mstype.int32], self.name)
validator.check_value_type("shape", shp, [tuple], self.name) validator.check_value_type("shape", shp, [tuple], self.name)
for i, x in enumerate(shp): for i, x in enumerate(shp):
validator.check_integer("shape[%d]" % i, x, 0, Rel.GT, self.name) validator.check_positive_int(x, f'shape[{i}]', self.name)
indices_shape, update_shape = indices["shape"], update["shape"] indices_shape, update_shape = indices["shape"], update["shape"]
if indices_shape[0] != update_shape[0]: if indices_shape[0] != update_shape[0]:
@ -3469,7 +3469,7 @@ class BroadcastTo(PrimitiveWithInfer):
validator.check_value_type("shape", shape, (tuple), self.name) validator.check_value_type("shape", shape, (tuple), self.name)
validator.check("shape length", len(shape), "", 0, Rel.GT, self.name) validator.check("shape length", len(shape), "", 0, Rel.GT, self.name)
for i in shape: for i in shape:
validator.check_integer("shape element", i, 0, Rel.GT, self.name) validator.check_positive_int(i, "shape element", self.name)
self.shape = shape self.shape = shape
def infer_shape(self, x_shape): def infer_shape(self, x_shape):

View File

@ -160,7 +160,7 @@ class AllGather(PrimitiveWithInfer):
self.add_prim_attr('group', _get_group(group)) self.add_prim_attr('group', _get_group(group))
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
validator.check_integer("x shape", len(x_shape), 0, Rel.GT, self.name) validator.check_positive_int(len(x_shape), "x shape", self.name)
x_shape[0] = x_shape[0] * self.rank_size x_shape[0] = x_shape[0] * self.rank_size
return x_shape return x_shape
@ -210,7 +210,7 @@ class _HostAllGather(PrimitiveWithInfer):
self.add_prim_attr('group', group) self.add_prim_attr('group', group)
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
validator.check_integer("x shape", len(x_shape), 0, Rel.GT, self.name) validator.check_positive_int(len(x_shape), "x shape", self.name)
x_shape[0] = x_shape[0] * self.group_size x_shape[0] = x_shape[0] * self.group_size
return x_shape return x_shape

View File

@ -1005,8 +1005,8 @@ class Conv2D(PrimitiveWithInfer):
self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name) self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name)
self.add_prim_attr('data_format', "NCHW") self.add_prim_attr('data_format', "NCHW")
self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT, self.name) self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
self.group = validator.check_integer('group', group, 0, Rel.GT, self.name) self.group = validator.check_positive_int(group, 'group', self.name)
self.add_prim_attr('offset_a', 0) self.add_prim_attr('offset_a', 0)
def infer_shape(self, x_shape, w_shape, b_shape=None): def infer_shape(self, x_shape, w_shape, b_shape=None):
@ -1142,9 +1142,8 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
validator.check_integer('pad item', item, 0, Rel.GE, self.name) validator.check_integer('pad item', item, 0, Rel.GE, self.name)
self.mode = validator.check_integer("mode", mode, 3, Rel.EQ, self.name) self.mode = validator.check_integer("mode", mode, 3, Rel.EQ, self.name)
self.add_prim_attr('data_format', "NCHW") self.add_prim_attr('data_format', "NCHW")
self.channel_multiplier = validator.check_integer("channel_multiplier", channel_multiplier, 0, Rel.GT, self.channel_multiplier = validator.check_positive_int(channel_multiplier, "channel_multiplier", self.name)
self.name) self.group = validator.check_positive_int(group, "group", self.name)
self.group = validator.check_integer("group", group, 0, Rel.GT, self.name)
self.add_prim_attr('offset_a', 0) self.add_prim_attr('offset_a', 0)
def infer_shape(self, x_shape, w_shape, b_shape=None): def infer_shape(self, x_shape, w_shape, b_shape=None):
@ -1508,7 +1507,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
group=1): group=1):
"""Initialize Conv2DBackpropInput""" """Initialize Conv2DBackpropInput"""
self.init_prim_io_names(inputs=['out_backprop', 'filter', 'input_sizes'], outputs=['output']) self.init_prim_io_names(inputs=['out_backprop', 'filter', 'input_sizes'], outputs=['output'])
self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT, self.name) self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name)
self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=False) self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=False)
self.add_prim_attr('stride', self.stride) self.add_prim_attr('stride', self.stride)
@ -1531,7 +1530,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
pad_mode = pad_mode.upper() pad_mode = pad_mode.upper()
self.add_prim_attr('pad_mode', pad_mode) self.add_prim_attr('pad_mode', pad_mode)
self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name) self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name)
self.group = validator.check_integer('group', group, 0, Rel.GT, self.name) self.group = validator.check_positive_int(group, 'group', self.name)
self.add_prim_attr('data_format', "NCHW") self.add_prim_attr('data_format', "NCHW")
if pad_list: if pad_list:
for x in pad_list: for x in pad_list:
@ -2062,10 +2061,10 @@ class SGD(PrimitiveWithInfer):
def infer_shape(self, parameters_shape, gradient_shape, learning_rate_shape, def infer_shape(self, parameters_shape, gradient_shape, learning_rate_shape,
accum_shape, momentum_shape, stat_shape): accum_shape, momentum_shape, stat_shape):
validator.check_integer(f'parameters rank', len(parameters_shape), 0, Rel.GT, self.name) validator.check_positive_int(len(parameters_shape), "parameters rank", self.name)
validator.check_integer(f'gradient rank', len(gradient_shape), 0, Rel.GE, self.name) validator.check_integer(f'gradient rank', len(gradient_shape), 0, Rel.GE, self.name)
validator.check_integer(f'learning rate rank', len(learning_rate_shape), 0, Rel.GE, self.name) validator.check_integer(f'learning rate rank', len(learning_rate_shape), 0, Rel.GE, self.name)
validator.check_integer(f'accumulation rank', len(accum_shape), 0, Rel.GT, self.name) validator.check_positive_int(len(accum_shape), "accumulation rank", self.name)
validator.check_integer(f'momentum rank', len(momentum_shape), 0, Rel.GE, self.name) validator.check_integer(f'momentum rank', len(momentum_shape), 0, Rel.GE, self.name)
validator.check_integer(f'stat rank', len(stat_shape), 0, Rel.GE, self.name) validator.check_integer(f'stat rank', len(stat_shape), 0, Rel.GE, self.name)
validator.check("gradient shape", gradient_shape, "stat shape", stat_shape, Rel.EQ, self.name) validator.check("gradient shape", gradient_shape, "stat shape", stat_shape, Rel.EQ, self.name)
@ -2748,9 +2747,9 @@ class LSTM(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
self.input_size = validator.check_integer("input_size", input_size, 0, Rel.GT, self.name) self.input_size = validator.check_positive_int(input_size, "input_size", self.name)
self.hidden_size = validator.check_integer("hidden_size", hidden_size, 0, Rel.GT, self.name) self.hidden_size = validator.check_positive_int(hidden_size, "hidden_size", self.name)
self.num_layers = validator.check_integer("num_layers", num_layers, 0, Rel.GT, self.name) self.num_layers = validator.check_positive_int(num_layers, "num_layers", self.name)
self.has_bias = validator.check_value_type("has_bias", has_bias, (bool,), self.name) self.has_bias = validator.check_value_type("has_bias", has_bias, (bool,), self.name)
self.bidirectional = validator.check_value_type("bidirectional", bidirectional, (bool,), self.name) self.bidirectional = validator.check_value_type("bidirectional", bidirectional, (bool,), self.name)
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)

View File

@ -12,11 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Operators for random.""" """Operators for random."""
from ..._checkparam import Validator as validator from ..._checkparam import Validator, Rel
from ..._checkparam import Rel
from ...common import dtype as mstype from ...common import dtype as mstype
from ..primitive import PrimitiveWithInfer, prim_attr_register from ..primitive import PrimitiveWithInfer, prim_attr_register
from .._utils import get_broadcast_shape from .._utils import get_broadcast_shape
@ -46,16 +44,16 @@ class StandardNormal(PrimitiveWithInfer):
def __init__(self, seed=0, seed2=0): def __init__(self, seed=0, seed2=0):
"""Initialize StandardNormal""" """Initialize StandardNormal"""
self.init_prim_io_names(inputs=['shape'], outputs=['output']) self.init_prim_io_names(inputs=['shape'], outputs=['output'])
validator.check_integer("seed", seed, 0, Rel.GE, self.name) Validator.check_integer("seed", seed, 0, Rel.GE, self.name)
validator.check_integer("seed2", seed2, 0, Rel.GE, self.name) Validator.check_integer("seed2", seed2, 0, Rel.GE, self.name)
def __infer__(self, shape): def __infer__(self, shape):
shape_v = shape["value"] shape_v = shape["value"]
if shape_v is None: if shape_v is None:
raise ValueError(f"For {self.name}, shape must be const.") raise ValueError(f"For {self.name}, shape must be const.")
validator.check_value_type("shape", shape_v, [tuple], self.name) Validator.check_value_type("shape", shape_v, [tuple], self.name)
for i, shape_i in enumerate(shape_v): for i, shape_i in enumerate(shape_v):
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name) Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
out = { out = {
'shape': shape_v, 'shape': shape_v,
'dtype': mstype.float32, 'dtype': mstype.float32,
@ -91,16 +89,16 @@ class StandardLaplace(PrimitiveWithInfer):
def __init__(self, seed=0, seed2=0): def __init__(self, seed=0, seed2=0):
"""Initialize StandardLaplace""" """Initialize StandardLaplace"""
self.init_prim_io_names(inputs=['shape'], outputs=['output']) self.init_prim_io_names(inputs=['shape'], outputs=['output'])
validator.check_value_type('seed', seed, [int], self.name) Validator.check_value_type('seed', seed, [int], self.name)
validator.check_value_type('seed2', seed2, [int], self.name) Validator.check_value_type('seed2', seed2, [int], self.name)
def __infer__(self, shape): def __infer__(self, shape):
shape_v = shape["value"] shape_v = shape["value"]
if shape_v is None: if shape_v is None:
raise ValueError(f"For {self.name}, shape must be const.") raise ValueError(f"For {self.name}, shape must be const.")
validator.check_value_type("shape", shape_v, [tuple], self.name) Validator.check_value_type("shape", shape_v, [tuple], self.name)
for i, shape_i in enumerate(shape_v): for i, shape_i in enumerate(shape_v):
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name) Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
out = { out = {
'shape': shape_v, 'shape': shape_v,
'dtype': mstype.float32, 'dtype': mstype.float32,
@ -143,18 +141,18 @@ class Gamma(PrimitiveWithInfer):
def __init__(self, seed=0, seed2=0): def __init__(self, seed=0, seed2=0):
"""Initialize Gamma""" """Initialize Gamma"""
self.init_prim_io_names(inputs=['shape', 'alpha', 'beta'], outputs=['output']) self.init_prim_io_names(inputs=['shape', 'alpha', 'beta'], outputs=['output'])
validator.check_integer("seed", seed, 0, Rel.GE, self.name) Validator.check_integer("seed", seed, 0, Rel.GE, self.name)
validator.check_integer("seed2", seed2, 0, Rel.GE, self.name) Validator.check_integer("seed2", seed2, 0, Rel.GE, self.name)
def __infer__(self, shape, alpha, beta): def __infer__(self, shape, alpha, beta):
shape_v = shape["value"] shape_v = shape["value"]
if shape_v is None: if shape_v is None:
raise ValueError(f"For {self.name}, shape must be const.") raise ValueError(f"For {self.name}, shape must be const.")
validator.check_value_type("shape", shape_v, [tuple], self.name) Validator.check_value_type("shape", shape_v, [tuple], self.name)
for i, shape_i in enumerate(shape_v): for i, shape_i in enumerate(shape_v):
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name) Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
validator.check_tensor_type_same({"alpha": alpha["dtype"]}, [mstype.float32], self.name) Validator.check_tensor_type_same({"alpha": alpha["dtype"]}, [mstype.float32], self.name)
validator.check_tensor_type_same({"beta": beta["dtype"]}, [mstype.float32], self.name) Validator.check_tensor_type_same({"beta": beta["dtype"]}, [mstype.float32], self.name)
broadcast_shape = get_broadcast_shape(alpha['shape'], beta['shape'], self.name) broadcast_shape = get_broadcast_shape(alpha['shape'], beta['shape'], self.name)
broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name) broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name)
out = { out = {
@ -195,17 +193,17 @@ class Poisson(PrimitiveWithInfer):
def __init__(self, seed=0, seed2=0): def __init__(self, seed=0, seed2=0):
"""Initialize Poisson""" """Initialize Poisson"""
self.init_prim_io_names(inputs=['shape', 'mean'], outputs=['output']) self.init_prim_io_names(inputs=['shape', 'mean'], outputs=['output'])
validator.check_integer("seed", seed, 0, Rel.GE, self.name) Validator.check_integer("seed", seed, 0, Rel.GE, self.name)
validator.check_integer("seed2", seed2, 0, Rel.GE, self.name) Validator.check_integer("seed2", seed2, 0, Rel.GE, self.name)
def __infer__(self, shape, mean): def __infer__(self, shape, mean):
shape_v = shape["value"] shape_v = shape["value"]
if shape_v is None: if shape_v is None:
raise ValueError(f"For {self.name}, shape must be const.") raise ValueError(f"For {self.name}, shape must be const.")
validator.check_value_type("shape", shape_v, [tuple], self.name) Validator.check_value_type("shape", shape_v, [tuple], self.name)
for i, shape_i in enumerate(shape_v): for i, shape_i in enumerate(shape_v):
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name) Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
validator.check_tensor_type_same({"mean": mean["dtype"]}, [mstype.float32], self.name) Validator.check_tensor_type_same({"mean": mean["dtype"]}, [mstype.float32], self.name)
broadcast_shape = get_broadcast_shape(mean['shape'], shape_v, self.name) broadcast_shape = get_broadcast_shape(mean['shape'], shape_v, self.name)
out = { out = {
'shape': broadcast_shape, 'shape': broadcast_shape,
@ -251,22 +249,22 @@ class UniformInt(PrimitiveWithInfer):
def __init__(self, seed=0, seed2=0): def __init__(self, seed=0, seed2=0):
"""Initialize UniformInt""" """Initialize UniformInt"""
self.init_prim_io_names(inputs=['shape', 'minval', 'maxval'], outputs=['output']) self.init_prim_io_names(inputs=['shape', 'minval', 'maxval'], outputs=['output'])
validator.check_integer("seed", seed, 0, Rel.GE, self.name) Validator.check_integer("seed", seed, 0, Rel.GE, self.name)
validator.check_integer("seed2", seed2, 0, Rel.GE, self.name) Validator.check_integer("seed2", seed2, 0, Rel.GE, self.name)
def __infer__(self, shape, minval, maxval): def __infer__(self, shape, minval, maxval):
shape_v = shape["value"] shape_v = shape["value"]
if shape_v is None: if shape_v is None:
raise ValueError(f"For {self.name}, shape must be const.") raise ValueError(f"For {self.name}, shape must be const.")
validator.check_value_type("shape", shape_v, [tuple], self.name) Validator.check_value_type("shape", shape_v, [tuple], self.name)
for i, shape_i in enumerate(shape_v): for i, shape_i in enumerate(shape_v):
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name) Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
validator.check_tensor_type_same({"minval": minval["dtype"]}, [mstype.int32], self.name) Validator.check_tensor_type_same({"minval": minval["dtype"]}, [mstype.int32], self.name)
validator.check_tensor_type_same({"maxval": maxval["dtype"]}, [mstype.int32], self.name) Validator.check_tensor_type_same({"maxval": maxval["dtype"]}, [mstype.int32], self.name)
minval_shape = minval['shape'] minval_shape = minval['shape']
maxval_shape = maxval['shape'] maxval_shape = maxval['shape']
validator.check("dim of minval", len(minval_shape), '0(scalar)', 0, Rel.EQ, self.name) Validator.check("dim of minval", len(minval_shape), '0(scalar)', 0, Rel.EQ, self.name)
validator.check("dim of maxval", len(maxval_shape), '0(scalar)', 0, Rel.EQ, self.name) Validator.check("dim of maxval", len(maxval_shape), '0(scalar)', 0, Rel.EQ, self.name)
out = { out = {
'shape': shape_v, 'shape': shape_v,
'dtype': mstype.int32, 'dtype': mstype.int32,
@ -298,16 +296,16 @@ class UniformReal(PrimitiveWithInfer):
def __init__(self, seed=0, seed2=0): def __init__(self, seed=0, seed2=0):
"""Initialize UniformReal""" """Initialize UniformReal"""
self.init_prim_io_names(inputs=['shape'], outputs=['output']) self.init_prim_io_names(inputs=['shape'], outputs=['output'])
validator.check_integer("seed", seed, 0, Rel.GE, self.name) Validator.check_integer("seed", seed, 0, Rel.GE, self.name)
validator.check_integer("seed2", seed2, 0, Rel.GE, self.name) Validator.check_integer("seed2", seed2, 0, Rel.GE, self.name)
def __infer__(self, shape): def __infer__(self, shape):
shape_v = shape["value"] shape_v = shape["value"]
if shape_v is None: if shape_v is None:
raise ValueError(f"For {self.name}, shape must be const.") raise ValueError(f"For {self.name}, shape must be const.")
validator.check_value_type("shape", shape_v, [tuple], self.name) Validator.check_value_type("shape", shape_v, [tuple], self.name)
for i, shape_i in enumerate(shape_v): for i, shape_i in enumerate(shape_v):
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name) Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
out = { out = {
'shape': shape_v, 'shape': shape_v,
'dtype': mstype.float32, 'dtype': mstype.float32,
@ -348,18 +346,18 @@ class RandomChoiceWithMask(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, count=256, seed=0, seed2=0): def __init__(self, count=256, seed=0, seed2=0):
"""Initialize RandomChoiceWithMask""" """Initialize RandomChoiceWithMask"""
validator.check_value_type("count", count, [int], self.name) Validator.check_value_type("count", count, [int], self.name)
validator.check_integer("count", count, 0, Rel.GT, self.name) Validator.check_positive_int(count, "count", self.name)
validator.check_value_type('seed', seed, [int], self.name) Validator.check_value_type('seed', seed, [int], self.name)
validator.check_value_type('seed2', seed2, [int], self.name) Validator.check_value_type('seed2', seed2, [int], self.name)
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
validator.check_integer("input_x rank", len(x_shape), 1, Rel.GE, self.name) Validator.check_integer("input_x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check_integer("input_x rank", len(x_shape), 5, Rel.LE, self.name) Validator.check_integer("input_x rank", len(x_shape), 5, Rel.LE, self.name)
return ([self.count, len(x_shape)], [self.count]) return ([self.count, len(x_shape)], [self.count])
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, [mstype.bool_], self.name) Validator.check_tensor_type_same({'x': x_dtype}, [mstype.bool_], self.name)
return (mstype.int32, mstype.bool_) return (mstype.int32, mstype.bool_)
@ -399,19 +397,19 @@ class RandomCategorical(PrimitiveWithInfer):
self.dtype = dtype self.dtype = dtype
valid_values = (mstype.int32, mstype.int16, mstype.int64) valid_values = (mstype.int32, mstype.int16, mstype.int64)
validator.check_type_name("dtype", dtype, valid_values, self.name) Validator.check_type_name("dtype", dtype, valid_values, self.name)
self.init_prim_io_names(inputs=['logits', 'num_samples', 'seed'], self.init_prim_io_names(inputs=['logits', 'num_samples', 'seed'],
outputs=['output']) outputs=['output'])
def __infer__(self, logits, num_samples, seed): def __infer__(self, logits, num_samples, seed):
logits_dtype = logits['dtype'] logits_dtype = logits['dtype']
valid_types = (mstype.float32, mstype.float16, mstype.float64) valid_types = (mstype.float32, mstype.float16, mstype.float64)
validator.check_tensor_type_same({'logits': logits_dtype}, valid_types, self.name) Validator.check_tensor_type_same({'logits': logits_dtype}, valid_types, self.name)
num_samples_v = num_samples['value'] num_samples_v = num_samples['value']
seed_v = seed['value'] seed_v = seed['value']
validator.check_value_type('num_samples', num_samples_v, (int,), self.name) Validator.check_value_type('num_samples', num_samples_v, (int,), self.name)
validator.check_value_type('seed', seed_v, (int,), self.name) Validator.check_value_type('seed', seed_v, (int,), self.name)
validator.check_integer("num_samples", num_samples_v, 0, Rel.GT, self.name) Validator.check_positive_int(num_samples_v, "num_samples", self.name)
x_shape = list(logits['shape']) x_shape = list(logits['shape'])
if len(x_shape) != 2: if len(x_shape) != 2:
raise ValueError("RandomCategorical shape should be 2-dimension.") raise ValueError("RandomCategorical shape should be 2-dimension.")
@ -450,20 +448,20 @@ class Multinomial(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, seed=0): def __init__(self, seed=0):
"""init""" """init"""
validator.check_value_type("seed", seed, [int], self.name) Validator.check_value_type("seed", seed, [int], self.name)
validator.check_integer("seed", seed, 0, Rel.GE, self.name) Validator.check_integer("seed", seed, 0, Rel.GE, self.name)
self.init_prim_io_names(inputs=['input', 'num_sample'], outputs=['output']) self.init_prim_io_names(inputs=['input', 'num_sample'], outputs=['output'])
def __infer__(self, inputs, num_samples): def __infer__(self, inputs, num_samples):
input_shape = inputs["shape"] input_shape = inputs["shape"]
if len(input_shape) != 1 and len(input_shape) != 2: if len(input_shape) != 1 and len(input_shape) != 2:
raise ValueError("input dim must be 1 or 2") raise ValueError("input dim must be 1 or 2")
validator.check_tensor_type_same({'inputs': inputs['dtype']}, [mstype.float32], self.name) Validator.check_tensor_type_same({'inputs': inputs['dtype']}, [mstype.float32], self.name)
num_samples_value = num_samples["value"] num_samples_value = num_samples["value"]
if num_samples_value is None: if num_samples_value is None:
raise ValueError(f"For {self.name}, shape nust be const") raise ValueError(f"For {self.name}, shape nust be const")
validator.check_value_type("num_samples", num_samples_value, (int,), self.name) Validator.check_value_type("num_samples", num_samples_value, (int,), self.name)
validator.check_integer("num_samples", num_samples_value, 0, Rel.GT, None) Validator.check_positive_int(num_samples_value, "num_samples")
y_shape = (num_samples_value,) y_shape = (num_samples_value,)
if len(input_shape) == 2: if len(input_shape) == 2:
y_shape = (input_shape[0], num_samples_value) y_shape = (input_shape[0], num_samples_value)

View File

@ -21,7 +21,7 @@ import time
import threading import threading
import mindspore.context as context import mindspore.context as context
from mindspore import log as logger from mindspore import log as logger
from mindspore._checkparam import Validator, check_int_non_negative from mindspore._checkparam import Validator
from mindspore.train._utils import _make_directory from mindspore.train._utils import _make_directory
from mindspore.train.serialization import save_checkpoint, _save_graph from mindspore.train.serialization import save_checkpoint, _save_graph
from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank
@ -107,13 +107,13 @@ class CheckpointConfig:
async_save=False): async_save=False):
if save_checkpoint_steps is not None: if save_checkpoint_steps is not None:
save_checkpoint_steps = check_int_non_negative(save_checkpoint_steps) save_checkpoint_steps = Validator.check_non_negative_int(save_checkpoint_steps)
if save_checkpoint_seconds is not None: if save_checkpoint_seconds is not None:
save_checkpoint_seconds = check_int_non_negative(save_checkpoint_seconds) save_checkpoint_seconds = Validator.check_non_negative_int(save_checkpoint_seconds)
if keep_checkpoint_max is not None: if keep_checkpoint_max is not None:
keep_checkpoint_max = check_int_non_negative(keep_checkpoint_max) keep_checkpoint_max = Validator.check_non_negative_int(keep_checkpoint_max)
if keep_checkpoint_per_n_minutes is not None: if keep_checkpoint_per_n_minutes is not None:
keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes) keep_checkpoint_per_n_minutes = Validator.check_non_negative_int(keep_checkpoint_per_n_minutes)
if not save_checkpoint_steps and not save_checkpoint_seconds and \ if not save_checkpoint_steps and not save_checkpoint_seconds and \
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes: not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:

View File

@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Loss scale manager abstract class.""" """Loss scale manager abstract class."""
from .._checkparam import Validator as validator from .._checkparam import Validator as validator
from .._checkparam import Rel
from .. import nn from .. import nn
__all__ = ["LossScaleManager", "FixedLossScaleManager", "DynamicLossScaleManager"] __all__ = ["LossScaleManager", "FixedLossScaleManager", "DynamicLossScaleManager"]
@ -97,7 +97,7 @@ class DynamicLossScaleManager(LossScaleManager):
if init_loss_scale < 1.0: if init_loss_scale < 1.0:
raise ValueError("Loss scale value should be > 1") raise ValueError("Loss scale value should be > 1")
self.loss_scale = init_loss_scale self.loss_scale = init_loss_scale
validator.check_integer("scale_window", scale_window, 0, Rel.GT, self.__class__.__name__) validator.check_positive_int(scale_window, "scale_window", self.__class__.__name__)
self.scale_window = scale_window self.scale_window = scale_window
if scale_factor <= 0: if scale_factor <= 0:
raise ValueError("Scale factor should be > 1") raise ValueError("Scale factor should be > 1")

View File

@ -22,7 +22,7 @@ import numpy as np
from mindspore import log as logger from mindspore import log as logger
from ..common.tensor import Tensor from ..common.tensor import Tensor
from ..nn.metrics import get_metrics from ..nn.metrics import get_metrics
from .._checkparam import check_input_data, check_output_data, check_int_positive, Validator, check_int from .._checkparam import check_input_data, check_output_data, Validator, check_int
from .callback import _InternalCallbackParam, RunContext, _CallbackManager from .callback import _InternalCallbackParam, RunContext, _CallbackManager
from .. import context from .. import context
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
@ -339,7 +339,7 @@ class Model:
dataset not sink. dataset not sink.
sink_size (int): Control the amount of data in each sink. Default: -1. sink_size (int): Control the amount of data in each sink. Default: -1.
""" """
epoch = check_int_positive(epoch) epoch = Validator.check_positive_int(epoch)
if self._parameter_broadcast: if self._parameter_broadcast:
self._train_network.set_broadcast_flag() self._train_network.set_broadcast_flag()

View File

@ -16,7 +16,7 @@
import numpy as np import numpy as np
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore._checkparam import Validator, twice, check_int_positive from mindspore._checkparam import Validator, twice
from mindspore._extends import cell_attr_register from mindspore._extends import cell_attr_register
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
@ -292,8 +292,8 @@ class Dense_Thor_GPU(Cell):
has_bias=True, has_bias=True,
activation=None): activation=None):
super(Dense_Thor_GPU, self).__init__() super(Dense_Thor_GPU, self).__init__()
self.in_channels = check_int_positive(in_channels) self.in_channels = Validator.check_positive_int(in_channels)
self.out_channels = check_int_positive(out_channels) self.out_channels = Validator.check_positive_int(out_channels)
self.has_bias = Validator.check_bool(has_bias) self.has_bias = Validator.check_bool(has_bias)
self.thor = True self.thor = True
if isinstance(weight_init, Tensor): if isinstance(weight_init, Tensor):
@ -641,8 +641,8 @@ class Dense_Thor(Cell):
has_bias=True, has_bias=True,
activation=None): activation=None):
super(Dense_Thor, self).__init__() super(Dense_Thor, self).__init__()
self.in_channels = check_int_positive(in_channels) self.in_channels = Validator.check_positive_int(in_channels)
self.out_channels = check_int_positive(out_channels) self.out_channels = Validator.check_positive_int(out_channels)
self.has_bias = Validator.check_bool(has_bias) self.has_bias = Validator.check_bool(has_bias)
self.thor = True self.thor = True
self.batch_size = batch_size self.batch_size = batch_size

View File

@ -19,7 +19,7 @@ from mindspore.ops import functional as F
from mindspore._extends import cell_attr_register from mindspore._extends import cell_attr_register
from mindspore import Tensor, Parameter from mindspore import Tensor, Parameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore._checkparam import check_int_positive, Validator from mindspore._checkparam import Validator
from mindspore.nn.layer.activation import get_activation from mindspore.nn.layer.activation import get_activation
@ -72,8 +72,8 @@ class GNNFeatureTransform(nn.Cell):
bias_init='zeros', bias_init='zeros',
has_bias=True): has_bias=True):
super(GNNFeatureTransform, self).__init__() super(GNNFeatureTransform, self).__init__()
self.in_channels = check_int_positive(in_channels) self.in_channels = Validator.check_positive_int(in_channels)
self.out_channels = check_int_positive(out_channels) self.out_channels = Validator.check_positive_int(out_channels)
self.has_bias = Validator.check_bool(has_bias) self.has_bias = Validator.check_bool(has_bias)
if isinstance(weight_init, Tensor): if isinstance(weight_init, Tensor):
@ -259,8 +259,8 @@ class AttentionHead(nn.Cell):
coef_activation=nn.LeakyReLU(), coef_activation=nn.LeakyReLU(),
activation=nn.ELU()): activation=nn.ELU()):
super(AttentionHead, self).__init__() super(AttentionHead, self).__init__()
self.in_channel = check_int_positive(in_channel) self.in_channel = Validator.check_positive_int(in_channel)
self.out_channel = check_int_positive(out_channel) self.out_channel = Validator.check_positive_int(out_channel)
self.in_drop_ratio = in_drop_ratio self.in_drop_ratio = in_drop_ratio
self.in_drop = nn.Dropout(keep_prob=1 - in_drop_ratio) self.in_drop = nn.Dropout(keep_prob=1 - in_drop_ratio)
self.in_drop_2 = nn.Dropout(keep_prob=1 - in_drop_ratio) self.in_drop_2 = nn.Dropout(keep_prob=1 - in_drop_ratio)
@ -450,9 +450,9 @@ class GAT(nn.Cell):
super(GAT, self).__init__() super(GAT, self).__init__()
self.features = Tensor(features) self.features = Tensor(features)
self.biases = Tensor(biases) self.biases = Tensor(biases)
self.ftr_dims = check_int_positive(ftr_dims) self.ftr_dims = Validator.check_positive_int(ftr_dims)
self.num_class = check_int_positive(num_class) self.num_class = Validator.check_positive_int(num_class)
self.num_nodes = check_int_positive(num_nodes) self.num_nodes = Validator.check_positive_int(num_nodes)
self.hidden_units = hidden_units self.hidden_units = hidden_units
self.num_heads = num_heads self.num_heads = num_heads
self.attn_drop = attn_drop self.attn_drop = attn_drop

View File

@ -22,7 +22,7 @@ from mindspore._c_expression import init_exec_dataset
from mindspore import context from mindspore import context
from mindspore import log as logger from mindspore import log as logger
from mindspore import nn from mindspore import nn
from mindspore._checkparam import check_input_data, check_output_data, check_int_positive, Validator, check_int from mindspore._checkparam import check_input_data, check_output_data, Validator, check_int
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.common.dtype import pytype_to_dtype from mindspore.common.dtype import pytype_to_dtype
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
@ -374,7 +374,7 @@ class Model:
dataset not sink. dataset not sink.
sink_size (int): Control the amount of data each sink. Default: -1. sink_size (int): Control the amount of data each sink. Default: -1.
""" """
epoch = check_int_positive(epoch) epoch = Validator.check_positive_int(epoch)
self._train_network.set_train() self._train_network.set_train()
if self._parameter_broadcast: if self._parameter_broadcast:

View File

@ -15,7 +15,7 @@
"""thor_layer""" """thor_layer"""
import numpy as np import numpy as np
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore._checkparam import Validator, check_int_positive from mindspore._checkparam import Validator
from mindspore.common.initializer import TruncatedNormal, initializer from mindspore.common.initializer import TruncatedNormal, initializer
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
@ -160,8 +160,8 @@ class Dense_Thor(Cell):
activation=None, activation=None,
batch_size=12): batch_size=12):
super(Dense_Thor, self).__init__() super(Dense_Thor, self).__init__()
self.in_channels = check_int_positive(in_channels) self.in_channels = Validator.check_positive_int(in_channels)
self.out_channels = check_int_positive(out_channels) self.out_channels = Validator.check_positive_int(out_channels)
self.has_bias = Validator.check_bool(has_bias) self.has_bias = Validator.check_bool(has_bias)
self.thor = True self.thor = True
if isinstance(weight_init, Tensor): if isinstance(weight_init, Tensor):

View File

@ -15,7 +15,7 @@
"""Aggregator.""" """Aggregator."""
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor, Parameter from mindspore import Tensor, Parameter
from mindspore._checkparam import check_int_positive, Validator from mindspore._checkparam import Validator
from mindspore._extends import cell_attr_register from mindspore._extends import cell_attr_register
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.nn.layer.activation import get_activation from mindspore.nn.layer.activation import get_activation
@ -73,8 +73,8 @@ class GNNFeatureTransform(nn.Cell):
bias_init='zeros', bias_init='zeros',
has_bias=True): has_bias=True):
super(GNNFeatureTransform, self).__init__() super(GNNFeatureTransform, self).__init__()
self.in_channels = check_int_positive(in_channels) self.in_channels = Validator.check_positive_int(in_channels)
self.out_channels = check_int_positive(out_channels) self.out_channels = Validator.check_positive_int(out_channels)
self.has_bias = Validator.check_bool(has_bias) self.has_bias = Validator.check_bool(has_bias)
if isinstance(weight_init, Tensor): if isinstance(weight_init, Tensor):
@ -262,8 +262,8 @@ class AttentionHead(nn.Cell):
coef_activation=nn.LeakyReLU(), coef_activation=nn.LeakyReLU(),
activation=nn.ELU()): activation=nn.ELU()):
super(AttentionHead, self).__init__() super(AttentionHead, self).__init__()
self.in_channel = check_int_positive(in_channel) self.in_channel = Validator.check_positive_int(in_channel)
self.out_channel = check_int_positive(out_channel) self.out_channel = Validator.check_positive_int(out_channel)
self.in_drop_ratio = in_drop_ratio self.in_drop_ratio = in_drop_ratio
self.in_drop = nn.Dropout(keep_prob=1 - in_drop_ratio) self.in_drop = nn.Dropout(keep_prob=1 - in_drop_ratio)
self.in_drop_2 = nn.Dropout(keep_prob=1 - in_drop_ratio) self.in_drop_2 = nn.Dropout(keep_prob=1 - in_drop_ratio)

View File

@ -14,7 +14,7 @@
# ============================================================================ # ============================================================================
"""Graph Attention Networks.""" """Graph Attention Networks."""
import mindspore.nn as nn import mindspore.nn as nn
from mindspore._checkparam import Validator, check_int_positive from mindspore._checkparam import Validator
from aggregator import AttentionAggregator from aggregator import AttentionAggregator
@ -71,9 +71,9 @@ class GAT(nn.Cell):
activation=nn.ELU(), activation=nn.ELU(),
residual=False): residual=False):
super(GAT, self).__init__() super(GAT, self).__init__()
self.ftr_dims = check_int_positive(ftr_dims) self.ftr_dims = Validator.check_positive_int(ftr_dims)
self.num_class = check_int_positive(num_class) self.num_class = Validator.check_positive_int(num_class)
self.num_nodes = check_int_positive(num_nodes) self.num_nodes = Validator.check_positive_int(num_nodes)
self.hidden_units = hidden_units self.hidden_units = hidden_units
self.num_heads = num_heads self.num_heads = num_heads
self.attn_drop = attn_drop self.attn_drop = attn_drop

View File

@ -19,7 +19,7 @@ from mindspore import context
from mindspore import log as logger from mindspore import log as logger
from mindspore import nn from mindspore import nn
from mindspore._c_expression import init_exec_dataset from mindspore._c_expression import init_exec_dataset
from mindspore._checkparam import check_input_data, check_output_data, check_int_positive, Validator from mindspore._checkparam import check_input_data, check_output_data, Validator
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.common.dtype import pytype_to_dtype from mindspore.common.dtype import pytype_to_dtype
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
@ -377,7 +377,7 @@ class Model:
Configure pynative mode, the training process will be performed with Configure pynative mode, the training process will be performed with
dataset not sink. dataset not sink.
""" """
epoch = check_int_positive(epoch) epoch = Validator.check_positive_int(epoch)
self._train_network.set_train() self._train_network.set_train()
if self._parameter_broadcast: if self._parameter_broadcast:

View File

@ -16,7 +16,7 @@
import numpy as np import numpy as np
import mindspore as ms import mindspore as ms
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore._checkparam import Validator, twice, check_int_positive from mindspore._checkparam import Validator, twice
from mindspore._extends import cell_attr_register from mindspore._extends import cell_attr_register
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
@ -337,8 +337,8 @@ class Dense_Thor(Cell):
has_bias=True, has_bias=True,
activation=None): activation=None):
super(Dense_Thor, self).__init__() super(Dense_Thor, self).__init__()
self.in_channels = check_int_positive(in_channels) self.in_channels = Validator.check_positive_int(in_channels)
self.out_channels = check_int_positive(out_channels) self.out_channels = Validator.check_positive_int(out_channels)
self.has_bias = Validator.check_bool(has_bias) self.has_bias = Validator.check_bool(has_bias)
self.thor = True self.thor = True
if isinstance(weight_init, Tensor): if isinstance(weight_init, Tensor):

View File

@ -15,8 +15,7 @@
""" test checkparameter """ """ test checkparameter """
import pytest import pytest
from mindspore._checkparam import check_int, check_int_positive, \ from mindspore._checkparam import check_int, check_input_format, Validator, twice
check_input_format, Validator, twice
kernel_size = 5 kernel_size = 5
kernel_size1 = twice(kernel_size) kernel_size1 = twice(kernel_size)
@ -29,7 +28,7 @@ def test_check_int_1():
def check_int_positive_1(): def check_int_positive_1():
with pytest.raises(ValueError): with pytest.raises(ValueError):
check_int_positive(-1) Validator.check_positive_int(-1)
def test_NCHW1(): def test_NCHW1():

View File

@ -15,8 +15,7 @@
""" test_checkparameter """ """ test_checkparameter """
import pytest import pytest
from mindspore._checkparam import check_int, check_int_positive, \ from mindspore._checkparam import check_int, Validator, check_input_format, _expand_tuple
Validator, check_input_format, _expand_tuple
once = _expand_tuple(1) once = _expand_tuple(1)
twice = _expand_tuple(2) twice = _expand_tuple(2)
@ -32,7 +31,7 @@ def test_check_int_1():
def check_int_positive_1(): def check_int_positive_1():
with pytest.raises(ValueError): with pytest.raises(ValueError):
check_int_positive(-1) Validator.check_positive_int(-1)
def test_NCHW1(): def test_NCHW1():

View File

@ -15,8 +15,6 @@
"""VM implementations based on numpy.""" """VM implementations based on numpy."""
import numpy as np import numpy as np
from mindspore._checkparam import Rel
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
@ -33,7 +31,7 @@ def avg_pooling(x, pool_h, pool_w, stride):
Returns: Returns:
numpy.ndarray, an output array after applying average pooling on input array. numpy.ndarray, an output array after applying average pooling on input array.
""" """
validator.check_integer("stride", stride, 0, Rel.GT, None) validator.check_positive_int(stride, "stride")
num, channel, height, width = x.shape num, channel, height, width = x.shape
out_h = (height - pool_h) // stride + 1 out_h = (height - pool_h) // stride + 1
out_w = (width - pool_w) // stride + 1 out_w = (width - pool_w) // stride + 1
@ -423,7 +421,7 @@ def matmul(x, w, b=None):
def max_pooling(x, pool_h, pool_w, stride): def max_pooling(x, pool_h, pool_w, stride):
"""Max pooling.""" """Max pooling."""
validator.check_integer("stride", stride, 0, Rel.GT, None) validator.check_positive_int(stride, "stride")
num, channel, height, width = x.shape num, channel, height, width = x.shape
out_h = (height - pool_h) // stride + 1 out_h = (height - pool_h) // stride + 1
out_w = (width - pool_w) // stride + 1 out_w = (width - pool_w) // stride + 1
@ -466,7 +464,7 @@ def max_pool_grad_with_argmax(x, dout, arg_max, pool_h, pool_w, stride):
def max_pool_with_argmax(x, pool_h, pool_w, stride): def max_pool_with_argmax(x, pool_h, pool_w, stride):
"""Max pooling with argmax.""" """Max pooling with argmax."""
validator.check_integer("stride", stride, 0, Rel.GT, None) validator.check_positive_int(stride, "stride")
num, channel, height, width = x.shape num, channel, height, width = x.shape
out_h = (height - pool_h) // stride + 1 out_h = (height - pool_h) // stride + 1
out_w = (width - pool_w) // stride + 1 out_w = (width - pool_w) // stride + 1