forked from mindspore-Ecosystem/mindspore
!344 add prim name to error message for operators in array_ops.py
Merge pull request !344 from fary86/add_prim_name_to_error_message_for_array_ops
This commit is contained in:
commit
348b0ef53c
|
@ -210,7 +210,7 @@ class Validator:
|
|||
type_names = []
|
||||
for t in valid_values:
|
||||
type_names.append(str(t))
|
||||
types_info = '[' + ", ".join(type_names) + ']'
|
||||
types_info = '[' + ', '.join(type_names) + ']'
|
||||
raise TypeError(f'For \'{prim_name}\' type of `{arg_key}` should be in {types_info},'
|
||||
f' but got {elem_type}.')
|
||||
return (arg_key, elem_type)
|
||||
|
@ -320,224 +320,6 @@ class Validator:
|
|||
raise TypeError(f"{msg_prefix} `{arg_name}` must be float.")
|
||||
|
||||
|
||||
class ParamValidator:
|
||||
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`"""
|
||||
|
||||
@staticmethod
|
||||
def equal(arg_name, arg_value, cond_str, cond):
|
||||
"""Judging valid value."""
|
||||
if not cond:
|
||||
raise ValueError(f'The `{arg_name}` must be {cond_str}, but got {arg_value}.')
|
||||
|
||||
@staticmethod
|
||||
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ):
|
||||
"""This method is only used for check int values, since when compare float values,
|
||||
we need consider float error."""
|
||||
rel_fn = Rel.get_fns(rel)
|
||||
if not rel_fn(arg_value, value):
|
||||
rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
|
||||
raise ValueError(f'The `{arg_name}` should be {rel_str}, but got {arg_value}.')
|
||||
|
||||
@staticmethod
|
||||
def check_integer(arg_name, arg_value, value, rel):
|
||||
"""Integer value judgment."""
|
||||
rel_fn = Rel.get_fns(rel)
|
||||
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
|
||||
if type_mismatch or not rel_fn(arg_value, value):
|
||||
rel_str = Rel.get_strs(rel).format(value)
|
||||
raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
|
||||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
def check_shape_length(arg_name, arg_value, value, rel):
|
||||
"""Shape length judgment."""
|
||||
rel_fn = Rel.get_fns(rel)
|
||||
type_mismatch = not isinstance(arg_value, int)
|
||||
if type_mismatch or not rel_fn(arg_value, value):
|
||||
rel_str = Rel.get_strs(rel).format(value)
|
||||
raise ValueError(f'The length of `{arg_name}` should be an int and must {rel_str}, but got {arg_value}')
|
||||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel):
|
||||
"""This method is only used for check int values,
|
||||
since when compare float values, we need consider float error."""
|
||||
rel_fn = Rel.get_fns(rel)
|
||||
type_mismatch = not isinstance(arg_value, int)
|
||||
if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit):
|
||||
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
|
||||
raise ValueError(f'The `{arg_name}` should be an int in range {rel_str}, but got {arg_value}.')
|
||||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
def check_isinstance(arg_name, arg_value, classes):
|
||||
"""Check arg isinstance of classes"""
|
||||
if not isinstance(arg_value, classes):
|
||||
raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.')
|
||||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel):
|
||||
"""Is it necessary to consider error when comparing float values."""
|
||||
rel_fn = Rel.get_fns(rel)
|
||||
if not rel_fn(arg_value, lower_limit, upper_limit):
|
||||
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
|
||||
raise ValueError(f'The `{arg_name}` should be in range {rel_str}, but got {arg_value}.')
|
||||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
def check_subclass(arg_name, type_, template_type, with_type_of=True):
|
||||
"""Check whether some type is subclass of another type"""
|
||||
if not isinstance(template_type, Iterable):
|
||||
template_type = (template_type,)
|
||||
if not any([mstype.issubclass_(type_, x) for x in template_type]):
|
||||
type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_)
|
||||
raise TypeError(f'The {"type of" if with_type_of else ""} `{arg_name}` should be subclass'
|
||||
f' of {",".join((str(x) for x in template_type))}, but got {type_str}.')
|
||||
|
||||
@staticmethod
|
||||
def check_args_tensor(args):
|
||||
"""Check whether args are all tensor."""
|
||||
if not isinstance(args, dict):
|
||||
raise TypeError("The args should be a dict.")
|
||||
for arg, value in args.items():
|
||||
ParamValidator.check_subclass(arg, value, mstype.tensor)
|
||||
|
||||
@staticmethod
|
||||
def check_bool(arg_name, arg_value):
|
||||
"""Check arg isinstance of bool"""
|
||||
if not isinstance(arg_value, bool):
|
||||
raise ValueError(f'The `{arg_name}` should be isinstance of bool, but got {arg_value}.')
|
||||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
def check_type(arg_name, arg_value, valid_types):
|
||||
"""Type checking."""
|
||||
def raise_error_msg():
|
||||
"""func for raising error message when check failed"""
|
||||
type_names = [t.__name__ for t in valid_types]
|
||||
num_types = len(valid_types)
|
||||
raise TypeError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
|
||||
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
|
||||
|
||||
if isinstance(arg_value, type(mstype.tensor)):
|
||||
arg_value = arg_value.element_type()
|
||||
# Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and
|
||||
# `check_type('x', True, [bool, int])` will check pass
|
||||
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
|
||||
raise_error_msg()
|
||||
if isinstance(arg_value, tuple(valid_types)):
|
||||
return arg_value
|
||||
raise_error_msg()
|
||||
|
||||
@staticmethod
|
||||
def check_typename(arg_name, arg_type, valid_types):
|
||||
"""Does it contain the _name_ attribute."""
|
||||
|
||||
def get_typename(t):
|
||||
return t.__name__ if hasattr(t, '__name__') else str(t)
|
||||
|
||||
if isinstance(arg_type, type(mstype.tensor)):
|
||||
arg_type = arg_type.element_type()
|
||||
|
||||
if arg_type in valid_types:
|
||||
return arg_type
|
||||
type_names = [get_typename(t) for t in valid_types]
|
||||
if len(valid_types) == 1:
|
||||
raise ValueError(f'The type of `{arg_name}` should be {type_names[0]},'
|
||||
f' but got {get_typename(arg_type)}.')
|
||||
raise ValueError(f'The type of `{arg_name}` should be one of {type_names},'
|
||||
f' but got {get_typename(arg_type)}.')
|
||||
|
||||
@staticmethod
|
||||
def check_string(arg_name, arg_value, valid_values):
|
||||
"""String type judgment."""
|
||||
if isinstance(arg_value, str) and arg_value in valid_values:
|
||||
return arg_value
|
||||
if len(valid_values) == 1:
|
||||
raise ValueError(f'The `{arg_name}` should be str and must be {valid_values[0]},'
|
||||
f' but got {arg_value}.')
|
||||
raise ValueError(f'The `{arg_name}` should be str and must be one of {valid_values},'
|
||||
f' but got {arg_value}.')
|
||||
|
||||
@staticmethod
|
||||
def check_type_same(args, valid_values):
|
||||
"""Determine whether the types are the same."""
|
||||
name = list(args.keys())[0]
|
||||
value = list(args.values())[0]
|
||||
if isinstance(value, type(mstype.tensor)):
|
||||
value = value.element_type()
|
||||
for arg_name, arg_value in args.items():
|
||||
if isinstance(arg_value, type(mstype.tensor)):
|
||||
arg_value = arg_value.element_type()
|
||||
|
||||
if arg_value not in valid_values:
|
||||
raise TypeError(f'The `{arg_name}` should be in {valid_values},'
|
||||
f' but `{arg_name}` is {arg_value}.')
|
||||
if arg_value != value:
|
||||
raise TypeError(f'`{arg_name}` should be same as `{name}`,'
|
||||
f' but `{arg_name}` is {arg_value}, `{name}` is {value}.')
|
||||
|
||||
@staticmethod
|
||||
def check_two_types_same(arg1_name, arg1_type, arg2_name, arg2_type):
|
||||
"""Determine whether the types of two variables are the same."""
|
||||
if arg1_type != arg2_type:
|
||||
raise TypeError(f'The type of `{arg1_name}` and `{arg2_name}` should be same.')
|
||||
|
||||
@staticmethod
|
||||
def check_value_on_integer(arg_name, arg_value, value, rel):
|
||||
"""Judging integer type."""
|
||||
rel_fn = Rel.get_fns(rel)
|
||||
type_match = isinstance(arg_value, int)
|
||||
if type_match and (not rel_fn(arg_value, value)):
|
||||
rel_str = Rel.get_strs(rel).format(value)
|
||||
raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
|
||||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
def check_param_equal(param1_name, param1_value, param2_name, param2_value):
|
||||
"""Judging the equality of parameters."""
|
||||
if param1_value != param2_value:
|
||||
raise ValueError(f"`{param1_name}` must equal `{param2_name}`,"
|
||||
f" but got `{param1_name}` = {param1_value},"
|
||||
f" `{param2_name}` = {param2_value}.")
|
||||
|
||||
@staticmethod
|
||||
def check_const_input(arg_name, arg_value):
|
||||
"""Check valid value."""
|
||||
if arg_value is None:
|
||||
raise ValueError(f'The `{arg_name}` must be a const input, but got {arg_value}.')
|
||||
|
||||
@staticmethod
|
||||
def check_float_positive(arg_name, arg_value):
|
||||
"""Float type judgment."""
|
||||
if isinstance(arg_value, float):
|
||||
if arg_value > 0:
|
||||
return arg_value
|
||||
raise ValueError(f"The `{arg_name}` must be positive, but got {arg_value}.")
|
||||
|
||||
raise TypeError(f"`{arg_name}` must be float!")
|
||||
|
||||
@staticmethod
|
||||
def check_pad_value_by_mode(op_name, pad_mode, padding):
|
||||
"""Validate value of padding according to pad_mode"""
|
||||
if pad_mode != 'pad' and padding != 0:
|
||||
raise ValueError(f"For op '{op_name}', padding must be zero when pad_mode is '{pad_mode}'.")
|
||||
return padding
|
||||
|
||||
@staticmethod
|
||||
def check_empty_shape_input(arg_name, arg_value):
|
||||
"""Check zeros value."""
|
||||
if 0 in arg_value:
|
||||
raise ValueError(f"Input `{arg_name}` cannot be empty.")
|
||||
|
||||
@staticmethod
|
||||
def check_scalar_shape_input(arg_name, arg_value):
|
||||
"""Check scalar shape input."""
|
||||
if arg_value != []:
|
||||
raise ValueError(f"Input `{arg_name}` shape should be (). got {arg_value}")
|
||||
|
||||
|
||||
def check_int(input_param):
|
||||
"""Int type judgment."""
|
||||
if isinstance(input_param, int) and not isinstance(input_param, bool):
|
||||
|
@ -653,30 +435,6 @@ def check_output_data(data):
|
|||
raise RuntimeError('Executor return data ' + str(data) + ', please check your net or input data.')
|
||||
|
||||
|
||||
def check_axis_type_int(axis):
|
||||
"""Check axis type."""
|
||||
if not isinstance(axis, int):
|
||||
raise TypeError('Wrong type for axis, should be int.')
|
||||
|
||||
|
||||
def check_axis_range(axis, rank):
|
||||
"""Check axis range."""
|
||||
if not -rank <= axis < rank:
|
||||
raise ValueError('The axis should be in range [{}, {}),'' but got {}.'.format(-rank, rank, axis))
|
||||
|
||||
|
||||
def check_attr_int(attr_name, attr):
|
||||
"""Check int type."""
|
||||
if not isinstance(attr, int):
|
||||
raise TypeError("The attr {} should be int, but got {}.".format(attr_name, type(attr)))
|
||||
|
||||
|
||||
def check_t_in_range(t):
|
||||
"""Check input range."""
|
||||
if t not in (mstype.float16, mstype.float32, mstype.float64, mstype.int32, mstype.int64):
|
||||
raise ValueError("The param T should be (float16, float32, float64, int32, int64).")
|
||||
|
||||
|
||||
once = _expand_tuple(1)
|
||||
twice = _expand_tuple(2)
|
||||
triple = _expand_tuple(3)
|
||||
|
|
|
@ -175,7 +175,7 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
|
|||
UpdateAdjoint(node_adjoint);
|
||||
anfnode_to_adjoin_[morph] = node_adjoint;
|
||||
if (cnode_morph->stop_gradient()) {
|
||||
MS_LOG(WARNING) << "MapMorphism node " << morph->ToString() << " is stopped.";
|
||||
MS_LOG(DEBUG) << "MapMorphism node " << morph->ToString() << " is stopped.";
|
||||
return node_adjoint;
|
||||
}
|
||||
|
||||
|
|
|
@ -19,7 +19,6 @@ from mindspore._checkparam import Validator as validator
|
|||
from ... import context
|
||||
from ..cell import Cell
|
||||
from ..._checkparam import Rel
|
||||
from ..._checkparam import ParamValidator
|
||||
|
||||
|
||||
class _PoolNd(Cell):
|
||||
|
@ -265,11 +264,11 @@ class AvgPool1d(_PoolNd):
|
|||
stride=1,
|
||||
pad_mode="valid"):
|
||||
super(AvgPool1d, self).__init__(kernel_size, stride, pad_mode)
|
||||
ParamValidator.check_type('kernel_size', kernel_size, [int,])
|
||||
ParamValidator.check_type('stride', stride, [int,])
|
||||
self.pad_mode = ParamValidator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME'])
|
||||
ParamValidator.check_integer("kernel_size", kernel_size, 1, Rel.GE)
|
||||
ParamValidator.check_integer("stride", stride, 1, Rel.GE)
|
||||
validator.check_value_type('kernel_size', kernel_size, [int], self.cls_name)
|
||||
validator.check_value_type('stride', stride, [int], self.cls_name)
|
||||
self.pad_mode = validator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME'], self.cls_name)
|
||||
validator.check_integer("kernel_size", kernel_size, 1, Rel.GE, self.cls_name)
|
||||
validator.check_integer("stride", stride, 1, Rel.GE, self.cls_name)
|
||||
self.kernel_size = (1, kernel_size)
|
||||
self.stride = (1, stride)
|
||||
self.avg_pool = P.AvgPool(ksize=self.kernel_size,
|
||||
|
|
|
@ -24,7 +24,7 @@ import itertools
|
|||
import numbers
|
||||
import numpy as np
|
||||
|
||||
from ..._checkparam import ParamValidator as validator
|
||||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import Rel
|
||||
from ...common import dtype as mstype
|
||||
from ...common.tensor import Tensor
|
||||
|
@ -32,12 +32,12 @@ from ..operations.math_ops import _infer_shape_reduce
|
|||
from .._utils import _get_concat_offset
|
||||
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
||||
|
||||
def _check_infer_attr_reduce(axis, keep_dims):
|
||||
validator.check_type('keep_dims', keep_dims, [bool])
|
||||
validator.check_type('axis', axis, [int, tuple])
|
||||
def _check_infer_attr_reduce(axis, keep_dims, prim_name):
|
||||
validator.check_value_type('keep_dims', keep_dims, [bool], prim_name)
|
||||
validator.check_value_type('axis', axis, [int, tuple], prim_name)
|
||||
if isinstance(axis, tuple):
|
||||
for index, value in enumerate(axis):
|
||||
validator.check_type('axis[%d]' % index, value, [int])
|
||||
validator.check_value_type('axis[%d]' % index, value, [int], prim_name)
|
||||
|
||||
|
||||
class ExpandDims(PrimitiveWithInfer):
|
||||
|
@ -74,13 +74,11 @@ class ExpandDims(PrimitiveWithInfer):
|
|||
self.init_prim_io_names(inputs=['x', 'axis'], outputs=['output'])
|
||||
|
||||
def __infer__(self, x, axis):
|
||||
validator.check_subclass("input_x", x['dtype'], mstype.tensor)
|
||||
validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name)
|
||||
x_shape = list(x['shape'])
|
||||
axis_v = axis['value']
|
||||
rank = len(x_shape)
|
||||
validator.check_const_input('axis', axis_v)
|
||||
validator.check_type("axis", axis_v, [int])
|
||||
validator.check_int_range('axis', axis_v, -rank - 1, rank, Rel.INC_BOTH)
|
||||
validator.check_int_range('axis', axis_v, -rank - 1, rank, Rel.INC_BOTH, self.name)
|
||||
if axis_v < 0:
|
||||
axis_v = rank + 1 + axis_v
|
||||
x_shape.insert(axis_v, 1)
|
||||
|
@ -110,7 +108,7 @@ class DType(PrimitiveWithInfer):
|
|||
"""init DType"""
|
||||
|
||||
def __infer__(self, x):
|
||||
validator.check_subclass("input_x", x['dtype'], mstype.tensor)
|
||||
validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name)
|
||||
out = {'shape': (),
|
||||
'dtype': mstype.type_type,
|
||||
'value': x['dtype'].element_type()}
|
||||
|
@ -144,19 +142,17 @@ class SameTypeShape(PrimitiveWithInfer):
|
|||
|
||||
def __call__(self, x, y):
|
||||
"""run in PyNative mode"""
|
||||
if x.dtype() != y.dtype():
|
||||
raise TypeError(f"The {x} and {y} should be same dtype.")
|
||||
if x.shape() != y.shape():
|
||||
raise TypeError(f"The {x} and {y} should have same shape.")
|
||||
validator.check_subclass('x', x.dtype(), mstype.tensor, self.name)
|
||||
validator.check_subclass('y', y.dtype(), mstype.tensor, self.name)
|
||||
validator.check('x dtype', x.dtype(), 'y dtype', y.dtype(), Rel.EQ, self.name, TypeError)
|
||||
validator.check('x shape', x.shape(), 'y shape', y.shape(), Rel.EQ, self.name)
|
||||
return x
|
||||
|
||||
def __infer__(self, x, y):
|
||||
if x['dtype'] != y['dtype']:
|
||||
raise TypeError(f"The {x} and {y} should be same dtype,"
|
||||
f" but got {x['dtype']} {y['dtype']}.")
|
||||
if x['shape'] != y['shape']:
|
||||
raise ValueError(f"The {x} and {y} should be same shape,"
|
||||
f" but got {x['shape']} {y['shape']}.")
|
||||
validator.check_subclass('x', x['dtype'], mstype.tensor, self.name)
|
||||
validator.check_subclass('y', y['dtype'], mstype.tensor, self.name)
|
||||
validator.check('x dtype', x['dtype'], 'y dtype', y['dtype'], Rel.EQ, self.name, TypeError)
|
||||
validator.check('x shape', x['shape'], 'y shape', y['shape'], Rel.EQ, self.name)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -191,8 +187,8 @@ class Cast(PrimitiveWithInfer):
|
|||
src_type = x['dtype']
|
||||
dst_type = t['value']
|
||||
|
||||
validator.check_subclass("input_x", src_type, [mstype.tensor, mstype.number])
|
||||
validator.check_subclass("type", dst_type, mstype.number, with_type_of=False)
|
||||
validator.check_subclass("input_x", src_type, [mstype.tensor, mstype.number], self.name)
|
||||
validator.check_subclass("type", dst_type, mstype.number, self.name)
|
||||
|
||||
if isinstance(src_type, type(mstype.tensor)):
|
||||
src_type = x['dtype'].element_type()
|
||||
|
@ -238,8 +234,8 @@ class IsSubClass(PrimitiveWithInfer):
|
|||
sub_type_t = sub_type['value']
|
||||
type_v = type_['value']
|
||||
|
||||
validator.check_type("sub_type", sub_type_t, [mstype.Type])
|
||||
validator.check_type("type_", type_v, [mstype.Type])
|
||||
validator.check_value_type("sub_type", sub_type_t, [mstype.Type], self.name)
|
||||
validator.check_value_type("type_", type_v, [mstype.Type], self.name)
|
||||
|
||||
value = mstype.issubclass_(sub_type_t, type_v)
|
||||
|
||||
|
@ -273,8 +269,8 @@ class IsInstance(PrimitiveWithInfer):
|
|||
sub_type_t = inst['dtype']
|
||||
type_v = type_['value']
|
||||
|
||||
validator.check_const_input("inst", inst['value'])
|
||||
validator.check_type("type_", type_v, [mstype.Type])
|
||||
validator.check_const_input("inst", inst['value'], self.name)
|
||||
validator.check_value_type("type_", type_v, [mstype.Type], self.name)
|
||||
|
||||
value = mstype.issubclass_(sub_type_t, type_v)
|
||||
|
||||
|
@ -316,14 +312,13 @@ class Reshape(PrimitiveWithInfer):
|
|||
def __infer__(self, x, shape):
|
||||
shape_v = shape['value']
|
||||
x_shp = x['shape']
|
||||
validator.check_subclass("x", x['dtype'], mstype.tensor)
|
||||
validator.check_const_input("shape", shape_v)
|
||||
validator.check_type("shape", shape_v, [tuple])
|
||||
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
|
||||
validator.check_value_type("shape", shape_v, [tuple], self.name)
|
||||
shape_v = list(shape_v)
|
||||
neg_index = -1
|
||||
dim_prod = 1
|
||||
for i, shp_i in enumerate(shape_v):
|
||||
validator.check_type("shape[%d]" % i, shp_i, [int])
|
||||
validator.check_value_type("shape[%d]" % i, shp_i, [int], self.name)
|
||||
if shp_i == -1:
|
||||
if neg_index != -1:
|
||||
raise ValueError(f'The shape can only has one -1 at most, but {shape_v}.')
|
||||
|
@ -332,7 +327,7 @@ class Reshape(PrimitiveWithInfer):
|
|||
dim_prod *= shp_i
|
||||
arr_prod = np.prod(x_shp)
|
||||
if dim_prod <= 0 or arr_prod % dim_prod != 0:
|
||||
raise ValueError(f'The product of shape should > 0 and'
|
||||
raise ValueError(f'For \'{self.name}\' the product of shape should > 0 and'
|
||||
f' can be divided by prod of input {arr_prod},'
|
||||
f' but shape {shape}, product of shape {dim_prod}.')
|
||||
|
||||
|
@ -340,7 +335,7 @@ class Reshape(PrimitiveWithInfer):
|
|||
shape_v[neg_index] = int(arr_prod / dim_prod)
|
||||
dim_prod *= shape_v[neg_index]
|
||||
if dim_prod != arr_prod:
|
||||
raise ValueError(f'The shape arg for reshape must match array''s size'
|
||||
raise ValueError(f'For \'{self.name}\' The shape arg for reshape must match array''s size'
|
||||
f' input shape {arr_prod}, shape {dim_prod}.')
|
||||
|
||||
value = None
|
||||
|
@ -406,10 +401,10 @@ class Squeeze(PrimitiveWithInfer):
|
|||
def __init__(self, axis=()):
|
||||
"""init Squeeze"""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
validator.check_type('axis', axis, [int, tuple])
|
||||
validator.check_value_type('axis', axis, [int, tuple], self.name)
|
||||
if isinstance(axis, tuple):
|
||||
for item in axis:
|
||||
validator.check_type("item", item, [int])
|
||||
for idx, item in enumerate(axis):
|
||||
validator.check_value_type("axis[%d]" % idx, item, [int], self.name)
|
||||
else:
|
||||
self.axis = (axis,)
|
||||
self.add_prim_attr("axis", (axis,))
|
||||
|
@ -422,14 +417,14 @@ class Squeeze(PrimitiveWithInfer):
|
|||
ret = [d for d in x_shape if d != 1]
|
||||
else:
|
||||
for a in axis:
|
||||
validator.check_int_range('axis or its elements', a, -ndim, ndim - 1, Rel.INC_BOTH)
|
||||
validator.check_int_range('axis or its elements', a, -ndim, ndim - 1, Rel.INC_BOTH, self.name)
|
||||
if x_shape[a] != 1:
|
||||
raise ValueError('Cannot select an axis to squeeze out which has size not equal to one.')
|
||||
ret = [x_shape[i] for i in range(ndim) if not (i in axis or (i - ndim) in axis)]
|
||||
return ret
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_subclass("x", x_dtype, mstype.tensor)
|
||||
validator.check_subclass("x", x_dtype, mstype.tensor, self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
|
@ -467,14 +462,13 @@ class Transpose(PrimitiveWithInfer):
|
|||
if len(x_shape) != len(p_value):
|
||||
raise ValueError('The dimension of x and perm must be equal.')
|
||||
|
||||
validator.check_const_input("perm", p_value)
|
||||
validator.check_type("p_value", p_value, [tuple])
|
||||
validator.check_subclass("x_type", x_type, mstype.tensor)
|
||||
validator.check_value_type("p_value", p_value, [tuple], self.name)
|
||||
validator.check_subclass("x_type", x_type, mstype.tensor, self.name)
|
||||
|
||||
tmp = list(p_value)
|
||||
for i, dim in enumerate(p_value):
|
||||
validator.check_integer("perm[%d]" % i, dim, 0, Rel.GE)
|
||||
validator.check_integer("perm[%d]" % i, dim, len(p_value), Rel.LT)
|
||||
validator.check_integer("perm[%d]" % i, dim, 0, Rel.GE, self.name)
|
||||
validator.check_integer("perm[%d]" % i, dim, len(p_value), Rel.LT, self.name)
|
||||
tmp.remove(dim)
|
||||
if dim in tmp:
|
||||
raise ValueError('The value of perm is wrong.')
|
||||
|
@ -517,15 +511,13 @@ class GatherV2(PrimitiveWithInfer):
|
|||
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
|
||||
|
||||
def __infer__(self, params, indices, axis):
|
||||
validator.check_subclass("params", params['dtype'], mstype.tensor)
|
||||
validator.check_subclass("indices", indices['dtype'], mstype.tensor)
|
||||
validator.check_subclass("axis", axis['dtype'], mstype.int_)
|
||||
validator.check_typename("element of indices", indices['dtype'], mstype.int_type)
|
||||
validator.check_const_input("axis", axis['value'])
|
||||
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
|
||||
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
|
||||
validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name)
|
||||
axis_v = axis['value']
|
||||
params_shp = params['shape']
|
||||
rank = len(params_shp)
|
||||
validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT)
|
||||
validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name)
|
||||
if axis_v < 0:
|
||||
axis_v += rank
|
||||
out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:]
|
||||
|
@ -564,19 +556,20 @@ class Split(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, axis=0, output_num=1):
|
||||
"""init Split"""
|
||||
validator.check_type("axis", axis, [int])
|
||||
validator.check_type("output_num", output_num, [int])
|
||||
validator.check_value_type("axis", axis, [int], self.name)
|
||||
validator.check_value_type("output_num", output_num, [int], self.name)
|
||||
self.axis = axis
|
||||
self.output_num = output_num
|
||||
|
||||
def __infer__(self, x):
|
||||
validator.check_subclass("x", x['dtype'], mstype.tensor)
|
||||
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
|
||||
x_shape = list(x['shape'])
|
||||
dim = len(x_shape)
|
||||
validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT)
|
||||
validator.check_integer("output_num", self.output_num, 0, Rel.GT)
|
||||
validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name)
|
||||
validator.check_integer("output_num", self.output_num, 0, Rel.GT, self.name)
|
||||
output_valid_check = x_shape[self.axis] % self.output_num
|
||||
validator.check_integer("the dimension which to split divides output_num", output_valid_check, 0, Rel.EQ)
|
||||
validator.check_integer("the dimension which to split divides output_num", output_valid_check, 0, Rel.EQ,
|
||||
self.name)
|
||||
x_shape[self.axis] = int(x_shape[self.axis] / self.output_num)
|
||||
out_shapes = []
|
||||
out_dtypes = []
|
||||
|
@ -615,7 +608,7 @@ class Rank(PrimitiveWithInfer):
|
|||
"""init Rank"""
|
||||
|
||||
def __infer__(self, x):
|
||||
validator.check_subclass("x", x['dtype'], mstype.tensor)
|
||||
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
|
||||
out = {'shape': None,
|
||||
'dtype': None,
|
||||
'value': len(x['shape'])}
|
||||
|
@ -647,15 +640,14 @@ class TruncatedNormal(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, seed=0, dtype=mstype.float32):
|
||||
"""init TruncatedNormal"""
|
||||
validator.check_type('seed', seed, [int])
|
||||
validator.check_typename('dtype', dtype, mstype.number_type)
|
||||
validator.check_value_type('seed', seed, [int], self.name)
|
||||
validator.check_type_same({'dtype': dtype}, mstype.number_type, self.name)
|
||||
|
||||
def __infer__(self, shape):
|
||||
shape_value = shape['value']
|
||||
validator.check_const_input("shape", shape_value)
|
||||
validator.check_type("shape", shape_value, [tuple])
|
||||
validator.check_value_type("shape", shape_value, [tuple], self.name)
|
||||
for i, value in enumerate(shape_value):
|
||||
validator.check_integer(f'{i}th value of shape', value, 0, Rel.GT)
|
||||
validator.check_integer(f'{i}th value of shape', value, 0, Rel.GT, self.name)
|
||||
out = {'shape': shape_value,
|
||||
'dtype': mstype.tensor_type(self.dtype),
|
||||
'value': None}
|
||||
|
@ -687,7 +679,7 @@ class Size(PrimitiveWithInfer):
|
|||
|
||||
def __infer__(self, x):
|
||||
size = 1
|
||||
validator.check_subclass("x", x['dtype'], mstype.tensor)
|
||||
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
|
||||
shp = x['shape']
|
||||
if not shp:
|
||||
size = 0
|
||||
|
@ -723,25 +715,20 @@ class Fill(PrimitiveWithInfer):
|
|||
"""init Fill"""
|
||||
|
||||
def __infer__(self, dtype, dims, x):
|
||||
validator.check_const_input("type", dtype['value'])
|
||||
validator.check_const_input("shape", dims['value'])
|
||||
validator.check_const_input("value", x['value'])
|
||||
validator.check_type("shape", dims['value'], [tuple])
|
||||
validator.check_type("value", x['value'], [numbers.Number, bool])
|
||||
for item in dims['value']:
|
||||
validator.check_type("item", item, [int])
|
||||
validator.check_integer("item", item, 0, Rel.GT)
|
||||
x_dtype = dtype['value']
|
||||
validator.check_value_type("shape", dims['value'], [tuple], self.name)
|
||||
validator.check_value_type("value", x['value'], [numbers.Number, bool], self.name)
|
||||
for idx, item in enumerate(dims['value']):
|
||||
validator.check_integer("dims[%d]" % idx, item, 0, Rel.GT, self.name)
|
||||
valid_types = [mstype.bool_, mstype.int8, mstype.int32, mstype.int64,
|
||||
mstype.uint8, mstype.uint32, mstype.uint64,
|
||||
mstype.float16, mstype.float32, mstype.float64]
|
||||
validator.check_typename("value", x_dtype, valid_types)
|
||||
x_nptype = mstype.dtype_to_nptype(x_dtype)
|
||||
validator.check_type_same({"value": dtype['value']}, valid_types, self.name)
|
||||
x_nptype = mstype.dtype_to_nptype(dtype['value'])
|
||||
ret = np.full(dims['value'], x['value'], x_nptype)
|
||||
out = {
|
||||
'value': Tensor(ret),
|
||||
'shape': dims['value'],
|
||||
'dtype': x_dtype,
|
||||
'dtype': x['dtype'],
|
||||
}
|
||||
return out
|
||||
|
||||
|
@ -772,8 +759,7 @@ class OnesLike(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_subclass("x", x_dtype, mstype.tensor)
|
||||
validator.check_typename('x_dtype', x_dtype, mstype.number_type + (mstype.bool_,))
|
||||
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
|
@ -804,8 +790,7 @@ class ZerosLike(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_subclass("x", x_dtype, mstype.tensor)
|
||||
validator.check_typename('x_dtype', x_dtype, mstype.number_type + (mstype.bool_,))
|
||||
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
|
@ -830,14 +815,13 @@ class TupleToArray(PrimitiveWithInfer):
|
|||
"""init TupleToArray"""
|
||||
|
||||
def infer_value(self, x):
|
||||
validator.check_const_input("x", x)
|
||||
validator.check_type("x", x, [tuple])
|
||||
validator.check("size of x", len(x), '', 0, Rel.GT)
|
||||
validator.check_value_type("x", x, [tuple], self.name)
|
||||
validator.check("size of x", len(x), '', 0, Rel.GT, self.name)
|
||||
dtype = type(x[0])
|
||||
for i, item in enumerate(x):
|
||||
validator.check_type(f"x[{i}]", item, [numbers.Number])
|
||||
validator.check_value_type(f"x[{i}]", item, [numbers.Number], self.name)
|
||||
if not all(isinstance(item, dtype) for item in x):
|
||||
raise TypeError("All elements of input x must be have same type.")
|
||||
raise TypeError("For \'{self.name}\' all elements of input x must be have same type.")
|
||||
if isinstance(x[0], int):
|
||||
ret = np.array(x, np.int32)
|
||||
else:
|
||||
|
@ -867,8 +851,7 @@ class ScalarToArray(PrimitiveWithInfer):
|
|||
pass
|
||||
|
||||
def infer_value(self, x):
|
||||
validator.check_const_input("x", x)
|
||||
validator.check_type("x", x, [int, float])
|
||||
validator.check_value_type("x", x, [int, float], self.name)
|
||||
if isinstance(x, int):
|
||||
ret = np.array(x, np.int32)
|
||||
else:
|
||||
|
@ -899,9 +882,8 @@ class ScalarToTensor(PrimitiveWithInfer):
|
|||
pass
|
||||
|
||||
def infer_value(self, x, dtype=mstype.float32):
|
||||
validator.check_const_input("x", x)
|
||||
validator.check_type("x", x, [int, float])
|
||||
validator.check_subclass("dtype", dtype, mstype.number, with_type_of=False)
|
||||
validator.check_value_type("x", x, [int, float], self.name)
|
||||
validator.check_subclass("dtype", dtype, mstype.number, self.name)
|
||||
data_type = mstype.dtype_to_nptype(dtype)
|
||||
return Tensor(np.array(x, data_type))
|
||||
|
||||
|
@ -943,15 +925,14 @@ class InvertPermutation(PrimitiveWithInfer):
|
|||
def __infer__(self, x):
|
||||
x_shp = x['shape']
|
||||
x_value = x['value']
|
||||
validator.check_const_input("shape", x_shp)
|
||||
validator.check_type("shape", x_shp, [tuple])
|
||||
validator.check_value_type("shape", x_shp, [tuple], self.name)
|
||||
z = [x_value[i] for i in range(len(x_value))]
|
||||
z.sort()
|
||||
|
||||
y = [None]*len(x_value)
|
||||
for i, value in enumerate(x_value):
|
||||
validator.check_type("input[%d]" % i, value, [int])
|
||||
validator.check(f'value', z[i], f'index', i)
|
||||
validator.check_value_type("input[%d]" % i, value, [int], self.name)
|
||||
validator.check(f'value', z[i], f'index', i, Rel.EQ, self.name)
|
||||
y[value] = i
|
||||
z.append(value)
|
||||
return {'shape': x_shp,
|
||||
|
@ -986,8 +967,8 @@ class Argmax(PrimitiveWithInfer):
|
|||
def __init__(self, axis=-1, output_type=mstype.int64):
|
||||
"""init Argmax"""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
validator.check_type("axis", axis, [int])
|
||||
validator.check_typename('output_type', output_type, [mstype.int32, mstype.int64])
|
||||
validator.check_value_type("axis", axis, [int], self.name)
|
||||
validator.check_type_same({'output': output_type}, [mstype.int32, mstype.int64], self.name)
|
||||
self.axis = axis
|
||||
self.add_prim_attr('output_type', output_type)
|
||||
|
||||
|
@ -996,14 +977,13 @@ class Argmax(PrimitiveWithInfer):
|
|||
if axis is None:
|
||||
axis = 0
|
||||
x_rank = len(x_shape)
|
||||
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT)
|
||||
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name)
|
||||
axis = axis + x_rank if axis < 0 else axis
|
||||
ouput_shape = [x_shape[i] for i in range(x_rank) if i != axis]
|
||||
return ouput_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_subclass("input_x", x_dtype, mstype.tensor)
|
||||
validator.check_typename('input_x', x_dtype, [mstype.float32, mstype.float16])
|
||||
validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
|
||||
return mstype.tensor_type(self.output_type)
|
||||
|
||||
|
||||
|
@ -1035,7 +1015,7 @@ class Argmin(PrimitiveWithInfer):
|
|||
def __init__(self, axis=-1, output_type=mstype.int64):
|
||||
"""init Argmin"""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
validator.check_type("axis", axis, [int])
|
||||
validator.check_value_type("axis", axis, [int], self.name)
|
||||
self.axis = axis
|
||||
self.add_prim_attr('output_type', output_type)
|
||||
|
||||
|
@ -1044,13 +1024,13 @@ class Argmin(PrimitiveWithInfer):
|
|||
if axis is None:
|
||||
axis = 0
|
||||
x_rank = len(x_shape)
|
||||
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT)
|
||||
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name)
|
||||
axis = axis + x_rank if axis < 0 else axis
|
||||
ouput_shape = [x_shape[i] for i in range(x_rank) if i != axis]
|
||||
return ouput_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_subclass("input_x", x_dtype, mstype.tensor)
|
||||
validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
|
||||
return mstype.tensor_type(self.output_type)
|
||||
|
||||
|
||||
|
@ -1087,17 +1067,17 @@ class ArgMaxWithValue(PrimitiveWithInfer):
|
|||
"""init ArgMaxWithValue"""
|
||||
self.axis = axis
|
||||
self.keep_dims = keep_dims
|
||||
_check_infer_attr_reduce(axis, keep_dims)
|
||||
_check_infer_attr_reduce(axis, keep_dims, self.name)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
axis = self.axis
|
||||
x_rank = len(x_shape)
|
||||
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT)
|
||||
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name)
|
||||
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name)
|
||||
return ouput_shape, ouput_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_subclass("input_x", x_dtype, mstype.tensor)
|
||||
validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
|
||||
return mstype.tensor_type(mstype.int32), x_dtype
|
||||
|
||||
|
||||
|
@ -1133,17 +1113,17 @@ class ArgMinWithValue(PrimitiveWithInfer):
|
|||
"""init ArgMinWithValue"""
|
||||
self.axis = axis
|
||||
self.keep_dims = keep_dims
|
||||
_check_infer_attr_reduce(axis, keep_dims)
|
||||
_check_infer_attr_reduce(axis, keep_dims, self.name)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
axis = self.axis
|
||||
x_rank = len(x_shape)
|
||||
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT)
|
||||
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name)
|
||||
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name)
|
||||
return ouput_shape, ouput_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_subclass("input_x", x_dtype, mstype.tensor)
|
||||
validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
|
||||
return mstype.tensor_type(mstype.int32), x_dtype
|
||||
|
||||
|
||||
|
@ -1183,13 +1163,11 @@ class Tile(PrimitiveWithInfer):
|
|||
def __infer__(self, x, multiples):
|
||||
multiples_v = multiples['value']
|
||||
x_shp = x['shape']
|
||||
validator.check_const_input("shape", multiples_v)
|
||||
validator.check_type("shape", multiples_v, [tuple])
|
||||
validator.check_value_type("shape", multiples_v, [tuple], self.name)
|
||||
for i, multiple in enumerate(multiples_v):
|
||||
validator.check_type("multiples[%d]" % i, multiple, [int])
|
||||
validator.check_typename('x', x['dtype'],
|
||||
[mstype.int16, mstype.int32, mstype.bool_,
|
||||
mstype.float16, mstype.float32])
|
||||
validator.check_value_type("multiples[%d]" % i, multiple, [int], self.name)
|
||||
valid_types = [mstype.int16, mstype.int32, mstype.bool_, mstype.float16, mstype.float32]
|
||||
validator.check_tensor_type_same({'x': x['dtype']}, valid_types, self.name)
|
||||
len_sub = len(multiples_v) - len(x_shp)
|
||||
multiples_w = None
|
||||
if len_sub == 0:
|
||||
|
@ -1199,7 +1177,8 @@ class Tile(PrimitiveWithInfer):
|
|||
x_shp.insert(0, 1)
|
||||
multiples_w = multiples_v
|
||||
elif len_sub < 0:
|
||||
raise ValueError("The length of multiples can not be smaller than the length of dimension in input_x.")
|
||||
raise ValueError(f'For \'{self.name}\' the length of multiples can not be smaller than '
|
||||
f'the length of dimension in input_x.')
|
||||
for i, a in enumerate(multiples_w):
|
||||
x_shp[i] *= a
|
||||
value = None
|
||||
|
@ -1246,23 +1225,23 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
|
|||
def __infer__(self, x, segment_ids, num_segments):
|
||||
x_type = x['dtype']
|
||||
x_shp = x['shape']
|
||||
validator.check_subclass("input_x", x_type, mstype.tensor)
|
||||
validator.check_type("x_shape", x_shp, [list])
|
||||
validator.check_subclass("input_x", x_type, mstype.tensor, self.name)
|
||||
validator.check_value_type("x_shape", x_shp, [list], self.name)
|
||||
x_shp_len = len(x_shp)
|
||||
validator.check_integer("rank of input_x", x_shp_len, 0, Rel.GT)
|
||||
validator.check_integer("rank of input_x", x_shp_len, 0, Rel.GT, self.name)
|
||||
segment_ids_shp = segment_ids['shape']
|
||||
segment_ids_type = segment_ids['dtype']
|
||||
validator.check_subclass("segment_ids", segment_ids_type, mstype.tensor)
|
||||
validator.check_type("segment_ids", segment_ids_shp, [list])
|
||||
validator.check_subclass("segment_ids", segment_ids_type, mstype.tensor, self.name)
|
||||
validator.check_value_type("segment_ids", segment_ids_shp, [list], self.name)
|
||||
segment_ids_shp_len = len(segment_ids_shp)
|
||||
validator.check_integer("rank of segment_ids", segment_ids_shp_len, 0, Rel.GT)
|
||||
validator.check_integer("rank of segment_ids", segment_ids_shp_len, 0, Rel.GT, self.name)
|
||||
validator.check(f'rank of input_x', len(x_shp),
|
||||
'rank of segments_id', len(segment_ids_shp), Rel.GE)
|
||||
'rank of segments_id', len(segment_ids_shp), Rel.GE, self.name)
|
||||
for i, value in enumerate(segment_ids_shp):
|
||||
validator.check("ids[%d]" % i, value, 'input[%d]' % i, x_shp[i])
|
||||
validator.check("ids[%d]" % i, value, 'input[%d]' % i, x_shp[i], Rel.EQ, self.name)
|
||||
num_segments_v = num_segments['value']
|
||||
validator.check_type('num_segments', num_segments_v, [int])
|
||||
validator.check_integer("num_segments", num_segments_v, 0, Rel.GT)
|
||||
validator.check_value_type('num_segments', num_segments_v, [int], self.name)
|
||||
validator.check_integer("num_segments", num_segments_v, 0, Rel.GT, self.name)
|
||||
shp = [num_segments_v]
|
||||
shp += x_shp[segment_ids_shp_len:]
|
||||
out = {'shape': shp,
|
||||
|
@ -1306,7 +1285,7 @@ class Concat(PrimitiveWithInfer):
|
|||
def __init__(self, axis=0):
|
||||
"""init Tile"""
|
||||
self.__setattr_flag__ = True
|
||||
validator.check_type("axis", axis, [int])
|
||||
validator.check_value_type("axis", axis, [int], self.name)
|
||||
|
||||
def __infer__(self, input_x):
|
||||
axis = self.axis
|
||||
|
@ -1323,25 +1302,25 @@ class Concat(PrimitiveWithInfer):
|
|||
return out
|
||||
|
||||
|
||||
def _get_pack_shape(x_shape, x_type, axis):
|
||||
def _get_pack_shape(x_shape, x_type, axis, prim_name):
|
||||
"""for pack output shape"""
|
||||
validator.check_type("shape", x_shape, [tuple, list])
|
||||
validator.check_integer("len of input_x shape", len(x_shape), 0, Rel.GT)
|
||||
validator.check_subclass("shape0", x_type[0], mstype.tensor)
|
||||
validator.check_integer("len of input_x0 shape", len(x_shape[0]), 0, Rel.GT)
|
||||
validator.check_value_type("shape", x_shape, [tuple, list], prim_name)
|
||||
validator.check_integer("len of input_x shape", len(x_shape), 0, Rel.GT, prim_name)
|
||||
validator.check_subclass("shape0", x_type[0], mstype.tensor, prim_name)
|
||||
validator.check_integer("len of input_x0 shape", len(x_shape[0]), 0, Rel.GT, prim_name)
|
||||
rank_base = len(x_shape[0])
|
||||
N = len(x_shape)
|
||||
out_shape = x_shape[0]
|
||||
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH)
|
||||
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH, prim_name)
|
||||
if axis < 0:
|
||||
axis = axis + rank_base + 1
|
||||
for i in range(1, N):
|
||||
v = x_shape[i]
|
||||
validator.check('len of x_shape[%d]' % i, len(v), 'len of rank_base', rank_base)
|
||||
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0])
|
||||
validator.check('len of x_shape[%d]' % i, len(v), 'len of rank_base', rank_base, Rel.EQ, prim_name)
|
||||
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0], Rel.EQ, prim_name)
|
||||
for j in range(rank_base):
|
||||
if v[j] != x_shape[0][j]:
|
||||
raise ValueError("Pack evaluator element %d shape in input can not pack with first element" % i)
|
||||
raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not pack with first element")
|
||||
out_shape.insert(axis, N)
|
||||
return out_shape
|
||||
|
||||
|
@ -1376,14 +1355,14 @@ class Pack(PrimitiveWithInfer):
|
|||
def __init__(self, axis=0):
|
||||
"""init Pack"""
|
||||
self.__setattr_flag__ = True
|
||||
validator.check_type("axis", axis, [int])
|
||||
validator.check_value_type("axis", axis, [int], self.name)
|
||||
self.axis = axis
|
||||
|
||||
def __infer__(self, value):
|
||||
x_shape = value['shape']
|
||||
x_type = value['dtype']
|
||||
self.add_prim_attr('num', len(x_shape))
|
||||
all_shape = _get_pack_shape(x_shape, x_type, self.axis)
|
||||
all_shape = _get_pack_shape(x_shape, x_type, self.axis, self.name)
|
||||
out = {'shape': all_shape,
|
||||
'dtype': x_type[0],
|
||||
'value': None}
|
||||
|
@ -1429,22 +1408,23 @@ class Unpack(PrimitiveWithInfer):
|
|||
def __init__(self, axis=0):
|
||||
"""init Unpack"""
|
||||
self.__setattr_flag__ = True
|
||||
validator.check_type("axis", axis, [int])
|
||||
validator.check_value_type("axis", axis, [int], self.name)
|
||||
self.axis = axis
|
||||
|
||||
def __infer__(self, x):
|
||||
validator.check_subclass("x", x['dtype'], mstype.tensor)
|
||||
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
|
||||
x_shape = list(x['shape'])
|
||||
dim = len(x_shape)
|
||||
validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT)
|
||||
validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name)
|
||||
if self.axis < 0:
|
||||
self.axis = self.axis + dim
|
||||
output_num = x_shape[self.axis]
|
||||
validator.check_type("num", output_num, [int])
|
||||
validator.check_integer("output_num", output_num, 0, Rel.GT)
|
||||
validator.check_value_type("num", output_num, [int], self.name)
|
||||
validator.check_integer("output_num", output_num, 0, Rel.GT, self.name)
|
||||
self.add_prim_attr('num', output_num)
|
||||
output_valid_check = x_shape[self.axis] - output_num
|
||||
validator.check_integer("The dimension which to unpack divides output_num", output_valid_check, 0, Rel.EQ)
|
||||
validator.check_integer("The dimension which to unpack divides output_num", output_valid_check, 0, Rel.EQ,
|
||||
self.name)
|
||||
out_shapes = []
|
||||
out_dtypes = []
|
||||
out_shape = x_shape[:self.axis] + x_shape[self.axis + 1:]
|
||||
|
@ -1486,8 +1466,8 @@ class Slice(PrimitiveWithInfer):
|
|||
def __infer__(self, x, begin, size):
|
||||
x_shape = x['shape']
|
||||
x_shp_len = len(x_shape)
|
||||
validator.check_const_input('begin', begin['value'])
|
||||
validator.check_const_input('size', size['value'])
|
||||
validator.check_const_input('begin', begin['value'], self.name)
|
||||
validator.check_const_input('size', size['value'], self.name)
|
||||
begin_v, size_v = begin['value'], size['value']
|
||||
if begin_v is None or size_v is None:
|
||||
return {'shape': None,
|
||||
|
@ -1499,7 +1479,8 @@ class Slice(PrimitiveWithInfer):
|
|||
for i in range(x_shp_len):
|
||||
if x_shape[i] < begin_v[i] + size_v[i]:
|
||||
y = begin_v[i] + size_v[i]
|
||||
raise ValueError("Slice shape can not bigger than orign shape %d, %d." % (x_shape[i], y))
|
||||
raise ValueError("For '%s' slice shape can not bigger than orign shape %d, %d." %
|
||||
(self.name, x_shape[i], y))
|
||||
return {'shape': size_v,
|
||||
'dtype': x['dtype'],
|
||||
'value': None}
|
||||
|
@ -1565,11 +1546,11 @@ class Select(PrimitiveWithInfer):
|
|||
|
||||
def infer_dtype(self, cond_type, x_type, y_type):
|
||||
self.add_prim_attr('T', x_type)
|
||||
validator.check_subclass("x_type", x_type, mstype.tensor)
|
||||
validator.check_subclass("y_type", y_type, mstype.tensor)
|
||||
validator.check_typename("cond_type", cond_type, [mstype.bool_])
|
||||
validator.check_subclass("x_type", x_type, mstype.tensor, self.name)
|
||||
validator.check_subclass("y_type", y_type, mstype.tensor, self.name)
|
||||
validator.check_tensor_type_same({"cond": cond_type}, [mstype.bool_], self.name)
|
||||
if x_type != y_type:
|
||||
raise TypeError('The x_type %s must be the same as y_type %s.' % (x_type, y_type))
|
||||
raise TypeError('\'%s\' the x_type %s must be the same as y_type %s.' % (self.name, x_type, y_type))
|
||||
return x_type
|
||||
|
||||
|
||||
|
@ -1637,26 +1618,23 @@ class StridedSlice(PrimitiveWithInfer):
|
|||
shrink_axis_mask=0):
|
||||
"""init StrideSlice"""
|
||||
self.init_prim_io_names(inputs=['x', 'begin', 'end', 'strides'], outputs=['output'])
|
||||
validator.check_type('begin_mask', begin_mask, [int])
|
||||
validator.check_type('end_mask', end_mask, [int])
|
||||
validator.check_type('ellipsis_mask', ellipsis_mask, [int])
|
||||
validator.check_type('new_axis_mask', new_axis_mask, [int])
|
||||
validator.check_type('shrink_axis_mask', shrink_axis_mask, [int])
|
||||
validator.check_value_type('begin_mask', begin_mask, [int], self.name)
|
||||
validator.check_value_type('end_mask', end_mask, [int], self.name)
|
||||
validator.check_value_type('ellipsis_mask', ellipsis_mask, [int], self.name)
|
||||
validator.check_value_type('new_axis_mask', new_axis_mask, [int], self.name)
|
||||
validator.check_value_type('shrink_axis_mask', shrink_axis_mask, [int], self.name)
|
||||
|
||||
def __infer__(self, x, begin, end, strides):
|
||||
begin_v, end_v, strides_v = begin['value'], end['value'], strides['value']
|
||||
validator.check_const_input("begin", begin_v)
|
||||
validator.check_const_input("end", end_v)
|
||||
validator.check_const_input("strides", strides_v)
|
||||
validator.check_type("begin", begin_v, [tuple])
|
||||
validator.check_type("end", end_v, [tuple])
|
||||
validator.check_type("strides", strides_v, [tuple])
|
||||
validator.check_value_type("begin", begin_v, [tuple], self.name)
|
||||
validator.check_value_type("end", end_v, [tuple], self.name)
|
||||
validator.check_value_type("strides", strides_v, [tuple], self.name)
|
||||
|
||||
x_shape = x['shape']
|
||||
x_shp_len = len(x_shape)
|
||||
if len(begin_v) != x_shp_len or len(end_v) != x_shp_len or len(strides_v) != x_shp_len:
|
||||
raise ValueError(f"The length of begin index{begin_v}, end index{end_v} and strides{strides_v} "
|
||||
f"must be equal to the dims({x_shp_len}) of input.")
|
||||
raise ValueError(f"For \'{self.name}\' the length of begin index{begin_v}, end index{end_v} and "
|
||||
f"strides{strides_v} must be equal to the dims({x_shp_len}) of input.")
|
||||
|
||||
ret_shape = []
|
||||
append_dimensions = []
|
||||
|
@ -1669,8 +1647,8 @@ class StridedSlice(PrimitiveWithInfer):
|
|||
append_dimensions.append(x_shape[x_shp_len - 1 - len(append_dimensions)])
|
||||
continue
|
||||
if i < (len(shrink_pos) - 2) and shrink_pos[i] == '1':
|
||||
validator.check_integer(f'begin[{i}]', begin_v[i], -x_shape[i], Rel.GE)
|
||||
validator.check_integer(f'begin[{i}]', begin_v[i], x_shape[i], Rel.LT)
|
||||
validator.check_integer(f'begin[{i}]', begin_v[i], -x_shape[i], Rel.GE, self.name)
|
||||
validator.check_integer(f'begin[{i}]', begin_v[i], x_shape[i], Rel.LT, self.name)
|
||||
continue
|
||||
|
||||
begin_idx = begin_v[i]
|
||||
|
@ -1680,9 +1658,9 @@ class StridedSlice(PrimitiveWithInfer):
|
|||
begin_idx = 0
|
||||
if self.end_mask:
|
||||
end_idx = x_shape[i]
|
||||
validator.check_integer(f'begin[{i}]', begin_idx, x_shape[i], Rel.LE)
|
||||
validator.check_integer(f'end[{i}]', end_idx, x_shape[i], Rel.LE)
|
||||
validator.check_integer(f'strides[{i}]', strides_idx, 0, Rel.NE)
|
||||
validator.check_integer(f'begin[{i}]', begin_idx, x_shape[i], Rel.LE, self.name)
|
||||
validator.check_integer(f'end[{i}]', end_idx, x_shape[i], Rel.LE, self.name)
|
||||
validator.check_integer(f'strides[{i}]', strides_idx, 0, Rel.NE, self.name)
|
||||
if strides_idx > 0:
|
||||
# If sliced forward , end_idx >= begin_idx
|
||||
validator.check(f'begin[{i}]', begin_idx, f'end[{i}]', end_idx, Rel.LE)
|
||||
|
@ -1736,7 +1714,7 @@ class Diag(PrimitiveWithInfer):
|
|||
"""init Diag"""
|
||||
|
||||
def infer_dtype(self, x_type):
|
||||
validator.check_subclass('input_x', x_type, mstype.tensor)
|
||||
validator.check_subclass('input_x', x_type, mstype.tensor, self.name)
|
||||
return x_type
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
|
@ -1748,7 +1726,7 @@ class Diag(PrimitiveWithInfer):
|
|||
def infer_value(self, x):
|
||||
if x is None:
|
||||
return None
|
||||
validator.check("input x rank", len(x.shape()), "", 1)
|
||||
validator.check_integer("input x rank", len(x.shape()), 1, Rel.EQ, self.name)
|
||||
ret = np.diag(x.asnumpy())
|
||||
return Tensor(ret)
|
||||
|
||||
|
@ -1783,13 +1761,13 @@ class DiagPart(PrimitiveWithInfer):
|
|||
"""init DiagPart"""
|
||||
|
||||
def infer_dtype(self, x_type):
|
||||
validator.check_subclass('input_x', x_type, mstype.tensor)
|
||||
validator.check_subclass('input_x', x_type, mstype.tensor, self.name)
|
||||
return x_type
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
if len(x_shape)%2 != 0 or \
|
||||
not x_shape:
|
||||
raise ValueError(f"DiagPart input rank must be non-zero and even, but got rank {len(x_shape)}, "
|
||||
raise ValueError(f"For \'{self.name}\' input rank must be non-zero and even, but got rank {len(x_shape)}, "
|
||||
f"with shapes {x_shape}")
|
||||
length = len(x_shape) // 2
|
||||
ret_shape = x_shape[0:length]
|
||||
|
@ -1798,7 +1776,7 @@ class DiagPart(PrimitiveWithInfer):
|
|||
def infer_value(self, x):
|
||||
if x is None:
|
||||
return None
|
||||
validator.check("x rank", len(x.shape()), "", 2)
|
||||
validator.check("x rank", len(x.shape()), "", 2, Rel.EQ, self.name)
|
||||
ret = np.diag(x.asnumpy())
|
||||
return Tensor(ret)
|
||||
|
||||
|
@ -1826,12 +1804,10 @@ class Eye(PrimitiveWithInfer):
|
|||
"""init Eye"""
|
||||
|
||||
def infer_value(self, n, m, t):
|
||||
validator.check_type("n", n, [int])
|
||||
validator.check_integer("n", n, 0, Rel.GT)
|
||||
validator.check_type("m", m, [int])
|
||||
validator.check_integer("m", m, 0, Rel.GT)
|
||||
validator.check_integer("n", n, 0, Rel.GT, self.name)
|
||||
validator.check_integer("m", m, 0, Rel.GT, self.name)
|
||||
args = {"dtype": t}
|
||||
validator.check_type_same(args, mstype.number_type + (mstype.bool_,))
|
||||
validator.check_type_same(args, mstype.number_type + (mstype.bool_,), self.name)
|
||||
np_type = mstype.dtype_to_nptype(t)
|
||||
ret = np.eye(n, m, dtype=np_type)
|
||||
return Tensor(ret)
|
||||
|
@ -1866,16 +1842,15 @@ class ScatterNd(PrimitiveWithInfer):
|
|||
|
||||
def __infer__(self, indices, update, shape):
|
||||
shp = shape['value']
|
||||
validator.check_subclass("indices_dtype", indices['dtype'], mstype.tensor)
|
||||
validator.check_subclass("update_dtype", update['dtype'], mstype.tensor)
|
||||
validator.check_typename("indices_dtype", indices['dtype'], mstype.int_type)
|
||||
validator.check_type("shape", shp, [tuple])
|
||||
validator.check_subclass("update_dtype", update['dtype'], mstype.tensor, self.name)
|
||||
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
|
||||
validator.check_value_type("shape", shp, [tuple], self.name)
|
||||
for i, x in enumerate(shp):
|
||||
validator.check_integer("shape[%d]" % i, x, 0, Rel.GT)
|
||||
validator.check_integer("shape[%d]" % i, x, 0, Rel.GT, self.name)
|
||||
|
||||
indices_shape, update_shape = indices["shape"], update["shape"]
|
||||
if indices_shape[0] != update_shape[0]:
|
||||
raise ValueError('The indices_shape[0] and update_shape[0] must be equal.')
|
||||
raise ValueError(f'For \'{self.name}\' The indices_shape[0] and update_shape[0] must be equal.')
|
||||
|
||||
return {'shape': shp,
|
||||
'dtype': update['dtype'],
|
||||
|
@ -1913,7 +1888,7 @@ class ResizeNearestNeighbor(PrimitiveWithInfer):
|
|||
self.init_prim_io_names(inputs=['image_in'], outputs=['image_out'])
|
||||
|
||||
def infer_shape(self, x):
|
||||
validator.check('the dimension of input_x', len(x), '', 2, Rel.GE)
|
||||
validator.check('the dimension of input_x', len(x), '', 2, Rel.GE, self.name)
|
||||
return tuple(x)[:-2] + tuple(self.size)
|
||||
|
||||
def infer_dtype(self, x):
|
||||
|
@ -1947,13 +1922,12 @@ class GatherNd(PrimitiveWithInfer):
|
|||
|
||||
def infer_shape(self, x_shape, indices_shape):
|
||||
validator.check('the dimension of x', len(x_shape),
|
||||
'the dimension of indices', indices_shape[-1], Rel.GE)
|
||||
'the dimension of indices', indices_shape[-1], Rel.GE, self.name)
|
||||
return indices_shape[:-1] + x_shape[indices_shape[-1]:]
|
||||
|
||||
def infer_dtype(self, x_dtype, indices_dtype):
|
||||
validator.check_subclass("x_dtype", x_dtype, mstype.tensor)
|
||||
validator.check_subclass("indices_dtype", indices_dtype, mstype.tensor)
|
||||
validator.check_typename("indices_dtype", indices_dtype, mstype.int_type)
|
||||
validator.check_subclass("x_dtype", x_dtype, mstype.tensor, self.name)
|
||||
validator.check_tensor_type_same({"indices": indices_dtype}, mstype.int_type, self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
|
@ -1995,12 +1969,9 @@ class ScatterNdUpdate(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
|
||||
validator.check_subclass("x_dtype", x_dtype, mstype.tensor)
|
||||
validator.check_subclass("indices_dtype", indices_dtype, mstype.tensor)
|
||||
validator.check_subclass("value_dtype", value_dtype, mstype.tensor)
|
||||
validator.check_typename('indices_dtype', indices_dtype, mstype.int_type)
|
||||
args = {"x_dtype": x_dtype, "value_dtype": value_dtype}
|
||||
validator.check_type_same(args, (mstype.bool_,) + mstype.number_type)
|
||||
validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name)
|
||||
args = {"x": x_dtype, "value": value_dtype}
|
||||
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
|
@ -2038,7 +2009,7 @@ class SpaceToDepth(PrimitiveWithInfer):
|
|||
def __init__(self, block_size):
|
||||
"""Init SpaceToDepth"""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
||||
validator.check_type('block_size', block_size, [int])
|
||||
validator.check_value_type('block_size', block_size, [int], self.name)
|
||||
validator.check('block_size', block_size, '', 2, Rel.GE)
|
||||
self.block_size = block_size
|
||||
self.add_prim_attr("data_format", "NCHW")
|
||||
|
@ -2048,7 +2019,7 @@ class SpaceToDepth(PrimitiveWithInfer):
|
|||
out_shape = copy.deepcopy(x_shape)
|
||||
for i in range(2):
|
||||
if out_shape[i+2] % self.block_size != 0:
|
||||
raise ValueError(f'SpaceToDepth input shape[{i+2}] {out_shape[i+2]} should be '
|
||||
raise ValueError(f'For \'{self.name}\' input shape[{i+2}] {out_shape[i+2]} should be '
|
||||
f'fully divided by block_size {self.block_size}')
|
||||
out_shape[i+2] //= self.block_size
|
||||
|
||||
|
@ -2056,7 +2027,7 @@ class SpaceToDepth(PrimitiveWithInfer):
|
|||
return out_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_subclass("x_dtype", x_dtype, mstype.tensor)
|
||||
validator.check_subclass("x_dtype", x_dtype, mstype.tensor, self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
|
@ -2096,8 +2067,8 @@ class DepthToSpace(PrimitiveWithInfer):
|
|||
def __init__(self, block_size):
|
||||
"""Init DepthToSpace"""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
||||
validator.check_type('block_size', block_size, [int])
|
||||
validator.check('block_size', block_size, '', 2, Rel.GE)
|
||||
validator.check_value_type('block_size', block_size, [int], self.name)
|
||||
validator.check('block_size', block_size, '', 2, Rel.GE, self.name)
|
||||
self.block_size = block_size
|
||||
self.add_prim_attr("data_format", "NCHW")
|
||||
|
||||
|
@ -2107,12 +2078,13 @@ class DepthToSpace(PrimitiveWithInfer):
|
|||
for i in range(2):
|
||||
out_shape[i+2] *= self.block_size
|
||||
|
||||
validator.check('x_shape[1] % (block_size*block_size)', x_shape[1] % (self.block_size*self.block_size), '', 0)
|
||||
validator.check_integer('x_shape[1] % (block_size*block_size)', x_shape[1] % (self.block_size*self.block_size),
|
||||
0, Rel.EQ, self.name)
|
||||
out_shape[1] //= self.block_size * self.block_size
|
||||
return out_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_subclass("x_dtype", x_dtype, mstype.tensor)
|
||||
validator.check_subclass("x_dtype", x_dtype, mstype.tensor, self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
|
@ -2159,27 +2131,26 @@ class SpaceToBatch(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, block_size, paddings):
|
||||
"""Init SpaceToBatch"""
|
||||
validator.check_type('block_size', block_size, [int])
|
||||
validator.check('block_size', block_size, '', 1, Rel.GT)
|
||||
validator.check_value_type('block_size', block_size, [int], self.name)
|
||||
validator.check('block_size', block_size, '', 1, Rel.GT, self.name)
|
||||
self.block_size = block_size
|
||||
validator.check('paddings shape', np.array(paddings).shape, '', (2, 2))
|
||||
validator.check('paddings shape', np.array(paddings).shape, '', (2, 2), Rel.EQ, self.name)
|
||||
for elem in itertools.chain(*paddings):
|
||||
validator.check_type('paddings element', elem, [int])
|
||||
validator.check_value_type('paddings element', elem, [int], self.name)
|
||||
self.paddings = paddings
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_subclass("input_x", x_dtype, mstype.tensor)
|
||||
validator.check_typename('input_x', x_dtype, mstype.number_type)
|
||||
validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name)
|
||||
return x_dtype
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
validator.check('rank of input_x', len(x_shape), '', 4)
|
||||
validator.check_integer('rank of input_x', len(x_shape), 4, Rel.EQ, self.name)
|
||||
out_shape = copy.deepcopy(x_shape)
|
||||
for i in range(2):
|
||||
padded = out_shape[i+2] + self.paddings[i][0] + \
|
||||
self.paddings[i][1]
|
||||
if padded % self.block_size != 0:
|
||||
raise ValueError(f'padded[{i}] {padded} should be divisible by '
|
||||
raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by '
|
||||
f'block_size {self.block_size}')
|
||||
out_shape[i+2] = padded // self.block_size
|
||||
out_shape[0] *= self.block_size * self.block_size
|
||||
|
@ -2227,17 +2198,16 @@ class BatchToSpace(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, block_size, crops):
|
||||
"""Init BatchToSpace"""
|
||||
validator.check_type('block_size', block_size, [int])
|
||||
validator.check('block_size', block_size, '', 1, Rel.GT)
|
||||
validator.check_value_type('block_size', block_size, [int], self.name)
|
||||
validator.check('block_size', block_size, '', 1, Rel.GT, self.name)
|
||||
self.block_size = block_size
|
||||
validator.check('crops shape', np.array(crops).shape, '', (2, 2))
|
||||
for elem in itertools.chain(*crops):
|
||||
validator.check_type('crops element', elem, [int])
|
||||
validator.check_value_type('crops element', elem, [int], self.name)
|
||||
self.crops = crops
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_subclass("input_x", x_dtype, mstype.tensor)
|
||||
validator.check_typename('input_x', x_dtype, mstype.number_type)
|
||||
validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name)
|
||||
return x_dtype
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
|
@ -2246,11 +2216,11 @@ class BatchToSpace(PrimitiveWithInfer):
|
|||
for i in range(2):
|
||||
x_block_prod = out_shape[i+2] * self.block_size
|
||||
crops_sum = self.crops[i][0] + self.crops[i][1]
|
||||
validator.check("x block shape prod", x_block_prod, 'crops sum', crops_sum, Rel.GT)
|
||||
validator.check("x block shape prod", x_block_prod, 'crops sum', crops_sum, Rel.GT, self.name)
|
||||
out_shape[i+2] = x_block_prod - crops_sum
|
||||
block_size_prod = self.block_size * self.block_size
|
||||
if out_shape[0] % block_size_prod != 0:
|
||||
raise ValueError(f'input_x dimension 0 {out_shape[0]} should be divisible by '
|
||||
raise ValueError(f'For \'{self.name}\' input_x dimension 0 {out_shape[0]} should be divisible by '
|
||||
f'block_size_prod {block_size_prod}')
|
||||
out_shape[0] = out_shape[0] // block_size_prod
|
||||
return out_shape
|
||||
|
|
|
@ -0,0 +1,159 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test ops """
|
||||
import functools
|
||||
import numpy as np
|
||||
from mindspore import ops
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
import mindspore.ops.composite as C
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.parameter import Parameter
|
||||
from ..ut_filter import non_graph_engine
|
||||
from mindspore.common.api import _executor
|
||||
|
||||
from ....mindspore_test_framework.mindspore_test import mindspore_test
|
||||
from ....mindspore_test_framework.pipeline.forward.compile_forward\
|
||||
import (pipeline_for_compile_forward_ge_graph_for_case_by_case_config,
|
||||
pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception)
|
||||
from ....mindspore_test_framework.pipeline.gradient.compile_gradient\
|
||||
import pipeline_for_compile_grad_ge_graph_for_case_by_case_config
|
||||
|
||||
|
||||
class ExpandDimsNet(nn.Cell):
|
||||
def __init__(self, axis):
|
||||
super(ExpandDimsNet, self).__init__()
|
||||
self.axis = axis
|
||||
self.op = P.ExpandDims()
|
||||
|
||||
def construct(self, x):
|
||||
return self.op(x, self.axis)
|
||||
|
||||
|
||||
class IsInstanceNet(nn.Cell):
|
||||
def __init__(self, inst):
|
||||
super(IsInstanceNet, self).__init__()
|
||||
self.inst = inst
|
||||
self.op = P.IsInstance()
|
||||
|
||||
def construct(self, t):
|
||||
return self.op(self.inst, t)
|
||||
|
||||
|
||||
class ReshapeNet(nn.Cell):
|
||||
def __init__(self, shape):
|
||||
super(ReshapeNet, self).__init__()
|
||||
self.shape = shape
|
||||
self.op = P.Reshape()
|
||||
|
||||
def construct(self, x):
|
||||
return self.op(x, self.shape)
|
||||
|
||||
|
||||
raise_set = [
|
||||
# input is scala, not Tensor
|
||||
('ExpandDims0', {
|
||||
'block': (P.ExpandDims(), {'exception': TypeError, 'error_keywords': ['ExpandDims']}),
|
||||
'desc_inputs': [5.0, 1],
|
||||
'skip': ['backward']}),
|
||||
# axis is as a parameter
|
||||
('ExpandDims1', {
|
||||
'block': (P.ExpandDims(), {'exception': TypeError, 'error_keywords': ['ExpandDims']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), 1],
|
||||
'skip': ['backward']}),
|
||||
# axis as an attribute, but less then lower limit
|
||||
('ExpandDims2', {
|
||||
'block': (ExpandDimsNet(-4), {'exception': ValueError, 'error_keywords': ['ExpandDims']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# axis as an attribute, but greater then upper limit
|
||||
('ExpandDims3', {
|
||||
'block': (ExpandDimsNet(3), {'exception': ValueError, 'error_keywords': ['ExpandDims']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input is scala, not Tensor
|
||||
('DType0', {
|
||||
'block': (P.DType(), {'exception': TypeError, 'error_keywords': ['DType']}),
|
||||
'desc_inputs': [5.0],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input x scala, not Tensor
|
||||
('SameTypeShape0', {
|
||||
'block': (P.SameTypeShape(), {'exception': TypeError, 'error_keywords': ['SameTypeShape']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input y scala, not Tensor
|
||||
('SameTypeShape1', {
|
||||
'block': (P.SameTypeShape(), {'exception': TypeError, 'error_keywords': ['SameTypeShape']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), 5.0],
|
||||
'skip': ['backward']}),
|
||||
# type of x and y not match
|
||||
('SameTypeShape2', {
|
||||
'block': (P.SameTypeShape(), {'exception': TypeError, 'error_keywords': ['SameTypeShape']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.int32))],
|
||||
'skip': ['backward']}),
|
||||
# shape of x and y not match
|
||||
('SameTypeShape3', {
|
||||
'block': (P.SameTypeShape(), {'exception': ValueError, 'error_keywords': ['SameTypeShape']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 3]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# sub_type is None
|
||||
('IsSubClass0', {
|
||||
'block': (P.IsSubClass(), {'exception': TypeError, 'error_keywords': ['IsSubClass']}),
|
||||
'desc_inputs': [None, mstype.number],
|
||||
'skip': ['backward']}),
|
||||
# type_ is None
|
||||
('IsSubClass1', {
|
||||
'block': (P.IsSubClass(), {'exception': TypeError, 'error_keywords': ['IsSubClass']}),
|
||||
'desc_inputs': [mstype.number, None],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# inst is var
|
||||
('IsInstance0', {
|
||||
'block': (P.IsInstance(), {'exception': ValueError, 'error_keywords': ['IsInstance']}),
|
||||
'desc_inputs': [5.0, mstype.number],
|
||||
'skip': ['backward']}),
|
||||
# t is not mstype.Type
|
||||
('IsInstance1', {
|
||||
'block': (IsInstanceNet(5.0), {'exception': TypeError, 'error_keywords': ['IsInstance']}),
|
||||
'desc_inputs': [None],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input x is scalar, not Tensor
|
||||
('Reshape0', {
|
||||
'block': (P.Reshape(), {'exception': TypeError, 'error_keywords': ['Reshape']}),
|
||||
'desc_inputs': [5.0, (1, 2)],
|
||||
'skip': ['backward']}),
|
||||
# input shape is var
|
||||
('Reshape1', {
|
||||
'block': (P.Reshape(), {'exception': TypeError, 'error_keywords': ['Reshape']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), (2, 3, 2)],
|
||||
'skip': ['backward']}),
|
||||
# element of shape is not int
|
||||
('Reshape3', {
|
||||
'block': (ReshapeNet((2, 3.0, 2)), {'exception': TypeError, 'error_keywords': ['Reshape']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
]
|
||||
|
||||
|
||||
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception)
|
||||
def test_check_exception():
|
||||
return raise_set
|
|
@ -383,7 +383,7 @@ def test_tensor_slice_reduce_out_of_bounds_neg():
|
|||
net = NetWork()
|
||||
with pytest.raises(ValueError) as ex:
|
||||
net(input_tensor)
|
||||
assert "The `begin[0]` should be an int and must greater or equal to -6, but got -7" in str(ex.value)
|
||||
assert "For 'StridedSlice' the `begin[0]` should be an int and must greater or equal to -6, but got `-7`" in str(ex.value)
|
||||
|
||||
|
||||
def test_tensor_slice_reduce_out_of_bounds_positive():
|
||||
|
@ -400,4 +400,4 @@ def test_tensor_slice_reduce_out_of_bounds_positive():
|
|||
net = NetWork()
|
||||
with pytest.raises(ValueError) as ex:
|
||||
net(input_tensor)
|
||||
assert "The `begin[0]` should be an int and must less than 6, but got 6" in str(ex.value)
|
||||
assert "For 'StridedSlice' the `begin[0]` should be an int and must less than 6, but got `6`" in str(ex.value)
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
import numpy as np
|
||||
from mindspore._checkparam import Rel
|
||||
from mindspore._checkparam import ParamValidator as validator
|
||||
from mindspore._checkparam import Validator as validator
|
||||
|
||||
|
||||
def avg_pooling(x, pool_h, pool_w, stride):
|
||||
|
@ -32,7 +32,7 @@ def avg_pooling(x, pool_h, pool_w, stride):
|
|||
Returns:
|
||||
numpy.ndarray, an output array after applying average pooling on input array.
|
||||
"""
|
||||
validator.check_integer("stride", stride, 0, Rel.GT)
|
||||
validator.check_integer("stride", stride, 0, Rel.GT, None)
|
||||
num, channel, height, width = x.shape
|
||||
out_h = (height - pool_h)//stride + 1
|
||||
out_w = (width - pool_w)//stride + 1
|
||||
|
@ -217,7 +217,7 @@ def conv2d(x, weight, bias=None, stride=1, pad=0,
|
|||
dilation=1, groups=1, padding_mode='zeros'):
|
||||
"""Convolution 2D."""
|
||||
# pylint: disable=unused-argument
|
||||
validator.check_type('stride', stride, (int, tuple))
|
||||
validator.check_value_type('stride', stride, (int, tuple), None)
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride)
|
||||
elif len(stride) == 4:
|
||||
|
@ -229,7 +229,7 @@ def conv2d(x, weight, bias=None, stride=1, pad=0,
|
|||
f"a tuple of two positive int numbers, but got {stride}")
|
||||
stride_h = stride[0]
|
||||
stride_w = stride[1]
|
||||
validator.check_type('dilation', dilation, (int, tuple))
|
||||
validator.check_value_type('dilation', dilation, (int, tuple), None)
|
||||
if isinstance(dilation, int):
|
||||
dilation = (dilation, dilation)
|
||||
elif len(dilation) == 4:
|
||||
|
@ -384,7 +384,7 @@ def matmul(x, w, b=None):
|
|||
|
||||
def max_pooling(x, pool_h, pool_w, stride):
|
||||
"""Max pooling."""
|
||||
validator.check_integer("stride", stride, 0, Rel.GT)
|
||||
validator.check_integer("stride", stride, 0, Rel.GT, None)
|
||||
num, channel, height, width = x.shape
|
||||
out_h = (height - pool_h)//stride + 1
|
||||
out_w = (width - pool_w)//stride + 1
|
||||
|
@ -427,7 +427,7 @@ def max_pool_grad_with_argmax(x, dout, arg_max, pool_h, pool_w, stride):
|
|||
|
||||
def max_pool_with_argmax(x, pool_h, pool_w, stride):
|
||||
"""Max pooling with argmax."""
|
||||
validator.check_integer("stride", stride, 0, Rel.GT)
|
||||
validator.check_integer("stride", stride, 0, Rel.GT, None)
|
||||
num, channel, height, width = x.shape
|
||||
out_h = (height - pool_h)//stride + 1
|
||||
out_w = (width - pool_w)//stride + 1
|
||||
|
|
Loading…
Reference in New Issue