forked from mindspore-Ecosystem/mindspore
Add prim name to error message for _grad_ops.py
This commit is contained in:
parent
c803569648
commit
8bb93411f3
|
@ -206,8 +206,8 @@ class Validator:
|
|||
def _check_tensor_type(arg):
|
||||
arg_key, arg_val = arg
|
||||
elem_type = arg_val
|
||||
type_names = []
|
||||
if not elem_type in valid_values:
|
||||
type_names = []
|
||||
for t in valid_values:
|
||||
type_names.append(str(t))
|
||||
types_info = '[' + ", ".join(type_names) + ']'
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
|
||||
"""utils for operator"""
|
||||
|
||||
from ..._checkparam import ParamValidator as validator
|
||||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import Rel
|
||||
from ...common import dtype as mstype
|
||||
|
||||
|
@ -62,25 +62,25 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name):
|
|||
return broadcast_shape
|
||||
|
||||
|
||||
def _get_concat_offset(x_shp, x_type, axis):
|
||||
def _get_concat_offset(x_shp, x_type, axis, prim_name):
|
||||
"""for concat and concatoffset check args and compute offset"""
|
||||
validator.check_type("shape", x_shp, [tuple])
|
||||
validator.check_integer("len of input_x shape", len(x_shp), 0, Rel.GT)
|
||||
validator.check_subclass("shape0", x_type[0], mstype.tensor)
|
||||
validator.check_integer("len of input_x0 shape", len(x_shp[0]), 0, Rel.GT)
|
||||
validator.check_value_type("shape", x_shp, [tuple], prim_name)
|
||||
validator.check_integer("input_x rank", len(x_shp), 0, Rel.GT, prim_name)
|
||||
validator.check_subclass("shape0", x_type[0], mstype.tensor, prim_name)
|
||||
validator.check_integer("len of x_shp[0]", len(x_shp[0]), 0, Rel.GT, prim_name)
|
||||
rank_base = len(x_shp[0])
|
||||
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH)
|
||||
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH, prim_name)
|
||||
if axis < 0:
|
||||
axis = axis + rank_base
|
||||
all_shp = x_shp[0][axis]
|
||||
offset = [0,]
|
||||
for i in range(1, len(x_shp)):
|
||||
v = x_shp[i]
|
||||
validator.check('len of x_shp[%d]' % i, len(v), 'len of base', len(x_shp[0]))
|
||||
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0])
|
||||
validator.check('len of x_shp[%d]' % i, len(v), 'len of x_shp[0]', len(x_shp[0]), Rel.EQ, prim_name)
|
||||
validator.check('x_type[%d]' % i, x_type[i], 'x_type[0]', x_type[0], Rel.EQ, prim_name)
|
||||
for j in range(rank_base):
|
||||
if j != axis and v[j] != x_shp[0][j]:
|
||||
raise ValueError("Concat evaluator element %d shape in input can not concat with first element" % i)
|
||||
raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not concat with first element")
|
||||
offset.append(all_shp)
|
||||
all_shp += v[axis]
|
||||
return offset, all_shp, axis
|
||||
|
|
|
@ -18,8 +18,7 @@
|
|||
from ..._c_expression import signature_rw as sig_rw
|
||||
from ..._c_expression import signature_kind as sig_kind
|
||||
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
||||
from ..._checkparam import ParamValidator as validator
|
||||
from ..._checkparam import Rel, check_int_positive, check_bool
|
||||
from ..._checkparam import Validator as validator, Rel
|
||||
from .._utils import _get_concat_offset
|
||||
from ...common import dtype as mstype
|
||||
|
||||
|
@ -51,12 +50,12 @@ class ACosGrad(PrimitiveWithInfer):
|
|||
"""init ACosGrad"""
|
||||
|
||||
def infer_shape(self, x, dout):
|
||||
validator.check_param_equal("x", x, "dout", dout)
|
||||
validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name)
|
||||
return x
|
||||
|
||||
def infer_dtype(self, x, dout):
|
||||
args = {"x": x, "dout": dout}
|
||||
validator.check_type_same(args, mstype.number_type)
|
||||
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -65,8 +64,8 @@ class BatchNormGrad(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, is_training=False, epsilon=1e-5):
|
||||
self.is_training = validator.check_type('is_training', is_training, (bool,))
|
||||
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT)
|
||||
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
|
||||
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name)
|
||||
self.add_prim_attr('data_format', "NCHW")
|
||||
|
||||
def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape, reserve_3_shape):
|
||||
|
@ -93,19 +92,19 @@ class BinaryCrossEntropyGrad(PrimitiveWithInfer):
|
|||
"""Computes gradients for `BinaryCrossEntropy` operation."""
|
||||
@prim_attr_register
|
||||
def __init__(self, reduction='mean'):
|
||||
self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum'])
|
||||
self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum'], self.name)
|
||||
|
||||
def infer_shape(self, x_shape, y_shape, doutput_shape, weight_shape):
|
||||
validator.check_param_equal('x_shape', x_shape, 'y_shape', y_shape)
|
||||
validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name)
|
||||
if weight_shape:
|
||||
validator.check_param_equal('y_shape', y_shape, 'weight_shape', weight_shape)
|
||||
validator.check('y_shape', y_shape, 'weight_shape', weight_shape, Rel.EQ, self.name)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type, y_type, doutput_type, weight_type):
|
||||
args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type}
|
||||
validator.check_type_same(args, (mstype.float16, mstype.float32))
|
||||
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
|
||||
if weight_type:
|
||||
validator.check_two_types_same('x_type', x_type, 'weight_type', weight_type)
|
||||
validator.check('x_type', x_type, 'weight_type', weight_type, Rel.EQ, TypeError)
|
||||
return x_type
|
||||
|
||||
|
||||
|
@ -120,7 +119,7 @@ class ConcatOffset(PrimitiveWithInfer):
|
|||
axis = self.axis
|
||||
x_shp = input_x['shape']
|
||||
x_type = input_x['dtype']
|
||||
offset, _, axis = _get_concat_offset(x_shp, x_type, axis)
|
||||
offset, _, axis = _get_concat_offset(x_shp, x_type, axis, self.name)
|
||||
self.add_prim_attr('T', x_type[0].element_type())
|
||||
offset_values = []
|
||||
for i in range(len(x_shp)):
|
||||
|
@ -184,11 +183,11 @@ class Conv2DBackpropFilter(PrimitiveWithInfer):
|
|||
|
||||
def __infer__(self, doutput, x, w_size):
|
||||
w_size_v = w_size['value']
|
||||
validator.check_type('w_size', w_size_v, [tuple])
|
||||
validator.check_value_type('w_size', w_size_v, [tuple], self.name)
|
||||
for i, dim_len in enumerate(w_size_v):
|
||||
validator.check_type("w_size[%d]" % i, dim_len, [int])
|
||||
validator.check_typename('x_dtype', x['dtype'], [mstype.int8, mstype.int32, mstype.float16, mstype.float32])
|
||||
validator.check_two_types_same('doutput_dtype', doutput['dtype'], 'x_dtype', x['dtype'])
|
||||
validator.check_value_type("w_size[%d]" % i, dim_len, [int], self.name)
|
||||
args = {"x": x['dtype'], "doutput": doutput['dtype']}
|
||||
validator.check_tensor_type_same(args, [mstype.int8, mstype.int32, mstype.float16, mstype.float32], self.name)
|
||||
out = {
|
||||
'value': None,
|
||||
'shape': w_size_v,
|
||||
|
@ -250,8 +249,8 @@ class DepthwiseConv2dNativeBackpropFilter(PrimitiveWithInfer):
|
|||
|
||||
def __infer__(self, x, w_size, dout):
|
||||
w_size_v = w_size['value']
|
||||
args = {'x_dtype': x['dtype'], 'dout_type': dout['dtype']}
|
||||
validator.check_type_same(args, mstype.number_type)
|
||||
args = {'x': x['dtype'], 'dout': dout['dtype']}
|
||||
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
||||
out = {
|
||||
'value': None,
|
||||
'shape': w_size_v,
|
||||
|
@ -310,8 +309,8 @@ class DepthwiseConv2dNativeBackpropInput(PrimitiveWithInfer):
|
|||
raise NotImplementedError
|
||||
|
||||
def __infer__(self, x_size, w, dout):
|
||||
args = {'w_dtype': w['dtype'], 'dout_type': dout['dtype']}
|
||||
validator.check_type_same(args, mstype.number_type)
|
||||
args = {'w': w['dtype'], 'dout': dout['dtype']}
|
||||
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
||||
x_size_v = x_size['value']
|
||||
out = {
|
||||
'value': None,
|
||||
|
@ -360,9 +359,9 @@ class GeluGrad(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, y_backprop_dtype, x_dtype, y_dtype):
|
||||
validator.check_typename("y_backprop_dtype", y_backprop_dtype, (mstype.float16, mstype.float32))
|
||||
validator.check_typename("x_dtype", x_dtype, (mstype.float16, mstype.float32))
|
||||
validator.check_typename("y_dtype", y_dtype, (mstype.float16, mstype.float32))
|
||||
validator.check_tensor_type_same({"y_backprop": y_backprop_dtype}, (mstype.float16, mstype.float32), self.name)
|
||||
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
|
||||
validator.check_tensor_type_same({"y": y_dtype}, (mstype.float16, mstype.float32), self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
|
@ -373,56 +372,36 @@ class _PoolGrad(PrimitiveWithInfer):
|
|||
def __init__(self, ksize, strides, padding="VALID"):
|
||||
self.init_prim_io_names(inputs=['x_origin', 'out_origin', 'grad'], outputs=['output'])
|
||||
|
||||
validator.check_type('ksize', ksize, [int, tuple])
|
||||
validator.check_type('strides', strides, [int, tuple])
|
||||
self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'])
|
||||
validator.check_value_type('ksize', ksize, [int, tuple], self.name)
|
||||
validator.check_value_type('strides', strides, [int, tuple], self.name)
|
||||
self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name)
|
||||
self.add_prim_attr("padding", self.padding)
|
||||
self.is_maxpoolgradwithargmax = (self.name == "MaxPoolGradWithArgmax")
|
||||
if not self.is_maxpoolgradwithargmax:
|
||||
self.add_prim_attr('data_format', "NCHW")
|
||||
|
||||
if isinstance(ksize, int):
|
||||
validator.check_integer("ksize", ksize, 1, Rel.GE)
|
||||
if self.is_maxpoolgradwithargmax:
|
||||
self.ksize = (1, ksize, ksize, 1)
|
||||
def _grad_check_int_or_tuple(arg_name, arg_val, is_argmax):
|
||||
validator.check_value_type(arg_name, arg_val, (int, tuple), self.name)
|
||||
error_msg = ValueError(f"For '{self.name}' the '{arg_name}' should be an positive int number "
|
||||
f"or a tuple of two or four positive int numbers, but got {arg_val}")
|
||||
if isinstance(arg_val, int):
|
||||
ret = (1, arg_val, arg_val, 1) if is_argmax else (1, 1, arg_val, arg_val)
|
||||
elif len(arg_val) == 2:
|
||||
ret = (1, arg_val[0], arg_val[1], 1) if is_argmax else (1, 1, arg_val[0], arg_val[1])
|
||||
elif len(arg_val) == 4:
|
||||
ret = arg_val
|
||||
else:
|
||||
self.ksize = (1, 1, ksize, ksize)
|
||||
else:
|
||||
ksize_error = ValueError(f"The 'ksize' passed to operator {self.name} should be an positive int number"
|
||||
f"or a tuple of two or four positive int numbers, but got {ksize}")
|
||||
if len(ksize) != 2 and len(ksize) != 4:
|
||||
raise ksize_error
|
||||
for ksize_val in ksize:
|
||||
if not isinstance(ksize_val, int) or (ksize_val <= 0):
|
||||
raise ksize_error
|
||||
if len(ksize) == 2 and self.is_maxpoolgradwithargmax:
|
||||
self.ksize = (1, ksize[0], ksize[1], 1)
|
||||
elif len(ksize) == 2 and not self.is_maxpoolgradwithargmax:
|
||||
self.ksize = (1, 1, ksize[0], ksize[1])
|
||||
else:
|
||||
self.ksize = ksize
|
||||
raise error_msg
|
||||
# whether all elements of tuple are positive integers
|
||||
for item in ret:
|
||||
if not isinstance(item, int) or item <= 0:
|
||||
raise error_msg
|
||||
return ret
|
||||
|
||||
self.ksize = _grad_check_int_or_tuple("ksize", ksize, self.is_maxpoolgradwithargmax)
|
||||
self.add_prim_attr("ksize", self.ksize)
|
||||
|
||||
if isinstance(strides, int):
|
||||
validator.check_integer("strides", strides, 1, Rel.GE)
|
||||
if self.is_maxpoolgradwithargmax:
|
||||
self.strides = (1, strides, strides, 1)
|
||||
else:
|
||||
self.strides = (1, 1, strides, strides)
|
||||
else:
|
||||
strides_error = ValueError(f"The 'strides' passed to operator {self.name} should be an positive int number"
|
||||
f"or a tuple of two or four positive int numbers, but got {strides}")
|
||||
if len(strides) != 2 and len(strides) != 4:
|
||||
raise strides_error
|
||||
for strides_val in strides:
|
||||
if not isinstance(strides_val, int) or (strides_val <= 0):
|
||||
raise strides_error
|
||||
if len(strides) == 2 and self.is_maxpoolgradwithargmax:
|
||||
self.strides = (1, strides[0], strides[1], 1)
|
||||
elif len(strides) == 2 and not self.is_maxpoolgradwithargmax:
|
||||
self.strides = (1, 1, strides[0], strides[1])
|
||||
else:
|
||||
self.strides = strides
|
||||
self.strides = _grad_check_int_or_tuple("strides", strides, self.is_maxpoolgradwithargmax)
|
||||
self.add_prim_attr("strides", self.strides)
|
||||
|
||||
|
||||
|
@ -529,17 +508,17 @@ class L2NormalizeGrad(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, axis=0, epsilon=1e-4):
|
||||
validator.check_type('axis', axis, [int])
|
||||
validator.check_type('epsilon', epsilon, [int, float])
|
||||
validator.check_value_type('axis', axis, [int], self.name)
|
||||
validator.check_value_type('epsilon', epsilon, [int, float], self.name)
|
||||
|
||||
def infer_shape(self, input_x, out, dout):
|
||||
validator.check_param_equal('input_x', input_x, 'out', out)
|
||||
validator.check_param_equal('input_x', input_x, 'dout', dout)
|
||||
validator.check('input_x shape', input_x, 'out shape', out, Rel.EQ, self.name)
|
||||
validator.check('input_x shape', input_x, 'dout shape', dout, Rel.EQ, self.name)
|
||||
return input_x
|
||||
|
||||
def infer_dtype(self, input_x, out, dout):
|
||||
args = {'input_x': input_x, 'out': out, 'dout': dout}
|
||||
validator.check_type_same(args, mstype.number_type)
|
||||
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
||||
return input_x
|
||||
|
||||
|
||||
|
@ -560,8 +539,8 @@ class LayerNormGrad(Primitive):
|
|||
@prim_attr_register
|
||||
def __init__(self, begin_norm_axis=1, begin_params_axis=1):
|
||||
"""init"""
|
||||
self.begin_norm_axis = validator.check_type('begin_norm_axis', begin_norm_axis, [int])
|
||||
self.begin_params_axis = validator.check_type('begin_params_axis', begin_params_axis, [int])
|
||||
self.begin_norm_axis = validator.check_value_type('begin_norm_axis', begin_norm_axis, [int], self.name)
|
||||
self.begin_params_axis = validator.check_value_type('begin_params_axis', begin_params_axis, [int], self.name)
|
||||
|
||||
def __call__(self, x, dy, variance, mean, gamma):
|
||||
raise NotImplementedError
|
||||
|
@ -573,15 +552,15 @@ class LogSoftmaxGrad(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, axis=-1):
|
||||
"""init LogSoftmaxGrad"""
|
||||
validator.check_type("axis", axis, [int])
|
||||
validator.check_value_type("axis", axis, [int], self.name)
|
||||
|
||||
def infer_shape(self, dout, logits):
|
||||
rank = len(logits)
|
||||
validator.check_int_range('axis', self.axis, -rank - 1, rank, Rel.INC_BOTH)
|
||||
validator.check_int_range('axis', self.axis, -rank - 1, rank, Rel.INC_BOTH, self.name)
|
||||
return logits
|
||||
|
||||
def infer_dtype(self, dout, logits):
|
||||
validator.check_subclass("logits", logits, mstype.tensor)
|
||||
validator.check_subclass("logits", logits, mstype.tensor, self.name)
|
||||
return logits
|
||||
|
||||
|
||||
|
@ -590,13 +569,13 @@ class LSTMGradData(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
|
||||
self.input_size = check_int_positive(input_size)
|
||||
self.hidden_size = check_int_positive(hidden_size)
|
||||
self.num_layers = check_int_positive(num_layers)
|
||||
self.has_bias = check_bool(has_bias)
|
||||
self.bidirectional = check_bool(bidirectional)
|
||||
self.dropout = validator.check_type("dropout", dropout, [float])
|
||||
self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH)
|
||||
self.input_size = validator.check_integer('input_size', input_size, 0, Rel.GT, self.name)
|
||||
self.hidden_size = validator.check_integer('hidden_size', hidden_size, 0, Rel.GT, self.name)
|
||||
self.num_layers = validator.check_integer('num_layers', num_layers, 0, Rel.GT, self.name)
|
||||
self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
|
||||
self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
|
||||
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
|
||||
self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name)
|
||||
|
||||
if bidirectional:
|
||||
self.num_directions = 2
|
||||
|
@ -606,19 +585,19 @@ class LSTMGradData(PrimitiveWithInfer):
|
|||
def infer_shape(self, y_shape, dy_shape, dhy_shape, dcy_shape, w_shape,
|
||||
hx_shape, cx_shape, reserve_shape, state_shape):
|
||||
# dhy and dcy should be same shape
|
||||
validator.check_integer("h_shape", len(dhy_shape), 3, Rel.EQ)
|
||||
validator.check_integer("h_shape", len(dhy_shape), len(dcy_shape), Rel.EQ)
|
||||
validator.check_integer("h_shape[0]", dhy_shape[0], dcy_shape[0], Rel.EQ)
|
||||
validator.check_integer("h_shape[1]", dhy_shape[1], dcy_shape[1], Rel.EQ)
|
||||
validator.check_integer("h_shape[2]", dhy_shape[2], dcy_shape[2], Rel.EQ)
|
||||
validator.check_integer("h_shape", len(dhy_shape), 3, Rel.EQ, self.name)
|
||||
validator.check_integer("h_shape", len(dhy_shape), len(dcy_shape), Rel.EQ, self.name)
|
||||
validator.check_integer("h_shape[0]", dhy_shape[0], dcy_shape[0], Rel.EQ, self.name)
|
||||
validator.check_integer("h_shape[1]", dhy_shape[1], dcy_shape[1], Rel.EQ, self.name)
|
||||
validator.check_integer("h_shape[2]", dhy_shape[2], dcy_shape[2], Rel.EQ, self.name)
|
||||
|
||||
validator.check_integer("h_shape[0]", dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ)
|
||||
validator.check_integer("h_shape[2]", dhy_shape[2], self.hidden_size, Rel.EQ)
|
||||
validator.check_integer("h_shape[0]", dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, self.name)
|
||||
validator.check_integer("h_shape[2]", dhy_shape[2], self.hidden_size, Rel.EQ, self.name)
|
||||
|
||||
# dy: (seq_len, batch_size, hidden_size * num_directions)
|
||||
validator.check_integer("dy_shape", len(dy_shape), 3, Rel.EQ)
|
||||
validator.check_integer("dy[1]", dy_shape[1], dhy_shape[1], Rel.EQ)
|
||||
validator.check_integer("dy[2]", dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ)
|
||||
validator.check_integer("dy_shape", len(dy_shape), 3, Rel.EQ, self.name)
|
||||
validator.check_integer("dy[1]", dy_shape[1], dhy_shape[1], Rel.EQ, self.name)
|
||||
validator.check_integer("dy[2]", dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, self.name)
|
||||
|
||||
# (seq_len, batch_size, input_size)
|
||||
dx_shape = (y_shape[0], y_shape[1], self.input_size)
|
||||
|
@ -629,11 +608,8 @@ class LSTMGradData(PrimitiveWithInfer):
|
|||
|
||||
def infer_dtype(self, y_dtype, dy_dtype, dhy_dtype, dcy_dtype, w_dtype,
|
||||
hx_dtype, cx_dtype, reserve_dtype, state_dtype):
|
||||
validator.check_typename("dy_dtype", dy_dtype, (mstype.float32, mstype.float16))
|
||||
validator.check_typename("dhy_dtype", dhy_dtype, (mstype.float32, mstype.float16))
|
||||
validator.check_typename("dcy_dtype", dcy_dtype, (mstype.float32, mstype.float16))
|
||||
validator.check_typename("datatype", dy_dtype, (dhy_dtype.element_type(),))
|
||||
validator.check_typename("datatype", dy_dtype, (dcy_dtype.element_type(),))
|
||||
args = {"dy": dy_dtype, "dhy": dhy_dtype, "dcy": dcy_dtype}
|
||||
validator.check_tensor_type_same(args, (mstype.float32, mstype.float16), self.name)
|
||||
return (dy_dtype, dy_dtype, dy_dtype)
|
||||
|
||||
|
||||
|
@ -642,13 +618,13 @@ class LSTMGradWeight(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
|
||||
self.input_size = check_int_positive(input_size)
|
||||
self.hidden_size = check_int_positive(hidden_size)
|
||||
self.num_layers = check_int_positive(num_layers)
|
||||
self.has_bias = check_bool(has_bias)
|
||||
self.bidirectional = check_bool(bidirectional)
|
||||
self.dropout = validator.check_type("dropout", dropout, [float])
|
||||
self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH)
|
||||
self.input_size = validator.check_integer('input_size', input_size, 0, Rel.GT, self.name)
|
||||
self.hidden_size = validator.check_integer('hidden_size', hidden_size, 0, Rel.GT, self.name)
|
||||
self.num_layers = validator.check_integer('num_layers', num_layers, 0, Rel.GT, self.name)
|
||||
self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
|
||||
self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
|
||||
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
|
||||
self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name)
|
||||
|
||||
if bidirectional:
|
||||
self.num_directions = 2
|
||||
|
@ -693,9 +669,10 @@ class PReLUGrad(PrimitiveWithInfer):
|
|||
return y_backprop_shape, w_shape
|
||||
|
||||
def infer_dtype(self, y_backprop_dtype, A_dtype, w_dtype):
|
||||
validator.check_typename("y_backprop_dtype", y_backprop_dtype, (mstype.float16, mstype.float32))
|
||||
validator.check_typename("A_dtype", A_dtype, (mstype.float16, mstype.float32))
|
||||
validator.check_typename("w_dtype", w_dtype, (mstype.float16, mstype.float32))
|
||||
valid_types = (mstype.float16, mstype.float32)
|
||||
validator.check_tensor_type_same({"y_backprop": y_backprop_dtype}, valid_types, self.name)
|
||||
validator.check_tensor_type_same({"A_dtype": A_dtype}, valid_types, self.name)
|
||||
validator.check_tensor_type_same({"w_dtype": w_dtype}, valid_types, self.name)
|
||||
return y_backprop_dtype, w_dtype
|
||||
|
||||
|
||||
|
@ -725,8 +702,8 @@ class ReLU6Grad(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, y_grad_dtype, x_dtype):
|
||||
validator.check_typename("y_grad_dtype", y_grad_dtype, (mstype.float16, mstype.float32))
|
||||
validator.check_typename("x_dtype", x_dtype, (mstype.float16, mstype.float32))
|
||||
validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name)
|
||||
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
|
@ -744,10 +721,8 @@ class ReluGradV2(PrimitiveWithInfer):
|
|||
return gradients_shape
|
||||
|
||||
def infer_dtype(self, gradients_dtype, mask_dtype):
|
||||
args_type = {'gradients': gradients_dtype, 'mask': mask_dtype}
|
||||
validator.check_args_tensor(args_type)
|
||||
validator.check_typename("gradients_dtype", gradients_dtype, mstype.number_type)
|
||||
validator.check_typename("mask_dtype", mask_dtype, (mstype.uint8,))
|
||||
validator.check_tensor_type_same({'gradients': gradients_dtype}, mstype.number_type, self.name)
|
||||
validator.check_tensor_type_same({'mask': mask_dtype}, (mstype.uint8,), self.name)
|
||||
return gradients_dtype
|
||||
|
||||
|
||||
|
@ -762,10 +737,8 @@ class EluGrad(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, y_grad_dtype, x_dtype):
|
||||
args_type = {'y_grad': y_grad_dtype, 'x': x_dtype}
|
||||
validator.check_args_tensor(args_type)
|
||||
args_dtype = {'y_grad_dtype': y_grad_dtype, 'x_dtype': x_dtype}
|
||||
validator.check_type_same(args_dtype, mstype.float_type)
|
||||
args = {'y_grad': y_grad_dtype, 'x': x_dtype}
|
||||
validator.check_tensor_type_same(args, mstype.float_type, self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
|
@ -821,11 +794,11 @@ class ROIAlignGrad(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, xdiff_shape, pooled_height, pooled_width, spatial_scale, sample_num=2):
|
||||
"""init ROIAlignGrad"""
|
||||
validator.check_type("pooled_height", pooled_height, [int])
|
||||
validator.check_type("pooled_width", pooled_width, [int])
|
||||
validator.check_type("spatial_scale", spatial_scale, [float])
|
||||
validator.check_type("sample_num", sample_num, [int])
|
||||
validator.check_type("xdiff_shape", xdiff_shape, [tuple])
|
||||
validator.check_value_type("pooled_height", pooled_height, [int], self.name)
|
||||
validator.check_value_type("pooled_width", pooled_width, [int], self.name)
|
||||
validator.check_value_type("spatial_scale", spatial_scale, [float], self.name)
|
||||
validator.check_value_type("sample_num", sample_num, [int], self.name)
|
||||
validator.check_value_type("xdiff_shape", xdiff_shape, [tuple], self.name)
|
||||
self.xdiff_shape = xdiff_shape
|
||||
self.pooled_height = pooled_height
|
||||
self.pooled_width = pooled_width
|
||||
|
@ -850,10 +823,8 @@ class SigmoidGrad(PrimitiveWithInfer):
|
|||
return out
|
||||
|
||||
def infer_dtype(self, out, dout):
|
||||
validator.check_typename("dout dtype", dout, (mstype.float16, mstype.float32))
|
||||
validator.check_typename("out dtype", out, (mstype.float16, mstype.float32))
|
||||
args = {"out type": out, "dout type": dout}
|
||||
validator.check_type_same(args, mstype.number_type)
|
||||
args = {'out': out, 'dout': dout}
|
||||
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
||||
return out
|
||||
|
||||
|
||||
|
@ -868,8 +839,8 @@ class HSigmoidGrad(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, y_grad_dtype, x_dtype):
|
||||
validator.check_typename("y_grad dtype", y_grad_dtype, (mstype.float16, mstype.float32))
|
||||
validator.check_typename("x dtype", x_dtype, (mstype.float16, mstype.float32))
|
||||
validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name)
|
||||
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
|
@ -884,8 +855,8 @@ class HSwishGrad(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, y_grad_dtype, x_dtype):
|
||||
validator.check_typename("y_grad dtype", y_grad_dtype, (mstype.float16, mstype.float32))
|
||||
validator.check_typename("x_ dtype", x_dtype, (mstype.float16, mstype.float32))
|
||||
validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name)
|
||||
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
|
@ -898,13 +869,13 @@ class SigmoidCrossEntropyWithLogitsGrad(PrimitiveWithInfer):
|
|||
self.init_prim_io_names(inputs=['x', 'y', 'dout'], outputs=['x_grad'])
|
||||
|
||||
def infer_shape(self, x_shape, y_shape, dout_shape):
|
||||
validator.check_param_equal("x_shape", x_shape, "y_shape", y_shape)
|
||||
validator.check_param_equal("x_shape", x_shape, "dout_shape", dout_shape)
|
||||
validator.check("x_shape", x_shape, "y_shape", y_shape, Rel.EQ, self.name)
|
||||
validator.check("x_shape", x_shape, "dout_shape", dout_shape, Rel.EQ, self.name)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, y_dtype, dout_dtype):
|
||||
args = {"x_dtype": x_dtype, "y_dtype": y_dtype, 'dout_dtype': dout_dtype}
|
||||
validator.check_type_same(args, mstype.number_type)
|
||||
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
||||
return dout_dtype
|
||||
|
||||
|
||||
|
@ -920,8 +891,8 @@ class SliceGrad(PrimitiveWithInfer):
|
|||
dy_shape, x_shape, size_value = dy['shape'], x['shape'], size['value']
|
||||
dy_shape_len = len(dy_shape)
|
||||
for i in range(dy_shape_len):
|
||||
validator.check(f'dy_shape[{i}]', dy_shape[i], f'x_shape[{i}]', x_shape[i], Rel.LE)
|
||||
validator.check(f'dy_shape[{i}]', dy_shape[i], f'size_shape[{i}]', size_value[i], Rel.EQ)
|
||||
validator.check(f'dy_shape[{i}]', dy_shape[i], f'x_shape[{i}]', x_shape[i], Rel.LE, self.name)
|
||||
validator.check(f'dy_shape[{i}]', dy_shape[i], f'size_shape[{i}]', size_value[i], Rel.EQ, self.name)
|
||||
return {'shape': x_shape,
|
||||
'dtype': x['dtype'],
|
||||
'value': None}
|
||||
|
@ -935,13 +906,13 @@ class SmoothL1LossGrad(PrimitiveWithInfer):
|
|||
pass
|
||||
|
||||
def infer_shape(self, prediction, target, dloss):
|
||||
validator.check_param_equal('prediction', prediction, 'target', target)
|
||||
validator.check_param_equal('prediction', prediction, 'dloss', dloss)
|
||||
validator.check('prediction shape', prediction, 'target shape', target, Rel.EQ, self.name)
|
||||
validator.check('prediction shape', prediction, 'dloss shape', dloss, Rel.EQ, self.name)
|
||||
return prediction
|
||||
|
||||
def infer_dtype(self, prediction, target, dloss):
|
||||
args = {"prediction": prediction, "target": target, 'dloss': dloss}
|
||||
validator.check_type_same(args, mstype.number_type)
|
||||
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
||||
return dloss
|
||||
|
||||
|
||||
|
@ -968,11 +939,11 @@ class StridedSliceGrad(PrimitiveWithInfer):
|
|||
new_axis_mask=0,
|
||||
shrink_axis_mask=0):
|
||||
"""init StrideSliceGrad"""
|
||||
validator.check_type('begin_mask', begin_mask, [int])
|
||||
validator.check_type('end_mask', end_mask, [int])
|
||||
validator.check_type('ellipsis_mask', ellipsis_mask, [int])
|
||||
validator.check_type('new_axis_mask', new_axis_mask, [int])
|
||||
validator.check_type('shrink_axis_mask', shrink_axis_mask, [int])
|
||||
validator.check_value_type('begin_mask', begin_mask, [int], self.name)
|
||||
validator.check_value_type('end_mask', end_mask, [int], self.name)
|
||||
validator.check_value_type('ellipsis_mask', ellipsis_mask, [int], self.name)
|
||||
validator.check_value_type('new_axis_mask', new_axis_mask, [int], self.name)
|
||||
validator.check_value_type('shrink_axis_mask', shrink_axis_mask, [int], self.name)
|
||||
self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output'])
|
||||
|
||||
def __infer__(self, dy, shapex, begin, end, strides):
|
||||
|
@ -992,10 +963,8 @@ class TanhGrad(PrimitiveWithInfer):
|
|||
return out
|
||||
|
||||
def infer_dtype(self, out, dout):
|
||||
validator.check_subclass("out", out, mstype.tensor)
|
||||
validator.check_subclass("dout", dout, mstype.tensor)
|
||||
args = {"out type": out, "dout type": dout}
|
||||
validator.check_type_same(args, mstype.number_type)
|
||||
args = {"out": out, "dout": dout}
|
||||
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
||||
return out
|
||||
|
||||
|
||||
|
@ -1005,13 +974,13 @@ class MirrorPadGrad(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, mode="REFLECT"):
|
||||
"""init MirrorPad"""
|
||||
validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC'])
|
||||
validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC'], self.name)
|
||||
self.mode = mode
|
||||
|
||||
def __infer__(self, dout, paddings, x):
|
||||
validator.check_subclass("dout", dout['dtype'], mstype.tensor)
|
||||
validator.check_subclass("paddings", paddings['dtype'], mstype.tensor)
|
||||
validator.check_subclass("input_x", x['dtype'], mstype.tensor)
|
||||
validator.check_subclass("dout", dout['dtype'], mstype.tensor, self.name)
|
||||
validator.check_subclass("paddings", paddings['dtype'], mstype.tensor, self.name)
|
||||
validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name)
|
||||
return {'shape': x['shape'],
|
||||
'dtype': dout['dtype'],
|
||||
'value': None}
|
||||
|
|
|
@ -1316,7 +1316,7 @@ class Concat(PrimitiveWithInfer):
|
|||
axis = self.axis
|
||||
x_shp = input_x['shape']
|
||||
x_type = input_x['dtype']
|
||||
_, all_shp, _ = _get_concat_offset(x_shp, x_type, axis)
|
||||
_, all_shp, _ = _get_concat_offset(x_shp, x_type, axis, self.name)
|
||||
self.add_prim_attr('T', x_type[0].element_type())
|
||||
self.add_prim_attr('inputNums', len(x_shp))
|
||||
ret_shp = x_shp[0].copy()
|
||||
|
|
Loading…
Reference in New Issue