[ME] delete reduant function in check_parameter

This commit is contained in:
chenzomi 2020-10-16 14:22:24 +08:00
parent 5b769dfb20
commit acadb694aa
28 changed files with 307 additions and 367 deletions

View File

@ -97,7 +97,7 @@ def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=N
Check argument integer.
Usage:
- number = check_integer(number, 0, Rel.GE, "number", None) # number >= 0
- number = check_int(number, 0, Rel.GE, "number", None) # number >= 0
"""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, arg_type) or isinstance(arg_value, bool)
@ -166,12 +166,12 @@ class Validator:
return arg_value
@staticmethod
def check_integer(arg_name, arg_value, value, rel, prim_name=None):
def check_int(arg_value, value, rel, arg_name=None, prim_name=None):
"""
Checks input integer value `arg_value` compare to `value`.
Usage:
- number = check_integer(number, 0, Rel.GE, "number", None) # number >= 0
- number = check_int(number, 0, Rel.GE, "number", None) # number >= 0
"""
return check_number(arg_value, value, rel, int, arg_name, prim_name)
@ -187,6 +187,16 @@ class Validator:
"""
return check_is_number(arg_value, int, arg_name, prim_name)
@staticmethod
def check_equal_int(arg_value, value, arg_name=None, prim_name=None):
"""
Checks input integer value `arg_value` compare to `value`.
Usage:
- number = check_int(number, 0, Rel.GE, "number", None) # number >= 0
"""
return check_number(arg_value, value, Rel.EQ, int, arg_name, prim_name)
@staticmethod
def check_positive_int(arg_value, arg_name=None, prim_name=None):
"""
@ -365,6 +375,17 @@ class Validator:
raise ValueError(f'{msg_prefix} `{arg_name}` should be str and must be in `{valid_values}`,'
f' but got `{arg_value}`.')
@staticmethod
def check_str_by_regular(target, reg=None, flag=re.ASCII, prim_name=None):
if reg is None:
# Named string regular expression
reg = r"^\w+[0-9a-zA-Z\_\.]*$"
if re.match(reg, target, flag) is None:
prim_name = f'in `{prim_name}`' if prim_name else ""
raise ValueError("'{}' {} is illegal, it should be match regular'{}' by flags'{}'".format(
target, prim_name, reg, flag))
return True
@staticmethod
def check_pad_value_by_mode(pad_mode, padding, prim_name):
"""Validates value of padding according to pad_mode"""
@ -530,13 +551,6 @@ class Validator:
f'{tuple(exp_shape)}, but got {shape}.')
def check_int_zero_one(input_param):
"""Judge whether it is 0 or 1."""
if input_param in (0, 1):
return input_param
raise ValueError("The data must be 0 or 1.")
def check_input_format(input_param):
"""Judge input format."""
if input_param == "NCHW":
@ -544,27 +558,6 @@ def check_input_format(input_param):
raise ValueError("The data format must be NCHW.")
def check_padding(padding):
"""Check padding."""
if padding >= 0:
return padding
raise ValueError("The padding must be at least 0,"" but got padding {}.".format(padding))
def check_padmode(mode):
"""Check padmode."""
if mode in ("same", "valid", "pad"):
return mode
raise ValueError("The pad mode must be same or valid or pad,"" but got mode {}.".format(mode))
def check_tensor_supported_type(dtype):
"""Check tensor dtype."""
if dtype in (mstype.int32, mstype.float32):
return dtype
raise ValueError("The dtype must be mstype.int32 or mstype.float32, but got mstype {}.".format(dtype))
def _expand_tuple(n_dimensions):
"""To expand a number to tuple."""
@ -673,42 +666,6 @@ def check_typename(arg_name, arg_type, valid_types):
f' but got {get_typename(arg_type)}.')
def check_shape(arg_name, arg_value):
"""Check shape."""
# First, check if shape is a tuple
if not isinstance(arg_value, tuple):
raise TypeError(f'The type of `{arg_name}` should be one of {tuple.__name__},'
f' but got {type(arg_value).__name__}.')
# Second, wrap arg_value with numpy array so that it can be checked through numpy api
arg_value = np.array(arg_value)
# shape can not be ()
if arg_value.size == 0:
raise ValueError('Shape can not be empty.')
# shape's dimension should be 1
if arg_value.ndim != 1:
raise ValueError('Shape of tensor should be 1-dim vector, but got {}-dim.'.format(arg_value.ndim))
# Thirdly, check each element's type of the shape
valid_types = (int, np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64)
for dim_size in arg_value:
if not isinstance(dim_size, valid_types) or dim_size <= 0:
raise ValueError('Every dimension size of the tensor shape should be a positive integer,'
' but got {}.'.format(dim_size))
def _check_str_by_regular(target, reg=None, flag=re.ASCII):
if reg is None:
# Named string regular expression
reg = r"^\w+[0-9a-zA-Z\_\.]*$"
if re.match(reg, target, flag) is None:
raise ValueError("'{}' is illegal, it should be match regular'{}' by flags'{}'".format(target, reg, flag))
return True
def args_type_check(*type_args, **type_kwargs):
"""Check whether input data type is correct."""

View File

@ -19,7 +19,7 @@ from .._c_expression import ParamInfo
from . import dtype as mstype
from .initializer import initializer, Initializer
from .tensor import Tensor, MetaTensor
from .._checkparam import _check_str_by_regular
from .._checkparam import Validator
from ..parallel._tensor import _get_slice_index
from ..parallel._auto_parallel_context import auto_parallel_context
from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched
@ -263,7 +263,7 @@ class Parameter(MetaTensor):
Returns:
Parameter, a new parameter.
"""
_check_str_by_regular(prefix)
Validator.check_str_by_regular(prefix)
x = copy(self)
# pylint: disable=protected-access
x._param_info = self._param_info.clone()
@ -446,7 +446,7 @@ class ParameterTuple(tuple):
Returns:
Tuple, the new Parameter tuple.
"""
_check_str_by_regular(prefix)
Validator.check_str_by_regular(prefix)
new = []
for x in self:
x1 = x.clone(prefix, init)

View File

@ -23,7 +23,7 @@ from collections import namedtuple
from types import FunctionType
from mindspore import log as logger
from mindspore._c_expression import MSContext, ms_ctx_param
from mindspore._checkparam import args_type_check
from mindspore._checkparam import args_type_check, Validator
from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \
_reset_auto_parallel_context
from mindspore.parallel._ps_context import _set_ps_context, _get_ps_context, _reset_ps_context
@ -35,9 +35,9 @@ __all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_aut
GRAPH_MODE = 0
PYNATIVE_MODE = 1
# The max memory size of graph plus variable.
_DEVICE_APP_MEMORY_SIZE = 31
_DEVICE_APP_MEMORY_SIZE = 31 # The max memory size of graph plus variable.
_re_pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB'
_k_context = None
def _make_directory(path):
"""Make directory."""
@ -223,7 +223,7 @@ class _Context:
def set_variable_memory_max_size(self, variable_memory_max_size):
"""set values of variable_memory_max_size and graph_memory_max_size"""
if not _check_input_format(variable_memory_max_size):
if not Validator.check_str_by_regular(variable_memory_max_size, _re_pattern):
raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"")
if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE:
raise ValueError("Context param variable_memory_max_size should be less than 31GB.")
@ -235,7 +235,7 @@ class _Context:
self.set_param(ms_ctx_param._graph_memory_max_size, graph_memory_max_size_)
def set_max_device_memory(self, max_device_memory):
if not _check_input_format(max_device_memory):
if not Validator.check_str_by_regular(max_device_memory, _re_pattern):
raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"")
max_device_memory_value = float(max_device_memory[:-2])
if max_device_memory_value == 0:
@ -294,16 +294,6 @@ class _Context:
thread_info.debug_runtime = enable
def _check_input_format(x):
import re
pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB'
result = re.match(pattern, x)
return result is not None
_k_context = None
def _context():
"""
Get the global _context, if context is not created, create a new one.

View File

@ -23,7 +23,7 @@ from mindspore import log as logger
from .. import context
from ..common import dtype as mstype
from ..common.api import _executor, _pynative_exec
from .._checkparam import _check_str_by_regular
from .._checkparam import Validator
from ..common.parameter import Parameter, ParameterTuple
from .._c_expression import init_backend, Cell_
from ..ops.primitive import Primitive
@ -715,7 +715,7 @@ class Cell(Cell_):
recurse (bool): Whether contains the parameters of subcells. Default: True.
"""
_check_str_by_regular(prefix)
Validator.check_str_by_regular(prefix)
for name, param in self.parameters_and_names(expand=recurse):
if prefix != '':
param.is_init = False

View File

@ -549,7 +549,7 @@ class Unfold(Cell):
@constexpr
def _get_matrix_diag_assist(x_shape, x_dtype):
Validator.check_integer("x rank", len(x_shape), 1, Rel.GE, "_get_matrix_diag_assist")
Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "_get_matrix_diag_assist")
base_eye = np.eye(x_shape[-1], x_shape[-1]).reshape(-1)
assist = np.tile(base_eye, x_shape[:-1]).reshape(x_shape + (x_shape[-1],))
return Tensor(assist, x_dtype)
@ -557,7 +557,7 @@ def _get_matrix_diag_assist(x_shape, x_dtype):
@constexpr
def _get_matrix_diag_part_assist(x_shape, x_dtype):
Validator.check_integer("x rank", len(x_shape), 2, Rel.GE, "_get_matrix_diag_part_assist")
Validator.check_int(len(x_shape), 2, Rel.GE, "x rank", "_get_matrix_diag_part_assist")
base_eye = np.eye(x_shape[-2], x_shape[-1]).reshape(-1)
assist = np.tile(base_eye, x_shape[:-2]).reshape(x_shape)
return Tensor(assist, x_dtype)

View File

@ -239,8 +239,8 @@ class Conv2d(_Conv):
"""Initialize depthwise conv2d op"""
if context.get_context("device_target") == "Ascend" and self.group > 1:
self.dilation = self._dilation
Validator.check_integer('group', self.group, self.in_channels, Rel.EQ)
Validator.check_integer('group', self.group, self.out_channels, Rel.EQ)
Validator.check_equal_int(self.group, self.in_channels, 'group')
Validator.check_equal_int(self.group, self.out_channels, 'group')
self.conv2d = P.DepthwiseConv2dNative(channel_multiplier=1,
kernel_size=self.kernel_size,
pad_mode=self.pad_mode,
@ -384,10 +384,10 @@ class Conv1d(_Conv):
Validator.check_value_type("stride", stride, [int], self.cls_name)
Validator.check_value_type("padding", padding, [int], self.cls_name)
Validator.check_value_type("dilation", dilation, [int], self.cls_name)
Validator.check_integer('kernel_size', kernel_size, 1, Rel.GE, self.cls_name)
Validator.check_integer('stride', stride, 1, Rel.GE, self.cls_name)
Validator.check_int(kernel_size, 1, Rel.GE, 'kernel_size', self.cls_name)
Validator.check_int(stride, 1, Rel.GE, 'stride', self.cls_name)
Validator.check_non_negative_int(padding, 'padding', self.cls_name)
Validator.check_integer('dilation', dilation, 1, Rel.GE, self.cls_name)
Validator.check_int(dilation, 1, Rel.GE, 'dilation', self.cls_name)
kernel_size = (1, kernel_size)
stride = (1, stride)
dilation = (1, dilation)
@ -395,7 +395,7 @@ class Conv1d(_Conv):
get_dtype = P.DType()
if isinstance(weight_init, Tensor):
weight_init_shape = get_shape(weight_init)
Validator.check_integer('weight_init_shape', len(weight_init_shape), 3, Rel.EQ, self.cls_name)
Validator.check_equal_int(len(weight_init_shape), 3, 'weight_init_shape', self.cls_name)
weight_init_dtype = get_dtype(weight_init)
weight_init_value = weight_init.asnumpy()
weight_init_value = np.expand_dims(weight_init_value, 2)
@ -539,7 +539,7 @@ class Conv2dTranspose(_Conv):
dilation = twice(dilation)
Validator.check_value_type('padding', padding, (int, tuple), self.cls_name)
if isinstance(padding, tuple):
Validator.check_integer('padding size', len(padding), 4, Rel.EQ, self.cls_name)
Validator.check_equal_int(len(padding), 4, 'padding size', self.cls_name)
# out_channels and in_channels swap.
# cause Conv2DBackpropInput's out_channel refers to Conv2D's out_channel,
# then Conv2dTranspose's out_channel refers to Conv2DBackpropInput's in_channel.
@ -703,10 +703,10 @@ class Conv1dTranspose(_Conv):
Validator.check_value_type("stride", stride, [int], self.cls_name)
Validator.check_value_type("padding", padding, [int], self.cls_name)
Validator.check_value_type("dilation", dilation, [int], self.cls_name)
Validator.check_integer('kernel_size', kernel_size, 1, Rel.GE, self.cls_name)
Validator.check_integer('stride', stride, 1, Rel.GE, self.cls_name)
Validator.check_int(kernel_size, 1, Rel.GE, 'kernel_size', self.cls_name)
Validator.check_int(stride, 1, Rel.GE, 'stride', self.cls_name)
Validator.check_non_negative_int(padding, 'padding', self.cls_name)
Validator.check_integer('dilation', dilation, 1, Rel.GE, self.cls_name)
Validator.check_int(dilation, 1, Rel.GE, 'dilation', self.cls_name)
kernel_size = (1, kernel_size)
stride = (1, stride)
dilation = (1, dilation)
@ -714,7 +714,7 @@ class Conv1dTranspose(_Conv):
get_dtype = P.DType()
if isinstance(weight_init, Tensor):
weight_init_shape = get_shape(weight_init)
Validator.check_integer('weight_init_shape', len(weight_init_shape), 3, Rel.EQ, self.cls_name)
Validator.check_equal_int(len(weight_init_shape), 3, 'weight_init_shape', self.cls_name)
weight_init_dtype = get_dtype(weight_init)
weight_init_value = weight_init.asnumpy()
weight_init_value = np.expand_dims(weight_init_value, 2)

View File

@ -220,7 +220,7 @@ class SSIM(Cell):
validator.check_value_type('max_val', max_val, [int, float], self.cls_name)
validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name)
self.max_val = max_val
self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name)
self.filter_size = validator.check_int(filter_size, 1, Rel.GE, 'filter_size', self.cls_name)
self.filter_sigma = validator.check_positive_float(filter_sigma, 'filter_sigma', self.cls_name)
self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name)
self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name)
@ -298,7 +298,7 @@ class MSSSIM(Cell):
validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name)
self.max_val = max_val
validator.check_value_type('power_factors', power_factors, [tuple, list], self.cls_name)
self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name)
self.filter_size = validator.check_int(filter_size, 1, Rel.GE, 'filter_size', self.cls_name)
self.filter_sigma = validator.check_positive_float(filter_sigma, 'filter_sigma', self.cls_name)
self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name)
self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name)

View File

@ -190,8 +190,8 @@ class MaxPool1d(_PoolNd):
validator.check_value_type('kernel_size', kernel_size, [int], self.cls_name)
validator.check_value_type('stride', stride, [int], self.cls_name)
self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.cls_name)
validator.check_integer("kernel_size", kernel_size, 1, Rel.GE, self.cls_name)
validator.check_integer("stride", stride, 1, Rel.GE, self.cls_name)
validator.check_int(kernel_size, 1, Rel.GE, "kernel_size", self.cls_name)
validator.check_int(stride, 1, Rel.GE, "stride", self.cls_name)
self.kernel_size = (1, kernel_size)
self.stride = (1, stride)
self.max_pool = P.MaxPool(ksize=self.kernel_size,
@ -349,8 +349,8 @@ class AvgPool1d(_PoolNd):
validator.check_value_type('kernel_size', kernel_size, [int], self.cls_name)
validator.check_value_type('stride', stride, [int], self.cls_name)
self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.cls_name)
validator.check_integer("kernel_size", kernel_size, 1, Rel.GE, self.cls_name)
validator.check_integer("stride", stride, 1, Rel.GE, self.cls_name)
validator.check_int(kernel_size, 1, Rel.GE, "kernel_size", self.cls_name)
validator.check_int(stride, 1, Rel.GE, "stride", self.cls_name)
self.kernel_size = (1, kernel_size)
self.stride = (1, stride)
self.avg_pool = P.AvgPool(ksize=self.kernel_size,

View File

@ -323,7 +323,7 @@ class FakeQuantWithMinMax(Cell):
Validator.check_type("min_init", min_init, [int, float])
Validator.check_type("max_init", max_init, [int, float])
Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT)
Validator.check_integer('quant_delay', quant_delay, 0, Rel.GE)
Validator.check_non_negative_int(quant_delay, 'quant_delay')
self.min_init = min_init
self.max_init = max_init
self.num_bits = num_bits
@ -489,8 +489,8 @@ class Conv2dBnFoldQuant(Cell):
# initialize convolution op and Parameter
if context.get_context('device_target') == "Ascend" and group > 1:
Validator.check_integer('group', group, in_channels, Rel.EQ)
Validator.check_integer('group', group, out_channels, Rel.EQ)
Validator.check_equal_int(group, in_channels, 'group')
Validator.check_equal_int(group, out_channels, 'group')
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
kernel_size=self.kernel_size,
pad_mode=pad_mode,
@ -674,8 +674,8 @@ class Conv2dBnWithoutFoldQuant(Cell):
self.bias = None
# initialize convolution op and Parameter
if context.get_context('device_target') == "Ascend" and group > 1:
Validator.check_integer('group', group, in_channels, Rel.EQ)
Validator.check_integer('group', group, out_channels, Rel.EQ)
Validator.check_equal_int(group, in_channels, 'group')
Validator.check_equal_int(group, out_channels, 'group')
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
kernel_size=self.kernel_size,
pad_mode=pad_mode,

View File

@ -931,19 +931,19 @@ class LSTMGradData(PrimitiveWithInfer):
def infer_shape(self, y_shape, dy_shape, dhy_shape, dcy_shape, w_shape,
hx_shape, cx_shape, reserve_shape, state_shape):
# dhy and dcy should be same shape
validator.check_integer("h_shape", len(dhy_shape), 3, Rel.EQ, self.name)
validator.check_integer("h_shape", len(dhy_shape), len(dcy_shape), Rel.EQ, self.name)
validator.check_integer("h_shape[0]", dhy_shape[0], dcy_shape[0], Rel.EQ, self.name)
validator.check_integer("h_shape[1]", dhy_shape[1], dcy_shape[1], Rel.EQ, self.name)
validator.check_integer("h_shape[2]", dhy_shape[2], dcy_shape[2], Rel.EQ, self.name)
validator.check_equal_int(len(dhy_shape), 3, "h_shape", self.name)
validator.check_equal_int(len(dhy_shape), len(dcy_shape), "h_shape", self.name)
validator.check_equal_int(dhy_shape[0], dcy_shape[0], "h_shape[0]", self.name)
validator.check_equal_int(dhy_shape[1], dcy_shape[1], "h_shape[1]", self.name)
validator.check_equal_int(dhy_shape[2], dcy_shape[2], "h_shape[2]", self.name)
validator.check_integer("h_shape[0]", dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, self.name)
validator.check_integer("h_shape[2]", dhy_shape[2], self.hidden_size, Rel.EQ, self.name)
validator.check_int(dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h_shape[0]", self.name)
validator.check_equal_int(dhy_shape[2], self.hidden_size, "h_shape[2]", self.name)
# dy: (seq_len, batch_size, hidden_size * num_directions)
validator.check_integer("dy_shape", len(dy_shape), 3, Rel.EQ, self.name)
validator.check_integer("dy[1]", dy_shape[1], dhy_shape[1], Rel.EQ, self.name)
validator.check_integer("dy[2]", dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, self.name)
validator.check_equal_int(len(dy_shape), 3, "dy_shape", self.name)
validator.check_equal_int(dy_shape[1], dhy_shape[1], "dy[1]", self.name)
validator.check_int(dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, "dy[2]", self.name)
# (seq_len, batch_size, input_size)
dx_shape = (y_shape[0], y_shape[1], self.input_size)
@ -1015,19 +1015,19 @@ class LSTMGrad(PrimitiveWithInfer):
def infer_shape(self, x_shape, hx_shape, cx_shape, w_shape, y_shape, hy_shape, cy_shape, dy_shape, dhy_shape,
dcy_shape, reserve_shape):
# dhy and dcy should be same shape
validator.check_integer("h_shape", len(dhy_shape), 3, Rel.EQ, self.name)
validator.check_integer("h_shape", len(dhy_shape), len(dcy_shape), Rel.EQ, self.name)
validator.check_integer("h_shape[0]", dhy_shape[0], dcy_shape[0], Rel.EQ, self.name)
validator.check_integer("h_shape[1]", dhy_shape[1], dcy_shape[1], Rel.EQ, self.name)
validator.check_integer("h_shape[2]", dhy_shape[2], dcy_shape[2], Rel.EQ, self.name)
validator.check_equal_int(len(dhy_shape), 3, "h_shape", self.name)
validator.check_equal_int(len(dhy_shape), len(dcy_shape), "h_shape", self.name)
validator.check_equal_int(dhy_shape[0], dcy_shape[0], "h_shape[0]", self.name)
validator.check_equal_int(dhy_shape[1], dcy_shape[1], "h_shape[1]", self.name)
validator.check_equal_int(dhy_shape[2], dcy_shape[2], "h_shape[2]", self.name)
validator.check_integer("h_shape[0]", dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, self.name)
validator.check_integer("h_shape[2]", dhy_shape[2], self.hidden_size, Rel.EQ, self.name)
validator.check_int(dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h_shape[0]", self.name)
validator.check_equal_int(dhy_shape[2], self.hidden_size, "h_shape[2]", self.name)
# dy: (seq_len, batch_size, hidden_size * num_directions)
validator.check_integer("dy_shape", len(dy_shape), 3, Rel.EQ, self.name)
validator.check_integer("dy[1]", dy_shape[1], dhy_shape[1], Rel.EQ, self.name)
validator.check_integer("dy[2]", dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, self.name)
validator.check_equal_int(len(dy_shape), 3, "dy_shape", self.name)
validator.check_equal_int(dy_shape[1], dhy_shape[1], "dy[1]", self.name)
validator.check_int(dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, "dy[2]", self.name)
# (seq_len, batch_size, input_size)
dx_shape = (y_shape[0], y_shape[1], self.input_size)
@ -1069,7 +1069,7 @@ class DynamicRNNGrad(PrimitiveWithInfer):
def infer_shape(self, x_shape, w_shape, b_shape, y_shape, init_h_shape, init_c_shape, h_shape,
c_shape, dy_shape, dh_shape, dc_shape, i_shape, j_shape, f_shape, o_shape, tanhc_shape):
validator.check_integer("x_shape", len(x_shape), 3, Rel.EQ, self.name)
validator.check_equal_int(len(x_shape), 3, "x_shape", self.name)
num_step, batch_size, input_size = x_shape
hidden_size = w_shape[-1] // 4
if w_shape[-1] % 4 != 0:
@ -1575,7 +1575,7 @@ class BasicLSTMCellCStateGrad(PrimitiveWithInfer):
def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape):
# dhy and dcy should be same shape
validator.check_integer("c rank", len(c_shape), 2, Rel.EQ, self.name)
validator.check_equal_int(len(c_shape), 2, "c rank", self.name)
validator.check("dht rank", len(dht_shape), "c rank", len(c_shape), Rel.EQ, self.name)
validator.check("dct rank", len(dct_shape), "c rank", len(c_shape), Rel.EQ, self.name)
validator.check("it rank", len(it_shape), "c rank", len(c_shape), Rel.EQ, self.name)
@ -1624,7 +1624,7 @@ class BasicLSTMCellWeightGrad(PrimitiveWithInfer):
self.add_prim_attr("io_format", "HWCN")
def infer_shape(self, x_shape, h_shape, dgate_shape):
validator.check_integer("x rank", len(x_shape), 2, Rel.EQ, self.name)
validator.check_equal_int(len(x_shape), 2, "x rank", self.name)
validator.check("h rank", len(h_shape), " x rank", len(x_shape), Rel.EQ, self.name)
validator.check("dgate rank", len(dgate_shape), "x rank", len(x_shape), Rel.EQ, self.name)
validator.check("h_shape[0]", h_shape[0], "x_shape[0]", x_shape[0], Rel.EQ, self.name)
@ -1656,8 +1656,8 @@ class BasicLSTMCellInputGrad(PrimitiveWithInfer):
self.add_prim_attr("io_format", "ND")
def infer_shape(self, dgate_shape, w_shape):
validator.check_integer("dgate rank", len(dgate_shape), 2, Rel.EQ, self.name)
validator.check_integer("w rank", len(w_shape), 2, Rel.EQ, self.name)
validator.check_equal_int(len(dgate_shape), 2, "dgate rank", self.name)
validator.check_equal_int(len(w_shape), 2, "w rank", self.name)
validator.check("dgate_shape[1]", dgate_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
batch_size = dgate_shape[0]
hidden_size = dgate_shape[1] // 4

View File

@ -347,7 +347,7 @@ class MatrixDiag(PrimitiveWithInfer):
return x_dtype
def infer_shape(self, x_shape, assist_shape):
validator.check_integer("assist rank", len(assist_shape), 2, Rel.GE, self.name)
validator.check_int(len(assist_shape), 2, Rel.GE, "assist rank", self.name)
validator.check('rank of x', len(x_shape)+1,
'rank of assist', len(assist_shape), Rel.LE, self.name)
validator.check('assist\'s penultimate dimension', assist_shape[-2], 'assist\'s last dimension',
@ -395,7 +395,7 @@ class MatrixDiagPart(PrimitiveWithInfer):
return x_dtype
def infer_shape(self, x_shape, assist_shape):
validator.check_integer("x rank", len(x_shape), 2, Rel.GE, self.name)
validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name)
validator.check("x shape", x_shape, "assist shape", assist_shape, Rel.EQ, self.name)
if assist_shape[-2] < assist_shape[-1]:
@ -438,7 +438,7 @@ class MatrixSetDiag(PrimitiveWithInfer):
return x_dtype
def infer_shape(self, x_shape, diagonal_shape, assist_shape):
validator.check_integer("x rank", len(x_shape), 2, Rel.GE, self.name)
validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name)
validator.check("x shape", x_shape, "assist shape", assist_shape, Rel.EQ, self.name)
if x_shape[-2] < x_shape[-1]:

View File

@ -81,11 +81,10 @@ class MinMaxUpdatePerLayer(PrimitiveWithInfer):
outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
@ -147,11 +146,10 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank:
raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'")
if not self.is_ascend:
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
@ -228,9 +226,9 @@ class FakeQuantPerLayer(PrimitiveWithInfer):
outputs=['out'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(min_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
return x_shape
def infer_dtype(self, x_type, min_type, max_type):
@ -284,8 +282,7 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer):
x_shape, Rel.EQ, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
return dout_shape
def infer_dtype(self, dout_type, x_type, min_type, max_type):
@ -375,14 +372,12 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank:
raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'")
if not self.is_ascend:
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
if len(x_shape) == 1:
self.channel_axis = 0
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
validator.check_integer(
"min shape", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
validator.check_integer(
"max shape", max_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
validator.check_equal_int(min_shape[0], x_shape[self.channel_axis], "min shape", self.name)
validator.check_equal_int(max_shape[0], x_shape[self.channel_axis], "max shape", self.name)
return x_shape
def infer_dtype(self, x_type, min_type, max_type):
@ -501,7 +496,7 @@ class BatchNormFold(PrimitiveWithInfer):
def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape):
validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name)
validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[self.channel_axis], Rel.EQ, self.name)
validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
return mean_shape, mean_shape, mean_shape, mean_shape
def infer_dtype(self, x_type, mean_type, variance_type, global_step_type):
@ -548,7 +543,7 @@ class BatchNormFoldGrad(PrimitiveWithInfer):
"batch_std shape", batch_std_shape, Rel.EQ, self.name)
validator.check("d_batch_mean_shape[0]", d_batch_mean_shape[0],
"input channel", x_shape[self.channel_axis], Rel.EQ, self.name)
validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
return x_shape
def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type,
@ -723,7 +718,7 @@ class BatchNormFold2(PrimitiveWithInfer):
validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name)
validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
Rel.EQ, self.name)
validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
return x_shape
def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type,
@ -771,7 +766,7 @@ class BatchNormFold2Grad(PrimitiveWithInfer):
validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name)
validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis],
Rel.EQ, self.name)
validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape
def infer_dtype(self, dout_type, x_type, gamma_type,

View File

@ -520,7 +520,7 @@ class Im2Col(PrimitiveWithInfer):
self.add_prim_attr('data_format', "NCHW")
def infer_shape(self, x_shape):
validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name)
validator.check_equal_int(len(x_shape), 4, "x rank", self.name)
kernel_size_h = self.kernel_size[0]
kernel_size_w = self.kernel_size[1]
stride_h = self.stride[2]

View File

@ -583,8 +583,8 @@ class Transpose(PrimitiveWithInfer):
tmp = list(p_value)
for i, dim in enumerate(p_value):
validator.check_integer("perm[%d]" % i, dim, 0, Rel.GE, self.name)
validator.check_integer("perm[%d]" % i, dim, len(p_value), Rel.LT, self.name)
validator.check_int(dim, 0, Rel.GE, f'perm[{i}]', self.name)
validator.check_int(dim, len(p_value), Rel.LT, f'perm[{i}]', self.name)
tmp.remove(dim)
if dim in tmp:
raise ValueError('The value of perm is wrong.')
@ -725,8 +725,8 @@ class Padding(PrimitiveWithInfer):
def __infer__(self, x):
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
x_shape = list(x['shape'])
validator.check_integer("rank of x", len(x_shape), 1, Rel.GT, self.name)
validator.check_integer("last dim of x", x_shape[-1], 1, Rel.EQ, self.name)
validator.check_int(len(x_shape), 1, Rel.GT, "rank of x", self.name)
validator.check_int(x_shape[-1], 1, Rel.EQ, "last dim of x", self.name)
out_shape = x_shape
out_shape[-1] = self.pad_dim_size
out = {'shape': out_shape,
@ -1575,7 +1575,7 @@ class UnsortedSegmentMin(PrimitiveWithInfer):
valid_type = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name)
validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name)
validator.check_integer("rank of segment_ids_shape", len(segment_ids_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
validator.check(f'first shape of input_x', x_shape[0],
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
num_segments_v = num_segments['value']
@ -1628,7 +1628,7 @@ class UnsortedSegmentProd(PrimitiveWithInfer):
valid_type = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name)
validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name)
validator.check_integer("rank of segment_ids_shape", len(segment_ids_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
validator.check(f'first shape of input_x', x_shape[0],
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
num_segments_v = num_segments['value']
@ -1730,7 +1730,7 @@ class ParallelConcat(PrimitiveWithInfer):
x_shp = values['shape']
x_type = values['dtype']
validator.check_integer(f'x_shp length', len(x_shp), 1, Rel.GE, self.name)
validator.check_int(len(x_shp), 1, Rel.GE, f'x_shp length', self.name)
args = {f"x_type[{i}]": elem for i, elem in enumerate(x_type)}
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name)
@ -1738,7 +1738,7 @@ class ParallelConcat(PrimitiveWithInfer):
first_elem = x_shp[0]
for i, elem in enumerate(x_shp[1:]):
j = i + 1
validator.check_integer(f'x_shp[{j}][0]', elem[0], 1, Rel.EQ, self.name)
validator.check_equal_int(elem[0], 1, f'x_shp[{j}][0]', self.name)
validator.check(f"x_shp[0] shape", first_elem, f"x_shp[{j}] shape", elem, Rel.EQ, self.name)
ret_shp = x_shp[0].copy()
@ -1755,7 +1755,7 @@ class ParallelConcat(PrimitiveWithInfer):
def _get_pack_shape(x_shape, x_type, axis, prim_name):
"""for pack output shape"""
validator.check_value_type("shape", x_shape, [tuple, list], prim_name)
validator.check_integer("len of input_x", len(x_shape), 1, Rel.GE, prim_name)
validator.check_int(len(x_shape), 1, Rel.GE, "len of input_x", prim_name)
validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, prim_name)
rank_base = len(x_shape[0])
N = len(x_shape)
@ -1871,8 +1871,8 @@ class Unpack(PrimitiveWithInfer):
validator.check_positive_int(output_num, "output_num", self.name)
self.add_prim_attr('num', output_num)
output_valid_check = x_shape[self.axis] - output_num
validator.check_integer("The dimension which to unpack divides output_num", output_valid_check, 0, Rel.EQ,
self.name)
validator.check_int(output_valid_check, 0, Rel.EQ,
"The dimension which to unpack divides output_num", self.name)
out_shapes = []
out_dtypes = []
out_shape = x_shape[:self.axis] + x_shape[self.axis + 1:]
@ -2523,7 +2523,7 @@ class ResizeNearestNeighbor(PrimitiveWithInfer):
"""Initialize ResizeNearestNeighbor"""
validator.check_value_type("size", size, [tuple, list], self.name)
validator.check_value_type("align_corners", align_corners, [bool], self.name)
validator.check_integer("length of size", len(size), 2, Rel.EQ, self.name)
validator.check_equal_int(len(size), 2, "length of size", self.name)
for i, value in enumerate(size):
validator.check_non_negative_int(value, f'{i}th value of size', self.name)
self.init_prim_io_names(inputs=['image_in'], outputs=['image_out'])
@ -3134,9 +3134,8 @@ class DepthToSpace(PrimitiveWithInfer):
for i in range(2):
out_shape[i + 2] *= self.block_size
validator.check_integer('x_shape[1] % (block_size*block_size)',
x_shape[1] % (self.block_size * self.block_size),
0, Rel.EQ, self.name)
validator.check_int(x_shape[1] % (self.block_size * self.block_size),
0, Rel.EQ, 'x_shape[1] % (block_size*block_size)', self.name)
out_shape[1] //= self.block_size * self.block_size
return out_shape
@ -3205,7 +3204,7 @@ class SpaceToBatch(PrimitiveWithInfer):
return x_dtype
def infer_shape(self, x_shape):
validator.check_integer('rank of input_x', len(x_shape), 4, Rel.EQ, self.name)
validator.check_equal_int(len(x_shape), 4, 'rank of input_x', self.name)
out_shape = copy.deepcopy(x_shape)
for i in range(2):
padded = out_shape[i + 2] + self.paddings[i][0] + self.paddings[i][1]
@ -3367,7 +3366,7 @@ class SpaceToBatchND(PrimitiveWithInfer):
def infer_shape(self, x_shape):
x_rank = len(x_shape)
validator.check_integer('x_shape rank', x_rank, 4, Rel.EQ, self.name)
validator.check_equal_int(x_rank, 4, 'x_shape rank', self.name)
out_shape = copy.deepcopy(x_shape)
block_shape_prod = 1
@ -3460,7 +3459,7 @@ class BatchToSpaceND(PrimitiveWithInfer):
def infer_shape(self, x_shape):
x_rank = len(x_shape)
validator.check_integer('x_shape rank', x_rank, 4, Rel.EQ, self.name)
validator.check_int(x_rank, 4, Rel.EQ, 'x_shape rank', self.name)
out_shape = copy.deepcopy(x_shape)
block_shape_prod = 1
@ -3607,11 +3606,11 @@ class Meshgrid(PrimitiveWithInfer):
def infer_shape(self, x_shape):
validator.check_value_type("shape", x_shape, [tuple, list], self.name)
validator.check_integer("len of input_x", len(x_shape), 2, Rel.GE, self.name)
validator.check_int(len(x_shape), 2, Rel.GE, "len of input_x", self.name)
n = len(x_shape)
shape_0 = []
for s in x_shape:
validator.check_integer('each_input_rank', len(s), 1, Rel.EQ, self.name)
validator.check_int(len(s), 1, Rel.EQ, 'each_input_rank', self.name)
shape_0.append(s[0])
if self.indexing == "xy":
shape_0[0], shape_0[1] = shape_0[1], shape_0[0]

View File

@ -204,7 +204,7 @@ class _HostAllGather(PrimitiveWithInfer):
if group is None:
raise ValueError(f"For '{self.name}' group must be set.")
validator.check_value_type('group', group, (tuple, list), self.name)
validator.check_integer("group size", len(group), 2, Rel.GE, self.name)
validator.check_int(len(group), 2, Rel.GE, "group size", self.name)
for r in group:
validator.check_int_range(r, 0, 7, Rel.INC_BOTH, "rank_id", self.name)
validator.check_value_type("rank_id", r, (int,), self.name)
@ -313,7 +313,7 @@ class _HostReduceScatter(PrimitiveWithInfer):
raise ValueError(f"For '{self.name}' group must be set.")
validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name)
validator.check_value_type('group', group, (tuple, list), self.name)
validator.check_integer("group size", len(group), 2, Rel.GE, self.name)
validator.check_int(len(group), 2, Rel.GE, "group size", self.name)
for r in group:
validator.check_int_range(r, 0, 7, Rel.INC_BOTH, "rank_id", self.name)
validator.check_value_type("rank_id", r, (int,), self.name)

View File

@ -126,7 +126,7 @@ class GeSwitch(PrimitiveWithInfer):
raise NotImplementedError
def infer_shape(self, data, pred):
validator.check_integer("pred rank", len(pred), 0, Rel.EQ, self.name)
validator.check_equal_int(len(pred), 0, "pred rank", self.name)
return (data, data)
def infer_dtype(self, data_type, pred_type):

View File

@ -374,9 +374,9 @@ class Assert(PrimitiveWithInfer):
def infer_shape(self, condition, inputs):
condition_len = len(condition)
validator.check_integer("condition's rank", condition_len, 1, Rel.LE, self.name)
validator.check_int(condition_len, 1, Rel.LE, "condition's rank", self.name)
if condition_len == 1:
validator.check_integer("condition[0]", condition[0], 1, Rel.EQ, self.name)
validator.check_equal_int(condition[0], 1, "condition[0]", self.name)
return [1]
def infer_dtype(self, condition, inputs):

View File

@ -17,7 +17,6 @@
import numbers
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ...common.dtype import tensor, dtype_to_pytype
from ..primitive import prim_attr_register, PrimitiveWithInfer
@ -43,7 +42,7 @@ class ScalarCast(PrimitiveWithInfer):
pass
def __infer__(self, x, t):
validator.check_integer('x shape', len(x['shape']), 0, Rel.EQ, self.name)
validator.check_equal_int(len(x['shape']), 0, 'x shape', self.name)
value, to = x['value'], t['value']
if value is not None:
validator.check_value_type("value", value, [numbers.Number, bool], self.name)

View File

@ -827,7 +827,7 @@ class AddN(PrimitiveWithInfer):
def infer_shape(self, inputs):
cls_name = self.name
validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name)
validator.check_int(len(inputs), 1, Rel.GE, "inputs", cls_name)
self.add_prim_attr('n', len(inputs))
shp0 = inputs[0]
for i, shp in enumerate(inputs):
@ -837,7 +837,7 @@ class AddN(PrimitiveWithInfer):
def infer_dtype(self, inputs):
cls_name = self.name
validator.check_value_type("inputs", inputs, [tuple, list], cls_name)
validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name)
validator.check_int(len(inputs), 1, Rel.GE, "inputs", cls_name)
args = {}
contains_undetermined = False
for i, dtype in enumerate(inputs):
@ -910,7 +910,7 @@ class AccumulateNV2(PrimitiveWithInfer):
def infer_shape(self, inputs):
cls_name = self.name
validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name)
validator.check_int(len(inputs), 1, Rel.GE, "inputs", cls_name)
self.add_prim_attr('n', len(inputs))
shp0 = inputs[0]
for i, shp in enumerate(inputs):
@ -920,7 +920,7 @@ class AccumulateNV2(PrimitiveWithInfer):
def infer_dtype(self, inputs):
cls_name = self.name
validator.check_value_type("inputs", inputs, [tuple, list], cls_name)
validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name)
validator.check_int(len(inputs), 1, Rel.GE, "inputs", cls_name)
args = {}
for i, dtype in enumerate(inputs):
args[f"inputs[{i}]"] = dtype
@ -1488,7 +1488,7 @@ class HistogramFixedWidth(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, nbins, dtype='int32'):
self.nbins = validator.check_value_type("nbins", nbins, [int], self.name)
validator.check_integer("nbins", nbins, 1, Rel.GE, self.name)
validator.check_int(nbins, 1, Rel.GE, "nbins", self.name)
valid_values = ['int32', 'int64']
self.dtype = validator.check_string(dtype, valid_values, "dtype", self.name)
self.init_prim_io_names(inputs=['x', 'range'], outputs=['y'])
@ -2810,8 +2810,8 @@ class NPUGetFloatStatus(PrimitiveWithInfer):
def infer_shape(self, x_shape):
cls_name = self.name
validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name)
validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name)
validator.check_equal_int(len(x_shape), 1, "len(x_shape)", cls_name)
validator.check_equal_int(x_shape[0], 8, "x_shape[0]", cls_name)
return [8]
def infer_dtype(self, x_dtype):
@ -2853,8 +2853,8 @@ class NPUClearFloatStatus(PrimitiveWithInfer):
def infer_shape(self, x_shape):
cls_name = self.name
validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name)
validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name)
validator.check_equal_int(len(x_shape), 1, "len(x_shape)", cls_name)
validator.check_equal_int(x_shape[0], 8, "x_shape[0]", cls_name)
return [8]
def infer_dtype(self, x_dtype):
@ -3023,9 +3023,9 @@ class NMSWithMask(PrimitiveWithInfer):
def infer_shape(self, bboxes_shape):
cls_name = self.name
validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, cls_name)
validator.check_equal_int(len(bboxes_shape), 2, "bboxes rank", cls_name)
validator.check_positive_int(bboxes_shape[0], "bboxes.shape[0]", cls_name)
validator.check_integer("bboxes.shape[1]", bboxes_shape[1], 5, Rel.EQ, cls_name)
validator.check_equal_int(bboxes_shape[1], 5, "bboxes.shape[1]", cls_name)
num = bboxes_shape[0]
return (bboxes_shape, (num,), (num,))
@ -3572,11 +3572,11 @@ class IFMR(PrimitiveWithInfer):
validator.check_value_type("offset_flag", with_offset, [bool], self.name)
def infer_shape(self, data_shape, data_min_shape, data_max_shape, cumsum_shape):
validator.check_integer("dims of data_min", len(data_min_shape), 1, Rel.EQ, self.name)
validator.check_integer("data_min[0]", data_min_shape[0], 1, Rel.EQ, self.name)
validator.check_integer("dims of data_max", len(data_max_shape), 1, Rel.EQ, self.name)
validator.check_integer("data_max[0]", data_max_shape[0], 1, Rel.EQ, self.name)
validator.check_integer("dims of cumsum", len(cumsum_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(data_min_shape), 1, "dims of data_min", self.name)
validator.check_equal_int(data_min_shape[0], 1, "data_min[0]", self.name)
validator.check_equal_int(len(data_max_shape), 1, "dims of data_max", self.name)
validator.check_equal_int(data_max_shape[0], 1, "data_max[0]", self.name)
validator.check_equal_int(len(cumsum_shape), 1, "dims of cumsum", self.name)
return (1,), (1,)
def infer_dtype(self, data_dtype, data_min_dtype, data_max_dtype, cumsum_dtype):

View File

@ -98,7 +98,7 @@ class Flatten(PrimitiveWithInfer):
pass
def infer_shape(self, input_x):
validator.check_integer('input_x rank', len(input_x), 1, Rel.GE, self.name)
validator.check_int(len(input_x), 1, Rel.GE, 'input_x rank', self.name)
prod = 1 if len(input_x) == 1 else reduce(operator.mul, input_x[1:])
return input_x[0], prod
@ -146,7 +146,7 @@ class Softmax(PrimitiveWithInfer):
validator.check_value_type("item of axis", item, [int], self.name)
def infer_shape(self, logits):
validator.check_integer("length of axis", len(self.axis), 1, Rel.GE, self.name)
validator.check_int(len(self.axis), 1, Rel.GE, "length of axis", self.name)
rank = len(logits)
for axis_v in self.axis:
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
@ -636,7 +636,7 @@ class FusedBatchNorm(Primitive):
def __init__(self, mode=0, epsilon=1e-5, momentum=0.1):
self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'],
outputs=['y', 'running_mean', 'running_variance', 'save_mean', 'save_inv_variance'])
self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name)
self.mode = validator.check_int(mode, [0, 1], Rel.IN, 'mode', self.name)
self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
self._update_parameter = True
@ -709,17 +709,17 @@ class FusedBatchNormEx(PrimitiveWithInfer):
def __init__(self, mode=0, epsilon=1e-5, momentum=0.1):
self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'],
outputs=['y', 'save_scale', 'save_bias', 'save_mean', 'save_inv_variance', 'reserve'])
self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name)
self.mode = validator.check_int(mode, [0, 1], Rel.IN, 'mode', self.name)
self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
self._update_parameter = True
self.add_prim_attr('data_format', "NCHW")
def infer_shape(self, input_x, scale, bias, mean, variance):
validator.check_integer("scale rank", len(scale), 1, Rel.EQ, self.name)
validator.check_equal_int(len(scale), 1, "scale rank", self.name)
validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name)
validator.check("scale shape[0]", scale[0], "input_x shape[1]", input_x[1], Rel.EQ, self.name)
validator.check_integer("mean rank", len(mean), 1, Rel.EQ, self.name)
validator.check_equal_int(len(mean), 1, "mean rank", self.name)
validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name)
validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name)
return (input_x, scale, scale, scale, scale, scale)
@ -757,7 +757,7 @@ class BNTrainingReduce(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['x'], outputs=['sum', 'square_sum'])
def infer_shape(self, x_shape):
validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name)
validator.check_equal_int(len(x_shape), 4, "x rank", self.name)
return ([x_shape[1]], [x_shape[1]])
def infer_dtype(self, x_type):
@ -822,13 +822,13 @@ class BNTrainingUpdate(PrimitiveWithInfer):
self.factor = validator.check_float_range(factor, 0, 1, Rel.INC_BOTH, 'factor', 'BNTrainingUpdate')
def infer_shape(self, x, sum, square_sum, scale, b, mean, variance):
validator.check_integer("x rank", len(x), 4, Rel.EQ, self.name)
validator.check_integer("sum rank", len(sum), 1, Rel.EQ, self.name)
validator.check_integer("square_sum rank", len(square_sum), 1, Rel.EQ, self.name)
validator.check_integer("scale rank", len(scale), 1, Rel.EQ, self.name)
validator.check_integer("b rank", len(b), 1, Rel.EQ, self.name)
validator.check_integer("mean rank", len(mean), 1, Rel.EQ, self.name)
validator.check_integer("variance rank", len(variance), 1, Rel.EQ, self.name)
validator.check_equal_int(len(x), 4, "x rank", self.name)
validator.check_equal_int(len(sum), 1, "sum rank", self.name)
validator.check_equal_int(len(square_sum), 1, "square_sum rank", self.name)
validator.check_equal_int(len(scale), 1, "scale rank", self.name)
validator.check_equal_int(len(b), 1, "b rank", self.name)
validator.check_equal_int(len(mean), 1, "mean rank", self.name)
validator.check_equal_int(len(variance), 1, "variance rank", self.name)
validator.check("sum shape", sum, "x_shape[1]", x[1], Rel.EQ, self.name)
validator.check("square_sum shape", square_sum, "sum", sum, Rel.EQ, self.name)
validator.check("scale shape", scale, "x_shape[1]", x[1], Rel.EQ, self.name)
@ -904,11 +904,11 @@ class BatchNorm(PrimitiveWithInfer):
outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2'])
def infer_shape(self, input_x, scale, bias, mean, variance):
validator.check_integer("scale rank", len(scale), 1, Rel.EQ, self.name)
validator.check_equal_int(len(scale), 1, "scale rank", self.name)
validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name)
validator.check("scale shape[0]", scale[0], "input_x shape[1]", input_x[1], Rel.EQ, self.name)
if not self.is_training:
validator.check_integer("mean rank", len(mean), 1, Rel.EQ, self.name)
validator.check_equal_int(len(mean), 1, "mean rank", self.name)
validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name)
validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name)
return (input_x, scale, scale, scale, scale)
@ -1010,7 +1010,7 @@ class Conv2D(PrimitiveWithInfer):
if isinstance(pad, int):
pad = (pad,) * 4
else:
validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name)
validator.check_equal_int(len(pad), 4, 'pad size', self.name)
self.padding = pad
self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name)
@ -1020,15 +1020,15 @@ class Conv2D(PrimitiveWithInfer):
for item in pad:
validator.check_non_negative_int(item, 'pad item', self.name)
self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name)
self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
self.add_prim_attr('data_format', "NCHW")
self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
self.group = validator.check_positive_int(group, 'group', self.name)
self.add_prim_attr('offset_a', 0)
def infer_shape(self, x_shape, w_shape, b_shape=None):
validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name)
validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name)
validator.check_equal_int(len(w_shape), 4, "weight rank", self.name)
validator.check_equal_int(len(x_shape), 4, "x rank", self.name)
validator.check(f"x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name)
validator.check('out_channel', self.out_channel, 'w_shape[0]', w_shape[0], Rel.EQ, self.name)
validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4]), Rel.EQ, self.name)
@ -1150,7 +1150,7 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
if isinstance(pad, int):
pad = (pad,) * 4
else:
validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name)
validator.check_equal_int(len(pad), 4, 'pad size', self.name)
self.padding = pad
self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name)
if pad_mode != 'pad' and pad != (0, 0, 0, 0):
@ -1158,15 +1158,15 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
if self.pad_mode == 'pad':
for item in pad:
validator.check_non_negative_int(item, 'pad item', self.name)
self.mode = validator.check_integer("mode", mode, 3, Rel.EQ, self.name)
self.mode = validator.check_equal_int(mode, 3, "mode", self.name)
self.add_prim_attr('data_format', "NCHW")
self.channel_multiplier = validator.check_positive_int(channel_multiplier, "channel_multiplier", self.name)
self.group = validator.check_positive_int(group, "group", self.name)
self.add_prim_attr('offset_a', 0)
def infer_shape(self, x_shape, w_shape, b_shape=None):
validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name)
validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name)
validator.check_equal_int(len(w_shape), 4, "weight rank", self.name)
validator.check_equal_int(len(x_shape), 4, "x rank", self.name)
validator.check("x_shape[1]", x_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4]), Rel.EQ, self.name)
@ -1250,7 +1250,7 @@ class _Pool(PrimitiveWithInfer):
self.add_prim_attr("strides", self.strides)
def infer_shape(self, x_shape):
validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name)
validator.check_equal_int(len(x_shape), 4, "x rank", self.name)
batch, channel, input_h, input_w = x_shape
if self.is_maxpoolwithargmax:
_, kernel_h, kernel_w, _ = self.ksize
@ -1536,7 +1536,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
if isinstance(pad, int):
pad = (pad,) * 4
else:
validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name)
validator.check_equal_int(len(pad), 4, 'pad size', self.name)
self.padding = pad
self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name)
if pad_mode != 'pad' and pad != (0, 0, 0, 0):
@ -1547,7 +1547,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
pad_mode = pad_mode.upper()
self.add_prim_attr('pad_mode', pad_mode)
self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name)
self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
self.group = validator.check_positive_int(group, 'group', self.name)
self.add_prim_attr('data_format', "NCHW")
if pad_list:
@ -1624,8 +1624,8 @@ class BiasAdd(PrimitiveWithInfer):
self.add_prim_attr('data_format', 'NCHW')
def infer_shape(self, x_shape, b_shape):
validator.check_integer("x rank", len(x_shape), 2, Rel.GE, self.name)
validator.check_integer("bias rank", len(b_shape), 1, Rel.EQ, self.name)
validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name)
validator.check_equal_int(len(b_shape), 1, "bias rank", self.name)
validator.check("b_shape[0]", b_shape[0], "x_shape[1]", x_shape[1], Rel.EQ, self.name)
return x_shape
@ -2007,10 +2007,10 @@ class RNNTLoss(PrimitiveWithInfer):
outputs=['costs', 'grads'])
def infer_shape(self, acts_shape, labels_shape, input_length_shape, label_length_shape):
validator.check_integer('acts_rank', len(acts_shape), 4, Rel.EQ, self.name)
validator.check_integer('labels_rank', len(labels_shape), 2, Rel.EQ, self.name)
validator.check_integer('input_length_rank', len(input_length_shape), 1, Rel.EQ, self.name)
validator.check_integer('label_length_rank', len(label_length_shape), 1, Rel.EQ, self.name)
validator.check_equal_int(len(acts_shape), 4, 'acts_rank', self.name)
validator.check_equal_int(len(labels_shape), 2, 'labels_rank', self.name)
validator.check_equal_int(len(input_length_shape), 1, 'input_length_rank', self.name)
validator.check_equal_int(len(label_length_shape), 1, 'label_length_rank', self.name)
validator.check('labels shape[0]', labels_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name)
validator.check('labels shape[1]', labels_shape[1], 'acts shape[2]-1', acts_shape[2]-1, Rel.EQ, self.name)
validator.check('input_length size', input_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name)
@ -2080,11 +2080,11 @@ class SGD(PrimitiveWithInfer):
def infer_shape(self, parameters_shape, gradient_shape, learning_rate_shape,
accum_shape, momentum_shape, stat_shape):
validator.check_positive_int(len(parameters_shape), "parameters rank", self.name)
validator.check_integer(f'gradient rank', len(gradient_shape), 0, Rel.GE, self.name)
validator.check_integer(f'learning rate rank', len(learning_rate_shape), 0, Rel.GE, self.name)
validator.check_int(len(gradient_shape), 0, Rel.GE, f'gradient rank', self.name)
validator.check_int(len(learning_rate_shape), 0, Rel.GE, f'learning rate rank', self.name)
validator.check_positive_int(len(accum_shape), "accumulation rank", self.name)
validator.check_integer(f'momentum rank', len(momentum_shape), 0, Rel.GE, self.name)
validator.check_integer(f'stat rank', len(stat_shape), 0, Rel.GE, self.name)
validator.check_int(len(momentum_shape), 0, Rel.GE, f'momentum rank', self.name)
validator.check_int(len(stat_shape), 0, Rel.GE, f'stat rank', self.name)
validator.check("gradient shape", gradient_shape, "stat shape", stat_shape, Rel.EQ, self.name)
return parameters_shape
@ -2780,17 +2780,17 @@ class LSTM(PrimitiveWithInfer):
def infer_shape(self, x_shape, h_shape, c_shape, w_shape):
# (seq, batch_size, feature)
validator.check_integer("x rank", len(x_shape), 3, Rel.EQ, self.name)
validator.check_integer("x[2]", x_shape[2], self.input_size, Rel.EQ, self.name)
validator.check_equal_int(len(x_shape), 3, "x rank", self.name)
validator.check_equal_int(x_shape[2], self.input_size, "x[2]", self.name)
# h and c should be same shape
validator.check_integer("h rank", len(h_shape), 3, Rel.EQ, self.name)
validator.check_equal_int(len(h_shape), 3, "h rank", self.name)
validator.check("h_shape", h_shape, "c_shape", c_shape, Rel.EQ, self.name)
# (num_layers * num_directions, batch, hidden_size)
validator.check_integer("h[0]", h_shape[0], self.num_layers * self.num_directions, Rel.EQ, self.name)
validator.check_integer("h[1]", h_shape[1], x_shape[1], Rel.EQ, self.name)
validator.check_integer("h[2]", h_shape[2], self.hidden_size, Rel.EQ, self.name)
validator.check_int(h_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h[0]", self.name)
validator.check_equal_int(h_shape[1], x_shape[1], "h[1]", self.name)
validator.check_int(h_shape[2], self.hidden_size, Rel.EQ, "h[2]", self.name)
y_shape = (x_shape[0], x_shape[1], self.hidden_size * self.num_directions)
@ -2918,7 +2918,7 @@ class Pad(PrimitiveWithInfer):
def infer_shape(self, x):
paddings = np.array(self.paddings)
validator.check_integer('paddings.shape', paddings.size, len(x) * 2, Rel.EQ, self.name)
validator.check_int(paddings.size, len(x) * 2, Rel.EQ, 'paddings.shape', self.name)
if not np.all(paddings >= 0):
raise ValueError('All elements of paddings must be >= 0.')
y_shape = ()
@ -2992,7 +2992,7 @@ class MirrorPad(PrimitiveWithInfer):
x_shape = list(input_x['shape'])
paddings_value = paddings['value'].asnumpy()
paddings_size = paddings_value.size
validator.check_integer('paddings.shape', paddings_size, len(x_shape) * 2, Rel.EQ, self.name)
validator.check_int(paddings_size, len(x_shape) * 2, Rel.EQ, 'paddings.shape', self.name)
if not np.all(paddings_value >= 0):
raise ValueError('All elements of paddings must be >= 0.')
adjust = 0
@ -3276,7 +3276,7 @@ class FusedSparseAdam(PrimitiveWithInfer):
beta1_shape, beta2_shape, epsilon_shape, grad_shape, indices_shape):
validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name)
validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name)
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name)
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
if len(var_shape) > 1 and grad_shape != indices_shape + var_shape[1:]:
raise ValueError(f"For '{self.name}', the shape of updates should be [] or "
@ -3409,7 +3409,7 @@ class FusedSparseLazyAdam(PrimitiveWithInfer):
beta1_shape, beta2_shape, epsilon_shape, grad_shape, indices_shape):
validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name)
validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name)
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name)
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
if len(var_shape) > 1 and grad_shape != indices_shape + var_shape[1:]:
raise ValueError(f"For '{self.name}', the shape of updates should be [] or "
@ -3513,7 +3513,7 @@ class FusedSparseFtrl(PrimitiveWithInfer):
validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name)
if len(var_shape) > 1:
validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name)
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
return [1], [1], [1]
@ -3602,7 +3602,7 @@ class FusedSparseProximalAdagrad(PrimitiveWithInfer):
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape):
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name)
return [1], [1]
def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype):
@ -3869,25 +3869,25 @@ class ApplyAdaMax(PrimitiveWithInfer):
validator.check("v_shape", v_shape, "var_shape", var_shape, Rel.EQ, self.name)
validator.check("grad_shape", grad_shape, "var_shape", var_shape, Rel.EQ, self.name)
beta1_power_shp_len = len(beta1_power_shape)
validator.check_integer("beta1 power's rank", beta1_power_shp_len, 1, Rel.LE, self.name)
validator.check_int(beta1_power_shp_len, 1, Rel.LE, "beta1 power's rank", self.name)
if beta1_power_shp_len == 1:
validator.check_integer("beta1_power_shape[0]", beta1_power_shape[0], 1, Rel.EQ, self.name)
validator.check_int(beta1_power_shape[0], 1, Rel.EQ, "beta1_power_shape[0]", self.name)
lr_shp_len = len(lr_shape)
validator.check_integer("lr's rank", lr_shp_len, 1, Rel.LE, self.name)
validator.check_int(lr_shp_len, 1, Rel.LE, "lr's rank", self.name)
if lr_shp_len == 1:
validator.check_integer("lr_shape[0]", lr_shape[0], 1, Rel.EQ, self.name)
validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name)
beta1_shp_len = len(beta1_shape)
validator.check_integer("beta1's rank", beta1_shp_len, 1, Rel.LE, self.name)
validator.check_int(beta1_shp_len, 1, Rel.LE, "beta1's rank", self.name)
if beta1_shp_len == 1:
validator.check_integer("beta1_shape[0]", beta1_shape[0], 1, Rel.EQ, self.name)
validator.check_int(beta1_shape[0], 1, Rel.EQ, "beta1_shape[0]", self.name)
beta2_shp_len = len(beta2_shape)
validator.check_integer("beta2's rank", beta2_shp_len, 1, Rel.LE, self.name)
validator.check_int(beta2_shp_len, 1, Rel.LE, "beta2's rank", self.name)
if beta2_shp_len == 1:
validator.check_integer("beta2_shape[0]", beta2_shape[0], 1, Rel.EQ, self.name)
validator.check_int(beta2_shape[0], 1, Rel.EQ, "beta2_shape[0]", self.name)
epsilon_shp_len = len(epsilon_shape)
validator.check_integer("epsilon's rank", epsilon_shp_len, 1, Rel.LE, self.name)
validator.check_int(epsilon_shp_len, 1, Rel.LE, "epsilon's rank", self.name)
if epsilon_shp_len == 1:
validator.check_integer("epsilon_shape[0]", epsilon_shape[0], 1, Rel.EQ, self.name)
validator.check_int(epsilon_shape[0], 1, Rel.EQ, "epsilon_shape[0]", self.name)
return var_shape, m_shape, v_shape
def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, lr_dtype,
@ -3985,17 +3985,17 @@ class ApplyAdadelta(PrimitiveWithInfer):
validator.check("accum_update_shape", accum_update_shape, "var_shape", var_shape, Rel.EQ, self.name)
validator.check("grad_shape", grad_shape, "var_shape", var_shape, Rel.EQ, self.name)
lr_shp_len = len(lr_shape)
validator.check_integer("lr's rank", lr_shp_len, 1, Rel.LE, self.name)
validator.check_int(lr_shp_len, 1, Rel.LE, "lr's rank", self.name)
if lr_shp_len == 1:
validator.check_integer("lr_shape[0]", lr_shape[0], 1, Rel.EQ, self.name)
validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name)
rho_shp_len = len(rho_shape)
validator.check_integer("rho's rank", rho_shp_len, 1, Rel.LE, self.name)
validator.check_int(rho_shp_len, 1, Rel.LE, "rho's rank", self.name)
if rho_shp_len == 1:
validator.check_integer("rho_shape[0]", rho_shape[0], 1, Rel.EQ, self.name)
validator.check_int(rho_shape[0], 1, Rel.EQ, "rho_shape[0]", self.name)
epsilon_shp_len = len(epsilon_shape)
validator.check_integer("lepsilon's rank", epsilon_shp_len, 1, Rel.LE, self.name)
validator.check_int(epsilon_shp_len, 1, Rel.LE, "lepsilon's rank", self.name)
if epsilon_shp_len == 1:
validator.check_integer("epsilon_shape[0]", epsilon_shape[0], 1, Rel.EQ, self.name)
validator.check_int(epsilon_shape[0], 1, Rel.EQ, "epsilon_shape[0]", self.name)
return var_shape, accum_shape, accum_update_shape
def infer_dtype(self, var_dtype, accum_dtype, accum_update_dtype, lr_dtype, rho_dtype,
@ -4077,9 +4077,9 @@ class ApplyAdagrad(PrimitiveWithInfer):
validator.check('accum shape', accum_shape, 'var shape', var_shape, Rel.EQ, self.name)
validator.check('grad shape', grad_shape, 'var shape', var_shape, Rel.EQ, self.name)
lr_shp_len = len(lr_shape)
validator.check_integer("lr's rank", lr_shp_len, 1, Rel.LE, self.name)
validator.check_int(lr_shp_len, 1, Rel.LE, "lr's rank", self.name)
if lr_shp_len == 1:
validator.check_integer("lr_shape[0]", lr_shape[0], 1, Rel.EQ, self.name)
validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name)
return var_shape, accum_shape
def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, grad_dtype):
@ -4161,9 +4161,9 @@ class ApplyAdagradV2(PrimitiveWithInfer):
validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
validator.check('var shape', var_shape, 'grad shape', grad_shape, Rel.EQ, self.name)
lr_shp_len = len(lr_shape)
validator.check_integer("lr's rank", lr_shp_len, 1, Rel.LE, self.name)
validator.check_int(lr_shp_len, 1, Rel.LE, "lr's rank", self.name)
if lr_shp_len == 1:
validator.check_integer("lr_shape[0]", lr_shape[0], 1, Rel.EQ, self.name)
validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name)
return var_shape, accum_shape
def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, grad_dtype):
@ -4249,7 +4249,7 @@ class SparseApplyAdagrad(PrimitiveWithInfer):
validator.check('len of var shape', len(var_shape), 'len of grad shape', len(grad_shape), Rel.EQ, self.name)
if len(var_shape) > 1:
validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name)
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
return var_shape, accum_shape
@ -4338,7 +4338,7 @@ class SparseApplyAdagradV2(PrimitiveWithInfer):
validator.check('len of var shape', len(var_shape), 'len of grad shape', len(grad_shape), Rel.EQ, self.name)
if len(var_shape) > 1:
validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name)
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
return var_shape, accum_shape
@ -4428,17 +4428,17 @@ class ApplyProximalAdagrad(PrimitiveWithInfer):
validator.check('accum shape', accum_shape, 'var shape', var_shape, Rel.EQ, self.name)
validator.check('grad shape', grad_shape, 'var shape', var_shape, Rel.EQ, self.name)
lr_shp_len = len(lr_shape)
validator.check_integer("lr's rank", lr_shp_len, 1, Rel.LE, self.name)
validator.check_int(lr_shp_len, 1, Rel.LE, "lr's rank", self.name)
if lr_shp_len == 1:
validator.check_integer("lr_shape[0]", lr_shape[0], 1, Rel.EQ, self.name)
validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name)
l1_shp_len = len(l1_shape)
validator.check_integer("l1's rank", l1_shp_len, 1, Rel.LE, self.name)
validator.check_int(l1_shp_len, 1, Rel.LE, "l1's rank", self.name)
if l1_shp_len == 1:
validator.check_integer("l1_shape[0]", l1_shape[0], 1, Rel.EQ, self.name)
validator.check_int(l1_shape[0], 1, Rel.EQ, "l1_shape[0]", self.name)
l2_shp_len = len(l2_shape)
validator.check_integer("l2's rank", l2_shp_len, 1, Rel.LE, self.name)
validator.check_int(l2_shp_len, 1, Rel.LE, "l2's rank", self.name)
if l2_shp_len == 1:
validator.check_integer("l2_shape[0]", l2_shape[0], 1, Rel.EQ, self.name)
validator.check_int(l2_shape[0], 1, Rel.EQ, "l2_shape[0]", self.name)
return var_shape, accum_shape
def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype):
@ -4532,7 +4532,7 @@ class SparseApplyProximalAdagrad(PrimitiveWithCheck):
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
def check_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape):
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name)
def check_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype):
args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
@ -4623,21 +4623,21 @@ class ApplyAddSign(PrimitiveWithInfer):
validator.check('m_shape', m_shape, 'var_shape', var_shape, Rel.EQ, self.name)
validator.check('grad_shape', grad_shape, 'var_shape', var_shape, Rel.EQ, self.name)
lr_shape_len = len(lr_shape)
validator.check_integer("lr's rank", lr_shape_len, 1, Rel.LE, self.name)
validator.check_int(lr_shape_len, 1, Rel.LE, "lr's rank", self.name)
if lr_shape_len == 1:
validator.check_integer("lr_shape[0]", lr_shape[0], 1, Rel.EQ, self.name)
validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name)
alpha_shape_len = len(alpha_shape)
validator.check_integer("alpha's rank", alpha_shape_len, 1, Rel.LE, self.name)
validator.check_int(alpha_shape_len, 1, Rel.LE, "alpha's rank", self.name)
if alpha_shape_len == 1:
validator.check_integer("alpha_shape[0]", alpha_shape[0], 1, Rel.EQ, self.name)
validator.check_int(alpha_shape[0], 1, Rel.EQ, "alpha_shape[0]", self.name)
sign_decay_shape_len = len(sign_decay_shape)
validator.check_integer("sign_decay's rank", sign_decay_shape_len, 1, Rel.LE, self.name)
validator.check_int(sign_decay_shape_len, 1, Rel.LE, "sign_decay's rank", self.name)
if sign_decay_shape_len == 1:
validator.check_integer("sign_decay_shape[0]", sign_decay_shape[0], 1, Rel.EQ, self.name)
validator.check_int(sign_decay_shape[0], 1, Rel.EQ, "sign_decay_shape[0]", self.name)
beta_shape_len = len(beta_shape)
validator.check_integer("beta's rank", beta_shape_len, 1, Rel.LE, self.name)
validator.check_int(beta_shape_len, 1, Rel.LE, "beta's rank", self.name)
if beta_shape_len == 1:
validator.check_integer("beta_shape[0]", beta_shape[0], 1, Rel.EQ, self.name)
validator.check_int(beta_shape[0], 1, Rel.EQ, "beta_shape[0]", self.name)
return var_shape, m_shape
def infer_dtype(self, var_dtype, m_dtype, lr_dtype, alpha_dtype, sign_decay_dtype, beta_dtype, grad_dtype):
@ -4732,21 +4732,21 @@ class ApplyPowerSign(PrimitiveWithInfer):
validator.check('m_shape', m_shape, 'var_shape', var_shape, Rel.EQ, self.name)
validator.check('grad_shape', grad_shape, 'var_shape', var_shape, Rel.EQ, self.name)
lr_shape_len = len(lr_shape)
validator.check_integer("lr's rank", lr_shape_len, 1, Rel.LE, self.name)
validator.check_int(lr_shape_len, 1, Rel.LE, "lr's rank", self.name)
if lr_shape_len == 1:
validator.check_integer("lr_shape[0]", lr_shape[0], 1, Rel.EQ, self.name)
validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name)
logbase_shape_len = len(logbase_shape)
validator.check_integer("logbase's rank", logbase_shape_len, 1, Rel.LE, self.name)
validator.check_int(logbase_shape_len, 1, Rel.LE, "logbase's rank", self.name)
if logbase_shape_len == 1:
validator.check_integer("logbase_shape[0]", logbase_shape[0], 1, Rel.EQ, self.name)
validator.check_int(logbase_shape[0], 1, Rel.EQ, "logbase_shape[0]", self.name)
sign_decay_shape_len = len(sign_decay_shape)
validator.check_integer("sign_decay's rank", sign_decay_shape_len, 1, Rel.LE, self.name)
validator.check_int(sign_decay_shape_len, 1, Rel.LE, "sign_decay's rank", self.name)
if sign_decay_shape_len == 1:
validator.check_integer("sign_decay_shape[0]", sign_decay_shape[0], 1, Rel.EQ, self.name)
validator.check_int(sign_decay_shape[0], 1, Rel.EQ, "sign_decay_shape[0]", self.name)
beta_shape_len = len(beta_shape)
validator.check_integer("beta's rank", beta_shape_len, 1, Rel.LE, self.name)
validator.check_int(beta_shape_len, 1, Rel.LE, "beta's rank", self.name)
if beta_shape_len == 1:
validator.check_integer("beta_shape[0]", beta_shape[0], 1, Rel.EQ, self.name)
validator.check_int(beta_shape[0], 1, Rel.EQ, "beta_shape[0]", self.name)
return var_shape, m_shape
def infer_dtype(self, var_dtype, m_dtype, lr_dtype, logbase_dtype, sign_decay_dtype, beta_dtype, grad_dtype):
@ -4812,9 +4812,9 @@ class ApplyGradientDescent(PrimitiveWithInfer):
def infer_shape(self, var_shape, alpha_shape, delta_shape):
validator.check('delta shape', delta_shape, 'var shape', var_shape, Rel.EQ, self.name)
alpha_shape_len = len(alpha_shape)
validator.check_integer("alpha's rank", alpha_shape_len, 1, Rel.LE, self.name)
validator.check_int(alpha_shape_len, 1, Rel.LE, "alpha's rank", self.name)
if alpha_shape_len == 1:
validator.check_integer("alpha_shape[0]", alpha_shape[0], 1, Rel.EQ, self.name)
validator.check_int(alpha_shape[0], 1, Rel.EQ, "alpha_shape[0]", self.name)
return var_shape
def infer_dtype(self, var_dtype, alpha_dtype, delta_dtype):
@ -4887,17 +4887,17 @@ class ApplyProximalGradientDescent(PrimitiveWithInfer):
def infer_shape(self, var_shape, alpha_shape, l1_shape, l2_shape, delta_shape):
validator.check('delta shape', delta_shape, 'var shape', var_shape, Rel.EQ, self.name)
alpha_shape_len = len(alpha_shape)
validator.check_integer("alpha's rank", alpha_shape_len, 1, Rel.LE, self.name)
validator.check_int(alpha_shape_len, 1, Rel.LE, "alpha's rank", self.name)
if alpha_shape_len == 1:
validator.check_integer("alpha_shape[0]", alpha_shape[0], 1, Rel.EQ, self.name)
validator.check_int(alpha_shape[0], 1, Rel.EQ, "alpha_shape[0]", self.name)
l1_shape_len = len(l1_shape)
validator.check_integer("l1's rank", l1_shape_len, 1, Rel.LE, self.name)
validator.check_int(l1_shape_len, 1, Rel.LE, "l1's rank", self.name)
if l1_shape_len == 1:
validator.check_integer("l1_shape[0]", l1_shape[0], 1, Rel.EQ, self.name)
validator.check_int(l1_shape[0], 1, Rel.EQ, "l1_shape[0]", self.name)
l2_shape_len = len(l2_shape)
validator.check_integer("l2's rank", l2_shape_len, 1, Rel.LE, self.name)
validator.check_int(l2_shape_len, 1, Rel.LE, "l2's rank", self.name)
if l2_shape_len == 1:
validator.check_integer("l2_shape[0]", l2_shape[0], 1, Rel.EQ, self.name)
validator.check_int(l2_shape[0], 1, Rel.EQ, "l2_shape[0]", self.name)
return var_shape
def infer_dtype(self, var_dtype, alpha_dtype, l1_dtype, l2_dtype, delta_dtype):
@ -4965,13 +4965,13 @@ class LARSUpdate(PrimitiveWithInfer):
validator.check("norm weight shape", norm_weight_shape, "norm gradient shape", norm_gradient_shape, Rel.EQ,
self.name)
shp_len = len(weight_decay_shape)
validator.check_integer("weight decay's rank", shp_len, 1, Rel.LE, self.name)
validator.check_int(shp_len, 1, Rel.LE, "weight decay's rank", self.name)
if shp_len == 1:
validator.check_integer("weight_decay_shape[0]", weight_decay_shape[0], 1, Rel.EQ, self.name)
validator.check_int(weight_decay_shape[0], 1, Rel.EQ, "weight_decay_shape[0]", self.name)
shp_len = len(learning_rate_shape)
validator.check_integer("learning rate's rank", shp_len, 1, Rel.LE, self.name)
validator.check_int(shp_len, 1, Rel.LE, "learning rate's rank", self.name)
if shp_len == 1:
validator.check_integer("learning_rate_shape[0]", learning_rate_shape[0], 1, Rel.EQ, self.name)
validator.check_int(learning_rate_shape[0], 1, Rel.EQ, "learning_rate_shape[0]", self.name)
return weight_shape
def infer_dtype(self, weight_dtype, gradient_dtype, norm_weight_dtype, norm_gradient_dtype,
@ -5155,7 +5155,7 @@ class SparseApplyFtrl(PrimitiveWithCheck):
validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name)
if len(var_shape) > 1:
validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name)
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
def check_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype):
@ -5251,7 +5251,7 @@ class SparseApplyFtrlV2(PrimitiveWithInfer):
validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name)
if len(var_shape) > 1:
validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name)
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
return var_shape, accum_shape, linear_shape
@ -5288,7 +5288,7 @@ class Dropout(PrimitiveWithInfer):
self.keep_prob = validator.check_float_range(keep_prob, 0, 1, Rel.INC_RIGHT, "keep_prob", self.name)
def infer_shape(self, x_shape):
validator.check_integer("x_shape", len(x_shape), 1, Rel.GE, self.name)
validator.check_int(len(x_shape), 1, Rel.GE, "x_shape", self.name)
mask_shape = x_shape
return x_shape, mask_shape
@ -5352,11 +5352,11 @@ class CTCLoss(PrimitiveWithInfer):
self.ignore_longer_outputs_than_inputs_ = ignore_longer_outputs_than_inputs
def infer_shape(self, inputs, labels_indices, labels_values, sequence_length):
validator.check_integer("inputs rank", len(inputs), 3, Rel.EQ, self.name)
validator.check_integer("labels_indices rank", len(labels_indices), 2, Rel.EQ, self.name)
validator.check_integer("labels_indices dim one", labels_indices[1], 2, Rel.EQ, self.name)
validator.check_integer("labels_values rank", len(labels_values), 1, Rel.EQ, self.name)
validator.check_integer("sequence_length rank", len(sequence_length), 1, Rel.EQ, self.name)
validator.check_int(len(inputs), 3, Rel.EQ, "inputs rank", self.name)
validator.check_int(len(labels_indices), 2, Rel.EQ, "labels_indices rank", self.name)
validator.check_int(labels_indices[1], 2, Rel.EQ, "labels_indices dim one", self.name)
validator.check_int(len(labels_values), 1, Rel.EQ, "labels_values rank", self.name)
validator.check_int(len(sequence_length), 1, Rel.EQ, "sequence_length rank", self.name)
validator.check('labels_indices size', labels_indices[0], 'labels_values size',
labels_values[0], Rel.EQ, self.name)
validator.check('inputs batch_size', inputs[1], 'sequence_length batch_size',
@ -5422,8 +5422,8 @@ class CTCGreedyDecoder(PrimitiveWithInfer):
self.merge_repeated = validator.check_value_type("merge_repeated", merge_repeated, [bool], self.name)
def infer_shape(self, inputs_shape, sequence_length_shape):
validator.check_integer("inputs rank", len(inputs_shape), 3, Rel.EQ, self.name)
validator.check_integer("sequence_length rank", len(sequence_length_shape), 1, Rel.EQ, self.name)
validator.check_int(len(inputs_shape), 3, Rel.EQ, "inputs rank", self.name)
validator.check_int(len(sequence_length_shape), 1, Rel.EQ, "sequence_length rank", self.name)
validator.check('inputs batch_size', inputs_shape[1], 'sequence_length batch_size',
sequence_length_shape[0], Rel.EQ, self.name)
total_decoded_outputs = -1
@ -5517,11 +5517,11 @@ class BasicLSTMCell(PrimitiveWithInfer):
self.add_prim_attr("io_format", "ND")
def infer_shape(self, x_shape, h_shape, c_shape, w_shape, b_shape):
validator.check_integer("x rank", len(x_shape), 2, Rel.EQ, self.name)
validator.check_integer("h rank", len(h_shape), 2, Rel.EQ, self.name)
validator.check_integer("c rank", len(c_shape), 2, Rel.EQ, self.name)
validator.check_integer("w rank", len(w_shape), 2, Rel.EQ, self.name)
validator.check_integer("b rank", len(b_shape), 1, Rel.EQ, self.name)
validator.check_int(len(x_shape), 2, Rel.EQ, "x rank", self.name)
validator.check_int(len(h_shape), 2, Rel.EQ, "h rank", self.name)
validator.check_int(len(c_shape), 2, Rel.EQ, "c rank", self.name)
validator.check_int(len(w_shape), 2, Rel.EQ, "w rank", self.name)
validator.check_int(len(b_shape), 1, Rel.EQ, "b rank", self.name)
validator.check("x_shape[0]", x_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name)
validator.check("c_shape[0]", c_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name)
validator.check("c_shape[1]", c_shape[1], "h_shape[1]", h_shape[1], Rel.EQ, self.name)
@ -5637,11 +5637,11 @@ class DynamicRNN(PrimitiveWithInfer):
self.add_prim_attr("io_format", "ND")
def infer_shape(self, x_shape, w_shape, b_shape, seq_shape, h_shape, c_shape):
validator.check_integer("x_shape", len(x_shape), 3, Rel.EQ, self.name)
validator.check_integer("w rank", len(w_shape), 2, Rel.EQ, self.name)
validator.check_integer("b rank", len(b_shape), 1, Rel.EQ, self.name)
validator.check_integer("h_shape", len(h_shape), 3, Rel.EQ, self.name)
validator.check_integer("c_shape", len(c_shape), 3, Rel.EQ, self.name)
validator.check_int(len(x_shape), 3, Rel.EQ, "x_shape", self.name)
validator.check_int(len(w_shape), 2, Rel.EQ, "w rank", self.name)
validator.check_int(len(b_shape), 1, Rel.EQ, "b rank", self.name)
validator.check_int(len(h_shape), 3, Rel.EQ, "h_shape", self.name)
validator.check_int(len(c_shape), 3, Rel.EQ, "c_shape", self.name)
if seq_shape is not None:
raise ValueError(f"For {self.name}, seq_shape should be None.")
@ -5654,7 +5654,7 @@ class DynamicRNN(PrimitiveWithInfer):
validator.check("w_shape[0]", w_shape[0], "input_size + hidden_size",
input_size + hidden_size, Rel.EQ, self.name)
validator.check("b_shape[0]", b_shape[0], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
validator.check_integer("h_shape[0]", h_shape[0], 1, Rel.EQ, self.name)
validator.check_int(h_shape[0], 1, Rel.EQ, "h_shape[0]", self.name)
validator.check("h_shape[1]", h_shape[1], "batch_size", batch_size, Rel.EQ, self.name)
validator.check("h_shape[2]", h_shape[2], "hidden_size", hidden_size, Rel.EQ, self.name)
validator.check("c_shape", c_shape, "h_shape", h_shape, Rel.EQ, self.name)
@ -5754,5 +5754,5 @@ class LRN(PrimitiveWithInfer):
return x_dtype
def infer_shape(self, x_shape):
validator.check_integer("x_shape", len(x_shape), 4, Rel.EQ, self.name)
validator.check_int(len(x_shape), 4, Rel.EQ, "x_shape", self.name)
return x_shape

View File

@ -98,16 +98,16 @@ class BoundingBoxEncode(PrimitiveWithInfer):
validator.check_value_type("means[%d]" % i, value, [float], self.name)
for i, value in enumerate(stds):
validator.check_value_type("stds[%d]" % i, value, [float], self.name)
validator.check_integer("means len", len(means), 4, Rel.EQ, self.name)
validator.check_integer("stds len", len(stds), 4, Rel.EQ, self.name)
validator.check_equal_int(len(means), 4, "means len", self.name)
validator.check_equal_int(len(stds), 4, "stds len", self.name)
def infer_shape(self, anchor_box, groundtruth_box):
validator.check('anchor_box shape[0]', anchor_box[0], 'groundtruth_box shape[0]', groundtruth_box[0], Rel.EQ,
self.name)
validator.check("anchor_box rank", len(anchor_box), "", 2, Rel.EQ, self.name)
validator.check("groundtruth_box rank", len(groundtruth_box), "", 2, Rel.EQ, self.name)
validator.check_integer('anchor_box shape[1]', anchor_box[1], 4, Rel.EQ, self.name)
validator.check_integer('groundtruth_box shape[1]', groundtruth_box[1], 4, Rel.EQ, self.name)
validator.check_equal_int(anchor_box[1], 4, 'anchor_box shape[1]', self.name)
validator.check_equal_int(groundtruth_box[1], 4, 'groundtruth_box shape[1]', self.name)
return anchor_box
def infer_dtype(self, anchor_box, groundtruth_box):
@ -153,18 +153,18 @@ class BoundingBoxDecode(PrimitiveWithInfer):
for i, value in enumerate(stds):
validator.check_value_type("stds[%d]" % i, value, [float], self.name)
validator.check_value_type('wh_ratio_clip', wh_ratio_clip, [float], self.name)
validator.check_integer("means len", len(means), 4, Rel.EQ, self.name)
validator.check_integer("stds len", len(stds), 4, Rel.EQ, self.name)
validator.check_equal_int(len(means), 4, "means len", self.name)
validator.check_equal_int(len(stds), 4, "stds len", self.name)
if max_shape is not None:
validator.check_value_type('max_shape', max_shape, [tuple], self.name)
validator.check_integer("max_shape len", len(max_shape), 2, Rel.EQ, self.name)
validator.check_equal_int(len(max_shape), 2, "max_shape len", self.name)
def infer_shape(self, anchor_box, deltas):
validator.check('anchor_box shape[0]', anchor_box[0], 'deltas shape[0]', deltas[0], Rel.EQ, self.name)
validator.check("anchor_box rank", len(anchor_box), "", 2, Rel.EQ, self.name)
validator.check("deltas rank", len(deltas), "", 2, Rel.EQ, self.name)
validator.check_integer('anchor_box shape[1]', anchor_box[1], 4, Rel.EQ, self.name)
validator.check_integer('deltas shape[1]', deltas[1], 4, Rel.EQ, self.name)
validator.check_equal_int(anchor_box[1], 4, 'anchor_box shape[1]', self.name)
validator.check_equal_int(deltas[1], 4, 'deltas shape[1]', self.name)
return anchor_box
def infer_dtype(self, anchor_box, deltas):
@ -272,10 +272,10 @@ class IOU(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['anchor_boxes', 'gt_boxes'], outputs=['overlap'])
def infer_shape(self, anchor_boxes, gt_boxes):
validator.check_integer('gt_boxes shape[1]', gt_boxes[1], 4, Rel.EQ, self.name)
validator.check_integer('anchor_boxes shape[1]', anchor_boxes[1], 4, Rel.EQ, self.name)
validator.check_integer('anchor_boxes rank', len(anchor_boxes), 2, Rel.EQ, self.name)
validator.check_integer('gt_boxes rank', len(gt_boxes), 2, Rel.EQ, self.name)
validator.check_equal_int(gt_boxes[1], 4, 'gt_boxes shape[1]', self.name)
validator.check_equal_int(anchor_boxes[1], 4, 'anchor_boxes shape[1]', self.name)
validator.check_equal_int(len(anchor_boxes), 2, 'anchor_boxes rank', self.name)
validator.check_equal_int(len(gt_boxes), 2, 'gt_boxes rank', self.name)
iou = [gt_boxes[0], anchor_boxes[0]]
return iou

View File

@ -356,8 +356,8 @@ class RandomChoiceWithMask(PrimitiveWithInfer):
Validator.check_value_type('seed2', seed2, [int], self.name)
def infer_shape(self, x_shape):
Validator.check_integer("input_x rank", len(x_shape), 1, Rel.GE, self.name)
Validator.check_integer("input_x rank", len(x_shape), 5, Rel.LE, self.name)
Validator.check_int(len(x_shape), 1, Rel.GE, "input_x rank", self.name)
Validator.check_int(len(x_shape), 5, Rel.LE, "input_x rank", self.name)
return ([self.count, len(x_shape)], [self.count])
def infer_dtype(self, x_dtype):

View File

@ -227,7 +227,7 @@ class PrimitiveWithCheck(Primitive):
>>> def __init__(self):
>>> pass
>>> def check_shape(self, input_x):
>>> validator.check_integer('input_x rank', len(input_x), 1, Rel.GE, self.name)
>>> validator.check_int(len(input_x), 1, Rel.GE, 'input_x rank', self.name)
>>>
>>> def check_dtype(self, input_x):
>>> validator.check_subclass("input_x", input_x, mstype.tensor, self.name)

View File

@ -89,12 +89,12 @@ class ConvertToQuantNetwork:
def __init__(self, **kwargs):
self.network = Validator.check_isinstance('network', kwargs["network"], (nn.Cell,))
self.weight_qdelay = Validator.check_integer("quant delay", kwargs["quant_delay"][0], 0, Rel.GE)
self.act_qdelay = Validator.check_integer("quant delay", kwargs["quant_delay"][-1], 0, Rel.GE)
self.weight_qdelay = Validator.check_non_negative_int(kwargs["quant_delay"][0], "quant delay")
self.act_qdelay = Validator.check_int(kwargs["quant_delay"][-1], 0, Rel.GE, "quant delay")
self.bn_fold = Validator.check_bool(kwargs["bn_fold"], "bn fold")
self.freeze_bn = Validator.check_integer("freeze bn", kwargs["freeze_bn"], 0, Rel.GE)
self.weight_bits = Validator.check_integer("weights bit", kwargs["num_bits"][0], 0, Rel.GE)
self.act_bits = Validator.check_integer("activations bit", kwargs["num_bits"][-1], 0, Rel.GE)
self.freeze_bn = Validator.check_non_negative_int(kwargs["freeze_bn"], "freeze bn")
self.weight_bits = Validator.check_non_negative_int(kwargs["num_bits"][0], "weights bit")
self.act_bits = Validator.check_int(kwargs["num_bits"][-1], 0, Rel.GE, "activations bit")
self.weight_channel = Validator.check_bool(kwargs["per_channel"][0], "per channel")
self.act_channel = Validator.check_bool(kwargs["per_channel"][-1], "per channel")
self.weight_symmetric = Validator.check_bool(kwargs["symmetric"][0], "symmetric")

View File

@ -21,7 +21,7 @@ from PIL import Image
from mindspore import log as logger
from ..._checkparam import _check_str_by_regular
from ..._checkparam import Validator
from ..anf_ir_pb2 import DataType, ModelProto
from ..summary_pb2 import Event
@ -47,8 +47,8 @@ def get_event_file_name(prefix, suffix):
Returns:
String, the name of event log file.
"""
_check_str_by_regular(prefix)
_check_str_by_regular(suffix)
Validator.check_str_by_regular(prefix)
Validator.check_str_by_regular(suffix)
file_name = ""
time_second = str(int(time.time()))
hostname = platform.node()

View File

@ -21,7 +21,7 @@ import threading
from mindspore import log as logger
from ..._c_expression import Tensor
from ..._checkparam import _check_str_by_regular
from ..._checkparam import Validator
from .._utils import _check_lineage_value, _check_to_numpy, _make_directory
from ._summary_adapter import get_event_file_name, package_graph_event
from ._writer_pool import WriterPool
@ -103,8 +103,8 @@ class SummaryRecord:
self._closed, self._event_writer = False, None
self._mode, self._data_pool = 'train', _dictlist()
_check_str_by_regular(file_prefix)
_check_str_by_regular(file_suffix)
Validator.check_str_by_regular(file_prefix)
Validator.check_str_by_regular(file_suffix)
self.log_path = _make_directory(log_dir)

View File

@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test checkparameter """
""" test check parameter """
import pytest
import numpy as np
from mindspore._checkparam import twice, Validator
from mindspore._checkparam import Validator, twice
kernel_size = 5
kernel_size1 = twice(kernel_size)

View File

@ -18,7 +18,7 @@ import numpy as np
import pytest
from mindspore import context, Tensor, Parameter, ParameterTuple, nn
from mindspore._checkparam import _check_str_by_regular
from mindspore._checkparam import Validator
from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer
@ -124,15 +124,15 @@ def test_check_str_by_regular():
str4 = ".12_sf.asdf"
str5 = "12_sf.a$sdf."
str6 = "12+sf.asdf"
_check_str_by_regular(str1)
_check_str_by_regular(str2)
_check_str_by_regular(str3)
Validator.check_str_by_regular(str1)
Validator.check_str_by_regular(str2)
Validator.check_str_by_regular(str3)
with pytest.raises(ValueError):
_check_str_by_regular(str4)
Validator.check_str_by_regular(str4)
with pytest.raises(ValueError):
_check_str_by_regular(str5)
Validator.check_str_by_regular(str5)
with pytest.raises(ValueError):
_check_str_by_regular(str6)
Validator.check_str_by_regular(str6)
def test_parameter_compute():
para_1 = Parameter(initializer('ones', [1, 2, 3], mstype.int32), 'test1')