From 8bb93411f3845ceac91171a3c511f42bf88ca245 Mon Sep 17 00:00:00 2001 From: fary86 Date: Thu, 16 Apr 2020 11:41:51 +0800 Subject: [PATCH] Add prim name to error message for _grad_ops.py --- mindspore/_checkparam.py | 2 +- mindspore/ops/_utils/utils.py | 20 +- mindspore/ops/operations/_grad_ops.py | 271 ++++++++++++-------------- mindspore/ops/operations/array_ops.py | 2 +- 4 files changed, 132 insertions(+), 163 deletions(-) diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index 3543f58cf5a..bf24b7e522a 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -206,8 +206,8 @@ class Validator: def _check_tensor_type(arg): arg_key, arg_val = arg elem_type = arg_val - type_names = [] if not elem_type in valid_values: + type_names = [] for t in valid_values: type_names.append(str(t)) types_info = '[' + ", ".join(type_names) + ']' diff --git a/mindspore/ops/_utils/utils.py b/mindspore/ops/_utils/utils.py index fbd81c4f0d9..90496afc9bd 100644 --- a/mindspore/ops/_utils/utils.py +++ b/mindspore/ops/_utils/utils.py @@ -15,7 +15,7 @@ """utils for operator""" -from ..._checkparam import ParamValidator as validator +from ..._checkparam import Validator as validator from ..._checkparam import Rel from ...common import dtype as mstype @@ -62,25 +62,25 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name): return broadcast_shape -def _get_concat_offset(x_shp, x_type, axis): +def _get_concat_offset(x_shp, x_type, axis, prim_name): """for concat and concatoffset check args and compute offset""" - validator.check_type("shape", x_shp, [tuple]) - validator.check_integer("len of input_x shape", len(x_shp), 0, Rel.GT) - validator.check_subclass("shape0", x_type[0], mstype.tensor) - validator.check_integer("len of input_x0 shape", len(x_shp[0]), 0, Rel.GT) + validator.check_value_type("shape", x_shp, [tuple], prim_name) + validator.check_integer("input_x rank", len(x_shp), 0, Rel.GT, prim_name) + validator.check_subclass("shape0", x_type[0], mstype.tensor, prim_name) + validator.check_integer("len of x_shp[0]", len(x_shp[0]), 0, Rel.GT, prim_name) rank_base = len(x_shp[0]) - validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH) + validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH, prim_name) if axis < 0: axis = axis + rank_base all_shp = x_shp[0][axis] offset = [0,] for i in range(1, len(x_shp)): v = x_shp[i] - validator.check('len of x_shp[%d]' % i, len(v), 'len of base', len(x_shp[0])) - validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0]) + validator.check('len of x_shp[%d]' % i, len(v), 'len of x_shp[0]', len(x_shp[0]), Rel.EQ, prim_name) + validator.check('x_type[%d]' % i, x_type[i], 'x_type[0]', x_type[0], Rel.EQ, prim_name) for j in range(rank_base): if j != axis and v[j] != x_shp[0][j]: - raise ValueError("Concat evaluator element %d shape in input can not concat with first element" % i) + raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not concat with first element") offset.append(all_shp) all_shp += v[axis] return offset, all_shp, axis diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index c29832dcb76..e130dcc3825 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -18,8 +18,7 @@ from ..._c_expression import signature_rw as sig_rw from ..._c_expression import signature_kind as sig_kind from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register -from ..._checkparam import ParamValidator as validator -from ..._checkparam import Rel, check_int_positive, check_bool +from ..._checkparam import Validator as validator, Rel from .._utils import _get_concat_offset from ...common import dtype as mstype @@ -51,12 +50,12 @@ class ACosGrad(PrimitiveWithInfer): """init ACosGrad""" def infer_shape(self, x, dout): - validator.check_param_equal("x", x, "dout", dout) + validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name) return x def infer_dtype(self, x, dout): args = {"x": x, "dout": dout} - validator.check_type_same(args, mstype.number_type) + validator.check_tensor_type_same(args, mstype.number_type, self.name) return x @@ -65,8 +64,8 @@ class BatchNormGrad(PrimitiveWithInfer): @prim_attr_register def __init__(self, is_training=False, epsilon=1e-5): - self.is_training = validator.check_type('is_training', is_training, (bool,)) - self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT) + self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) + self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) self.add_prim_attr('data_format', "NCHW") def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape, reserve_3_shape): @@ -93,19 +92,19 @@ class BinaryCrossEntropyGrad(PrimitiveWithInfer): """Computes gradients for `BinaryCrossEntropy` operation.""" @prim_attr_register def __init__(self, reduction='mean'): - self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum']) + self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum'], self.name) def infer_shape(self, x_shape, y_shape, doutput_shape, weight_shape): - validator.check_param_equal('x_shape', x_shape, 'y_shape', y_shape) + validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name) if weight_shape: - validator.check_param_equal('y_shape', y_shape, 'weight_shape', weight_shape) + validator.check('y_shape', y_shape, 'weight_shape', weight_shape, Rel.EQ, self.name) return x_shape def infer_dtype(self, x_type, y_type, doutput_type, weight_type): args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type} - validator.check_type_same(args, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) if weight_type: - validator.check_two_types_same('x_type', x_type, 'weight_type', weight_type) + validator.check('x_type', x_type, 'weight_type', weight_type, Rel.EQ, TypeError) return x_type @@ -120,7 +119,7 @@ class ConcatOffset(PrimitiveWithInfer): axis = self.axis x_shp = input_x['shape'] x_type = input_x['dtype'] - offset, _, axis = _get_concat_offset(x_shp, x_type, axis) + offset, _, axis = _get_concat_offset(x_shp, x_type, axis, self.name) self.add_prim_attr('T', x_type[0].element_type()) offset_values = [] for i in range(len(x_shp)): @@ -184,11 +183,11 @@ class Conv2DBackpropFilter(PrimitiveWithInfer): def __infer__(self, doutput, x, w_size): w_size_v = w_size['value'] - validator.check_type('w_size', w_size_v, [tuple]) + validator.check_value_type('w_size', w_size_v, [tuple], self.name) for i, dim_len in enumerate(w_size_v): - validator.check_type("w_size[%d]" % i, dim_len, [int]) - validator.check_typename('x_dtype', x['dtype'], [mstype.int8, mstype.int32, mstype.float16, mstype.float32]) - validator.check_two_types_same('doutput_dtype', doutput['dtype'], 'x_dtype', x['dtype']) + validator.check_value_type("w_size[%d]" % i, dim_len, [int], self.name) + args = {"x": x['dtype'], "doutput": doutput['dtype']} + validator.check_tensor_type_same(args, [mstype.int8, mstype.int32, mstype.float16, mstype.float32], self.name) out = { 'value': None, 'shape': w_size_v, @@ -250,8 +249,8 @@ class DepthwiseConv2dNativeBackpropFilter(PrimitiveWithInfer): def __infer__(self, x, w_size, dout): w_size_v = w_size['value'] - args = {'x_dtype': x['dtype'], 'dout_type': dout['dtype']} - validator.check_type_same(args, mstype.number_type) + args = {'x': x['dtype'], 'dout': dout['dtype']} + validator.check_tensor_type_same(args, mstype.number_type, self.name) out = { 'value': None, 'shape': w_size_v, @@ -310,8 +309,8 @@ class DepthwiseConv2dNativeBackpropInput(PrimitiveWithInfer): raise NotImplementedError def __infer__(self, x_size, w, dout): - args = {'w_dtype': w['dtype'], 'dout_type': dout['dtype']} - validator.check_type_same(args, mstype.number_type) + args = {'w': w['dtype'], 'dout': dout['dtype']} + validator.check_tensor_type_same(args, mstype.number_type, self.name) x_size_v = x_size['value'] out = { 'value': None, @@ -360,9 +359,9 @@ class GeluGrad(PrimitiveWithInfer): return x_shape def infer_dtype(self, y_backprop_dtype, x_dtype, y_dtype): - validator.check_typename("y_backprop_dtype", y_backprop_dtype, (mstype.float16, mstype.float32)) - validator.check_typename("x_dtype", x_dtype, (mstype.float16, mstype.float32)) - validator.check_typename("y_dtype", y_dtype, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same({"y_backprop": y_backprop_dtype}, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_type_same({"y": y_dtype}, (mstype.float16, mstype.float32), self.name) return x_dtype @@ -373,56 +372,36 @@ class _PoolGrad(PrimitiveWithInfer): def __init__(self, ksize, strides, padding="VALID"): self.init_prim_io_names(inputs=['x_origin', 'out_origin', 'grad'], outputs=['output']) - validator.check_type('ksize', ksize, [int, tuple]) - validator.check_type('strides', strides, [int, tuple]) - self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME']) + validator.check_value_type('ksize', ksize, [int, tuple], self.name) + validator.check_value_type('strides', strides, [int, tuple], self.name) + self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name) self.add_prim_attr("padding", self.padding) self.is_maxpoolgradwithargmax = (self.name == "MaxPoolGradWithArgmax") if not self.is_maxpoolgradwithargmax: self.add_prim_attr('data_format', "NCHW") - if isinstance(ksize, int): - validator.check_integer("ksize", ksize, 1, Rel.GE) - if self.is_maxpoolgradwithargmax: - self.ksize = (1, ksize, ksize, 1) + def _grad_check_int_or_tuple(arg_name, arg_val, is_argmax): + validator.check_value_type(arg_name, arg_val, (int, tuple), self.name) + error_msg = ValueError(f"For '{self.name}' the '{arg_name}' should be an positive int number " + f"or a tuple of two or four positive int numbers, but got {arg_val}") + if isinstance(arg_val, int): + ret = (1, arg_val, arg_val, 1) if is_argmax else (1, 1, arg_val, arg_val) + elif len(arg_val) == 2: + ret = (1, arg_val[0], arg_val[1], 1) if is_argmax else (1, 1, arg_val[0], arg_val[1]) + elif len(arg_val) == 4: + ret = arg_val else: - self.ksize = (1, 1, ksize, ksize) - else: - ksize_error = ValueError(f"The 'ksize' passed to operator {self.name} should be an positive int number" - f"or a tuple of two or four positive int numbers, but got {ksize}") - if len(ksize) != 2 and len(ksize) != 4: - raise ksize_error - for ksize_val in ksize: - if not isinstance(ksize_val, int) or (ksize_val <= 0): - raise ksize_error - if len(ksize) == 2 and self.is_maxpoolgradwithargmax: - self.ksize = (1, ksize[0], ksize[1], 1) - elif len(ksize) == 2 and not self.is_maxpoolgradwithargmax: - self.ksize = (1, 1, ksize[0], ksize[1]) - else: - self.ksize = ksize + raise error_msg + # whether all elements of tuple are positive integers + for item in ret: + if not isinstance(item, int) or item <= 0: + raise error_msg + return ret + + self.ksize = _grad_check_int_or_tuple("ksize", ksize, self.is_maxpoolgradwithargmax) self.add_prim_attr("ksize", self.ksize) - if isinstance(strides, int): - validator.check_integer("strides", strides, 1, Rel.GE) - if self.is_maxpoolgradwithargmax: - self.strides = (1, strides, strides, 1) - else: - self.strides = (1, 1, strides, strides) - else: - strides_error = ValueError(f"The 'strides' passed to operator {self.name} should be an positive int number" - f"or a tuple of two or four positive int numbers, but got {strides}") - if len(strides) != 2 and len(strides) != 4: - raise strides_error - for strides_val in strides: - if not isinstance(strides_val, int) or (strides_val <= 0): - raise strides_error - if len(strides) == 2 and self.is_maxpoolgradwithargmax: - self.strides = (1, strides[0], strides[1], 1) - elif len(strides) == 2 and not self.is_maxpoolgradwithargmax: - self.strides = (1, 1, strides[0], strides[1]) - else: - self.strides = strides + self.strides = _grad_check_int_or_tuple("strides", strides, self.is_maxpoolgradwithargmax) self.add_prim_attr("strides", self.strides) @@ -529,17 +508,17 @@ class L2NormalizeGrad(PrimitiveWithInfer): @prim_attr_register def __init__(self, axis=0, epsilon=1e-4): - validator.check_type('axis', axis, [int]) - validator.check_type('epsilon', epsilon, [int, float]) + validator.check_value_type('axis', axis, [int], self.name) + validator.check_value_type('epsilon', epsilon, [int, float], self.name) def infer_shape(self, input_x, out, dout): - validator.check_param_equal('input_x', input_x, 'out', out) - validator.check_param_equal('input_x', input_x, 'dout', dout) + validator.check('input_x shape', input_x, 'out shape', out, Rel.EQ, self.name) + validator.check('input_x shape', input_x, 'dout shape', dout, Rel.EQ, self.name) return input_x def infer_dtype(self, input_x, out, dout): args = {'input_x': input_x, 'out': out, 'dout': dout} - validator.check_type_same(args, mstype.number_type) + validator.check_tensor_type_same(args, mstype.number_type, self.name) return input_x @@ -560,8 +539,8 @@ class LayerNormGrad(Primitive): @prim_attr_register def __init__(self, begin_norm_axis=1, begin_params_axis=1): """init""" - self.begin_norm_axis = validator.check_type('begin_norm_axis', begin_norm_axis, [int]) - self.begin_params_axis = validator.check_type('begin_params_axis', begin_params_axis, [int]) + self.begin_norm_axis = validator.check_value_type('begin_norm_axis', begin_norm_axis, [int], self.name) + self.begin_params_axis = validator.check_value_type('begin_params_axis', begin_params_axis, [int], self.name) def __call__(self, x, dy, variance, mean, gamma): raise NotImplementedError @@ -573,15 +552,15 @@ class LogSoftmaxGrad(PrimitiveWithInfer): @prim_attr_register def __init__(self, axis=-1): """init LogSoftmaxGrad""" - validator.check_type("axis", axis, [int]) + validator.check_value_type("axis", axis, [int], self.name) def infer_shape(self, dout, logits): rank = len(logits) - validator.check_int_range('axis', self.axis, -rank - 1, rank, Rel.INC_BOTH) + validator.check_int_range('axis', self.axis, -rank - 1, rank, Rel.INC_BOTH, self.name) return logits def infer_dtype(self, dout, logits): - validator.check_subclass("logits", logits, mstype.tensor) + validator.check_subclass("logits", logits, mstype.tensor, self.name) return logits @@ -590,13 +569,13 @@ class LSTMGradData(PrimitiveWithInfer): @prim_attr_register def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): - self.input_size = check_int_positive(input_size) - self.hidden_size = check_int_positive(hidden_size) - self.num_layers = check_int_positive(num_layers) - self.has_bias = check_bool(has_bias) - self.bidirectional = check_bool(bidirectional) - self.dropout = validator.check_type("dropout", dropout, [float]) - self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH) + self.input_size = validator.check_integer('input_size', input_size, 0, Rel.GT, self.name) + self.hidden_size = validator.check_integer('hidden_size', hidden_size, 0, Rel.GT, self.name) + self.num_layers = validator.check_integer('num_layers', num_layers, 0, Rel.GT, self.name) + self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name) + self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name) + self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) + self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name) if bidirectional: self.num_directions = 2 @@ -606,19 +585,19 @@ class LSTMGradData(PrimitiveWithInfer): def infer_shape(self, y_shape, dy_shape, dhy_shape, dcy_shape, w_shape, hx_shape, cx_shape, reserve_shape, state_shape): # dhy and dcy should be same shape - validator.check_integer("h_shape", len(dhy_shape), 3, Rel.EQ) - validator.check_integer("h_shape", len(dhy_shape), len(dcy_shape), Rel.EQ) - validator.check_integer("h_shape[0]", dhy_shape[0], dcy_shape[0], Rel.EQ) - validator.check_integer("h_shape[1]", dhy_shape[1], dcy_shape[1], Rel.EQ) - validator.check_integer("h_shape[2]", dhy_shape[2], dcy_shape[2], Rel.EQ) + validator.check_integer("h_shape", len(dhy_shape), 3, Rel.EQ, self.name) + validator.check_integer("h_shape", len(dhy_shape), len(dcy_shape), Rel.EQ, self.name) + validator.check_integer("h_shape[0]", dhy_shape[0], dcy_shape[0], Rel.EQ, self.name) + validator.check_integer("h_shape[1]", dhy_shape[1], dcy_shape[1], Rel.EQ, self.name) + validator.check_integer("h_shape[2]", dhy_shape[2], dcy_shape[2], Rel.EQ, self.name) - validator.check_integer("h_shape[0]", dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ) - validator.check_integer("h_shape[2]", dhy_shape[2], self.hidden_size, Rel.EQ) + validator.check_integer("h_shape[0]", dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, self.name) + validator.check_integer("h_shape[2]", dhy_shape[2], self.hidden_size, Rel.EQ, self.name) # dy: (seq_len, batch_size, hidden_size * num_directions) - validator.check_integer("dy_shape", len(dy_shape), 3, Rel.EQ) - validator.check_integer("dy[1]", dy_shape[1], dhy_shape[1], Rel.EQ) - validator.check_integer("dy[2]", dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ) + validator.check_integer("dy_shape", len(dy_shape), 3, Rel.EQ, self.name) + validator.check_integer("dy[1]", dy_shape[1], dhy_shape[1], Rel.EQ, self.name) + validator.check_integer("dy[2]", dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, self.name) # (seq_len, batch_size, input_size) dx_shape = (y_shape[0], y_shape[1], self.input_size) @@ -629,11 +608,8 @@ class LSTMGradData(PrimitiveWithInfer): def infer_dtype(self, y_dtype, dy_dtype, dhy_dtype, dcy_dtype, w_dtype, hx_dtype, cx_dtype, reserve_dtype, state_dtype): - validator.check_typename("dy_dtype", dy_dtype, (mstype.float32, mstype.float16)) - validator.check_typename("dhy_dtype", dhy_dtype, (mstype.float32, mstype.float16)) - validator.check_typename("dcy_dtype", dcy_dtype, (mstype.float32, mstype.float16)) - validator.check_typename("datatype", dy_dtype, (dhy_dtype.element_type(),)) - validator.check_typename("datatype", dy_dtype, (dcy_dtype.element_type(),)) + args = {"dy": dy_dtype, "dhy": dhy_dtype, "dcy": dcy_dtype} + validator.check_tensor_type_same(args, (mstype.float32, mstype.float16), self.name) return (dy_dtype, dy_dtype, dy_dtype) @@ -642,13 +618,13 @@ class LSTMGradWeight(PrimitiveWithInfer): @prim_attr_register def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): - self.input_size = check_int_positive(input_size) - self.hidden_size = check_int_positive(hidden_size) - self.num_layers = check_int_positive(num_layers) - self.has_bias = check_bool(has_bias) - self.bidirectional = check_bool(bidirectional) - self.dropout = validator.check_type("dropout", dropout, [float]) - self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH) + self.input_size = validator.check_integer('input_size', input_size, 0, Rel.GT, self.name) + self.hidden_size = validator.check_integer('hidden_size', hidden_size, 0, Rel.GT, self.name) + self.num_layers = validator.check_integer('num_layers', num_layers, 0, Rel.GT, self.name) + self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name) + self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name) + self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) + self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name) if bidirectional: self.num_directions = 2 @@ -693,9 +669,10 @@ class PReLUGrad(PrimitiveWithInfer): return y_backprop_shape, w_shape def infer_dtype(self, y_backprop_dtype, A_dtype, w_dtype): - validator.check_typename("y_backprop_dtype", y_backprop_dtype, (mstype.float16, mstype.float32)) - validator.check_typename("A_dtype", A_dtype, (mstype.float16, mstype.float32)) - validator.check_typename("w_dtype", w_dtype, (mstype.float16, mstype.float32)) + valid_types = (mstype.float16, mstype.float32) + validator.check_tensor_type_same({"y_backprop": y_backprop_dtype}, valid_types, self.name) + validator.check_tensor_type_same({"A_dtype": A_dtype}, valid_types, self.name) + validator.check_tensor_type_same({"w_dtype": w_dtype}, valid_types, self.name) return y_backprop_dtype, w_dtype @@ -725,8 +702,8 @@ class ReLU6Grad(PrimitiveWithInfer): return x_shape def infer_dtype(self, y_grad_dtype, x_dtype): - validator.check_typename("y_grad_dtype", y_grad_dtype, (mstype.float16, mstype.float32)) - validator.check_typename("x_dtype", x_dtype, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) return x_dtype @@ -744,10 +721,8 @@ class ReluGradV2(PrimitiveWithInfer): return gradients_shape def infer_dtype(self, gradients_dtype, mask_dtype): - args_type = {'gradients': gradients_dtype, 'mask': mask_dtype} - validator.check_args_tensor(args_type) - validator.check_typename("gradients_dtype", gradients_dtype, mstype.number_type) - validator.check_typename("mask_dtype", mask_dtype, (mstype.uint8,)) + validator.check_tensor_type_same({'gradients': gradients_dtype}, mstype.number_type, self.name) + validator.check_tensor_type_same({'mask': mask_dtype}, (mstype.uint8,), self.name) return gradients_dtype @@ -762,10 +737,8 @@ class EluGrad(PrimitiveWithInfer): return x_shape def infer_dtype(self, y_grad_dtype, x_dtype): - args_type = {'y_grad': y_grad_dtype, 'x': x_dtype} - validator.check_args_tensor(args_type) - args_dtype = {'y_grad_dtype': y_grad_dtype, 'x_dtype': x_dtype} - validator.check_type_same(args_dtype, mstype.float_type) + args = {'y_grad': y_grad_dtype, 'x': x_dtype} + validator.check_tensor_type_same(args, mstype.float_type, self.name) return x_dtype @@ -821,11 +794,11 @@ class ROIAlignGrad(PrimitiveWithInfer): @prim_attr_register def __init__(self, xdiff_shape, pooled_height, pooled_width, spatial_scale, sample_num=2): """init ROIAlignGrad""" - validator.check_type("pooled_height", pooled_height, [int]) - validator.check_type("pooled_width", pooled_width, [int]) - validator.check_type("spatial_scale", spatial_scale, [float]) - validator.check_type("sample_num", sample_num, [int]) - validator.check_type("xdiff_shape", xdiff_shape, [tuple]) + validator.check_value_type("pooled_height", pooled_height, [int], self.name) + validator.check_value_type("pooled_width", pooled_width, [int], self.name) + validator.check_value_type("spatial_scale", spatial_scale, [float], self.name) + validator.check_value_type("sample_num", sample_num, [int], self.name) + validator.check_value_type("xdiff_shape", xdiff_shape, [tuple], self.name) self.xdiff_shape = xdiff_shape self.pooled_height = pooled_height self.pooled_width = pooled_width @@ -850,10 +823,8 @@ class SigmoidGrad(PrimitiveWithInfer): return out def infer_dtype(self, out, dout): - validator.check_typename("dout dtype", dout, (mstype.float16, mstype.float32)) - validator.check_typename("out dtype", out, (mstype.float16, mstype.float32)) - args = {"out type": out, "dout type": dout} - validator.check_type_same(args, mstype.number_type) + args = {'out': out, 'dout': dout} + validator.check_tensor_type_same(args, mstype.number_type, self.name) return out @@ -868,8 +839,8 @@ class HSigmoidGrad(PrimitiveWithInfer): return x_shape def infer_dtype(self, y_grad_dtype, x_dtype): - validator.check_typename("y_grad dtype", y_grad_dtype, (mstype.float16, mstype.float32)) - validator.check_typename("x dtype", x_dtype, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) return x_dtype @@ -884,8 +855,8 @@ class HSwishGrad(PrimitiveWithInfer): return x_shape def infer_dtype(self, y_grad_dtype, x_dtype): - validator.check_typename("y_grad dtype", y_grad_dtype, (mstype.float16, mstype.float32)) - validator.check_typename("x_ dtype", x_dtype, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) return x_dtype @@ -898,13 +869,13 @@ class SigmoidCrossEntropyWithLogitsGrad(PrimitiveWithInfer): self.init_prim_io_names(inputs=['x', 'y', 'dout'], outputs=['x_grad']) def infer_shape(self, x_shape, y_shape, dout_shape): - validator.check_param_equal("x_shape", x_shape, "y_shape", y_shape) - validator.check_param_equal("x_shape", x_shape, "dout_shape", dout_shape) + validator.check("x_shape", x_shape, "y_shape", y_shape, Rel.EQ, self.name) + validator.check("x_shape", x_shape, "dout_shape", dout_shape, Rel.EQ, self.name) return x_shape def infer_dtype(self, x_dtype, y_dtype, dout_dtype): args = {"x_dtype": x_dtype, "y_dtype": y_dtype, 'dout_dtype': dout_dtype} - validator.check_type_same(args, mstype.number_type) + validator.check_tensor_type_same(args, mstype.number_type, self.name) return dout_dtype @@ -920,8 +891,8 @@ class SliceGrad(PrimitiveWithInfer): dy_shape, x_shape, size_value = dy['shape'], x['shape'], size['value'] dy_shape_len = len(dy_shape) for i in range(dy_shape_len): - validator.check(f'dy_shape[{i}]', dy_shape[i], f'x_shape[{i}]', x_shape[i], Rel.LE) - validator.check(f'dy_shape[{i}]', dy_shape[i], f'size_shape[{i}]', size_value[i], Rel.EQ) + validator.check(f'dy_shape[{i}]', dy_shape[i], f'x_shape[{i}]', x_shape[i], Rel.LE, self.name) + validator.check(f'dy_shape[{i}]', dy_shape[i], f'size_shape[{i}]', size_value[i], Rel.EQ, self.name) return {'shape': x_shape, 'dtype': x['dtype'], 'value': None} @@ -935,13 +906,13 @@ class SmoothL1LossGrad(PrimitiveWithInfer): pass def infer_shape(self, prediction, target, dloss): - validator.check_param_equal('prediction', prediction, 'target', target) - validator.check_param_equal('prediction', prediction, 'dloss', dloss) + validator.check('prediction shape', prediction, 'target shape', target, Rel.EQ, self.name) + validator.check('prediction shape', prediction, 'dloss shape', dloss, Rel.EQ, self.name) return prediction def infer_dtype(self, prediction, target, dloss): args = {"prediction": prediction, "target": target, 'dloss': dloss} - validator.check_type_same(args, mstype.number_type) + validator.check_tensor_type_same(args, mstype.number_type, self.name) return dloss @@ -968,11 +939,11 @@ class StridedSliceGrad(PrimitiveWithInfer): new_axis_mask=0, shrink_axis_mask=0): """init StrideSliceGrad""" - validator.check_type('begin_mask', begin_mask, [int]) - validator.check_type('end_mask', end_mask, [int]) - validator.check_type('ellipsis_mask', ellipsis_mask, [int]) - validator.check_type('new_axis_mask', new_axis_mask, [int]) - validator.check_type('shrink_axis_mask', shrink_axis_mask, [int]) + validator.check_value_type('begin_mask', begin_mask, [int], self.name) + validator.check_value_type('end_mask', end_mask, [int], self.name) + validator.check_value_type('ellipsis_mask', ellipsis_mask, [int], self.name) + validator.check_value_type('new_axis_mask', new_axis_mask, [int], self.name) + validator.check_value_type('shrink_axis_mask', shrink_axis_mask, [int], self.name) self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output']) def __infer__(self, dy, shapex, begin, end, strides): @@ -992,10 +963,8 @@ class TanhGrad(PrimitiveWithInfer): return out def infer_dtype(self, out, dout): - validator.check_subclass("out", out, mstype.tensor) - validator.check_subclass("dout", dout, mstype.tensor) - args = {"out type": out, "dout type": dout} - validator.check_type_same(args, mstype.number_type) + args = {"out": out, "dout": dout} + validator.check_tensor_type_same(args, mstype.number_type, self.name) return out @@ -1005,13 +974,13 @@ class MirrorPadGrad(PrimitiveWithInfer): @prim_attr_register def __init__(self, mode="REFLECT"): """init MirrorPad""" - validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC']) + validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC'], self.name) self.mode = mode def __infer__(self, dout, paddings, x): - validator.check_subclass("dout", dout['dtype'], mstype.tensor) - validator.check_subclass("paddings", paddings['dtype'], mstype.tensor) - validator.check_subclass("input_x", x['dtype'], mstype.tensor) + validator.check_subclass("dout", dout['dtype'], mstype.tensor, self.name) + validator.check_subclass("paddings", paddings['dtype'], mstype.tensor, self.name) + validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name) return {'shape': x['shape'], 'dtype': dout['dtype'], 'value': None} diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index bb6755bef3a..43398a5f296 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1316,7 +1316,7 @@ class Concat(PrimitiveWithInfer): axis = self.axis x_shp = input_x['shape'] x_type = input_x['dtype'] - _, all_shp, _ = _get_concat_offset(x_shp, x_type, axis) + _, all_shp, _ = _get_concat_offset(x_shp, x_type, axis, self.name) self.add_prim_attr('T', x_type[0].element_type()) self.add_prim_attr('inputNums', len(x_shp)) ret_shp = x_shp[0].copy()