Rectify and optimize the type checking function

This commit is contained in:
buxue 2020-11-02 16:00:30 +08:00
parent 9ae5f96988
commit 346bcfa3fd
30 changed files with 648 additions and 712 deletions

View File

@ -415,37 +415,20 @@ class Validator:
break break
if not hit: if not hit:
type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_) type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_)
raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass' raise TypeError(f'For \'{prim_name}\', the type of `{arg_name}` should be subclass'
f' of {",".join((str(x) for x in template_types))}, but got {type_str}.') f' of {", ".join((str(x) for x in template_types))}, but got {type_str}.')
@staticmethod @staticmethod
def check_const_input(arg_name, arg_value, prim_name): def check_const_input(arg_name, arg_value, prim_name):
"""Checks valid value.""" """Checks valid value."""
if arg_value is None: if arg_value is None:
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must be a const input, but got {arg_value}.') raise ValueError(f'For \'{prim_name}\', the `{arg_name}` must be a const input, but got {arg_value}.')
return arg_value return arg_value
@staticmethod @staticmethod
def check_type(arg_name, arg_value, valid_types): def check_types_same_and_valid(args, valid_values, prim_name):
"""Type checking.""" """Checks whether the types of inputs are the same and valid."""
def raise_error_msg(): def _check_type_valid(arg):
"""func for raising error message when check failed"""
raise TypeError(f'The type of `{arg_name}` should be in {valid_types}, but got {type(arg_value).__name__}.')
if isinstance(arg_value, type(mstype.tensor)):
arg_value = arg_value.element_type()
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
raise_error_msg()
if arg_value in valid_types:
return arg_value
if isinstance(arg_value, tuple(valid_types)):
return arg_value
raise_error_msg()
@staticmethod
def check_type_same(args, valid_values, prim_name):
"""Checks whether the types of inputs are the same."""
def _check_tensor_type(arg):
arg_key, arg_val = arg arg_key, arg_val = arg
elem_type = arg_val elem_type = arg_val
Validator.check_subclass(arg_key, elem_type, valid_values, prim_name) Validator.check_subclass(arg_key, elem_type, valid_values, prim_name)
@ -455,21 +438,27 @@ class Validator:
arg1_name, arg1_type = arg1 arg1_name, arg1_type = arg1
arg2_name, arg2_type = arg2 arg2_name, arg2_type = arg2
if arg1_type != arg2_type: if arg1_type != arg2_type:
raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,' raise TypeError(f'For \'{prim_name}\', type of `{arg2_name}` should be same as `{arg1_name}`,'
f' but `{arg1_name}` with type {arg1_type} and `{arg2_name}` with type {arg2_type}.') f' but `{arg1_name}` with type {arg1_type} and `{arg2_name}` with type {arg2_type}.')
return arg1 return arg1
elem_types = map(_check_tensor_type, args.items()) elem_types = map(_check_type_valid, args.items())
reduce(_check_types_same, elem_types) reduce(_check_types_same, elem_types)
@staticmethod @staticmethod
def check_tensor_type_same(args, valid_values, prim_name): def check_tensors_dtypes_same_and_valid(args, valid_dtypes, prim_name):
"""Checks whether the element types of input tensors are the same.""" """Checks whether the element types of input tensors are the same and valid."""
tensor_types = [mstype.tensor_type(t) for t in valid_values] tensor_types = [mstype.tensor_type(t) for t in valid_dtypes]
Validator.check_type_same(args, tensor_types, prim_name) Validator.check_types_same_and_valid(args, tensor_types, prim_name)
@staticmethod @staticmethod
def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False): def check_tensor_dtype_valid(arg_name, arg_type, valid_dtypes, prim_name):
"""Checks whether the element types of input tensors are valid."""
tensor_types = [mstype.tensor_type(t) for t in valid_dtypes]
Validator.check_subclass(arg_name, arg_type, tensor_types, prim_name)
@staticmethod
def check_scalar_or_tensor_types_same(args, valid_values, prim_name, allow_mix=False):
""" """
Checks whether the types of inputs are the same. If the input args are tensors, checks their element types. Checks whether the types of inputs are the same. If the input args are tensors, checks their element types.
If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised. If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised.
@ -480,7 +469,7 @@ class Validator:
if isinstance(arg_val, type(mstype.tensor)): if isinstance(arg_val, type(mstype.tensor)):
arg_val = arg_val.element_type() arg_val = arg_val.element_type()
if not arg_val in valid_values: if not arg_val in valid_values:
raise TypeError(f'For \'{prim_name}\' the `{arg_key}` should be in {valid_values},' raise TypeError(f'For \'{prim_name}\', the `{arg_key}` should be in {valid_values},'
f' but `{arg_key}` is {arg_val}.') f' but `{arg_key}` is {arg_val}.')
return arg return arg
@ -512,40 +501,40 @@ class Validator:
def raise_error_msg(): def raise_error_msg():
"""func for raising error message when check failed""" """func for raising error message when check failed"""
type_names = [t.__name__ for t in valid_types] type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in valid_types]
num_types = len(valid_types) num_types = len(valid_types)
msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}' raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') f'{type_names if num_types > 1 else type_names[0]}, '
f'but got {arg_value} with type {type(arg_value).__name__}.')
# Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and # Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
# `check_value_type('x', True, [bool, int])` will check pass # `check_value_type('x', True, [bool, int])` will check pass
if isinstance(arg_value, bool) and bool not in tuple(valid_types): if isinstance(arg_value, bool) and bool not in tuple(valid_types):
raise_error_msg() raise_error_msg()
if isinstance(arg_value, tuple(valid_types)): if not isinstance(arg_value, tuple(valid_types)):
return arg_value raise_error_msg()
raise_error_msg() return arg_value
@staticmethod @staticmethod
def check_type_name(arg_name, arg_type, valid_types, prim_name): def check_type_name(arg_name, arg_type, valid_types, prim_name):
"""Checks whether a type in some specified types""" """Checks whether a type in some specified types"""
valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
def get_typename(t): def raise_error_msg():
return t.__name__ if hasattr(t, '__name__') else str(t) """func for raising error message when check failed"""
type_names = [t.__name__ if hasattr(t, '__name__') else t for t in valid_types]
num_types = len(valid_types)
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
raise TypeError(f"{msg_prefix} '{arg_name}' should be {'one of ' if num_types > 1 else ''}"
f"{type_names if num_types > 1 else type_names[0]}, "
f"but got {arg_type.__name__ if hasattr(arg_type, '__name__') else repr(arg_type)}.")
if isinstance(arg_type, type(mstype.tensor)): if isinstance(arg_type, type(mstype.tensor)):
arg_type = arg_type.element_type() arg_type = arg_type.element_type()
if arg_type not in valid_types:
if arg_type in valid_types: raise_error_msg()
return arg_type return arg_type
type_names = [get_typename(t) for t in valid_types]
msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The'
if len(valid_types) == 1:
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {type_names[0]},'
f' but got {get_typename(arg_type)}.')
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},'
f' but got {get_typename(arg_type)}.')
@staticmethod @staticmethod
def check_reduce_shape(ori_shape, shape, axis, prim_name): def check_reduce_shape(ori_shape, shape, axis, prim_name):
@ -611,65 +600,6 @@ def check_output_data(data):
once = _expand_tuple(1) once = _expand_tuple(1)
twice = _expand_tuple(2) twice = _expand_tuple(2)
triple = _expand_tuple(3) triple = _expand_tuple(3)
valid_data_types = (int, float, np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
np.float32, np.float64, bool, np.bool_)
def check_type(arg_name, arg_value, valid_types):
"""Check value type."""
# if input type is Tensor ,get element type
if isinstance(arg_value, type(mstype.tensor)):
arg_value = arg_value.element_type()
# First, check if arg_value has argvalid_types
if isinstance(arg_value, tuple(valid_types)):
return type(arg_value).__name__
# Second, wrap arg_value with numpy array so that it can be checked through numpy api
if isinstance(arg_value, (list, tuple)):
arg_value = np.array(arg_value)
# Thirdly, check the data type by numpy's dtype api
valid = False
if isinstance(arg_value, np.ndarray):
valid = arg_value.dtype in valid_data_types
# Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and
# `check_type('x', True, [bool, int])` will check pass
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
valid = False
if not valid:
type_names = [t.__name__ for t in valid_types]
if len(valid_types) == 1:
raise TypeError(f'The type of `{arg_name}` should be {type_names[0]},'
f' but got {type(arg_value).__name__}.')
raise TypeError(f'The type of `{arg_name}` should be one of {type_names},'
f' but got {type(arg_value).__name__}.')
return type(arg_value).__name__
def check_typename(arg_name, arg_type, valid_types):
"""Check type name."""
def get_typename(t):
return t.__name__ if hasattr(t, '__name__') else str(t)
if isinstance(arg_type, type(mstype.tensor)):
arg_type = arg_type.element_type()
if arg_type in valid_types:
return arg_type
if isinstance(arg_type, tuple(valid_types)):
return arg_type
type_names = [get_typename(t) for t in valid_types]
if len(valid_types) == 1:
raise TypeError(f'The type of `{arg_name}` should be {type_names[0]},'
f' but got {get_typename(arg_type)}.')
raise TypeError(f'The type of `{arg_name}` should be one of {type_names},'
f' but got {get_typename(arg_type)}.')
def args_type_check(*type_args, **type_kwargs): def args_type_check(*type_args, **type_kwargs):

View File

@ -19,7 +19,7 @@ from mindspore import log as logger
from mindspore.communication.management import get_rank, get_group_size from mindspore.communication.management import get_rank, get_group_size
from .._c_expression import Tensor as Tensor_ from .._c_expression import Tensor as Tensor_
from .._c_expression import MetaTensor as MetaTensor_ from .._c_expression import MetaTensor as MetaTensor_
from .._checkparam import check_type, check_typename from .._checkparam import Validator as validator
from . import dtype as mstype from . import dtype as mstype
from ._register_for_tensor import tensor_operator_registry from ._register_for_tensor import tensor_operator_registry
@ -64,9 +64,19 @@ class Tensor(Tensor_):
input_data = np.array(input_data) input_data = np.array(input_data)
# If input_data is tuple/list/numpy.ndarray, it's support in check_type method. # If input_data is tuple/list/numpy.ndarray, it's support in check_type method.
check_type('tensor input_data', input_data, (Tensor_, float, int)) validator.check_value_type('input_data', input_data, (Tensor_, np.ndarray, list, tuple, float, int, bool),
'Tensor')
valid_dtypes = (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64,
np.float16, np.float32, np.float64, np.bool_)
if isinstance(input_data, np.ndarray) and input_data.dtype not in valid_dtypes:
raise TypeError(f"For Tensor, the input_data is a numpy array whose data type is "
f"{input_data.dtype} that is not supported to initialize a Tensor.")
if isinstance(input_data, (tuple, list)):
if np.array(input_data).dtype not in valid_dtypes:
raise TypeError(f"For Tensor, the input_data is {input_data} that contain unsupported element.")
if dtype is not None: if dtype is not None:
check_typename('dtype', dtype, mstype.number_type + (mstype.bool_,)) validator.check_type_name('dtype', dtype, mstype.number_type + (mstype.bool_,), "Tensor")
if isinstance(input_data, np.ndarray) and (not input_data.flags['FORC']): if isinstance(input_data, np.ndarray) and (not input_data.flags['FORC']):
input_data = np.ascontiguousarray(input_data) input_data = np.ascontiguousarray(input_data)
if dtype is None: if dtype is None:
@ -405,8 +415,9 @@ class MetaTensor(MetaTensor_):
Returns: Returns:
Array, an array after being initialized. Array, an array after being initialized.
""" """
def __init__(self, dtype, shape, init=None): def __init__(self, dtype, shape, init=None):
#check param # check param
self.init = init self.init = init
MetaTensor_.__init__(self, dtype, shape) MetaTensor_.__init__(self, dtype, shape)
@ -434,8 +445,10 @@ class MetaTensor(MetaTensor_):
msg = "Error shape={}".format(shape) msg = "Error shape={}".format(shape)
logger.error(msg) logger.error(msg)
raise ValueError(msg) raise ValueError(msg)
class seed_context: class seed_context:
'''set and restore seed''' '''set and restore seed'''
def __init__(self, init): def __init__(self, init):
self.init = init self.init = init
from .seed import get_seed from .seed import get_seed
@ -482,4 +495,5 @@ def _vm_compare(*args):
y = args[0] y = args[0]
return Tensor(np.array(fn(y))) return Tensor(np.array(fn(y)))
tensor_operator_registry.register('vm_compare', _vm_compare) tensor_operator_registry.register('vm_compare', _vm_compare)

View File

@ -21,7 +21,7 @@ from ...ops import operations as P
from ...ops.primitive import PrimitiveWithInfer, prim_attr_register from ...ops.primitive import PrimitiveWithInfer, prim_attr_register
from ...ops.composite import multitype_ops as C from ...ops.composite import multitype_ops as C
from ...ops.operations import _grad_ops as G from ...ops.operations import _grad_ops as G
from ..._checkparam import Validator from ..._checkparam import Validator as validator
from ..cell import Cell, GraphKernel from ..cell import Cell, GraphKernel
@ -194,7 +194,7 @@ class ApplyMomentum(GraphKernel):
use_locking=False, use_locking=False,
gradient_scale=1.0): gradient_scale=1.0):
super(ApplyMomentum, self).__init__() super(ApplyMomentum, self).__init__()
self.gradient_scale = Validator.check_type('gradient_scale', gradient_scale, [float]) self.gradient_scale = validator.check_value_type('gradient_scale', gradient_scale, [float], type(self).__name__)
self.fake_output_assign_1 = InplaceAssign() self.fake_output_assign_1 = InplaceAssign()
self.fake_output_assign_1.add_prim_attr("fake_output", True) self.fake_output_assign_1.add_prim_attr("fake_output", True)
self.fake_output_assign_2 = InplaceAssign() self.fake_output_assign_2 = InplaceAssign()
@ -334,7 +334,7 @@ class ReduceMean(GraphKernel):
def __init__(self, keep_dims=True): def __init__(self, keep_dims=True):
super(ReduceMean, self).__init__() super(ReduceMean, self).__init__()
self.keep_dims = Validator.check_type('keep_dims', keep_dims, [bool]) self.keep_dims = validator.check_value_type('keep_dims', keep_dims, [bool], type(self).__name__)
self.sum = P.ReduceSum(self.keep_dims) self.sum = P.ReduceSum(self.keep_dims)
def construct(self, x, axis): def construct(self, x, axis):
@ -431,8 +431,10 @@ class LayerNormForward(GraphKernel):
""" Forward function of the LayerNorm operator. """ """ Forward function of the LayerNorm operator. """
def __init__(self, begin_norm_axis=1, begin_params_axis=1): def __init__(self, begin_norm_axis=1, begin_params_axis=1):
super(LayerNormForward, self).__init__() super(LayerNormForward, self).__init__()
self.begin_norm_axis = Validator.check_type('begin_norm_axis', begin_norm_axis, [int]) self.begin_norm_axis = validator.check_value_type('begin_norm_axis', begin_norm_axis, [int],
self.begin_params_axis = Validator.check_type('begin_params_axis', begin_params_axis, [int]) type(self).__name__)
self.begin_params_axis = validator.check_value_type('begin_params_axis', begin_params_axis, [int],
type(self).__name__)
self.mul = P.Mul() self.mul = P.Mul()
self.sum_keep_dims = P.ReduceSum(keep_dims=True) self.sum_keep_dims = P.ReduceSum(keep_dims=True)
self.sub = P.Sub() self.sub = P.Sub()
@ -686,7 +688,7 @@ class LogSoftmax(GraphKernel):
def __init__(self, axis=-1): def __init__(self, axis=-1):
super(LogSoftmax, self).__init__() super(LogSoftmax, self).__init__()
self.axis = Validator.check_type('axis', axis, [int]) self.axis = validator.check_value_type('axis', axis, [int], type(self).__name__)
self.max_keep_dims = P.ReduceMax(keep_dims=True) self.max_keep_dims = P.ReduceMax(keep_dims=True)
self.sub = P.Sub() self.sub = P.Sub()
self.exp = P.Exp() self.exp = P.Exp()
@ -952,13 +954,13 @@ class Softmax(GraphKernel):
def __init__(self, axis): def __init__(self, axis):
super(Softmax, self).__init__() super(Softmax, self).__init__()
Validator.check_type("axis", axis, [int, tuple]) validator.check_value_type("axis", axis, [int, tuple], type(self).__name__)
if isinstance(axis, int): if isinstance(axis, int):
self.axis = (axis,) self.axis = (axis,)
else: else:
self.axis = axis self.axis = axis
for item in self.axis: for item in self.axis:
Validator.check_type("item of axis", item, [int]) validator.check_value_type("item of axis", item, [int], type(self).__name__)
self.max = P.ReduceMax(keep_dims=True) self.max = P.ReduceMax(keep_dims=True)
self.sub = P.Sub() self.sub = P.Sub()
self.exp = P.Exp() self.exp = P.Exp()

View File

@ -19,7 +19,7 @@ from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.ops.primitive import constexpr from mindspore.ops.primitive import constexpr
import mindspore.context as context import mindspore.context as context
from mindspore._checkparam import Validator, check_typename from mindspore._checkparam import Validator as validator
from mindspore._extends import cell_attr_register from mindspore._extends import cell_attr_register
from mindspore.communication.management import get_group_size, get_rank from mindspore.communication.management import get_group_size, get_rank
from mindspore.communication import management from mindspore.communication import management
@ -52,7 +52,7 @@ class _BatchNorm(Cell):
if momentum < 0 or momentum > 1: if momentum < 0 or momentum > 1:
raise ValueError("momentum should be a number in range [0, 1], but got {}".format(momentum)) raise ValueError("momentum should be a number in range [0, 1], but got {}".format(momentum))
self.format = Validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name) self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
if context.get_context("device_target") != "GPU" and self.format == "NHWC": if context.get_context("device_target") != "GPU" and self.format == "NHWC":
raise ValueError("NHWC format only support in GPU target.") raise ValueError("NHWC format only support in GPU target.")
self.use_batch_statistics = use_batch_statistics self.use_batch_statistics = use_batch_statistics
@ -67,7 +67,7 @@ class _BatchNorm(Cell):
gamma_init, num_features), name="gamma", requires_grad=affine) gamma_init, num_features), name="gamma", requires_grad=affine)
self.beta = Parameter(initializer( self.beta = Parameter(initializer(
beta_init, num_features), name="beta", requires_grad=affine) beta_init, num_features), name="beta", requires_grad=affine)
self.group = Validator.check_positive_int(device_num_each_group) self.group = validator.check_positive_int(device_num_each_group)
self.is_global = False self.is_global = False
if self.group != 1: if self.group != 1:
self.rank_id = get_rank() self.rank_id = get_rank()
@ -472,7 +472,7 @@ class GlobalBatchNorm(_BatchNorm):
use_batch_statistics, use_batch_statistics,
device_num_each_group, device_num_each_group,
input_dims='both') input_dims='both')
self.group = Validator.check_positive_int(device_num_each_group) self.group = validator.check_positive_int(device_num_each_group)
if self.group <= 1: if self.group <= 1:
raise ValueError("the number of group must be greater than 1.") raise ValueError("the number of group must be greater than 1.")
@ -607,12 +607,12 @@ class GroupNorm(Cell):
def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, gamma_init='ones', beta_init='zeros'): def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, gamma_init='ones', beta_init='zeros'):
super(GroupNorm, self).__init__() super(GroupNorm, self).__init__()
self.num_groups = Validator.check_positive_int(num_groups) self.num_groups = validator.check_positive_int(num_groups)
self.num_channels = Validator.check_positive_int(num_channels) self.num_channels = validator.check_positive_int(num_channels)
if num_channels % num_groups != 0: if num_channels % num_groups != 0:
raise ValueError("num_channels should be divided by num_groups") raise ValueError("num_channels should be divided by num_groups")
self.eps = check_typename('eps', eps, (float,)) self.eps = validator.check_value_type('eps', eps, (float,), type(self).__name__)
self.affine = Validator.check_bool(affine) self.affine = validator.check_bool(affine)
gamma = initializer(gamma_init, num_channels) gamma = initializer(gamma_init, num_channels)
beta = initializer(beta_init, num_channels) beta = initializer(beta_init, num_channels)

View File

@ -442,8 +442,8 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver):
super(FakeQuantWithMinMaxObserver, self).__init__(quant_dtype=quant_dtype, per_channel=per_channel, super(FakeQuantWithMinMaxObserver, self).__init__(quant_dtype=quant_dtype, per_channel=per_channel,
symmetric=symmetric, narrow_range=narrow_range, symmetric=symmetric, narrow_range=narrow_range,
num_channels=num_channels) num_channels=num_channels)
Validator.check_type("min_init", min_init, [int, float]) Validator.check_value_type("min_init", min_init, [int, float], type(self).__name__)
Validator.check_type("max_init", max_init, [int, float]) Validator.check_value_type("max_init", max_init, [int, float], type(self).__name__)
Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT) Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT)
Validator.check_non_negative_int(quant_delay, 'quant_delay') Validator.check_non_negative_int(quant_delay, 'quant_delay')
self.min_init = min_init self.min_init = min_init

View File

@ -68,7 +68,7 @@ class GumbelCDF(Bijector):
""" """
param = dict(locals()) param = dict(locals())
valid_dtype = mstype.float_type + mstype.int_type + mstype.uint_type valid_dtype = mstype.float_type + mstype.int_type + mstype.uint_type
Validator.check_type(type(self).__name__, dtype, valid_dtype) Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
parameter_type = set_param_type({'loc': loc, "scale": scale}, dtype) parameter_type = set_param_type({'loc': loc, "scale": scale}, dtype)
super(GumbelCDF, self).__init__(name=name, dtype=dtype, param=param) super(GumbelCDF, self).__init__(name=name, dtype=dtype, param=param)

View File

@ -119,7 +119,7 @@ class Bernoulli(Distribution):
param = dict(locals()) param = dict(locals())
param['param_dict'] = {'probs': probs} param['param_dict'] = {'probs': probs}
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
Validator.check_type(type(self).__name__, dtype, valid_dtype) Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
super(Bernoulli, self).__init__(seed, dtype, name, param) super(Bernoulli, self).__init__(seed, dtype, name, param)
self._probs = self._add_parameter(probs, 'probs') self._probs = self._add_parameter(probs, 'probs')

View File

@ -109,7 +109,7 @@ class Categorical(Distribution):
param = dict(locals()) param = dict(locals())
param['param_dict'] = {'probs': probs} param['param_dict'] = {'probs': probs}
valid_dtype = mstype.int_type valid_dtype = mstype.int_type
Validator.check_type("Categorical", dtype, valid_dtype) Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
super(Categorical, self).__init__(seed, dtype, name, param) super(Categorical, self).__init__(seed, dtype, name, param)
self._probs = self._add_parameter(probs, 'probs') self._probs = self._add_parameter(probs, 'probs')

View File

@ -121,7 +121,7 @@ class Exponential(Distribution):
param = dict(locals()) param = dict(locals())
param['param_dict'] = {'rate': rate} param['param_dict'] = {'rate': rate}
valid_dtype = mstype.float_type valid_dtype = mstype.float_type
Validator.check_type(type(self).__name__, dtype, valid_dtype) Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
super(Exponential, self).__init__(seed, dtype, name, param) super(Exponential, self).__init__(seed, dtype, name, param)
self._rate = self._add_parameter(rate, 'rate') self._rate = self._add_parameter(rate, 'rate')

View File

@ -122,7 +122,7 @@ class Geometric(Distribution):
param = dict(locals()) param = dict(locals())
param['param_dict'] = {'probs': probs} param['param_dict'] = {'probs': probs}
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
Validator.check_type(type(self).__name__, dtype, valid_dtype) Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
super(Geometric, self).__init__(seed, dtype, name, param) super(Geometric, self).__init__(seed, dtype, name, param)
self._probs = self._add_parameter(probs, 'probs') self._probs = self._add_parameter(probs, 'probs')

View File

@ -102,7 +102,7 @@ class Gumbel(TransformedDistribution):
Constructor of Gumbel distribution. Constructor of Gumbel distribution.
""" """
valid_dtype = mstype.float_type valid_dtype = mstype.float_type
Validator.check_type(type(self).__name__, dtype, valid_dtype) Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
gumbel_cdf = msb.GumbelCDF(loc, scale, dtype) gumbel_cdf = msb.GumbelCDF(loc, scale, dtype)
super(Gumbel, self).__init__( super(Gumbel, self).__init__(
distribution=msd.Uniform(0.0, 1.0, dtype=dtype), distribution=msd.Uniform(0.0, 1.0, dtype=dtype),

View File

@ -111,7 +111,7 @@ class Logistic(Distribution):
param = dict(locals()) param = dict(locals())
param['param_dict'] = {'loc': loc, 'scale': scale} param['param_dict'] = {'loc': loc, 'scale': scale}
valid_dtype = mstype.float_type valid_dtype = mstype.float_type
Validator.check_type(type(self).__name__, dtype, valid_dtype) Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
super(Logistic, self).__init__(seed, dtype, name, param) super(Logistic, self).__init__(seed, dtype, name, param)
self._loc = self._add_parameter(loc, 'loc') self._loc = self._add_parameter(loc, 'loc')

View File

@ -127,7 +127,7 @@ class Normal(Distribution):
param = dict(locals()) param = dict(locals())
param['param_dict'] = {'mean': mean, 'sd': sd} param['param_dict'] = {'mean': mean, 'sd': sd}
valid_dtype = mstype.float_type valid_dtype = mstype.float_type
Validator.check_type(type(self).__name__, dtype, valid_dtype) Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
super(Normal, self).__init__(seed, dtype, name, param) super(Normal, self).__init__(seed, dtype, name, param)
self._mean_value = self._add_parameter(mean, 'mean') self._mean_value = self._add_parameter(mean, 'mean')

View File

@ -126,7 +126,7 @@ class Uniform(Distribution):
param = dict(locals()) param = dict(locals())
param['param_dict'] = {'low': low, 'high': high} param['param_dict'] = {'low': low, 'high': high}
valid_dtype = mstype.float_type valid_dtype = mstype.float_type
Validator.check_type(type(self).__name__, dtype, valid_dtype) Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
super(Uniform, self).__init__(seed, dtype, name, param) super(Uniform, self).__init__(seed, dtype, name, param)
self._low = self._add_parameter(low, 'low') self._low = self._add_parameter(low, 'low')

View File

@ -55,8 +55,7 @@ class UpdateCache(PrimitiveWithInfer):
return [1] return [1]
def infer_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype): def infer_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype):
args = {"indices": indices_dtype} validator.check_tensor_dtype_valid("indices", indices_dtype, mstype.int_type, self.name)
validator.check_tensor_type_same(args, mstype.int_type, self.name)
return input_x_dtype return input_x_dtype
@ -140,7 +139,7 @@ class SearchCacheIdx(PrimitiveWithInfer):
def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype): def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype):
args = {"hashmap": hashmap_dtype, "indices": indices_dtype} args = {"hashmap": hashmap_dtype, "indices": indices_dtype}
validator.check_tensor_type_same(args, mstype.int_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, mstype.int_type, self.name)
out_dtype = (hashmap_dtype, hashmap_dtype, hashmap_dtype) out_dtype = (hashmap_dtype, hashmap_dtype, hashmap_dtype)
return out_dtype return out_dtype
@ -182,8 +181,7 @@ class CacheSwapHashmap(PrimitiveWithInfer):
return out_shape return out_shape
def infer_dtype(self, hashmap_dtype, miss_emb_idx_dtype, step_dtype): def infer_dtype(self, hashmap_dtype, miss_emb_idx_dtype, step_dtype):
args = {"miss_emb_idx": miss_emb_idx_dtype} validator.check_tensor_dtype_valid("miss_emb_idx", miss_emb_idx_dtype, mstype.int_type, self.name)
validator.check_tensor_type_same(args, mstype.int_type, self.name)
out_dtype = (miss_emb_idx_dtype, miss_emb_idx_dtype) out_dtype = (miss_emb_idx_dtype, miss_emb_idx_dtype)
return out_dtype return out_dtype
@ -224,8 +222,7 @@ class CacheSwapTable(PrimitiveWithInfer):
return miss_value_shape return miss_value_shape
def infer_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype): def infer_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype):
args = {"swap_cache_idx": swap_cache_idx_dtype} validator.check_tensor_dtype_valid("swap_cache_idx", swap_cache_idx_dtype, mstype.int_type, self.name)
validator.check_tensor_type_same(args, mstype.int_type, self.name)
return miss_value_dtype return miss_value_dtype
@ -261,7 +258,7 @@ class MapCacheIdx(PrimitiveWithInfer):
def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype): def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype):
args = {"hashmap": hashmap_dtype, "indices": indices_dtype} args = {"hashmap": hashmap_dtype, "indices": indices_dtype}
validator.check_tensor_type_same(args, mstype.int_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, mstype.int_type, self.name)
out_dtype = (hashmap_dtype, hashmap_dtype, out_dtype = (hashmap_dtype, hashmap_dtype,
hashmap_dtype, hashmap_dtype) hashmap_dtype, hashmap_dtype)
return out_dtype return out_dtype

View File

@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
"""Operators for gradients.""" """Operators for gradients."""
from functools import partial
from .. import signature as sig from .. import signature as sig
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
@ -23,6 +24,7 @@ from ...common import dtype as mstype
from .. import functional as F from .. import functional as F
from ... import context from ... import context
class AbsGrad(PrimitiveWithInfer): class AbsGrad(PrimitiveWithInfer):
"""Computes gradients for abs operation.""" """Computes gradients for abs operation."""
@ -55,7 +57,7 @@ class ACosGrad(PrimitiveWithInfer):
def infer_dtype(self, x, dout): def infer_dtype(self, x, dout):
args = {"x": x, "dout": dout} args = {"x": x, "dout": dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return x return x
@ -72,7 +74,7 @@ class AcoshGrad(PrimitiveWithInfer):
def infer_dtype(self, x, dout): def infer_dtype(self, x, dout):
args = {"x": x, "dout": dout} args = {"x": x, "dout": dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return x return x
@ -94,7 +96,7 @@ class AsinGrad(PrimitiveWithInfer):
def infer_dtype(self, x, dout): def infer_dtype(self, x, dout):
args = {"x": x, "dout": dout} args = {"x": x, "dout": dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return x return x
@ -111,7 +113,7 @@ class AsinhGrad(PrimitiveWithInfer):
def infer_dtype(self, x, dout): def infer_dtype(self, x, dout):
args = {"x": x, "dout": dout} args = {"x": x, "dout": dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return x return x
@ -128,7 +130,7 @@ class ReciprocalGrad(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, dout_dtype): def infer_dtype(self, x_dtype, dout_dtype):
args = {"x": x_dtype, "dout": dout_dtype} args = {"x": x_dtype, "dout": dout_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
return x_dtype return x_dtype
@ -145,7 +147,8 @@ class RsqrtGrad(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, dout_dtype): def infer_dtype(self, x_dtype, dout_dtype):
args = {"x": x_dtype, "dout": dout_dtype} args = {"x": x_dtype, "dout": dout_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], self.name) validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32, mstype.int32, mstype.int8],
self.name)
return x_dtype return x_dtype
@ -162,7 +165,7 @@ class SoftmaxGrad(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, dout_dtype): def infer_dtype(self, x_dtype, dout_dtype):
args = {"x": x_dtype, "dout": dout_dtype} args = {"x": x_dtype, "dout": dout_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
return x_dtype return x_dtype
@ -179,7 +182,7 @@ class SqrtGrad(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, dout_dtype): def infer_dtype(self, x_dtype, dout_dtype):
args = {"x": x_dtype, "dout": dout_dtype} args = {"x": x_dtype, "dout": dout_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
return x_dtype return x_dtype
@ -232,7 +235,7 @@ class KLDivLossGrad(PrimitiveWithInfer):
def infer_dtype(self, x_type, y_type, doutput_type): def infer_dtype(self, x_type, y_type, doutput_type):
args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type} args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
return x_type, y_type return x_type, y_type
@ -251,7 +254,7 @@ class BinaryCrossEntropyGrad(PrimitiveWithInfer):
def infer_dtype(self, x_type, y_type, doutput_type, weight_type): def infer_dtype(self, x_type, y_type, doutput_type, weight_type):
args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type} args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
if weight_type: if weight_type:
validator.check('x_type', x_type, 'weight_type', weight_type, Rel.EQ, TypeError) validator.check('x_type', x_type, 'weight_type', weight_type, Rel.EQ, TypeError)
return x_type return x_type
@ -343,7 +346,8 @@ class Conv2DBackpropFilter(PrimitiveWithInfer):
for i, dim_len in enumerate(w_size_v): for i, dim_len in enumerate(w_size_v):
validator.check_value_type("w_size[%d]" % i, dim_len, [int], self.name) validator.check_value_type("w_size[%d]" % i, dim_len, [int], self.name)
args = {"x": x['dtype'], "doutput": doutput['dtype']} args = {"x": x['dtype'], "doutput": doutput['dtype']}
validator.check_tensor_type_same(args, [mstype.int8, mstype.int32, mstype.float16, mstype.float32], self.name) validator.check_tensors_dtypes_same_and_valid(args, [mstype.int8, mstype.int32, mstype.float16, mstype.float32],
self.name)
out = { out = {
'value': None, 'value': None,
'shape': w_size_v, 'shape': w_size_v,
@ -406,7 +410,7 @@ class DepthwiseConv2dNativeBackpropFilter(PrimitiveWithInfer):
def __infer__(self, x, w_size, dout): def __infer__(self, x, w_size, dout):
w_size_v = w_size['value'] w_size_v = w_size['value']
args = {'x': x['dtype'], 'dout': dout['dtype']} args = {'x': x['dtype'], 'dout': dout['dtype']}
validator.check_tensor_type_same(args, mstype.number_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
out = { out = {
'value': None, 'value': None,
'shape': w_size_v, 'shape': w_size_v,
@ -466,7 +470,7 @@ class DepthwiseConv2dNativeBackpropInput(PrimitiveWithInfer):
def __infer__(self, x_size, w, dout): def __infer__(self, x_size, w, dout):
args = {'w': w['dtype'], 'dout': dout['dtype']} args = {'w': w['dtype'], 'dout': dout['dtype']}
validator.check_tensor_type_same(args, mstype.number_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
x_size_v = x_size['value'] x_size_v = x_size['value']
out = { out = {
'value': None, 'value': None,
@ -505,10 +509,9 @@ class DropoutGrad(PrimitiveWithInfer):
return dy_shape return dy_shape
def infer_dtype(self, dy_dtype, mask_dtype): def infer_dtype(self, dy_dtype, mask_dtype):
valid_types = (mstype.float16, mstype.float32) valid_dtypes = (mstype.float16, mstype.float32)
validator.check_subclass("dy", dy_dtype, mstype.tensor, self.name)
validator.check_subclass("mask", mask_dtype, mstype.tensor, self.name) validator.check_subclass("mask", mask_dtype, mstype.tensor, self.name)
validator.check_tensor_type_same({"dy_dtype": dy_dtype}, valid_types, self.name) validator.check_tensor_dtype_valid("dy", dy_dtype, valid_dtypes, self.name)
return dy_dtype return dy_dtype
@ -627,9 +630,10 @@ class GeluGrad(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, y_backprop_dtype, x_dtype, y_dtype): def infer_dtype(self, y_backprop_dtype, x_dtype, y_dtype):
validator.check_tensor_type_same({"y_backprop": y_backprop_dtype}, (mstype.float16, mstype.float32), self.name) tuple(map(partial(validator.check_tensor_dtype_valid,
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
validator.check_tensor_type_same({"y": y_dtype}, (mstype.float16, mstype.float32), self.name) ("y_backprop", "x", "y"),
(y_backprop_dtype, x_dtype, y_dtype)))
return x_dtype return x_dtype
@ -782,7 +786,7 @@ class MaxPoolGradGrad(_PoolGrad):
def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype): def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype):
args = {'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'grad_dtype': grad_dtype} args = {'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'grad_dtype': grad_dtype}
validator.check_tensor_type_same(args, [mstype.float16], self.name) validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16], self.name)
return x1_dtype return x1_dtype
@ -858,7 +862,7 @@ class MaxPoolGradGradWithArgmax(_PoolGrad):
def infer_dtype(self, x_dtype, grad_dtype, argmax_dtype): def infer_dtype(self, x_dtype, grad_dtype, argmax_dtype):
args = {'x_dtype': x_dtype, 'grad_dtype': grad_dtype} args = {'x_dtype': x_dtype, 'grad_dtype': grad_dtype}
validator.check_tensor_type_same(args, [mstype.float16], self.name) validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16], self.name)
return grad_dtype return grad_dtype
@ -902,7 +906,7 @@ class L2NormalizeGrad(PrimitiveWithInfer):
def infer_dtype(self, input_x, out, dout): def infer_dtype(self, input_x, out, dout):
args = {'input_x': input_x, 'out': out, 'dout': dout} args = {'input_x': input_x, 'out': out, 'dout': dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return input_x return input_x
@ -993,7 +997,7 @@ class LSTMGradData(PrimitiveWithInfer):
def infer_dtype(self, y_dtype, dy_dtype, dhy_dtype, dcy_dtype, w_dtype, def infer_dtype(self, y_dtype, dy_dtype, dhy_dtype, dcy_dtype, w_dtype,
hx_dtype, cx_dtype, reserve_dtype, state_dtype): hx_dtype, cx_dtype, reserve_dtype, state_dtype):
args = {"dy": dy_dtype, "dhy": dhy_dtype, "dcy": dcy_dtype} args = {"dy": dy_dtype, "dhy": dhy_dtype, "dcy": dcy_dtype}
validator.check_tensor_type_same(args, (mstype.float32, mstype.float16), self.name) validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32, mstype.float16), self.name)
return (dy_dtype, dy_dtype, dy_dtype) return (dy_dtype, dy_dtype, dy_dtype)
@ -1265,14 +1269,14 @@ class DynamicGRUV2Grad(PrimitiveWithInfer):
args = {"y_dtype": y_dtype, "init_h_dtype": init_h_dtype, "h_dtype": h_dtype, args = {"y_dtype": y_dtype, "init_h_dtype": init_h_dtype, "h_dtype": h_dtype,
"dy_dtype": dy_dtype, "dh_dtype": dh_dtype, "update_dtype": update_dtype, "dy_dtype": dy_dtype, "dh_dtype": dh_dtype, "update_dtype": update_dtype,
"reset_dtype": reset_dtype, "new_dtype": new_dtype, "hnew_dtype": hnew_dtype} "reset_dtype": reset_dtype, "new_dtype": new_dtype, "hnew_dtype": hnew_dtype}
validator.check_tensor_type_same({"x_dtype": x_dtype}, valid_types, self.name) validator.check_tensor_dtype_valid("x_dtype", x_dtype, valid_types, self.name)
validator.check_tensor_type_same({"winput_dtype": winput_dtype}, valid_types, self.name) validator.check_tensor_dtype_valid("winput_dtype", winput_dtype, valid_types, self.name)
validator.check_tensor_type_same({"whidden_dtype": whidden_dtype}, valid_types, self.name) validator.check_tensor_dtype_valid("whidden_dtype", whidden_dtype, valid_types, self.name)
validator.check_tensor_type_same(args, valid_types, self.name) validator.check_tensors_dtypes_same_and_valid(args, valid_types, self.name)
if seq_dtype is not None: if seq_dtype is not None:
validator.check_tensor_type_same({"seq_dtype": seq_dtype}, (mstype.float32, mstype.float16), self.name) validator.check_tensor_dtype_valid("seq_dtype", seq_dtype, valid_types, self.name)
if mask_dtype is not None: if mask_dtype is not None:
validator.check_tensor_type_same({"mask_dtype": mask_dtype}, (mstype.float32, mstype.float16), self.name) validator.check_tensor_dtype_valid("mask_dtype", mask_dtype, valid_types, self.name)
return x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype return x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype
@ -1302,10 +1306,10 @@ class PReLUGrad(PrimitiveWithInfer):
return y_backprop_shape, w_shape return y_backprop_shape, w_shape
def infer_dtype(self, y_backprop_dtype, A_dtype, w_dtype): def infer_dtype(self, y_backprop_dtype, A_dtype, w_dtype):
valid_types = (mstype.float16, mstype.float32) tuple(map(partial(validator.check_tensor_dtype_valid,
validator.check_tensor_type_same({"y_backprop": y_backprop_dtype}, valid_types, self.name) valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
validator.check_tensor_type_same({"A_dtype": A_dtype}, valid_types, self.name) ('y_backprop', "input_x", "weight"),
validator.check_tensor_type_same({"w_dtype": w_dtype}, valid_types, self.name) (y_backprop_dtype, A_dtype, w_dtype)))
return y_backprop_dtype, w_dtype return y_backprop_dtype, w_dtype
@ -1335,8 +1339,9 @@ class ReLU6Grad(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, y_grad_dtype, x_dtype): def infer_dtype(self, y_grad_dtype, x_dtype):
validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name) valid_dtypes = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) validator.check_tensor_dtype_valid("y_grad", y_grad_dtype, valid_dtypes, self.name)
validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
return x_dtype return x_dtype
@ -1354,8 +1359,8 @@ class ReluGradV2(PrimitiveWithInfer):
return gradients_shape return gradients_shape
def infer_dtype(self, gradients_dtype, mask_dtype): def infer_dtype(self, gradients_dtype, mask_dtype):
validator.check_tensor_type_same({'gradients': gradients_dtype}, mstype.number_type, self.name) validator.check_tensor_dtype_valid('gradients', gradients_dtype, mstype.number_type, self.name)
validator.check_tensor_type_same({'mask': mask_dtype}, (mstype.uint8,), self.name) validator.check_tensor_dtype_valid('mask', mask_dtype, (mstype.uint8,), self.name)
return gradients_dtype return gradients_dtype
@ -1371,7 +1376,7 @@ class EluGrad(PrimitiveWithInfer):
def infer_dtype(self, y_grad_dtype, x_dtype): def infer_dtype(self, y_grad_dtype, x_dtype):
args = {'y_grad': y_grad_dtype, 'x': x_dtype} args = {'y_grad': y_grad_dtype, 'x': x_dtype}
validator.check_tensor_type_same(args, mstype.float_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type, self.name)
return x_dtype return x_dtype
@ -1474,7 +1479,7 @@ class SigmoidGrad(PrimitiveWithInfer):
def infer_dtype(self, out, dout): def infer_dtype(self, out, dout):
args = {'out': out, 'dout': dout} args = {'out': out, 'dout': dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return out return out
@ -1489,8 +1494,9 @@ class HSigmoidGrad(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, y_grad_dtype, x_dtype): def infer_dtype(self, y_grad_dtype, x_dtype):
validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name) valid_dtypes = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) validator.check_tensor_dtype_valid("y_grad", y_grad_dtype, valid_dtypes, self.name)
validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
return x_dtype return x_dtype
@ -1505,8 +1511,9 @@ class HSwishGrad(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, y_grad_dtype, x_dtype): def infer_dtype(self, y_grad_dtype, x_dtype):
validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name) valid_dtypes = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) validator.check_tensor_dtype_valid("y_grad", y_grad_dtype, valid_dtypes, self.name)
validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
return x_dtype return x_dtype
@ -1525,7 +1532,7 @@ class SigmoidCrossEntropyWithLogitsGrad(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, y_dtype, dout_dtype): def infer_dtype(self, x_dtype, y_dtype, dout_dtype):
args = {"x_dtype": x_dtype, "y_dtype": y_dtype, 'dout_dtype': dout_dtype} args = {"x_dtype": x_dtype, "y_dtype": y_dtype, 'dout_dtype': dout_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return dout_dtype return dout_dtype
@ -1562,7 +1569,7 @@ class SmoothL1LossGrad(PrimitiveWithInfer):
def infer_dtype(self, prediction, target, dloss): def infer_dtype(self, prediction, target, dloss):
args = {"prediction": prediction, "target": target, 'dloss': dloss} args = {"prediction": prediction, "target": target, 'dloss': dloss}
validator.check_tensor_type_same(args, mstype.number_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return dloss return dloss
@ -1597,8 +1604,7 @@ class StridedSliceGrad(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output']) self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output'])
def __infer__(self, dy, shapex, begin, end, strides): def __infer__(self, dy, shapex, begin, end, strides):
args = {"dy": dy['dtype']} validator.check_tensor_dtype_valid("dy", dy['dtype'], mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name)
for idx, item in enumerate(shapex['value']): for idx, item in enumerate(shapex['value']):
validator.check_value_type("shapex[%d]" % idx, item, [int], self.name) validator.check_value_type("shapex[%d]" % idx, item, [int], self.name)
@ -1627,7 +1633,7 @@ class SoftplusGrad(PrimitiveWithInfer):
def infer_dtype(self, dout_dtype, x_dtype): def infer_dtype(self, dout_dtype, x_dtype):
args = {"x_dtype": x_dtype, "dout_dtype": dout_dtype} args = {"x_dtype": x_dtype, "dout_dtype": dout_dtype}
validator.check_tensor_type_same(args, mstype.float_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type, self.name)
return x_dtype return x_dtype
@ -1643,7 +1649,7 @@ class TanhGrad(PrimitiveWithInfer):
def infer_dtype(self, out, dout): def infer_dtype(self, out, dout):
args = {"out": out, "dout": dout} args = {"out": out, "dout": dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return out return out
@ -1756,7 +1762,7 @@ class AtanGrad(PrimitiveWithInfer):
def infer_dtype(self, x, dout): def infer_dtype(self, x, dout):
args = {"x": x, "dout": dout} args = {"x": x, "dout": dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return x return x
@ -1900,7 +1906,7 @@ class LRNGrad(PrimitiveWithInfer):
def infer_dtype(self, grads, x, y): def infer_dtype(self, grads, x, y):
args = {"grads": grads, "x": x, "y": y} args = {"grads": grads, "x": x, "y": y}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32,), self.name) validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32,), self.name)
return x return x
def infer_shape(self, grads, x, y): def infer_shape(self, grads, x, y):

View File

@ -54,6 +54,7 @@ class ExtractImagePatches(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, ksizes, strides, rates, padding="valid"): def __init__(self, ksizes, strides, rates, padding="valid"):
"""init""" """init"""
def _check_tuple_or_list(arg_name, arg_val, prim_name): def _check_tuple_or_list(arg_name, arg_val, prim_name):
validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.name) validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.name)
if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1: if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1:
@ -103,7 +104,7 @@ class ExtractImagePatches(PrimitiveWithInfer):
def infer_dtype(self, input_x): def infer_dtype(self, input_x):
"""infer dtype""" """infer dtype"""
validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name) validator.check_tensor_dtype_valid("input_x", input_x, mstype.number_type, self.name)
return input_x return input_x
@ -161,7 +162,7 @@ class Range(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x_dtype': x_dtype}, [mstype.float32, mstype.int32], self.name) validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.int32], self.name)
return x_dtype return x_dtype
@ -254,6 +255,7 @@ class Dequant(PrimitiveWithInfer):
>>> dequant = P.Dequant(False, False) >>> dequant = P.Dequant(False, False)
>>> y = dequant(input_x) >>> y = dequant(input_x)
""" """
@prim_attr_register @prim_attr_register
def __init__(self, sqrt_mode=False, relu_flag=False): def __init__(self, sqrt_mode=False, relu_flag=False):
self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name) self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
@ -303,10 +305,9 @@ class LinSpace(PrimitiveWithInfer):
return assist return assist
def infer_dtype(self, assist, start, stop, num): def infer_dtype(self, assist, start, stop, num):
args = {"num": num} validator.check_tensor_dtype_valid("num", num, (mstype.int32,), self.name)
validator.check_tensor_type_same(args, (mstype.int32,), self.name)
args = {"assist": assist, "start": start, "stop": stop} args = {"assist": assist, "start": start, "stop": stop}
validator.check_tensor_type_same(args, (mstype.float32,), self.name) validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32,), self.name)
return assist return assist
@ -343,12 +344,12 @@ class MatrixDiag(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, assist_dtype): def infer_dtype(self, x_dtype, assist_dtype):
valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
args = {"x": x_dtype, "assist": assist_dtype} args = {"x": x_dtype, "assist": assist_dtype}
validator.check_tensor_type_same(args, valid_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
return x_dtype return x_dtype
def infer_shape(self, x_shape, assist_shape): def infer_shape(self, x_shape, assist_shape):
validator.check_int(len(assist_shape), 2, Rel.GE, "assist rank", self.name) validator.check_int(len(assist_shape), 2, Rel.GE, "assist rank", self.name)
validator.check('rank of x', len(x_shape)+1, validator.check('rank of x', len(x_shape) + 1,
'rank of assist', len(assist_shape), Rel.LE, self.name) 'rank of assist', len(assist_shape), Rel.LE, self.name)
validator.check('assist\'s penultimate dimension', assist_shape[-2], 'assist\'s last dimension', validator.check('assist\'s penultimate dimension', assist_shape[-2], 'assist\'s last dimension',
assist_shape[-1], Rel.EQ, self.name) assist_shape[-1], Rel.EQ, self.name)
@ -358,7 +359,7 @@ class MatrixDiag(PrimitiveWithInfer):
while r_idx >= r_end_dim: while r_idx >= r_end_dim:
if x_shape[r_idx] != 1: if x_shape[r_idx] != 1:
validator.check("reverse x dim %d" % r_idx, x_shape[r_idx], "reverse assist dim %d" % validator.check("reverse x dim %d" % r_idx, x_shape[r_idx], "reverse assist dim %d" %
assist_shape[r_idx-1], assist_shape[r_idx-1], Rel.EQ, self.name) assist_shape[r_idx - 1], assist_shape[r_idx - 1], Rel.EQ, self.name)
r_idx = r_idx - 1 r_idx = r_idx - 1
return assist_shape return assist_shape
@ -391,7 +392,7 @@ class MatrixDiagPart(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, assist_dtype): def infer_dtype(self, x_dtype, assist_dtype):
valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
args = {"x": x_dtype, "assist": assist_dtype} args = {"x": x_dtype, "assist": assist_dtype}
validator.check_tensor_type_same(args, valid_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
return x_dtype return x_dtype
def infer_shape(self, x_shape, assist_shape): def infer_shape(self, x_shape, assist_shape):
@ -434,7 +435,7 @@ class MatrixSetDiag(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, diagonal_dtype, assist_dtype): def infer_dtype(self, x_dtype, diagonal_dtype, assist_dtype):
valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
args = {"x": x_dtype, "diagonal": diagonal_dtype, "assist": assist_dtype} args = {"x": x_dtype, "diagonal": diagonal_dtype, "assist": assist_dtype}
validator.check_tensor_type_same(args, valid_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
return x_dtype return x_dtype
def infer_shape(self, x_shape, diagonal_shape, assist_shape): def infer_shape(self, x_shape, diagonal_shape, assist_shape):
@ -583,21 +584,21 @@ class DynamicGRUV2(PrimitiveWithInfer):
return y_shape, outh_shape, outh_shape, outh_shape, outh_shape, outh_shape return y_shape, outh_shape, outh_shape, outh_shape, outh_shape, outh_shape
def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, binput_dtype, bhidden_dtype, seq_dtype, h_dtype): def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, binput_dtype, bhidden_dtype, seq_dtype, h_dtype):
validator.check_tensor_type_same({"x dtype": x_dtype}, (mstype.float16,), self.name) validator.check_tensor_dtype_valid("x dtype", x_dtype, [mstype.float16], self.name)
validator.check_tensor_type_same({"weight input dtype": winput_dtype}, (mstype.float16,), self.name) validator.check_tensor_dtype_valid("weight input dtype", winput_dtype, [mstype.float16], self.name)
validator.check_tensor_type_same({"weight hidden dtype": whidden_dtype}, (mstype.float16,), self.name) validator.check_tensor_dtype_valid("weight hidden dtype", whidden_dtype, [mstype.float16], self.name)
b_dtype = mstype.float32 b_dtype = mstype.float32
if binput_dtype is not None: if binput_dtype is not None:
validator.check_tensor_type_same({"bias input dtype": binput_dtype}, validator.check_tensor_dtype_valid("bias input dtype", binput_dtype,
(mstype.float16, mstype.float32), self.name) (mstype.float16, mstype.float32), self.name)
b_dtype = binput_dtype b_dtype = binput_dtype
elif bhidden_dtype is not None: elif bhidden_dtype is not None:
validator.check_tensor_type_same({"bias hidden dtype": bhidden_dtype}, validator.check_tensor_dtype_valid("bias hidden dtype", bhidden_dtype,
(mstype.float16, mstype.float32), self.name) (mstype.float16, mstype.float32), self.name)
b_dtype = bhidden_dtype b_dtype = bhidden_dtype
elif h_dtype is not None: elif h_dtype is not None:
validator.check_tensor_type_same({"init_h dtype": h_dtype}, validator.check_tensor_dtype_valid("init_h dtype", h_dtype,
(mstype.float16, mstype.float32), self.name) (mstype.float16, mstype.float32), self.name)
b_dtype = h_dtype b_dtype = h_dtype
return b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype return b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype

View File

@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
"""Operators for quantization.""" """Operators for quantization."""
from functools import partial
import mindspore.context as context import mindspore.context as context
from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator
@ -92,12 +93,10 @@ class MinMaxUpdatePerLayer(PrimitiveWithInfer):
return min_shape, max_shape return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type): def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32) tuple(map(partial(validator.check_tensor_dtype_valid,
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
validator.check_tensor_type_same( ("x", "min", "max"),
{"min": min_type}, valid_types, self.name) (x_type, min_type, max_type)))
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return min_type, max_type return min_type, max_type
@ -157,13 +156,10 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
return min_shape, max_shape return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type): def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32) tuple(map(partial(validator.check_tensor_dtype_valid,
validator.check_tensor_type_same( valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
{"x": x_type}, valid_types, self.name) ("x", "min", "max"),
validator.check_tensor_type_same( (x_type, min_type, max_type)))
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return min_type, max_type return min_type, max_type
@ -193,6 +189,7 @@ class FakeQuantWithMinMaxVars(PrimitiveWithInfer):
>>> input_tensor, min_tensor, max_tensor) >>> input_tensor, min_tensor, max_tensor)
>>> output_tensor shape: (3, 16, 5, 5) data type: mstype.float32 >>> output_tensor shape: (3, 16, 5, 5) data type: mstype.float32
""" """
@prim_attr_register @prim_attr_register
def __init__(self, def __init__(self,
num_bits=8, num_bits=8,
@ -217,10 +214,10 @@ class FakeQuantWithMinMaxVars(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_type, min_type, max_type): def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32) tuple(map(partial(validator.check_tensor_dtype_valid,
validator.check_tensor_type_same({'x': x_type}, valid_types, self.name) valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
validator.check_tensor_type_same({'min': min_type}, valid_types, self.name) ("x", "min", "max"),
validator.check_tensor_type_same({'max': max_type}, valid_types, self.name) (x_type, min_type, max_type)))
return x_type return x_type
@ -256,6 +253,7 @@ class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer):
>>> min_gradient shape: (1,) data type: mstype.float32 >>> min_gradient shape: (1,) data type: mstype.float32
>>> max_gradient shape: (1,) data type: mstype.float32 >>> max_gradient shape: (1,) data type: mstype.float32
""" """
@prim_attr_register @prim_attr_register
def __init__(self, def __init__(self,
num_bits=8, num_bits=8,
@ -281,11 +279,10 @@ class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer):
return x_shape, min_shape, max_shape return x_shape, min_shape, max_shape
def infer_dtype(self, dout_type, x_type, min_type, max_type): def infer_dtype(self, dout_type, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32) tuple(map(partial(validator.check_tensor_dtype_valid,
validator.check_tensor_type_same({'dout': dout_type}, valid_types, self.name) valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
validator.check_tensor_type_same({'x': x_type}, valid_types, self.name) ('dout', "x", "min", "max"),
validator.check_tensor_type_same({'min': min_type}, valid_types, self.name) (dout_type, x_type, min_type, max_type)))
validator.check_tensor_type_same({'max': max_type}, valid_types, self.name)
return x_type, min_type, max_type return x_type, min_type, max_type
@ -315,6 +312,7 @@ class FakeQuantWithMinMaxVarsPerChannel(PrimitiveWithInfer):
>>> input_tensor, min_tensor, max_tensor) >>> input_tensor, min_tensor, max_tensor)
>>> output_tensor shape: (3, 16, 3, 4) data type: mstype.float32 >>> output_tensor shape: (3, 16, 3, 4) data type: mstype.float32
""" """
@prim_attr_register @prim_attr_register
def __init__(self, def __init__(self,
num_bits=8, num_bits=8,
@ -332,10 +330,10 @@ class FakeQuantWithMinMaxVarsPerChannel(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_type, min_type, max_type): def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32) tuple(map(partial(validator.check_tensor_dtype_valid,
validator.check_tensor_type_same({'x': x_type}, valid_types, self.name) valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
validator.check_tensor_type_same({'min': min_type}, valid_types, self.name) ("x", "min", "max"),
validator.check_tensor_type_same({'max': max_type}, valid_types, self.name) (x_type, min_type, max_type)))
return x_type return x_type
@ -372,6 +370,7 @@ class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer):
>>> min_gradient shape: (4,) data type: mstype.float32 >>> min_gradient shape: (4,) data type: mstype.float32
>>> max_gradient shape: (4,) data type: mstype.float32 >>> max_gradient shape: (4,) data type: mstype.float32
""" """
@prim_attr_register @prim_attr_register
def __init__(self, def __init__(self,
num_bits=8, num_bits=8,
@ -390,11 +389,10 @@ class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer):
return x_shape, min_shape, max_shape return x_shape, min_shape, max_shape
def infer_dtype(self, dout_type, x_type, min_type, max_type): def infer_dtype(self, dout_type, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32) tuple(map(partial(validator.check_tensor_dtype_valid,
validator.check_tensor_type_same({'dout': dout_type}, valid_types, self.name) valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
validator.check_tensor_type_same({'x': x_type}, valid_types, self.name) ("dout", "x", "min", "max"),
validator.check_tensor_type_same({'min': min_type}, valid_types, self.name) (dout_type, x_type, min_type, max_type)))
validator.check_tensor_type_same({'max': max_type}, valid_types, self.name)
return x_type, min_type, max_type return x_type, min_type, max_type
@ -468,14 +466,12 @@ class FakeQuantPerLayer(PrimitiveWithInfer):
def infer_dtype(self, x_type, min_type, max_type): def infer_dtype(self, x_type, min_type, max_type):
if context.get_context('device_target') == "GPU": if context.get_context('device_target') == "GPU":
valid_types = (mstype.float32,) valid_dtypes = (mstype.float32,)
else: else:
valid_types = (mstype.float16, mstype.float32) valid_dtypes = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
validator.check_tensor_type_same( ("x", "min", "max"),
{"min": min_type}, valid_types, self.name) (x_type, min_type, max_type)))
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return x_type return x_type
@ -525,16 +521,12 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer):
def infer_dtype(self, dout_type, x_type, min_type, max_type): def infer_dtype(self, dout_type, x_type, min_type, max_type):
if context.get_context('device_target') == "GPU": if context.get_context('device_target') == "GPU":
valid_types = (mstype.float32,) valid_dtypes = (mstype.float32,)
else: else:
valid_types = (mstype.float16, mstype.float32) valid_dtypes = (mstype.float16, mstype.float32)
validator.check_tensor_type_same( tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
{"dout": dout_type}, valid_types, self.name) ("dout", "x", "min", "max"),
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) (dout_type, x_type, min_type, max_type)))
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return dout_type return dout_type
@ -623,14 +615,12 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
def infer_dtype(self, x_type, min_type, max_type): def infer_dtype(self, x_type, min_type, max_type):
if context.get_context('device_target') == "GPU": if context.get_context('device_target') == "GPU":
valid_types = (mstype.float32,) valid_dtypes = (mstype.float32,)
else: else:
valid_types = (mstype.float16, mstype.float32) valid_dtypes = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
validator.check_tensor_type_same( ("x", "min", "max"),
{"min": min_type}, valid_types, self.name) (x_type, min_type, max_type)))
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return x_type return x_type
@ -680,16 +670,12 @@ class FakeQuantPerChannelGrad(PrimitiveWithInfer):
def infer_dtype(self, dout_type, x_type, min_type, max_type): def infer_dtype(self, dout_type, x_type, min_type, max_type):
if context.get_context('device_target') == "GPU": if context.get_context('device_target') == "GPU":
valid_types = (mstype.float32,) valid_dtypes = (mstype.float32,)
else: else:
valid_types = (mstype.float16, mstype.float32) valid_dtypes = (mstype.float16, mstype.float32)
validator.check_tensor_type_same( tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
{"dout": dout_type}, valid_types, self.name) ("dout", "x", "min", "max"),
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) (dout_type, x_type, min_type, max_type)))
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return dout_type return dout_type
@ -750,8 +736,8 @@ class BatchNormFold(PrimitiveWithInfer):
validator.check("input type", x_type, "mean type", mean_type) validator.check("input type", x_type, "mean type", mean_type)
validator.check("input type", x_type, "variance type", variance_type) validator.check("input type", x_type, "variance type", variance_type)
args = {"x": x_type, "mean": mean_type, "variance": variance_type} args = {"x": x_type, "mean": mean_type, "variance": variance_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name) validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name)
return x_type, x_type, x_type, x_type return x_type, x_type, x_type, x_type
@ -797,8 +783,8 @@ class BatchNormFoldGrad(PrimitiveWithInfer):
global_step_type): global_step_type):
args = {"input": x_type, "d_batch_mean": d_batch_mean_type, "d_batch_std": d_batch_std_type, args = {"input": x_type, "d_batch_mean": d_batch_mean_type, "d_batch_std": d_batch_std_type,
"batch_mean": batch_mean_type, "batch_std": batch_std_type} "batch_mean": batch_mean_type, "batch_std": batch_std_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name) validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name)
return x_type return x_type
@ -841,7 +827,7 @@ class CorrectionMul(PrimitiveWithInfer):
def infer_dtype(self, x_type, batch_std_type, running_std_type): def infer_dtype(self, x_type, batch_std_type, running_std_type):
args = {"x": x_type, "batch_std": batch_std_type, "running_std": running_std_type} args = {"x": x_type, "batch_std": batch_std_type, "running_std": running_std_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
return x_type return x_type
@ -879,7 +865,7 @@ class CorrectionMulGrad(PrimitiveWithInfer):
def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type): def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type):
args = {"dout": dout_type, "x": x_type, "gamma": gamma_type, "running_std": running_std_type} args = {"dout": dout_type, "x": x_type, "gamma": gamma_type, "running_std": running_std_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
if context.get_context('device_target') == "Ascend": if context.get_context('device_target') == "Ascend":
return x_type, x_type return x_type, x_type
return x_type, gamma_type return x_type, gamma_type
@ -972,8 +958,8 @@ class BatchNormFold2(PrimitiveWithInfer):
running_mean_type, global_step_type): running_mean_type, global_step_type):
args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type, args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type,
"beta": beta_type, "running_mean": running_mean_type, "gamma": gamma_type, "x": x_type} "beta": beta_type, "running_mean": running_mean_type, "gamma": gamma_type, "x": x_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name) validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name)
return x_type return x_type
@ -1031,8 +1017,8 @@ class BatchNormFold2Grad(PrimitiveWithInfer):
"dout type", dout_type) "dout type", dout_type)
args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type, args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type,
"running_std": running_std_type, "running_mean": running_mean_type, "dout": dout_type} "running_std": running_std_type, "running_mean": running_mean_type, "dout": dout_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name) validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name)
return gamma_type, gamma_type, gamma_type, gamma_type, gamma_type return gamma_type, gamma_type, gamma_type, gamma_type, gamma_type
@ -1061,7 +1047,7 @@ class BatchNormFoldD(PrimitiveWithInfer):
validator.check("input type", x_type, "mean type", mean_type) validator.check("input type", x_type, "mean type", mean_type)
validator.check("input type", x_type, "variance type", variance_type) validator.check("input type", x_type, "variance type", variance_type)
args = {"x": x_type, "mean": mean_type, "variance": variance_type} args = {"x": x_type, "mean": mean_type, "variance": variance_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
return x_type, x_type, x_type, x_type, x_type, x_type, x_type return x_type, x_type, x_type, x_type, x_type, x_type, x_type
@ -1090,8 +1076,7 @@ class BatchNormFoldGradD(PrimitiveWithInfer):
validator.check("input type", x_type, "d_batch_std type", d_batch_std_type) validator.check("input type", x_type, "d_batch_std type", d_batch_std_type)
validator.check("input type", x_type, "batch_mean type", batch_mean_type) validator.check("input type", x_type, "batch_mean type", batch_mean_type)
validator.check("input type", x_type, "batch_std type", batch_std_type) validator.check("input type", x_type, "batch_std type", batch_std_type)
args = {"input type": x_type} validator.check_tensor_dtype_valid("input type", x_type, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
return x_type return x_type
@ -1136,7 +1121,7 @@ class BatchNormFold2_D(PrimitiveWithInfer):
def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type): def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type):
args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type, args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type,
"beta": beta_type, "gamma": gamma_type, "x": x_type} "beta": beta_type, "gamma": gamma_type, "x": x_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
return x_type return x_type
@ -1174,7 +1159,7 @@ class BatchNormFold2GradD(PrimitiveWithInfer):
"dout type", dout_type) "dout type", dout_type)
args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type, args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type,
"running_std": running_std_type, "dout": dout_type} "running_std": running_std_type, "dout": dout_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
return gamma_type, gamma_type, gamma_type, gamma_type return gamma_type, gamma_type, gamma_type, gamma_type

View File

@ -165,7 +165,7 @@ class CusFusedAbsMax1(PrimitiveWithInfer):
def infer_shape(self, data1_shape): def infer_shape(self, data1_shape):
ll = [] ll = []
if len(data1_shape) == 2: if len(data1_shape) == 2:
ll = [1,] ll = [1]
else: else:
ll = [32, 64] ll = [32, 64]
return ll return ll
@ -497,6 +497,7 @@ class Im2Col(PrimitiveWithInfer):
>>> img2col = P.CusMatMulCubeDenseLeft(kernel_size=7, pad=3, stride=2) >>> img2col = P.CusMatMulCubeDenseLeft(kernel_size=7, pad=3, stride=2)
>>> output = img2col(input_x) >>> output = img2col(input_x)
""" """
@prim_attr_register @prim_attr_register
def __init__(self, def __init__(self,
kernel_size, kernel_size,
@ -556,9 +557,8 @@ class Im2Col(PrimitiveWithInfer):
return out_shape return out_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
args = {'x': x_dtype} valid_dtypes = [mstype.float16, mstype.float32]
valid_types = [mstype.float16, mstype.float32] validator.check_tensor_dtype_valid('x', x_dtype, valid_dtypes, self.name)
validator.check_tensor_type_same(args, valid_types, self.name)
return x_dtype return x_dtype
@ -602,14 +602,17 @@ class UpdateThorGradient(PrimitiveWithInfer):
return x2_shape return x2_shape
def infer_dtype(self, x1_dtype, x2_dtype, x3_dtype): def infer_dtype(self, x1_dtype, x2_dtype, x3_dtype):
validator.check_tensor_type_same({'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'x3_dtype': x3_dtype}, validator.check_tensors_dtypes_same_and_valid(
[mstype.float32], self.name) {'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'x3_dtype': x3_dtype},
[mstype.float32], self.name)
return x2_dtype return x2_dtype
class Cholesky(PrimitiveWithInfer): class Cholesky(PrimitiveWithInfer):
""" """
Inner API for resnet50 THOR GPU backend Inner API for resnet50 THOR GPU backend
""" """
@prim_attr_register @prim_attr_register
def __init__(self, split_dim=0): def __init__(self, split_dim=0):
self.init_prim_io_names(inputs=['x1'], outputs=['y']) self.init_prim_io_names(inputs=['x1'], outputs=['y'])
@ -634,13 +637,15 @@ class Cholesky(PrimitiveWithInfer):
return out_shape return out_shape
def infer_dtype(self, x1_dtype): def infer_dtype(self, x1_dtype):
validator.check_tensor_type_same({'x1_dtype': x1_dtype}, [mstype.float32], self.name) validator.check_tensor_dtype_valid('x1', x1_dtype, [mstype.float32], self.name)
return x1_dtype return x1_dtype
class DetTriangle(PrimitiveWithInfer): class DetTriangle(PrimitiveWithInfer):
""" """
Calculate the determinant of triangle matrices Calculate the determinant of triangle matrices
""" """
@prim_attr_register @prim_attr_register
def __init__(self, fill_mode=0): def __init__(self, fill_mode=0):
self.init_prim_io_names(inputs=['x1'], outputs=['y']) self.init_prim_io_names(inputs=['x1'], outputs=['y'])
@ -653,5 +658,5 @@ class DetTriangle(PrimitiveWithInfer):
return out_shape return out_shape
def infer_dtype(self, x1_dtype): def infer_dtype(self, x1_dtype):
validator.check_tensor_type_same({'x1_dtype': x1_dtype}, [mstype.float32], self.name) validator.check_tensor_dtype_valid('x1', x1_dtype, [mstype.float32], self.name)
return x1_dtype return x1_dtype

View File

@ -63,9 +63,9 @@ class _ScatterOp(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name) validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
args = {"x": x_dtype, "updates": updates_dtype} args = {"x": x_dtype, "updates": updates_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return x_dtype return x_dtype
@ -73,6 +73,7 @@ class _ScatterNdOp(_ScatterOp):
""" """
Defines _ScatterNd operators Defines _ScatterNd operators
""" """
def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name): def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name):
validator.check('the dimension of x', len(x_shape), validator.check('the dimension of x', len(x_shape),
'the dimension of indices', indices_shape[-1], Rel.GE) 'the dimension of indices', indices_shape[-1], Rel.GE)
@ -627,6 +628,7 @@ class Unique(Primitive):
>>> out = P.Unique()(x) >>> out = P.Unique()(x)
(Tensor([1, 2, 5], mindspore.int32), Tensor([0, 1, 2, 1], mindspore.int32)) (Tensor([1, 2, 5], mindspore.int32), Tensor([0, 1, 2, 1], mindspore.int32))
""" """
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
self.init_prim_io_names(inputs=['x'], outputs=['output']) self.init_prim_io_names(inputs=['x'], outputs=['output'])
@ -661,11 +663,11 @@ class GatherV2(PrimitiveWithCheck):
def __init__(self): def __init__(self):
"""Initialize index_select""" """Initialize index_select"""
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
self.add_prim_attr("dynamic_shape_depends", [2,]) self.add_prim_attr("dynamic_shape_depends", [2])
def __check__(self, params, indices, axis): def __check__(self, params, indices, axis):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name) validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name)
axis_v = axis['value'] axis_v = axis['value']
params_shp = params['shape'] params_shp = params['shape']
@ -727,6 +729,7 @@ class Padding(PrimitiveWithInfer):
>>> out = P.Padding(pad_dim_size)(x) >>> out = P.Padding(pad_dim_size)(x)
[[8, 0, 0, 0], [10, 0, 0, 0]] [[8, 0, 0, 0], [10, 0, 0, 0]]
""" """
@prim_attr_register @prim_attr_register
def __init__(self, pad_dim_size=8): def __init__(self, pad_dim_size=8):
"""Initialize padding""" """Initialize padding"""
@ -766,12 +769,13 @@ class UniqueWithPad(PrimitiveWithInfer):
>>> out = P.UniqueWithPad()(x, pad_num) >>> out = P.UniqueWithPad()(x, pad_num)
([1, 5, 4, 3, 2, 8, 8, 8, 8, 8], [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]) ([1, 5, 4, 3, 2, 8, 8, 8, 8, 8], [0, 0, 1, 1, 2, 2, 3, 3, 4, 4])
""" """
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""init UniqueWithPad""" """init UniqueWithPad"""
def __infer__(self, x, pad_num): def __infer__(self, x, pad_num):
validator.check_tensor_type_same({"x": x['dtype']}, [mstype.int32, mstype.int64], self.name) validator.check_tensor_dtype_valid("x", x['dtype'], [mstype.int32, mstype.int64], self.name)
validator.check_subclass("pad_num", pad_num['dtype'], [mstype.int32, mstype.int64], self.name) validator.check_subclass("pad_num", pad_num['dtype'], [mstype.int32, mstype.int64], self.name)
x_shape = list(x['shape']) x_shape = list(x['shape'])
validator.check("rank of x", len(x_shape), "expected", 1, Rel.EQ, self.name) validator.check("rank of x", len(x_shape), "expected", 1, Rel.EQ, self.name)
@ -903,7 +907,7 @@ class TruncatedNormal(PrimitiveWithInfer):
def __init__(self, seed=0, dtype=mstype.float32): def __init__(self, seed=0, dtype=mstype.float32):
"""Initialize TruncatedNormal""" """Initialize TruncatedNormal"""
validator.check_value_type('seed', seed, [int], self.name) validator.check_value_type('seed', seed, [int], self.name)
validator.check_type_same({'dtype': dtype}, mstype.number_type, self.name) validator.check_types_same_and_valid({'dtype': dtype}, mstype.number_type, self.name)
def __infer__(self, shape): def __infer__(self, shape):
shape_value = shape['value'] shape_value = shape['value']
@ -984,10 +988,10 @@ class Fill(PrimitiveWithInfer):
validator.check_value_type("value", x['value'], [numbers.Number, bool], self.name) validator.check_value_type("value", x['value'], [numbers.Number, bool], self.name)
for i, item in enumerate(dims['value']): for i, item in enumerate(dims['value']):
validator.check_positive_int(item, f'dims[{i}]', self.name) validator.check_positive_int(item, f'dims[{i}]', self.name)
valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64, valid_dtypes = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64,
mstype.uint8, mstype.uint32, mstype.uint64, mstype.uint8, mstype.uint32, mstype.uint64,
mstype.float16, mstype.float32, mstype.float64] mstype.float16, mstype.float32, mstype.float64]
validator.check_type_same({"value": dtype['value']}, valid_types, self.name) validator.check_types_same_and_valid({"value": dtype['value']}, valid_dtypes, self.name)
x_nptype = mstype.dtype_to_nptype(dtype['value']) x_nptype = mstype.dtype_to_nptype(dtype['value'])
ret = np.full(dims['value'], x['value'], x_nptype) ret = np.full(dims['value'], x['value'], x_nptype)
out = { out = {
@ -1026,7 +1030,7 @@ class OnesLike(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name) validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type + (mstype.bool_,), self.name)
return x_dtype return x_dtype
@ -1059,7 +1063,7 @@ class ZerosLike(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name) validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type + (mstype.bool_,), self.name)
return x_dtype return x_dtype
@ -1264,7 +1268,7 @@ class Argmax(PrimitiveWithInfer):
"""Initialize Argmax""" """Initialize Argmax"""
self.init_prim_io_names(inputs=['x'], outputs=['output']) self.init_prim_io_names(inputs=['x'], outputs=['output'])
validator.check_value_type("axis", axis, [int], self.name) validator.check_value_type("axis", axis, [int], self.name)
validator.check_type_same({'output': output_type}, [mstype.int32], self.name) validator.check_types_same_and_valid({'output': output_type}, [mstype.int32], self.name)
self.axis = axis self.axis = axis
self.add_prim_attr('output_type', output_type) self.add_prim_attr('output_type', output_type)
@ -1547,7 +1551,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
def __init__(self): def __init__(self):
"""Initialize UnsortedSegmentSum""" """Initialize UnsortedSegmentSum"""
self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y']) self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y'])
self.add_prim_attr("dynamic_shape_depends", [2,]) self.add_prim_attr("dynamic_shape_depends", [2])
def __infer__(self, x, segment_ids, num_segments): def __infer__(self, x, segment_ids, num_segments):
x_type = x['dtype'] x_type = x['dtype']
@ -1570,7 +1574,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
num_segments_type = num_segments['dtype'] num_segments_type = num_segments['dtype']
validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name) validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name)
if isinstance(num_segments_type, type(mstype.tensor)): if isinstance(num_segments_type, type(mstype.tensor)):
validator.check_tensor_type_same({"num_segments": num_segments_type}, [mstype.int32], self.name) validator.check_tensor_dtype_valid("num_segments", num_segments_type, [mstype.int32], self.name)
shp = [-1] shp = [-1]
else: else:
validator.check_value_type('num_segments', num_segments_v, [int], self.name) validator.check_value_type('num_segments', num_segments_v, [int], self.name)
@ -1623,8 +1627,8 @@ class UnsortedSegmentMin(PrimitiveWithInfer):
x_shape = x['shape'] x_shape = x['shape']
segment_ids_shape = segment_ids['shape'] segment_ids_shape = segment_ids['shape']
valid_type = [mstype.float16, mstype.float32, mstype.int32] valid_type = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name) validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name)
validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name) validator.check_tensor_dtype_valid("segment_ids", segment_ids['dtype'], [mstype.int32], self.name)
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
validator.check(f'first shape of input_x', x_shape[0], validator.check(f'first shape of input_x', x_shape[0],
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
@ -1673,8 +1677,8 @@ class UnsortedSegmentMax(PrimitiveWithInfer):
x_shape = x['shape'] x_shape = x['shape']
segment_ids_shape = segment_ids['shape'] segment_ids_shape = segment_ids['shape']
valid_type = [mstype.float16, mstype.float32, mstype.int32] valid_type = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name) validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name)
validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name) validator.check_tensors_dtypes_same_and_valid({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name)
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
validator.check(f'first shape of input_x', x_shape[0], validator.check(f'first shape of input_x', x_shape[0],
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
@ -1726,8 +1730,8 @@ class UnsortedSegmentProd(PrimitiveWithInfer):
validator.check_subclass("input_x", x_type, mstype.tensor, self.name) validator.check_subclass("input_x", x_type, mstype.tensor, self.name)
validator.check_value_type("x_shape", x_shape, [list], self.name) validator.check_value_type("x_shape", x_shape, [list], self.name)
valid_type = [mstype.float16, mstype.float32, mstype.int32] valid_type = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name) validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name)
validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name) validator.check_tensor_dtype_valid("segment_ids", segment_ids['dtype'], [mstype.int32], self.name)
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
validator.check(f'first shape of input_x', x_shape[0], validator.check(f'first shape of input_x', x_shape[0],
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
@ -1833,7 +1837,7 @@ class ParallelConcat(PrimitiveWithInfer):
validator.check_int(len(x_shp), 1, Rel.GE, f'x_shp length', self.name) validator.check_int(len(x_shp), 1, Rel.GE, f'x_shp length', self.name)
args = {f"x_type[{i}]": elem for i, elem in enumerate(x_type)} args = {f"x_type[{i}]": elem for i, elem in enumerate(x_type)}
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name) validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type + (mstype.bool_,), self.name)
first_elem = x_shp[0] first_elem = x_shp[0]
for i, elem in enumerate(x_shp[1:]): for i, elem in enumerate(x_shp[1:]):
@ -2070,7 +2074,7 @@ class ReverseV2(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, (mstype.bool_,) + mstype.number_type, self.name) validator.check_tensor_dtype_valid('x', x_dtype, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype return x_dtype
@ -2100,7 +2104,7 @@ class Rint(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float16, mstype.float32], self.name) validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float16, mstype.float32], self.name)
return x_dtype return x_dtype
@ -2167,7 +2171,7 @@ class Select(PrimitiveWithInfer):
self.add_prim_attr('T', x_type) self.add_prim_attr('T', x_type)
validator.check_subclass("x_type", x_type, mstype.tensor, self.name) validator.check_subclass("x_type", x_type, mstype.tensor, self.name)
validator.check_subclass("y_type", y_type, mstype.tensor, self.name) validator.check_subclass("y_type", y_type, mstype.tensor, self.name)
validator.check_tensor_type_same({"cond": cond_type}, [mstype.bool_], self.name) validator.check_tensor_dtype_valid("cond", cond_type, [mstype.bool_], self.name)
if x_type != y_type: if x_type != y_type:
raise TypeError('\'%s\' the x_type %s must be the same as y_type %s.' % (self.name, x_type, y_type)) raise TypeError('\'%s\' the x_type %s must be the same as y_type %s.' % (self.name, x_type, y_type))
return x_type return x_type
@ -2542,7 +2546,7 @@ class Eye(PrimitiveWithInfer):
validator.check_positive_int(n, "n", self.name) validator.check_positive_int(n, "n", self.name)
validator.check_positive_int(m, "m", self.name) validator.check_positive_int(m, "m", self.name)
args = {"dtype": t} args = {"dtype": t}
validator.check_type_same(args, mstype.number_type + (mstype.bool_,), self.name) validator.check_types_same_and_valid(args, mstype.number_type + (mstype.bool_,), self.name)
np_type = mstype.dtype_to_nptype(t) np_type = mstype.dtype_to_nptype(t)
ret = np.eye(n, m, dtype=np_type) ret = np.eye(n, m, dtype=np_type)
return Tensor(ret) return Tensor(ret)
@ -2581,7 +2585,7 @@ class ScatterNd(PrimitiveWithInfer):
def __infer__(self, indices, update, shape): def __infer__(self, indices, update, shape):
shp = shape['value'] shp = shape['value']
validator.check_subclass("update_dtype", update['dtype'], mstype.tensor, self.name) validator.check_subclass("update_dtype", update['dtype'], mstype.tensor, self.name)
validator.check_tensor_type_same({"indices": indices['dtype']}, [mstype.int32], self.name) validator.check_tensor_dtype_valid("indices", indices['dtype'], [mstype.int32], self.name)
validator.check_value_type("shape", shp, [tuple], self.name) validator.check_value_type("shape", shp, [tuple], self.name)
for i, x in enumerate(shp): for i, x in enumerate(shp):
validator.check_positive_int(x, f'shape[{i}]', self.name) validator.check_positive_int(x, f'shape[{i}]', self.name)
@ -2632,14 +2636,13 @@ class ResizeNearestNeighbor(PrimitiveWithInfer):
validator.check_non_negative_int(value, f'{i}th value of size', self.name) validator.check_non_negative_int(value, f'{i}th value of size', self.name)
self.init_prim_io_names(inputs=['image_in'], outputs=['image_out']) self.init_prim_io_names(inputs=['image_in'], outputs=['image_out'])
def infer_shape(self, x): def infer_shape(self, x_shape):
validator.check('the dimension of input_x', len(x), '', 4, Rel.EQ, self.name) validator.check('the dimension of input_x', len(x_shape), '', 4, Rel.EQ, self.name)
return tuple(x)[:-2] + tuple(self.size) return tuple(x_shape)[:-2] + tuple(self.size)
def infer_dtype(self, x): def infer_dtype(self, x_dtype):
validator.check_subclass("x", x, mstype.tensor, self.name) validator.check_tensor_dtype_valid("x", x_dtype, mstype.number_type, self.name)
validator.check_tensor_type_same({"x": x}, mstype.number_type, self.name) return x_dtype
return x
class GatherNd(PrimitiveWithInfer): class GatherNd(PrimitiveWithInfer):
@ -2674,8 +2677,7 @@ class GatherNd(PrimitiveWithInfer):
return indices_shape[:-1] + x_shape[indices_shape[-1]:] return indices_shape[:-1] + x_shape[indices_shape[-1]:]
def infer_dtype(self, x_dtype, indices_dtype): def infer_dtype(self, x_dtype, indices_dtype):
validator.check_subclass("x_dtype", x_dtype, mstype.tensor, self.name) validator.check_tensor_dtype_valid("indices", indices_dtype, mstype.int_type, self.name)
validator.check_tensor_type_same({"indices": indices_dtype}, mstype.int_type, self.name)
return x_dtype return x_dtype
@ -2715,9 +2717,9 @@ class TensorScatterUpdate(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype, indices_dtype, value_dtype): def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name) validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
args = {"x": x_dtype, "value": value_dtype} args = {"x": x_dtype, "value": value_dtype}
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype return x_dtype
@ -2763,9 +2765,9 @@ class ScatterUpdate(_ScatterOp):
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
def infer_dtype(self, x_dtype, indices_dtype, value_dtype): def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name) validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
args = {"x": x_dtype, "value": value_dtype} args = {"x": x_dtype, "value": value_dtype}
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype return x_dtype
@ -2802,7 +2804,6 @@ class ScatterNdUpdate(_ScatterNdOp):
[0.4 2.2 -3.2]] [0.4 2.2 -3.2]]
""" """
@prim_attr_register @prim_attr_register
def __init__(self, use_locking=True): def __init__(self, use_locking=True):
"""Initialize ScatterNdUpdate""" """Initialize ScatterNdUpdate"""
@ -2810,9 +2811,9 @@ class ScatterNdUpdate(_ScatterNdOp):
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y']) self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
def infer_dtype(self, x_dtype, indices_dtype, value_dtype): def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name) validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
args = {"x": x_dtype, "value": value_dtype} args = {"x": x_dtype, "value": value_dtype}
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype return x_dtype
@ -3131,9 +3132,9 @@ class ScatterNonAliasingAdd(_ScatterNdOp):
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name) validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
args = {"x": x_dtype, "updates": updates_dtype} args = {"x": x_dtype, "updates": updates_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int32], self.name) validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32, mstype.int32], self.name)
return x_dtype return x_dtype
@ -3304,7 +3305,7 @@ class SpaceToBatch(PrimitiveWithInfer):
self.paddings = paddings self.paddings = paddings
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name) validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name)
return x_dtype return x_dtype
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
@ -3376,7 +3377,7 @@ class BatchToSpace(PrimitiveWithInfer):
self.crops = crops self.crops = crops
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name) validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name)
return x_dtype return x_dtype
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
@ -3465,7 +3466,7 @@ class SpaceToBatchND(PrimitiveWithInfer):
self.add_prim_attr("paddings", paddings_append) self.add_prim_attr("paddings", paddings_append)
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name) validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name)
return x_dtype return x_dtype
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
@ -3558,7 +3559,7 @@ class BatchToSpaceND(PrimitiveWithInfer):
self.add_prim_attr("crops", crops_append) self.add_prim_attr("crops", crops_append)
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name) validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name)
return x_dtype return x_dtype
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
@ -3721,7 +3722,6 @@ class Meshgrid(PrimitiveWithInfer):
out_shape = tuple(tuple(shape_0) for _ in range(n)) out_shape = tuple(tuple(shape_0) for _ in range(n))
return out_shape return out_shape
def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, self.name) validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, self.name)
n = len(x_type) n = len(x_type)
@ -3729,6 +3729,7 @@ class Meshgrid(PrimitiveWithInfer):
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0], Rel.EQ, self.name, TypeError) validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0], Rel.EQ, self.name, TypeError)
return x_type return x_type
class InplaceUpdate(PrimitiveWithInfer): class InplaceUpdate(PrimitiveWithInfer):
r""" r"""
Updates specified rows with values in `v`. Updates specified rows with values in `v`.
@ -3771,7 +3772,7 @@ class InplaceUpdate(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, v_dtype): def infer_dtype(self, x_dtype, v_dtype):
args = {'x': x_dtype, 'v': v_dtype} args = {'x': x_dtype, 'v': v_dtype}
valid_type = [mstype.int32, mstype.float16, mstype.float32] valid_type = [mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
return x_dtype return x_dtype
def infer_shape(self, x_shape, v_shape): def infer_shape(self, x_shape, v_shape):
@ -3831,8 +3832,8 @@ class ReverseSequence(PrimitiveWithInfer):
return x return x
def infer_dtype(self, x, seq_lengths): def infer_dtype(self, x, seq_lengths):
validator.check_tensor_type_same({"x_dtype": x}, mstype.number_type + (mstype.bool_,), self.name) validator.check_tensor_dtype_valid("x_dtype", x, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensor_type_same({"seq_lengths_dtype": seq_lengths}, [mstype.int32, mstype.int64], self.name) validator.check_tensor_dtype_valid("seq_lengths_dtype", seq_lengths, [mstype.int32, mstype.int64], self.name)
return x return x
@ -3899,9 +3900,9 @@ class EditDistance(PrimitiveWithInfer):
validator.check_const_input('truth_shape', truth_shape['value'], self.name) validator.check_const_input('truth_shape', truth_shape['value'], self.name)
args_int = {"hypothesis_indices": h_indices['dtype'], "hypothesis_shape": h_shape['dtype'], args_int = {"hypothesis_indices": h_indices['dtype'], "hypothesis_shape": h_shape['dtype'],
"truth_indices": truth_indices['dtype'], "truth_shape": truth_shape['dtype']} "truth_indices": truth_indices['dtype'], "truth_shape": truth_shape['dtype']}
validator.check_tensor_type_same(args_int, [mstype.int64], self.name) validator.check_tensors_dtypes_same_and_valid(args_int, [mstype.int64], self.name)
args = {"hypothesis_values": h_values['dtype'], "truth_values": truth_values['dtype']} args = {"hypothesis_values": h_values['dtype'], "truth_values": truth_values['dtype']}
validator.check_tensor_type_same(args, mstype.number_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
hypothesis_indices_shp, truth_indices_shp = h_indices['shape'], truth_indices['shape'] hypothesis_indices_shp, truth_indices_shp = h_indices['shape'], truth_indices['shape']
validator.check("hypothesis_indices rank", len(hypothesis_indices_shp), "expected", 2, Rel.EQ, self.name) validator.check("hypothesis_indices rank", len(hypothesis_indices_shp), "expected", 2, Rel.EQ, self.name)
@ -3941,6 +3942,7 @@ class TransShape(PrimitiveWithInfer):
Outputs: Outputs:
Tensor, a tensor whose data type is same as 'input_x', and the shape is the same as the `out_shape`. Tensor, a tensor whose data type is same as 'input_x', and the shape is the same as the `out_shape`.
""" """
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
self.__setattr_flag__ = True self.__setattr_flag__ = True
@ -3948,7 +3950,7 @@ class TransShape(PrimitiveWithInfer):
def __infer__(self, x, shape): def __infer__(self, x, shape):
shp = shape['value'] shp = shape['value']
dtype = x['dtype'] dtype = x['dtype']
validator.check_tensor_type_same({'x': dtype}, mstype.number_type + (mstype.bool_,), self.name) validator.check_tensor_dtype_valid('x', dtype, mstype.number_type + (mstype.bool_,), self.name)
self.add_prim_attr('out_shape', tuple(shp)) self.add_prim_attr('out_shape', tuple(shp))
return {'shape': shp, return {'shape': shp,
'dtype': dtype, 'dtype': dtype,
@ -3989,7 +3991,7 @@ class Sort(PrimitiveWithInfer):
return x_shape, x_shape return x_shape, x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x_dtype": x_dtype}, [mstype.float32, mstype.float16], self.name) validator.check_tensor_dtype_valid("x_dtype", x_dtype, [mstype.float32, mstype.float16], self.name)
return x_dtype, mstype.tensor_type(mstype.int32) return x_dtype, mstype.tensor_type(mstype.int32)
@ -4019,6 +4021,7 @@ class EmbeddingLookup(PrimitiveWithInfer):
>>> out = P.EmbeddingLookup()(input_params, input_indices, offset) >>> out = P.EmbeddingLookup()(input_params, input_indices, offset)
[[[10, 11], [0 ,0]], [[0, 0], [10, 11]]] [[[10, 11], [0 ,0]], [[0, 0], [10, 11]]]
""" """
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""Initialize index_select""" """Initialize index_select"""
@ -4028,7 +4031,7 @@ class EmbeddingLookup(PrimitiveWithInfer):
def __infer__(self, params, indices, offset): def __infer__(self, params, indices, offset):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name) validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name)
params_shp = params['shape'] params_shp = params['shape']
if len(params_shp) != 2: if len(params_shp) != 2:
@ -4060,6 +4063,7 @@ class GatherD(PrimitiveWithInfer):
>>> out = P.GatherD()(x, dim, index) >>> out = P.GatherD()(x, dim, index)
[[1, 1], [4, 3]] [[1, 1], [4, 3]]
""" """
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""Initialize GatherD""" """Initialize GatherD"""
@ -4067,7 +4071,7 @@ class GatherD(PrimitiveWithInfer):
def __infer__(self, x, dim, index): def __infer__(self, x, dim, index):
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
validator.check_tensor_type_same({"index": index['dtype']}, [mstype.int32, mstype.int64], self.name) validator.check_tensor_dtype_valid("index", index['dtype'], [mstype.int32, mstype.int64], self.name)
validator.check_subclass("dim", dim['dtype'], mstype.int32, self.name) validator.check_subclass("dim", dim['dtype'], mstype.int32, self.name)
x_shp = x['shape'] x_shp = x['shape']
idx_shp = index['shape'] idx_shp = index['shape']
@ -4103,6 +4107,7 @@ class Identity(PrimitiveWithInfer):
>>> y = P.Identity()(x) >>> y = P.Identity()(x)
[1, 2, 3, 4] [1, 2, 3, 4]
""" """
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""Initialize identity""" """Initialize identity"""

View File

@ -105,7 +105,7 @@ class AllReduce(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
return x_dtype return x_dtype
@ -167,7 +167,7 @@ class AllGather(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
return x_dtype return x_dtype
def __call__(self, tensor): def __call__(self, tensor):
@ -217,7 +217,7 @@ class _HostAllGather(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
return x_dtype return x_dtype
def __call__(self, tensor): def __call__(self, tensor):
@ -279,7 +279,7 @@ class ReduceScatter(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
return x_dtype return x_dtype
def __call__(self, tensor): def __call__(self, tensor):
@ -328,7 +328,7 @@ class _HostReduceScatter(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
return x_dtype return x_dtype
def __call__(self, tensor): def __call__(self, tensor):
@ -390,7 +390,7 @@ class Broadcast(PrimitiveWithInfer):
if not isinstance(x_dtype, tuple): if not isinstance(x_dtype, tuple):
raise TypeError(f"{self.name}'s input should be a tuple!") raise TypeError(f"{self.name}'s input should be a tuple!")
for _ele in x_dtype: for _ele in x_dtype:
validator.check_tensor_type_same({'x': _ele}, target_dtypes, self.name) validator.check_tensor_dtype_valid('x', _ele, target_dtypes, self.name)
return x_dtype return x_dtype
@ -432,7 +432,7 @@ class _AlltoAll(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
return x_dtype return x_dtype
def __call__(self, tensor): def __call__(self, tensor):

View File

@ -132,8 +132,7 @@ class GeSwitch(PrimitiveWithInfer):
def infer_dtype(self, data_type, pred_type): def infer_dtype(self, data_type, pred_type):
validator.check_subclass( validator.check_subclass(
"data", data_type, (mstype.tensor,) + mstype.number_type, self.name) "data", data_type, (mstype.tensor,) + mstype.number_type, self.name)
validator.check_tensor_type_same( validator.check_tensor_dtype_valid("pred", pred_type, [mstype.bool_], self.name)
{"pred": pred_type}, [mstype.bool_], self.name)
return (data_type, data_type) return (data_type, data_type)
@ -171,5 +170,5 @@ class Merge(PrimitiveWithInfer):
for i, item in enumerate(inputs): for i, item in enumerate(inputs):
args['inputs[%d]' % i] = item args['inputs[%d]' % i] = item
validator.check_scalar_or_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) validator.check_scalar_or_tensor_types_same(args, (mstype.bool_,) + mstype.number_type, self.name)
return (inputs[0], mstype.int32) return (inputs[0], mstype.int32)

View File

@ -380,7 +380,7 @@ class Assert(PrimitiveWithInfer):
return [1] return [1]
def infer_dtype(self, condition, inputs): def infer_dtype(self, condition, inputs):
validator.check_scalar_or_tensor_type_same({"condition": condition}, [mstype.bool_], self.name) validator.check_scalar_or_tensor_types_same({"condition": condition}, [mstype.bool_], self.name)
for dtype in inputs: for dtype in inputs:
validator.check_subclass("input", dtype, [mstype.tensor], self.name) validator.check_subclass("input", dtype, [mstype.tensor], self.name)
return mstype.int32 return mstype.int32

View File

@ -104,11 +104,11 @@ class CropAndResize(PrimitiveWithInfer):
box_index_dtype = box_index['dtype'] box_index_dtype = box_index['dtype']
crop_size_dtype = crop_size['dtype'] crop_size_dtype = crop_size['dtype']
# check dytpe # check dytpe
validator.check_tensor_type_same({"x": x_dtype}, validator.check_tensor_dtype_valid("x", x_dtype,
[mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.float16, [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.float16,
mstype.float32, mstype.float64, mstype.uint8, mstype.uint16], self.name) mstype.float32, mstype.float64, mstype.uint8, mstype.uint16], self.name)
validator.check_tensor_type_same({"boxes": boxes_dtype}, [mstype.float32], self.name) validator.check_tensor_dtype_valid("boxes", boxes_dtype, [mstype.float32], self.name)
validator.check_tensor_type_same({"box_index": box_index_dtype}, [mstype.int32], self.name) validator.check_tensor_dtype_valid("box_index", box_index_dtype, [mstype.int32], self.name)
validator.check_value_type("crop_size", crop_size_value, [tuple], self.name) validator.check_value_type("crop_size", crop_size_value, [tuple], self.name)
# check input shape rank # check input shape rank
validator.check("x rank", len(x_shape), "expected", 4, Rel.EQ, self.name) validator.check("x rank", len(x_shape), "expected", 4, Rel.EQ, self.name)

View File

@ -16,6 +16,8 @@
"""Operators for math.""" """Operators for math."""
import copy import copy
from functools import partial
import numpy as np import numpy as np
from ... import context from ... import context
from .. import signature as sig from .. import signature as sig
@ -85,7 +87,7 @@ class _MathBinaryOp(_BinaryOp):
@staticmethod @staticmethod
def do_infer_dtype(x_dtype, y_dtype, valid_dtype=mstype.number_type, prim_name=None): def do_infer_dtype(x_dtype, y_dtype, valid_dtype=mstype.number_type, prim_name=None):
args_type = {"x": x_dtype, "y": y_dtype} args_type = {"x": x_dtype, "y": y_dtype}
validator.check_tensor_type_same(args_type, valid_dtype, prim_name) validator.check_tensors_dtypes_same_and_valid(args_type, valid_dtype, prim_name)
return x_dtype return x_dtype
def infer_dtype(self, x_dtype, y_dtype): def infer_dtype(self, x_dtype, y_dtype):
@ -105,8 +107,8 @@ class _BitwiseBinaryOp(_MathBinaryOp):
@staticmethod @staticmethod
def _check_bitwise_op_input_type(x1_type, x2_type, prim): def _check_bitwise_op_input_type(x1_type, x2_type, prim):
args = {'x1': x1_type, 'x2': x2_type} args = {'x1': x1_type, 'x2': x2_type}
valid_types = mstype.int_type + mstype.uint_type valid_dtypes = mstype.int_type + mstype.uint_type
validator.check_tensor_type_same(args, valid_types, prim) validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, prim)
return x1_type return x1_type
def infer_dtype(self, x1_type, x2_type): def infer_dtype(self, x1_type, x2_type):
@ -198,7 +200,7 @@ class AssignAdd(PrimitiveWithInfer):
def infer_dtype(self, variable, value): def infer_dtype(self, variable, value):
args = {"variable": variable, "value": value} args = {"variable": variable, "value": value}
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name) validator.check_scalar_or_tensor_types_same(args, mstype.number_type, self.name)
return value return value
@ -248,7 +250,7 @@ class AssignSub(PrimitiveWithInfer):
def infer_dtype(self, variable, value): def infer_dtype(self, variable, value):
args = {"variable": variable, "value": value} args = {"variable": variable, "value": value}
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name) validator.check_scalar_or_tensor_types_same(args, mstype.number_type, self.name)
return value return value
@ -283,7 +285,7 @@ class _Reduce(PrimitiveWithInfer):
axis_v = axis['value'] axis_v = axis['value']
input_shp = input_x['shape'] input_shp = input_x['shape']
args = {'input_x': input_x['dtype']} args = {'input_x': input_x['dtype']}
validator.check_tensor_type_same(args, valid_dtype, self.name) validator.check_tensors_dtypes_same_and_valid(args, valid_dtype, self.name)
if axis_v is None: if axis_v is None:
raise ValueError(f"For {self.name}, axis must be const.") raise ValueError(f"For {self.name}, axis must be const.")
@ -504,6 +506,7 @@ class ReduceMax(_Reduce):
def __infer__(self, input_x, axis): def __infer__(self, input_x, axis):
return self.do_infer(input_x, axis, mstype.number_type + (mstype.bool_,)) return self.do_infer(input_x, axis, mstype.number_type + (mstype.bool_,))
class ReduceMin(_Reduce): class ReduceMin(_Reduce):
""" """
Reduce a dimension of a tensor by the minimum value in the dimension. Reduce a dimension of a tensor by the minimum value in the dimension.
@ -612,7 +615,7 @@ class CumProd(PrimitiveWithInfer):
def infer_dtype(self, x_type, axis_type): def infer_dtype(self, x_type, axis_type):
cls_name = self.name cls_name = self.name
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, cls_name) validator.check_tensor_dtype_valid('x', x_type, mstype.number_type, cls_name)
validator.check_subclass("axis", axis_type, mstype.int_, cls_name) validator.check_subclass("axis", axis_type, mstype.int_, cls_name)
return x_type return x_type
@ -689,7 +692,7 @@ class MatMul(PrimitiveWithInfer):
def infer_dtype(self, x1, x2): def infer_dtype(self, x1, x2):
args = {"x1": x1, "x2": x2} args = {"x1": x1, "x2": x2}
validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type + mstype.int_type, self.name)
if x1.element_type() == mstype.int8: if x1.element_type() == mstype.int8:
return mstype.tensor_type(mstype.int32) return mstype.tensor_type(mstype.int32)
return x1 return x1
@ -801,10 +804,10 @@ class TensorDot(PrimitiveWithInfer):
self.axes = axes self.axes = axes
validator.check_value_type('axes', axes, [int, tuple, list], self.name) validator.check_value_type('axes', axes, [int, tuple, list], self.name)
if not isinstance(self.axes, int): if not isinstance(self.axes, int):
self.axes = list(self.axes) # to avoid immutability issues self.axes = list(self.axes) # to avoid immutability issues
if len(self.axes) != 2: if len(self.axes) != 2:
raise ValueError("Require two axes inputs, given less") raise ValueError("Require two axes inputs, given less")
self.int_to_tuple_conv() # convert before length checks self.int_to_tuple_conv() # convert before length checks
if len(self.axes[0]) != len(self.axes[1]): if len(self.axes[0]) != len(self.axes[1]):
raise ValueError("Axes have to be the same size/length") raise ValueError("Axes have to be the same size/length")
if len(self.axes[0]) != len(set(self.axes[0])) or len(self.axes[1]) != len(set(self.axes[1])): if len(self.axes[0]) != len(set(self.axes[0])) or len(self.axes[1]) != len(set(self.axes[1])):
@ -825,7 +828,7 @@ class TensorDot(PrimitiveWithInfer):
if isinstance(self.axes, int): if isinstance(self.axes, int):
if self.axes <= 0: if self.axes <= 0:
# outer product, no input validation required # outer product, no input validation required
self.axes = ([], []) # no axes selected for either self.axes = ([], []) # no axes selected for either
return return
if self.axes > len(x1_shape) or self.axes > len(x2_shape): if self.axes > len(x1_shape) or self.axes > len(x2_shape):
raise ValueError( raise ValueError(
@ -877,8 +880,8 @@ class TensorDot(PrimitiveWithInfer):
def infer_dtype(self, x1, x2): def infer_dtype(self, x1, x2):
args = {"x1": x1, "x2": x2} args = {"x1": x1, "x2": x2}
valid_types = [mstype.float16, mstype.float32] valid_dtypes = [mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_types, self.name) validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
return x1 return x1
@ -922,8 +925,8 @@ class CumSum(PrimitiveWithInfer):
if axis['value'] is None: if axis['value'] is None:
raise ValueError(f"For {self.name}, axis must be const.") raise ValueError(f"For {self.name}, axis must be const.")
validator.check_value_type('axis', axis['value'], [int], cls_name) validator.check_value_type('axis', axis['value'], [int], cls_name)
valid_types = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32] valid_dtypes = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_type_same({'x': x['dtype']}, valid_types, cls_name) validator.check_tensor_dtype_valid('x', x['dtype'], valid_dtypes, cls_name)
return {'shape': x_shp, return {'shape': x_shp,
'dtype': x['dtype'], 'dtype': x['dtype'],
'value': None} 'value': None}
@ -989,7 +992,7 @@ class AddN(PrimitiveWithInfer):
if dtype == mstype.undetermined: if dtype == mstype.undetermined:
contains_undetermined = True contains_undetermined = True
if not contains_undetermined: if not contains_undetermined:
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), cls_name) validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type + (mstype.bool_,), cls_name)
return inputs[0] return inputs[0]
def infer_value(self, inputs): def infer_value(self, inputs):
@ -1068,7 +1071,7 @@ class AccumulateNV2(PrimitiveWithInfer):
args = {} args = {}
for i, dtype in enumerate(inputs): for i, dtype in enumerate(inputs):
args[f"inputs[{i}]"] = dtype args[f"inputs[{i}]"] = dtype
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), cls_name) validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type + (mstype.bool_,), cls_name)
return inputs[0] return inputs[0]
@ -1094,12 +1097,12 @@ class Neg(PrimitiveWithInfer):
"""Initialize Neg""" """Initialize Neg"""
self.init_prim_io_names(inputs=['x'], outputs=['y']) self.init_prim_io_names(inputs=['x'], outputs=['y'])
def infer_shape(self, input_x): def infer_shape(self, x_shape):
return input_x return x_shape
def infer_dtype(self, input_x): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name) validator.check_tensor_dtype_valid("x", x_dtype, mstype.number_type, self.name)
return input_x return x_dtype
def infer_value(self, input_x): def infer_value(self, input_x):
if input_x is not None: if input_x is not None:
@ -1151,7 +1154,7 @@ class InplaceAdd(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, v_dtype): def infer_dtype(self, x_dtype, v_dtype):
args = {'x': x_dtype, 'v': v_dtype} args = {'x': x_dtype, 'v': v_dtype}
valid_type = [mstype.int32, mstype.float16, mstype.float32] valid_type = [mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
return x_dtype return x_dtype
def infer_shape(self, x_shape, v_shape): def infer_shape(self, x_shape, v_shape):
@ -1209,7 +1212,7 @@ class InplaceSub(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, v_dtype): def infer_dtype(self, x_dtype, v_dtype):
args = {'x': x_dtype, 'v': v_dtype} args = {'x': x_dtype, 'v': v_dtype}
valid_type = [mstype.int32, mstype.float16, mstype.float32] valid_type = [mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
return x_dtype return x_dtype
def infer_shape(self, x_shape, v_shape): def infer_shape(self, x_shape, v_shape):
@ -1363,9 +1366,9 @@ class Square(PrimitiveWithInfer):
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
return x_shape return x_shape
def infer_dtype(self, x_type): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name) validator.check_tensor_dtype_valid("x", x_dtype, mstype.number_type, self.name)
return x_type return x_dtype
def infer_value(self, x): def infer_value(self, x):
if x is not None: if x is not None:
@ -1401,9 +1404,9 @@ class Rsqrt(PrimitiveWithInfer):
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
return x_shape return x_shape
def infer_dtype(self, x_type): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name) validator.check_tensor_dtype_valid("x", x_dtype, mstype.number_type, self.name)
return x_type return x_dtype
def infer_value(self, x): def infer_value(self, x):
if x is not None: if x is not None:
@ -1437,7 +1440,7 @@ class Sqrt(PrimitiveWithCheck):
self.init_prim_io_names(inputs=['x'], outputs=['output']) self.init_prim_io_names(inputs=['x'], outputs=['output'])
def check_dtype(self, x_type): def check_dtype(self, x_type):
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name) validator.check_tensor_dtype_valid("x", x_type, mstype.number_type, self.name)
def infer_value(self, x): def infer_value(self, x):
if x is not None: if x is not None:
@ -1599,8 +1602,7 @@ class Expm1(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_subclass("x", x_type, mstype.tensor, self.name) validator.check_tensor_dtype_valid("x", x_type, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_type_same({"x": x_type}, [mstype.float16, mstype.float32], self.name)
return x_type return x_type
@ -1641,10 +1643,9 @@ class HistogramFixedWidth(PrimitiveWithInfer):
return (self.nbins,) return (self.nbins,)
def infer_dtype(self, x_dtype, range_dtype): def infer_dtype(self, x_dtype, range_dtype):
validator.check_subclass("x", x_dtype, mstype.tensor, self.name) valid_dtypes = (mstype.float16, mstype.float32, mstype.int32)
valid_types = (mstype.float16, mstype.float32, mstype.int32) validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
validator.check_tensor_type_same({"x": x_dtype}, valid_types, self.name) validator.check_tensor_dtype_valid("range", range_dtype, valid_dtypes, self.name)
validator.check_tensor_type_same({"range": range_dtype}, valid_types, self.name)
y_dtype = mstype.int32 y_dtype = mstype.int32
return y_dtype return y_dtype
@ -1707,13 +1708,13 @@ class Log1p(PrimitiveWithInfer):
def __init__(self): def __init__(self):
self.init_prim_io_names(inputs=['x'], outputs=['y']) self.init_prim_io_names(inputs=['x'], outputs=['y'])
def infer_shape(self, x): def infer_shape(self, x_shape):
return x return x_shape
def infer_dtype(self, x): def infer_dtype(self, x_dtype):
validator.check_subclass("x", x, mstype.tensor, self.name) validator.check_subclass("x", x_dtype, mstype.tensor, self.name)
validator.check_tensor_type_same({"x": x}, [mstype.float16, mstype.float32], self.name) validator.check_tensor_dtype_valid("x", x_dtype, [mstype.float16, mstype.float32], self.name)
return x return x_dtype
class Erf(PrimitiveWithInfer): class Erf(PrimitiveWithInfer):
@ -1741,9 +1742,9 @@ class Erf(PrimitiveWithInfer):
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
return x_shape return x_shape
def infer_dtype(self, x_type): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x": x_type}, [mstype.float16, mstype.float32], self.name) validator.check_tensor_dtype_valid("x", x_dtype, [mstype.float16, mstype.float32], self.name)
return x_type return x_dtype
class Erfc(PrimitiveWithInfer): class Erfc(PrimitiveWithInfer):
@ -1772,7 +1773,7 @@ class Erfc(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_tensor_type_same({"x": x_type}, [mstype.float16, mstype.float32], self.name) validator.check_tensor_dtype_valid("x", x_type, [mstype.float16, mstype.float32], self.name)
return x_type return x_type
@ -2126,7 +2127,7 @@ class Floor(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x": x_dtype}, mstype.float_type, self.name) validator.check_tensor_dtype_valid("x", x_dtype, mstype.float_type, self.name)
return x_dtype return x_dtype
@ -2185,7 +2186,7 @@ class Ceil(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x": x_dtype}, [mstype.float16, mstype.float32], self.name) validator.check_tensor_dtype_valid("x", x_dtype, [mstype.float16, mstype.float32], self.name)
return x_dtype return x_dtype
@ -2281,7 +2282,7 @@ class Acosh(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype return x_dtype
@ -2310,7 +2311,7 @@ class Cosh(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype return x_dtype
@ -2339,7 +2340,7 @@ class Asinh(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype return x_dtype
@ -2368,7 +2369,7 @@ class Sinh(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype return x_dtype
@ -2380,7 +2381,7 @@ class _LogicBinaryOp(_BinaryOp):
@staticmethod @staticmethod
def do_infer_dtype(x_dtype, y_dtype, valid_type=mstype.number_type, prim_name=None): def do_infer_dtype(x_dtype, y_dtype, valid_type=mstype.number_type, prim_name=None):
args_dtype = {"x": x_dtype, "y": y_dtype} args_dtype = {"x": x_dtype, "y": y_dtype}
validator.check_tensor_type_same(args_dtype, valid_type, prim_name) validator.check_tensors_dtypes_same_and_valid(args_dtype, valid_type, prim_name)
return mstype.tensor_type(mstype.bool_) return mstype.tensor_type(mstype.bool_)
def infer_dtype(self, x_dtype, y_dtype): def infer_dtype(self, x_dtype, y_dtype):
@ -2461,7 +2462,7 @@ class ApproximateEqual(_LogicBinaryOp):
def infer_dtype(self, x_dtype, y_dtype): def infer_dtype(self, x_dtype, y_dtype):
args_dtype = {"x": x_dtype, "y": y_dtype} args_dtype = {"x": x_dtype, "y": y_dtype}
valid_type = [mstype.float32, mstype.float16] valid_type = [mstype.float32, mstype.float16]
validator.check_tensor_type_same(args_dtype, valid_type, prim_name=self.name) validator.check_tensors_dtypes_same_and_valid(args_dtype, valid_type, prim_name=self.name)
return mstype.tensor_type(mstype.bool_) return mstype.tensor_type(mstype.bool_)
@ -2498,7 +2499,7 @@ class EqualCount(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, y_dtype): def infer_dtype(self, x_dtype, y_dtype):
args = {'x': x_dtype, 'y': y_dtype} args = {'x': x_dtype, 'y': y_dtype}
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name) validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type + (mstype.bool_,), self.name)
return x_dtype return x_dtype
@ -2711,7 +2712,7 @@ class LogicalNot(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x": x_dtype}, [mstype.bool_], self.name) validator.check_tensor_dtype_valid("x", x_dtype, [mstype.bool_], self.name)
return mstype.tensor_type(mstype.bool_) return mstype.tensor_type(mstype.bool_)
@ -2859,8 +2860,7 @@ class IsFinite(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_subclass("x", x_dtype, mstype.tensor, self.name) validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name)
return mstype.bool_ return mstype.bool_
@ -2890,7 +2890,7 @@ class FloatStatus(PrimitiveWithInfer):
return [1] return [1]
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32, mstype.float16], self.name) validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.float16], self.name)
return x_dtype return x_dtype
@ -2959,7 +2959,7 @@ class NPUGetFloatStatus(PrimitiveWithInfer):
return [8] return [8]
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float16, mstype.float32], self.name) validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float16, mstype.float32], self.name)
return mstype.float32 return mstype.float32
@ -3002,7 +3002,7 @@ class NPUClearFloatStatus(PrimitiveWithInfer):
return [8] return [8]
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float16, mstype.float32], self.name) validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float16, mstype.float32], self.name)
return mstype.float32 return mstype.float32
@ -3030,7 +3030,7 @@ class Cos(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype return x_dtype
@ -3058,7 +3058,7 @@ class ACos(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype return x_dtype
@ -3087,7 +3087,7 @@ class Sin(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype return x_dtype
@ -3116,7 +3116,7 @@ class Asin(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype return x_dtype
@ -3175,7 +3175,7 @@ class NMSWithMask(PrimitiveWithInfer):
return (bboxes_shape, (num,), (num,)) return (bboxes_shape, (num,), (num,))
def infer_dtype(self, bboxes_dtype): def infer_dtype(self, bboxes_dtype):
validator.check_tensor_type_same({"bboxes": bboxes_dtype}, [mstype.float16, mstype.float32], self.name) validator.check_tensor_dtype_valid("bboxes", bboxes_dtype, [mstype.float16, mstype.float32], self.name)
return (bboxes_dtype, mstype.int32, mstype.bool_) return (bboxes_dtype, mstype.int32, mstype.bool_)
@ -3205,7 +3205,7 @@ class Abs(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name) validator.check_tensor_dtype_valid('x', x_type, mstype.number_type, self.name)
return x_type return x_type
def infer_value(self, x): def infer_value(self, x):
@ -3247,7 +3247,7 @@ class Sign(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype return x_dtype
@ -3276,9 +3276,9 @@ class Round(PrimitiveWithInfer):
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
return x_shape return x_shape
def infer_dtype(self, x_type): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name) validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_type return x_dtype
class Tan(PrimitiveWithInfer): class Tan(PrimitiveWithInfer):
@ -3306,8 +3306,8 @@ class Tan(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_type): def infer_dtype(self, x_type):
valid_types = [mstype.float16, mstype.float32, mstype.int32] valid_dtypes = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_type_same({'x': x_type}, valid_types, self.name) validator.check_tensor_dtype_valid('x', x_type, valid_dtypes, self.name)
return x_type return x_type
@ -3338,7 +3338,7 @@ class Atan(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name) validator.check_tensor_dtype_valid('x', x_type, mstype.number_type, self.name)
return x_type return x_type
@ -3367,7 +3367,7 @@ class Atanh(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name) validator.check_tensor_dtype_valid('x', x_type, mstype.number_type, self.name)
return x_type return x_type
@ -3431,8 +3431,9 @@ class SquareSumAll(PrimitiveWithInfer):
return [], [] return [], []
def infer_dtype(self, x_type, y_type): def infer_dtype(self, x_type, y_type):
validator.check_tensor_type_same({'x1_type': x_type}, [mstype.float16, mstype.float32], self.name) valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({'x2_type': y_type}, [mstype.float16, mstype.float32], self.name) validator.check_tensor_dtype_valid('x1_type', x_type, valid_types, self.name)
validator.check_tensor_dtype_valid('x2_type', y_type, valid_types, self.name)
return x_type, y_type return x_type, y_type
@ -3539,7 +3540,7 @@ class BesselI0e(PrimitiveWithInfer):
return x return x
def infer_dtype(self, x): def infer_dtype(self, x):
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name) validator.check_tensor_dtype_valid('x', x, mstype.number_type, self.name)
return x return x
@ -3568,7 +3569,7 @@ class BesselI1e(PrimitiveWithInfer):
return x return x
def infer_dtype(self, x): def infer_dtype(self, x):
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name) validator.check_tensor_dtype_valid('x', x, mstype.number_type, self.name)
return x return x
@ -3598,7 +3599,7 @@ class Inv(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x_dtype': x_dtype}, [mstype.float16, mstype.float32, validator.check_tensor_dtype_valid('x_dtype', x_dtype, [mstype.float16, mstype.float32,
mstype.int32], self.name) mstype.int32], self.name)
return x_dtype return x_dtype
@ -3628,7 +3629,7 @@ class Invert(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x_dtype': x_dtype}, [mstype.int16, mstype.uint16], self.name) validator.check_tensor_dtype_valid('x_dtype', x_dtype, [mstype.int16, mstype.uint16], self.name)
return x_dtype return x_dtype
@ -3654,8 +3655,8 @@ class Eps(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['input_x'], outputs=['y']) self.init_prim_io_names(inputs=['input_x'], outputs=['y'])
def __infer__(self, input_x): def __infer__(self, input_x):
valid_types = [mstype.float16, mstype.float32] valid_dtypes = [mstype.float16, mstype.float32]
validator.check_tensor_type_same({'input_x': input_x['dtype']}, valid_types, self.name) validator.check_tensor_dtype_valid('input_x', input_x['dtype'], valid_dtypes, self.name)
x_nptype = mstype.dtype_to_nptype(input_x['dtype'].element_type()) x_nptype = mstype.dtype_to_nptype(input_x['dtype'].element_type())
if x_nptype == np.float16: if x_nptype == np.float16:
@ -3725,9 +3726,9 @@ class IFMR(PrimitiveWithInfer):
return (1,), (1,) return (1,), (1,)
def infer_dtype(self, data_dtype, data_min_dtype, data_max_dtype, cumsum_dtype): def infer_dtype(self, data_dtype, data_min_dtype, data_max_dtype, cumsum_dtype):
valid_types = [mstype.float32, mstype.float16] tuple(map(partial(validator.check_tensor_dtype_valid,
validator.check_tensor_type_same({"input_value": data_dtype}, valid_types, self.name) valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
validator.check_tensor_type_same({"input_min": data_min_dtype}, valid_types, self.name) ("input_value", "input_min", "input_max"),
validator.check_tensor_type_same({"input_max": data_max_dtype}, valid_types, self.name) (data_dtype, data_min_dtype, data_max_dtype)))
validator.check_tensor_type_same({"input_bins": cumsum_dtype}, [mstype.int32], self.name) validator.check_tensor_dtype_valid("input_bins", cumsum_dtype, [mstype.int32], self.name)
return mstype.tensor_type(mstype.float32), mstype.tensor_type(mstype.float32) return mstype.tensor_type(mstype.float32), mstype.tensor_type(mstype.float32)

View File

@ -17,7 +17,7 @@
import math import math
import operator import operator
from functools import reduce from functools import reduce, partial
import numpy as np import numpy as np
from ... import context from ... import context
from .. import signature as sig from .. import signature as sig
@ -153,8 +153,7 @@ class Softmax(PrimitiveWithInfer):
return logits return logits
def infer_dtype(self, logits): def infer_dtype(self, logits):
validator.check_subclass("logits", logits, mstype.tensor, self.name) validator.check_tensor_dtype_valid("logits", logits, mstype.float_type, self.name)
validator.check_tensor_type_same({"logits": logits}, mstype.float_type, self.name)
return logits return logits
@ -197,8 +196,7 @@ class LogSoftmax(PrimitiveWithInfer):
return logits return logits
def infer_dtype(self, logits): def infer_dtype(self, logits):
validator.check_subclass("logits", logits, mstype.tensor, self.name) validator.check_tensor_dtype_valid("logits", logits, mstype.float_type, self.name)
validator.check_tensor_type_same({"logits": logits}, mstype.float_type, self.name)
return logits return logits
@ -230,12 +228,12 @@ class Softplus(PrimitiveWithInfer):
"""Initialize Softplus""" """Initialize Softplus"""
self.init_prim_io_names(inputs=['x'], outputs=['output']) self.init_prim_io_names(inputs=['x'], outputs=['output'])
def infer_shape(self, input_x): def infer_shape(self, x_shape):
return input_x return x_shape
def infer_dtype(self, input_x): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'input_x': input_x}, mstype.float_type, self.name) validator.check_tensor_dtype_valid('x', x_dtype, mstype.float_type, self.name)
return input_x return x_dtype
class Softsign(PrimitiveWithInfer): class Softsign(PrimitiveWithInfer):
@ -269,7 +267,7 @@ class Softsign(PrimitiveWithInfer):
return input_x return input_x
def infer_dtype(self, input_x): def infer_dtype(self, input_x):
validator.check_tensor_type_same({'input_x': input_x}, [mstype.float16, mstype.float32], self.name) validator.check_tensor_dtype_valid('input_x', input_x, [mstype.float16, mstype.float32], self.name)
return input_x return input_x
@ -301,7 +299,7 @@ class ReLU(PrimitiveWithInfer):
return input_x return input_x
def infer_dtype(self, input_x): def infer_dtype(self, input_x):
validator.check_tensor_type_same({'input_x': input_x}, mstype.number_type, self.name) validator.check_tensor_dtype_valid('input_x', input_x, mstype.number_type, self.name)
return input_x return input_x
@ -332,7 +330,7 @@ class ReLU6(PrimitiveWithInfer):
return input_x return input_x
def infer_dtype(self, input_x): def infer_dtype(self, input_x):
validator.check_tensor_type_same({'input_x': input_x}, (mstype.float16, mstype.float32), self.name) validator.check_tensor_dtype_valid('input_x', input_x, (mstype.float16, mstype.float32), self.name)
return input_x return input_x
@ -384,7 +382,7 @@ class ReLUV2(PrimitiveWithInfer):
output_shape = (input_x['shape'], mask_shape) output_shape = (input_x['shape'], mask_shape)
validator.check_subclass("input_x", input_dtype, mstype.tensor, self.name) validator.check_subclass("input_x", input_dtype, mstype.tensor, self.name)
validator.check_tensor_type_same({'input_x': input_dtype}, mstype.number_type, self.name) validator.check_tensor_dtype_valid('input_x', input_dtype, mstype.number_type, self.name)
mask_dtype = mstype.uint8 mask_dtype = mstype.uint8
output_dtype = (input_dtype, mask_dtype) output_dtype = (input_dtype, mask_dtype)
@ -426,7 +424,7 @@ class Elu(PrimitiveWithInfer):
return input_x return input_x
def infer_dtype(self, input_x): def infer_dtype(self, input_x):
validator.check_tensor_type_same({'input_x': input_x}, mstype.float_type, self.name) validator.check_tensor_dtype_valid('input_x', input_x, mstype.float_type, self.name)
return input_x return input_x
@ -463,7 +461,7 @@ class HSwish(PrimitiveWithInfer):
return xshape return xshape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) validator.check_tensor_dtype_valid("x", x_dtype, (mstype.float16, mstype.float32), self.name)
return x_dtype return x_dtype
@ -499,7 +497,7 @@ class Sigmoid(PrimitiveWithInfer):
return input_x return input_x
def infer_dtype(self, input_x): def infer_dtype(self, input_x):
validator.check_tensor_type_same({"input_x": input_x}, (mstype.float16, mstype.float32), self.name) validator.check_tensor_dtype_valid("input_x", input_x, (mstype.float16, mstype.float32), self.name)
return input_x return input_x
@ -536,7 +534,7 @@ class HSigmoid(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) validator.check_tensor_dtype_valid("x", x_dtype, (mstype.float16, mstype.float32), self.name)
return x_dtype return x_dtype
@ -733,12 +731,12 @@ class FusedBatchNormEx(PrimitiveWithInfer):
return (input_x, scale, scale, scale, scale, scale) return (input_x, scale, scale, scale, scale, scale)
def infer_dtype(self, input_x, scale, bias, mean, variance): def infer_dtype(self, input_x, scale, bias, mean, variance):
validator.check_tensor_type_same({"input_x": input_x}, [mstype.float16, mstype.float32], self.name) validator.check_tensor_dtype_valid("input_x", input_x, [mstype.float16, mstype.float32], self.name)
args = {"scale": scale, "bias": bias} args = {"scale": scale, "bias": bias}
validator.check_tensor_type_same(args, [mstype.float32], self.name) validator.check_tensors_dtypes_same_and_valid(args, [mstype.float32], self.name)
args_moving = {"mean": mean, "variance": variance} args_moving = {"mean": mean, "variance": variance}
valid_types = [mstype.tensor_type(mstype.float32)] valid_dtypes = [mstype.tensor_type(mstype.float32)]
validator.check_type_same(args_moving, valid_types, self.name) validator.check_types_same_and_valid(args_moving, valid_dtypes, self.name)
return (input_x, scale, scale, scale, scale, scale) return (input_x, scale, scale, scale, scale, scale)
@ -769,7 +767,7 @@ class BNTrainingReduce(PrimitiveWithInfer):
return ([x_shape[1]], [x_shape[1]]) return ([x_shape[1]], [x_shape[1]])
def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_tensor_type_same({"x_type": x_type}, [mstype.float16, mstype.float32], self.name) validator.check_tensor_dtype_valid("x", x_type, [mstype.float16, mstype.float32], self.name)
return (x_type, x_type) return (x_type, x_type)
@ -819,6 +817,7 @@ class BNTrainingUpdate(PrimitiveWithInfer):
>>> bn_training_update = P.BNTrainingUpdate() >>> bn_training_update = P.BNTrainingUpdate()
>>> output = bn_training_update(input_x, sum, square_sum, scale, offset, mean, variance) >>> output = bn_training_update(input_x, sum, square_sum, scale, offset, mean, variance)
""" """
@prim_attr_register @prim_attr_register
def __init__(self, isRef=True, epsilon=1e-5, factor=0.1): def __init__(self, isRef=True, epsilon=1e-5, factor=0.1):
self.init_prim_io_names(inputs=['x', 'sum', 'square_sum', 'scale', 'b', 'mean', 'variance'], self.init_prim_io_names(inputs=['x', 'sum', 'square_sum', 'scale', 'b', 'mean', 'variance'],
@ -846,13 +845,10 @@ class BNTrainingUpdate(PrimitiveWithInfer):
return (x, variance, variance, variance, variance) return (x, variance, variance, variance, variance)
def infer_dtype(self, x, sum, square_sum, scale, b, mean, variance): def infer_dtype(self, x, sum, square_sum, scale, b, mean, variance):
validator.check_tensor_type_same({"x_type": x}, [mstype.float16, mstype.float32], self.name) tuple(map(partial(validator.check_tensor_dtype_valid,
validator.check_tensor_type_same({"sum_type": sum}, [mstype.float16, mstype.float32], self.name) valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
validator.check_tensor_type_same({"square_sum_type": square_sum}, [mstype.float16, mstype.float32], self.name) ("x", "sum", "square_sum", "scale", "b", "mean", "variance"),
validator.check_tensor_type_same({"scale_type": scale}, [mstype.float16, mstype.float32], self.name) (x, sum, square_sum, scale, b, mean, variance)))
validator.check_tensor_type_same({"b_type": b}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_type_same({"mean_type": mean}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_type_same({"variance_type": variance}, [mstype.float16, mstype.float32], self.name)
return (x, variance, variance, variance, variance) return (x, variance, variance, variance, variance)
@ -928,16 +924,16 @@ class BatchNorm(PrimitiveWithInfer):
return (input_x, scale, scale, scale, scale) return (input_x, scale, scale, scale, scale)
def infer_dtype(self, input_x, scale, bias, mean, variance): def infer_dtype(self, input_x, scale, bias, mean, variance):
validator.check_tensor_type_same({"input_x": input_x}, [mstype.float16, mstype.float32], self.name) validator.check_tensor_dtype_valid("input_x", input_x, [mstype.float16, mstype.float32], self.name)
args = {"scale": scale, "bias": bias} args = {"scale": scale, "bias": bias}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
args_moving = {"mean": mean, "variance": variance} args_moving = {"mean": mean, "variance": variance}
if self.is_training: if self.is_training:
valid_types = [mstype.tensor_type(mstype.float16), mstype.tensor_type(mstype.float32), None] valid_dtypes = [mstype.tensor_type(mstype.float16), mstype.tensor_type(mstype.float32), None]
validator.check_type_same(args_moving, valid_types, self.name) validator.check_types_same_and_valid(args_moving, valid_dtypes, self.name)
else: else:
args_moving = {"mean": mean, "variance": variance} args_moving = {"mean": mean, "variance": variance}
validator.check_tensor_type_same(args_moving, [mstype.float16, mstype.float32], self.name) validator.check_tensors_dtypes_same_and_valid(args_moving, [mstype.float16, mstype.float32], self.name)
return (input_x, scale, bias, input_x, input_x) return (input_x, scale, bias, input_x, input_x)
@ -1053,7 +1049,7 @@ class Conv2D(PrimitiveWithInfer):
validator.check_equal_int(len(w_shape_norm), 4, "weight rank", self.name) validator.check_equal_int(len(w_shape_norm), 4, "weight rank", self.name)
validator.check_equal_int(len(x_shape_norm), 4, "x rank", self.name) validator.check_equal_int(len(x_shape_norm), 4, "x rank", self.name)
validator.check(f"x_shape[1] / group", x_shape_norm[1] // self.group, "w_shape[1]", w_shape_norm[1], \ validator.check(f"x_shape[1] / group", x_shape_norm[1] // self.group, "w_shape[1]", w_shape_norm[1], \
Rel.EQ, self.name) Rel.EQ, self.name)
validator.check('out_channel', self.out_channel, 'w_shape[0]', w_shape_norm[0], Rel.EQ, self.name) validator.check('out_channel', self.out_channel, 'w_shape[0]', w_shape_norm[0], Rel.EQ, self.name)
validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape_norm[2:4]), Rel.EQ, self.name) validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape_norm[2:4]), Rel.EQ, self.name)
@ -1084,24 +1080,24 @@ class Conv2D(PrimitiveWithInfer):
pad_top, pad_bottom, pad_left, pad_right = self.padding pad_top, pad_bottom, pad_left, pad_right = self.padding
h_out = 1 + (x_shape_norm[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) \ h_out = 1 + (x_shape_norm[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) \
* (dilation_h - 1)) / stride_h * (dilation_h - 1)) / stride_h
w_out = 1 + (x_shape_norm[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) \ w_out = 1 + (x_shape_norm[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) \
* (dilation_w - 1)) / stride_w * (dilation_w - 1)) / stride_w
h_out = math.floor(h_out) h_out = math.floor(h_out)
w_out = math.floor(w_out) w_out = math.floor(w_out)
self.pad_list = [pad_top, pad_bottom, pad_left, pad_right] self.pad_list = [pad_top, pad_bottom, pad_left, pad_right]
self.add_prim_attr('pad_list', (pad_top, pad_bottom, pad_left, pad_right)) self.add_prim_attr('pad_list', (pad_top, pad_bottom, pad_left, pad_right))
out_channel = self.out_channel out_channel = self.out_channel
out_shape = [x_shape_norm[0], out_channel, h_out, w_out] if self.format == "NCHW" else\ out_shape = [x_shape_norm[0], out_channel, h_out, w_out] if self.format == "NCHW" else \
[x_shape_norm[0], h_out, w_out, out_channel] [x_shape_norm[0], h_out, w_out, out_channel]
_check_shape('output', out_shape, self.name) _check_shape('output', out_shape, self.name)
return out_shape return out_shape
def infer_dtype(self, x_dtype, w_dtype, b_dtype=None): def infer_dtype(self, x_dtype, w_dtype, b_dtype=None):
args = {'x': x_dtype, 'w': w_dtype} args = {'x': x_dtype, 'w': w_dtype}
valid_types = [mstype.int8, mstype.int32, mstype.float16, mstype.float32] valid_dtypes = [mstype.int8, mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_types, self.name) validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
if x_dtype.element_type() == mstype.int8: if x_dtype.element_type() == mstype.int8:
return mstype.tensor_type(mstype.int32) return mstype.tensor_type(mstype.int32)
return x_dtype return x_dtype
@ -1220,9 +1216,9 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
pad_top, pad_bottom, pad_left, pad_right = self.padding pad_top, pad_bottom, pad_left, pad_right = self.padding
h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) \ h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) \
/ stride_h / stride_h
w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) \ w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) \
/ stride_w / stride_w
h_out = math.floor(h_out) h_out = math.floor(h_out)
w_out = math.floor(w_out) w_out = math.floor(w_out)
@ -1235,7 +1231,7 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, w_dtype, b_dtype=None): def infer_dtype(self, x_dtype, w_dtype, b_dtype=None):
args = {'x': x_dtype, 'w': w_dtype} args = {'x': x_dtype, 'w': w_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
if x_dtype.element_type() == mstype.int8: if x_dtype.element_type() == mstype.int8:
return mstype.tensor_type(mstype.int32) return mstype.tensor_type(mstype.int32)
return x_dtype return x_dtype
@ -1436,7 +1432,7 @@ class MaxPoolWithArgmax(_Pool):
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
out_dtype = x_dtype out_dtype = x_dtype
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) validator.check_tensor_dtype_valid("x", x_dtype, (mstype.float16, mstype.float32), self.name)
argmax_dtype = mstype.uint16 argmax_dtype = mstype.uint16
if self.is_gpu: if self.is_gpu:
argmax_dtype = mstype.int32 argmax_dtype = mstype.int32
@ -1604,12 +1600,12 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
for i, dim_len in enumerate(x_size_v): for i, dim_len in enumerate(x_size_v):
validator.check_value_type("x_size[%d]" % i, dim_len, [int], self.name) validator.check_value_type("x_size[%d]" % i, dim_len, [int], self.name)
args = {'doutput': doutput['dtype'], 'w': w['dtype']} args = {'doutput': doutput['dtype'], 'w': w['dtype']}
valid_types = [mstype.int8, mstype.int32, mstype.float16, mstype.float32] valid_dtypes = [mstype.int8, mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_types, self.name) validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
# infer shape # infer shape
dout_shape = doutput['shape'] dout_shape = doutput['shape']
dout_shape_norm = dout_shape if self.format == "NCHW" else\ dout_shape_norm = dout_shape if self.format == "NCHW" else \
[dout_shape[0], dout_shape[2], dout_shape[3], dout_shape[1]] [dout_shape[0], dout_shape[2], dout_shape[3], dout_shape[1]]
kernel_h = self.kernel_size[0] kernel_h = self.kernel_size[0]
kernel_w = self.kernel_size[1] kernel_w = self.kernel_size[1]
@ -1682,7 +1678,7 @@ class BiasAdd(PrimitiveWithInfer):
def infer_dtype(self, x_type, b_type): def infer_dtype(self, x_type, b_type):
args = {"input_x": x_type, "bias": b_type} args = {"input_x": x_type, "bias": b_type}
validator.check_tensor_type_same(args, mstype.number_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return x_type return x_type
@ -1721,8 +1717,8 @@ class TopK(PrimitiveWithInfer):
def __infer__(self, input_x, k): def __infer__(self, input_x, k):
x_dtype = input_x['dtype'] x_dtype = input_x['dtype']
valid_types = (mstype.int32, mstype.float16, mstype.float32) valid_dtypes = (mstype.int32, mstype.float16, mstype.float32)
validator.check_tensor_type_same({'x': x_dtype}, valid_types, self.name) validator.check_tensor_dtype_valid('x', x_dtype, valid_dtypes, self.name)
k_v = k['value'] k_v = k['value']
validator.check_value_type('k', k_v, (int,), self.name) validator.check_value_type('k', k_v, (int,), self.name)
x_shape = list(input_x['shape']) x_shape = list(input_x['shape'])
@ -1774,7 +1770,7 @@ class SoftmaxCrossEntropyWithLogits(PrimitiveWithInfer):
def infer_dtype(self, logits_type, labels_type): def infer_dtype(self, logits_type, labels_type):
args = {"logits": logits_type, "labels": labels_type} args = {"logits": logits_type, "labels": labels_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
return (logits_type, logits_type) return (logits_type, logits_type)
@ -1825,8 +1821,9 @@ class SparseSoftmaxCrossEntropyWithLogits(PrimitiveWithInfer):
return loss_shape return loss_shape
def infer_dtype(self, logits_type, labels_type): def infer_dtype(self, logits_type, labels_type):
validator.check_tensor_type_same({"logits": logits_type}, (mstype.float16, mstype.float32), self.name) validator.check_tensor_dtype_valid("logits", logits_type, (mstype.float16, mstype.float32),
validator.check_tensor_type_same({"labels": labels_type}, (mstype.int32, mstype.int64), self.name) self.name)
validator.check_tensor_dtype_valid("labels", labels_type, (mstype.int32, mstype.int64), self.name)
return logits_type return logits_type
@ -1886,13 +1883,13 @@ class ApplyMomentum(PrimitiveWithInfer):
return v_shape return v_shape
def infer_dtype(self, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype): def infer_dtype(self, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype):
valid_types = [mstype.float16, mstype.float32, mstype.float64] valid_dtypes = [mstype.float16, mstype.float32, mstype.float64]
if v_dtype != mstype.type_refkey and a_dtype != mstype.type_refkey: if v_dtype != mstype.type_refkey and a_dtype != mstype.type_refkey:
validator.check_tensor_type_same({"v": v_dtype}, valid_types, self.name) validator.check_tensor_dtype_valid("v", v_dtype, valid_dtypes, self.name)
validator.check_tensor_type_same({"a": a_dtype}, valid_types, self.name) validator.check_tensor_dtype_valid("a", a_dtype, valid_dtypes, self.name)
validator.check_scalar_or_tensor_type_same({"l_dtype": l_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_types_same({"l_dtype": l_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_type_same({"g_dtype": g_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_types_same({"g_dtype": g_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_type_same({"m_dtype": m_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_types_same({"m_dtype": m_dtype}, valid_dtypes, self.name)
if not self.is_ge and self.is_tbe: if not self.is_ge and self.is_tbe:
return g_dtype, g_dtype return g_dtype, g_dtype
return g_dtype return g_dtype
@ -1944,7 +1941,7 @@ class SmoothL1Loss(PrimitiveWithInfer):
def infer_dtype(self, prediction, target): def infer_dtype(self, prediction, target):
args = {"prediction": prediction, "target": target} args = {"prediction": prediction, "target": target}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
return prediction return prediction
@ -1981,9 +1978,8 @@ class L2Loss(PrimitiveWithInfer):
return loss_shape return loss_shape
def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_subclass("x_type", x_type, mstype.tensor, self.name) valid_dtypes = [mstype.float16, mstype.float32]
valid_types = [mstype.float16, mstype.float32] validator.check_tensor_dtype_valid('x_type', x_type, valid_dtypes, self.name)
validator.check_tensor_type_same({'x_type': x_type}, valid_types, self.name)
return x_type return x_type
@ -2019,11 +2015,10 @@ class DataFormatDimMap(PrimitiveWithInfer):
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
return x_shape return x_shape
def infer_dtype(self, x_type): def infer_dtype(self, x_dtype):
validator.check_subclass("x", x_type, mstype.tensor, self.name) valid_dtypes = [mstype.int32]
valid_types = [mstype.int32] validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) return x_dtype
return x_type
class RNNTLoss(PrimitiveWithInfer): class RNNTLoss(PrimitiveWithInfer):
@ -2065,21 +2060,18 @@ class RNNTLoss(PrimitiveWithInfer):
validator.check_equal_int(len(input_length_shape), 1, 'input_length_rank', self.name) validator.check_equal_int(len(input_length_shape), 1, 'input_length_rank', self.name)
validator.check_equal_int(len(label_length_shape), 1, 'label_length_rank', self.name) validator.check_equal_int(len(label_length_shape), 1, 'label_length_rank', self.name)
validator.check('labels shape[0]', labels_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name) validator.check('labels shape[0]', labels_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name)
validator.check('labels shape[1]', labels_shape[1], 'acts shape[2]-1', acts_shape[2]-1, Rel.EQ, self.name) validator.check('labels shape[1]', labels_shape[1], 'acts shape[2]-1', acts_shape[2] - 1, Rel.EQ, self.name)
validator.check('input_length size', input_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name) validator.check('input_length size', input_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name)
validator.check('label_length size', label_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name) validator.check('label_length size', label_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name)
costs_shape = (acts_shape[0],) costs_shape = (acts_shape[0],)
return (costs_shape, acts_shape) return (costs_shape, acts_shape)
def infer_dtype(self, acts_type, labels_type, input_length_type, label_length_type): def infer_dtype(self, acts_type, labels_type, input_length_type, label_length_type):
validator.check_subclass("acts_type", acts_type, mstype.tensor, self.name) validator.check_tensor_dtype_valid("acts_type", acts_type, [mstype.float32, mstype.float16], self.name)
validator.check_subclass("labels_type", labels_type, mstype.tensor, self.name) tuple(map(partial(validator.check_tensor_dtype_valid,
validator.check_subclass("input_length_type", input_length_type, mstype.tensor, self.name) valid_dtypes=(mstype.int32,), prim_name=self.name),
validator.check_subclass("label_length_type", label_length_type, mstype.tensor, self.name) ("labels", "input_length", "label_length"),
validator.check_tensor_type_same({"acts_type": acts_type}, [mstype.float32, mstype.float16], self.name) (labels_type, input_length_type, label_length_type)))
validator.check_tensor_type_same({"labels_type": labels_type}, [mstype.int32], self.name)
validator.check_tensor_type_same({"input_length_type": input_length_type}, [mstype.int32], self.name)
validator.check_tensor_type_same({"label_length_type": label_length_type}, [mstype.int32], self.name)
return (acts_type, acts_type) return (acts_type, acts_type)
@ -2143,13 +2135,10 @@ class SGD(PrimitiveWithInfer):
def infer_dtype(self, parameters_dtype, gradient_dtype, learning_rate_dtype, def infer_dtype(self, parameters_dtype, gradient_dtype, learning_rate_dtype,
accum_dtype, momentum_dtype, stat_dtype): accum_dtype, momentum_dtype, stat_dtype):
valid_types = [mstype.float16, mstype.float32] tuple(map(partial(validator.check_tensor_dtype_valid,
validator.check_tensor_type_same({"parameters": parameters_dtype}, valid_types, self.name) valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
validator.check_tensor_type_same({"gradient": gradient_dtype}, valid_types, self.name) ("parameters", "gradient", "learning_rate", "accum", "momentum", "stat"),
validator.check_tensor_type_same({"learning_rate": learning_rate_dtype}, valid_types, self.name) (parameters_dtype, gradient_dtype, learning_rate_dtype, accum_dtype, momentum_dtype, stat_dtype)))
validator.check_tensor_type_same({"accum": accum_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"momentum": momentum_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"stat": stat_dtype}, valid_types, self.name)
return parameters_dtype return parameters_dtype
@ -2229,13 +2218,13 @@ class ApplyRMSProp(PrimitiveWithInfer):
def infer_dtype(self, var_dtype, mean_square_dtype, moment_dtype, learning_rate_dtype, grad_dtype, decay_dtype, def infer_dtype(self, var_dtype, mean_square_dtype, moment_dtype, learning_rate_dtype, grad_dtype, decay_dtype,
momentum_dtype, epsilon_dtype): momentum_dtype, epsilon_dtype):
args = {"var": var_dtype, "mean_square": mean_square_dtype, "moment": moment_dtype, "grad": grad_dtype} args = {"var": var_dtype, "mean_square": mean_square_dtype, "moment": moment_dtype, "grad": grad_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
valid_types = [mstype.float16, mstype.float32] valid_dtypes = [mstype.float16, mstype.float32]
args_decay = {"decay": decay_dtype, 'momentum': momentum_dtype, "epsilon": epsilon_dtype} args_decay = {"decay": decay_dtype, 'momentum': momentum_dtype, "epsilon": epsilon_dtype}
validator.check_type_same(args_decay, valid_types, self.name) validator.check_types_same_and_valid(args_decay, valid_dtypes, self.name)
args_lr = {"learning_rate": learning_rate_dtype, "decay": decay_dtype} args_lr = {"learning_rate": learning_rate_dtype, "decay": decay_dtype}
validator.check_scalar_or_tensor_type_same(args_lr, valid_types, self.name, allow_mix=True) validator.check_scalar_or_tensor_types_same(args_lr, valid_dtypes, self.name, allow_mix=True)
if not self.is_ge and self.is_d: if not self.is_ge and self.is_d:
return var_dtype, var_dtype, var_dtype return var_dtype, var_dtype, var_dtype
return var_dtype return var_dtype
@ -2332,13 +2321,13 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer):
learning_rate_dtype, rho_dtype, momentum_dtype, epsilon_dtype): learning_rate_dtype, rho_dtype, momentum_dtype, epsilon_dtype):
args = {"var": var_dtype, "mean_gradient": mean_gradient_dtype, args = {"var": var_dtype, "mean_gradient": mean_gradient_dtype,
"mean_square": mean_square_dtype, "moment": moment_dtype, "grad": grad_dtype} "mean_square": mean_square_dtype, "moment": moment_dtype, "grad": grad_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
valid_types = [mstype.float16, mstype.float32] valid_dtypes = [mstype.float16, mstype.float32]
args_rho = {"rho": rho_dtype, 'momentum': momentum_dtype, "epsilon": epsilon_dtype} args_rho = {"rho": rho_dtype, 'momentum': momentum_dtype, "epsilon": epsilon_dtype}
validator.check_type_same(args_rho, valid_types, self.name) validator.check_types_same_and_valid(args_rho, valid_dtypes, self.name)
args_lr = {"learning_rate": learning_rate_dtype, "rho": rho_dtype} args_lr = {"learning_rate": learning_rate_dtype, "rho": rho_dtype}
validator.check_scalar_or_tensor_type_same(args_lr, valid_types, self.name, allow_mix=True) validator.check_scalar_or_tensor_types_same(args_lr, valid_dtypes, self.name, allow_mix=True)
if self.is_ascend: if self.is_ascend:
return var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype return var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype
return var_dtype return var_dtype
@ -2440,8 +2429,7 @@ class L2Normalize(PrimitiveWithInfer):
return input_x return input_x
def infer_dtype(self, input_x): def infer_dtype(self, input_x):
validator.check_subclass("x", input_x, mstype.tensor, self.name) validator.check_tensor_dtype_valid("input_x", input_x, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_type_same({"input_x": input_x}, [mstype.float16, mstype.float32], self.name)
return input_x return input_x
@ -2527,9 +2515,9 @@ class DropoutDoMask(PrimitiveWithInfer):
raise ValueError(f"DropoutDoMask y mask do not math input input_x shape:" raise ValueError(f"DropoutDoMask y mask do not math input input_x shape:"
"{input_x_shape}, mask shape: {mask_shape}.") "{input_x_shape}, mask shape: {mask_shape}.")
validator.check_tensor_type_same({"input_x": input_x['dtype']}, [mstype.float32, mstype.float16, mstype.int32], validator.check_tensor_dtype_valid("input_x", input_x['dtype'], [mstype.float32, mstype.float16, mstype.int32],
self.name) self.name)
validator.check_tensor_type_same({"input_mask": mask['dtype']}, [mstype.uint8], self.name) validator.check_tensor_dtype_valid("input_mask", mask['dtype'], [mstype.uint8], self.name)
keep_prob_v = keep_prob['value'] keep_prob_v = keep_prob['value']
if keep_prob_v is not None: if keep_prob_v is not None:
@ -2587,7 +2575,8 @@ class ResizeBilinear(PrimitiveWithInfer):
return out_shape return out_shape
def infer_dtype(self, input_dtype): def infer_dtype(self, input_dtype):
validator.check_tensor_type_same({'input_dtype': input_dtype}, [mstype.float16, mstype.float32], self.name) validator.check_tensor_dtype_valid('input_dtype', input_dtype, [mstype.float16, mstype.float32],
self.name)
return mstype.tensor_type(mstype.float32) return mstype.tensor_type(mstype.float32)
@ -2631,10 +2620,10 @@ class OneHot(PrimitiveWithInfer):
def __infer__(self, indices, depth, on_value, off_value): def __infer__(self, indices, depth, on_value, off_value):
# check type # check type
validator.check_tensor_type_same({"indices": indices['dtype']}, (mstype.int32,), self.name) validator.check_tensor_dtype_valid("indices", indices['dtype'], (mstype.int32,), self.name)
validator.check_type_name("depth", depth['dtype'], mstype.int_type, self.name) validator.check_type_name("depth", depth['dtype'], mstype.int_type, self.name)
args = {"on_value": on_value['dtype'], "off_value": off_value['dtype']} args = {"on_value": on_value['dtype'], "off_value": off_value['dtype']}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
# check shape # check shape
indices_shp = indices['shape'] indices_shp = indices['shape']
@ -2685,7 +2674,7 @@ class Gelu(PrimitiveWithInfer):
return input_x return input_x
def infer_dtype(self, input_x): def infer_dtype(self, input_x):
validator.check_tensor_type_same({"input_x": input_x}, (mstype.float16, mstype.float32), self.name) validator.check_tensor_dtype_valid("input_x", input_x, (mstype.float16, mstype.float32), self.name)
return input_x return input_x
@ -2804,9 +2793,9 @@ class PReLU(PrimitiveWithInfer):
return input_x_shape return input_x_shape
def infer_dtype(self, input_x_dtype, weight_dtype): def infer_dtype(self, input_x_dtype, weight_dtype):
valid_types = (mstype.float16, mstype.float32) valid_dtypes = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"input_x": input_x_dtype}, valid_types, self.name) validator.check_tensor_dtype_valid("input_x", input_x_dtype, valid_dtypes, self.name)
validator.check_tensor_type_same({"weight": weight_dtype}, valid_types, self.name) validator.check_tensor_dtype_valid("weight", weight_dtype, valid_dtypes, self.name)
return input_x_dtype return input_x_dtype
@ -2877,7 +2866,7 @@ class LSTM(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, h_dtype, c_dtype, w_dtype): def infer_dtype(self, x_dtype, h_dtype, c_dtype, w_dtype):
args = {'x': x_dtype, 'h': h_dtype, 'c': c_dtype, 'w': w_dtype} args = {'x': x_dtype, 'h': h_dtype, 'c': c_dtype, 'w': w_dtype}
validator.check_tensor_type_same(args, (mstype.float32, mstype.float16), self.name) validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32, mstype.float16), self.name)
return (x_dtype, x_dtype, x_dtype, x_dtype, x_dtype) return (x_dtype, x_dtype, x_dtype, x_dtype, x_dtype)
def rnd_up(self, current_offset, page_size): def rnd_up(self, current_offset, page_size):
@ -2930,7 +2919,7 @@ class SigmoidCrossEntropyWithLogits(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, y_dtype): def infer_dtype(self, x_dtype, y_dtype):
args = {"x_dtype": x_dtype, "y_dtype": y_dtype} args = {"x_dtype": x_dtype, "y_dtype": y_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return x_dtype return x_dtype
@ -3123,9 +3112,9 @@ class ROIAlign(PrimitiveWithInfer):
return [rois_shape[0], inputs_shape[1], self.pooled_height, self.pooled_width] return [rois_shape[0], inputs_shape[1], self.pooled_height, self.pooled_width]
def infer_dtype(self, inputs_type, rois_type): def infer_dtype(self, inputs_type, rois_type):
valid_types = (mstype.float16, mstype.float32) valid_dtypes = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"inputs_type": inputs_type}, valid_types, self.name) validator.check_tensor_dtype_valid("inputs_type", inputs_type, valid_dtypes, self.name)
validator.check_tensor_type_same({"rois_type": rois_type}, valid_types, self.name) validator.check_tensor_dtype_valid("rois_type", rois_type, valid_dtypes, self.name)
return inputs_type return inputs_type
@ -3199,6 +3188,7 @@ class Adam(PrimitiveWithInfer):
>>> gradient = Tensor(np.random.rand(3, 3, 3).astype(np.float32)) >>> gradient = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
>>> result = net(0.9, 0.999, 0.001, 0.9, 0.999, 1e-8, gradient) >>> result = net(0.9, 0.999, 0.001, 0.9, 0.999, 1e-8, gradient)
""" """
@prim_attr_register @prim_attr_register
def __init__(self, use_locking=False, use_nesterov=False): def __init__(self, use_locking=False, use_nesterov=False):
validator.check_value_type("use_locking", use_locking, [bool], self.name) validator.check_value_type("use_locking", use_locking, [bool], self.name)
@ -3214,11 +3204,11 @@ class Adam(PrimitiveWithInfer):
def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype, def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype,
beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype): beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype):
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype} args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
args = {"beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype, args = {"beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype,
"beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype} "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype}
validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name, True) validator.check_scalar_or_tensor_types_same(args, [mstype.float16, mstype.float32], self.name, True)
return var_dtype, m_dtype, v_dtype return var_dtype, m_dtype, v_dtype
@ -3345,12 +3335,12 @@ class FusedSparseAdam(PrimitiveWithInfer):
def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype, def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype,
beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype, indices_dtype): beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype, indices_dtype):
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype} args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
args = {"beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype, args = {"beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype,
"beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype} "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype}
validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name, True) validator.check_scalar_or_tensor_types_same(args, [mstype.float16, mstype.float32], self.name, True)
validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name) validator.check_tensor_dtype_valid("indices_dtype", indices_dtype, [mstype.int32], self.name)
return var_dtype, m_dtype, v_dtype return var_dtype, m_dtype, v_dtype
@ -3478,13 +3468,13 @@ class FusedSparseLazyAdam(PrimitiveWithInfer):
def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype, def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype,
beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype, indices_dtype): beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype, indices_dtype):
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype} args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
args = {"beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype, args = {"beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype,
"beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype} "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype}
validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name, True) validator.check_scalar_or_tensor_types_same(args, [mstype.float16, mstype.float32], self.name, True)
validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name) validator.check_tensor_dtype_valid("indices_dtype", indices_dtype, [mstype.int32], self.name)
return var_dtype, m_dtype, v_dtype return var_dtype, m_dtype, v_dtype
@ -3578,8 +3568,8 @@ class FusedSparseFtrl(PrimitiveWithInfer):
def infer_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype): def infer_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype):
args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype, args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype,
"linear_dtype": linear_dtype, "grad_dtype": grad_dtype} "linear_dtype": linear_dtype, "grad_dtype": grad_dtype}
validator.check_tensor_type_same(args, [mstype.float32], self.name) validator.check_tensors_dtypes_same_and_valid(args, [mstype.float32], self.name)
validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name) validator.check_tensor_dtype_valid("indices_dtype", indices_dtype, [mstype.int32], self.name)
return var_dtype, accum_dtype, linear_dtype return var_dtype, accum_dtype, linear_dtype
@ -3665,13 +3655,13 @@ class FusedSparseProximalAdagrad(PrimitiveWithInfer):
def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype): def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype):
args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
validator.check_tensor_type_same(args, [mstype.float32], self.name) validator.check_tensors_dtypes_same_and_valid(args, [mstype.float32], self.name)
validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, [mstype.float32], self.name) validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, [mstype.float32], self.name)
validator.check_scalar_or_tensor_type_same({"l1": l1_dtype}, [mstype.float32], self.name) validator.check_scalar_or_tensor_types_same({"l1": l1_dtype}, [mstype.float32], self.name)
validator.check_scalar_or_tensor_type_same({"l2": l2_dtype}, [mstype.float32], self.name) validator.check_scalar_or_tensor_types_same({"l2": l2_dtype}, [mstype.float32], self.name)
valid_types = [mstype.int16, mstype.int32, mstype.int64, valid_dtypes = [mstype.int16, mstype.int32, mstype.int64,
mstype.uint16, mstype.uint32, mstype.uint64] mstype.uint16, mstype.uint32, mstype.uint64]
validator.check_tensor_type_same({'indices': indices_dtype}, valid_types, self.name) validator.check_tensor_dtype_valid('indices', indices_dtype, valid_dtypes, self.name)
return var_dtype, accum_dtype return var_dtype, accum_dtype
@ -3742,8 +3732,8 @@ class KLDivLoss(PrimitiveWithInfer):
def infer_dtype(self, x_type, y_type): def infer_dtype(self, x_type, y_type):
args = {'x': x_type, 'y': y_type} args = {'x': x_type, 'y': y_type}
valid_types = (mstype.float16, mstype.float32) valid_dtypes = (mstype.float16, mstype.float32)
validator.check_tensor_type_same(args, valid_types, self.name) validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
return x_type return x_type
@ -3820,10 +3810,10 @@ class BinaryCrossEntropy(PrimitiveWithInfer):
def infer_dtype(self, x_type, y_type, weight_type): def infer_dtype(self, x_type, y_type, weight_type):
args = {'x': x_type, 'y': y_type} args = {'x': x_type, 'y': y_type}
valid_types = (mstype.float16, mstype.float32) valid_dtypes = (mstype.float16, mstype.float32)
validator.check_tensor_type_same(args, valid_types, self.name) validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
if weight_type: if weight_type:
validator.check_tensor_type_same({'x': x_type, 'weight': weight_type}, valid_types, self.name) validator.check_tensors_dtypes_same_and_valid({'x': x_type, 'weight': weight_type}, valid_dtypes, self.name)
return x_type return x_type
@ -3950,14 +3940,14 @@ class ApplyAdaMax(PrimitiveWithInfer):
def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, lr_dtype, def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, lr_dtype,
beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype): beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype):
valid_types = [mstype.float16, mstype.float32] valid_dtypes = [mstype.float16, mstype.float32]
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype} args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype}
validator.check_tensor_type_same(args, valid_types, self.name) validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
validator.check_scalar_or_tensor_type_same({"beta1_power": beta1_power_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_types_same({"beta1_power": beta1_power_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_type_same({"beta1": beta1_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_types_same({"beta1": beta1_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_type_same({"beta2": beta2_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_types_same({"beta2": beta2_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_type_same({"epsilon": epsilon_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_types_same({"epsilon": epsilon_dtype}, valid_dtypes, self.name)
return var_dtype, m_dtype, v_dtype return var_dtype, m_dtype, v_dtype
@ -4058,12 +4048,12 @@ class ApplyAdadelta(PrimitiveWithInfer):
def infer_dtype(self, var_dtype, accum_dtype, accum_update_dtype, lr_dtype, rho_dtype, def infer_dtype(self, var_dtype, accum_dtype, accum_update_dtype, lr_dtype, rho_dtype,
epsilon_dtype, grad_dtype): epsilon_dtype, grad_dtype):
valid_types = [mstype.float16, mstype.float32] valid_dtypes = [mstype.float16, mstype.float32]
args = {"var": var_dtype, "accum": accum_dtype, "accum_update": accum_update_dtype, "grad": grad_dtype} args = {"var": var_dtype, "accum": accum_dtype, "accum_update": accum_update_dtype, "grad": grad_dtype}
validator.check_tensor_type_same(args, valid_types, self.name) validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_type_same({"rho": rho_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_types_same({"rho": rho_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_type_same({"epsilon": epsilon_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_types_same({"epsilon": epsilon_dtype}, valid_dtypes, self.name)
return var_dtype, accum_dtype, accum_update_dtype return var_dtype, accum_dtype, accum_update_dtype
@ -4142,9 +4132,9 @@ class ApplyAdagrad(PrimitiveWithInfer):
def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, grad_dtype): def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, grad_dtype):
args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
valid_types = [mstype.float16, mstype.float32] valid_dtypes = [mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_types, self.name) validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
validator.check_scalar_or_tensor_type_same({'lr': lr_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_types_same({'lr': lr_dtype}, valid_dtypes, self.name)
return var_dtype, accum_dtype return var_dtype, accum_dtype
@ -4226,8 +4216,8 @@ class ApplyAdagradV2(PrimitiveWithInfer):
def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, grad_dtype): def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, grad_dtype):
args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
validator.check_scalar_or_tensor_type_same({'lr': lr_dtype}, [mstype.float16, mstype.float32], self.name) validator.check_scalar_or_tensor_types_same({'lr': lr_dtype}, [mstype.float16, mstype.float32], self.name)
return var_dtype, accum_dtype return var_dtype, accum_dtype
@ -4313,8 +4303,8 @@ class SparseApplyAdagrad(PrimitiveWithInfer):
def infer_dtype(self, var_type, accum_type, grad_type, indices_type): def infer_dtype(self, var_type, accum_type, grad_type, indices_type):
args = {'var': var_type, 'accum': accum_type, 'grad': grad_type} args = {'var': var_type, 'accum': accum_type, 'grad': grad_type}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_type_same({'indices': indices_type}, [mstype.int32], self.name) validator.check_tensor_dtype_valid('indices', indices_type, [mstype.int32], self.name)
return var_type, accum_type return var_type, accum_type
@ -4402,8 +4392,8 @@ class SparseApplyAdagradV2(PrimitiveWithInfer):
def infer_dtype(self, var_type, accum_type, grad_type, indices_type): def infer_dtype(self, var_type, accum_type, grad_type, indices_type):
args = {'var': var_type, 'accum': accum_type, 'grad': grad_type} args = {'var': var_type, 'accum': accum_type, 'grad': grad_type}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_type_same({'indices': indices_type}, [mstype.int32], self.name) validator.check_tensor_dtype_valid('indices', indices_type, [mstype.int32], self.name)
return var_type, accum_type return var_type, accum_type
@ -4500,12 +4490,12 @@ class ApplyProximalAdagrad(PrimitiveWithInfer):
return var_shape, accum_shape return var_shape, accum_shape
def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype): def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype):
valid_types = [mstype.float16, mstype.float32] valid_dtypes = [mstype.float16, mstype.float32]
args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
validator.check_tensor_type_same(args, valid_types, self.name) validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_type_same({"l1": l1_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_types_same({"l1": l1_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_type_same({"l2": l2_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_types_same({"l2": l2_dtype}, valid_dtypes, self.name)
return var_dtype, accum_dtype return var_dtype, accum_dtype
@ -4594,13 +4584,13 @@ class SparseApplyProximalAdagrad(PrimitiveWithCheck):
def check_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype): def check_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype):
args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, [mstype.float16, mstype.float32], self.name) validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, [mstype.float16, mstype.float32], self.name)
validator.check_scalar_or_tensor_type_same({"l1": l1_dtype}, [mstype.float16, mstype.float32], self.name) validator.check_scalar_or_tensor_types_same({"l1": l1_dtype}, [mstype.float16, mstype.float32], self.name)
validator.check_scalar_or_tensor_type_same({"l2": l2_dtype}, [mstype.float16, mstype.float32], self.name) validator.check_scalar_or_tensor_types_same({"l2": l2_dtype}, [mstype.float16, mstype.float32], self.name)
valid_types = [mstype.int16, mstype.int32, mstype.int64, valid_dtypes = [mstype.int16, mstype.int32, mstype.int64,
mstype.uint16, mstype.uint32, mstype.uint64] mstype.uint16, mstype.uint32, mstype.uint64]
validator.check_tensor_type_same({'indices': indices_dtype}, valid_types, self.name) validator.check_tensor_dtype_valid('indices', indices_dtype, valid_dtypes, self.name)
class ApplyAddSign(PrimitiveWithInfer): class ApplyAddSign(PrimitiveWithInfer):
@ -4699,13 +4689,13 @@ class ApplyAddSign(PrimitiveWithInfer):
return var_shape, m_shape return var_shape, m_shape
def infer_dtype(self, var_dtype, m_dtype, lr_dtype, alpha_dtype, sign_decay_dtype, beta_dtype, grad_dtype): def infer_dtype(self, var_dtype, m_dtype, lr_dtype, alpha_dtype, sign_decay_dtype, beta_dtype, grad_dtype):
valid_types = [mstype.float16, mstype.float32] valid_dtypes = [mstype.float16, mstype.float32]
args = {'var': var_dtype, 'm': m_dtype, 'grad': grad_dtype} args = {'var': var_dtype, 'm': m_dtype, 'grad': grad_dtype}
validator.check_tensor_type_same(args, valid_types, self.name) validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_type_same({"alpha": alpha_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_types_same({"alpha": alpha_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_type_same({"sign_decay": sign_decay_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_types_same({"sign_decay": sign_decay_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_type_same({"beta": beta_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_types_same({"beta": beta_dtype}, valid_dtypes, self.name)
return var_dtype, m_dtype return var_dtype, m_dtype
@ -4808,13 +4798,13 @@ class ApplyPowerSign(PrimitiveWithInfer):
return var_shape, m_shape return var_shape, m_shape
def infer_dtype(self, var_dtype, m_dtype, lr_dtype, logbase_dtype, sign_decay_dtype, beta_dtype, grad_dtype): def infer_dtype(self, var_dtype, m_dtype, lr_dtype, logbase_dtype, sign_decay_dtype, beta_dtype, grad_dtype):
valid_types = [mstype.float16, mstype.float32] valid_dtypes = [mstype.float16, mstype.float32]
args = {'var': var_dtype, 'm': m_dtype, 'grad': grad_dtype} args = {'var': var_dtype, 'm': m_dtype, 'grad': grad_dtype}
validator.check_tensor_type_same(args, valid_types, self.name) validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_type_same({"logbase": logbase_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_types_same({"logbase": logbase_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_type_same({"sign_decay": sign_decay_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_types_same({"sign_decay": sign_decay_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_type_same({"beta": beta_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_types_same({"beta": beta_dtype}, valid_dtypes, self.name)
return var_dtype, m_dtype return var_dtype, m_dtype
@ -4876,10 +4866,10 @@ class ApplyGradientDescent(PrimitiveWithInfer):
return var_shape return var_shape
def infer_dtype(self, var_dtype, alpha_dtype, delta_dtype): def infer_dtype(self, var_dtype, alpha_dtype, delta_dtype):
valid_types = [mstype.float16, mstype.float32] valid_dtypes = [mstype.float16, mstype.float32]
args = {'var': var_dtype, 'delta': delta_dtype} args = {'var': var_dtype, 'delta': delta_dtype}
validator.check_tensor_type_same(args, valid_types, self.name) validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
validator.check_scalar_or_tensor_type_same({"alpha": alpha_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_types_same({"alpha": alpha_dtype}, valid_dtypes, self.name)
return var_dtype return var_dtype
@ -4959,12 +4949,12 @@ class ApplyProximalGradientDescent(PrimitiveWithInfer):
return var_shape return var_shape
def infer_dtype(self, var_dtype, alpha_dtype, l1_dtype, l2_dtype, delta_dtype): def infer_dtype(self, var_dtype, alpha_dtype, l1_dtype, l2_dtype, delta_dtype):
valid_types = [mstype.float16, mstype.float32] valid_dtypes = [mstype.float16, mstype.float32]
args = {'var': var_dtype, 'delta': delta_dtype} args = {'var': var_dtype, 'delta': delta_dtype}
validator.check_tensor_type_same(args, valid_types, self.name) validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
validator.check_scalar_or_tensor_type_same({"alpha": alpha_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_types_same({"alpha": alpha_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_type_same({"l1": l1_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_types_same({"l1": l1_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_type_same({"l2": l2_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_types_same({"l2": l2_dtype}, valid_dtypes, self.name)
return var_dtype return var_dtype
@ -5036,11 +5026,13 @@ class LARSUpdate(PrimitiveWithInfer):
weight_decay_dtype, learning_rate_dtype): weight_decay_dtype, learning_rate_dtype):
args = {"Weight dtype": weight_dtype, "gradient dtype": gradient_dtype, "norm weight dtype": norm_weight_dtype, args = {"Weight dtype": weight_dtype, "gradient dtype": gradient_dtype, "norm weight dtype": norm_weight_dtype,
"norm gradient dtype": norm_gradient_dtype} "norm gradient dtype": norm_gradient_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int16, mstype.int32], self.name) validator.check_tensors_dtypes_same_and_valid(args,
validator.check_scalar_or_tensor_type_same({"weight_decay": weight_decay_dtype}, [mstype.float16, mstype.float32, mstype.int16, mstype.int32],
[mstype.float16, mstype.float32, mstype.float64], self.name) self.name)
validator.check_scalar_or_tensor_type_same({"learning_rate": learning_rate_dtype}, validator.check_scalar_or_tensor_types_same({"weight_decay": weight_decay_dtype},
[mstype.float16, mstype.float32, mstype.float64], self.name) [mstype.float16, mstype.float32, mstype.float64], self.name)
validator.check_scalar_or_tensor_types_same({"learning_rate": learning_rate_dtype},
[mstype.float16, mstype.float32, mstype.float64], self.name)
return weight_dtype return weight_dtype
@ -5117,14 +5109,14 @@ class ApplyFtrl(PrimitiveWithInfer):
return var_shape return var_shape
def infer_dtype(self, var_type, accum_type, linear_type, grad_type, lr_type, l1_type, l2_type, lr_power_type): def infer_dtype(self, var_type, accum_type, linear_type, grad_type, lr_type, l1_type, l2_type, lr_power_type):
valid_types = [mstype.float16, mstype.float32] valid_dtypes = [mstype.float16, mstype.float32]
args = {'var': var_type, 'accum': accum_type, 'linear': linear_type, 'grad': grad_type} args = {'var': var_type, 'accum': accum_type, 'linear': linear_type, 'grad': grad_type}
validator.check_tensor_type_same(args, valid_types, self.name) validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
validator.check_scalar_or_tensor_type_same({"lr": lr_type}, valid_types, self.name) validator.check_scalar_or_tensor_types_same({"lr": lr_type}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_type_same({"l1": l1_type}, valid_types, self.name) validator.check_scalar_or_tensor_types_same({"l1": l1_type}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_type_same({"l2": l2_type}, valid_types, self.name) validator.check_scalar_or_tensor_types_same({"l2": l2_type}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_type_same({"lr_power": lr_power_type}, valid_types, self.name) validator.check_scalar_or_tensor_types_same({"lr_power": lr_power_type}, valid_dtypes, self.name)
if self.is_tbe: if self.is_tbe:
return var_type, var_type, var_type return var_type, var_type, var_type
return var_type return var_type
@ -5219,8 +5211,8 @@ class SparseApplyFtrl(PrimitiveWithCheck):
def check_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype): def check_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype):
args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype, args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype,
"linear_dtype": linear_dtype, "grad_dtype": grad_dtype} "linear_dtype": linear_dtype, "grad_dtype": grad_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32, mstype.int64], self.name) validator.check_tensor_dtype_valid("indices_dtype", indices_dtype, [mstype.int32, mstype.int64], self.name)
class SparseApplyFtrlV2(PrimitiveWithInfer): class SparseApplyFtrlV2(PrimitiveWithInfer):
@ -5316,8 +5308,8 @@ class SparseApplyFtrlV2(PrimitiveWithInfer):
def infer_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype): def infer_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype):
args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype, args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype,
"linear_dtype": linear_dtype, "grad_dtype": grad_dtype} "linear_dtype": linear_dtype, "grad_dtype": grad_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name) validator.check_tensor_dtype_valid("indicese", indices_dtype, [mstype.int32], self.name)
return var_dtype, accum_dtype, linear_dtype return var_dtype, accum_dtype, linear_dtype
@ -5351,9 +5343,8 @@ class Dropout(PrimitiveWithInfer):
return x_shape, mask_shape return x_shape, mask_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
valid_types = (mstype.float16, mstype.float32) valid_dtypes = (mstype.float16, mstype.float32)
validator.check_subclass("x", x_dtype, mstype.tensor, self.name) validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
validator.check_tensor_type_same({"x_dtype": x_dtype}, valid_types, self.name)
return x_dtype, x_dtype return x_dtype, x_dtype
@ -5425,10 +5416,10 @@ class CTCLoss(PrimitiveWithInfer):
def infer_dtype(self, inputs, labels_indices, labels_values, sequence_length): def infer_dtype(self, inputs, labels_indices, labels_values, sequence_length):
valid_dtype = [mstype.float16, mstype.float32, mstype.double] valid_dtype = [mstype.float16, mstype.float32, mstype.double]
validator.check_tensor_type_same({"inputs_dtype": inputs}, valid_dtype, self.name) validator.check_tensor_dtype_valid("inputs", inputs, valid_dtype, self.name)
validator.check_tensor_type_same({"labels_indices_dtype": labels_indices}, [mstype.int64], self.name) validator.check_tensor_dtype_valid("labels_indices", labels_indices, [mstype.int64], self.name)
validator.check_tensor_type_same({"labels_values_dtype": labels_values}, [mstype.int32], self.name) validator.check_tensor_dtype_valid("labels_values", labels_values, [mstype.int32], self.name)
validator.check_tensor_type_same({"sequence_length_dtype": sequence_length}, [mstype.int32], self.name) validator.check_tensor_dtype_valid("sequence_length", sequence_length, [mstype.int32], self.name)
return inputs, inputs return inputs, inputs
@ -5492,8 +5483,8 @@ class CTCGreedyDecoder(PrimitiveWithInfer):
return decoded_indices_shape, decoded_values, decoded_shape, log_probability_shape return decoded_indices_shape, decoded_values, decoded_shape, log_probability_shape
def infer_dtype(self, inputs_dtype, sequence_length_dtype): def infer_dtype(self, inputs_dtype, sequence_length_dtype):
validator.check_tensor_type_same({"inputs_dtype": inputs_dtype}, [mstype.float32, mstype.double], self.name) validator.check_tensor_dtype_valid("inputs_dtype", inputs_dtype, [mstype.float32, mstype.double], self.name)
validator.check_tensor_type_same({"sequence_length_dtype": sequence_length_dtype}, [mstype.int32], self.name) validator.check_tensor_dtype_valid("sequence_length_dtype", sequence_length_dtype, [mstype.int32], self.name)
decoded_type = mstype.tensor_type(mstype.int64) decoded_type = mstype.tensor_type(mstype.int64)
return decoded_type, decoded_type, decoded_type, inputs_dtype return decoded_type, decoded_type, decoded_type, inputs_dtype
@ -5597,12 +5588,12 @@ class BasicLSTMCell(PrimitiveWithInfer):
return (ct_shape, ht_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape) return (ct_shape, ht_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape)
def infer_dtype(self, x_dtype, h_dtype, c_dtype, w_dtype, b_dtype): def infer_dtype(self, x_dtype, h_dtype, c_dtype, w_dtype, b_dtype):
validator.check_tensor_type_same({"x_dtype": x_dtype}, [mstype.float16, mstype.float32], self.name) tuple(map(partial(validator.check_tensor_dtype_valid,
validator.check_tensor_type_same({"h_dtype": h_dtype}, [mstype.float16, mstype.float32], self.name) valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
validator.check_tensor_type_same({"w_dtype": w_dtype}, [mstype.float16, mstype.float32], self.name) ("x_dtype", "h_dtype", "w_dtype"),
(x_dtype, h_dtype, w_dtype)))
args = {"c_dtype": c_dtype, "b_dtype": b_dtype} args = {"c_dtype": c_dtype, "b_dtype": b_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
return (c_dtype, mstype.float16, c_dtype, c_dtype, c_dtype, c_dtype, c_dtype) return (c_dtype, mstype.float16, c_dtype, c_dtype, c_dtype, c_dtype, c_dtype)
@ -5725,11 +5716,10 @@ class DynamicRNN(PrimitiveWithInfer):
return y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape return y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape
def infer_dtype(self, x_dtype, w_dtype, b_dtype, seq_dtype, h_dtype, c_dtype): def infer_dtype(self, x_dtype, w_dtype, b_dtype, seq_dtype, h_dtype, c_dtype):
validator.check_tensor_type_same({"x dtype": x_dtype}, (mstype.float16,), self.name) tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=mstype.float16, prim_name=self.name),
validator.check_tensor_type_same({"w dtype": w_dtype}, (mstype.float16,), self.name) ("x", "w", "h", "c"),
validator.check_tensor_type_same({"b dtype": b_dtype}, (mstype.float32, mstype.float16), self.name) (x_dtype, w_dtype, h_dtype, c_dtype)))
validator.check_tensor_type_same({"h dtype": h_dtype}, (mstype.float16,), self.name) validator.check_tensor_dtype_valid("b", b_dtype, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"c dtype": c_dtype}, (mstype.float16,), self.name)
return b_dtype, x_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype return b_dtype, x_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype
@ -5765,8 +5755,8 @@ class InTopK(PrimitiveWithInfer):
validator.check_value_type("k", k, [int], self.name) validator.check_value_type("k", k, [int], self.name)
def infer_dtype(self, x1_dtype, x2_dtype): def infer_dtype(self, x1_dtype, x2_dtype):
validator.check_tensor_type_same({"x1": x1_dtype}, (mstype.float16, mstype.float32,), self.name) validator.check_tensor_dtype_valid("x1", x1_dtype, (mstype.float16, mstype.float32,), self.name)
validator.check_tensor_type_same({"x2": x2_dtype}, (mstype.int32,), self.name) validator.check_tensor_dtype_valid("x2", x2_dtype, (mstype.int32,), self.name)
return mstype.tensor_type(mstype.bool_) return mstype.tensor_type(mstype.bool_)
@ -5803,6 +5793,7 @@ class LRN(PrimitiveWithInfer):
[[0.6258911 0.4964315 ] [[0.6258911 0.4964315 ]
[0.3141494 0.43636137]]]] [0.3141494 0.43636137]]]]
""" """
@prim_attr_register @prim_attr_register
def __init__(self, depth_radius=5, bias=1.0, alpha=1.0, beta=0.5, norm_region="ACROSS_CHANNELS"): def __init__(self, depth_radius=5, bias=1.0, alpha=1.0, beta=0.5, norm_region="ACROSS_CHANNELS"):
"""Initialize LRN""" """Initialize LRN"""
@ -5816,7 +5807,7 @@ class LRN(PrimitiveWithInfer):
validator.check_non_negative_int(depth_radius, "depth_radius", self.name) validator.check_non_negative_int(depth_radius, "depth_radius", self.name)
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32,), self.name) validator.check_tensor_dtype_valid("x", x_dtype, (mstype.float16, mstype.float32,), self.name)
return x_dtype return x_dtype
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
@ -5857,6 +5848,7 @@ class UniformSampler(PrimitiveWithInfer):
[3]], dtype=np.int32))) [3]], dtype=np.int32)))
[1, 1, 3], [[0.75], [0.75], [0.75], [0.75], [0.75]], [0.75, 0.75, 0.75] [1, 1, 3], [[0.75], [0.75], [0.75], [0.75], [0.75]], [0.75, 0.75, 0.75]
""" """
@prim_attr_register @prim_attr_register
def __init__(self, num_true, num_sampled, unique, range_max, seed=0, remove_accidental_hits=False): def __init__(self, num_true, num_sampled, unique, range_max, seed=0, remove_accidental_hits=False):
"""Initialize UniformSampler""" """Initialize UniformSampler"""

View File

@ -61,8 +61,8 @@ class Assign(PrimitiveWithCheck):
def check_dtype(self, variable, value): def check_dtype(self, variable, value):
if variable != mstype.type_refkey: if variable != mstype.type_refkey:
validator.check_tensor_type_same({"variable": variable}, mstype.number_type, self.name) validator.check_tensor_dtype_valid("variable", variable, mstype.number_type, self.name)
validator.check_scalar_or_tensor_type_same({"value": value}, mstype.number_type, self.name) validator.check_scalar_or_tensor_types_same({"value": value}, mstype.number_type, self.name)
class BoundingBoxEncode(PrimitiveWithInfer): class BoundingBoxEncode(PrimitiveWithInfer):
@ -112,7 +112,7 @@ class BoundingBoxEncode(PrimitiveWithInfer):
def infer_dtype(self, anchor_box, groundtruth_box): def infer_dtype(self, anchor_box, groundtruth_box):
args = {"anchor_box": anchor_box, "groundtruth_box": groundtruth_box} args = {"anchor_box": anchor_box, "groundtruth_box": groundtruth_box}
validator.check_tensor_type_same(args, mstype.number_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return anchor_box return anchor_box
@ -169,7 +169,7 @@ class BoundingBoxDecode(PrimitiveWithInfer):
def infer_dtype(self, anchor_box, deltas): def infer_dtype(self, anchor_box, deltas):
args = {"anchor_box": anchor_box, "deltas": deltas} args = {"anchor_box": anchor_box, "deltas": deltas}
validator.check_tensor_type_same(args, mstype.number_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return anchor_box return anchor_box
@ -221,8 +221,8 @@ class CheckValid(PrimitiveWithInfer):
def infer_dtype(self, bboxes_type, metas_type): def infer_dtype(self, bboxes_type, metas_type):
valid_type = [mstype.float32, mstype.float16, mstype.int16, mstype.uint8] valid_type = [mstype.float32, mstype.float16, mstype.int16, mstype.uint8]
validator.check_tensor_type_same({"bboxes_type": bboxes_type}, valid_type, self.name) validator.check_tensor_dtype_valid("bboxes_type", bboxes_type, valid_type, self.name)
validator.check_tensor_type_same({"metas_type": metas_type}, valid_type, self.name) validator.check_tensor_dtype_valid("metas_type", metas_type, valid_type, self.name)
return mstype.bool_ return mstype.bool_
@ -281,8 +281,8 @@ class IOU(PrimitiveWithInfer):
def infer_dtype(self, anchor_boxes, gt_boxes): def infer_dtype(self, anchor_boxes, gt_boxes):
valid_type = [mstype.float32, mstype.float16] valid_type = [mstype.float32, mstype.float16]
validator.check_tensor_type_same({"anchor_boxes": anchor_boxes}, valid_type, self.name) validator.check_tensor_dtype_valid("anchor_boxes", anchor_boxes, valid_type, self.name)
validator.check_tensor_type_same({"gt_boxes": gt_boxes}, valid_type, self.name) validator.check_tensor_dtype_valid("gt_boxes", gt_boxes, valid_type, self.name)
return anchor_boxes return anchor_boxes
@ -478,7 +478,7 @@ class ConfusionMatrix(PrimitiveWithInfer):
if weights is not None: if weights is not None:
validator.check_subclass('weights', weights, mstype.tensor, self.name) validator.check_subclass('weights', weights, mstype.tensor, self.name)
args = {"labels": labels, "predictions": predictions} args = {"labels": labels, "predictions": predictions}
validator.check_tensor_type_same(args, (mstype.number_type), self.name) validator.check_tensors_dtypes_same_and_valid(args, (mstype.number_type), self.name)
return labels return labels
@ -506,8 +506,7 @@ class PopulationCount(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
args = {"x": x_dtype} validator.check_tensor_dtype_valid("x", x_dtype, (mstype.int16, mstype.uint16,), self.name)
validator.check_tensor_type_same(args, (mstype.int16, mstype.uint16,), self.name)
return mstype.tensor_type(mstype.uint8) return mstype.tensor_type(mstype.uint8)
class Push(PrimitiveWithInfer): class Push(PrimitiveWithInfer):

View File

@ -151,8 +151,8 @@ class Gamma(PrimitiveWithInfer):
Validator.check_value_type("shape", shape_v, [tuple], self.name) Validator.check_value_type("shape", shape_v, [tuple], self.name)
for i, shape_i in enumerate(shape_v): for i, shape_i in enumerate(shape_v):
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name) Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
Validator.check_tensor_type_same({"alpha": alpha["dtype"]}, [mstype.float32], self.name) Validator.check_tensor_dtype_valid("alpha", alpha["dtype"], [mstype.float32], self.name)
Validator.check_tensor_type_same({"beta": beta["dtype"]}, [mstype.float32], self.name) Validator.check_tensor_dtype_valid("beta", beta["dtype"], [mstype.float32], self.name)
broadcast_shape = get_broadcast_shape(alpha['shape'], beta['shape'], self.name) broadcast_shape = get_broadcast_shape(alpha['shape'], beta['shape'], self.name)
broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name) broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name)
out = { out = {
@ -203,7 +203,7 @@ class Poisson(PrimitiveWithInfer):
Validator.check_value_type("shape", shape_v, [tuple], self.name) Validator.check_value_type("shape", shape_v, [tuple], self.name)
for i, shape_i in enumerate(shape_v): for i, shape_i in enumerate(shape_v):
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name) Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
Validator.check_tensor_type_same({"mean": mean["dtype"]}, [mstype.float32], self.name) Validator.check_tensor_dtype_valid("mean", mean["dtype"], [mstype.float32], self.name)
broadcast_shape = get_broadcast_shape(mean['shape'], shape_v, self.name) broadcast_shape = get_broadcast_shape(mean['shape'], shape_v, self.name)
out = { out = {
'shape': broadcast_shape, 'shape': broadcast_shape,
@ -259,8 +259,8 @@ class UniformInt(PrimitiveWithInfer):
Validator.check_value_type("shape", shape_v, [tuple], self.name) Validator.check_value_type("shape", shape_v, [tuple], self.name)
for i, shape_i in enumerate(shape_v): for i, shape_i in enumerate(shape_v):
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name) Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
Validator.check_tensor_type_same({"minval": minval["dtype"]}, [mstype.int32], self.name) Validator.check_tensor_dtype_valid("minval", minval["dtype"], [mstype.int32], self.name)
Validator.check_tensor_type_same({"maxval": maxval["dtype"]}, [mstype.int32], self.name) Validator.check_tensor_dtype_valid("maxval", maxval["dtype"], [mstype.int32], self.name)
minval_shape = minval['shape'] minval_shape = minval['shape']
maxval_shape = maxval['shape'] maxval_shape = maxval['shape']
Validator.check("dim of minval", len(minval_shape), '0(scalar)', 0, Rel.EQ, self.name) Validator.check("dim of minval", len(minval_shape), '0(scalar)', 0, Rel.EQ, self.name)
@ -361,7 +361,7 @@ class RandomChoiceWithMask(PrimitiveWithInfer):
return ([self.count, len(x_shape)], [self.count]) return ([self.count, len(x_shape)], [self.count])
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
Validator.check_tensor_type_same({'x': x_dtype}, [mstype.bool_], self.name) Validator.check_tensor_dtype_valid('x', x_dtype, [mstype.bool_], self.name)
return (mstype.int32, mstype.bool_) return (mstype.int32, mstype.bool_)
@ -407,8 +407,8 @@ class RandomCategorical(PrimitiveWithInfer):
def __infer__(self, logits, num_samples, seed): def __infer__(self, logits, num_samples, seed):
logits_dtype = logits['dtype'] logits_dtype = logits['dtype']
valid_types = (mstype.float32, mstype.float16, mstype.float64) valid_dtypes = (mstype.float32, mstype.float16, mstype.float64)
Validator.check_tensor_type_same({'logits': logits_dtype}, valid_types, self.name) Validator.check_tensor_dtype_valid('logits', logits_dtype, valid_dtypes, self.name)
num_samples_v = num_samples['value'] num_samples_v = num_samples['value']
seed_v = seed['value'] seed_v = seed['value']
Validator.check_value_type('num_samples', num_samples_v, (int,), self.name) Validator.check_value_type('num_samples', num_samples_v, (int,), self.name)
@ -460,7 +460,7 @@ class Multinomial(PrimitiveWithInfer):
input_shape = inputs["shape"] input_shape = inputs["shape"]
if len(input_shape) != 1 and len(input_shape) != 2: if len(input_shape) != 1 and len(input_shape) != 2:
raise ValueError("input dim must be 1 or 2") raise ValueError("input dim must be 1 or 2")
Validator.check_tensor_type_same({'inputs': inputs['dtype']}, [mstype.float32], self.name) Validator.check_tensor_dtype_valid('inputs', inputs['dtype'], [mstype.float32], self.name)
num_samples_value = num_samples["value"] num_samples_value = num_samples["value"]
if num_samples_value is None: if num_samples_value is None:
raise ValueError(f"For {self.name}, shape nust be const") raise ValueError(f"For {self.name}, shape nust be const")

View File

@ -588,8 +588,8 @@ def _quant_export(network, *inputs, file_format, **kwargs):
if quant_mode not in quant_mode_formats: if quant_mode not in quant_mode_formats:
raise KeyError(f'Quant_mode input is wrong, Please choose the right mode of the quant_mode.') raise KeyError(f'Quant_mode input is wrong, Please choose the right mode of the quant_mode.')
mean = Validator.check_type("mean", mean, (int, float)) mean = Validator.check_value_type("mean", mean, (int, float))
std_dev = Validator.check_type("std_dev", std_dev, (int, float)) std_dev = Validator.check_value_type("std_dev", std_dev, (int, float))
if context.get_context('device_target') not in supported_device: if context.get_context('device_target') not in supported_device:
raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))

View File

@ -117,7 +117,7 @@ class MySparseGatherV2(PrimitiveWithInfer):
def __infer__(self, params, indices, axis): def __infer__(self, params, indices, axis):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name) validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name)
axis_v = axis['value'] axis_v = axis['value']
params_shp = params['shape'] params_shp = params['shape']