Rectify and optimize the type checking function

This commit is contained in:
buxue 2020-11-02 16:00:30 +08:00
parent 9ae5f96988
commit 346bcfa3fd
30 changed files with 648 additions and 712 deletions

View File

@ -415,37 +415,20 @@ class Validator:
break
if not hit:
type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_)
raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass'
f' of {",".join((str(x) for x in template_types))}, but got {type_str}.')
raise TypeError(f'For \'{prim_name}\', the type of `{arg_name}` should be subclass'
f' of {", ".join((str(x) for x in template_types))}, but got {type_str}.')
@staticmethod
def check_const_input(arg_name, arg_value, prim_name):
"""Checks valid value."""
if arg_value is None:
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must be a const input, but got {arg_value}.')
raise ValueError(f'For \'{prim_name}\', the `{arg_name}` must be a const input, but got {arg_value}.')
return arg_value
@staticmethod
def check_type(arg_name, arg_value, valid_types):
"""Type checking."""
def raise_error_msg():
"""func for raising error message when check failed"""
raise TypeError(f'The type of `{arg_name}` should be in {valid_types}, but got {type(arg_value).__name__}.')
if isinstance(arg_value, type(mstype.tensor)):
arg_value = arg_value.element_type()
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
raise_error_msg()
if arg_value in valid_types:
return arg_value
if isinstance(arg_value, tuple(valid_types)):
return arg_value
raise_error_msg()
@staticmethod
def check_type_same(args, valid_values, prim_name):
"""Checks whether the types of inputs are the same."""
def _check_tensor_type(arg):
def check_types_same_and_valid(args, valid_values, prim_name):
"""Checks whether the types of inputs are the same and valid."""
def _check_type_valid(arg):
arg_key, arg_val = arg
elem_type = arg_val
Validator.check_subclass(arg_key, elem_type, valid_values, prim_name)
@ -455,21 +438,27 @@ class Validator:
arg1_name, arg1_type = arg1
arg2_name, arg2_type = arg2
if arg1_type != arg2_type:
raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,'
raise TypeError(f'For \'{prim_name}\', type of `{arg2_name}` should be same as `{arg1_name}`,'
f' but `{arg1_name}` with type {arg1_type} and `{arg2_name}` with type {arg2_type}.')
return arg1
elem_types = map(_check_tensor_type, args.items())
elem_types = map(_check_type_valid, args.items())
reduce(_check_types_same, elem_types)
@staticmethod
def check_tensor_type_same(args, valid_values, prim_name):
"""Checks whether the element types of input tensors are the same."""
tensor_types = [mstype.tensor_type(t) for t in valid_values]
Validator.check_type_same(args, tensor_types, prim_name)
def check_tensors_dtypes_same_and_valid(args, valid_dtypes, prim_name):
"""Checks whether the element types of input tensors are the same and valid."""
tensor_types = [mstype.tensor_type(t) for t in valid_dtypes]
Validator.check_types_same_and_valid(args, tensor_types, prim_name)
@staticmethod
def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False):
def check_tensor_dtype_valid(arg_name, arg_type, valid_dtypes, prim_name):
"""Checks whether the element types of input tensors are valid."""
tensor_types = [mstype.tensor_type(t) for t in valid_dtypes]
Validator.check_subclass(arg_name, arg_type, tensor_types, prim_name)
@staticmethod
def check_scalar_or_tensor_types_same(args, valid_values, prim_name, allow_mix=False):
"""
Checks whether the types of inputs are the same. If the input args are tensors, checks their element types.
If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised.
@ -480,7 +469,7 @@ class Validator:
if isinstance(arg_val, type(mstype.tensor)):
arg_val = arg_val.element_type()
if not arg_val in valid_values:
raise TypeError(f'For \'{prim_name}\' the `{arg_key}` should be in {valid_values},'
raise TypeError(f'For \'{prim_name}\', the `{arg_key}` should be in {valid_values},'
f' but `{arg_key}` is {arg_val}.')
return arg
@ -512,40 +501,40 @@ class Validator:
def raise_error_msg():
"""func for raising error message when check failed"""
type_names = [t.__name__ for t in valid_types]
type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in valid_types]
num_types = len(valid_types)
msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The'
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
f'{type_names if num_types > 1 else type_names[0]}, '
f'but got {arg_value} with type {type(arg_value).__name__}.')
# Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
# `check_value_type('x', True, [bool, int])` will check pass
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
raise_error_msg()
if isinstance(arg_value, tuple(valid_types)):
return arg_value
raise_error_msg()
if not isinstance(arg_value, tuple(valid_types)):
raise_error_msg()
return arg_value
@staticmethod
def check_type_name(arg_name, arg_type, valid_types, prim_name):
"""Checks whether a type in some specified types"""
valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
def get_typename(t):
return t.__name__ if hasattr(t, '__name__') else str(t)
def raise_error_msg():
"""func for raising error message when check failed"""
type_names = [t.__name__ if hasattr(t, '__name__') else t for t in valid_types]
num_types = len(valid_types)
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
raise TypeError(f"{msg_prefix} '{arg_name}' should be {'one of ' if num_types > 1 else ''}"
f"{type_names if num_types > 1 else type_names[0]}, "
f"but got {arg_type.__name__ if hasattr(arg_type, '__name__') else repr(arg_type)}.")
if isinstance(arg_type, type(mstype.tensor)):
arg_type = arg_type.element_type()
if arg_type in valid_types:
return arg_type
type_names = [get_typename(t) for t in valid_types]
msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The'
if len(valid_types) == 1:
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {type_names[0]},'
f' but got {get_typename(arg_type)}.')
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},'
f' but got {get_typename(arg_type)}.')
if arg_type not in valid_types:
raise_error_msg()
return arg_type
@staticmethod
def check_reduce_shape(ori_shape, shape, axis, prim_name):
@ -611,65 +600,6 @@ def check_output_data(data):
once = _expand_tuple(1)
twice = _expand_tuple(2)
triple = _expand_tuple(3)
valid_data_types = (int, float, np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
np.float32, np.float64, bool, np.bool_)
def check_type(arg_name, arg_value, valid_types):
"""Check value type."""
# if input type is Tensor ,get element type
if isinstance(arg_value, type(mstype.tensor)):
arg_value = arg_value.element_type()
# First, check if arg_value has argvalid_types
if isinstance(arg_value, tuple(valid_types)):
return type(arg_value).__name__
# Second, wrap arg_value with numpy array so that it can be checked through numpy api
if isinstance(arg_value, (list, tuple)):
arg_value = np.array(arg_value)
# Thirdly, check the data type by numpy's dtype api
valid = False
if isinstance(arg_value, np.ndarray):
valid = arg_value.dtype in valid_data_types
# Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and
# `check_type('x', True, [bool, int])` will check pass
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
valid = False
if not valid:
type_names = [t.__name__ for t in valid_types]
if len(valid_types) == 1:
raise TypeError(f'The type of `{arg_name}` should be {type_names[0]},'
f' but got {type(arg_value).__name__}.')
raise TypeError(f'The type of `{arg_name}` should be one of {type_names},'
f' but got {type(arg_value).__name__}.')
return type(arg_value).__name__
def check_typename(arg_name, arg_type, valid_types):
"""Check type name."""
def get_typename(t):
return t.__name__ if hasattr(t, '__name__') else str(t)
if isinstance(arg_type, type(mstype.tensor)):
arg_type = arg_type.element_type()
if arg_type in valid_types:
return arg_type
if isinstance(arg_type, tuple(valid_types)):
return arg_type
type_names = [get_typename(t) for t in valid_types]
if len(valid_types) == 1:
raise TypeError(f'The type of `{arg_name}` should be {type_names[0]},'
f' but got {get_typename(arg_type)}.')
raise TypeError(f'The type of `{arg_name}` should be one of {type_names},'
f' but got {get_typename(arg_type)}.')
def args_type_check(*type_args, **type_kwargs):

View File

@ -19,7 +19,7 @@ from mindspore import log as logger
from mindspore.communication.management import get_rank, get_group_size
from .._c_expression import Tensor as Tensor_
from .._c_expression import MetaTensor as MetaTensor_
from .._checkparam import check_type, check_typename
from .._checkparam import Validator as validator
from . import dtype as mstype
from ._register_for_tensor import tensor_operator_registry
@ -64,9 +64,19 @@ class Tensor(Tensor_):
input_data = np.array(input_data)
# If input_data is tuple/list/numpy.ndarray, it's support in check_type method.
check_type('tensor input_data', input_data, (Tensor_, float, int))
validator.check_value_type('input_data', input_data, (Tensor_, np.ndarray, list, tuple, float, int, bool),
'Tensor')
valid_dtypes = (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64,
np.float16, np.float32, np.float64, np.bool_)
if isinstance(input_data, np.ndarray) and input_data.dtype not in valid_dtypes:
raise TypeError(f"For Tensor, the input_data is a numpy array whose data type is "
f"{input_data.dtype} that is not supported to initialize a Tensor.")
if isinstance(input_data, (tuple, list)):
if np.array(input_data).dtype not in valid_dtypes:
raise TypeError(f"For Tensor, the input_data is {input_data} that contain unsupported element.")
if dtype is not None:
check_typename('dtype', dtype, mstype.number_type + (mstype.bool_,))
validator.check_type_name('dtype', dtype, mstype.number_type + (mstype.bool_,), "Tensor")
if isinstance(input_data, np.ndarray) and (not input_data.flags['FORC']):
input_data = np.ascontiguousarray(input_data)
if dtype is None:
@ -405,8 +415,9 @@ class MetaTensor(MetaTensor_):
Returns:
Array, an array after being initialized.
"""
def __init__(self, dtype, shape, init=None):
#check param
# check param
self.init = init
MetaTensor_.__init__(self, dtype, shape)
@ -434,8 +445,10 @@ class MetaTensor(MetaTensor_):
msg = "Error shape={}".format(shape)
logger.error(msg)
raise ValueError(msg)
class seed_context:
'''set and restore seed'''
def __init__(self, init):
self.init = init
from .seed import get_seed
@ -482,4 +495,5 @@ def _vm_compare(*args):
y = args[0]
return Tensor(np.array(fn(y)))
tensor_operator_registry.register('vm_compare', _vm_compare)

View File

@ -21,7 +21,7 @@ from ...ops import operations as P
from ...ops.primitive import PrimitiveWithInfer, prim_attr_register
from ...ops.composite import multitype_ops as C
from ...ops.operations import _grad_ops as G
from ..._checkparam import Validator
from ..._checkparam import Validator as validator
from ..cell import Cell, GraphKernel
@ -194,7 +194,7 @@ class ApplyMomentum(GraphKernel):
use_locking=False,
gradient_scale=1.0):
super(ApplyMomentum, self).__init__()
self.gradient_scale = Validator.check_type('gradient_scale', gradient_scale, [float])
self.gradient_scale = validator.check_value_type('gradient_scale', gradient_scale, [float], type(self).__name__)
self.fake_output_assign_1 = InplaceAssign()
self.fake_output_assign_1.add_prim_attr("fake_output", True)
self.fake_output_assign_2 = InplaceAssign()
@ -334,7 +334,7 @@ class ReduceMean(GraphKernel):
def __init__(self, keep_dims=True):
super(ReduceMean, self).__init__()
self.keep_dims = Validator.check_type('keep_dims', keep_dims, [bool])
self.keep_dims = validator.check_value_type('keep_dims', keep_dims, [bool], type(self).__name__)
self.sum = P.ReduceSum(self.keep_dims)
def construct(self, x, axis):
@ -431,8 +431,10 @@ class LayerNormForward(GraphKernel):
""" Forward function of the LayerNorm operator. """
def __init__(self, begin_norm_axis=1, begin_params_axis=1):
super(LayerNormForward, self).__init__()
self.begin_norm_axis = Validator.check_type('begin_norm_axis', begin_norm_axis, [int])
self.begin_params_axis = Validator.check_type('begin_params_axis', begin_params_axis, [int])
self.begin_norm_axis = validator.check_value_type('begin_norm_axis', begin_norm_axis, [int],
type(self).__name__)
self.begin_params_axis = validator.check_value_type('begin_params_axis', begin_params_axis, [int],
type(self).__name__)
self.mul = P.Mul()
self.sum_keep_dims = P.ReduceSum(keep_dims=True)
self.sub = P.Sub()
@ -686,7 +688,7 @@ class LogSoftmax(GraphKernel):
def __init__(self, axis=-1):
super(LogSoftmax, self).__init__()
self.axis = Validator.check_type('axis', axis, [int])
self.axis = validator.check_value_type('axis', axis, [int], type(self).__name__)
self.max_keep_dims = P.ReduceMax(keep_dims=True)
self.sub = P.Sub()
self.exp = P.Exp()
@ -952,13 +954,13 @@ class Softmax(GraphKernel):
def __init__(self, axis):
super(Softmax, self).__init__()
Validator.check_type("axis", axis, [int, tuple])
validator.check_value_type("axis", axis, [int, tuple], type(self).__name__)
if isinstance(axis, int):
self.axis = (axis,)
else:
self.axis = axis
for item in self.axis:
Validator.check_type("item of axis", item, [int])
validator.check_value_type("item of axis", item, [int], type(self).__name__)
self.max = P.ReduceMax(keep_dims=True)
self.sub = P.Sub()
self.exp = P.Exp()

View File

@ -19,7 +19,7 @@ from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.ops.primitive import constexpr
import mindspore.context as context
from mindspore._checkparam import Validator, check_typename
from mindspore._checkparam import Validator as validator
from mindspore._extends import cell_attr_register
from mindspore.communication.management import get_group_size, get_rank
from mindspore.communication import management
@ -52,7 +52,7 @@ class _BatchNorm(Cell):
if momentum < 0 or momentum > 1:
raise ValueError("momentum should be a number in range [0, 1], but got {}".format(momentum))
self.format = Validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
raise ValueError("NHWC format only support in GPU target.")
self.use_batch_statistics = use_batch_statistics
@ -67,7 +67,7 @@ class _BatchNorm(Cell):
gamma_init, num_features), name="gamma", requires_grad=affine)
self.beta = Parameter(initializer(
beta_init, num_features), name="beta", requires_grad=affine)
self.group = Validator.check_positive_int(device_num_each_group)
self.group = validator.check_positive_int(device_num_each_group)
self.is_global = False
if self.group != 1:
self.rank_id = get_rank()
@ -472,7 +472,7 @@ class GlobalBatchNorm(_BatchNorm):
use_batch_statistics,
device_num_each_group,
input_dims='both')
self.group = Validator.check_positive_int(device_num_each_group)
self.group = validator.check_positive_int(device_num_each_group)
if self.group <= 1:
raise ValueError("the number of group must be greater than 1.")
@ -607,12 +607,12 @@ class GroupNorm(Cell):
def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, gamma_init='ones', beta_init='zeros'):
super(GroupNorm, self).__init__()
self.num_groups = Validator.check_positive_int(num_groups)
self.num_channels = Validator.check_positive_int(num_channels)
self.num_groups = validator.check_positive_int(num_groups)
self.num_channels = validator.check_positive_int(num_channels)
if num_channels % num_groups != 0:
raise ValueError("num_channels should be divided by num_groups")
self.eps = check_typename('eps', eps, (float,))
self.affine = Validator.check_bool(affine)
self.eps = validator.check_value_type('eps', eps, (float,), type(self).__name__)
self.affine = validator.check_bool(affine)
gamma = initializer(gamma_init, num_channels)
beta = initializer(beta_init, num_channels)

View File

@ -442,8 +442,8 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver):
super(FakeQuantWithMinMaxObserver, self).__init__(quant_dtype=quant_dtype, per_channel=per_channel,
symmetric=symmetric, narrow_range=narrow_range,
num_channels=num_channels)
Validator.check_type("min_init", min_init, [int, float])
Validator.check_type("max_init", max_init, [int, float])
Validator.check_value_type("min_init", min_init, [int, float], type(self).__name__)
Validator.check_value_type("max_init", max_init, [int, float], type(self).__name__)
Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT)
Validator.check_non_negative_int(quant_delay, 'quant_delay')
self.min_init = min_init

View File

@ -68,7 +68,7 @@ class GumbelCDF(Bijector):
"""
param = dict(locals())
valid_dtype = mstype.float_type + mstype.int_type + mstype.uint_type
Validator.check_type(type(self).__name__, dtype, valid_dtype)
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
parameter_type = set_param_type({'loc': loc, "scale": scale}, dtype)
super(GumbelCDF, self).__init__(name=name, dtype=dtype, param=param)

View File

@ -119,7 +119,7 @@ class Bernoulli(Distribution):
param = dict(locals())
param['param_dict'] = {'probs': probs}
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
Validator.check_type(type(self).__name__, dtype, valid_dtype)
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
super(Bernoulli, self).__init__(seed, dtype, name, param)
self._probs = self._add_parameter(probs, 'probs')

View File

@ -109,7 +109,7 @@ class Categorical(Distribution):
param = dict(locals())
param['param_dict'] = {'probs': probs}
valid_dtype = mstype.int_type
Validator.check_type("Categorical", dtype, valid_dtype)
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
super(Categorical, self).__init__(seed, dtype, name, param)
self._probs = self._add_parameter(probs, 'probs')

View File

@ -121,7 +121,7 @@ class Exponential(Distribution):
param = dict(locals())
param['param_dict'] = {'rate': rate}
valid_dtype = mstype.float_type
Validator.check_type(type(self).__name__, dtype, valid_dtype)
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
super(Exponential, self).__init__(seed, dtype, name, param)
self._rate = self._add_parameter(rate, 'rate')

View File

@ -122,7 +122,7 @@ class Geometric(Distribution):
param = dict(locals())
param['param_dict'] = {'probs': probs}
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
Validator.check_type(type(self).__name__, dtype, valid_dtype)
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
super(Geometric, self).__init__(seed, dtype, name, param)
self._probs = self._add_parameter(probs, 'probs')

View File

@ -102,7 +102,7 @@ class Gumbel(TransformedDistribution):
Constructor of Gumbel distribution.
"""
valid_dtype = mstype.float_type
Validator.check_type(type(self).__name__, dtype, valid_dtype)
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
gumbel_cdf = msb.GumbelCDF(loc, scale, dtype)
super(Gumbel, self).__init__(
distribution=msd.Uniform(0.0, 1.0, dtype=dtype),

View File

@ -111,7 +111,7 @@ class Logistic(Distribution):
param = dict(locals())
param['param_dict'] = {'loc': loc, 'scale': scale}
valid_dtype = mstype.float_type
Validator.check_type(type(self).__name__, dtype, valid_dtype)
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
super(Logistic, self).__init__(seed, dtype, name, param)
self._loc = self._add_parameter(loc, 'loc')

View File

@ -127,7 +127,7 @@ class Normal(Distribution):
param = dict(locals())
param['param_dict'] = {'mean': mean, 'sd': sd}
valid_dtype = mstype.float_type
Validator.check_type(type(self).__name__, dtype, valid_dtype)
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
super(Normal, self).__init__(seed, dtype, name, param)
self._mean_value = self._add_parameter(mean, 'mean')

View File

@ -126,7 +126,7 @@ class Uniform(Distribution):
param = dict(locals())
param['param_dict'] = {'low': low, 'high': high}
valid_dtype = mstype.float_type
Validator.check_type(type(self).__name__, dtype, valid_dtype)
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
super(Uniform, self).__init__(seed, dtype, name, param)
self._low = self._add_parameter(low, 'low')

View File

@ -55,8 +55,7 @@ class UpdateCache(PrimitiveWithInfer):
return [1]
def infer_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype):
args = {"indices": indices_dtype}
validator.check_tensor_type_same(args, mstype.int_type, self.name)
validator.check_tensor_dtype_valid("indices", indices_dtype, mstype.int_type, self.name)
return input_x_dtype
@ -140,7 +139,7 @@ class SearchCacheIdx(PrimitiveWithInfer):
def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype):
args = {"hashmap": hashmap_dtype, "indices": indices_dtype}
validator.check_tensor_type_same(args, mstype.int_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.int_type, self.name)
out_dtype = (hashmap_dtype, hashmap_dtype, hashmap_dtype)
return out_dtype
@ -182,8 +181,7 @@ class CacheSwapHashmap(PrimitiveWithInfer):
return out_shape
def infer_dtype(self, hashmap_dtype, miss_emb_idx_dtype, step_dtype):
args = {"miss_emb_idx": miss_emb_idx_dtype}
validator.check_tensor_type_same(args, mstype.int_type, self.name)
validator.check_tensor_dtype_valid("miss_emb_idx", miss_emb_idx_dtype, mstype.int_type, self.name)
out_dtype = (miss_emb_idx_dtype, miss_emb_idx_dtype)
return out_dtype
@ -224,8 +222,7 @@ class CacheSwapTable(PrimitiveWithInfer):
return miss_value_shape
def infer_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype):
args = {"swap_cache_idx": swap_cache_idx_dtype}
validator.check_tensor_type_same(args, mstype.int_type, self.name)
validator.check_tensor_dtype_valid("swap_cache_idx", swap_cache_idx_dtype, mstype.int_type, self.name)
return miss_value_dtype
@ -261,7 +258,7 @@ class MapCacheIdx(PrimitiveWithInfer):
def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype):
args = {"hashmap": hashmap_dtype, "indices": indices_dtype}
validator.check_tensor_type_same(args, mstype.int_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.int_type, self.name)
out_dtype = (hashmap_dtype, hashmap_dtype,
hashmap_dtype, hashmap_dtype)
return out_dtype

View File

@ -14,6 +14,7 @@
# ============================================================================
"""Operators for gradients."""
from functools import partial
from .. import signature as sig
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
@ -23,6 +24,7 @@ from ...common import dtype as mstype
from .. import functional as F
from ... import context
class AbsGrad(PrimitiveWithInfer):
"""Computes gradients for abs operation."""
@ -55,7 +57,7 @@ class ACosGrad(PrimitiveWithInfer):
def infer_dtype(self, x, dout):
args = {"x": x, "dout": dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return x
@ -72,7 +74,7 @@ class AcoshGrad(PrimitiveWithInfer):
def infer_dtype(self, x, dout):
args = {"x": x, "dout": dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return x
@ -94,7 +96,7 @@ class AsinGrad(PrimitiveWithInfer):
def infer_dtype(self, x, dout):
args = {"x": x, "dout": dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return x
@ -111,7 +113,7 @@ class AsinhGrad(PrimitiveWithInfer):
def infer_dtype(self, x, dout):
args = {"x": x, "dout": dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return x
@ -128,7 +130,7 @@ class ReciprocalGrad(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, dout_dtype):
args = {"x": x_dtype, "dout": dout_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
return x_dtype
@ -145,7 +147,8 @@ class RsqrtGrad(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, dout_dtype):
args = {"x": x_dtype, "dout": dout_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32, mstype.int32, mstype.int8],
self.name)
return x_dtype
@ -162,7 +165,7 @@ class SoftmaxGrad(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, dout_dtype):
args = {"x": x_dtype, "dout": dout_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
return x_dtype
@ -179,7 +182,7 @@ class SqrtGrad(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, dout_dtype):
args = {"x": x_dtype, "dout": dout_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
return x_dtype
@ -232,7 +235,7 @@ class KLDivLossGrad(PrimitiveWithInfer):
def infer_dtype(self, x_type, y_type, doutput_type):
args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
return x_type, y_type
@ -251,7 +254,7 @@ class BinaryCrossEntropyGrad(PrimitiveWithInfer):
def infer_dtype(self, x_type, y_type, doutput_type, weight_type):
args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
if weight_type:
validator.check('x_type', x_type, 'weight_type', weight_type, Rel.EQ, TypeError)
return x_type
@ -343,7 +346,8 @@ class Conv2DBackpropFilter(PrimitiveWithInfer):
for i, dim_len in enumerate(w_size_v):
validator.check_value_type("w_size[%d]" % i, dim_len, [int], self.name)
args = {"x": x['dtype'], "doutput": doutput['dtype']}
validator.check_tensor_type_same(args, [mstype.int8, mstype.int32, mstype.float16, mstype.float32], self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.int8, mstype.int32, mstype.float16, mstype.float32],
self.name)
out = {
'value': None,
'shape': w_size_v,
@ -406,7 +410,7 @@ class DepthwiseConv2dNativeBackpropFilter(PrimitiveWithInfer):
def __infer__(self, x, w_size, dout):
w_size_v = w_size['value']
args = {'x': x['dtype'], 'dout': dout['dtype']}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
out = {
'value': None,
'shape': w_size_v,
@ -466,7 +470,7 @@ class DepthwiseConv2dNativeBackpropInput(PrimitiveWithInfer):
def __infer__(self, x_size, w, dout):
args = {'w': w['dtype'], 'dout': dout['dtype']}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
x_size_v = x_size['value']
out = {
'value': None,
@ -505,10 +509,9 @@ class DropoutGrad(PrimitiveWithInfer):
return dy_shape
def infer_dtype(self, dy_dtype, mask_dtype):
valid_types = (mstype.float16, mstype.float32)
validator.check_subclass("dy", dy_dtype, mstype.tensor, self.name)
valid_dtypes = (mstype.float16, mstype.float32)
validator.check_subclass("mask", mask_dtype, mstype.tensor, self.name)
validator.check_tensor_type_same({"dy_dtype": dy_dtype}, valid_types, self.name)
validator.check_tensor_dtype_valid("dy", dy_dtype, valid_dtypes, self.name)
return dy_dtype
@ -627,9 +630,10 @@ class GeluGrad(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, y_backprop_dtype, x_dtype, y_dtype):
validator.check_tensor_type_same({"y_backprop": y_backprop_dtype}, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"y": y_dtype}, (mstype.float16, mstype.float32), self.name)
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
("y_backprop", "x", "y"),
(y_backprop_dtype, x_dtype, y_dtype)))
return x_dtype
@ -782,7 +786,7 @@ class MaxPoolGradGrad(_PoolGrad):
def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype):
args = {'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'grad_dtype': grad_dtype}
validator.check_tensor_type_same(args, [mstype.float16], self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16], self.name)
return x1_dtype
@ -858,7 +862,7 @@ class MaxPoolGradGradWithArgmax(_PoolGrad):
def infer_dtype(self, x_dtype, grad_dtype, argmax_dtype):
args = {'x_dtype': x_dtype, 'grad_dtype': grad_dtype}
validator.check_tensor_type_same(args, [mstype.float16], self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16], self.name)
return grad_dtype
@ -902,7 +906,7 @@ class L2NormalizeGrad(PrimitiveWithInfer):
def infer_dtype(self, input_x, out, dout):
args = {'input_x': input_x, 'out': out, 'dout': dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return input_x
@ -993,7 +997,7 @@ class LSTMGradData(PrimitiveWithInfer):
def infer_dtype(self, y_dtype, dy_dtype, dhy_dtype, dcy_dtype, w_dtype,
hx_dtype, cx_dtype, reserve_dtype, state_dtype):
args = {"dy": dy_dtype, "dhy": dhy_dtype, "dcy": dcy_dtype}
validator.check_tensor_type_same(args, (mstype.float32, mstype.float16), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32, mstype.float16), self.name)
return (dy_dtype, dy_dtype, dy_dtype)
@ -1265,14 +1269,14 @@ class DynamicGRUV2Grad(PrimitiveWithInfer):
args = {"y_dtype": y_dtype, "init_h_dtype": init_h_dtype, "h_dtype": h_dtype,
"dy_dtype": dy_dtype, "dh_dtype": dh_dtype, "update_dtype": update_dtype,
"reset_dtype": reset_dtype, "new_dtype": new_dtype, "hnew_dtype": hnew_dtype}
validator.check_tensor_type_same({"x_dtype": x_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"winput_dtype": winput_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"whidden_dtype": whidden_dtype}, valid_types, self.name)
validator.check_tensor_type_same(args, valid_types, self.name)
validator.check_tensor_dtype_valid("x_dtype", x_dtype, valid_types, self.name)
validator.check_tensor_dtype_valid("winput_dtype", winput_dtype, valid_types, self.name)
validator.check_tensor_dtype_valid("whidden_dtype", whidden_dtype, valid_types, self.name)
validator.check_tensors_dtypes_same_and_valid(args, valid_types, self.name)
if seq_dtype is not None:
validator.check_tensor_type_same({"seq_dtype": seq_dtype}, (mstype.float32, mstype.float16), self.name)
validator.check_tensor_dtype_valid("seq_dtype", seq_dtype, valid_types, self.name)
if mask_dtype is not None:
validator.check_tensor_type_same({"mask_dtype": mask_dtype}, (mstype.float32, mstype.float16), self.name)
validator.check_tensor_dtype_valid("mask_dtype", mask_dtype, valid_types, self.name)
return x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype
@ -1302,10 +1306,10 @@ class PReLUGrad(PrimitiveWithInfer):
return y_backprop_shape, w_shape
def infer_dtype(self, y_backprop_dtype, A_dtype, w_dtype):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"y_backprop": y_backprop_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"A_dtype": A_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"w_dtype": w_dtype}, valid_types, self.name)
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
('y_backprop', "input_x", "weight"),
(y_backprop_dtype, A_dtype, w_dtype)))
return y_backprop_dtype, w_dtype
@ -1335,8 +1339,9 @@ class ReLU6Grad(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, y_grad_dtype, x_dtype):
validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
valid_dtypes = (mstype.float16, mstype.float32)
validator.check_tensor_dtype_valid("y_grad", y_grad_dtype, valid_dtypes, self.name)
validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
return x_dtype
@ -1354,8 +1359,8 @@ class ReluGradV2(PrimitiveWithInfer):
return gradients_shape
def infer_dtype(self, gradients_dtype, mask_dtype):
validator.check_tensor_type_same({'gradients': gradients_dtype}, mstype.number_type, self.name)
validator.check_tensor_type_same({'mask': mask_dtype}, (mstype.uint8,), self.name)
validator.check_tensor_dtype_valid('gradients', gradients_dtype, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('mask', mask_dtype, (mstype.uint8,), self.name)
return gradients_dtype
@ -1371,7 +1376,7 @@ class EluGrad(PrimitiveWithInfer):
def infer_dtype(self, y_grad_dtype, x_dtype):
args = {'y_grad': y_grad_dtype, 'x': x_dtype}
validator.check_tensor_type_same(args, mstype.float_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type, self.name)
return x_dtype
@ -1474,7 +1479,7 @@ class SigmoidGrad(PrimitiveWithInfer):
def infer_dtype(self, out, dout):
args = {'out': out, 'dout': dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return out
@ -1489,8 +1494,9 @@ class HSigmoidGrad(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, y_grad_dtype, x_dtype):
validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
valid_dtypes = (mstype.float16, mstype.float32)
validator.check_tensor_dtype_valid("y_grad", y_grad_dtype, valid_dtypes, self.name)
validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
return x_dtype
@ -1505,8 +1511,9 @@ class HSwishGrad(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, y_grad_dtype, x_dtype):
validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
valid_dtypes = (mstype.float16, mstype.float32)
validator.check_tensor_dtype_valid("y_grad", y_grad_dtype, valid_dtypes, self.name)
validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
return x_dtype
@ -1525,7 +1532,7 @@ class SigmoidCrossEntropyWithLogitsGrad(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, y_dtype, dout_dtype):
args = {"x_dtype": x_dtype, "y_dtype": y_dtype, 'dout_dtype': dout_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return dout_dtype
@ -1562,7 +1569,7 @@ class SmoothL1LossGrad(PrimitiveWithInfer):
def infer_dtype(self, prediction, target, dloss):
args = {"prediction": prediction, "target": target, 'dloss': dloss}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return dloss
@ -1597,8 +1604,7 @@ class StridedSliceGrad(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output'])
def __infer__(self, dy, shapex, begin, end, strides):
args = {"dy": dy['dtype']}
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensor_dtype_valid("dy", dy['dtype'], mstype.number_type + (mstype.bool_,), self.name)
for idx, item in enumerate(shapex['value']):
validator.check_value_type("shapex[%d]" % idx, item, [int], self.name)
@ -1627,7 +1633,7 @@ class SoftplusGrad(PrimitiveWithInfer):
def infer_dtype(self, dout_dtype, x_dtype):
args = {"x_dtype": x_dtype, "dout_dtype": dout_dtype}
validator.check_tensor_type_same(args, mstype.float_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type, self.name)
return x_dtype
@ -1643,7 +1649,7 @@ class TanhGrad(PrimitiveWithInfer):
def infer_dtype(self, out, dout):
args = {"out": out, "dout": dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return out
@ -1756,7 +1762,7 @@ class AtanGrad(PrimitiveWithInfer):
def infer_dtype(self, x, dout):
args = {"x": x, "dout": dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return x
@ -1900,7 +1906,7 @@ class LRNGrad(PrimitiveWithInfer):
def infer_dtype(self, grads, x, y):
args = {"grads": grads, "x": x, "y": y}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32,), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32,), self.name)
return x
def infer_shape(self, grads, x, y):

View File

@ -54,6 +54,7 @@ class ExtractImagePatches(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, ksizes, strides, rates, padding="valid"):
"""init"""
def _check_tuple_or_list(arg_name, arg_val, prim_name):
validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.name)
if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1:
@ -103,7 +104,7 @@ class ExtractImagePatches(PrimitiveWithInfer):
def infer_dtype(self, input_x):
"""infer dtype"""
validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid("input_x", input_x, mstype.number_type, self.name)
return input_x
@ -161,7 +162,7 @@ class Range(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x_dtype': x_dtype}, [mstype.float32, mstype.int32], self.name)
validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.int32], self.name)
return x_dtype
@ -254,6 +255,7 @@ class Dequant(PrimitiveWithInfer):
>>> dequant = P.Dequant(False, False)
>>> y = dequant(input_x)
"""
@prim_attr_register
def __init__(self, sqrt_mode=False, relu_flag=False):
self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
@ -303,10 +305,9 @@ class LinSpace(PrimitiveWithInfer):
return assist
def infer_dtype(self, assist, start, stop, num):
args = {"num": num}
validator.check_tensor_type_same(args, (mstype.int32,), self.name)
validator.check_tensor_dtype_valid("num", num, (mstype.int32,), self.name)
args = {"assist": assist, "start": start, "stop": stop}
validator.check_tensor_type_same(args, (mstype.float32,), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32,), self.name)
return assist
@ -343,12 +344,12 @@ class MatrixDiag(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, assist_dtype):
valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
args = {"x": x_dtype, "assist": assist_dtype}
validator.check_tensor_type_same(args, valid_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
return x_dtype
def infer_shape(self, x_shape, assist_shape):
validator.check_int(len(assist_shape), 2, Rel.GE, "assist rank", self.name)
validator.check('rank of x', len(x_shape)+1,
validator.check('rank of x', len(x_shape) + 1,
'rank of assist', len(assist_shape), Rel.LE, self.name)
validator.check('assist\'s penultimate dimension', assist_shape[-2], 'assist\'s last dimension',
assist_shape[-1], Rel.EQ, self.name)
@ -358,7 +359,7 @@ class MatrixDiag(PrimitiveWithInfer):
while r_idx >= r_end_dim:
if x_shape[r_idx] != 1:
validator.check("reverse x dim %d" % r_idx, x_shape[r_idx], "reverse assist dim %d" %
assist_shape[r_idx-1], assist_shape[r_idx-1], Rel.EQ, self.name)
assist_shape[r_idx - 1], assist_shape[r_idx - 1], Rel.EQ, self.name)
r_idx = r_idx - 1
return assist_shape
@ -391,7 +392,7 @@ class MatrixDiagPart(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, assist_dtype):
valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
args = {"x": x_dtype, "assist": assist_dtype}
validator.check_tensor_type_same(args, valid_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
return x_dtype
def infer_shape(self, x_shape, assist_shape):
@ -434,7 +435,7 @@ class MatrixSetDiag(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, diagonal_dtype, assist_dtype):
valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
args = {"x": x_dtype, "diagonal": diagonal_dtype, "assist": assist_dtype}
validator.check_tensor_type_same(args, valid_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
return x_dtype
def infer_shape(self, x_shape, diagonal_shape, assist_shape):
@ -583,21 +584,21 @@ class DynamicGRUV2(PrimitiveWithInfer):
return y_shape, outh_shape, outh_shape, outh_shape, outh_shape, outh_shape
def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, binput_dtype, bhidden_dtype, seq_dtype, h_dtype):
validator.check_tensor_type_same({"x dtype": x_dtype}, (mstype.float16,), self.name)
validator.check_tensor_type_same({"weight input dtype": winput_dtype}, (mstype.float16,), self.name)
validator.check_tensor_type_same({"weight hidden dtype": whidden_dtype}, (mstype.float16,), self.name)
validator.check_tensor_dtype_valid("x dtype", x_dtype, [mstype.float16], self.name)
validator.check_tensor_dtype_valid("weight input dtype", winput_dtype, [mstype.float16], self.name)
validator.check_tensor_dtype_valid("weight hidden dtype", whidden_dtype, [mstype.float16], self.name)
b_dtype = mstype.float32
if binput_dtype is not None:
validator.check_tensor_type_same({"bias input dtype": binput_dtype},
(mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("bias input dtype", binput_dtype,
(mstype.float16, mstype.float32), self.name)
b_dtype = binput_dtype
elif bhidden_dtype is not None:
validator.check_tensor_type_same({"bias hidden dtype": bhidden_dtype},
(mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("bias hidden dtype", bhidden_dtype,
(mstype.float16, mstype.float32), self.name)
b_dtype = bhidden_dtype
elif h_dtype is not None:
validator.check_tensor_type_same({"init_h dtype": h_dtype},
(mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("init_h dtype", h_dtype,
(mstype.float16, mstype.float32), self.name)
b_dtype = h_dtype
return b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype

View File

@ -14,6 +14,7 @@
# ============================================================================
"""Operators for quantization."""
from functools import partial
import mindspore.context as context
from ..._checkparam import Validator as validator
@ -92,12 +93,10 @@ class MinMaxUpdatePerLayer(PrimitiveWithInfer):
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
("x", "min", "max"),
(x_type, min_type, max_type)))
return min_type, max_type
@ -157,13 +156,10 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same(
{"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
("x", "min", "max"),
(x_type, min_type, max_type)))
return min_type, max_type
@ -193,6 +189,7 @@ class FakeQuantWithMinMaxVars(PrimitiveWithInfer):
>>> input_tensor, min_tensor, max_tensor)
>>> output_tensor shape: (3, 16, 5, 5) data type: mstype.float32
"""
@prim_attr_register
def __init__(self,
num_bits=8,
@ -217,10 +214,10 @@ class FakeQuantWithMinMaxVars(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({'x': x_type}, valid_types, self.name)
validator.check_tensor_type_same({'min': min_type}, valid_types, self.name)
validator.check_tensor_type_same({'max': max_type}, valid_types, self.name)
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
("x", "min", "max"),
(x_type, min_type, max_type)))
return x_type
@ -256,6 +253,7 @@ class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer):
>>> min_gradient shape: (1,) data type: mstype.float32
>>> max_gradient shape: (1,) data type: mstype.float32
"""
@prim_attr_register
def __init__(self,
num_bits=8,
@ -281,11 +279,10 @@ class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer):
return x_shape, min_shape, max_shape
def infer_dtype(self, dout_type, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({'dout': dout_type}, valid_types, self.name)
validator.check_tensor_type_same({'x': x_type}, valid_types, self.name)
validator.check_tensor_type_same({'min': min_type}, valid_types, self.name)
validator.check_tensor_type_same({'max': max_type}, valid_types, self.name)
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
('dout', "x", "min", "max"),
(dout_type, x_type, min_type, max_type)))
return x_type, min_type, max_type
@ -315,6 +312,7 @@ class FakeQuantWithMinMaxVarsPerChannel(PrimitiveWithInfer):
>>> input_tensor, min_tensor, max_tensor)
>>> output_tensor shape: (3, 16, 3, 4) data type: mstype.float32
"""
@prim_attr_register
def __init__(self,
num_bits=8,
@ -332,10 +330,10 @@ class FakeQuantWithMinMaxVarsPerChannel(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({'x': x_type}, valid_types, self.name)
validator.check_tensor_type_same({'min': min_type}, valid_types, self.name)
validator.check_tensor_type_same({'max': max_type}, valid_types, self.name)
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
("x", "min", "max"),
(x_type, min_type, max_type)))
return x_type
@ -372,6 +370,7 @@ class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer):
>>> min_gradient shape: (4,) data type: mstype.float32
>>> max_gradient shape: (4,) data type: mstype.float32
"""
@prim_attr_register
def __init__(self,
num_bits=8,
@ -390,11 +389,10 @@ class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer):
return x_shape, min_shape, max_shape
def infer_dtype(self, dout_type, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({'dout': dout_type}, valid_types, self.name)
validator.check_tensor_type_same({'x': x_type}, valid_types, self.name)
validator.check_tensor_type_same({'min': min_type}, valid_types, self.name)
validator.check_tensor_type_same({'max': max_type}, valid_types, self.name)
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
("dout", "x", "min", "max"),
(dout_type, x_type, min_type, max_type)))
return x_type, min_type, max_type
@ -468,14 +466,12 @@ class FakeQuantPerLayer(PrimitiveWithInfer):
def infer_dtype(self, x_type, min_type, max_type):
if context.get_context('device_target') == "GPU":
valid_types = (mstype.float32,)
valid_dtypes = (mstype.float32,)
else:
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
valid_dtypes = (mstype.float16, mstype.float32)
tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
("x", "min", "max"),
(x_type, min_type, max_type)))
return x_type
@ -525,16 +521,12 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer):
def infer_dtype(self, dout_type, x_type, min_type, max_type):
if context.get_context('device_target') == "GPU":
valid_types = (mstype.float32,)
valid_dtypes = (mstype.float32,)
else:
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same(
{"dout": dout_type}, valid_types, self.name)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
valid_dtypes = (mstype.float16, mstype.float32)
tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
("dout", "x", "min", "max"),
(dout_type, x_type, min_type, max_type)))
return dout_type
@ -623,14 +615,12 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
def infer_dtype(self, x_type, min_type, max_type):
if context.get_context('device_target') == "GPU":
valid_types = (mstype.float32,)
valid_dtypes = (mstype.float32,)
else:
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
valid_dtypes = (mstype.float16, mstype.float32)
tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
("x", "min", "max"),
(x_type, min_type, max_type)))
return x_type
@ -680,16 +670,12 @@ class FakeQuantPerChannelGrad(PrimitiveWithInfer):
def infer_dtype(self, dout_type, x_type, min_type, max_type):
if context.get_context('device_target') == "GPU":
valid_types = (mstype.float32,)
valid_dtypes = (mstype.float32,)
else:
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same(
{"dout": dout_type}, valid_types, self.name)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
valid_dtypes = (mstype.float16, mstype.float32)
tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
("dout", "x", "min", "max"),
(dout_type, x_type, min_type, max_type)))
return dout_type
@ -750,8 +736,8 @@ class BatchNormFold(PrimitiveWithInfer):
validator.check("input type", x_type, "mean type", mean_type)
validator.check("input type", x_type, "variance type", variance_type)
args = {"x": x_type, "mean": mean_type, "variance": variance_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name)
return x_type, x_type, x_type, x_type
@ -797,8 +783,8 @@ class BatchNormFoldGrad(PrimitiveWithInfer):
global_step_type):
args = {"input": x_type, "d_batch_mean": d_batch_mean_type, "d_batch_std": d_batch_std_type,
"batch_mean": batch_mean_type, "batch_std": batch_std_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name)
return x_type
@ -841,7 +827,7 @@ class CorrectionMul(PrimitiveWithInfer):
def infer_dtype(self, x_type, batch_std_type, running_std_type):
args = {"x": x_type, "batch_std": batch_std_type, "running_std": running_std_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
return x_type
@ -879,7 +865,7 @@ class CorrectionMulGrad(PrimitiveWithInfer):
def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type):
args = {"dout": dout_type, "x": x_type, "gamma": gamma_type, "running_std": running_std_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
if context.get_context('device_target') == "Ascend":
return x_type, x_type
return x_type, gamma_type
@ -972,8 +958,8 @@ class BatchNormFold2(PrimitiveWithInfer):
running_mean_type, global_step_type):
args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type,
"beta": beta_type, "running_mean": running_mean_type, "gamma": gamma_type, "x": x_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name)
return x_type
@ -1031,8 +1017,8 @@ class BatchNormFold2Grad(PrimitiveWithInfer):
"dout type", dout_type)
args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type,
"running_std": running_std_type, "running_mean": running_mean_type, "dout": dout_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name)
return gamma_type, gamma_type, gamma_type, gamma_type, gamma_type
@ -1061,7 +1047,7 @@ class BatchNormFoldD(PrimitiveWithInfer):
validator.check("input type", x_type, "mean type", mean_type)
validator.check("input type", x_type, "variance type", variance_type)
args = {"x": x_type, "mean": mean_type, "variance": variance_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
return x_type, x_type, x_type, x_type, x_type, x_type, x_type
@ -1090,8 +1076,7 @@ class BatchNormFoldGradD(PrimitiveWithInfer):
validator.check("input type", x_type, "d_batch_std type", d_batch_std_type)
validator.check("input type", x_type, "batch_mean type", batch_mean_type)
validator.check("input type", x_type, "batch_std type", batch_std_type)
args = {"input type": x_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("input type", x_type, (mstype.float16, mstype.float32), self.name)
return x_type
@ -1136,7 +1121,7 @@ class BatchNormFold2_D(PrimitiveWithInfer):
def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type):
args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type,
"beta": beta_type, "gamma": gamma_type, "x": x_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
return x_type
@ -1174,7 +1159,7 @@ class BatchNormFold2GradD(PrimitiveWithInfer):
"dout type", dout_type)
args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type,
"running_std": running_std_type, "dout": dout_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
return gamma_type, gamma_type, gamma_type, gamma_type

View File

@ -165,7 +165,7 @@ class CusFusedAbsMax1(PrimitiveWithInfer):
def infer_shape(self, data1_shape):
ll = []
if len(data1_shape) == 2:
ll = [1,]
ll = [1]
else:
ll = [32, 64]
return ll
@ -497,6 +497,7 @@ class Im2Col(PrimitiveWithInfer):
>>> img2col = P.CusMatMulCubeDenseLeft(kernel_size=7, pad=3, stride=2)
>>> output = img2col(input_x)
"""
@prim_attr_register
def __init__(self,
kernel_size,
@ -556,9 +557,8 @@ class Im2Col(PrimitiveWithInfer):
return out_shape
def infer_dtype(self, x_dtype):
args = {'x': x_dtype}
valid_types = [mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_types, self.name)
valid_dtypes = [mstype.float16, mstype.float32]
validator.check_tensor_dtype_valid('x', x_dtype, valid_dtypes, self.name)
return x_dtype
@ -602,14 +602,17 @@ class UpdateThorGradient(PrimitiveWithInfer):
return x2_shape
def infer_dtype(self, x1_dtype, x2_dtype, x3_dtype):
validator.check_tensor_type_same({'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'x3_dtype': x3_dtype},
[mstype.float32], self.name)
validator.check_tensors_dtypes_same_and_valid(
{'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'x3_dtype': x3_dtype},
[mstype.float32], self.name)
return x2_dtype
class Cholesky(PrimitiveWithInfer):
"""
Inner API for resnet50 THOR GPU backend
"""
@prim_attr_register
def __init__(self, split_dim=0):
self.init_prim_io_names(inputs=['x1'], outputs=['y'])
@ -634,13 +637,15 @@ class Cholesky(PrimitiveWithInfer):
return out_shape
def infer_dtype(self, x1_dtype):
validator.check_tensor_type_same({'x1_dtype': x1_dtype}, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('x1', x1_dtype, [mstype.float32], self.name)
return x1_dtype
class DetTriangle(PrimitiveWithInfer):
"""
Calculate the determinant of triangle matrices
"""
@prim_attr_register
def __init__(self, fill_mode=0):
self.init_prim_io_names(inputs=['x1'], outputs=['y'])
@ -653,5 +658,5 @@ class DetTriangle(PrimitiveWithInfer):
return out_shape
def infer_dtype(self, x1_dtype):
validator.check_tensor_type_same({'x1_dtype': x1_dtype}, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('x1', x1_dtype, [mstype.float32], self.name)
return x1_dtype

View File

@ -63,9 +63,9 @@ class _ScatterOp(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
args = {"x": x_dtype, "updates": updates_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return x_dtype
@ -73,6 +73,7 @@ class _ScatterNdOp(_ScatterOp):
"""
Defines _ScatterNd operators
"""
def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name):
validator.check('the dimension of x', len(x_shape),
'the dimension of indices', indices_shape[-1], Rel.GE)
@ -627,6 +628,7 @@ class Unique(Primitive):
>>> out = P.Unique()(x)
(Tensor([1, 2, 5], mindspore.int32), Tensor([0, 1, 2, 1], mindspore.int32))
"""
@prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=['x'], outputs=['output'])
@ -661,11 +663,11 @@ class GatherV2(PrimitiveWithCheck):
def __init__(self):
"""Initialize index_select"""
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
self.add_prim_attr("dynamic_shape_depends", [2,])
self.add_prim_attr("dynamic_shape_depends", [2])
def __check__(self, params, indices, axis):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name)
axis_v = axis['value']
params_shp = params['shape']
@ -727,6 +729,7 @@ class Padding(PrimitiveWithInfer):
>>> out = P.Padding(pad_dim_size)(x)
[[8, 0, 0, 0], [10, 0, 0, 0]]
"""
@prim_attr_register
def __init__(self, pad_dim_size=8):
"""Initialize padding"""
@ -766,12 +769,13 @@ class UniqueWithPad(PrimitiveWithInfer):
>>> out = P.UniqueWithPad()(x, pad_num)
([1, 5, 4, 3, 2, 8, 8, 8, 8, 8], [0, 0, 1, 1, 2, 2, 3, 3, 4, 4])
"""
@prim_attr_register
def __init__(self):
"""init UniqueWithPad"""
def __infer__(self, x, pad_num):
validator.check_tensor_type_same({"x": x['dtype']}, [mstype.int32, mstype.int64], self.name)
validator.check_tensor_dtype_valid("x", x['dtype'], [mstype.int32, mstype.int64], self.name)
validator.check_subclass("pad_num", pad_num['dtype'], [mstype.int32, mstype.int64], self.name)
x_shape = list(x['shape'])
validator.check("rank of x", len(x_shape), "expected", 1, Rel.EQ, self.name)
@ -903,7 +907,7 @@ class TruncatedNormal(PrimitiveWithInfer):
def __init__(self, seed=0, dtype=mstype.float32):
"""Initialize TruncatedNormal"""
validator.check_value_type('seed', seed, [int], self.name)
validator.check_type_same({'dtype': dtype}, mstype.number_type, self.name)
validator.check_types_same_and_valid({'dtype': dtype}, mstype.number_type, self.name)
def __infer__(self, shape):
shape_value = shape['value']
@ -984,10 +988,10 @@ class Fill(PrimitiveWithInfer):
validator.check_value_type("value", x['value'], [numbers.Number, bool], self.name)
for i, item in enumerate(dims['value']):
validator.check_positive_int(item, f'dims[{i}]', self.name)
valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64,
mstype.uint8, mstype.uint32, mstype.uint64,
mstype.float16, mstype.float32, mstype.float64]
validator.check_type_same({"value": dtype['value']}, valid_types, self.name)
valid_dtypes = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64,
mstype.uint8, mstype.uint32, mstype.uint64,
mstype.float16, mstype.float32, mstype.float64]
validator.check_types_same_and_valid({"value": dtype['value']}, valid_dtypes, self.name)
x_nptype = mstype.dtype_to_nptype(dtype['value'])
ret = np.full(dims['value'], x['value'], x_nptype)
out = {
@ -1026,7 +1030,7 @@ class OnesLike(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type + (mstype.bool_,), self.name)
return x_dtype
@ -1059,7 +1063,7 @@ class ZerosLike(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type + (mstype.bool_,), self.name)
return x_dtype
@ -1264,7 +1268,7 @@ class Argmax(PrimitiveWithInfer):
"""Initialize Argmax"""
self.init_prim_io_names(inputs=['x'], outputs=['output'])
validator.check_value_type("axis", axis, [int], self.name)
validator.check_type_same({'output': output_type}, [mstype.int32], self.name)
validator.check_types_same_and_valid({'output': output_type}, [mstype.int32], self.name)
self.axis = axis
self.add_prim_attr('output_type', output_type)
@ -1547,7 +1551,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
def __init__(self):
"""Initialize UnsortedSegmentSum"""
self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y'])
self.add_prim_attr("dynamic_shape_depends", [2,])
self.add_prim_attr("dynamic_shape_depends", [2])
def __infer__(self, x, segment_ids, num_segments):
x_type = x['dtype']
@ -1570,7 +1574,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
num_segments_type = num_segments['dtype']
validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name)
if isinstance(num_segments_type, type(mstype.tensor)):
validator.check_tensor_type_same({"num_segments": num_segments_type}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid("num_segments", num_segments_type, [mstype.int32], self.name)
shp = [-1]
else:
validator.check_value_type('num_segments', num_segments_v, [int], self.name)
@ -1623,8 +1627,8 @@ class UnsortedSegmentMin(PrimitiveWithInfer):
x_shape = x['shape']
segment_ids_shape = segment_ids['shape']
valid_type = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name)
validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name)
validator.check_tensor_dtype_valid("segment_ids", segment_ids['dtype'], [mstype.int32], self.name)
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
validator.check(f'first shape of input_x', x_shape[0],
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
@ -1673,8 +1677,8 @@ class UnsortedSegmentMax(PrimitiveWithInfer):
x_shape = x['shape']
segment_ids_shape = segment_ids['shape']
valid_type = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name)
validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name)
validator.check_tensors_dtypes_same_and_valid({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name)
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
validator.check(f'first shape of input_x', x_shape[0],
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
@ -1726,8 +1730,8 @@ class UnsortedSegmentProd(PrimitiveWithInfer):
validator.check_subclass("input_x", x_type, mstype.tensor, self.name)
validator.check_value_type("x_shape", x_shape, [list], self.name)
valid_type = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name)
validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name)
validator.check_tensor_dtype_valid("segment_ids", segment_ids['dtype'], [mstype.int32], self.name)
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
validator.check(f'first shape of input_x', x_shape[0],
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
@ -1833,7 +1837,7 @@ class ParallelConcat(PrimitiveWithInfer):
validator.check_int(len(x_shp), 1, Rel.GE, f'x_shp length', self.name)
args = {f"x_type[{i}]": elem for i, elem in enumerate(x_type)}
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type + (mstype.bool_,), self.name)
first_elem = x_shp[0]
for i, elem in enumerate(x_shp[1:]):
@ -2070,7 +2074,7 @@ class ReverseV2(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, (mstype.bool_,) + mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype
@ -2100,7 +2104,7 @@ class Rint(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float16, mstype.float32], self.name)
return x_dtype
@ -2167,7 +2171,7 @@ class Select(PrimitiveWithInfer):
self.add_prim_attr('T', x_type)
validator.check_subclass("x_type", x_type, mstype.tensor, self.name)
validator.check_subclass("y_type", y_type, mstype.tensor, self.name)
validator.check_tensor_type_same({"cond": cond_type}, [mstype.bool_], self.name)
validator.check_tensor_dtype_valid("cond", cond_type, [mstype.bool_], self.name)
if x_type != y_type:
raise TypeError('\'%s\' the x_type %s must be the same as y_type %s.' % (self.name, x_type, y_type))
return x_type
@ -2542,7 +2546,7 @@ class Eye(PrimitiveWithInfer):
validator.check_positive_int(n, "n", self.name)
validator.check_positive_int(m, "m", self.name)
args = {"dtype": t}
validator.check_type_same(args, mstype.number_type + (mstype.bool_,), self.name)
validator.check_types_same_and_valid(args, mstype.number_type + (mstype.bool_,), self.name)
np_type = mstype.dtype_to_nptype(t)
ret = np.eye(n, m, dtype=np_type)
return Tensor(ret)
@ -2581,7 +2585,7 @@ class ScatterNd(PrimitiveWithInfer):
def __infer__(self, indices, update, shape):
shp = shape['value']
validator.check_subclass("update_dtype", update['dtype'], mstype.tensor, self.name)
validator.check_tensor_type_same({"indices": indices['dtype']}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid("indices", indices['dtype'], [mstype.int32], self.name)
validator.check_value_type("shape", shp, [tuple], self.name)
for i, x in enumerate(shp):
validator.check_positive_int(x, f'shape[{i}]', self.name)
@ -2632,14 +2636,13 @@ class ResizeNearestNeighbor(PrimitiveWithInfer):
validator.check_non_negative_int(value, f'{i}th value of size', self.name)
self.init_prim_io_names(inputs=['image_in'], outputs=['image_out'])
def infer_shape(self, x):
validator.check('the dimension of input_x', len(x), '', 4, Rel.EQ, self.name)
return tuple(x)[:-2] + tuple(self.size)
def infer_shape(self, x_shape):
validator.check('the dimension of input_x', len(x_shape), '', 4, Rel.EQ, self.name)
return tuple(x_shape)[:-2] + tuple(self.size)
def infer_dtype(self, x):
validator.check_subclass("x", x, mstype.tensor, self.name)
validator.check_tensor_type_same({"x": x}, mstype.number_type, self.name)
return x
def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid("x", x_dtype, mstype.number_type, self.name)
return x_dtype
class GatherNd(PrimitiveWithInfer):
@ -2674,8 +2677,7 @@ class GatherNd(PrimitiveWithInfer):
return indices_shape[:-1] + x_shape[indices_shape[-1]:]
def infer_dtype(self, x_dtype, indices_dtype):
validator.check_subclass("x_dtype", x_dtype, mstype.tensor, self.name)
validator.check_tensor_type_same({"indices": indices_dtype}, mstype.int_type, self.name)
validator.check_tensor_dtype_valid("indices", indices_dtype, mstype.int_type, self.name)
return x_dtype
@ -2715,9 +2717,9 @@ class TensorScatterUpdate(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
args = {"x": x_dtype, "value": value_dtype}
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype
@ -2763,9 +2765,9 @@ class ScatterUpdate(_ScatterOp):
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
args = {"x": x_dtype, "value": value_dtype}
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype
@ -2802,7 +2804,6 @@ class ScatterNdUpdate(_ScatterNdOp):
[0.4 2.2 -3.2]]
"""
@prim_attr_register
def __init__(self, use_locking=True):
"""Initialize ScatterNdUpdate"""
@ -2810,9 +2811,9 @@ class ScatterNdUpdate(_ScatterNdOp):
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
args = {"x": x_dtype, "value": value_dtype}
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype
@ -3131,9 +3132,9 @@ class ScatterNonAliasingAdd(_ScatterNdOp):
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
args = {"x": x_dtype, "updates": updates_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int32], self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32, mstype.int32], self.name)
return x_dtype
@ -3304,7 +3305,7 @@ class SpaceToBatch(PrimitiveWithInfer):
self.paddings = paddings
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name)
return x_dtype
def infer_shape(self, x_shape):
@ -3376,7 +3377,7 @@ class BatchToSpace(PrimitiveWithInfer):
self.crops = crops
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name)
return x_dtype
def infer_shape(self, x_shape):
@ -3465,7 +3466,7 @@ class SpaceToBatchND(PrimitiveWithInfer):
self.add_prim_attr("paddings", paddings_append)
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name)
return x_dtype
def infer_shape(self, x_shape):
@ -3558,7 +3559,7 @@ class BatchToSpaceND(PrimitiveWithInfer):
self.add_prim_attr("crops", crops_append)
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name)
return x_dtype
def infer_shape(self, x_shape):
@ -3721,7 +3722,6 @@ class Meshgrid(PrimitiveWithInfer):
out_shape = tuple(tuple(shape_0) for _ in range(n))
return out_shape
def infer_dtype(self, x_type):
validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, self.name)
n = len(x_type)
@ -3729,6 +3729,7 @@ class Meshgrid(PrimitiveWithInfer):
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0], Rel.EQ, self.name, TypeError)
return x_type
class InplaceUpdate(PrimitiveWithInfer):
r"""
Updates specified rows with values in `v`.
@ -3771,7 +3772,7 @@ class InplaceUpdate(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, v_dtype):
args = {'x': x_dtype, 'v': v_dtype}
valid_type = [mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
return x_dtype
def infer_shape(self, x_shape, v_shape):
@ -3831,8 +3832,8 @@ class ReverseSequence(PrimitiveWithInfer):
return x
def infer_dtype(self, x, seq_lengths):
validator.check_tensor_type_same({"x_dtype": x}, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensor_type_same({"seq_lengths_dtype": seq_lengths}, [mstype.int32, mstype.int64], self.name)
validator.check_tensor_dtype_valid("x_dtype", x, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensor_dtype_valid("seq_lengths_dtype", seq_lengths, [mstype.int32, mstype.int64], self.name)
return x
@ -3899,9 +3900,9 @@ class EditDistance(PrimitiveWithInfer):
validator.check_const_input('truth_shape', truth_shape['value'], self.name)
args_int = {"hypothesis_indices": h_indices['dtype'], "hypothesis_shape": h_shape['dtype'],
"truth_indices": truth_indices['dtype'], "truth_shape": truth_shape['dtype']}
validator.check_tensor_type_same(args_int, [mstype.int64], self.name)
validator.check_tensors_dtypes_same_and_valid(args_int, [mstype.int64], self.name)
args = {"hypothesis_values": h_values['dtype'], "truth_values": truth_values['dtype']}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
hypothesis_indices_shp, truth_indices_shp = h_indices['shape'], truth_indices['shape']
validator.check("hypothesis_indices rank", len(hypothesis_indices_shp), "expected", 2, Rel.EQ, self.name)
@ -3941,6 +3942,7 @@ class TransShape(PrimitiveWithInfer):
Outputs:
Tensor, a tensor whose data type is same as 'input_x', and the shape is the same as the `out_shape`.
"""
@prim_attr_register
def __init__(self):
self.__setattr_flag__ = True
@ -3948,7 +3950,7 @@ class TransShape(PrimitiveWithInfer):
def __infer__(self, x, shape):
shp = shape['value']
dtype = x['dtype']
validator.check_tensor_type_same({'x': dtype}, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensor_dtype_valid('x', dtype, mstype.number_type + (mstype.bool_,), self.name)
self.add_prim_attr('out_shape', tuple(shp))
return {'shape': shp,
'dtype': dtype,
@ -3989,7 +3991,7 @@ class Sort(PrimitiveWithInfer):
return x_shape, x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x_dtype": x_dtype}, [mstype.float32, mstype.float16], self.name)
validator.check_tensor_dtype_valid("x_dtype", x_dtype, [mstype.float32, mstype.float16], self.name)
return x_dtype, mstype.tensor_type(mstype.int32)
@ -4019,6 +4021,7 @@ class EmbeddingLookup(PrimitiveWithInfer):
>>> out = P.EmbeddingLookup()(input_params, input_indices, offset)
[[[10, 11], [0 ,0]], [[0, 0], [10, 11]]]
"""
@prim_attr_register
def __init__(self):
"""Initialize index_select"""
@ -4028,7 +4031,7 @@ class EmbeddingLookup(PrimitiveWithInfer):
def __infer__(self, params, indices, offset):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name)
params_shp = params['shape']
if len(params_shp) != 2:
@ -4060,6 +4063,7 @@ class GatherD(PrimitiveWithInfer):
>>> out = P.GatherD()(x, dim, index)
[[1, 1], [4, 3]]
"""
@prim_attr_register
def __init__(self):
"""Initialize GatherD"""
@ -4067,7 +4071,7 @@ class GatherD(PrimitiveWithInfer):
def __infer__(self, x, dim, index):
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
validator.check_tensor_type_same({"index": index['dtype']}, [mstype.int32, mstype.int64], self.name)
validator.check_tensor_dtype_valid("index", index['dtype'], [mstype.int32, mstype.int64], self.name)
validator.check_subclass("dim", dim['dtype'], mstype.int32, self.name)
x_shp = x['shape']
idx_shp = index['shape']
@ -4103,6 +4107,7 @@ class Identity(PrimitiveWithInfer):
>>> y = P.Identity()(x)
[1, 2, 3, 4]
"""
@prim_attr_register
def __init__(self):
"""Initialize identity"""

View File

@ -105,7 +105,7 @@ class AllReduce(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
return x_dtype
@ -167,7 +167,7 @@ class AllGather(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
return x_dtype
def __call__(self, tensor):
@ -217,7 +217,7 @@ class _HostAllGather(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
return x_dtype
def __call__(self, tensor):
@ -279,7 +279,7 @@ class ReduceScatter(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
return x_dtype
def __call__(self, tensor):
@ -328,7 +328,7 @@ class _HostReduceScatter(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
return x_dtype
def __call__(self, tensor):
@ -390,7 +390,7 @@ class Broadcast(PrimitiveWithInfer):
if not isinstance(x_dtype, tuple):
raise TypeError(f"{self.name}'s input should be a tuple!")
for _ele in x_dtype:
validator.check_tensor_type_same({'x': _ele}, target_dtypes, self.name)
validator.check_tensor_dtype_valid('x', _ele, target_dtypes, self.name)
return x_dtype
@ -432,7 +432,7 @@ class _AlltoAll(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
return x_dtype
def __call__(self, tensor):

View File

@ -132,8 +132,7 @@ class GeSwitch(PrimitiveWithInfer):
def infer_dtype(self, data_type, pred_type):
validator.check_subclass(
"data", data_type, (mstype.tensor,) + mstype.number_type, self.name)
validator.check_tensor_type_same(
{"pred": pred_type}, [mstype.bool_], self.name)
validator.check_tensor_dtype_valid("pred", pred_type, [mstype.bool_], self.name)
return (data_type, data_type)
@ -171,5 +170,5 @@ class Merge(PrimitiveWithInfer):
for i, item in enumerate(inputs):
args['inputs[%d]' % i] = item
validator.check_scalar_or_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
validator.check_scalar_or_tensor_types_same(args, (mstype.bool_,) + mstype.number_type, self.name)
return (inputs[0], mstype.int32)

View File

@ -380,7 +380,7 @@ class Assert(PrimitiveWithInfer):
return [1]
def infer_dtype(self, condition, inputs):
validator.check_scalar_or_tensor_type_same({"condition": condition}, [mstype.bool_], self.name)
validator.check_scalar_or_tensor_types_same({"condition": condition}, [mstype.bool_], self.name)
for dtype in inputs:
validator.check_subclass("input", dtype, [mstype.tensor], self.name)
return mstype.int32

View File

@ -104,11 +104,11 @@ class CropAndResize(PrimitiveWithInfer):
box_index_dtype = box_index['dtype']
crop_size_dtype = crop_size['dtype']
# check dytpe
validator.check_tensor_type_same({"x": x_dtype},
[mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.float16,
mstype.float32, mstype.float64, mstype.uint8, mstype.uint16], self.name)
validator.check_tensor_type_same({"boxes": boxes_dtype}, [mstype.float32], self.name)
validator.check_tensor_type_same({"box_index": box_index_dtype}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid("x", x_dtype,
[mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.float16,
mstype.float32, mstype.float64, mstype.uint8, mstype.uint16], self.name)
validator.check_tensor_dtype_valid("boxes", boxes_dtype, [mstype.float32], self.name)
validator.check_tensor_dtype_valid("box_index", box_index_dtype, [mstype.int32], self.name)
validator.check_value_type("crop_size", crop_size_value, [tuple], self.name)
# check input shape rank
validator.check("x rank", len(x_shape), "expected", 4, Rel.EQ, self.name)

View File

@ -16,6 +16,8 @@
"""Operators for math."""
import copy
from functools import partial
import numpy as np
from ... import context
from .. import signature as sig
@ -85,7 +87,7 @@ class _MathBinaryOp(_BinaryOp):
@staticmethod
def do_infer_dtype(x_dtype, y_dtype, valid_dtype=mstype.number_type, prim_name=None):
args_type = {"x": x_dtype, "y": y_dtype}
validator.check_tensor_type_same(args_type, valid_dtype, prim_name)
validator.check_tensors_dtypes_same_and_valid(args_type, valid_dtype, prim_name)
return x_dtype
def infer_dtype(self, x_dtype, y_dtype):
@ -105,8 +107,8 @@ class _BitwiseBinaryOp(_MathBinaryOp):
@staticmethod
def _check_bitwise_op_input_type(x1_type, x2_type, prim):
args = {'x1': x1_type, 'x2': x2_type}
valid_types = mstype.int_type + mstype.uint_type
validator.check_tensor_type_same(args, valid_types, prim)
valid_dtypes = mstype.int_type + mstype.uint_type
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, prim)
return x1_type
def infer_dtype(self, x1_type, x2_type):
@ -198,7 +200,7 @@ class AssignAdd(PrimitiveWithInfer):
def infer_dtype(self, variable, value):
args = {"variable": variable, "value": value}
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name)
validator.check_scalar_or_tensor_types_same(args, mstype.number_type, self.name)
return value
@ -248,7 +250,7 @@ class AssignSub(PrimitiveWithInfer):
def infer_dtype(self, variable, value):
args = {"variable": variable, "value": value}
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name)
validator.check_scalar_or_tensor_types_same(args, mstype.number_type, self.name)
return value
@ -283,7 +285,7 @@ class _Reduce(PrimitiveWithInfer):
axis_v = axis['value']
input_shp = input_x['shape']
args = {'input_x': input_x['dtype']}
validator.check_tensor_type_same(args, valid_dtype, self.name)
validator.check_tensors_dtypes_same_and_valid(args, valid_dtype, self.name)
if axis_v is None:
raise ValueError(f"For {self.name}, axis must be const.")
@ -504,6 +506,7 @@ class ReduceMax(_Reduce):
def __infer__(self, input_x, axis):
return self.do_infer(input_x, axis, mstype.number_type + (mstype.bool_,))
class ReduceMin(_Reduce):
"""
Reduce a dimension of a tensor by the minimum value in the dimension.
@ -612,7 +615,7 @@ class CumProd(PrimitiveWithInfer):
def infer_dtype(self, x_type, axis_type):
cls_name = self.name
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, cls_name)
validator.check_tensor_dtype_valid('x', x_type, mstype.number_type, cls_name)
validator.check_subclass("axis", axis_type, mstype.int_, cls_name)
return x_type
@ -689,7 +692,7 @@ class MatMul(PrimitiveWithInfer):
def infer_dtype(self, x1, x2):
args = {"x1": x1, "x2": x2}
validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type + mstype.int_type, self.name)
if x1.element_type() == mstype.int8:
return mstype.tensor_type(mstype.int32)
return x1
@ -801,10 +804,10 @@ class TensorDot(PrimitiveWithInfer):
self.axes = axes
validator.check_value_type('axes', axes, [int, tuple, list], self.name)
if not isinstance(self.axes, int):
self.axes = list(self.axes) # to avoid immutability issues
self.axes = list(self.axes) # to avoid immutability issues
if len(self.axes) != 2:
raise ValueError("Require two axes inputs, given less")
self.int_to_tuple_conv() # convert before length checks
self.int_to_tuple_conv() # convert before length checks
if len(self.axes[0]) != len(self.axes[1]):
raise ValueError("Axes have to be the same size/length")
if len(self.axes[0]) != len(set(self.axes[0])) or len(self.axes[1]) != len(set(self.axes[1])):
@ -825,7 +828,7 @@ class TensorDot(PrimitiveWithInfer):
if isinstance(self.axes, int):
if self.axes <= 0:
# outer product, no input validation required
self.axes = ([], []) # no axes selected for either
self.axes = ([], []) # no axes selected for either
return
if self.axes > len(x1_shape) or self.axes > len(x2_shape):
raise ValueError(
@ -877,8 +880,8 @@ class TensorDot(PrimitiveWithInfer):
def infer_dtype(self, x1, x2):
args = {"x1": x1, "x2": x2}
valid_types = [mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_types, self.name)
valid_dtypes = [mstype.float16, mstype.float32]
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
return x1
@ -922,8 +925,8 @@ class CumSum(PrimitiveWithInfer):
if axis['value'] is None:
raise ValueError(f"For {self.name}, axis must be const.")
validator.check_value_type('axis', axis['value'], [int], cls_name)
valid_types = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_type_same({'x': x['dtype']}, valid_types, cls_name)
valid_dtypes = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_dtype_valid('x', x['dtype'], valid_dtypes, cls_name)
return {'shape': x_shp,
'dtype': x['dtype'],
'value': None}
@ -989,7 +992,7 @@ class AddN(PrimitiveWithInfer):
if dtype == mstype.undetermined:
contains_undetermined = True
if not contains_undetermined:
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), cls_name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type + (mstype.bool_,), cls_name)
return inputs[0]
def infer_value(self, inputs):
@ -1068,7 +1071,7 @@ class AccumulateNV2(PrimitiveWithInfer):
args = {}
for i, dtype in enumerate(inputs):
args[f"inputs[{i}]"] = dtype
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), cls_name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type + (mstype.bool_,), cls_name)
return inputs[0]
@ -1094,12 +1097,12 @@ class Neg(PrimitiveWithInfer):
"""Initialize Neg"""
self.init_prim_io_names(inputs=['x'], outputs=['y'])
def infer_shape(self, input_x):
return input_x
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, input_x):
validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name)
return input_x
def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid("x", x_dtype, mstype.number_type, self.name)
return x_dtype
def infer_value(self, input_x):
if input_x is not None:
@ -1151,7 +1154,7 @@ class InplaceAdd(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, v_dtype):
args = {'x': x_dtype, 'v': v_dtype}
valid_type = [mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
return x_dtype
def infer_shape(self, x_shape, v_shape):
@ -1209,7 +1212,7 @@ class InplaceSub(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, v_dtype):
args = {'x': x_dtype, 'v': v_dtype}
valid_type = [mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
return x_dtype
def infer_shape(self, x_shape, v_shape):
@ -1363,9 +1366,9 @@ class Square(PrimitiveWithInfer):
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_type):
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name)
return x_type
def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid("x", x_dtype, mstype.number_type, self.name)
return x_dtype
def infer_value(self, x):
if x is not None:
@ -1401,9 +1404,9 @@ class Rsqrt(PrimitiveWithInfer):
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_type):
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name)
return x_type
def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid("x", x_dtype, mstype.number_type, self.name)
return x_dtype
def infer_value(self, x):
if x is not None:
@ -1437,7 +1440,7 @@ class Sqrt(PrimitiveWithCheck):
self.init_prim_io_names(inputs=['x'], outputs=['output'])
def check_dtype(self, x_type):
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid("x", x_type, mstype.number_type, self.name)
def infer_value(self, x):
if x is not None:
@ -1599,8 +1602,7 @@ class Expm1(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_type):
validator.check_subclass("x", x_type, mstype.tensor, self.name)
validator.check_tensor_type_same({"x": x_type}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid("x", x_type, [mstype.float16, mstype.float32], self.name)
return x_type
@ -1641,10 +1643,9 @@ class HistogramFixedWidth(PrimitiveWithInfer):
return (self.nbins,)
def infer_dtype(self, x_dtype, range_dtype):
validator.check_subclass("x", x_dtype, mstype.tensor, self.name)
valid_types = (mstype.float16, mstype.float32, mstype.int32)
validator.check_tensor_type_same({"x": x_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"range": range_dtype}, valid_types, self.name)
valid_dtypes = (mstype.float16, mstype.float32, mstype.int32)
validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
validator.check_tensor_dtype_valid("range", range_dtype, valid_dtypes, self.name)
y_dtype = mstype.int32
return y_dtype
@ -1707,13 +1708,13 @@ class Log1p(PrimitiveWithInfer):
def __init__(self):
self.init_prim_io_names(inputs=['x'], outputs=['y'])
def infer_shape(self, x):
return x
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x):
validator.check_subclass("x", x, mstype.tensor, self.name)
validator.check_tensor_type_same({"x": x}, [mstype.float16, mstype.float32], self.name)
return x
def infer_dtype(self, x_dtype):
validator.check_subclass("x", x_dtype, mstype.tensor, self.name)
validator.check_tensor_dtype_valid("x", x_dtype, [mstype.float16, mstype.float32], self.name)
return x_dtype
class Erf(PrimitiveWithInfer):
@ -1741,9 +1742,9 @@ class Erf(PrimitiveWithInfer):
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_type):
validator.check_tensor_type_same({"x": x_type}, [mstype.float16, mstype.float32], self.name)
return x_type
def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid("x", x_dtype, [mstype.float16, mstype.float32], self.name)
return x_dtype
class Erfc(PrimitiveWithInfer):
@ -1772,7 +1773,7 @@ class Erfc(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_type):
validator.check_tensor_type_same({"x": x_type}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid("x", x_type, [mstype.float16, mstype.float32], self.name)
return x_type
@ -2126,7 +2127,7 @@ class Floor(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x": x_dtype}, mstype.float_type, self.name)
validator.check_tensor_dtype_valid("x", x_dtype, mstype.float_type, self.name)
return x_dtype
@ -2185,7 +2186,7 @@ class Ceil(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x": x_dtype}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid("x", x_dtype, [mstype.float16, mstype.float32], self.name)
return x_dtype
@ -2281,7 +2282,7 @@ class Acosh(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype
@ -2310,7 +2311,7 @@ class Cosh(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype
@ -2339,7 +2340,7 @@ class Asinh(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype
@ -2368,7 +2369,7 @@ class Sinh(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype
@ -2380,7 +2381,7 @@ class _LogicBinaryOp(_BinaryOp):
@staticmethod
def do_infer_dtype(x_dtype, y_dtype, valid_type=mstype.number_type, prim_name=None):
args_dtype = {"x": x_dtype, "y": y_dtype}
validator.check_tensor_type_same(args_dtype, valid_type, prim_name)
validator.check_tensors_dtypes_same_and_valid(args_dtype, valid_type, prim_name)
return mstype.tensor_type(mstype.bool_)
def infer_dtype(self, x_dtype, y_dtype):
@ -2461,7 +2462,7 @@ class ApproximateEqual(_LogicBinaryOp):
def infer_dtype(self, x_dtype, y_dtype):
args_dtype = {"x": x_dtype, "y": y_dtype}
valid_type = [mstype.float32, mstype.float16]
validator.check_tensor_type_same(args_dtype, valid_type, prim_name=self.name)
validator.check_tensors_dtypes_same_and_valid(args_dtype, valid_type, prim_name=self.name)
return mstype.tensor_type(mstype.bool_)
@ -2498,7 +2499,7 @@ class EqualCount(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, y_dtype):
args = {'x': x_dtype, 'y': y_dtype}
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type + (mstype.bool_,), self.name)
return x_dtype
@ -2711,7 +2712,7 @@ class LogicalNot(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x": x_dtype}, [mstype.bool_], self.name)
validator.check_tensor_dtype_valid("x", x_dtype, [mstype.bool_], self.name)
return mstype.tensor_type(mstype.bool_)
@ -2859,8 +2860,7 @@ class IsFinite(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_subclass("x", x_dtype, mstype.tensor, self.name)
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type + (mstype.bool_,), self.name)
return mstype.bool_
@ -2890,7 +2890,7 @@ class FloatStatus(PrimitiveWithInfer):
return [1]
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32, mstype.float16], self.name)
validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.float16], self.name)
return x_dtype
@ -2959,7 +2959,7 @@ class NPUGetFloatStatus(PrimitiveWithInfer):
return [8]
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float16, mstype.float32], self.name)
return mstype.float32
@ -3002,7 +3002,7 @@ class NPUClearFloatStatus(PrimitiveWithInfer):
return [8]
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float16, mstype.float32], self.name)
return mstype.float32
@ -3030,7 +3030,7 @@ class Cos(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype
@ -3058,7 +3058,7 @@ class ACos(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype
@ -3087,7 +3087,7 @@ class Sin(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype
@ -3116,7 +3116,7 @@ class Asin(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype
@ -3175,7 +3175,7 @@ class NMSWithMask(PrimitiveWithInfer):
return (bboxes_shape, (num,), (num,))
def infer_dtype(self, bboxes_dtype):
validator.check_tensor_type_same({"bboxes": bboxes_dtype}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid("bboxes", bboxes_dtype, [mstype.float16, mstype.float32], self.name)
return (bboxes_dtype, mstype.int32, mstype.bool_)
@ -3205,7 +3205,7 @@ class Abs(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_type):
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_type, mstype.number_type, self.name)
return x_type
def infer_value(self, x):
@ -3247,7 +3247,7 @@ class Sign(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype
@ -3276,9 +3276,9 @@ class Round(PrimitiveWithInfer):
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_type):
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name)
return x_type
def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype
class Tan(PrimitiveWithInfer):
@ -3306,8 +3306,8 @@ class Tan(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_type):
valid_types = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_type_same({'x': x_type}, valid_types, self.name)
valid_dtypes = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_dtype_valid('x', x_type, valid_dtypes, self.name)
return x_type
@ -3338,7 +3338,7 @@ class Atan(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_type):
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_type, mstype.number_type, self.name)
return x_type
@ -3367,7 +3367,7 @@ class Atanh(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_type):
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_type, mstype.number_type, self.name)
return x_type
@ -3431,8 +3431,9 @@ class SquareSumAll(PrimitiveWithInfer):
return [], []
def infer_dtype(self, x_type, y_type):
validator.check_tensor_type_same({'x1_type': x_type}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_type_same({'x2_type': y_type}, [mstype.float16, mstype.float32], self.name)
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_dtype_valid('x1_type', x_type, valid_types, self.name)
validator.check_tensor_dtype_valid('x2_type', y_type, valid_types, self.name)
return x_type, y_type
@ -3539,7 +3540,7 @@ class BesselI0e(PrimitiveWithInfer):
return x
def infer_dtype(self, x):
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x, mstype.number_type, self.name)
return x
@ -3568,7 +3569,7 @@ class BesselI1e(PrimitiveWithInfer):
return x
def infer_dtype(self, x):
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x, mstype.number_type, self.name)
return x
@ -3598,7 +3599,7 @@ class Inv(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x_dtype': x_dtype}, [mstype.float16, mstype.float32,
validator.check_tensor_dtype_valid('x_dtype', x_dtype, [mstype.float16, mstype.float32,
mstype.int32], self.name)
return x_dtype
@ -3628,7 +3629,7 @@ class Invert(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x_dtype': x_dtype}, [mstype.int16, mstype.uint16], self.name)
validator.check_tensor_dtype_valid('x_dtype', x_dtype, [mstype.int16, mstype.uint16], self.name)
return x_dtype
@ -3654,8 +3655,8 @@ class Eps(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['input_x'], outputs=['y'])
def __infer__(self, input_x):
valid_types = [mstype.float16, mstype.float32]
validator.check_tensor_type_same({'input_x': input_x['dtype']}, valid_types, self.name)
valid_dtypes = [mstype.float16, mstype.float32]
validator.check_tensor_dtype_valid('input_x', input_x['dtype'], valid_dtypes, self.name)
x_nptype = mstype.dtype_to_nptype(input_x['dtype'].element_type())
if x_nptype == np.float16:
@ -3725,9 +3726,9 @@ class IFMR(PrimitiveWithInfer):
return (1,), (1,)
def infer_dtype(self, data_dtype, data_min_dtype, data_max_dtype, cumsum_dtype):
valid_types = [mstype.float32, mstype.float16]
validator.check_tensor_type_same({"input_value": data_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"input_min": data_min_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"input_max": data_max_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"input_bins": cumsum_dtype}, [mstype.int32], self.name)
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
("input_value", "input_min", "input_max"),
(data_dtype, data_min_dtype, data_max_dtype)))
validator.check_tensor_dtype_valid("input_bins", cumsum_dtype, [mstype.int32], self.name)
return mstype.tensor_type(mstype.float32), mstype.tensor_type(mstype.float32)

View File

@ -17,7 +17,7 @@
import math
import operator
from functools import reduce
from functools import reduce, partial
import numpy as np
from ... import context
from .. import signature as sig
@ -153,8 +153,7 @@ class Softmax(PrimitiveWithInfer):
return logits
def infer_dtype(self, logits):
validator.check_subclass("logits", logits, mstype.tensor, self.name)
validator.check_tensor_type_same({"logits": logits}, mstype.float_type, self.name)
validator.check_tensor_dtype_valid("logits", logits, mstype.float_type, self.name)
return logits
@ -197,8 +196,7 @@ class LogSoftmax(PrimitiveWithInfer):
return logits
def infer_dtype(self, logits):
validator.check_subclass("logits", logits, mstype.tensor, self.name)
validator.check_tensor_type_same({"logits": logits}, mstype.float_type, self.name)
validator.check_tensor_dtype_valid("logits", logits, mstype.float_type, self.name)
return logits
@ -230,12 +228,12 @@ class Softplus(PrimitiveWithInfer):
"""Initialize Softplus"""
self.init_prim_io_names(inputs=['x'], outputs=['output'])
def infer_shape(self, input_x):
return input_x
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, input_x):
validator.check_tensor_type_same({'input_x': input_x}, mstype.float_type, self.name)
return input_x
def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid('x', x_dtype, mstype.float_type, self.name)
return x_dtype
class Softsign(PrimitiveWithInfer):
@ -269,7 +267,7 @@ class Softsign(PrimitiveWithInfer):
return input_x
def infer_dtype(self, input_x):
validator.check_tensor_type_same({'input_x': input_x}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid('input_x', input_x, [mstype.float16, mstype.float32], self.name)
return input_x
@ -301,7 +299,7 @@ class ReLU(PrimitiveWithInfer):
return input_x
def infer_dtype(self, input_x):
validator.check_tensor_type_same({'input_x': input_x}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('input_x', input_x, mstype.number_type, self.name)
return input_x
@ -332,7 +330,7 @@ class ReLU6(PrimitiveWithInfer):
return input_x
def infer_dtype(self, input_x):
validator.check_tensor_type_same({'input_x': input_x}, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid('input_x', input_x, (mstype.float16, mstype.float32), self.name)
return input_x
@ -384,7 +382,7 @@ class ReLUV2(PrimitiveWithInfer):
output_shape = (input_x['shape'], mask_shape)
validator.check_subclass("input_x", input_dtype, mstype.tensor, self.name)
validator.check_tensor_type_same({'input_x': input_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('input_x', input_dtype, mstype.number_type, self.name)
mask_dtype = mstype.uint8
output_dtype = (input_dtype, mask_dtype)
@ -426,7 +424,7 @@ class Elu(PrimitiveWithInfer):
return input_x
def infer_dtype(self, input_x):
validator.check_tensor_type_same({'input_x': input_x}, mstype.float_type, self.name)
validator.check_tensor_dtype_valid('input_x', input_x, mstype.float_type, self.name)
return input_x
@ -463,7 +461,7 @@ class HSwish(PrimitiveWithInfer):
return xshape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("x", x_dtype, (mstype.float16, mstype.float32), self.name)
return x_dtype
@ -499,7 +497,7 @@ class Sigmoid(PrimitiveWithInfer):
return input_x
def infer_dtype(self, input_x):
validator.check_tensor_type_same({"input_x": input_x}, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("input_x", input_x, (mstype.float16, mstype.float32), self.name)
return input_x
@ -536,7 +534,7 @@ class HSigmoid(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("x", x_dtype, (mstype.float16, mstype.float32), self.name)
return x_dtype
@ -733,12 +731,12 @@ class FusedBatchNormEx(PrimitiveWithInfer):
return (input_x, scale, scale, scale, scale, scale)
def infer_dtype(self, input_x, scale, bias, mean, variance):
validator.check_tensor_type_same({"input_x": input_x}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid("input_x", input_x, [mstype.float16, mstype.float32], self.name)
args = {"scale": scale, "bias": bias}
validator.check_tensor_type_same(args, [mstype.float32], self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float32], self.name)
args_moving = {"mean": mean, "variance": variance}
valid_types = [mstype.tensor_type(mstype.float32)]
validator.check_type_same(args_moving, valid_types, self.name)
valid_dtypes = [mstype.tensor_type(mstype.float32)]
validator.check_types_same_and_valid(args_moving, valid_dtypes, self.name)
return (input_x, scale, scale, scale, scale, scale)
@ -769,7 +767,7 @@ class BNTrainingReduce(PrimitiveWithInfer):
return ([x_shape[1]], [x_shape[1]])
def infer_dtype(self, x_type):
validator.check_tensor_type_same({"x_type": x_type}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid("x", x_type, [mstype.float16, mstype.float32], self.name)
return (x_type, x_type)
@ -819,6 +817,7 @@ class BNTrainingUpdate(PrimitiveWithInfer):
>>> bn_training_update = P.BNTrainingUpdate()
>>> output = bn_training_update(input_x, sum, square_sum, scale, offset, mean, variance)
"""
@prim_attr_register
def __init__(self, isRef=True, epsilon=1e-5, factor=0.1):
self.init_prim_io_names(inputs=['x', 'sum', 'square_sum', 'scale', 'b', 'mean', 'variance'],
@ -846,13 +845,10 @@ class BNTrainingUpdate(PrimitiveWithInfer):
return (x, variance, variance, variance, variance)
def infer_dtype(self, x, sum, square_sum, scale, b, mean, variance):
validator.check_tensor_type_same({"x_type": x}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_type_same({"sum_type": sum}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_type_same({"square_sum_type": square_sum}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_type_same({"scale_type": scale}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_type_same({"b_type": b}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_type_same({"mean_type": mean}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_type_same({"variance_type": variance}, [mstype.float16, mstype.float32], self.name)
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
("x", "sum", "square_sum", "scale", "b", "mean", "variance"),
(x, sum, square_sum, scale, b, mean, variance)))
return (x, variance, variance, variance, variance)
@ -928,16 +924,16 @@ class BatchNorm(PrimitiveWithInfer):
return (input_x, scale, scale, scale, scale)
def infer_dtype(self, input_x, scale, bias, mean, variance):
validator.check_tensor_type_same({"input_x": input_x}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid("input_x", input_x, [mstype.float16, mstype.float32], self.name)
args = {"scale": scale, "bias": bias}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
args_moving = {"mean": mean, "variance": variance}
if self.is_training:
valid_types = [mstype.tensor_type(mstype.float16), mstype.tensor_type(mstype.float32), None]
validator.check_type_same(args_moving, valid_types, self.name)
valid_dtypes = [mstype.tensor_type(mstype.float16), mstype.tensor_type(mstype.float32), None]
validator.check_types_same_and_valid(args_moving, valid_dtypes, self.name)
else:
args_moving = {"mean": mean, "variance": variance}
validator.check_tensor_type_same(args_moving, [mstype.float16, mstype.float32], self.name)
validator.check_tensors_dtypes_same_and_valid(args_moving, [mstype.float16, mstype.float32], self.name)
return (input_x, scale, bias, input_x, input_x)
@ -1053,7 +1049,7 @@ class Conv2D(PrimitiveWithInfer):
validator.check_equal_int(len(w_shape_norm), 4, "weight rank", self.name)
validator.check_equal_int(len(x_shape_norm), 4, "x rank", self.name)
validator.check(f"x_shape[1] / group", x_shape_norm[1] // self.group, "w_shape[1]", w_shape_norm[1], \
Rel.EQ, self.name)
Rel.EQ, self.name)
validator.check('out_channel', self.out_channel, 'w_shape[0]', w_shape_norm[0], Rel.EQ, self.name)
validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape_norm[2:4]), Rel.EQ, self.name)
@ -1084,24 +1080,24 @@ class Conv2D(PrimitiveWithInfer):
pad_top, pad_bottom, pad_left, pad_right = self.padding
h_out = 1 + (x_shape_norm[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) \
* (dilation_h - 1)) / stride_h
* (dilation_h - 1)) / stride_h
w_out = 1 + (x_shape_norm[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) \
* (dilation_w - 1)) / stride_w
* (dilation_w - 1)) / stride_w
h_out = math.floor(h_out)
w_out = math.floor(w_out)
self.pad_list = [pad_top, pad_bottom, pad_left, pad_right]
self.add_prim_attr('pad_list', (pad_top, pad_bottom, pad_left, pad_right))
out_channel = self.out_channel
out_shape = [x_shape_norm[0], out_channel, h_out, w_out] if self.format == "NCHW" else\
out_shape = [x_shape_norm[0], out_channel, h_out, w_out] if self.format == "NCHW" else \
[x_shape_norm[0], h_out, w_out, out_channel]
_check_shape('output', out_shape, self.name)
return out_shape
def infer_dtype(self, x_dtype, w_dtype, b_dtype=None):
args = {'x': x_dtype, 'w': w_dtype}
valid_types = [mstype.int8, mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_types, self.name)
valid_dtypes = [mstype.int8, mstype.int32, mstype.float16, mstype.float32]
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
if x_dtype.element_type() == mstype.int8:
return mstype.tensor_type(mstype.int32)
return x_dtype
@ -1220,9 +1216,9 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
pad_top, pad_bottom, pad_left, pad_right = self.padding
h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) \
/ stride_h
/ stride_h
w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) \
/ stride_w
/ stride_w
h_out = math.floor(h_out)
w_out = math.floor(w_out)
@ -1235,7 +1231,7 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, w_dtype, b_dtype=None):
args = {'x': x_dtype, 'w': w_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
if x_dtype.element_type() == mstype.int8:
return mstype.tensor_type(mstype.int32)
return x_dtype
@ -1436,7 +1432,7 @@ class MaxPoolWithArgmax(_Pool):
def infer_dtype(self, x_dtype):
out_dtype = x_dtype
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("x", x_dtype, (mstype.float16, mstype.float32), self.name)
argmax_dtype = mstype.uint16
if self.is_gpu:
argmax_dtype = mstype.int32
@ -1604,12 +1600,12 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
for i, dim_len in enumerate(x_size_v):
validator.check_value_type("x_size[%d]" % i, dim_len, [int], self.name)
args = {'doutput': doutput['dtype'], 'w': w['dtype']}
valid_types = [mstype.int8, mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_types, self.name)
valid_dtypes = [mstype.int8, mstype.int32, mstype.float16, mstype.float32]
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
# infer shape
dout_shape = doutput['shape']
dout_shape_norm = dout_shape if self.format == "NCHW" else\
dout_shape_norm = dout_shape if self.format == "NCHW" else \
[dout_shape[0], dout_shape[2], dout_shape[3], dout_shape[1]]
kernel_h = self.kernel_size[0]
kernel_w = self.kernel_size[1]
@ -1682,7 +1678,7 @@ class BiasAdd(PrimitiveWithInfer):
def infer_dtype(self, x_type, b_type):
args = {"input_x": x_type, "bias": b_type}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return x_type
@ -1721,8 +1717,8 @@ class TopK(PrimitiveWithInfer):
def __infer__(self, input_x, k):
x_dtype = input_x['dtype']
valid_types = (mstype.int32, mstype.float16, mstype.float32)
validator.check_tensor_type_same({'x': x_dtype}, valid_types, self.name)
valid_dtypes = (mstype.int32, mstype.float16, mstype.float32)
validator.check_tensor_dtype_valid('x', x_dtype, valid_dtypes, self.name)
k_v = k['value']
validator.check_value_type('k', k_v, (int,), self.name)
x_shape = list(input_x['shape'])
@ -1774,7 +1770,7 @@ class SoftmaxCrossEntropyWithLogits(PrimitiveWithInfer):
def infer_dtype(self, logits_type, labels_type):
args = {"logits": logits_type, "labels": labels_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
return (logits_type, logits_type)
@ -1825,8 +1821,9 @@ class SparseSoftmaxCrossEntropyWithLogits(PrimitiveWithInfer):
return loss_shape
def infer_dtype(self, logits_type, labels_type):
validator.check_tensor_type_same({"logits": logits_type}, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"labels": labels_type}, (mstype.int32, mstype.int64), self.name)
validator.check_tensor_dtype_valid("logits", logits_type, (mstype.float16, mstype.float32),
self.name)
validator.check_tensor_dtype_valid("labels", labels_type, (mstype.int32, mstype.int64), self.name)
return logits_type
@ -1886,13 +1883,13 @@ class ApplyMomentum(PrimitiveWithInfer):
return v_shape
def infer_dtype(self, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype):
valid_types = [mstype.float16, mstype.float32, mstype.float64]
valid_dtypes = [mstype.float16, mstype.float32, mstype.float64]
if v_dtype != mstype.type_refkey and a_dtype != mstype.type_refkey:
validator.check_tensor_type_same({"v": v_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"a": a_dtype}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"l_dtype": l_dtype}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"g_dtype": g_dtype}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"m_dtype": m_dtype}, valid_types, self.name)
validator.check_tensor_dtype_valid("v", v_dtype, valid_dtypes, self.name)
validator.check_tensor_dtype_valid("a", a_dtype, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"l_dtype": l_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"g_dtype": g_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"m_dtype": m_dtype}, valid_dtypes, self.name)
if not self.is_ge and self.is_tbe:
return g_dtype, g_dtype
return g_dtype
@ -1944,7 +1941,7 @@ class SmoothL1Loss(PrimitiveWithInfer):
def infer_dtype(self, prediction, target):
args = {"prediction": prediction, "target": target}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
return prediction
@ -1981,9 +1978,8 @@ class L2Loss(PrimitiveWithInfer):
return loss_shape
def infer_dtype(self, x_type):
validator.check_subclass("x_type", x_type, mstype.tensor, self.name)
valid_types = [mstype.float16, mstype.float32]
validator.check_tensor_type_same({'x_type': x_type}, valid_types, self.name)
valid_dtypes = [mstype.float16, mstype.float32]
validator.check_tensor_dtype_valid('x_type', x_type, valid_dtypes, self.name)
return x_type
@ -2019,11 +2015,10 @@ class DataFormatDimMap(PrimitiveWithInfer):
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_type):
validator.check_subclass("x", x_type, mstype.tensor, self.name)
valid_types = [mstype.int32]
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
return x_type
def infer_dtype(self, x_dtype):
valid_dtypes = [mstype.int32]
validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
return x_dtype
class RNNTLoss(PrimitiveWithInfer):
@ -2065,21 +2060,18 @@ class RNNTLoss(PrimitiveWithInfer):
validator.check_equal_int(len(input_length_shape), 1, 'input_length_rank', self.name)
validator.check_equal_int(len(label_length_shape), 1, 'label_length_rank', self.name)
validator.check('labels shape[0]', labels_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name)
validator.check('labels shape[1]', labels_shape[1], 'acts shape[2]-1', acts_shape[2]-1, Rel.EQ, self.name)
validator.check('labels shape[1]', labels_shape[1], 'acts shape[2]-1', acts_shape[2] - 1, Rel.EQ, self.name)
validator.check('input_length size', input_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name)
validator.check('label_length size', label_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name)
costs_shape = (acts_shape[0],)
return (costs_shape, acts_shape)
def infer_dtype(self, acts_type, labels_type, input_length_type, label_length_type):
validator.check_subclass("acts_type", acts_type, mstype.tensor, self.name)
validator.check_subclass("labels_type", labels_type, mstype.tensor, self.name)
validator.check_subclass("input_length_type", input_length_type, mstype.tensor, self.name)
validator.check_subclass("label_length_type", label_length_type, mstype.tensor, self.name)
validator.check_tensor_type_same({"acts_type": acts_type}, [mstype.float32, mstype.float16], self.name)
validator.check_tensor_type_same({"labels_type": labels_type}, [mstype.int32], self.name)
validator.check_tensor_type_same({"input_length_type": input_length_type}, [mstype.int32], self.name)
validator.check_tensor_type_same({"label_length_type": label_length_type}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid("acts_type", acts_type, [mstype.float32, mstype.float16], self.name)
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.int32,), prim_name=self.name),
("labels", "input_length", "label_length"),
(labels_type, input_length_type, label_length_type)))
return (acts_type, acts_type)
@ -2143,13 +2135,10 @@ class SGD(PrimitiveWithInfer):
def infer_dtype(self, parameters_dtype, gradient_dtype, learning_rate_dtype,
accum_dtype, momentum_dtype, stat_dtype):
valid_types = [mstype.float16, mstype.float32]
validator.check_tensor_type_same({"parameters": parameters_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"gradient": gradient_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"learning_rate": learning_rate_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"accum": accum_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"momentum": momentum_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"stat": stat_dtype}, valid_types, self.name)
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
("parameters", "gradient", "learning_rate", "accum", "momentum", "stat"),
(parameters_dtype, gradient_dtype, learning_rate_dtype, accum_dtype, momentum_dtype, stat_dtype)))
return parameters_dtype
@ -2229,13 +2218,13 @@ class ApplyRMSProp(PrimitiveWithInfer):
def infer_dtype(self, var_dtype, mean_square_dtype, moment_dtype, learning_rate_dtype, grad_dtype, decay_dtype,
momentum_dtype, epsilon_dtype):
args = {"var": var_dtype, "mean_square": mean_square_dtype, "moment": moment_dtype, "grad": grad_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
valid_types = [mstype.float16, mstype.float32]
valid_dtypes = [mstype.float16, mstype.float32]
args_decay = {"decay": decay_dtype, 'momentum': momentum_dtype, "epsilon": epsilon_dtype}
validator.check_type_same(args_decay, valid_types, self.name)
validator.check_types_same_and_valid(args_decay, valid_dtypes, self.name)
args_lr = {"learning_rate": learning_rate_dtype, "decay": decay_dtype}
validator.check_scalar_or_tensor_type_same(args_lr, valid_types, self.name, allow_mix=True)
validator.check_scalar_or_tensor_types_same(args_lr, valid_dtypes, self.name, allow_mix=True)
if not self.is_ge and self.is_d:
return var_dtype, var_dtype, var_dtype
return var_dtype
@ -2332,13 +2321,13 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer):
learning_rate_dtype, rho_dtype, momentum_dtype, epsilon_dtype):
args = {"var": var_dtype, "mean_gradient": mean_gradient_dtype,
"mean_square": mean_square_dtype, "moment": moment_dtype, "grad": grad_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
valid_types = [mstype.float16, mstype.float32]
valid_dtypes = [mstype.float16, mstype.float32]
args_rho = {"rho": rho_dtype, 'momentum': momentum_dtype, "epsilon": epsilon_dtype}
validator.check_type_same(args_rho, valid_types, self.name)
validator.check_types_same_and_valid(args_rho, valid_dtypes, self.name)
args_lr = {"learning_rate": learning_rate_dtype, "rho": rho_dtype}
validator.check_scalar_or_tensor_type_same(args_lr, valid_types, self.name, allow_mix=True)
validator.check_scalar_or_tensor_types_same(args_lr, valid_dtypes, self.name, allow_mix=True)
if self.is_ascend:
return var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype
return var_dtype
@ -2440,8 +2429,7 @@ class L2Normalize(PrimitiveWithInfer):
return input_x
def infer_dtype(self, input_x):
validator.check_subclass("x", input_x, mstype.tensor, self.name)
validator.check_tensor_type_same({"input_x": input_x}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid("input_x", input_x, [mstype.float16, mstype.float32], self.name)
return input_x
@ -2527,9 +2515,9 @@ class DropoutDoMask(PrimitiveWithInfer):
raise ValueError(f"DropoutDoMask y mask do not math input input_x shape:"
"{input_x_shape}, mask shape: {mask_shape}.")
validator.check_tensor_type_same({"input_x": input_x['dtype']}, [mstype.float32, mstype.float16, mstype.int32],
self.name)
validator.check_tensor_type_same({"input_mask": mask['dtype']}, [mstype.uint8], self.name)
validator.check_tensor_dtype_valid("input_x", input_x['dtype'], [mstype.float32, mstype.float16, mstype.int32],
self.name)
validator.check_tensor_dtype_valid("input_mask", mask['dtype'], [mstype.uint8], self.name)
keep_prob_v = keep_prob['value']
if keep_prob_v is not None:
@ -2587,7 +2575,8 @@ class ResizeBilinear(PrimitiveWithInfer):
return out_shape
def infer_dtype(self, input_dtype):
validator.check_tensor_type_same({'input_dtype': input_dtype}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid('input_dtype', input_dtype, [mstype.float16, mstype.float32],
self.name)
return mstype.tensor_type(mstype.float32)
@ -2631,10 +2620,10 @@ class OneHot(PrimitiveWithInfer):
def __infer__(self, indices, depth, on_value, off_value):
# check type
validator.check_tensor_type_same({"indices": indices['dtype']}, (mstype.int32,), self.name)
validator.check_tensor_dtype_valid("indices", indices['dtype'], (mstype.int32,), self.name)
validator.check_type_name("depth", depth['dtype'], mstype.int_type, self.name)
args = {"on_value": on_value['dtype'], "off_value": off_value['dtype']}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
# check shape
indices_shp = indices['shape']
@ -2685,7 +2674,7 @@ class Gelu(PrimitiveWithInfer):
return input_x
def infer_dtype(self, input_x):
validator.check_tensor_type_same({"input_x": input_x}, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("input_x", input_x, (mstype.float16, mstype.float32), self.name)
return input_x
@ -2804,9 +2793,9 @@ class PReLU(PrimitiveWithInfer):
return input_x_shape
def infer_dtype(self, input_x_dtype, weight_dtype):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"input_x": input_x_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"weight": weight_dtype}, valid_types, self.name)
valid_dtypes = (mstype.float16, mstype.float32)
validator.check_tensor_dtype_valid("input_x", input_x_dtype, valid_dtypes, self.name)
validator.check_tensor_dtype_valid("weight", weight_dtype, valid_dtypes, self.name)
return input_x_dtype
@ -2877,7 +2866,7 @@ class LSTM(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, h_dtype, c_dtype, w_dtype):
args = {'x': x_dtype, 'h': h_dtype, 'c': c_dtype, 'w': w_dtype}
validator.check_tensor_type_same(args, (mstype.float32, mstype.float16), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32, mstype.float16), self.name)
return (x_dtype, x_dtype, x_dtype, x_dtype, x_dtype)
def rnd_up(self, current_offset, page_size):
@ -2930,7 +2919,7 @@ class SigmoidCrossEntropyWithLogits(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, y_dtype):
args = {"x_dtype": x_dtype, "y_dtype": y_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return x_dtype
@ -3123,9 +3112,9 @@ class ROIAlign(PrimitiveWithInfer):
return [rois_shape[0], inputs_shape[1], self.pooled_height, self.pooled_width]
def infer_dtype(self, inputs_type, rois_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"inputs_type": inputs_type}, valid_types, self.name)
validator.check_tensor_type_same({"rois_type": rois_type}, valid_types, self.name)
valid_dtypes = (mstype.float16, mstype.float32)
validator.check_tensor_dtype_valid("inputs_type", inputs_type, valid_dtypes, self.name)
validator.check_tensor_dtype_valid("rois_type", rois_type, valid_dtypes, self.name)
return inputs_type
@ -3199,6 +3188,7 @@ class Adam(PrimitiveWithInfer):
>>> gradient = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
>>> result = net(0.9, 0.999, 0.001, 0.9, 0.999, 1e-8, gradient)
"""
@prim_attr_register
def __init__(self, use_locking=False, use_nesterov=False):
validator.check_value_type("use_locking", use_locking, [bool], self.name)
@ -3214,11 +3204,11 @@ class Adam(PrimitiveWithInfer):
def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype,
beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype):
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
args = {"beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype,
"beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype}
validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name, True)
validator.check_scalar_or_tensor_types_same(args, [mstype.float16, mstype.float32], self.name, True)
return var_dtype, m_dtype, v_dtype
@ -3345,12 +3335,12 @@ class FusedSparseAdam(PrimitiveWithInfer):
def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype,
beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype, indices_dtype):
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
args = {"beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype,
"beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype}
validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name, True)
validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name)
validator.check_scalar_or_tensor_types_same(args, [mstype.float16, mstype.float32], self.name, True)
validator.check_tensor_dtype_valid("indices_dtype", indices_dtype, [mstype.int32], self.name)
return var_dtype, m_dtype, v_dtype
@ -3478,13 +3468,13 @@ class FusedSparseLazyAdam(PrimitiveWithInfer):
def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype,
beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype, indices_dtype):
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
args = {"beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype,
"beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype}
validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name, True)
validator.check_scalar_or_tensor_types_same(args, [mstype.float16, mstype.float32], self.name, True)
validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid("indices_dtype", indices_dtype, [mstype.int32], self.name)
return var_dtype, m_dtype, v_dtype
@ -3578,8 +3568,8 @@ class FusedSparseFtrl(PrimitiveWithInfer):
def infer_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype):
args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype,
"linear_dtype": linear_dtype, "grad_dtype": grad_dtype}
validator.check_tensor_type_same(args, [mstype.float32], self.name)
validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float32], self.name)
validator.check_tensor_dtype_valid("indices_dtype", indices_dtype, [mstype.int32], self.name)
return var_dtype, accum_dtype, linear_dtype
@ -3665,13 +3655,13 @@ class FusedSparseProximalAdagrad(PrimitiveWithInfer):
def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype):
args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
validator.check_tensor_type_same(args, [mstype.float32], self.name)
validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, [mstype.float32], self.name)
validator.check_scalar_or_tensor_type_same({"l1": l1_dtype}, [mstype.float32], self.name)
validator.check_scalar_or_tensor_type_same({"l2": l2_dtype}, [mstype.float32], self.name)
valid_types = [mstype.int16, mstype.int32, mstype.int64,
mstype.uint16, mstype.uint32, mstype.uint64]
validator.check_tensor_type_same({'indices': indices_dtype}, valid_types, self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float32], self.name)
validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, [mstype.float32], self.name)
validator.check_scalar_or_tensor_types_same({"l1": l1_dtype}, [mstype.float32], self.name)
validator.check_scalar_or_tensor_types_same({"l2": l2_dtype}, [mstype.float32], self.name)
valid_dtypes = [mstype.int16, mstype.int32, mstype.int64,
mstype.uint16, mstype.uint32, mstype.uint64]
validator.check_tensor_dtype_valid('indices', indices_dtype, valid_dtypes, self.name)
return var_dtype, accum_dtype
@ -3742,8 +3732,8 @@ class KLDivLoss(PrimitiveWithInfer):
def infer_dtype(self, x_type, y_type):
args = {'x': x_type, 'y': y_type}
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same(args, valid_types, self.name)
valid_dtypes = (mstype.float16, mstype.float32)
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
return x_type
@ -3820,10 +3810,10 @@ class BinaryCrossEntropy(PrimitiveWithInfer):
def infer_dtype(self, x_type, y_type, weight_type):
args = {'x': x_type, 'y': y_type}
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same(args, valid_types, self.name)
valid_dtypes = (mstype.float16, mstype.float32)
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
if weight_type:
validator.check_tensor_type_same({'x': x_type, 'weight': weight_type}, valid_types, self.name)
validator.check_tensors_dtypes_same_and_valid({'x': x_type, 'weight': weight_type}, valid_dtypes, self.name)
return x_type
@ -3950,14 +3940,14 @@ class ApplyAdaMax(PrimitiveWithInfer):
def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, lr_dtype,
beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype):
valid_types = [mstype.float16, mstype.float32]
valid_dtypes = [mstype.float16, mstype.float32]
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype}
validator.check_tensor_type_same(args, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"beta1_power": beta1_power_dtype}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"beta1": beta1_dtype}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"beta2": beta2_dtype}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"epsilon": epsilon_dtype}, valid_types, self.name)
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"beta1_power": beta1_power_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"beta1": beta1_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"beta2": beta2_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"epsilon": epsilon_dtype}, valid_dtypes, self.name)
return var_dtype, m_dtype, v_dtype
@ -4058,12 +4048,12 @@ class ApplyAdadelta(PrimitiveWithInfer):
def infer_dtype(self, var_dtype, accum_dtype, accum_update_dtype, lr_dtype, rho_dtype,
epsilon_dtype, grad_dtype):
valid_types = [mstype.float16, mstype.float32]
valid_dtypes = [mstype.float16, mstype.float32]
args = {"var": var_dtype, "accum": accum_dtype, "accum_update": accum_update_dtype, "grad": grad_dtype}
validator.check_tensor_type_same(args, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"rho": rho_dtype}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"epsilon": epsilon_dtype}, valid_types, self.name)
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"rho": rho_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"epsilon": epsilon_dtype}, valid_dtypes, self.name)
return var_dtype, accum_dtype, accum_update_dtype
@ -4142,9 +4132,9 @@ class ApplyAdagrad(PrimitiveWithInfer):
def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, grad_dtype):
args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
valid_types = [mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({'lr': lr_dtype}, valid_types, self.name)
valid_dtypes = [mstype.float16, mstype.float32]
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({'lr': lr_dtype}, valid_dtypes, self.name)
return var_dtype, accum_dtype
@ -4226,8 +4216,8 @@ class ApplyAdagradV2(PrimitiveWithInfer):
def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, grad_dtype):
args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
validator.check_scalar_or_tensor_type_same({'lr': lr_dtype}, [mstype.float16, mstype.float32], self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
validator.check_scalar_or_tensor_types_same({'lr': lr_dtype}, [mstype.float16, mstype.float32], self.name)
return var_dtype, accum_dtype
@ -4313,8 +4303,8 @@ class SparseApplyAdagrad(PrimitiveWithInfer):
def infer_dtype(self, var_type, accum_type, grad_type, indices_type):
args = {'var': var_type, 'accum': accum_type, 'grad': grad_type}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_type_same({'indices': indices_type}, [mstype.int32], self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid('indices', indices_type, [mstype.int32], self.name)
return var_type, accum_type
@ -4402,8 +4392,8 @@ class SparseApplyAdagradV2(PrimitiveWithInfer):
def infer_dtype(self, var_type, accum_type, grad_type, indices_type):
args = {'var': var_type, 'accum': accum_type, 'grad': grad_type}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_type_same({'indices': indices_type}, [mstype.int32], self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid('indices', indices_type, [mstype.int32], self.name)
return var_type, accum_type
@ -4500,12 +4490,12 @@ class ApplyProximalAdagrad(PrimitiveWithInfer):
return var_shape, accum_shape
def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype):
valid_types = [mstype.float16, mstype.float32]
valid_dtypes = [mstype.float16, mstype.float32]
args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
validator.check_tensor_type_same(args, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"l1": l1_dtype}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"l2": l2_dtype}, valid_types, self.name)
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"l1": l1_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"l2": l2_dtype}, valid_dtypes, self.name)
return var_dtype, accum_dtype
@ -4594,13 +4584,13 @@ class SparseApplyProximalAdagrad(PrimitiveWithCheck):
def check_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype):
args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, [mstype.float16, mstype.float32], self.name)
validator.check_scalar_or_tensor_type_same({"l1": l1_dtype}, [mstype.float16, mstype.float32], self.name)
validator.check_scalar_or_tensor_type_same({"l2": l2_dtype}, [mstype.float16, mstype.float32], self.name)
valid_types = [mstype.int16, mstype.int32, mstype.int64,
mstype.uint16, mstype.uint32, mstype.uint64]
validator.check_tensor_type_same({'indices': indices_dtype}, valid_types, self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, [mstype.float16, mstype.float32], self.name)
validator.check_scalar_or_tensor_types_same({"l1": l1_dtype}, [mstype.float16, mstype.float32], self.name)
validator.check_scalar_or_tensor_types_same({"l2": l2_dtype}, [mstype.float16, mstype.float32], self.name)
valid_dtypes = [mstype.int16, mstype.int32, mstype.int64,
mstype.uint16, mstype.uint32, mstype.uint64]
validator.check_tensor_dtype_valid('indices', indices_dtype, valid_dtypes, self.name)
class ApplyAddSign(PrimitiveWithInfer):
@ -4699,13 +4689,13 @@ class ApplyAddSign(PrimitiveWithInfer):
return var_shape, m_shape
def infer_dtype(self, var_dtype, m_dtype, lr_dtype, alpha_dtype, sign_decay_dtype, beta_dtype, grad_dtype):
valid_types = [mstype.float16, mstype.float32]
valid_dtypes = [mstype.float16, mstype.float32]
args = {'var': var_dtype, 'm': m_dtype, 'grad': grad_dtype}
validator.check_tensor_type_same(args, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"alpha": alpha_dtype}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"sign_decay": sign_decay_dtype}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"beta": beta_dtype}, valid_types, self.name)
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"alpha": alpha_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"sign_decay": sign_decay_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"beta": beta_dtype}, valid_dtypes, self.name)
return var_dtype, m_dtype
@ -4808,13 +4798,13 @@ class ApplyPowerSign(PrimitiveWithInfer):
return var_shape, m_shape
def infer_dtype(self, var_dtype, m_dtype, lr_dtype, logbase_dtype, sign_decay_dtype, beta_dtype, grad_dtype):
valid_types = [mstype.float16, mstype.float32]
valid_dtypes = [mstype.float16, mstype.float32]
args = {'var': var_dtype, 'm': m_dtype, 'grad': grad_dtype}
validator.check_tensor_type_same(args, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"logbase": logbase_dtype}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"sign_decay": sign_decay_dtype}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"beta": beta_dtype}, valid_types, self.name)
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"logbase": logbase_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"sign_decay": sign_decay_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"beta": beta_dtype}, valid_dtypes, self.name)
return var_dtype, m_dtype
@ -4876,10 +4866,10 @@ class ApplyGradientDescent(PrimitiveWithInfer):
return var_shape
def infer_dtype(self, var_dtype, alpha_dtype, delta_dtype):
valid_types = [mstype.float16, mstype.float32]
valid_dtypes = [mstype.float16, mstype.float32]
args = {'var': var_dtype, 'delta': delta_dtype}
validator.check_tensor_type_same(args, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"alpha": alpha_dtype}, valid_types, self.name)
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"alpha": alpha_dtype}, valid_dtypes, self.name)
return var_dtype
@ -4959,12 +4949,12 @@ class ApplyProximalGradientDescent(PrimitiveWithInfer):
return var_shape
def infer_dtype(self, var_dtype, alpha_dtype, l1_dtype, l2_dtype, delta_dtype):
valid_types = [mstype.float16, mstype.float32]
valid_dtypes = [mstype.float16, mstype.float32]
args = {'var': var_dtype, 'delta': delta_dtype}
validator.check_tensor_type_same(args, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"alpha": alpha_dtype}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"l1": l1_dtype}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"l2": l2_dtype}, valid_types, self.name)
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"alpha": alpha_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"l1": l1_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"l2": l2_dtype}, valid_dtypes, self.name)
return var_dtype
@ -5036,11 +5026,13 @@ class LARSUpdate(PrimitiveWithInfer):
weight_decay_dtype, learning_rate_dtype):
args = {"Weight dtype": weight_dtype, "gradient dtype": gradient_dtype, "norm weight dtype": norm_weight_dtype,
"norm gradient dtype": norm_gradient_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int16, mstype.int32], self.name)
validator.check_scalar_or_tensor_type_same({"weight_decay": weight_decay_dtype},
[mstype.float16, mstype.float32, mstype.float64], self.name)
validator.check_scalar_or_tensor_type_same({"learning_rate": learning_rate_dtype},
[mstype.float16, mstype.float32, mstype.float64], self.name)
validator.check_tensors_dtypes_same_and_valid(args,
[mstype.float16, mstype.float32, mstype.int16, mstype.int32],
self.name)
validator.check_scalar_or_tensor_types_same({"weight_decay": weight_decay_dtype},
[mstype.float16, mstype.float32, mstype.float64], self.name)
validator.check_scalar_or_tensor_types_same({"learning_rate": learning_rate_dtype},
[mstype.float16, mstype.float32, mstype.float64], self.name)
return weight_dtype
@ -5117,14 +5109,14 @@ class ApplyFtrl(PrimitiveWithInfer):
return var_shape
def infer_dtype(self, var_type, accum_type, linear_type, grad_type, lr_type, l1_type, l2_type, lr_power_type):
valid_types = [mstype.float16, mstype.float32]
valid_dtypes = [mstype.float16, mstype.float32]
args = {'var': var_type, 'accum': accum_type, 'linear': linear_type, 'grad': grad_type}
validator.check_tensor_type_same(args, valid_types, self.name)
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
validator.check_scalar_or_tensor_type_same({"lr": lr_type}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"l1": l1_type}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"l2": l2_type}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"lr_power": lr_power_type}, valid_types, self.name)
validator.check_scalar_or_tensor_types_same({"lr": lr_type}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"l1": l1_type}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"l2": l2_type}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"lr_power": lr_power_type}, valid_dtypes, self.name)
if self.is_tbe:
return var_type, var_type, var_type
return var_type
@ -5219,8 +5211,8 @@ class SparseApplyFtrl(PrimitiveWithCheck):
def check_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype):
args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype,
"linear_dtype": linear_dtype, "grad_dtype": grad_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32, mstype.int64], self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid("indices_dtype", indices_dtype, [mstype.int32, mstype.int64], self.name)
class SparseApplyFtrlV2(PrimitiveWithInfer):
@ -5316,8 +5308,8 @@ class SparseApplyFtrlV2(PrimitiveWithInfer):
def infer_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype):
args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype,
"linear_dtype": linear_dtype, "grad_dtype": grad_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid("indicese", indices_dtype, [mstype.int32], self.name)
return var_dtype, accum_dtype, linear_dtype
@ -5351,9 +5343,8 @@ class Dropout(PrimitiveWithInfer):
return x_shape, mask_shape
def infer_dtype(self, x_dtype):
valid_types = (mstype.float16, mstype.float32)
validator.check_subclass("x", x_dtype, mstype.tensor, self.name)
validator.check_tensor_type_same({"x_dtype": x_dtype}, valid_types, self.name)
valid_dtypes = (mstype.float16, mstype.float32)
validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
return x_dtype, x_dtype
@ -5425,10 +5416,10 @@ class CTCLoss(PrimitiveWithInfer):
def infer_dtype(self, inputs, labels_indices, labels_values, sequence_length):
valid_dtype = [mstype.float16, mstype.float32, mstype.double]
validator.check_tensor_type_same({"inputs_dtype": inputs}, valid_dtype, self.name)
validator.check_tensor_type_same({"labels_indices_dtype": labels_indices}, [mstype.int64], self.name)
validator.check_tensor_type_same({"labels_values_dtype": labels_values}, [mstype.int32], self.name)
validator.check_tensor_type_same({"sequence_length_dtype": sequence_length}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid("inputs", inputs, valid_dtype, self.name)
validator.check_tensor_dtype_valid("labels_indices", labels_indices, [mstype.int64], self.name)
validator.check_tensor_dtype_valid("labels_values", labels_values, [mstype.int32], self.name)
validator.check_tensor_dtype_valid("sequence_length", sequence_length, [mstype.int32], self.name)
return inputs, inputs
@ -5492,8 +5483,8 @@ class CTCGreedyDecoder(PrimitiveWithInfer):
return decoded_indices_shape, decoded_values, decoded_shape, log_probability_shape
def infer_dtype(self, inputs_dtype, sequence_length_dtype):
validator.check_tensor_type_same({"inputs_dtype": inputs_dtype}, [mstype.float32, mstype.double], self.name)
validator.check_tensor_type_same({"sequence_length_dtype": sequence_length_dtype}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid("inputs_dtype", inputs_dtype, [mstype.float32, mstype.double], self.name)
validator.check_tensor_dtype_valid("sequence_length_dtype", sequence_length_dtype, [mstype.int32], self.name)
decoded_type = mstype.tensor_type(mstype.int64)
return decoded_type, decoded_type, decoded_type, inputs_dtype
@ -5597,12 +5588,12 @@ class BasicLSTMCell(PrimitiveWithInfer):
return (ct_shape, ht_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape)
def infer_dtype(self, x_dtype, h_dtype, c_dtype, w_dtype, b_dtype):
validator.check_tensor_type_same({"x_dtype": x_dtype}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_type_same({"h_dtype": h_dtype}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_type_same({"w_dtype": w_dtype}, [mstype.float16, mstype.float32], self.name)
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
("x_dtype", "h_dtype", "w_dtype"),
(x_dtype, h_dtype, w_dtype)))
args = {"c_dtype": c_dtype, "b_dtype": b_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
return (c_dtype, mstype.float16, c_dtype, c_dtype, c_dtype, c_dtype, c_dtype)
@ -5725,11 +5716,10 @@ class DynamicRNN(PrimitiveWithInfer):
return y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape
def infer_dtype(self, x_dtype, w_dtype, b_dtype, seq_dtype, h_dtype, c_dtype):
validator.check_tensor_type_same({"x dtype": x_dtype}, (mstype.float16,), self.name)
validator.check_tensor_type_same({"w dtype": w_dtype}, (mstype.float16,), self.name)
validator.check_tensor_type_same({"b dtype": b_dtype}, (mstype.float32, mstype.float16), self.name)
validator.check_tensor_type_same({"h dtype": h_dtype}, (mstype.float16,), self.name)
validator.check_tensor_type_same({"c dtype": c_dtype}, (mstype.float16,), self.name)
tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=mstype.float16, prim_name=self.name),
("x", "w", "h", "c"),
(x_dtype, w_dtype, h_dtype, c_dtype)))
validator.check_tensor_dtype_valid("b", b_dtype, (mstype.float16, mstype.float32), self.name)
return b_dtype, x_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype
@ -5765,8 +5755,8 @@ class InTopK(PrimitiveWithInfer):
validator.check_value_type("k", k, [int], self.name)
def infer_dtype(self, x1_dtype, x2_dtype):
validator.check_tensor_type_same({"x1": x1_dtype}, (mstype.float16, mstype.float32,), self.name)
validator.check_tensor_type_same({"x2": x2_dtype}, (mstype.int32,), self.name)
validator.check_tensor_dtype_valid("x1", x1_dtype, (mstype.float16, mstype.float32,), self.name)
validator.check_tensor_dtype_valid("x2", x2_dtype, (mstype.int32,), self.name)
return mstype.tensor_type(mstype.bool_)
@ -5803,6 +5793,7 @@ class LRN(PrimitiveWithInfer):
[[0.6258911 0.4964315 ]
[0.3141494 0.43636137]]]]
"""
@prim_attr_register
def __init__(self, depth_radius=5, bias=1.0, alpha=1.0, beta=0.5, norm_region="ACROSS_CHANNELS"):
"""Initialize LRN"""
@ -5816,7 +5807,7 @@ class LRN(PrimitiveWithInfer):
validator.check_non_negative_int(depth_radius, "depth_radius", self.name)
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32,), self.name)
validator.check_tensor_dtype_valid("x", x_dtype, (mstype.float16, mstype.float32,), self.name)
return x_dtype
def infer_shape(self, x_shape):
@ -5857,6 +5848,7 @@ class UniformSampler(PrimitiveWithInfer):
[3]], dtype=np.int32)))
[1, 1, 3], [[0.75], [0.75], [0.75], [0.75], [0.75]], [0.75, 0.75, 0.75]
"""
@prim_attr_register
def __init__(self, num_true, num_sampled, unique, range_max, seed=0, remove_accidental_hits=False):
"""Initialize UniformSampler"""

View File

@ -61,8 +61,8 @@ class Assign(PrimitiveWithCheck):
def check_dtype(self, variable, value):
if variable != mstype.type_refkey:
validator.check_tensor_type_same({"variable": variable}, mstype.number_type, self.name)
validator.check_scalar_or_tensor_type_same({"value": value}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid("variable", variable, mstype.number_type, self.name)
validator.check_scalar_or_tensor_types_same({"value": value}, mstype.number_type, self.name)
class BoundingBoxEncode(PrimitiveWithInfer):
@ -112,7 +112,7 @@ class BoundingBoxEncode(PrimitiveWithInfer):
def infer_dtype(self, anchor_box, groundtruth_box):
args = {"anchor_box": anchor_box, "groundtruth_box": groundtruth_box}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return anchor_box
@ -169,7 +169,7 @@ class BoundingBoxDecode(PrimitiveWithInfer):
def infer_dtype(self, anchor_box, deltas):
args = {"anchor_box": anchor_box, "deltas": deltas}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return anchor_box
@ -221,8 +221,8 @@ class CheckValid(PrimitiveWithInfer):
def infer_dtype(self, bboxes_type, metas_type):
valid_type = [mstype.float32, mstype.float16, mstype.int16, mstype.uint8]
validator.check_tensor_type_same({"bboxes_type": bboxes_type}, valid_type, self.name)
validator.check_tensor_type_same({"metas_type": metas_type}, valid_type, self.name)
validator.check_tensor_dtype_valid("bboxes_type", bboxes_type, valid_type, self.name)
validator.check_tensor_dtype_valid("metas_type", metas_type, valid_type, self.name)
return mstype.bool_
@ -281,8 +281,8 @@ class IOU(PrimitiveWithInfer):
def infer_dtype(self, anchor_boxes, gt_boxes):
valid_type = [mstype.float32, mstype.float16]
validator.check_tensor_type_same({"anchor_boxes": anchor_boxes}, valid_type, self.name)
validator.check_tensor_type_same({"gt_boxes": gt_boxes}, valid_type, self.name)
validator.check_tensor_dtype_valid("anchor_boxes", anchor_boxes, valid_type, self.name)
validator.check_tensor_dtype_valid("gt_boxes", gt_boxes, valid_type, self.name)
return anchor_boxes
@ -478,7 +478,7 @@ class ConfusionMatrix(PrimitiveWithInfer):
if weights is not None:
validator.check_subclass('weights', weights, mstype.tensor, self.name)
args = {"labels": labels, "predictions": predictions}
validator.check_tensor_type_same(args, (mstype.number_type), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.number_type), self.name)
return labels
@ -506,8 +506,7 @@ class PopulationCount(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
args = {"x": x_dtype}
validator.check_tensor_type_same(args, (mstype.int16, mstype.uint16,), self.name)
validator.check_tensor_dtype_valid("x", x_dtype, (mstype.int16, mstype.uint16,), self.name)
return mstype.tensor_type(mstype.uint8)
class Push(PrimitiveWithInfer):

View File

@ -151,8 +151,8 @@ class Gamma(PrimitiveWithInfer):
Validator.check_value_type("shape", shape_v, [tuple], self.name)
for i, shape_i in enumerate(shape_v):
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
Validator.check_tensor_type_same({"alpha": alpha["dtype"]}, [mstype.float32], self.name)
Validator.check_tensor_type_same({"beta": beta["dtype"]}, [mstype.float32], self.name)
Validator.check_tensor_dtype_valid("alpha", alpha["dtype"], [mstype.float32], self.name)
Validator.check_tensor_dtype_valid("beta", beta["dtype"], [mstype.float32], self.name)
broadcast_shape = get_broadcast_shape(alpha['shape'], beta['shape'], self.name)
broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name)
out = {
@ -203,7 +203,7 @@ class Poisson(PrimitiveWithInfer):
Validator.check_value_type("shape", shape_v, [tuple], self.name)
for i, shape_i in enumerate(shape_v):
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
Validator.check_tensor_type_same({"mean": mean["dtype"]}, [mstype.float32], self.name)
Validator.check_tensor_dtype_valid("mean", mean["dtype"], [mstype.float32], self.name)
broadcast_shape = get_broadcast_shape(mean['shape'], shape_v, self.name)
out = {
'shape': broadcast_shape,
@ -259,8 +259,8 @@ class UniformInt(PrimitiveWithInfer):
Validator.check_value_type("shape", shape_v, [tuple], self.name)
for i, shape_i in enumerate(shape_v):
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
Validator.check_tensor_type_same({"minval": minval["dtype"]}, [mstype.int32], self.name)
Validator.check_tensor_type_same({"maxval": maxval["dtype"]}, [mstype.int32], self.name)
Validator.check_tensor_dtype_valid("minval", minval["dtype"], [mstype.int32], self.name)
Validator.check_tensor_dtype_valid("maxval", maxval["dtype"], [mstype.int32], self.name)
minval_shape = minval['shape']
maxval_shape = maxval['shape']
Validator.check("dim of minval", len(minval_shape), '0(scalar)', 0, Rel.EQ, self.name)
@ -361,7 +361,7 @@ class RandomChoiceWithMask(PrimitiveWithInfer):
return ([self.count, len(x_shape)], [self.count])
def infer_dtype(self, x_dtype):
Validator.check_tensor_type_same({'x': x_dtype}, [mstype.bool_], self.name)
Validator.check_tensor_dtype_valid('x', x_dtype, [mstype.bool_], self.name)
return (mstype.int32, mstype.bool_)
@ -407,8 +407,8 @@ class RandomCategorical(PrimitiveWithInfer):
def __infer__(self, logits, num_samples, seed):
logits_dtype = logits['dtype']
valid_types = (mstype.float32, mstype.float16, mstype.float64)
Validator.check_tensor_type_same({'logits': logits_dtype}, valid_types, self.name)
valid_dtypes = (mstype.float32, mstype.float16, mstype.float64)
Validator.check_tensor_dtype_valid('logits', logits_dtype, valid_dtypes, self.name)
num_samples_v = num_samples['value']
seed_v = seed['value']
Validator.check_value_type('num_samples', num_samples_v, (int,), self.name)
@ -460,7 +460,7 @@ class Multinomial(PrimitiveWithInfer):
input_shape = inputs["shape"]
if len(input_shape) != 1 and len(input_shape) != 2:
raise ValueError("input dim must be 1 or 2")
Validator.check_tensor_type_same({'inputs': inputs['dtype']}, [mstype.float32], self.name)
Validator.check_tensor_dtype_valid('inputs', inputs['dtype'], [mstype.float32], self.name)
num_samples_value = num_samples["value"]
if num_samples_value is None:
raise ValueError(f"For {self.name}, shape nust be const")

View File

@ -588,8 +588,8 @@ def _quant_export(network, *inputs, file_format, **kwargs):
if quant_mode not in quant_mode_formats:
raise KeyError(f'Quant_mode input is wrong, Please choose the right mode of the quant_mode.')
mean = Validator.check_type("mean", mean, (int, float))
std_dev = Validator.check_type("std_dev", std_dev, (int, float))
mean = Validator.check_value_type("mean", mean, (int, float))
std_dev = Validator.check_value_type("std_dev", std_dev, (int, float))
if context.get_context('device_target') not in supported_device:
raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))

View File

@ -117,7 +117,7 @@ class MySparseGatherV2(PrimitiveWithInfer):
def __infer__(self, params, indices, axis):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name)
axis_v = axis['value']
params_shp = params['shape']