forked from mindspore-Ecosystem/mindspore
!23069 Modify error content
Merge pull request !23069 from huchunmei/master
This commit is contained in:
commit
76fac95591
|
@ -593,20 +593,24 @@ class PReLU(Cell):
|
||||||
w = Tensor(tmp, dtype=mstype.float32)
|
w = Tensor(tmp, dtype=mstype.float32)
|
||||||
elif isinstance(w, list):
|
elif isinstance(w, list):
|
||||||
if len(w) != channel:
|
if len(w) != channel:
|
||||||
raise ValueError(f"When the 'w' is a list, the length should be equal to the channel, "
|
raise ValueError(f"For '{self.cls_name}', the length of 'w' should be equal to the 'channel' when "
|
||||||
f"but got the length {len(w)}, the channel {channel}")
|
f"the 'w' is a list, but got the length of 'w': {len(w)}, the 'channel': {channel}.")
|
||||||
|
|
||||||
for i in w:
|
for i in w:
|
||||||
if not isinstance(i, (float, np.float32)):
|
if not isinstance(i, (float, np.float32)):
|
||||||
raise ValueError(f"When the 'w' is a list, the all elements should be float, but got {w}")
|
raise ValueError(f"For '{self.cls_name}', all elements in 'w' should be "
|
||||||
|
f"float when the 'w' is a list, but got {i}.")
|
||||||
w = Tensor(w, dtype=mstype.float32)
|
w = Tensor(w, dtype=mstype.float32)
|
||||||
elif isinstance(w, Tensor):
|
elif isinstance(w, Tensor):
|
||||||
if w.dtype not in (mstype.float16, mstype.float32):
|
if w.dtype not in (mstype.float16, mstype.float32):
|
||||||
raise ValueError(f"When the 'w' is a tensor, the dtype should be float16 or float32, but got {w.dtype}")
|
raise ValueError(f"For '{self.cls_name}', the dtype of 'w' should be float16 or "
|
||||||
|
f"float32 when the 'w' is a tensor, but got {w.dtype}.")
|
||||||
if len(w.shape) != 1 or w.shape[0] != channel:
|
if len(w.shape) != 1 or w.shape[0] != channel:
|
||||||
raise ValueError(f"When the 'w' is a tensor, the rank should be 1, and the elements number "
|
raise ValueError(f"For '{self.cls_name}', the dimension of 'w' should be 1, and the elements number "
|
||||||
f"should be equal to the channel, but got w shape {w}, the channel {channel}")
|
f"should be equal to the 'channel' when the 'w' is a tensor, but got 'w' shape {w}, "
|
||||||
|
f"the 'channel' {channel}.")
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"The 'w' only supported float list and tensor, but got {type(w)}")
|
raise TypeError(f"For '{self.cls_name}', the 'w' only supported float, list and tensor, but got {type(w)}.")
|
||||||
self.w = Parameter(w, name='a')
|
self.w = Parameter(w, name='a')
|
||||||
self.prelu = P.PReLU()
|
self.prelu = P.PReLU()
|
||||||
self.relu = P.ReLU()
|
self.relu = P.ReLU()
|
||||||
|
@ -870,7 +874,7 @@ _activation = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_activation(name):
|
def get_activation(name, prim_name=None):
|
||||||
"""
|
"""
|
||||||
Gets the activation function.
|
Gets the activation function.
|
||||||
|
|
||||||
|
@ -888,9 +892,11 @@ def get_activation(name):
|
||||||
>>> print(sigmoid)
|
>>> print(sigmoid)
|
||||||
Sigmoid<>
|
Sigmoid<>
|
||||||
"""
|
"""
|
||||||
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||||
if name is None:
|
if name is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if name not in _activation:
|
if name not in _activation:
|
||||||
raise KeyError(f"Unknown activation type '{name}'")
|
raise KeyError(f"{msg_prefix} 'name' should be in {_activation}, but got '{name}'. "
|
||||||
|
f"Refer to official documents for more information.")
|
||||||
return _activation[name]()
|
return _activation[name]()
|
||||||
|
|
|
@ -81,9 +81,9 @@ class L1Regularizer(Cell):
|
||||||
super(L1Regularizer, self).__init__()
|
super(L1Regularizer, self).__init__()
|
||||||
Validator.check_value_type("scale", scale, [int, float], self.cls_name)
|
Validator.check_value_type("scale", scale, [int, float], self.cls_name)
|
||||||
if scale <= 0:
|
if scale <= 0:
|
||||||
raise ValueError("scale should be a number which greater than 0")
|
raise ValueError(f"For '{self.cls_name}', the 'scale' should be greater than 0, but got {scale}.")
|
||||||
if math.isinf(scale) or math.isnan(scale):
|
if math.isinf(scale) or math.isnan(scale):
|
||||||
raise ValueError("scale can not be INF or NAN")
|
raise ValueError(f"For '{self.cls_name}', the 'scale' can not be INF or NAN, but got {scale}.")
|
||||||
self.abs = P.Abs()
|
self.abs = P.Abs()
|
||||||
self.reduce_sum = P.ReduceSum()
|
self.reduce_sum = P.ReduceSum()
|
||||||
self.scale = Tensor(scale, dtype=mstype.float32)
|
self.scale = Tensor(scale, dtype=mstype.float32)
|
||||||
|
@ -149,7 +149,8 @@ class Dropout(Cell):
|
||||||
"""Initialize Dropout."""
|
"""Initialize Dropout."""
|
||||||
super(Dropout, self).__init__()
|
super(Dropout, self).__init__()
|
||||||
if keep_prob <= 0 or keep_prob > 1:
|
if keep_prob <= 0 or keep_prob > 1:
|
||||||
raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob))
|
raise ValueError(f"For '{self.cls_name}', the 'keep_prob' should be a number in range (0, 1], "
|
||||||
|
f"but got {keep_prob}.")
|
||||||
Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
|
Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
|
||||||
Validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name)
|
Validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name)
|
||||||
self.keep_prob = keep_prob
|
self.keep_prob = keep_prob
|
||||||
|
@ -215,10 +216,10 @@ class Flatten(Cell):
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def check_dense_input_shape(x):
|
def check_dense_input_shape(x, prim_name=None):
|
||||||
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||||
if len(x) < 2:
|
if len(x) < 2:
|
||||||
raise ValueError('For Dense, the dimension of input should not be less than 2, while the input dimension is '
|
raise ValueError(f"{msg_prefix} dimension of 'x' should not be less than 2, but got {len(x)}.")
|
||||||
+ f'{len(x)}.')
|
|
||||||
|
|
||||||
|
|
||||||
class Dense(Cell):
|
class Dense(Cell):
|
||||||
|
@ -292,26 +293,32 @@ class Dense(Cell):
|
||||||
if isinstance(weight_init, Tensor):
|
if isinstance(weight_init, Tensor):
|
||||||
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
|
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
|
||||||
weight_init.shape[1] != in_channels:
|
weight_init.shape[1] != in_channels:
|
||||||
raise ValueError("Weight init shape error.")
|
raise ValueError(f"For '{self.cls_name}', weight init shape error. The ndim of 'weight_init' should "
|
||||||
|
f"be equal to 2, and the first dim should be equal to 'out_channels', and the "
|
||||||
|
f"second dim should be equal to 'in_channels'. But got 'weight_init': {weight_init}, "
|
||||||
|
f"'out_channels': {out_channels}, 'in_channels': {in_channels}.")
|
||||||
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
|
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
|
||||||
|
|
||||||
self.bias = None
|
self.bias = None
|
||||||
if self.has_bias:
|
if self.has_bias:
|
||||||
if isinstance(bias_init, Tensor):
|
if isinstance(bias_init, Tensor):
|
||||||
if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
|
if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
|
||||||
raise ValueError("Bias init shape error.")
|
raise ValueError(f"For '{self.cls_name}', bias init shape error. The ndim of 'bias_init' should "
|
||||||
|
f"be equal to 1, and the first dim should be equal to 'out_channels'. But got "
|
||||||
|
f"'bias_init': {bias_init}, 'out_channels': {out_channels}.")
|
||||||
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
|
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
|
||||||
self.bias_add = P.BiasAdd()
|
self.bias_add = P.BiasAdd()
|
||||||
|
|
||||||
self.matmul = P.MatMul(transpose_b=True)
|
self.matmul = P.MatMul(transpose_b=True)
|
||||||
self.activation = get_activation(activation) if isinstance(activation, str) else activation
|
self.activation = get_activation(activation) if isinstance(activation, str) else activation
|
||||||
if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
|
if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
|
||||||
raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation))
|
raise TypeError(f"For '{self.cls_name}', the 'activation' must be str or Cell or Primitive, but got "
|
||||||
|
f"{type(activation)}.")
|
||||||
self.activation_flag = self.activation is not None
|
self.activation_flag = self.activation is not None
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
x_shape = self.shape_op(x)
|
x_shape = self.shape_op(x)
|
||||||
check_dense_input_shape(x_shape)
|
check_dense_input_shape(x_shape, self.cls_name)
|
||||||
if len(x_shape) != 2:
|
if len(x_shape) != 2:
|
||||||
x = self.reshape(x, (-1, x_shape[-1]))
|
x = self.reshape(x, (-1, x_shape[-1]))
|
||||||
x = self.matmul(x, self.weight)
|
x = self.matmul(x, self.weight)
|
||||||
|
@ -341,9 +348,10 @@ def _is_equal_one(x):
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _dtype_check(x_dtype):
|
def _dtype_check(x_dtype, prim_name=None):
|
||||||
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||||
if x_dtype not in [mstype.float32, mstype.float16]:
|
if x_dtype not in [mstype.float32, mstype.float16]:
|
||||||
raise TypeError("The input type must be float32 or float16.")
|
raise TypeError(f"{msg_prefix} x_dtype must be float32 or float16, but got {x_dtype}.")
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
|
@ -430,7 +438,7 @@ class ClipByNorm(Cell):
|
||||||
l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum)))
|
l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum)))
|
||||||
l2norm = self.select_(cond, self.sqrt(l2sum_safe), l2sum)
|
l2norm = self.select_(cond, self.sqrt(l2sum_safe), l2sum)
|
||||||
|
|
||||||
_dtype_check(self.dtype(x))
|
_dtype_check(self.dtype(x), self.cls_name)
|
||||||
if _is_equal_one(clip_norm):
|
if _is_equal_one(clip_norm):
|
||||||
intermediate = x
|
intermediate = x
|
||||||
else:
|
else:
|
||||||
|
@ -792,12 +800,14 @@ class Pad(Cell):
|
||||||
self.paddings = paddings
|
self.paddings = paddings
|
||||||
Validator.check_string(self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"], 'mode', self.cls_name)
|
Validator.check_string(self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"], 'mode', self.cls_name)
|
||||||
if not isinstance(paddings, tuple):
|
if not isinstance(paddings, tuple):
|
||||||
raise TypeError('Paddings must be tuple type.')
|
raise TypeError(f"For '{self.cls_name}', the type of 'paddings' must be tuple, but got {type(paddings)}.")
|
||||||
for item in paddings:
|
for item in paddings:
|
||||||
if len(item) != 2:
|
if len(item) != 2:
|
||||||
raise ValueError('The shape of paddings must be (n, 2).')
|
raise ValueError(f"For '{self.cls_name}', the dimension of 'paddings' must be (n, 2), "
|
||||||
|
f"but got {paddings}.")
|
||||||
if len(paddings) > 4:
|
if len(paddings) > 4:
|
||||||
raise ValueError('Only padding up to 4 dims is supported')
|
raise ValueError(f"For '{self.cls_name}', only 'paddings' up to 4 dims is supported, but got "
|
||||||
|
f"{len(paddings)}.")
|
||||||
if mode == "CONSTANT":
|
if mode == "CONSTANT":
|
||||||
self.pad = P.Pad(self.paddings)
|
self.pad = P.Pad(self.paddings)
|
||||||
else:
|
else:
|
||||||
|
@ -813,17 +823,18 @@ class Pad(Cell):
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def bilinear(shape, size, scale, align_corners):
|
def bilinear(shape, size, scale, align_corners, prim_name=None):
|
||||||
"""Check input and calculate shape"""
|
"""Check input and calculate shape"""
|
||||||
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||||
if not isinstance(align_corners, bool):
|
if not isinstance(align_corners, bool):
|
||||||
raise TypeError("align_corners should be type boolean")
|
raise TypeError(f"{msg_prefix} type of 'align_corners' should be boolean, but got {type(align_corners)}.")
|
||||||
if size is None and scale is None:
|
if size is None and scale is None:
|
||||||
raise ValueError("size and scale both none")
|
raise ValueError(f"{msg_prefix} 'size' and 'scale' both none.")
|
||||||
if size is not None and scale is not None:
|
if size is not None and scale is not None:
|
||||||
raise ValueError("size and scale both not none")
|
raise ValueError(f"{msg_prefix} 'size' and 'scale' both not none.")
|
||||||
if size is not None:
|
if size is not None:
|
||||||
if not isinstance(size, (tuple, list)):
|
if not isinstance(size, (tuple, list)):
|
||||||
raise ValueError("size must be tuple or list")
|
raise ValueError(f"{msg_prefix} 'size' must be tuple or list, but got {type(size)}.")
|
||||||
Validator.check_int(len(size), 2, Rel.EQ, "size", "bilinear")
|
Validator.check_int(len(size), 2, Rel.EQ, "size", "bilinear")
|
||||||
Validator.check_int(size[0], 1, Rel.GE, "size[0]", "bilinear")
|
Validator.check_int(size[0], 1, Rel.GE, "size[0]", "bilinear")
|
||||||
Validator.check_int(size[1], 1, Rel.GE, "size[1]", "bilinear")
|
Validator.check_int(size[1], 1, Rel.GE, "size[1]", "bilinear")
|
||||||
|
@ -958,7 +969,7 @@ class Unfold(Cell):
|
||||||
def _check_tuple_or_list(arg_name, arg_val, prim_name):
|
def _check_tuple_or_list(arg_name, arg_val, prim_name):
|
||||||
Validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.cls_name)
|
Validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.cls_name)
|
||||||
if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1:
|
if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1:
|
||||||
raise ValueError(f"For \'{prim_name}\' the format of {arg_name}s should be [1, {arg_name}_row, "
|
raise ValueError(f"For '{prim_name}' the format of {arg_name}s should be [1, {arg_name}_row, "
|
||||||
f"{arg_name}_col, 1], but got {arg_val}.")
|
f"{arg_name}_col, 1], but got {arg_val}.")
|
||||||
if not isinstance(arg_val[1], int) or not isinstance(arg_val[2], int) or arg_val[1] < 1 or arg_val[2] < 1:
|
if not isinstance(arg_val[1], int) or not isinstance(arg_val[2], int) or arg_val[1] < 1 or arg_val[2] < 1:
|
||||||
raise ValueError(f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in {arg_name}s should be "
|
raise ValueError(f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in {arg_name}s should be "
|
||||||
|
@ -1421,7 +1432,9 @@ class Roll(Cell):
|
||||||
self.op_list.append((inner.Roll(shift=self.shift, axis=0), self.axis))
|
self.op_list.append((inner.Roll(shift=self.shift, axis=0), self.axis))
|
||||||
else:
|
else:
|
||||||
if len(self.shift) != len(self.axis):
|
if len(self.shift) != len(self.axis):
|
||||||
raise ValueError('The shape of shift and the shape of axis must be the same.')
|
raise ValueError(f"For '{self.cls_name}', the shape of 'shift' and the shape of 'axis' must be "
|
||||||
|
f"the same, but got the length of 'shift' {len(self.shift)} and the length of 'axis'"
|
||||||
|
f" {len(self.axis)}.")
|
||||||
for idx, _ in enumerate(self.axis):
|
for idx, _ in enumerate(self.axis):
|
||||||
self.op_list.append((inner.Roll(shift=self.shift[idx], axis=0), self.axis[idx]))
|
self.op_list.append((inner.Roll(shift=self.shift[idx], axis=0), self.axis[idx]))
|
||||||
|
|
||||||
|
|
|
@ -46,7 +46,8 @@ def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
|
||||||
@constexpr
|
@constexpr
|
||||||
def _check_input_3d(input_shape, param_name, func_name):
|
def _check_input_3d(input_shape, param_name, func_name):
|
||||||
if len(input_shape) != 3:
|
if len(input_shape) != 3:
|
||||||
raise ValueError(f"{func_name} {param_name} should be 3d, but got shape {input_shape}")
|
raise ValueError(f"For '{func_name}', the {param_name} should be 3d, but got the length of input_shape:"
|
||||||
|
f" {len(input_shape)}.")
|
||||||
|
|
||||||
|
|
||||||
class LSTM(Cell):
|
class LSTM(Cell):
|
||||||
|
@ -166,7 +167,8 @@ class LSTM(Cell):
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
self.shape = P.Shape()
|
self.shape = P.Shape()
|
||||||
if dropout < 0 or dropout > 1:
|
if dropout < 0 or dropout > 1:
|
||||||
raise ValueError("For LSTM, dropout must be a number in range [0, 1], but got {}".format(dropout))
|
raise ValueError(f"For '{self.cls_name}', the 'dropout' must be a number in range [0, 1], "
|
||||||
|
f"but got {dropout}.")
|
||||||
if dropout == 1:
|
if dropout == 1:
|
||||||
self.dropout_op = P.ZerosLike()
|
self.dropout_op = P.ZerosLike()
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -148,7 +148,7 @@ class Range(Cell):
|
||||||
"""Initialize Range."""
|
"""Initialize Range."""
|
||||||
super(Range, self).__init__()
|
super(Range, self).__init__()
|
||||||
if delta == 0:
|
if delta == 0:
|
||||||
raise ValueError("The input of `delta` can not be equal to zero.")
|
raise ValueError(f"For '{self.cls_name}', the 'delta' can not be zero.")
|
||||||
data = np.arange(start, limit, delta)
|
data = np.arange(start, limit, delta)
|
||||||
if data.dtype == np.float:
|
if data.dtype == np.float:
|
||||||
self.ms_dtype = mstype.float32
|
self.ms_dtype = mstype.float32
|
||||||
|
@ -750,11 +750,13 @@ class LBeta(Cell):
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def get_broadcast_matmul_shape(x_shape, y_shape):
|
def get_broadcast_matmul_shape(x_shape, y_shape, prim_name=None):
|
||||||
"""get broadcast_matmul shape"""
|
"""get broadcast_matmul shape"""
|
||||||
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||||
if (len(x_shape) < 2) or (len(y_shape) < 2):
|
if (len(x_shape) < 2) or (len(y_shape) < 2):
|
||||||
raise ValueError('For matmul, rank of x1 and x2 should be equal to or greater than 2, '
|
raise ValueError(f"{msg_prefix} length of 'x_shape' and 'y_shape' should be equal to or greater than 2, "
|
||||||
+ f'but got {x_shape} and {y_shape}.')
|
f"but got the length of 'x_shape': {len(x_shape)} and the length of 'y_shape': "
|
||||||
|
f"{len(y_shape)}.")
|
||||||
x_shape_batch = x_shape[:-2]
|
x_shape_batch = x_shape[:-2]
|
||||||
y_shape_batch = y_shape[:-2]
|
y_shape_batch = y_shape[:-2]
|
||||||
if x_shape_batch == y_shape_batch:
|
if x_shape_batch == y_shape_batch:
|
||||||
|
@ -771,7 +773,9 @@ def get_broadcast_matmul_shape(x_shape, y_shape):
|
||||||
elif x_shape[i] == y_shape[i]:
|
elif x_shape[i] == y_shape[i]:
|
||||||
broadcast_shape_back.append(x_shape[i])
|
broadcast_shape_back.append(x_shape[i])
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"For MatMul, the x1_shape {x_shape} and x2_shape {y_shape} can not broadcast.")
|
raise ValueError(f"{msg_prefix} 'x_shape[{i}]' should be equal to 1, or the 'y_shape[{i}]' should be equal "
|
||||||
|
f"to 1, or the 'x_shape[{i}]' should be equal to 'y_shape[{i}]', but got "
|
||||||
|
f"'x_shape[{i}]': {x_shape[i]}, 'y_shape[{i}]': {y_shape[i]}.")
|
||||||
|
|
||||||
broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length]
|
broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length]
|
||||||
x_broadcast_shape = broadcast_shape_front + tuple(broadcast_shape_back) + x_shape[-2:]
|
x_broadcast_shape = broadcast_shape_front + tuple(broadcast_shape_back) + x_shape[-2:]
|
||||||
|
@ -780,8 +784,9 @@ def get_broadcast_matmul_shape(x_shape, y_shape):
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def check_col_row_equal(x1_shape, x2_shape, transpose_x1, transpose_x2):
|
def check_col_row_equal(x1_shape, x2_shape, transpose_x1, transpose_x2, prim_name=None):
|
||||||
"""check col and row equal"""
|
"""check col and row equal"""
|
||||||
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||||
if len(x1_shape) == 1:
|
if len(x1_shape) == 1:
|
||||||
transpose_x1 = False
|
transpose_x1 = False
|
||||||
x1_shape = (1,) + x1_shape
|
x1_shape = (1,) + x1_shape
|
||||||
|
@ -793,8 +798,8 @@ def check_col_row_equal(x1_shape, x2_shape, transpose_x1, transpose_x2):
|
||||||
x1_col = x1_last[not transpose_x1] # x1_col = x1_last[1] if (not transpose_a) else x1_last[0]
|
x1_col = x1_last[not transpose_x1] # x1_col = x1_last[1] if (not transpose_a) else x1_last[0]
|
||||||
x2_row = x2_last[transpose_x2] # x2_row = x2_last[0] if (not transpose_b) else x2_last[1]
|
x2_row = x2_last[transpose_x2] # x2_row = x2_last[0] if (not transpose_b) else x2_last[1]
|
||||||
if x1_col != x2_row:
|
if x1_col != x2_row:
|
||||||
raise ValueError('The column of matrix dimensions of x1 should be equal to '
|
raise ValueError(f"{msg_prefix} column of matrix dimensions of 'x1' should be equal to "
|
||||||
+ f'the row of matrix dimensions of x2, but got {x1_col} and {x2_row}.')
|
f"the row of matrix dimensions of 'x2', but got 'x1_col' {x1_col} and 'x2_row' {x2_row}.")
|
||||||
|
|
||||||
|
|
||||||
def matmul_op_select(x1_shape, x2_shape, transpose_x1, transpose_x2):
|
def matmul_op_select(x1_shape, x2_shape, transpose_x1, transpose_x2):
|
||||||
|
|
|
@ -62,17 +62,20 @@ class _BatchNorm(Cell):
|
||||||
super(_BatchNorm, self).__init__()
|
super(_BatchNorm, self).__init__()
|
||||||
validator.check_value_type('num_features', num_features, [int], self.cls_name)
|
validator.check_value_type('num_features', num_features, [int], self.cls_name)
|
||||||
if num_features < 1:
|
if num_features < 1:
|
||||||
raise ValueError("num_features must be at least 1")
|
raise ValueError(f"For '{self.cls_name}', the 'num_features' must be at least 1, but got {num_features}.")
|
||||||
|
|
||||||
if momentum < 0 or momentum > 1:
|
if momentum < 0 or momentum > 1:
|
||||||
raise ValueError("momentum should be a number in range [0, 1], but got {}".format(momentum))
|
raise ValueError(f"For '{self.cls_name}', the 'momentum' should be a number in range [0, 1], "
|
||||||
|
f"but got {momentum}.")
|
||||||
self.input_dims = input_dims
|
self.input_dims = input_dims
|
||||||
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
|
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
|
||||||
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
|
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
|
||||||
raise ValueError("NHWC format only support in GPU target.")
|
raise ValueError(f"For '{self.cls_name}', the 'NHWC' format only support in GPU target, but got device "
|
||||||
|
f"target {context.get_context('device_target')}.")
|
||||||
self.use_batch_statistics = use_batch_statistics
|
self.use_batch_statistics = use_batch_statistics
|
||||||
if self.use_batch_statistics is not None and not isinstance(self.use_batch_statistics, bool):
|
if self.use_batch_statistics is not None and not isinstance(self.use_batch_statistics, bool):
|
||||||
raise ValueError("use_batch_statistics should be a boolean value or None.")
|
raise ValueError(f"For '{self.cls_name}', the 'use_batch_statistics' should be a boolean value or None,"
|
||||||
|
f" but got {use_batch_statistics}.")
|
||||||
self.num_features = num_features
|
self.num_features = num_features
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.moving_mean = Parameter(initializer(
|
self.moving_mean = Parameter(initializer(
|
||||||
|
@ -159,11 +162,14 @@ class _BatchNorm(Cell):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def list_group(self, world_rank, group_size):
|
def list_group(self, world_rank, group_size):
|
||||||
|
""" Check whether world_rank and group_size are valid. """
|
||||||
if group_size > get_group_size():
|
if group_size > get_group_size():
|
||||||
raise ValueError("group size can not be greater than local rank size, group size is {}, "
|
raise ValueError(f"For '{self.cls_name}', the 'group_size' cannot be greater than local rank size, "
|
||||||
"local_rank_size is {}".format(group_size, get_group_size()))
|
f"but got 'group_size': {group_size}, local rank size: {get_group_size()}.")
|
||||||
if len(world_rank) % group_size != 0:
|
if len(world_rank) % group_size != 0:
|
||||||
raise ValueError("please make your group size correct.")
|
raise ValueError(f"For '{self.cls_name}', the dimension of 'world_rank' should be divisible by "
|
||||||
|
f"'group_size', but got the length of 'world_rank': {len(world_rank)}, "
|
||||||
|
f"'group_size': {group_size}.")
|
||||||
world_rank_list = zip(*(iter(world_rank),) * group_size)
|
world_rank_list = zip(*(iter(world_rank),) * group_size)
|
||||||
group_list = [list(i) for i in world_rank_list]
|
group_list = [list(i) for i in world_rank_list]
|
||||||
return group_list
|
return group_list
|
||||||
|
@ -173,7 +179,8 @@ class _BatchNorm(Cell):
|
||||||
for rid in itertools.chain(*process_groups):
|
for rid in itertools.chain(*process_groups):
|
||||||
validator.check_int_range(rid, 0, rank_size, Rel.INC_LEFT, "rank id in process_groups")
|
validator.check_int_range(rid, 0, rank_size, Rel.INC_LEFT, "rank id in process_groups")
|
||||||
if rid in seen:
|
if rid in seen:
|
||||||
raise ValueError("rank id in process_groups should not be duplicated.")
|
raise ValueError(f"For '{self.cls_name}', rank id in 'process_groups' should not be duplicated, "
|
||||||
|
f"but got {rid}.")
|
||||||
seen.add(rid)
|
seen.add(rid)
|
||||||
|
|
||||||
def _create_global_groups(self):
|
def _create_global_groups(self):
|
||||||
|
@ -197,7 +204,7 @@ class _BatchNorm(Cell):
|
||||||
management.create_group(SYNC_BN_GROUP_NAME, self.process_groups[i])
|
management.create_group(SYNC_BN_GROUP_NAME, self.process_groups[i])
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
_shape_check_bn(self.shape(x), self.input_dims)
|
_shape_check_bn(self.shape(x), self.input_dims, self.cls_name)
|
||||||
if self.use_batch_statistics is None:
|
if self.use_batch_statistics is None:
|
||||||
if self.training:
|
if self.training:
|
||||||
return self.bn_train(x,
|
return self.bn_train(x,
|
||||||
|
@ -231,29 +238,33 @@ class _BatchNorm(Cell):
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _channel_check(channel, num_channel):
|
def _channel_check(channel, num_channel, prim_name=None):
|
||||||
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||||
if channel != num_channel:
|
if channel != num_channel:
|
||||||
raise ValueError("the input channel is not equal with num_channel")
|
raise ValueError(f"{msg_prefix} channel should be equal with num_channel, but got channel: "
|
||||||
|
f"{channel}, num_channel: {num_channel}.")
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _shape_check(in_shape):
|
def _shape_check(in_shape, prim_name=None):
|
||||||
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||||
if len(in_shape) != 4:
|
if len(in_shape) != 4:
|
||||||
raise ValueError("The input must has 4 dims.")
|
raise ValueError(f"{msg_prefix} in_shape must has 4 dims, but got the length of in_shape: {len(in_shape)}.")
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _shape_check_bn(in_shape, in_dims):
|
def _shape_check_bn(in_shape, in_dims, prim_name=None):
|
||||||
"""check input dims of batch norm."""
|
"""check input dims of batch norm."""
|
||||||
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||||
dim = len(in_shape)
|
dim = len(in_shape)
|
||||||
if in_dims == '1d' and dim != 2:
|
if in_dims == '1d' and dim != 2:
|
||||||
raise ValueError("The input must has 2 dims.")
|
raise ValueError(f"{msg_prefix} in_shape must have 2 dims, but got {len(in_shape)}.")
|
||||||
if in_dims == '2d' and dim != 4:
|
if in_dims == '2d' and dim != 4:
|
||||||
raise ValueError("The input must has 4 dims.")
|
raise ValueError(f"{msg_prefix} in_shape must have 4 dims, but got {len(in_shape)}.")
|
||||||
if in_dims == '3d' and dim != 5:
|
if in_dims == '3d' and dim != 5:
|
||||||
raise ValueError("The input must has 5 dims.")
|
raise ValueError(f"{msg_prefix} in_shape must have 5 dims, but got {len(in_shape)}.")
|
||||||
if in_dims == 'both' and dim != 2 and dim != 4:
|
if in_dims == 'both' and dim != 2 and dim != 4:
|
||||||
raise ValueError("The input must has 2 dims or 4 dims.")
|
raise ValueError(f"{msg_prefix} in_shape must have 2 dims or 4 dims, but got {len(in_shape)}.")
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
|
@ -472,9 +483,11 @@ class BatchNorm2d(_BatchNorm):
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _check_3d_shape(input_shape):
|
def _check_3d_shape(input_shape, prim_name=None):
|
||||||
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||||
if len(input_shape) != 5:
|
if len(input_shape) != 5:
|
||||||
raise ValueError("For BatchNorm3d, input data must be 5-dimensional.")
|
raise ValueError(f"{msg_prefix} input_shape must be 5-dimensional, but got the length of input_shape: "
|
||||||
|
f"{len(input_shape)}.")
|
||||||
|
|
||||||
|
|
||||||
class BatchNorm3d(Cell):
|
class BatchNorm3d(Cell):
|
||||||
|
@ -570,7 +583,7 @@ class BatchNorm3d(Cell):
|
||||||
|
|
||||||
def construct(self, input_x):
|
def construct(self, input_x):
|
||||||
x_shape = F.shape(input_x)
|
x_shape = F.shape(input_x)
|
||||||
_check_3d_shape(x_shape)
|
_check_3d_shape(x_shape, self.cls_name)
|
||||||
input_x = self.reshape(input_x, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3], x_shape[4]))
|
input_x = self.reshape(input_x, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3], x_shape[4]))
|
||||||
bn2d_out = self.bn2d(input_x)
|
bn2d_out = self.bn2d(input_x)
|
||||||
bn3d_out = self.reshape(bn2d_out, x_shape)
|
bn3d_out = self.reshape(bn2d_out, x_shape)
|
||||||
|
@ -686,7 +699,8 @@ class GlobalBatchNorm(_BatchNorm):
|
||||||
input_dims='both')
|
input_dims='both')
|
||||||
self.group_device_num = validator.check_positive_int(device_num_each_group)
|
self.group_device_num = validator.check_positive_int(device_num_each_group)
|
||||||
if self.group_device_num <= 1:
|
if self.group_device_num <= 1:
|
||||||
raise ValueError("the number of group must be greater than 1.")
|
raise ValueError(f"For '{self.cls_name}', the 'device_num_each_group' must be greater than 1, "
|
||||||
|
f"but got {self.group_device_num}.")
|
||||||
|
|
||||||
def _check_data_dim(self, x):
|
def _check_data_dim(self, x):
|
||||||
if x.dim == 0:
|
if x.dim == 0:
|
||||||
|
@ -874,8 +888,8 @@ class LayerNorm(Cell):
|
||||||
"""Initialize LayerNorm."""
|
"""Initialize LayerNorm."""
|
||||||
super(LayerNorm, self).__init__()
|
super(LayerNorm, self).__init__()
|
||||||
if not isinstance(normalized_shape, (tuple, list)):
|
if not isinstance(normalized_shape, (tuple, list)):
|
||||||
raise TypeError("The type of 'normalized_shape' should be tuple[int] or list[int], but '{}' type is {}."
|
raise TypeError(f"For '{self.cls_name}', the type of 'normalized_shape' should be tuple[int] or list[int], "
|
||||||
.format(normalized_shape, type(normalized_shape)))
|
f"but got '{normalized_shape}' and the type is {type(normalized_shape)}.")
|
||||||
self.normalized_shape = normalized_shape
|
self.normalized_shape = normalized_shape
|
||||||
self.begin_norm_axis = begin_norm_axis
|
self.begin_norm_axis = begin_norm_axis
|
||||||
self.begin_params_axis = begin_params_axis
|
self.begin_params_axis = begin_params_axis
|
||||||
|
@ -985,10 +999,11 @@ class InstanceNorm2d(Cell):
|
||||||
args_input = {"gamma_init": gamma_init, "beta_init": beta_init}
|
args_input = {"gamma_init": gamma_init, "beta_init": beta_init}
|
||||||
self.check_types_valid(args_input, 'InstanceNorm2d')
|
self.check_types_valid(args_input, 'InstanceNorm2d')
|
||||||
if num_features < 1:
|
if num_features < 1:
|
||||||
raise ValueError("num_features must be at least 1")
|
raise ValueError(f"For '{self.cls_name}', the 'num_features' must be at least 1, but got {num_features}.")
|
||||||
|
|
||||||
if momentum < 0 or momentum > 1:
|
if momentum < 0 or momentum > 1:
|
||||||
raise ValueError("momentum should be a number in range [0, 1], but got {}".format(momentum))
|
raise ValueError(f"For '{self.cls_name}', the 'momentum' should be a number in range [0, 1], "
|
||||||
|
f"but got {momentum}.")
|
||||||
self.num_features = num_features
|
self.num_features = num_features
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.input_dims = '2d'
|
self.input_dims = '2d'
|
||||||
|
@ -1007,7 +1022,7 @@ class InstanceNorm2d(Cell):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
_shape_check_bn(self.shape(x), self.input_dims)
|
_shape_check_bn(self.shape(x), self.input_dims, self.cls_name)
|
||||||
return self.instance_bn(x,
|
return self.instance_bn(x,
|
||||||
self.gamma,
|
self.gamma,
|
||||||
self.beta,
|
self.beta,
|
||||||
|
@ -1022,10 +1037,11 @@ class InstanceNorm2d(Cell):
|
||||||
for key, _ in args_dict.items():
|
for key, _ in args_dict.items():
|
||||||
val = args_dict[key]
|
val = args_dict[key]
|
||||||
if not isinstance(val, (Tensor, numbers.Number, str, Initializer)):
|
if not isinstance(val, (Tensor, numbers.Number, str, Initializer)):
|
||||||
raise TypeError(f"[{name}]Supported type for arg {key} is [Tensor, numbers.Number, str, Initializer],"
|
raise TypeError(f"For '{self.cls_name}', the type of args_dict['{key}'] should be in "
|
||||||
f"but got {type(val)}")
|
f"[Tensor, numbers.Number, str, Initializer], but got type {type(val)}.")
|
||||||
if isinstance(val, Tensor) and val.dtype != mstype.float32:
|
if isinstance(val, Tensor) and val.dtype != mstype.float32:
|
||||||
raise TypeError(f"[{name}]The type of arg {key} should be float32, but got {val.dtype}")
|
raise TypeError(f"For '{self.cls_name}', the type of args_dict['{key}'] should be float32, "
|
||||||
|
f"but got {val.dtype}.")
|
||||||
|
|
||||||
|
|
||||||
class GroupNorm(Cell):
|
class GroupNorm(Cell):
|
||||||
|
@ -1090,7 +1106,8 @@ class GroupNorm(Cell):
|
||||||
self.num_groups = validator.check_positive_int(num_groups)
|
self.num_groups = validator.check_positive_int(num_groups)
|
||||||
self.num_channels = validator.check_positive_int(num_channels)
|
self.num_channels = validator.check_positive_int(num_channels)
|
||||||
if num_channels % num_groups != 0:
|
if num_channels % num_groups != 0:
|
||||||
raise ValueError("num_channels should be divided by num_groups")
|
raise ValueError(f"For '{self.cls_name}', the 'num_channels' should be divided by 'num_groups', "
|
||||||
|
f"but got 'num_channels': {num_channels}, 'num_groups': {num_groups}.")
|
||||||
self.eps = validator.check_value_type('eps', eps, (float,), type(self).__name__)
|
self.eps = validator.check_value_type('eps', eps, (float,), type(self).__name__)
|
||||||
self.affine = validator.check_bool(affine)
|
self.affine = validator.check_bool(affine)
|
||||||
|
|
||||||
|
@ -1112,7 +1129,7 @@ class GroupNorm(Cell):
|
||||||
def _cal_output(self, x):
|
def _cal_output(self, x):
|
||||||
"""calculate groupnorm output"""
|
"""calculate groupnorm output"""
|
||||||
batch, channel, height, width = self.shape(x)
|
batch, channel, height, width = self.shape(x)
|
||||||
_channel_check(channel, self.num_channels)
|
_channel_check(channel, self.num_channels, self.cls_name)
|
||||||
x = self.reshape(x, (batch, self.num_groups, -1))
|
x = self.reshape(x, (batch, self.num_groups, -1))
|
||||||
mean = self.reduce_mean(x, 2)
|
mean = self.reduce_mean(x, 2)
|
||||||
var = self.reduce_sum(self.square(x - mean), 2) / (channel * height * width / self.num_groups)
|
var = self.reduce_sum(self.square(x - mean), 2) / (channel * height * width / self.num_groups)
|
||||||
|
@ -1123,7 +1140,7 @@ class GroupNorm(Cell):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
_shape_check(self.shape(x))
|
_shape_check(self.shape(x), self.cls_name)
|
||||||
output = self._cal_output(x)
|
output = self._cal_output(x)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
|
@ -32,12 +32,13 @@ class _PoolNd(Cell):
|
||||||
self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.cls_name)
|
self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.cls_name)
|
||||||
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
|
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
|
||||||
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
|
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
|
||||||
raise ValueError("NHWC format only support in GPU target.")
|
raise ValueError(f"For '{self.cls_name}, the 'NHWC' format only support in GPU target, but got device "
|
||||||
|
f"target {context.get_context('device_target')}.")
|
||||||
|
|
||||||
def _check_int_or_tuple(arg_name, arg_value):
|
def _check_int_or_tuple(arg_name, arg_value):
|
||||||
validator.check_value_type(arg_name, arg_value, [int, tuple], self.cls_name)
|
validator.check_value_type(arg_name, arg_value, [int, tuple], self.cls_name)
|
||||||
error_msg = f'For \'{self.cls_name}\' the {arg_name} should be an positive int number or ' \
|
error_msg = f"For '{self.cls_name}', the '{arg_name}' should be an positive int number or " \
|
||||||
f'a tuple of two positive int numbers, but got {arg_value}'
|
f"a tuple of two positive int numbers, but got {arg_value}"
|
||||||
if isinstance(arg_value, int):
|
if isinstance(arg_value, int):
|
||||||
if arg_value <= 0:
|
if arg_value <= 0:
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
@ -61,9 +62,10 @@ class _PoolNd(Cell):
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _shape_check(in_shape):
|
def _shape_check(in_shape, prim_name=None):
|
||||||
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||||
if len(in_shape) != 3:
|
if len(in_shape) != 3:
|
||||||
raise ValueError("The input must has 3 dim")
|
raise ValueError(f"{msg_prefix} input must has 3 dim, but got {len(in_shape)}")
|
||||||
|
|
||||||
|
|
||||||
class MaxPool2d(_PoolNd):
|
class MaxPool2d(_PoolNd):
|
||||||
|
@ -216,7 +218,7 @@ class MaxPool1d(_PoolNd):
|
||||||
self.squeeze = P.Squeeze(2)
|
self.squeeze = P.Squeeze(2)
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
_shape_check(self.shape(x))
|
_shape_check(self.shape(x), self.cls_name)
|
||||||
x = self.expand(x, 2)
|
x = self.expand(x, 2)
|
||||||
output = self.max_pool(x)
|
output = self.max_pool(x)
|
||||||
output = self.squeeze(output)
|
output = self.squeeze(output)
|
||||||
|
@ -382,7 +384,7 @@ class AvgPool1d(_PoolNd):
|
||||||
self.squeeze = P.Squeeze(2)
|
self.squeeze = P.Squeeze(2)
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
x = F.depend(x, _shape_check(self.shape(x)))
|
x = F.depend(x, _shape_check(self.shape(x), self.cls_name))
|
||||||
batch, channel, width = self.shape(x)
|
batch, channel, width = self.shape(x)
|
||||||
if width == self.kernel_size[1]:
|
if width == self.kernel_size[1]:
|
||||||
x = self.reduce_mean(x, 2)
|
x = self.reduce_mean(x, 2)
|
||||||
|
|
|
@ -408,13 +408,15 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver):
|
||||||
min_array = self._get_init_array(self.min_init)
|
min_array = self._get_init_array(self.min_init)
|
||||||
max_array = self._get_init_array(self.max_init)
|
max_array = self._get_init_array(self.max_init)
|
||||||
if not np.greater(max_array, min_array).all():
|
if not np.greater(max_array, min_array).all():
|
||||||
raise ValueError("`min_init` is not less than `max_init`, please reset the initial value.")
|
raise ValueError(f"For '{self.cls_name}', the 'max_init' should be greater than 'min_init', "
|
||||||
|
f"but got 'max_array': {max_array}, 'min_init': {min_init}.")
|
||||||
if self.mode == "DEFAULT":
|
if self.mode == "DEFAULT":
|
||||||
self._default_init(min_array, max_array)
|
self._default_init(min_array, max_array)
|
||||||
elif self.mode == "LEARNED_SCALE":
|
elif self.mode == "LEARNED_SCALE":
|
||||||
self._learned_scale_init(min_array, max_array)
|
self._learned_scale_init(min_array, max_array)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid mode, currently only valid for `DEFAULT` and `LEARNED_SCALE` mode.")
|
raise ValueError(f"For '{self.cls_name}', only `DEFAULT` and `LEARNED_SCALE` mode are valid, but got "
|
||||||
|
f"'mode': {self.mode}.")
|
||||||
|
|
||||||
def reset(self, quant_dtype=QuantDtype.INT8, min_init=-6, max_init=6):
|
def reset(self, quant_dtype=QuantDtype.INT8, min_init=-6, max_init=6):
|
||||||
r"""
|
r"""
|
||||||
|
@ -433,12 +435,14 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver):
|
||||||
min_array = self._get_init_array(self.min_init)
|
min_array = self._get_init_array(self.min_init)
|
||||||
max_array = self._get_init_array(self.max_init)
|
max_array = self._get_init_array(self.max_init)
|
||||||
if not np.greater(max_array, min_array).all():
|
if not np.greater(max_array, min_array).all():
|
||||||
raise ValueError("`min_init` is not less than `max_init`, please reset the initial value.")
|
raise ValueError(f"For '{self.cls_name}', the 'max_init' should be greater than 'min_init', "
|
||||||
|
f"but got 'max_array': {max_array}, 'min_init': {min_init}.")
|
||||||
|
|
||||||
self.minq.set_data(Tensor(min_array))
|
self.minq.set_data(Tensor(min_array))
|
||||||
self.maxq.set_data(Tensor(max_array))
|
self.maxq.set_data(Tensor(max_array))
|
||||||
self.quant_max.set_data(Tensor(np.array([self._quant_max]).astype(np.float32)))
|
self.quant_max.set_data(Tensor(np.array([self._quant_max]).astype(np.float32)))
|
||||||
else:
|
else:
|
||||||
raise ValueError("The `reset` function is currently only valid for `LEARNED_SCALE` mode.")
|
raise ValueError(f"For '{self.cls_name}', only `LEARNED_SCALE` mode is valid, but got 'mode': {self.mode}.")
|
||||||
|
|
||||||
def _default_init(self, min_array, max_array):
|
def _default_init(self, min_array, max_array):
|
||||||
"""
|
"""
|
||||||
|
@ -479,16 +483,18 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver):
|
||||||
Initialization of `LEARNED_SCALE` mode.
|
Initialization of `LEARNED_SCALE` mode.
|
||||||
"""
|
"""
|
||||||
if not self.symmetric:
|
if not self.symmetric:
|
||||||
raise ValueError("The 'LEARNED_SCALE' mode only support symmetric quant, "
|
raise ValueError(f"For '{self.cls_name}', the 'LEARNED_SCALE' mode only support 'symmetric' quant, "
|
||||||
"please set symmetric to True.")
|
f"but got 'symmetric': {self.symmetric}. Please set 'symmetric' to True.")
|
||||||
if self.neg_trunc:
|
if self.neg_trunc:
|
||||||
min_array = self._get_init_array(0)
|
min_array = self._get_init_array(0)
|
||||||
if self.narrow_range:
|
if self.narrow_range:
|
||||||
raise ValueError("The 'LEARNED_SCALE' mode only support the combination of "
|
raise ValueError(f"For '{self.cls_name}', the 'LEARNED_SCALE' mode only support the combination of "
|
||||||
"neg_trunc=True and narrow_range=False config scenario.")
|
f"'neg_trunc=True and narrow_range=False' config scenario, but got 'narrow_range': "
|
||||||
|
f"{self.narrow_range}.")
|
||||||
elif not self.narrow_range:
|
elif not self.narrow_range:
|
||||||
raise ValueError("The 'LEARNED_SCALE' mode only support narrow_range=True config, "
|
raise ValueError(f"For '{self.cls_name}', the 'LEARNED_SCALE' mode only support 'narrow_range=True' "
|
||||||
"except for neg_trunc=True scenario.")
|
f"config, except for 'neg_trunc=True' scenario. But got 'narrow_range': "
|
||||||
|
f"{self.narrow_range}.")
|
||||||
|
|
||||||
self._calculate_quant_max()
|
self._calculate_quant_max()
|
||||||
|
|
||||||
|
@ -514,11 +520,11 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver):
|
||||||
Convert the initial value to array.
|
Convert the initial value to array.
|
||||||
"""
|
"""
|
||||||
if isinstance(init_date, list) and self.per_channel and len(init_date) != self.num_channels:
|
if isinstance(init_date, list) and self.per_channel and len(init_date) != self.num_channels:
|
||||||
raise ValueError("The length of the min_init/max_init list should be equal to num_channels for "
|
raise ValueError(f"For '{self.cls_name}', the length of 'min_init/max_init' list should be equal to "
|
||||||
"perchannel quant scenario, but get {}".format(len(init_date)))
|
f"'num_channels' for perchannel quant scenario, but got {len(init_date)}.")
|
||||||
if isinstance(init_date, list) and not self.per_channel and len(init_date) != 1:
|
if isinstance(init_date, list) and not self.per_channel and len(init_date) != 1:
|
||||||
raise ValueError("The length of the min_init/max_init list should be 1 for perlayer quant "
|
raise ValueError(f"For '{self.cls_name}', the length of the 'min_init/max_init' list should be 1 for "
|
||||||
"scenario, but get {}".format(len(init_date)))
|
f"perlayer quant scenario, but got {len(init_date)}.")
|
||||||
|
|
||||||
if isinstance(init_date, list):
|
if isinstance(init_date, list):
|
||||||
min_max_array = np.array(init_date).astype(np.float32)
|
min_max_array = np.array(init_date).astype(np.float32)
|
||||||
|
@ -690,8 +696,8 @@ class Conv2dBnFoldQuantOneConv(Cell):
|
||||||
for dilation_elem in self.dilation:
|
for dilation_elem in self.dilation:
|
||||||
Validator.check_positive_int(dilation_elem, 'dilation item', self.cls_name)
|
Validator.check_positive_int(dilation_elem, 'dilation item', self.cls_name)
|
||||||
if pad_mode not in ('valid', 'same', 'pad'):
|
if pad_mode not in ('valid', 'same', 'pad'):
|
||||||
raise ValueError('Attr \'pad_mode\' of \'Conv2dBnFoldQuant\' Op passed '
|
raise ValueError(f"For '{self.cls_name}', the 'pad_mode' should be one of values "
|
||||||
+ str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.')
|
f"in ('valid', 'same', 'pad'), but got {pad_mode}.")
|
||||||
self.pad_mode = pad_mode
|
self.pad_mode = pad_mode
|
||||||
if isinstance(padding, int):
|
if isinstance(padding, int):
|
||||||
Validator.check_non_negative_int(padding, 'padding', self.cls_name)
|
Validator.check_non_negative_int(padding, 'padding', self.cls_name)
|
||||||
|
@ -701,7 +707,8 @@ class Conv2dBnFoldQuantOneConv(Cell):
|
||||||
Validator.check_non_negative_int(pad, 'padding item', self.cls_name)
|
Validator.check_non_negative_int(pad, 'padding item', self.cls_name)
|
||||||
self.padding = padding
|
self.padding = padding
|
||||||
else:
|
else:
|
||||||
raise TypeError("padding type must be int/tuple(int) cannot be {}!".format(type(padding)))
|
raise TypeError(f"For '{self.cls_name}', the type of 'padding' must be int/tuple(int), but got "
|
||||||
|
f"{type(padding)}!")
|
||||||
self.group = Validator.check_positive_int(group)
|
self.group = Validator.check_positive_int(group)
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.momentum = 1 - momentum
|
self.momentum = 1 - momentum
|
||||||
|
@ -931,8 +938,8 @@ class Conv2dBnFoldQuant(Cell):
|
||||||
for dilation_elem in self.dilation:
|
for dilation_elem in self.dilation:
|
||||||
Validator.check_positive_int(dilation_elem, 'dilation item', self.cls_name)
|
Validator.check_positive_int(dilation_elem, 'dilation item', self.cls_name)
|
||||||
if pad_mode not in ('valid', 'same', 'pad'):
|
if pad_mode not in ('valid', 'same', 'pad'):
|
||||||
raise ValueError('Attr \'pad_mode\' of \'Conv2dBnFoldQuant\' Op passed '
|
raise ValueError(f"For '{self.cls_name}', the 'pad_mode' should be one of values in "
|
||||||
+ str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.')
|
f"('valid', 'same', 'pad'), but got {pad_mode}.")
|
||||||
self.pad_mode = pad_mode
|
self.pad_mode = pad_mode
|
||||||
if isinstance(padding, int):
|
if isinstance(padding, int):
|
||||||
Validator.check_non_negative_int(padding, 'padding', self.cls_name)
|
Validator.check_non_negative_int(padding, 'padding', self.cls_name)
|
||||||
|
@ -942,7 +949,8 @@ class Conv2dBnFoldQuant(Cell):
|
||||||
Validator.check_non_negative_int(pad, 'padding item', self.cls_name)
|
Validator.check_non_negative_int(pad, 'padding item', self.cls_name)
|
||||||
self.padding = padding
|
self.padding = padding
|
||||||
else:
|
else:
|
||||||
raise TypeError("padding type must be int/tuple(int) cannot be {}!".format(type(padding)))
|
raise TypeError(f"For '{self.cls_name}', the type of 'padding' must be int/tuple(int), "
|
||||||
|
f"but got {type(padding)}!")
|
||||||
self.group = Validator.check_positive_int(group)
|
self.group = Validator.check_positive_int(group)
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.momentum = momentum
|
self.momentum = momentum
|
||||||
|
@ -990,7 +998,8 @@ class Conv2dBnFoldQuant(Cell):
|
||||||
self.batchnorm_fold2_train = Q.BatchNormFold2(freeze_bn=freeze_bn)
|
self.batchnorm_fold2_train = Q.BatchNormFold2(freeze_bn=freeze_bn)
|
||||||
self.batchnorm_fold2_infer = Q.BatchNormFold2(freeze_bn=0)
|
self.batchnorm_fold2_infer = Q.BatchNormFold2(freeze_bn=0)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported platform: {}".format(context.get_context('device_target')))
|
raise ValueError(f"For '{self.cls_name}', only the 'Ascend' and 'GPU' platforms"
|
||||||
|
f" are supported, but got {context.get_context('device_target')}.")
|
||||||
self.step = Parameter(initializer('normal', [1], dtype=mstype.int32), name='step', requires_grad=False)
|
self.step = Parameter(initializer('normal', [1], dtype=mstype.int32), name='step', requires_grad=False)
|
||||||
self.one = Tensor(1, mstype.int32)
|
self.one = Tensor(1, mstype.int32)
|
||||||
self.assignadd = P.AssignAdd()
|
self.assignadd = P.AssignAdd()
|
||||||
|
@ -1141,8 +1150,8 @@ class Conv2dBnWithoutFoldQuant(Cell):
|
||||||
for dilation_elem in self.dilation:
|
for dilation_elem in self.dilation:
|
||||||
Validator.check_positive_int(dilation_elem, 'dilation item', self.cls_name)
|
Validator.check_positive_int(dilation_elem, 'dilation item', self.cls_name)
|
||||||
if pad_mode not in ('valid', 'same', 'pad'):
|
if pad_mode not in ('valid', 'same', 'pad'):
|
||||||
raise ValueError('Attr \'pad_mode\' of \'Conv2dBnWithoutFoldQuant\' Op passed '
|
raise ValueError(f"For '{self.cls_name}', the 'pad_mode' should be one of values in "
|
||||||
+ str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.')
|
f"('valid', 'same', 'pad'), but got {pad_mode}.")
|
||||||
self.pad_mode = pad_mode
|
self.pad_mode = pad_mode
|
||||||
if isinstance(padding, int):
|
if isinstance(padding, int):
|
||||||
Validator.check_non_negative_int(padding, 'padding', self.cls_name)
|
Validator.check_non_negative_int(padding, 'padding', self.cls_name)
|
||||||
|
@ -1152,9 +1161,9 @@ class Conv2dBnWithoutFoldQuant(Cell):
|
||||||
Validator.check_non_negative_int(pad, 'padding item', self.cls_name)
|
Validator.check_non_negative_int(pad, 'padding item', self.cls_name)
|
||||||
self.padding = padding
|
self.padding = padding
|
||||||
else:
|
else:
|
||||||
raise TypeError("padding type must be int/tuple(int) cannot be {}!".format(type(padding)))
|
raise TypeError(f"For '{self.cls_name}', the type of 'padding' must be int/tuple(int), "
|
||||||
|
f"but got {type(padding)}!")
|
||||||
self.group = Validator.check_positive_int(group)
|
self.group = Validator.check_positive_int(group)
|
||||||
|
|
||||||
self.bias_add = P.BiasAdd()
|
self.bias_add = P.BiasAdd()
|
||||||
if Validator.check_bool(has_bias):
|
if Validator.check_bool(has_bias):
|
||||||
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
|
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
|
||||||
|
@ -1285,8 +1294,8 @@ class Conv2dQuant(Cell):
|
||||||
for dilation_elem in self.dilation:
|
for dilation_elem in self.dilation:
|
||||||
Validator.check_positive_int(dilation_elem, 'dilation item', self.cls_name)
|
Validator.check_positive_int(dilation_elem, 'dilation item', self.cls_name)
|
||||||
if pad_mode not in ('valid', 'same', 'pad'):
|
if pad_mode not in ('valid', 'same', 'pad'):
|
||||||
raise ValueError('Attr \'pad_mode\' of \'Conv2dQuant\' Op passed '
|
raise ValueError(f"For '{self.cls_name}', the 'pad_mode' should be one of values "
|
||||||
+ str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.')
|
f"in ('valid', 'same', 'pad'), but got {pad_mode}.")
|
||||||
self.pad_mode = pad_mode
|
self.pad_mode = pad_mode
|
||||||
if isinstance(padding, int):
|
if isinstance(padding, int):
|
||||||
Validator.check_non_negative_int(padding, 'padding', self.cls_name)
|
Validator.check_non_negative_int(padding, 'padding', self.cls_name)
|
||||||
|
@ -1296,7 +1305,8 @@ class Conv2dQuant(Cell):
|
||||||
Validator.check_non_negative_int(pad, 'padding item', self.cls_name)
|
Validator.check_non_negative_int(pad, 'padding item', self.cls_name)
|
||||||
self.padding = padding
|
self.padding = padding
|
||||||
else:
|
else:
|
||||||
raise TypeError("padding type must be int/tuple(int) cannot be {}!".format(type(padding)))
|
raise TypeError(f"For '{self.cls_name}', the type of 'padding' must be int/tuple(int), "
|
||||||
|
f"but got {type(padding)}!")
|
||||||
self.group = Validator.check_positive_int(group)
|
self.group = Validator.check_positive_int(group)
|
||||||
|
|
||||||
weight_shape = [out_channels, in_channels // group, *self.kernel_size]
|
weight_shape = [out_channels, in_channels // group, *self.kernel_size]
|
||||||
|
@ -1414,7 +1424,10 @@ class DenseQuant(Cell):
|
||||||
if isinstance(weight_init, Tensor):
|
if isinstance(weight_init, Tensor):
|
||||||
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
|
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
|
||||||
weight_init.shape[1] != in_channels:
|
weight_init.shape[1] != in_channels:
|
||||||
raise ValueError("weight_init shape error")
|
raise ValueError(f"For '{self.cls_name}', weight init shape error. The ndim of 'weight_init' should "
|
||||||
|
f"be equal to 2, and the first dim should be equal to 'out_channels', and the "
|
||||||
|
f"second dim should be equal to 'in_channels'. But got 'weight_init': {weight_init}, "
|
||||||
|
f"'out_channels': {out_channels}, 'in_channels': {in_channels}.")
|
||||||
|
|
||||||
self.weight = Parameter(initializer(
|
self.weight = Parameter(initializer(
|
||||||
weight_init, [out_channels, in_channels]), name="weight")
|
weight_init, [out_channels, in_channels]), name="weight")
|
||||||
|
@ -1422,7 +1435,9 @@ class DenseQuant(Cell):
|
||||||
if self.has_bias:
|
if self.has_bias:
|
||||||
if isinstance(bias_init, Tensor):
|
if isinstance(bias_init, Tensor):
|
||||||
if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
|
if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
|
||||||
raise ValueError("bias_init shape error")
|
raise ValueError(f"For '{self.cls_name}', bias init shape error. The ndim of 'bias_init' should "
|
||||||
|
f"be equal to 1, and the first dim should be equal to 'out_channels'. But got "
|
||||||
|
f"'bias_init': {bias_init}, 'out_channels': {out_channels}.")
|
||||||
|
|
||||||
self.bias = Parameter(initializer(
|
self.bias = Parameter(initializer(
|
||||||
bias_init, [out_channels]), name="bias")
|
bias_init, [out_channels]), name="bias")
|
||||||
|
@ -1432,7 +1447,9 @@ class DenseQuant(Cell):
|
||||||
|
|
||||||
self.activation = get_activation(activation) if isinstance(activation, str) else activation
|
self.activation = get_activation(activation) if isinstance(activation, str) else activation
|
||||||
if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
|
if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
|
||||||
raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation))
|
raise TypeError(f"For '{self.cls_name}', the 'activation' must be str or Cell or Primitive, "
|
||||||
|
f"but got {activation}.")
|
||||||
|
|
||||||
self.activation_flag = self.activation is not None
|
self.activation_flag = self.activation is not None
|
||||||
self.fake_quant_weight = quant_config.weight(ema=False,
|
self.fake_quant_weight = quant_config.weight(ema=False,
|
||||||
channel_axis=0,
|
channel_axis=0,
|
||||||
|
|
|
@ -110,7 +110,8 @@ class _DynamicRNN(Cell):
|
||||||
elif mode == "GRU":
|
elif mode == "GRU":
|
||||||
cell = _gru_cell
|
cell = _gru_cell
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unrecognized RNN mode: " + mode)
|
raise ValueError(f"For '{self.cls_name}', the 'mode' should be in ['RNN_RELU', 'RNN_TANH', 'LSTM', 'GRU'], "
|
||||||
|
f"but got {mode}.")
|
||||||
self.cell = cell
|
self.cell = cell
|
||||||
self.is_lstm = mode == "LSTM"
|
self.is_lstm = mode == "LSTM"
|
||||||
|
|
||||||
|
@ -187,10 +188,8 @@ class _RNNBase(Cell):
|
||||||
validator.check_value_type("bidirectional", bidirectional, [bool], self.cls_name)
|
validator.check_value_type("bidirectional", bidirectional, [bool], self.cls_name)
|
||||||
|
|
||||||
if not 0 <= dropout < 1:
|
if not 0 <= dropout < 1:
|
||||||
raise ValueError(f"For '{self.cls_name}', "
|
raise ValueError(f"For '{self.cls_name}', the 'dropout' should be a number in range [0, 1) "
|
||||||
"dropout should be a number in range [0, 1) "
|
f"representing the probability of an element being zeroed, but got {dropout}.")
|
||||||
"representing the probability of an element being "
|
|
||||||
"zeroed")
|
|
||||||
|
|
||||||
if dropout > 0 and num_layers == 1:
|
if dropout > 0 and num_layers == 1:
|
||||||
logger.warning("dropout option adds dropout after all but last "
|
logger.warning("dropout option adds dropout after all but last "
|
||||||
|
@ -206,7 +205,8 @@ class _RNNBase(Cell):
|
||||||
elif mode == "RNN_RELU":
|
elif mode == "RNN_RELU":
|
||||||
gate_size = hidden_size
|
gate_size = hidden_size
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unrecognized RNN mode: " + mode)
|
raise ValueError(f"For '{self.cls_name}', the 'mode' should be in ['RNN_RELU', 'RNN_TANH', 'LSTM', 'GRU'], "
|
||||||
|
f"but got {mode}.")
|
||||||
|
|
||||||
self.reverse = P.ReverseV2([0])
|
self.reverse = P.ReverseV2([0])
|
||||||
self.reverse_sequence = P.ReverseSequence(0, 1)
|
self.reverse_sequence = P.ReverseSequence(0, 1)
|
||||||
|
@ -424,8 +424,8 @@ class RNN(_RNNBase):
|
||||||
elif kwargs['nonlinearity'] == 'relu':
|
elif kwargs['nonlinearity'] == 'relu':
|
||||||
mode = 'RNN_RELU'
|
mode = 'RNN_RELU'
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown nonlinearity '{}'".format(
|
raise ValueError(f"For '{self.cls_name}', the 'nonlinearity' should be in ['tanh', 'relu'], "
|
||||||
kwargs['nonlinearity']))
|
f"but got {kwargs['nonlinearity']}.")
|
||||||
del kwargs['nonlinearity']
|
del kwargs['nonlinearity']
|
||||||
else:
|
else:
|
||||||
mode = 'RNN_TANH'
|
mode = 'RNN_TANH'
|
||||||
|
@ -583,7 +583,8 @@ class RNNCell(_RNNCellBase):
|
||||||
def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True, nonlinearity: str = "tanh"):
|
def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True, nonlinearity: str = "tanh"):
|
||||||
super().__init__(input_size, hidden_size, has_bias, num_chunks=1)
|
super().__init__(input_size, hidden_size, has_bias, num_chunks=1)
|
||||||
if nonlinearity not in self._non_linearity:
|
if nonlinearity not in self._non_linearity:
|
||||||
raise ValueError("Unknown nonlinearity: {}".format(nonlinearity))
|
raise ValueError(f"For '{self.cls_name}', the 'nonlinearity' should be in ['tanh', 'relu'], "
|
||||||
|
f"but got {nonlinearity}.")
|
||||||
self.nonlinearity = nonlinearity
|
self.nonlinearity = nonlinearity
|
||||||
|
|
||||||
def construct(self, x, hx):
|
def construct(self, x, hx):
|
||||||
|
|
|
@ -98,13 +98,18 @@ class DenseThor(Cell):
|
||||||
if isinstance(weight_init, Tensor):
|
if isinstance(weight_init, Tensor):
|
||||||
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
|
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
|
||||||
weight_init.shape[1] != in_channels:
|
weight_init.shape[1] != in_channels:
|
||||||
raise ValueError("Weight init shape error.")
|
raise ValueError(f"For '{self.cls_name}', weight init shape error. The dim of 'weight_init' should "
|
||||||
|
f"be equal to 2, and the first dim should be equal to 'out_channels', and the "
|
||||||
|
f"second dim should be equal to 'in_channels'. But got 'weight_init': {weight_init}, "
|
||||||
|
f"'out_channels': {out_channels}, 'in_channels': {in_channels}.")
|
||||||
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
|
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
|
||||||
self.bias = None
|
self.bias = None
|
||||||
if self.has_bias:
|
if self.has_bias:
|
||||||
if isinstance(bias_init, Tensor):
|
if isinstance(bias_init, Tensor):
|
||||||
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
|
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
|
||||||
raise ValueError("Bias init shape error.")
|
raise ValueError(f"For '{self.cls_name}', bias init shape error. The dim of 'bias_init' should "
|
||||||
|
f"be equal to 1, and the first dim should be equal to 'out_channels'. But got "
|
||||||
|
f"'bias_init': {bias_init}, 'out_channels': {out_channels}.")
|
||||||
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
|
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
|
||||||
self.bias_add = P.BiasAdd()
|
self.bias_add = P.BiasAdd()
|
||||||
|
|
||||||
|
@ -211,7 +216,8 @@ class _ConvThor(Cell):
|
||||||
Validator.check_non_negative_int(padding, 'padding', self.cls_name)
|
Validator.check_non_negative_int(padding, 'padding', self.cls_name)
|
||||||
self.padding = padding
|
self.padding = padding
|
||||||
else:
|
else:
|
||||||
raise TypeError("padding type must be int or tuple(int) cannot be {}!".format(type(padding)))
|
raise TypeError(f"For '{self.cls_name}', the type of 'padding' must be int/tuple(int), but got "
|
||||||
|
f"{type(padding)}.")
|
||||||
|
|
||||||
self.dilation = dilation
|
self.dilation = dilation
|
||||||
self.group = Validator.check_positive_int(group)
|
self.group = Validator.check_positive_int(group)
|
||||||
|
@ -220,11 +226,11 @@ class _ConvThor(Cell):
|
||||||
self.__validate_stride(stride)
|
self.__validate_stride(stride)
|
||||||
self.__validate_dilation(dilation)
|
self.__validate_dilation(dilation)
|
||||||
if in_channels % group != 0:
|
if in_channels % group != 0:
|
||||||
raise ValueError("Attr 'in_channels' of 'Conv2DThor' Op must be divisible by "
|
raise ValueError(f"For '{self.cls_name}', the 'in_channels' must be divisible by 'group', but got "
|
||||||
"attr 'group' of 'Conv2DThor' Op.")
|
f"'in_channels': {in_channels} and 'group': {group}.")
|
||||||
if out_channels % group != 0:
|
if out_channels % group != 0:
|
||||||
raise ValueError("Attr 'out_channels' of 'Conv2DThor' Op must be divisible by "
|
raise ValueError(f"For '{self.cls_name}', the 'out_channels' must be divisible by 'group', but got "
|
||||||
"attr 'group' of 'Conv2DThor' Op.")
|
f"'out_channels': {out_channels} and 'group': {group}.")
|
||||||
if not transposed:
|
if not transposed:
|
||||||
shape = [out_channels, in_channels // group, *kernel_size]
|
shape = [out_channels, in_channels // group, *kernel_size]
|
||||||
else:
|
else:
|
||||||
|
@ -243,22 +249,22 @@ class _ConvThor(Cell):
|
||||||
if (not isinstance(kernel_size[0], int)) or (not isinstance(kernel_size[1], int)) or \
|
if (not isinstance(kernel_size[0], int)) or (not isinstance(kernel_size[1], int)) or \
|
||||||
isinstance(kernel_size[0], bool) or isinstance(kernel_size[1], bool) or \
|
isinstance(kernel_size[0], bool) or isinstance(kernel_size[1], bool) or \
|
||||||
kernel_size[0] < 1 or kernel_size[1] < 1:
|
kernel_size[0] < 1 or kernel_size[1] < 1:
|
||||||
raise ValueError("Attr 'kernel_size' of 'Conv2D' Op passed "
|
raise ValueError(f"For '{self.cls_name}', all elements in 'kernel_size' should be int or tuple and "
|
||||||
+ str(self.kernel_size) + ", should be a int or tuple and equal to or greater than 1.")
|
f"equal to or greater than 1, but got 'kernel_size': {kernel_size}.")
|
||||||
|
|
||||||
def __validate_stride(self, stride):
|
def __validate_stride(self, stride):
|
||||||
"""validate stride."""
|
"""validate stride."""
|
||||||
if (not isinstance(stride[0], int)) or (not isinstance(stride[1], int)) or \
|
if (not isinstance(stride[0], int)) or (not isinstance(stride[1], int)) or \
|
||||||
isinstance(stride[0], bool) or isinstance(stride[1], bool) or stride[0] < 1 or stride[1] < 1:
|
isinstance(stride[0], bool) or isinstance(stride[1], bool) or stride[0] < 1 or stride[1] < 1:
|
||||||
raise ValueError("Attr 'stride' of 'Conv2D' Op passed "
|
raise ValueError(f"For '{self.cls_name}', all elements in 'stride' should be int or tuple and "
|
||||||
+ str(self.stride) + ", should be a int or tuple and equal to or greater than 1.")
|
f"equal to or greater than 1, but got 'stride': {stride}.")
|
||||||
|
|
||||||
def __validate_dilation(self, dilation):
|
def __validate_dilation(self, dilation):
|
||||||
"""validate dilation."""
|
"""validate dilation."""
|
||||||
if (not isinstance(dilation[0], int)) or (not isinstance(dilation[1], int)) or \
|
if (not isinstance(dilation[0], int)) or (not isinstance(dilation[1], int)) or \
|
||||||
isinstance(dilation[0], bool) or isinstance(dilation[1], bool) or dilation[0] < 1 or dilation[1] < 1:
|
isinstance(dilation[0], bool) or isinstance(dilation[1], bool) or dilation[0] < 1 or dilation[1] < 1:
|
||||||
raise ValueError("Attr 'dilation' of 'Conv2D' Op passed "
|
raise ValueError(f"For '{self.cls_name}', all elements in 'dilation' should be int or tuple and "
|
||||||
+ str(self.dilation) + ", should be a int or tuple and equal to or greater than 1.")
|
f"equal to or greater than 1, but got 'dilation': {dilation}.")
|
||||||
|
|
||||||
|
|
||||||
class Conv2dThor(_ConvThor):
|
class Conv2dThor(_ConvThor):
|
||||||
|
@ -420,8 +426,8 @@ class Conv2dThor(_ConvThor):
|
||||||
"""Initialize depthwise conv2d op"""
|
"""Initialize depthwise conv2d op"""
|
||||||
if context.get_context("device_target") == "Ascend" and self.group > 1:
|
if context.get_context("device_target") == "Ascend" and self.group > 1:
|
||||||
self.dilation = self._dilation
|
self.dilation = self._dilation
|
||||||
Validator.check_integer('group', self.group, self.in_channels, Rel.EQ)
|
Validator.check_int('group', self.group, self.in_channels, Rel.EQ)
|
||||||
Validator.check_integer('group', self.group, self.out_channels, Rel.EQ)
|
Validator.check_int('group', self.group, self.out_channels, Rel.EQ)
|
||||||
self.conv2d = P.DepthwiseConv2dNative(channel_multiplier=1,
|
self.conv2d = P.DepthwiseConv2dNative(channel_multiplier=1,
|
||||||
kernel_size=self.kernel_size,
|
kernel_size=self.kernel_size,
|
||||||
pad_mode=self.pad_mode,
|
pad_mode=self.pad_mode,
|
||||||
|
@ -724,10 +730,11 @@ class EmbeddingLookupThor(Cell):
|
||||||
self.forward_unique = False
|
self.forward_unique = False
|
||||||
self.dtype = mstype.float16
|
self.dtype = mstype.float16
|
||||||
if target not in ('CPU', 'DEVICE'):
|
if target not in ('CPU', 'DEVICE'):
|
||||||
raise ValueError('Attr \'target\' of \'EmbeddingLookup\' Op passed '
|
raise ValueError(f"For '{self.cls_name}', the 'target' should be one of values in ('CPU', 'DEVICE'), "
|
||||||
+ str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.')
|
f"but got {target}.")
|
||||||
if not sparse and target == 'CPU':
|
if not sparse and target == 'CPU':
|
||||||
raise ValueError('When target is CPU, embedding_lookup must be sparse.')
|
raise ValueError(f"For '{self.cls_name}', embedding_lookup must be sparse when 'target' is CPU, but got "
|
||||||
|
f"'sparse': {sparse}, 'target': {target}.")
|
||||||
if sparse:
|
if sparse:
|
||||||
self.gatherv2 = P.SparseGatherV2()
|
self.gatherv2 = P.SparseGatherV2()
|
||||||
else:
|
else:
|
||||||
|
@ -755,9 +762,11 @@ class EmbeddingLookupThor(Cell):
|
||||||
indices_shape_size = 2
|
indices_shape_size = 2
|
||||||
if slice_mode == "field_slice" and is_auto_parallel:
|
if slice_mode == "field_slice" and is_auto_parallel:
|
||||||
if not manual_shapes:
|
if not manual_shapes:
|
||||||
raise ValueError("in slice field mode, the manual_shapes should not be none")
|
raise ValueError(f"For '{self.cls_name}', the 'manual_shapes' should not be none "
|
||||||
|
f"when 'slice_mode' is 'field_slice'.")
|
||||||
if not isinstance(manual_shapes, tuple):
|
if not isinstance(manual_shapes, tuple):
|
||||||
raise TypeError("manual_shapes type must be tuple(int) cannot be {}!".format(type(manual_shapes)))
|
raise TypeError(f"For '{self.cls_name}', the type of 'manual_shapes' must be tuple(int), but got "
|
||||||
|
f"type {type(manual_shapes)}.")
|
||||||
for dim in manual_shapes:
|
for dim in manual_shapes:
|
||||||
Validator.check_positive_int(dim, 'manual shape dim', self.cls_name)
|
Validator.check_positive_int(dim, 'manual shape dim', self.cls_name)
|
||||||
self.gatherv2.add_prim_attr("manual_split", manual_shapes)
|
self.gatherv2.add_prim_attr("manual_split", manual_shapes)
|
||||||
|
@ -789,11 +798,13 @@ class EmbeddingLookupThor(Cell):
|
||||||
self.embeddinglookup.shard(((1, 1), indices_strategy))
|
self.embeddinglookup.shard(((1, 1), indices_strategy))
|
||||||
else:
|
else:
|
||||||
if is_auto_parallel:
|
if is_auto_parallel:
|
||||||
raise ValueError("slice_mode should support mode in nn.EmbeddingLookup, but get "
|
raise ValueError(f"For '{self.cls_name}', the 'slice_mode' should be one of values in "
|
||||||
+ str(slice_mode))
|
f"['field_slice', 'table_row_slice', 'table_column_slice', 'batch_slice'], "
|
||||||
|
f"but got 'slice_mode': {slice_mode}")
|
||||||
if self.cache_enable and not enable_ps:
|
if self.cache_enable and not enable_ps:
|
||||||
if parallel_mode != ParallelMode.STAND_ALONE:
|
if parallel_mode != ParallelMode.STAND_ALONE:
|
||||||
raise ValueError("parallel mode haven't supported cache enable yet.")
|
raise ValueError(f"For '{self.cls_name}', the 'parallel_mode' should be equal to "
|
||||||
|
f"'ParallelMode.STAND_ALONE', but got {parallel_mode}.")
|
||||||
self._set_cache_enable()
|
self._set_cache_enable()
|
||||||
self.embedding_table.unique = self.forward_unique
|
self.embedding_table.unique = self.forward_unique
|
||||||
self.max_norm = max_norm
|
self.max_norm = max_norm
|
||||||
|
@ -834,11 +845,14 @@ class EmbeddingLookupThor(Cell):
|
||||||
def _set_cache_enable(self):
|
def _set_cache_enable(self):
|
||||||
"""EmbeddingLookup cache check for not ps env, which is only support 'ascend'."""
|
"""EmbeddingLookup cache check for not ps env, which is only support 'ascend'."""
|
||||||
if self.target != 'DEVICE':
|
if self.target != 'DEVICE':
|
||||||
raise ValueError("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target.")
|
raise ValueError(f"For '{self.cls_name}', the configuration of 'vocab_cache_size' is valid "
|
||||||
|
f"only when 'target' is 'DEVICE', but got 'target': {self.target}.")
|
||||||
if not self.sparse:
|
if not self.sparse:
|
||||||
raise ValueError("The configuration of 'vocab_cache_size' is valid only 'sparse' is true.")
|
raise ValueError(f"For '{self.cls_name}', the configuration of 'vocab_cache_size' is valid "
|
||||||
|
f"only when 'sparse' is true, but got 'sparse': {self.sparse}.")
|
||||||
if context.get_context("device_target") != 'Ascend':
|
if context.get_context("device_target") != 'Ascend':
|
||||||
raise ValueError("The configuration of 'vocab_cache_size' is valid only in 'ascend'.")
|
raise ValueError(f"For '{self.cls_name}', the configuration of 'vocab_cache_size' is valid "
|
||||||
|
f"only when 'device_target' is 'Ascend', but got {context.get_context('device_target')}.")
|
||||||
|
|
||||||
logger.info("EmbeddingLookup cache enable takes effect.")
|
logger.info("EmbeddingLookup cache enable takes effect.")
|
||||||
self.forward_unique = True
|
self.forward_unique = True
|
||||||
|
@ -869,17 +883,18 @@ class EmbeddingLookupThor(Cell):
|
||||||
rank_id = get_rank()
|
rank_id = get_rank()
|
||||||
full_batch = _get_full_batch()
|
full_batch = _get_full_batch()
|
||||||
if rank_size > 1 and not (full_batch and slice_mode == "table_row_slice"):
|
if rank_size > 1 and not (full_batch and slice_mode == "table_row_slice"):
|
||||||
raise ValueError("The embeddingLookup cache of parameter server parallel only be used "
|
raise ValueError(f"For '{self.cls_name}', the embeddingLookup cache of parameter server parallel "
|
||||||
"in 'full_batch' and 'table_row_slice' parallel strategy.")
|
f"only be used in 'full_batch' and 'table_row_slice' parallel strategy, but got "
|
||||||
|
f"'full_batch': {full_batch}, 'slice_mode': {slice_mode}.")
|
||||||
self.vocab_cache_size = self.vocab_cache_size * rank_size
|
self.vocab_cache_size = self.vocab_cache_size * rank_size
|
||||||
_set_rank_id(rank_id)
|
_set_rank_id(rank_id)
|
||||||
self.cache_enable = True
|
self.cache_enable = True
|
||||||
if _is_role_worker():
|
if _is_role_worker():
|
||||||
self.vocab_size = self.vocab_cache_size
|
self.vocab_size = self.vocab_cache_size
|
||||||
if context.get_context("enable_sparse") != self.sparse:
|
if context.get_context("enable_sparse") != self.sparse:
|
||||||
raise ValueError("The value of parameter 'sparse' must be same for all EmbeddingLookup "
|
raise ValueError(f"For '{self.cls_name}', the 'sparse' must be equal to the 'enable_sparse' "
|
||||||
"kernels and equal the value of 'enable_sparse' in context setting in "
|
f"in context setting in parameter server cache mode, but got 'sparse': "
|
||||||
"parameter server cache mode")
|
f"{self.sparse}, 'enable_sparse': {context.get_context('enable_sparse')}.")
|
||||||
|
|
||||||
def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size):
|
def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size):
|
||||||
"""PS embeddingLookup cache enable set."""
|
"""PS embeddingLookup cache enable set."""
|
||||||
|
|
|
@ -24,17 +24,22 @@ __all__ = ['TimeDistributed']
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _check_reshape_pos(reshape_pos, inputs_shape, outputs_shape):
|
def _check_reshape_pos(reshape_pos, inputs_shape, outputs_shape, prim_name=None):
|
||||||
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||||
if reshape_pos >= len(outputs_shape) or inputs_shape[reshape_pos] != outputs_shape[reshape_pos]:
|
if reshape_pos >= len(outputs_shape) or inputs_shape[reshape_pos] != outputs_shape[reshape_pos]:
|
||||||
raise ValueError("The parameter reshape_with_axis is invalid in the input and output of TimeDistributed. "
|
raise ValueError(f"{msg_prefix} 'reshape_with_axis' is invalid in the input and output. "
|
||||||
"You may try pass parameters without reshape_with_axis.")
|
f"The 'reshape_pos' should be less than the length of 'outputs_shape', and the "
|
||||||
|
f"'inputs_shape[reshape_pos]' should be equal to 'outputs_shape[reshape_pos]', but got "
|
||||||
|
f"'reshape_pos': {reshape_pos}, 'inputs_shape': {inputs_shape}, 'outputs_shape': "
|
||||||
|
f"{outputs_shape}. You may try pass parameters without 'reshape_with_axis'.")
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _check_expand_dims_axis(time_axis, ndim):
|
def _check_expand_dims_axis(time_axis, ndim, prim_name=None):
|
||||||
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||||
if time_axis > ndim:
|
if time_axis > ndim:
|
||||||
raise ValueError("The parameter time_axis is invalid in the input. "
|
raise ValueError(f"{msg_prefix} value of 'time_axis' should be in range of [{-ndim - 1}, {ndim}], "
|
||||||
"The value of time_axis should be in range of [{}, {}].".format(-ndim - 1, ndim))
|
f"but got {time_axis}.")
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
|
@ -45,15 +50,17 @@ def _generate_perm(axis_a, axis_b, length):
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _check_data(flag):
|
def _check_data(flag, prim_name=None):
|
||||||
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||||
if not flag:
|
if not flag:
|
||||||
raise TypeError("The inputs and outputs should be a Tensor.")
|
raise TypeError(f"{msg_prefix} inputs and outputs should be a Tensor.")
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _check_inputs_dim(shape):
|
def _check_inputs_dim(shape, prim_name=None):
|
||||||
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||||
if len(shape) < 3:
|
if len(shape) < 3:
|
||||||
raise ValueError("The inputs should be at least 3D.")
|
raise ValueError(f"{msg_prefix} inputs shape should be at least 3D, but got {len(shape)}.")
|
||||||
|
|
||||||
|
|
||||||
class TimeDistributed(Cell):
|
class TimeDistributed(Cell):
|
||||||
|
@ -97,8 +104,8 @@ class TimeDistributed(Cell):
|
||||||
def __init__(self, layer, time_axis, reshape_with_axis=None):
|
def __init__(self, layer, time_axis, reshape_with_axis=None):
|
||||||
"""Initialize TimeDistributed."""
|
"""Initialize TimeDistributed."""
|
||||||
if not isinstance(layer, (Cell, Primitive)):
|
if not isinstance(layer, (Cell, Primitive)):
|
||||||
raise TypeError("Please initialize TimeDistributed with mindspore.nn.Cell or "
|
raise TypeError(f"For '{self.cls_name}', the type of 'layer' should be mindspore.nn.Cell or "
|
||||||
"mindspore.ops.Primitive instance. You passed: {input}".format(input=layer))
|
"mindspore.ops.Primitive instance, but got type: {type(layer)}.")
|
||||||
super(TimeDistributed, self).__init__()
|
super(TimeDistributed, self).__init__()
|
||||||
Validator.check_is_int(time_axis)
|
Validator.check_is_int(time_axis)
|
||||||
if reshape_with_axis is not None:
|
if reshape_with_axis is not None:
|
||||||
|
@ -111,7 +118,7 @@ class TimeDistributed(Cell):
|
||||||
|
|
||||||
def construct(self, inputs):
|
def construct(self, inputs):
|
||||||
_check_data(isinstance(inputs, Tensor))
|
_check_data(isinstance(inputs, Tensor))
|
||||||
_check_inputs_dim(inputs.shape)
|
_check_inputs_dim(inputs.shape, self.cls_name)
|
||||||
time_axis = self.time_axis % len(inputs.shape)
|
time_axis = self.time_axis % len(inputs.shape)
|
||||||
if self.reshape_with_axis is not None:
|
if self.reshape_with_axis is not None:
|
||||||
reshape_with_axis = self.reshape_with_axis % len(inputs.shape)
|
reshape_with_axis = self.reshape_with_axis % len(inputs.shape)
|
||||||
|
@ -126,7 +133,7 @@ class TimeDistributed(Cell):
|
||||||
inputs = self.reshape(inputs, inputs_shape_new[: reshape_pos] + (-1,) + inputs_shape_new[reshape_pos + 2:])
|
inputs = self.reshape(inputs, inputs_shape_new[: reshape_pos] + (-1,) + inputs_shape_new[reshape_pos + 2:])
|
||||||
outputs = self.layer(inputs)
|
outputs = self.layer(inputs)
|
||||||
_check_data(isinstance(outputs, Tensor))
|
_check_data(isinstance(outputs, Tensor))
|
||||||
_check_reshape_pos(reshape_pos, inputs.shape, outputs.shape)
|
_check_reshape_pos(reshape_pos, inputs.shape, outputs.shape, self.cls_name)
|
||||||
outputs_shape_new = outputs.shape[:reshape_pos] + inputs_shape_new[reshape_pos: reshape_pos + 2]
|
outputs_shape_new = outputs.shape[:reshape_pos] + inputs_shape_new[reshape_pos: reshape_pos + 2]
|
||||||
if reshape_pos + 1 < len(outputs.shape):
|
if reshape_pos + 1 < len(outputs.shape):
|
||||||
outputs_shape_new += outputs.shape[reshape_pos + 1:]
|
outputs_shape_new += outputs.shape[reshape_pos + 1:]
|
||||||
|
@ -138,7 +145,7 @@ class TimeDistributed(Cell):
|
||||||
for item in inputs:
|
for item in inputs:
|
||||||
outputs = self.layer(item)
|
outputs = self.layer(item)
|
||||||
_check_data(isinstance(outputs, Tensor))
|
_check_data(isinstance(outputs, Tensor))
|
||||||
_check_expand_dims_axis(time_axis, outputs.ndim)
|
_check_expand_dims_axis(time_axis, outputs.ndim, self.cls_name)
|
||||||
y += (outputs,)
|
y += (outputs,)
|
||||||
y = Stack(time_axis)(y)
|
y = Stack(time_axis)(y)
|
||||||
return y
|
return y
|
||||||
|
|
|
@ -116,18 +116,17 @@ tensor_operator_registry.register('repeat_elements', repeat_elements)
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _check_sequence_mask_input_len(input_shape):
|
def _check_sequence_mask_input_len(input_shape, prim_name=None):
|
||||||
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||||
if not input_shape:
|
if not input_shape:
|
||||||
raise ValueError(f"Sequence_mask lengths_shape should be > 0. "
|
raise ValueError(f"{msg_prefix} lengths_shape should be greater than 0, but got {input_shape}.")
|
||||||
f"Current lengths_shape is {input_shape}.")
|
|
||||||
# broadcast only supports 7d shape
|
# broadcast only supports 7d shape
|
||||||
shape_size = len(input_shape)
|
shape_size = len(input_shape)
|
||||||
if shape_size >= 7:
|
if shape_size >= 7:
|
||||||
raise ValueError(f"Sequence_mask lengths_shape's size only support a value less than 7. "
|
raise ValueError(f"{msg_prefix} size of lengths_shape should be less than 7, but got {shape_size}d.")
|
||||||
f"Current lengths_shape is {shape_size}d.")
|
|
||||||
|
|
||||||
|
|
||||||
def sequence_mask(lengths, maxlen=None):
|
def sequence_mask(lengths, maxlen=None, prim_name='sequence_mask'):
|
||||||
"""
|
"""
|
||||||
Returns a mask tensor representing the first N positions of each cell.
|
Returns a mask tensor representing the first N positions of each cell.
|
||||||
|
|
||||||
|
@ -188,7 +187,7 @@ def sequence_mask(lengths, maxlen=None):
|
||||||
to_tensor_op = P.ScalarToArray()
|
to_tensor_op = P.ScalarToArray()
|
||||||
|
|
||||||
const_utils.check_type_valid(F.dtype(lengths), [mstype.int64, mstype.int32], 'lengths')
|
const_utils.check_type_valid(F.dtype(lengths), [mstype.int64, mstype.int32], 'lengths')
|
||||||
_check_sequence_mask_input_len(shape_op(lengths))
|
_check_sequence_mask_input_len(shape_op(lengths), prim_name)
|
||||||
|
|
||||||
if maxlen is None:
|
if maxlen is None:
|
||||||
flatten_data = reshape_op(lengths, (-1,))
|
flatten_data = reshape_op(lengths, (-1,))
|
||||||
|
|
|
@ -319,11 +319,11 @@ class GradOperation(GradOperation_):
|
||||||
def __init__(self, get_all=False, get_by_list=False, sens_param=False):
|
def __init__(self, get_all=False, get_by_list=False, sens_param=False):
|
||||||
"""Initialize GradOperation."""
|
"""Initialize GradOperation."""
|
||||||
if not isinstance(get_all, bool):
|
if not isinstance(get_all, bool):
|
||||||
raise TypeError(f'get_all should be bool, but got {type(get_all)}')
|
raise TypeError(f"For 'GradOperation', the 'get_all' should be bool, but got {type(get_all)}")
|
||||||
if not isinstance(get_by_list, bool):
|
if not isinstance(get_by_list, bool):
|
||||||
raise TypeError(f'get_by_list should be bool, but got {type(get_by_list)}')
|
raise TypeError(f"For 'GradOperation', the 'get_by_list' should be bool, but got {type(get_by_list)}")
|
||||||
if not isinstance(sens_param, bool):
|
if not isinstance(sens_param, bool):
|
||||||
raise TypeError(f'sens_param should be bool, but got {type(sens_param)}')
|
raise TypeError(f"For 'GradOperation', the 'sens_param' should be bool, but got {type(sens_param)}")
|
||||||
self.get_all = get_all
|
self.get_all = get_all
|
||||||
self.get_by_list = get_by_list
|
self.get_by_list = get_by_list
|
||||||
self.sens_param = sens_param
|
self.sens_param = sens_param
|
||||||
|
@ -443,7 +443,7 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
|
||||||
continue
|
continue
|
||||||
output = fn(*args)
|
output = fn(*args)
|
||||||
return output
|
return output
|
||||||
raise ValueError("Cannot find fn match given args.")
|
raise ValueError(f"For 'MultitypeFuncGraph', cannot find fn match given args.")
|
||||||
|
|
||||||
def register(self, *type_names):
|
def register(self, *type_names):
|
||||||
"""
|
"""
|
||||||
|
@ -461,7 +461,8 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
|
||||||
if isinstance(type_input, str):
|
if isinstance(type_input, str):
|
||||||
return mstype.typing.str_to_type(type_input)
|
return mstype.typing.str_to_type(type_input)
|
||||||
if not isinstance(type_input, mstype.Type):
|
if not isinstance(type_input, mstype.Type):
|
||||||
raise TypeError(f"MultitypeFuncGraph register only support str or {mstype.Type}")
|
raise TypeError(f"For 'MultitypeFuncGraph', register only support str or {mstype.Type}, but got "
|
||||||
|
f"'type_input': {type_input}.")
|
||||||
return type_input
|
return type_input
|
||||||
|
|
||||||
types = tuple(map(convert_type, type_names))
|
types = tuple(map(convert_type, type_names))
|
||||||
|
|
|
@ -26,9 +26,11 @@ from mindspore.ops.primitive import constexpr
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _check_shape(input_shape, out_shape):
|
def _check_shape(input_shape, out_shape, prim_name=None):
|
||||||
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||||
if input_shape != out_shape:
|
if input_shape != out_shape:
|
||||||
raise ValueError("Cannot broadcast the shape of x to the shape of clip_value_min or clip_value_max.")
|
raise ValueError(f"{msg_prefix} input_shape should be equal to the out_shape, but got "
|
||||||
|
f"input_shape {input_shape} and out_shape {out_shape}.")
|
||||||
|
|
||||||
|
|
||||||
def clip_by_value(x, clip_value_min, clip_value_max):
|
def clip_by_value(x, clip_value_min, clip_value_max):
|
||||||
|
@ -74,7 +76,7 @@ def clip_by_value(x, clip_value_min, clip_value_max):
|
||||||
max_op = P.Maximum()
|
max_op = P.Maximum()
|
||||||
x_min = min_op(x, clip_value_max)
|
x_min = min_op(x, clip_value_max)
|
||||||
x_max = max_op(x_min, clip_value_min)
|
x_max = max_op(x_min, clip_value_min)
|
||||||
_check_shape(F.shape(x), F.shape(x_max))
|
_check_shape(F.shape(x), F.shape(x_max), 'clip_by_value')
|
||||||
return x_max
|
return x_max
|
||||||
|
|
||||||
|
|
||||||
|
@ -115,7 +117,8 @@ class _ClipByGlobalNorm(Cell):
|
||||||
super(_ClipByGlobalNorm, self).__init__()
|
super(_ClipByGlobalNorm, self).__init__()
|
||||||
# Add interface. This parameter is not used at present
|
# Add interface. This parameter is not used at present
|
||||||
if use_norm is not None:
|
if use_norm is not None:
|
||||||
raise ValueError("Input 'use_norm' only supports None currently!")
|
raise ValueError(f"For '{self.cls_name}', input 'use_norm' only supports None currently, "
|
||||||
|
f"but got 'use_norm': {use_norm}")
|
||||||
validator.check_number("clip_norm", clip_norm, 0.0, Rel.GT, self.cls_name)
|
validator.check_number("clip_norm", clip_norm, 0.0, Rel.GT, self.cls_name)
|
||||||
self.clip_norm = Tensor([clip_norm], mstype.float32)
|
self.clip_norm = Tensor([clip_norm], mstype.float32)
|
||||||
self.hyper_map = C.HyperMap()
|
self.hyper_map = C.HyperMap()
|
||||||
|
|
|
@ -119,48 +119,53 @@ def _int_to_tuple_conv(axes):
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _check_axes(axes):
|
def _check_axes(axes, prim_name=None):
|
||||||
"""
|
"""
|
||||||
Check for validity and type of axes passed to function.
|
Check for validity and type of axes passed to function.
|
||||||
"""
|
"""
|
||||||
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||||
validator.check_value_type('axes', axes, [int, tuple, list], "tensor dot")
|
validator.check_value_type('axes', axes, [int, tuple, list], "tensor dot")
|
||||||
if not isinstance(axes, int):
|
if not isinstance(axes, int):
|
||||||
axes = list(axes) # to avoid immutability issues
|
axes = list(axes) # to avoid immutability issues
|
||||||
if len(axes) != 2:
|
if len(axes) != 2:
|
||||||
raise ValueError("Require two axes inputs, given less")
|
raise ValueError(f"{msg_prefix} dimension of axes should be 2, but got {axes}.")
|
||||||
axes = _int_to_tuple_conv(axes) # convert before length checks
|
axes = _int_to_tuple_conv(axes) # convert before length checks
|
||||||
if len(axes[0]) != len(axes[1]):
|
if len(axes[0]) != len(axes[1]):
|
||||||
raise ValueError("Axes have to be the same size/length")
|
raise ValueError(f"{msg_prefix} first and second dim of axes have to be the same size/length, "
|
||||||
|
f"but got {axes}.")
|
||||||
if len(axes[0]) != len(set(axes[0])) or len(axes[1]) != len(set(axes[1])):
|
if len(axes[0]) != len(set(axes[0])) or len(axes[1]) != len(set(axes[1])):
|
||||||
raise ValueError("Axes cannot have duplicating values")
|
raise ValueError(f"{msg_prefix} axes cannot have duplicating values, but got {axes}.")
|
||||||
return axes
|
return axes
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _typecheck_input(x1_type, x2_type):
|
def _typecheck_input(x1_type, x2_type, prim_name=None):
|
||||||
"""
|
"""
|
||||||
Check input tensor types to be valid and confirm they are the same type.
|
Check input tensor types to be valid and confirm they are the same type.
|
||||||
"""
|
"""
|
||||||
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||||
const_utils.check_type_valid(x1_type, [mstype.float32, mstype.float16], 'x1')
|
const_utils.check_type_valid(x1_type, [mstype.float32, mstype.float16], 'x1')
|
||||||
const_utils.check_type_valid(x2_type, [mstype.float32, mstype.float16], 'x2')
|
const_utils.check_type_valid(x2_type, [mstype.float32, mstype.float16], 'x2')
|
||||||
if x1_type != x2_type:
|
if x1_type != x2_type:
|
||||||
raise TypeError(f'Both Inputs must be the same Type. x1 is \'{x1_type}\' and x2 is \'{x2_type}\' ')
|
raise TypeError(f"{msg_prefix} inputs must be the same type, but got x1_type: {x1_type} "
|
||||||
|
f"and x2_type: {x2_type}.")
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _axes_int_check(x1_shape, x2_shape, axes):
|
def _axes_int_check(x1_shape, x2_shape, axes, prim_name=None):
|
||||||
"""
|
"""
|
||||||
Convert from single int axes to 2d tuple if required
|
Convert from single int axes to 2d tuple if required
|
||||||
"""
|
"""
|
||||||
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||||
if isinstance(axes, int):
|
if isinstance(axes, int):
|
||||||
if axes < 0:
|
if axes < 0:
|
||||||
raise ValueError(f"axes must be at least 0 for tensor dot, got {axes}")
|
raise ValueError(f"{msg_prefix} axes must be at least 0, but got {axes}.")
|
||||||
if axes == 0:
|
if axes == 0:
|
||||||
# outer product, no input validation required
|
# outer product, no input validation required
|
||||||
return [], []
|
return [], []
|
||||||
if axes > len(x1_shape) or axes > len(x2_shape):
|
if axes > len(x1_shape) or axes > len(x2_shape):
|
||||||
raise ValueError(
|
raise ValueError(f"{msg_prefix} axes cannot be greater than the length of x1_shape and x2_shape, "
|
||||||
"Axes value too high for given input arrays dimensions.")
|
f"but got axes: {axes}, x1_shape: {x1_shape}, x2_shape: {x2_shape}.")
|
||||||
x1_ind = tuple(range(len(x1_shape))[-1 * axes:])
|
x1_ind = tuple(range(len(x1_shape))[-1 * axes:])
|
||||||
x2_ind = tuple(range(len(x2_shape))[:axes])
|
x2_ind = tuple(range(len(x2_shape))[:axes])
|
||||||
axes = tuple((x1_ind, x2_ind))
|
axes = tuple((x1_ind, x2_ind))
|
||||||
|
@ -169,12 +174,13 @@ def _axes_int_check(x1_shape, x2_shape, axes):
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _validate_axes(x1_shape, x2_shape, axes):
|
def _validate_axes(x1_shape, x2_shape, axes, prim_name=None):
|
||||||
"""
|
"""
|
||||||
Checks for axes having the correct length according to input, for any value in axis
|
Checks for axes having the correct length according to input, for any value in axis
|
||||||
being out of range with given shape and also checking for compatible axes values
|
being out of range with given shape and also checking for compatible axes values
|
||||||
with given inputs.
|
with given inputs.
|
||||||
"""
|
"""
|
||||||
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||||
shapes = [x1_shape, x2_shape]
|
shapes = [x1_shape, x2_shape]
|
||||||
|
|
||||||
# axis length check
|
# axis length check
|
||||||
|
@ -182,8 +188,8 @@ def _validate_axes(x1_shape, x2_shape, axes):
|
||||||
axes_len = len(x_axes)
|
axes_len = len(x_axes)
|
||||||
shape_dim_len = len(shapes[ix_input])
|
shape_dim_len = len(shapes[ix_input])
|
||||||
if axes_len > shape_dim_len:
|
if axes_len > shape_dim_len:
|
||||||
raise ValueError(f"axes for input: {ix_input + 1} are of length: {axes_len} "
|
raise ValueError(f"{msg_prefix} length of x_axes should be less than or equal to {shape_dim_len}, "
|
||||||
f"can only be max: {shape_dim_len} due to input shape.")
|
f"but got 'len(x_axes)': {axes_len}.")
|
||||||
|
|
||||||
# axis values range check
|
# axis values range check
|
||||||
for ix_input, x_axes in enumerate(axes):
|
for ix_input, x_axes in enumerate(axes):
|
||||||
|
@ -192,8 +198,8 @@ def _validate_axes(x1_shape, x2_shape, axes):
|
||||||
min_val = -1 * len(comp_shape)
|
min_val = -1 * len(comp_shape)
|
||||||
for _, x_value in enumerate(x_axes):
|
for _, x_value in enumerate(x_axes):
|
||||||
if not min_val <= x_value <= max_val:
|
if not min_val <= x_value <= max_val:
|
||||||
raise ValueError(f"axes for input: {ix_input + 1} contains index: "
|
raise ValueError(f"{msg_prefix} value in axes should be in range: [{min_val}, {max_val}], "
|
||||||
f"{x_value}, but range is: [{min_val}, {max_val}]")
|
f"but got {x_value}.")
|
||||||
|
|
||||||
# check axis value with input shape - both ways for axis valid
|
# check axis value with input shape - both ways for axis valid
|
||||||
invalid_a = False
|
invalid_a = False
|
||||||
|
@ -204,7 +210,9 @@ def _validate_axes(x1_shape, x2_shape, axes):
|
||||||
if x1_shape[axes[0][i]] != x2_shape[axes[1][len(axes[0])-1-i]]:
|
if x1_shape[axes[0][i]] != x2_shape[axes[1][len(axes[0])-1-i]]:
|
||||||
invalid_b = True
|
invalid_b = True
|
||||||
if invalid_a and invalid_b:
|
if invalid_a and invalid_b:
|
||||||
raise ValueError("Given Axes are incompatible with given input arrays")
|
raise ValueError(f"{msg_prefix} 'i' should exist such that 'x1_shape[axes[0][i]]' is equal to "
|
||||||
|
f"'x2_shape[axes[1][i]]' or 'x2_shape[axes[1][len(axes[0])-1-i]]', but got "
|
||||||
|
f"x1_shape: {x1_shape}, x2_shape: {x2_shape}, axes: {axes}.")
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
|
@ -224,7 +232,7 @@ def _calc_new_shape(shape, axes, position=0):
|
||||||
return new_shape, transpose_perm, free_dims
|
return new_shape, transpose_perm, free_dims
|
||||||
|
|
||||||
|
|
||||||
def tensor_dot(x1, x2, axes):
|
def tensor_dot(x1, x2, axes, prim_name='tensor_dot'):
|
||||||
"""
|
"""
|
||||||
Computation of Tensor contraction on arbitrary axes between tensors `a` and `b`.
|
Computation of Tensor contraction on arbitrary axes between tensors `a` and `b`.
|
||||||
|
|
||||||
|
@ -276,11 +284,11 @@ def tensor_dot(x1, x2, axes):
|
||||||
x2_shape = shape_op(x2)
|
x2_shape = shape_op(x2)
|
||||||
x1_type = F.dtype(x1)
|
x1_type = F.dtype(x1)
|
||||||
x2_type = F.dtype(x2)
|
x2_type = F.dtype(x2)
|
||||||
axes = _check_axes(axes)
|
axes = _check_axes(axes, prim_name)
|
||||||
_typecheck_input(x1_type, x2_type)
|
_typecheck_input(x1_type, x2_type, prim_name)
|
||||||
# input compatibility check & axes format update
|
# input compatibility check & axes format update
|
||||||
axes = _axes_int_check(x1_shape, x2_shape, axes)
|
axes = _axes_int_check(x1_shape, x2_shape, axes, prim_name)
|
||||||
_validate_axes(x1_shape, x2_shape, axes)
|
_validate_axes(x1_shape, x2_shape, axes, prim_name)
|
||||||
x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape(x1_shape, axes, 0)
|
x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape(x1_shape, axes, 0)
|
||||||
x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape(x2_shape, axes, 1)
|
x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape(x2_shape, axes, 1)
|
||||||
output_shape = x1_ret + x2_ret # combine free axes from both inputs
|
output_shape = x1_ret + x2_ret # combine free axes from both inputs
|
||||||
|
@ -295,21 +303,24 @@ def tensor_dot(x1, x2, axes):
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _check_invalid_input(x1_shape, x2_shape):
|
def _check_invalid_input(x1_shape, x2_shape, prim_name=None):
|
||||||
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||||
if len(x1_shape) < 2 or len(x2_shape) < 2:
|
if len(x1_shape) < 2 or len(x2_shape) < 2:
|
||||||
raise ValueError('C.dot inputs x1, x2 should has dimension >= 2,'
|
raise ValueError(f"{msg_prefix} inputs x1, x2 should have 'dimension >= 2',"
|
||||||
+ f'while x1 is ({len(x1_shape)}) and x2 is ({len(x2_shape)}).')
|
f"but got 'len(x1_shape)': ({len(x1_shape)}) and 'len(x2_shape)': ({len(x2_shape)}).")
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _typecheck_input_dot(x1_type, x2_type):
|
def _typecheck_input_dot(x1_type, x2_type, prim_name=None):
|
||||||
"""
|
"""
|
||||||
Check input tensor types to be valid and confirm they are the same type for dot and batch dot ops.
|
Check input tensor types to be valid and confirm they are the same type for dot and batch dot ops.
|
||||||
"""
|
"""
|
||||||
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||||
const_utils.check_type_valid(x1_type, [mstype.float16, mstype.float32], 'x1')
|
const_utils.check_type_valid(x1_type, [mstype.float16, mstype.float32], 'x1')
|
||||||
const_utils.check_type_valid(x2_type, [mstype.float16, mstype.float32], 'x2')
|
const_utils.check_type_valid(x2_type, [mstype.float16, mstype.float32], 'x2')
|
||||||
if x1_type != x2_type:
|
if x1_type != x2_type:
|
||||||
raise TypeError(f'Both Inputs must be the same Type. x1 is \'{x1_type}\' and x2 is \'{x2_type}\' ')
|
raise TypeError(f"{msg_prefix} inputs must be the same type, but got "
|
||||||
|
f"x1_type: {x1_type} and x2_type: {x2_type}.")
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
|
@ -319,7 +330,7 @@ def _get_transpose_shape(x2_shape):
|
||||||
return x2_shape_transpose
|
return x2_shape_transpose
|
||||||
|
|
||||||
|
|
||||||
def dot(x1, x2):
|
def dot(x1, x2, prim_name=None):
|
||||||
"""
|
"""
|
||||||
Computation a dot product between samples in two tensors.
|
Computation a dot product between samples in two tensors.
|
||||||
|
|
||||||
|
@ -394,8 +405,8 @@ def dot(x1, x2):
|
||||||
x2_shape = shape_op(x2)
|
x2_shape = shape_op(x2)
|
||||||
x1_type = F.dtype(x1)
|
x1_type = F.dtype(x1)
|
||||||
x2_type = F.dtype(x2)
|
x2_type = F.dtype(x2)
|
||||||
_typecheck_input_dot(x1_type, x2_type)
|
_typecheck_input_dot(x1_type, x2_type, prim_name)
|
||||||
_check_invalid_input(x1_shape, x2_shape)
|
_check_invalid_input(x1_shape, x2_shape, prim_name)
|
||||||
|
|
||||||
if len(x1_shape) > 2 or len(x2_shape) > 2:
|
if len(x1_shape) > 2 or len(x2_shape) > 2:
|
||||||
x2_shape_transpose = _get_transpose_shape(x2_shape)
|
x2_shape_transpose = _get_transpose_shape(x2_shape)
|
||||||
|
@ -408,31 +419,36 @@ def dot(x1, x2):
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _get_batch_size(x1_shape, x2_shape):
|
def _get_batch_size(x1_shape, x2_shape, prim_name=None):
|
||||||
"""
|
"""
|
||||||
Get batch sizes from two inputs
|
Get batch sizes from two inputs
|
||||||
"""
|
"""
|
||||||
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||||
if len(x1_shape) < 2 or len(x2_shape) < 2:
|
if len(x1_shape) < 2 or len(x2_shape) < 2:
|
||||||
raise ValueError("Require both inputs with rank >= 2.")
|
raise ValueError(f"{msg_prefix} inputs x1, x2 should have 'dimension >= 2', "
|
||||||
|
f"but got 'len(x1_shape)': ({len(x1_shape)}) and 'len(x2_shape)': ({len(x2_shape)}).")
|
||||||
return x1_shape[0], x2_shape[0]
|
return x1_shape[0], x2_shape[0]
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _typecheck_input_batch_dot(x1_type, x2_type):
|
def _typecheck_input_batch_dot(x1_type, x2_type, prim_name=None):
|
||||||
"""
|
"""
|
||||||
Check input tensor types to be valid and confirm they are the same type for batch dot ops.
|
Check input tensor types to be valid and confirm they are the same type for batch dot ops.
|
||||||
"""
|
"""
|
||||||
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||||
const_utils.check_type_valid(x1_type, [mstype.float32], 'x1')
|
const_utils.check_type_valid(x1_type, [mstype.float32], 'x1')
|
||||||
const_utils.check_type_valid(x2_type, [mstype.float32], 'x2')
|
const_utils.check_type_valid(x2_type, [mstype.float32], 'x2')
|
||||||
if x1_type != x2_type:
|
if x1_type != x2_type:
|
||||||
raise TypeError(f'Both Inputs must be the same Type. x1 is \'{x1_type}\' and x2 is \'{x2_type}\' ')
|
raise TypeError(f"{msg_prefix} inputs must be the same type, but got x1_type: {x1_type} and "
|
||||||
|
f"x2_type: {x2_type}.")
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _check_axes_for_batch_dot(x1_shape, x2_shape, axes):
|
def _check_axes_for_batch_dot(x1_shape, x2_shape, axes, prim_name=None):
|
||||||
"""
|
"""
|
||||||
Check whether axes are valid and cast axes from tuple to list
|
Check whether axes are valid and cast axes from tuple to list
|
||||||
"""
|
"""
|
||||||
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||||
if axes is None:
|
if axes is None:
|
||||||
if len(x2_shape) == 2:
|
if len(x2_shape) == 2:
|
||||||
axes = [len(x1_shape) - 1, len(x2_shape) - 1]
|
axes = [len(x1_shape) - 1, len(x2_shape) - 1]
|
||||||
|
@ -441,9 +457,9 @@ def _check_axes_for_batch_dot(x1_shape, x2_shape, axes):
|
||||||
|
|
||||||
if isinstance(axes, (list, tuple)):
|
if isinstance(axes, (list, tuple)):
|
||||||
if 0 in axes:
|
if 0 in axes:
|
||||||
raise ValueError("Batch dim cannot be used as in axes.")
|
raise ValueError(f"{msg_prefix} axes cannot contain 0, but got axes: {axes}.")
|
||||||
if len(axes) != 2:
|
if len(axes) != 2:
|
||||||
raise ValueError("Require two axes inputs, given less")
|
raise ValueError(f"{msg_prefix} length of axes must be equal to 2, but got {len(axes)}.")
|
||||||
if isinstance(axes, tuple):
|
if isinstance(axes, tuple):
|
||||||
axes = list(axes)
|
axes = list(axes)
|
||||||
validator.check_value_type('axes[0]', axes[0], [int], 'batch_dot')
|
validator.check_value_type('axes[0]', axes[0], [int], 'batch_dot')
|
||||||
|
@ -456,22 +472,23 @@ def _check_axes_for_batch_dot(x1_shape, x2_shape, axes):
|
||||||
validator.check_non_negative_int(axes[0], 'reversed axes[0]', 'batch_dot')
|
validator.check_non_negative_int(axes[0], 'reversed axes[0]', 'batch_dot')
|
||||||
validator.check_non_negative_int(axes[1], 'reversed axes[1]', 'batch_dot')
|
validator.check_non_negative_int(axes[1], 'reversed axes[1]', 'batch_dot')
|
||||||
if axes[0] > len(x1_shape) or axes[1] > len(x2_shape):
|
if axes[0] > len(x1_shape) or axes[1] > len(x2_shape):
|
||||||
raise ValueError(
|
raise ValueError(f"{msg_prefix} axes[0] must be less than or equal to len(x1_shape), "
|
||||||
"Axes value too high for given input arrays dimensions.")
|
f"and axes[1] must be less than or equal to len(x2_shape)."
|
||||||
|
f"But got axes: {axes}, x1_shape: {x1_shape}, x2_shape: {x2_shape}.")
|
||||||
elif isinstance(axes, int):
|
elif isinstance(axes, int):
|
||||||
if axes == 0:
|
if axes == 0:
|
||||||
raise ValueError("Batch dim cannot be used as in axes.")
|
raise ValueError(f"{msg_prefix} axes should not equal to 0, but got {axes}.")
|
||||||
if axes < 0:
|
if axes < 0:
|
||||||
axes = [axes + len(x1_shape), axes + len(x2_shape)]
|
axes = [axes + len(x1_shape), axes + len(x2_shape)]
|
||||||
validator.check_non_negative_int(axes[0], 'reversed axes', 'batch_dot')
|
validator.check_non_negative_int(axes[0], 'reversed axes', 'batch_dot')
|
||||||
elif axes > len(x1_shape) or axes > len(x2_shape):
|
elif axes > len(x1_shape) or axes > len(x2_shape):
|
||||||
raise ValueError(
|
raise ValueError(f"{msg_prefix} axes cannot be greater than the length of x1_shape and x2_shape, "
|
||||||
"Axes value too high for given input arrays dimensions.")
|
f"but got axes: {axes}, x1_shape: {x1_shape}, x2_shape: {x2_shape}.")
|
||||||
else:
|
else:
|
||||||
axes = [axes, axes]
|
axes = [axes, axes]
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(f"{msg_prefix} type of axes must be one of those: int, tuple(int), list(int), "
|
||||||
"Axes type must be one of those: int, tuple(int), list(int).")
|
f"but got {type(axes)}.")
|
||||||
return axes
|
return axes
|
||||||
|
|
||||||
|
|
||||||
|
@ -496,12 +513,14 @@ def _calc_new_shape_batchdot(shape, axes, position=0):
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _check_batch_size(x1_batch_size, x2_batch_size):
|
def _check_batch_size(x1_batch_size, x2_batch_size, prim_name=None):
|
||||||
"""
|
"""
|
||||||
Check whether batch size of two inputs are the same
|
Check whether batch size of two inputs are the same
|
||||||
"""
|
"""
|
||||||
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||||
if x1_batch_size != x2_batch_size:
|
if x1_batch_size != x2_batch_size:
|
||||||
raise ValueError("Require both inputs with the same batch sizes.")
|
raise ValueError(f"{msg_prefix} both inputs x1, x2 should have the same batch sizes, but got "
|
||||||
|
f"x1_batch_size: {x1_batch_size} and x2_batch_size: {x2_batch_size}.")
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
|
@ -513,7 +532,7 @@ def _get_output_shape(batch_size, x1_ret, x2_ret):
|
||||||
return output_shape
|
return output_shape
|
||||||
|
|
||||||
|
|
||||||
def batch_dot(x1, x2, axes=None):
|
def batch_dot(x1, x2, axes=None, prim_name=None):
|
||||||
"""
|
"""
|
||||||
Computation of batch dot product between samples in two tensors containing batch dims.
|
Computation of batch dot product between samples in two tensors containing batch dims.
|
||||||
|
|
||||||
|
@ -593,11 +612,11 @@ def batch_dot(x1, x2, axes=None):
|
||||||
x1_type = F.dtype(x1)
|
x1_type = F.dtype(x1)
|
||||||
x2_type = F.dtype(x2)
|
x2_type = F.dtype(x2)
|
||||||
|
|
||||||
x1_batch_size, x2_batch_size = _get_batch_size(x1_shape, x2_shape)
|
x1_batch_size, x2_batch_size = _get_batch_size(x1_shape, x2_shape, prim_name)
|
||||||
|
|
||||||
_typecheck_input_batch_dot(x1_type, x2_type)
|
_typecheck_input_batch_dot(x1_type, x2_type, prim_name)
|
||||||
_check_batch_size(x1_batch_size, x2_batch_size)
|
_check_batch_size(x1_batch_size, x2_batch_size, prim_name)
|
||||||
axes = _check_axes_for_batch_dot(x1_shape, x2_shape, axes)
|
axes = _check_axes_for_batch_dot(x1_shape, x2_shape, axes, prim_name)
|
||||||
|
|
||||||
if x1_dim_num == 2:
|
if x1_dim_num == 2:
|
||||||
x1 = F.expand_dims(x1, 1)
|
x1 = F.expand_dims(x1, 1)
|
||||||
|
@ -664,19 +683,23 @@ def _infer_shape_rem(shape1, shape2, ndim1, ndim2, transpose_b):
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _check_matmul_shapes(shape1, shape2):
|
def _check_matmul_shapes(shape1, shape2, prim_name=None):
|
||||||
"""Checks shape1 and shape2 are valid to perform matmul, and returns output shape after broadcasting."""
|
"""Checks shape1 and shape2 are valid to perform matmul, and returns output shape after broadcasting."""
|
||||||
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||||
ndim1, ndim2 = len(shape1), len(shape2)
|
ndim1, ndim2 = len(shape1), len(shape2)
|
||||||
if ndim1 < 1 or ndim2 < 1:
|
if ndim1 < 1 or ndim2 < 1:
|
||||||
raise ValueError('input operands must have at least 1 dimension')
|
raise ValueError(f"{msg_prefix} dimension of input operands must be at least 1, but got "
|
||||||
|
f"the length of shape1: {ndim1}, the length of shape2: {ndim2}.")
|
||||||
if ndim2 >= 2 and shape1[-1] != shape2[-2]:
|
if ndim2 >= 2 and shape1[-1] != shape2[-2]:
|
||||||
raise ValueError(f'mismatch in core dimension of input operands (size '
|
raise ValueError(f"{msg_prefix} shape1[-1] should be equal to shape2[-2] when the length of shape2 "
|
||||||
f'{shape1[-1]} is different from {shape2[-2]})')
|
f"is greater than or equal to 2, but got shape1[-1]: {shape1[-1]}, "
|
||||||
|
f"shape2[-2]: {shape2[-2]}.")
|
||||||
shape_out = deque()
|
shape_out = deque()
|
||||||
for items in zip_longest(reversed(shape1[:-2]), reversed(shape2[:-2]), fillvalue=1):
|
for items in zip_longest(reversed(shape1[:-2]), reversed(shape2[:-2]), fillvalue=1):
|
||||||
max_size = max(items)
|
max_size = max(items)
|
||||||
if any(item not in (1, max_size) for item in items):
|
if any(item not in (1, max_size) for item in items):
|
||||||
raise ValueError(f'operands could not be broadcast together with shapes {shape1} {shape2}')
|
raise ValueError(f"{msg_prefix} operands could not be broadcast together with shape1 {shape1} and "
|
||||||
|
f"shape2 {shape2}.")
|
||||||
shape_out.appendleft(max_size)
|
shape_out.appendleft(max_size)
|
||||||
return tuple(shape_out)
|
return tuple(shape_out)
|
||||||
|
|
||||||
|
@ -710,7 +733,7 @@ def _broadcast_to(x, shape_cur, shape_to, ndim_to):
|
||||||
return F.tile(x, size)
|
return F.tile(x, size)
|
||||||
|
|
||||||
|
|
||||||
def matmul(x1, x2, dtype=None):
|
def matmul(x1, x2, dtype=None, prim_name=None):
|
||||||
"""
|
"""
|
||||||
Returns the matrix product of two arrays.
|
Returns the matrix product of two arrays.
|
||||||
|
|
||||||
|
@ -775,7 +798,7 @@ def matmul(x1, x2, dtype=None):
|
||||||
ndim1_orig, ndim2_orig = F.rank(x1), F.rank(x2)
|
ndim1_orig, ndim2_orig = F.rank(x1), F.rank(x2)
|
||||||
shape1_orig, shape2_orig = F.shape(x1), F.shape(x2)
|
shape1_orig, shape2_orig = F.shape(x1), F.shape(x2)
|
||||||
transpose_b = ndim2_orig == 1
|
transpose_b = ndim2_orig == 1
|
||||||
shape_backbone = _check_matmul_shapes(shape1_orig, shape2_orig)
|
shape_backbone = _check_matmul_shapes(shape1_orig, shape2_orig, prim_name)
|
||||||
# infers the shape of the output
|
# infers the shape of the output
|
||||||
shape_out = shape_backbone + _infer_shape_rem(shape1_orig, shape2_orig,
|
shape_out = shape_backbone + _infer_shape_rem(shape1_orig, shape2_orig,
|
||||||
ndim1_orig, ndim2_orig, transpose_b)
|
ndim1_orig, ndim2_orig, transpose_b)
|
||||||
|
|
|
@ -80,12 +80,13 @@ class ReduceOp:
|
||||||
target_dtypes = (mstype.int8, mstype.int32, mstype.float16, mstype.float32)
|
target_dtypes = (mstype.int8, mstype.int32, mstype.float16, mstype.float32)
|
||||||
|
|
||||||
|
|
||||||
def check_hcom_group_valid(group):
|
def check_hcom_group_valid(group, prim_name=None):
|
||||||
"""Check if hcom group is valid."""
|
"""Check if hcom group is valid."""
|
||||||
|
msg_pfefix = f"For '{prim_name}', only" if prim_name else "Only"
|
||||||
if context.get_context("mode") == context.PYNATIVE_MODE and \
|
if context.get_context("mode") == context.PYNATIVE_MODE and \
|
||||||
context.get_context("device_target") == "Ascend" and \
|
context.get_context("device_target") == "Ascend" and \
|
||||||
group != GlobalComm.WORLD_COMM_GROUP:
|
group != GlobalComm.WORLD_COMM_GROUP:
|
||||||
raise RuntimeError("Only hccl_world_group is supported in Pynative mode, but got {}".format(group))
|
raise RuntimeError(f"{msg_pfefix} hccl_world_group is supported in Pynative mode, but got 'group': {group}.")
|
||||||
|
|
||||||
|
|
||||||
class AllReduce(PrimitiveWithInfer):
|
class AllReduce(PrimitiveWithInfer):
|
||||||
|
@ -146,10 +147,11 @@ class AllReduce(PrimitiveWithInfer):
|
||||||
def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
|
def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
|
||||||
"""Initialize AllReduce."""
|
"""Initialize AllReduce."""
|
||||||
if not isinstance(op, type(ReduceOp.SUM)):
|
if not isinstance(op, type(ReduceOp.SUM)):
|
||||||
raise TypeError("The operation of AllReduce should be str.")
|
raise TypeError(f"For '{self.name}', the 'op' of AllReduce should be str, but got {type(op)}.")
|
||||||
if not isinstance(_get_group(group), str):
|
if not isinstance(_get_group(group), str):
|
||||||
raise TypeError("The group of AllReduce should be str.")
|
raise TypeError(f"For '{self.name}', the 'group' of AllReduce should be str, "
|
||||||
check_hcom_group_valid(group)
|
f"but got {type(_get_group(group))}.")
|
||||||
|
check_hcom_group_valid(group, prim_name=self.name)
|
||||||
self.op = op
|
self.op = op
|
||||||
self.add_prim_attr('group', _get_group(group))
|
self.add_prim_attr('group', _get_group(group))
|
||||||
self.add_prim_attr('fusion', 0)
|
self.add_prim_attr('fusion', 0)
|
||||||
|
@ -338,7 +340,7 @@ class _HostAllGather(PrimitiveWithInfer):
|
||||||
def __init__(self, group=None):
|
def __init__(self, group=None):
|
||||||
"""Initialize _HostAllGather."""
|
"""Initialize _HostAllGather."""
|
||||||
if group is None:
|
if group is None:
|
||||||
raise ValueError(f"For '{self.name}' group must be set.")
|
raise ValueError(f"For '{self.name}', the 'group' cannot be None, but got {group}.")
|
||||||
validator.check_value_type('group', group, (tuple, list), self.name)
|
validator.check_value_type('group', group, (tuple, list), self.name)
|
||||||
validator.check_int(len(group), 2, Rel.GE, "group size", self.name)
|
validator.check_int(len(group), 2, Rel.GE, "group size", self.name)
|
||||||
for r in group:
|
for r in group:
|
||||||
|
@ -425,9 +427,10 @@ class ReduceScatter(PrimitiveWithInfer):
|
||||||
|
|
||||||
def infer_shape(self, x_shape):
|
def infer_shape(self, x_shape):
|
||||||
if self.rank_size == 0:
|
if self.rank_size == 0:
|
||||||
raise ValueError(f"For '{self.name}' rank_size can not be zero.")
|
raise ValueError(f"For '{self.name}', the 'rank_size' cannot be zero, but got {self.rank_size}.")
|
||||||
if x_shape[0] % self.rank_size != 0:
|
if x_shape[0] % self.rank_size != 0:
|
||||||
raise ValueError(f"For '{self.name}' the first dimension of x should be divided by rank_size.")
|
raise ValueError(f"For '{self.name}', the first dimension of 'x_shape' should be divided by 'rank_size', "
|
||||||
|
f"but got 'x_shape[0]': {x_shape[0]}, 'rank_size': {self.rank_size}.")
|
||||||
x_shape[0] = int(x_shape[0] / self.rank_size)
|
x_shape[0] = int(x_shape[0] / self.rank_size)
|
||||||
return x_shape
|
return x_shape
|
||||||
|
|
||||||
|
@ -465,7 +468,7 @@ class _HostReduceScatter(PrimitiveWithInfer):
|
||||||
def __init__(self, op=ReduceOp.SUM, group=None):
|
def __init__(self, op=ReduceOp.SUM, group=None):
|
||||||
"""Initialize _HostReduceScatter."""
|
"""Initialize _HostReduceScatter."""
|
||||||
if group is None:
|
if group is None:
|
||||||
raise ValueError(f"For '{self.name}' group must be set.")
|
raise ValueError(f"For '{self.name}', the 'group' cannot be None, but got {group}.")
|
||||||
validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name)
|
validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name)
|
||||||
validator.check_value_type('group', group, (tuple, list), self.name)
|
validator.check_value_type('group', group, (tuple, list), self.name)
|
||||||
validator.check_int(len(group), 2, Rel.GE, "group size", self.name)
|
validator.check_int(len(group), 2, Rel.GE, "group size", self.name)
|
||||||
|
@ -479,7 +482,8 @@ class _HostReduceScatter(PrimitiveWithInfer):
|
||||||
|
|
||||||
def infer_shape(self, x_shape):
|
def infer_shape(self, x_shape):
|
||||||
if x_shape[0] % self.group_size != 0:
|
if x_shape[0] % self.group_size != 0:
|
||||||
raise ValueError(f"For '{self.name}' the first dimension of x should be divided by group_size.")
|
raise ValueError(f"For '{self.name}', the first dimension of 'x_shape' should be divided by 'group_size', "
|
||||||
|
f"but got 'x_shape[0]': {x_shape[0]}, 'rank_size': {self.group_size}.")
|
||||||
x_shape[0] = int(x_shape[0] / self.group_size)
|
x_shape[0] = int(x_shape[0] / self.group_size)
|
||||||
return x_shape
|
return x_shape
|
||||||
|
|
||||||
|
@ -550,7 +554,7 @@ class Broadcast(PrimitiveWithInfer):
|
||||||
"""Initialize Broadcast."""
|
"""Initialize Broadcast."""
|
||||||
validator.check_value_type('root_rank', root_rank, (int,), self.name)
|
validator.check_value_type('root_rank', root_rank, (int,), self.name)
|
||||||
validator.check_value_type('group', _get_group(group), (str,), self.name)
|
validator.check_value_type('group', _get_group(group), (str,), self.name)
|
||||||
check_hcom_group_valid(group)
|
check_hcom_group_valid(group, prim_name=self.name)
|
||||||
self.add_prim_attr('group', _get_group(group))
|
self.add_prim_attr('group', _get_group(group))
|
||||||
self.add_prim_attr('no_elimilate', True)
|
self.add_prim_attr('no_elimilate', True)
|
||||||
|
|
||||||
|
@ -559,7 +563,7 @@ class Broadcast(PrimitiveWithInfer):
|
||||||
|
|
||||||
def infer_dtype(self, x_dtype):
|
def infer_dtype(self, x_dtype):
|
||||||
if not isinstance(x_dtype, tuple):
|
if not isinstance(x_dtype, tuple):
|
||||||
raise TypeError(f"{self.name}'s input should be a tuple!")
|
raise TypeError(f"For '{self.name}', the 'input_x' should be a tuple, but got {type(x_dtype)}!")
|
||||||
for _ele in x_dtype:
|
for _ele in x_dtype:
|
||||||
validator.check_tensor_dtype_valid('x', _ele, target_dtypes, self.name)
|
validator.check_tensor_dtype_valid('x', _ele, target_dtypes, self.name)
|
||||||
return x_dtype
|
return x_dtype
|
||||||
|
@ -689,10 +693,11 @@ class AlltoAll(PrimitiveWithInfer):
|
||||||
def infer_shape(self, x_shape):
|
def infer_shape(self, x_shape):
|
||||||
rank_size = get_group_size(_get_group(self.group))
|
rank_size = get_group_size(_get_group(self.group))
|
||||||
if self.split_count != rank_size:
|
if self.split_count != rank_size:
|
||||||
raise ValueError(f"split count '{self.split_count}' must be equal to rank size '{rank_size}'.")
|
raise ValueError(f"For '{self.name}', the 'split_count' must be equal to 'rank_size', "
|
||||||
|
f"but got 'split_count': {self.split_count}, 'rank_size': {rank_size}.")
|
||||||
if x_shape[self.split_dim] % self.split_count != 0:
|
if x_shape[self.split_dim] % self.split_count != 0:
|
||||||
raise ValueError(
|
raise ValueError(f"For '{self.name}', the 'split_count' must be divisible by 'rank_size', "
|
||||||
f"split count '{self.split_count}' must be divisible by rank size '{x_shape[self.split_dim]}'.")
|
f"but got 'split_count' {self.split_count}, 'rank_size' {x_shape[self.split_dim]}.")
|
||||||
x_shape[self.concat_dim] = x_shape[self.concat_dim] * self.split_count
|
x_shape[self.concat_dim] = x_shape[self.concat_dim] * self.split_count
|
||||||
x_shape[self.split_dim] = int(x_shape[self.split_dim] / self.split_count)
|
x_shape[self.split_dim] = int(x_shape[self.split_dim] / self.split_count)
|
||||||
return x_shape
|
return x_shape
|
||||||
|
|
|
@ -27,7 +27,7 @@ def _check_mode(class_name):
|
||||||
"""Check for PyNative mode."""
|
"""Check for PyNative mode."""
|
||||||
mode = context.get_context('mode')
|
mode = context.get_context('mode')
|
||||||
if mode == context.PYNATIVE_MODE:
|
if mode == context.PYNATIVE_MODE:
|
||||||
raise RuntimeError(f'{class_name} operator does not support PyNative mode.')
|
raise RuntimeError(f"For '{class_name}', the operator does not support PyNative mode.")
|
||||||
|
|
||||||
|
|
||||||
def _check_summary_param(name, value, class_name):
|
def _check_summary_param(name, value, class_name):
|
||||||
|
@ -37,7 +37,7 @@ def _check_summary_param(name, value, class_name):
|
||||||
n_value = name['value']
|
n_value = name['value']
|
||||||
validator.check_value_type('name', n_type, [type(mstype.string)], class_name)
|
validator.check_value_type('name', n_type, [type(mstype.string)], class_name)
|
||||||
if not n_value:
|
if not n_value:
|
||||||
raise ValueError(f"For 'name' the value should by valid string in {class_name}, but got an empty string.")
|
raise ValueError(f"For '{class_name}', the name should be valid string, but got '{n_value}'.")
|
||||||
|
|
||||||
v_type = value['dtype']
|
v_type = value['dtype']
|
||||||
validator.check_value_type('value', v_type, [type(mstype.tensor)], class_name)
|
validator.check_value_type('value', v_type, [type(mstype.tensor)], class_name)
|
||||||
|
@ -135,7 +135,7 @@ class ImageSummary(PrimitiveWithInfer):
|
||||||
v_shape = value['shape']
|
v_shape = value['shape']
|
||||||
image_dim = 4
|
image_dim = 4
|
||||||
if len(v_shape) != image_dim:
|
if len(v_shape) != image_dim:
|
||||||
raise ValueError(f"For 'value' the dim should be {image_dim} in {self.__class__.__name__},"
|
raise ValueError(f"For '{self.name}', the dimension of 'value' should be {image_dim},"
|
||||||
f" but got {len(v_shape)}.")
|
f" but got {len(v_shape)}.")
|
||||||
|
|
||||||
return SUMMARY_RETURN_VALUE
|
return SUMMARY_RETURN_VALUE
|
||||||
|
@ -226,8 +226,8 @@ class HistogramSummary(PrimitiveWithInfer):
|
||||||
v_shape = value['shape']
|
v_shape = value['shape']
|
||||||
# In the summary, the histogram value should be a tensor whose shape is not [].
|
# In the summary, the histogram value should be a tensor whose shape is not [].
|
||||||
if not v_shape:
|
if not v_shape:
|
||||||
raise ValueError(f"For 'value' the type should be tensor in {self.__class__.__name__}, "
|
raise ValueError(f"For '{self.name}', the type of 'value' should be tensor, "
|
||||||
f"shape should not be [].")
|
f"and whose shape should not be [], but got {v_shape}.")
|
||||||
|
|
||||||
return SUMMARY_RETURN_VALUE
|
return SUMMARY_RETURN_VALUE
|
||||||
|
|
||||||
|
@ -343,7 +343,8 @@ class HookBackward(PrimitiveWithInfer):
|
||||||
self.add_prim_attr("cell_id", cell_id)
|
self.add_prim_attr("cell_id", cell_id)
|
||||||
self.init_attrs["cell_id"] = cell_id
|
self.init_attrs["cell_id"] = cell_id
|
||||||
if not isinstance(hook_fn, (FunctionType, MethodType)):
|
if not isinstance(hook_fn, (FunctionType, MethodType)):
|
||||||
raise TypeError("Hook function should be python function type.")
|
raise TypeError(f"For '{self.name}', the tye of 'hook_fn' should be python function, "
|
||||||
|
f"but got {type(hook_fn)}.")
|
||||||
self.register_hook(hook_fn)
|
self.register_hook(hook_fn)
|
||||||
self.cell_id = cell_id
|
self.cell_id = cell_id
|
||||||
|
|
||||||
|
|
|
@ -104,7 +104,7 @@ class CropAndResize(PrimitiveWithInfer):
|
||||||
box_index_shape = list(box_index['shape'])
|
box_index_shape = list(box_index['shape'])
|
||||||
# get value
|
# get value
|
||||||
if crop_size['value'] is None:
|
if crop_size['value'] is None:
|
||||||
raise ValueError(f"For {self.name}, crop_size must be constant.")
|
raise ValueError(f"For '{self.name}', the 'crop_size' cannot be None, but got {crop_size['value']}.")
|
||||||
crop_size_value = crop_size['value']
|
crop_size_value = crop_size['value']
|
||||||
# get dtype
|
# get dtype
|
||||||
x_dtype = x['dtype']
|
x_dtype = x['dtype']
|
||||||
|
|
|
@ -394,7 +394,7 @@ class _Reduce(PrimitiveWithInfer):
|
||||||
args = {'input_x': input_x['dtype']}
|
args = {'input_x': input_x['dtype']}
|
||||||
validator.check_tensors_dtypes_same_and_valid(args, valid_dtype, self.name)
|
validator.check_tensors_dtypes_same_and_valid(args, valid_dtype, self.name)
|
||||||
if not isinstance(axis, mstype.tensor_type) and axis_v is None:
|
if not isinstance(axis, mstype.tensor_type) and axis_v is None:
|
||||||
raise ValueError(f"For {self.name}, axis must be const.")
|
raise ValueError(f"For '{self.name}', the 'axis' cannot be None, but got {axis}.")
|
||||||
out_shape = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.name)
|
out_shape = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.name)
|
||||||
if -1 in input_shp:
|
if -1 in input_shp:
|
||||||
if axis_v is None:
|
if axis_v is None:
|
||||||
|
@ -404,7 +404,8 @@ class _Reduce(PrimitiveWithInfer):
|
||||||
max_v = max(input_max_shp)
|
max_v = max(input_max_shp)
|
||||||
axis_shape_list = axis['shape']
|
axis_shape_list = axis['shape']
|
||||||
if len(axis_shape_list) != 1:
|
if len(axis_shape_list) != 1:
|
||||||
raise ValueError("axis_shape must be 1-D, but got ", len(axis_shape_list))
|
raise ValueError(f"For '{self.name}', the shape of 'axis' must be 1-D, but "
|
||||||
|
f"got {len(axis_shape_list)}.")
|
||||||
axis_shape = axis_shape_list[0]
|
axis_shape = axis_shape_list[0]
|
||||||
if len(axis_shape) == 1 and axis_shape[0] == -1 and not self.keep_dims:
|
if len(axis_shape) == 1 and axis_shape[0] == -1 and not self.keep_dims:
|
||||||
out_shape = np.array([-2]).tolist()
|
out_shape = np.array([-2]).tolist()
|
||||||
|
@ -1052,7 +1053,7 @@ class CumProd(PrimitiveWithInfer):
|
||||||
|
|
||||||
def infer_value(self, x, axis):
|
def infer_value(self, x, axis):
|
||||||
if axis is None:
|
if axis is None:
|
||||||
raise ValueError(f"For {self.name}, axis must be const.")
|
raise ValueError(f"For '{self.name}', the 'axis' cannot be None, but got {axis}.")
|
||||||
|
|
||||||
|
|
||||||
class MatMul(PrimitiveWithCheck):
|
class MatMul(PrimitiveWithCheck):
|
||||||
|
@ -1107,8 +1108,8 @@ class MatMul(PrimitiveWithCheck):
|
||||||
def check_shape_size(self, x1, x2):
|
def check_shape_size(self, x1, x2):
|
||||||
"""Check the shape size of inputs for MatMul."""
|
"""Check the shape size of inputs for MatMul."""
|
||||||
if len(x1) != 2 or len(x2) != 2:
|
if len(x1) != 2 or len(x2) != 2:
|
||||||
raise ValueError('P.MatMul inputs x1, x2 should have the same dimension size and '
|
raise ValueError(f"For '{self.name}', inputs 'x', 'y' should have the same dimension size and "
|
||||||
+ f'equal to 2, while x1 size is ({len(x1)}) and x2 size is ({len(x2)}).')
|
f"be equal to 2, but got the size of 'x': ({len(x1)}) and the size of 'y': ({len(x2)}).")
|
||||||
|
|
||||||
def check_shape(self, x1, x2):
|
def check_shape(self, x1, x2):
|
||||||
self.check_shape_size(x1, x2)
|
self.check_shape_size(x1, x2)
|
||||||
|
@ -1116,8 +1117,8 @@ class MatMul(PrimitiveWithCheck):
|
||||||
# expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two
|
# expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two
|
||||||
for i in range(len(x1) - 2):
|
for i in range(len(x1) - 2):
|
||||||
if x1[i] != x2[i]:
|
if x1[i] != x2[i]:
|
||||||
raise ValueError(f'For \'{cls_name}\' shape in dim[{i}] not the same, '
|
raise ValueError(f"For '{cls_name}', the dim[{i}] of 'x' should be equal to the dim[{i}] of 'y', "
|
||||||
+ f'while x1 is {x1[i]}, x2 is {x2[i]}')
|
f"but got 'x[{i}]': {x1[i]} and 'y[{i}]': {x2[i]}.")
|
||||||
|
|
||||||
# validate whether last two dims satisfying matrix multiply
|
# validate whether last two dims satisfying matrix multiply
|
||||||
x1_last = x1[-2:]
|
x1_last = x1[-2:]
|
||||||
|
@ -1126,10 +1127,9 @@ class MatMul(PrimitiveWithCheck):
|
||||||
x2_row = x2_last[self.transpose_b]
|
x2_row = x2_last[self.transpose_b]
|
||||||
if np.all(np.array(x1) != -1) and np.all(np.array(x2) != -1):
|
if np.all(np.array(x1) != -1) and np.all(np.array(x2) != -1):
|
||||||
if x1_col != x2_row:
|
if x1_col != x2_row:
|
||||||
raise ValueError(f'For \'{cls_name}\' evaluator shapes of inputs can not do this operator,'
|
raise ValueError(f"For '{cls_name}', the input dimensions must be equal, but got 'x1_col': {x1_col} "
|
||||||
f' dimensions must be equal,'
|
f"and 'x2_row': {x2_row}. And 'x' shape {x1}(transpose_a={self.transpose_a}), "
|
||||||
+ f' got {x1_col} and {x2_row}, with x1 shape {x1}(transpose_a={self.transpose_a})'
|
f"'y' shape {x2}(transpose_b={self.transpose_b}).")
|
||||||
+ f', x2 shape {x2}(transpose_b={self.transpose_b}).')
|
|
||||||
# set attribute
|
# set attribute
|
||||||
self.add_prim_attr('transpose_x1', self.transpose_a)
|
self.add_prim_attr('transpose_x1', self.transpose_a)
|
||||||
self.add_prim_attr('transpose_x2', self.transpose_b)
|
self.add_prim_attr('transpose_x2', self.transpose_b)
|
||||||
|
@ -1212,8 +1212,8 @@ class BatchMatMul(MatMul):
|
||||||
|
|
||||||
def check_shape_size(self, x, y):
|
def check_shape_size(self, x, y):
|
||||||
if len(x) != len(y) or len(x) < 3:
|
if len(x) != len(y) or len(x) < 3:
|
||||||
raise ValueError('For \'BatchMatMul\' input x, y should be the same dimension size and should be '
|
raise ValueError(f"For '{self.name}', input 'x', 'y' should be the same dimension size and should be "
|
||||||
'greater or equal to 3,' + f' while x size = {len(x)}, y size= {len(y)}')
|
f"greater than or equal to 3, but got 'x' size: {len(x)}, 'y' size: {len(y)}.")
|
||||||
|
|
||||||
|
|
||||||
class CumSum(PrimitiveWithInfer):
|
class CumSum(PrimitiveWithInfer):
|
||||||
|
@ -1291,7 +1291,7 @@ class CumSum(PrimitiveWithInfer):
|
||||||
cls_name = self.name
|
cls_name = self.name
|
||||||
x_shp = x['shape']
|
x_shp = x['shape']
|
||||||
if axis['value'] is None:
|
if axis['value'] is None:
|
||||||
raise ValueError(f"For {self.name}, axis must be const.")
|
raise ValueError(f"For '{self.name}', the 'axis' cannot be None, but got {axis}.")
|
||||||
validator.check_value_type('axis', axis['value'], [int], cls_name)
|
validator.check_value_type('axis', axis['value'], [int], cls_name)
|
||||||
valid_dtypes = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32, mstype.float64]
|
valid_dtypes = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32, mstype.float64]
|
||||||
validator.check_tensor_dtype_valid('x', x['dtype'], valid_dtypes, cls_name)
|
validator.check_tensor_dtype_valid('x', x['dtype'], valid_dtypes, cls_name)
|
||||||
|
@ -1346,7 +1346,9 @@ class AddN(Primitive):
|
||||||
return False, None
|
return False, None
|
||||||
if isinstance(inputs[0], Tensor):
|
if isinstance(inputs[0], Tensor):
|
||||||
return True, inputs[0]
|
return True, inputs[0]
|
||||||
raise TypeError("Expecting Tensor, got : {}".format(type(inputs[0])))
|
raise TypeError(f"For '{self.name}', the type of 'inputs[0]' should be a tensor, but "
|
||||||
|
f"got {type(inputs[0]).__name__}, "
|
||||||
|
f"or the length of 'inputs' should not equal to 1, but got ({len(inputs)}).")
|
||||||
|
|
||||||
|
|
||||||
class AccumulateNV2(PrimitiveWithInfer):
|
class AccumulateNV2(PrimitiveWithInfer):
|
||||||
|
@ -1400,7 +1402,9 @@ class AccumulateNV2(PrimitiveWithInfer):
|
||||||
return False, None
|
return False, None
|
||||||
if isinstance(inputs[0], Tensor):
|
if isinstance(inputs[0], Tensor):
|
||||||
return True, inputs[0]
|
return True, inputs[0]
|
||||||
raise TypeError("Expecting Tensor, got : {}".format(type(inputs[0])))
|
raise TypeError(f"For '{self.name}', the type of 'inputs[0]' should be a tensor, "
|
||||||
|
f"but got {type(inputs[0]).__name__}, "
|
||||||
|
f"or the length of 'inputs' should not equal to 1, but got ({len(inputs)}).")
|
||||||
|
|
||||||
def infer_shape(self, inputs):
|
def infer_shape(self, inputs):
|
||||||
cls_name = self.name
|
cls_name = self.name
|
||||||
|
@ -1532,7 +1536,8 @@ class InplaceAdd(PrimitiveWithInfer):
|
||||||
Rel.EQ, self.name)
|
Rel.EQ, self.name)
|
||||||
for i in self.indices:
|
for i in self.indices:
|
||||||
if i < 0 or i >= x_shape[0]:
|
if i < 0 or i >= x_shape[0]:
|
||||||
raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.')
|
raise ValueError(f"For '{self.name}', the value of 'indices' must be "
|
||||||
|
f"in [0, {x_shape[0]}), but got {i}.")
|
||||||
x_rank = len(x_shape)
|
x_rank = len(x_shape)
|
||||||
for idx in range(x_rank)[1:]:
|
for idx in range(x_rank)[1:]:
|
||||||
validator.check('v dim %d' % idx, v_shape[idx], "x dim %d" % idx, x_shape[idx], Rel.EQ, self.name)
|
validator.check('v dim %d' % idx, v_shape[idx], "x dim %d" % idx, x_shape[idx], Rel.EQ, self.name)
|
||||||
|
@ -1600,7 +1605,8 @@ class InplaceSub(PrimitiveWithInfer):
|
||||||
Rel.EQ, self.name)
|
Rel.EQ, self.name)
|
||||||
for i in self.indices:
|
for i in self.indices:
|
||||||
if i < 0 or i >= x_shape[0]:
|
if i < 0 or i >= x_shape[0]:
|
||||||
raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.')
|
raise ValueError(f"For '{self.name}', the value of 'indices' must be "
|
||||||
|
f"in [0, {x_shape[0]}), but got {i}.")
|
||||||
x_rank = len(x_shape)
|
x_rank = len(x_shape)
|
||||||
for idx in range(x_rank)[1:]:
|
for idx in range(x_rank)[1:]:
|
||||||
validator.check('v dim %d' % idx, v_shape[idx], "x dim %d" % idx, x_shape[idx], Rel.EQ, self.name)
|
validator.check('v dim %d' % idx, v_shape[idx], "x dim %d" % idx, x_shape[idx], Rel.EQ, self.name)
|
||||||
|
|
|
@ -380,7 +380,7 @@ class IOU(PrimitiveWithInfer):
|
||||||
def __init__(self, mode='iou'):
|
def __init__(self, mode='iou'):
|
||||||
"""Initialize IOU."""
|
"""Initialize IOU."""
|
||||||
if mode not in {'iou', 'iof'}:
|
if mode not in {'iou', 'iof'}:
|
||||||
raise KeyError("Mode only support 'iou' or 'iof'.")
|
raise KeyError(f"For '{self.name}', only 'iou' or 'iof' are supported, but got 'mode': {mode}.")
|
||||||
self.init_prim_io_names(inputs=['anchor_boxes', 'gt_boxes'], outputs=['overlap'])
|
self.init_prim_io_names(inputs=['anchor_boxes', 'gt_boxes'], outputs=['overlap'])
|
||||||
|
|
||||||
def infer_shape(self, anchor_boxes, gt_boxes):
|
def infer_shape(self, anchor_boxes, gt_boxes):
|
||||||
|
@ -535,12 +535,11 @@ class CheckBprop(PrimitiveWithInfer):
|
||||||
self.prim_to_check = prim_to_check
|
self.prim_to_check = prim_to_check
|
||||||
|
|
||||||
def infer_shape(self, xshapes, yshapes):
|
def infer_shape(self, xshapes, yshapes):
|
||||||
tips = f'Bprop of {self.prim_to_check}'
|
validator.check_value_type('grads', xshapes, (tuple,), self.name)
|
||||||
validator.check_value_type('grads', xshapes, (tuple,), tips)
|
validator.check_value_type('params', yshapes, (tuple,), self.name)
|
||||||
validator.check_value_type('params', yshapes, (tuple,), tips)
|
|
||||||
if len(xshapes) < len(yshapes):
|
if len(xshapes) < len(yshapes):
|
||||||
raise ValueError(f"{tips}, the size of output should be {len(yshapes)},"
|
raise ValueError(f"For '{self.name}', the size of 'xshapes' should not be less than {len(yshapes)}, "
|
||||||
f" but got {len(xshapes)}.")
|
f"but got {len(xshapes)}.")
|
||||||
checking_range = len(yshapes)
|
checking_range = len(yshapes)
|
||||||
for i in range(checking_range):
|
for i in range(checking_range):
|
||||||
xshape = xshapes[i]
|
xshape = xshapes[i]
|
||||||
|
@ -548,16 +547,15 @@ class CheckBprop(PrimitiveWithInfer):
|
||||||
if not xshape or not yshape:
|
if not xshape or not yshape:
|
||||||
continue
|
continue
|
||||||
if xshape != yshape:
|
if xshape != yshape:
|
||||||
raise ValueError(f"{tips}, the shape of {i}th output should be {yshape},"
|
raise ValueError(f"For '{self.name}', the shape of {i}th 'xshapes' should be {yshape},"
|
||||||
f" but got {xshape}.")
|
f" but got 'xshapes[i]': {xshape}.")
|
||||||
return xshapes
|
return xshapes
|
||||||
|
|
||||||
def infer_dtype(self, xdtypes, ydtypes):
|
def infer_dtype(self, xdtypes, ydtypes):
|
||||||
tips = f'Bprop of {self.prim_to_check}'
|
validator.check_value_type('grads', xdtypes, (tuple,), self.name)
|
||||||
validator.check_value_type('grads', xdtypes, (tuple,), tips)
|
validator.check_value_type('params', ydtypes, (tuple,), self.name)
|
||||||
validator.check_value_type('params', ydtypes, (tuple,), tips)
|
|
||||||
if len(xdtypes) < len(ydtypes):
|
if len(xdtypes) < len(ydtypes):
|
||||||
raise ValueError(f"{tips}, the size of output should be {len(ydtypes)},"
|
raise ValueError(f"For '{self.name}', the size of 'xdtypes' should not be less than {len(ydtypes)},"
|
||||||
f" but got {len(xdtypes)}.")
|
f" but got {len(xdtypes)}.")
|
||||||
checking_range = len(ydtypes)
|
checking_range = len(ydtypes)
|
||||||
for i in range(checking_range):
|
for i in range(checking_range):
|
||||||
|
@ -567,11 +565,11 @@ class CheckBprop(PrimitiveWithInfer):
|
||||||
continue
|
continue
|
||||||
if isinstance(ydtype, mstype.function_type):
|
if isinstance(ydtype, mstype.function_type):
|
||||||
if not isinstance(xdtype, mstype.env_type_type):
|
if not isinstance(xdtype, mstype.env_type_type):
|
||||||
raise TypeError(f"{tips}, the dtype of {i}th output should be {mstype.env_type_type},"
|
raise TypeError(f"For '{self.name}', the dtype of {i}th 'xdtypes' should be {mstype.env_type_type},"
|
||||||
f" but got {xdtype}.")
|
f" but got {xdtype}.")
|
||||||
continue
|
continue
|
||||||
if xdtype != ydtype:
|
if xdtype != ydtype:
|
||||||
raise TypeError(f"{tips}, the dtype of {i}th output should be {ydtype},"
|
raise TypeError(f"For '{self.name}', the shape of {i}th 'xdtypes' should be {ydtype},"
|
||||||
f" but got {xdtype}.")
|
f" but got {xdtype}.")
|
||||||
return xdtypes
|
return xdtypes
|
||||||
|
|
||||||
|
|
|
@ -64,13 +64,11 @@ class PQC(PrimitiveWithInfer):
|
||||||
|
|
||||||
def check_shape_size(self, encoder_data, ansatz_data):
|
def check_shape_size(self, encoder_data, ansatz_data):
|
||||||
if len(encoder_data) != 2:
|
if len(encoder_data) != 2:
|
||||||
raise ValueError(
|
raise ValueError(f"For '{self.name}', the dimension of 'encoder_data' should be 2, "
|
||||||
"PQC input encoder_data should have dimension size \
|
f"but got {len(encoder_data)}.")
|
||||||
equal to 2, but got {}.".format(len(encoder_data)))
|
|
||||||
if len(ansatz_data) != 1:
|
if len(ansatz_data) != 1:
|
||||||
raise ValueError(
|
raise ValueError(f"For '{self.name}', the dimension of 'ansatz_data' should be 1, "
|
||||||
"PQC input ansatz_data should have dimension size \
|
f"but got {len(ansatz_data)}.")
|
||||||
equal to 1, but got {}.".format(len(ansatz_data)))
|
|
||||||
|
|
||||||
def infer_shape(self, encoder_data, ansatz_data):
|
def infer_shape(self, encoder_data, ansatz_data):
|
||||||
self.check_shape_size(encoder_data, ansatz_data)
|
self.check_shape_size(encoder_data, ansatz_data)
|
||||||
|
@ -124,8 +122,8 @@ class Evolution(PrimitiveWithInfer):
|
||||||
|
|
||||||
def check_shape_size(self, param_data):
|
def check_shape_size(self, param_data):
|
||||||
if len(param_data) != 1:
|
if len(param_data) != 1:
|
||||||
raise ValueError("PQC input param_data should have dimension size \
|
raise ValueError(f"For '{self.name}', the dimension of 'param_data' should be 1, "
|
||||||
equal to 1, but got {}.".format(len(param_data)))
|
f"but got {len(param_data)}.")
|
||||||
|
|
||||||
def infer_shape(self, param_data):
|
def infer_shape(self, param_data):
|
||||||
self.check_shape_size(param_data)
|
self.check_shape_size(param_data)
|
||||||
|
|
|
@ -70,7 +70,7 @@ class StandardNormal(PrimitiveWithInfer):
|
||||||
def __infer__(self, shape):
|
def __infer__(self, shape):
|
||||||
shape_v = shape["value"]
|
shape_v = shape["value"]
|
||||||
if shape_v is None:
|
if shape_v is None:
|
||||||
raise ValueError(f"For {self.name}, shape must be const.")
|
raise ValueError(f"For '{self.name}', the 'shape' cannot be None, but got {shape}.")
|
||||||
Validator.check_value_type("shape", shape_v, [tuple], self.name)
|
Validator.check_value_type("shape", shape_v, [tuple], self.name)
|
||||||
for i, shape_i in enumerate(shape_v):
|
for i, shape_i in enumerate(shape_v):
|
||||||
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
|
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
|
||||||
|
@ -128,7 +128,7 @@ class StandardLaplace(PrimitiveWithInfer):
|
||||||
def __infer__(self, shape):
|
def __infer__(self, shape):
|
||||||
shape_v = shape["value"]
|
shape_v = shape["value"]
|
||||||
if shape_v is None:
|
if shape_v is None:
|
||||||
raise ValueError(f"For {self.name}, shape must be const.")
|
raise ValueError(f"For '{self.name}', the 'shape' cannot be None, but got {shape}.")
|
||||||
Validator.check_value_type("shape", shape_v, [tuple], self.name)
|
Validator.check_value_type("shape", shape_v, [tuple], self.name)
|
||||||
for i, shape_i in enumerate(shape_v):
|
for i, shape_i in enumerate(shape_v):
|
||||||
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
|
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
|
||||||
|
@ -192,7 +192,7 @@ class Gamma(PrimitiveWithInfer):
|
||||||
def __infer__(self, shape, alpha, beta):
|
def __infer__(self, shape, alpha, beta):
|
||||||
shape_v = shape["value"]
|
shape_v = shape["value"]
|
||||||
if shape_v is None:
|
if shape_v is None:
|
||||||
raise ValueError(f"For {self.name}, shape must be const.")
|
raise ValueError(f"For '{self.name}', the 'shape' cannot be None, but got {shape}.")
|
||||||
Validator.check_value_type("shape", shape_v, [tuple], self.name)
|
Validator.check_value_type("shape", shape_v, [tuple], self.name)
|
||||||
for i, shape_i in enumerate(shape_v):
|
for i, shape_i in enumerate(shape_v):
|
||||||
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
|
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
|
||||||
|
@ -257,7 +257,7 @@ class Poisson(PrimitiveWithInfer):
|
||||||
def __infer__(self, shape, mean):
|
def __infer__(self, shape, mean):
|
||||||
shape_v = shape["value"]
|
shape_v = shape["value"]
|
||||||
if shape_v is None:
|
if shape_v is None:
|
||||||
raise ValueError(f"For {self.name}, shape must be const.")
|
raise ValueError(f"For '{self.name}', the 'shape' cannot be None, but got {shape}.")
|
||||||
Validator.check_value_type("shape", shape_v, [tuple], self.name)
|
Validator.check_value_type("shape", shape_v, [tuple], self.name)
|
||||||
for i, shape_i in enumerate(shape_v):
|
for i, shape_i in enumerate(shape_v):
|
||||||
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
|
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
|
||||||
|
@ -330,7 +330,7 @@ class UniformInt(PrimitiveWithInfer):
|
||||||
def __infer__(self, shape, minval, maxval):
|
def __infer__(self, shape, minval, maxval):
|
||||||
shape_v = shape["value"]
|
shape_v = shape["value"]
|
||||||
if shape_v is None:
|
if shape_v is None:
|
||||||
raise ValueError(f"For {self.name}, shape must be const.")
|
raise ValueError(f"For '{self.name}', the 'shape' cannot be None, but got {shape}.")
|
||||||
Validator.check_value_type("shape", shape_v, [tuple], self.name)
|
Validator.check_value_type("shape", shape_v, [tuple], self.name)
|
||||||
for i, shape_i in enumerate(shape_v):
|
for i, shape_i in enumerate(shape_v):
|
||||||
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
|
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
|
||||||
|
@ -507,7 +507,8 @@ class RandomCategorical(PrimitiveWithInfer):
|
||||||
Validator.check_positive_int(num_samples_v, "num_samples", self.name)
|
Validator.check_positive_int(num_samples_v, "num_samples", self.name)
|
||||||
x_shape = list(logits['shape'])
|
x_shape = list(logits['shape'])
|
||||||
if len(x_shape) != 2:
|
if len(x_shape) != 2:
|
||||||
raise ValueError("RandomCategorical shape should be 2-dimension.")
|
raise ValueError(f"For '{self.name}', the shape of 'logits' should be 2-dimension, "
|
||||||
|
f"but got {len(x_shape)}.")
|
||||||
ndim = len(x_shape) - 1
|
ndim = len(x_shape) - 1
|
||||||
x_shape[ndim] = num_samples_v
|
x_shape[ndim] = num_samples_v
|
||||||
self.add_prim_attr('num_samples', num_samples_v)
|
self.add_prim_attr('num_samples', num_samples_v)
|
||||||
|
@ -566,11 +567,12 @@ class Multinomial(PrimitiveWithInfer):
|
||||||
def __infer__(self, inputs, num_samples):
|
def __infer__(self, inputs, num_samples):
|
||||||
input_shape = inputs["shape"]
|
input_shape = inputs["shape"]
|
||||||
if len(input_shape) != 1 and len(input_shape) != 2:
|
if len(input_shape) != 1 and len(input_shape) != 2:
|
||||||
raise ValueError("input dim must be 1 or 2")
|
raise ValueError(f"For '{self.name}', the dimension of 'inputs' must be 1 or 2, "
|
||||||
|
f"but got {len(input_shape)}.")
|
||||||
Validator.check_tensor_dtype_valid('inputs', inputs['dtype'], [mstype.float32], self.name)
|
Validator.check_tensor_dtype_valid('inputs', inputs['dtype'], [mstype.float32], self.name)
|
||||||
num_samples_value = num_samples["value"]
|
num_samples_value = num_samples["value"]
|
||||||
if num_samples_value is None:
|
if num_samples_value is None:
|
||||||
raise ValueError(f"For {self.name}, shape nust be const")
|
raise ValueError(f"For '{self.name}', the 'num_samples' cannot be None, but got {num_samples}.")
|
||||||
Validator.check_value_type("num_samples", num_samples_value, (int,), self.name)
|
Validator.check_value_type("num_samples", num_samples_value, (int,), self.name)
|
||||||
Validator.check_positive_int(num_samples_value, "num_samples")
|
Validator.check_positive_int(num_samples_value, "num_samples")
|
||||||
y_shape = (num_samples_value,)
|
y_shape = (num_samples_value,)
|
||||||
|
|
|
@ -209,20 +209,27 @@ class BufferAppend(PrimitiveWithInfer):
|
||||||
exp_batch = exp_shape[0][0]
|
exp_batch = exp_shape[0][0]
|
||||||
for i in range(len(data_shape)):
|
for i in range(len(data_shape)):
|
||||||
if len(data_shape[i]) != len(exp_shape[i]):
|
if len(data_shape[i]) != len(exp_shape[i]):
|
||||||
raise ValueError(f'For {self.name}, exp shape size must equal to buffer')
|
raise ValueError(f"For '{self.name}', the dimension of {i}th 'exp_shape' must equal to "
|
||||||
|
f"the dimension of {i}th 'data_shape', but got the {i}th 'exp_shape': "
|
||||||
|
f"{exp_shape[i]}, the {i}th 'data_shape': {data_shape[i]}.")
|
||||||
if data_shape[i][0] < exp_shape[i][0]:
|
if data_shape[i][0] < exp_shape[i][0]:
|
||||||
raise ValueError(f'For {self.name}, exp batch size must lessequal than buffer')
|
raise ValueError(f"For '{self.name}', the first dimension of {i}th 'data_shape' must be greater "
|
||||||
|
f"than or equal to the first dimension of {i}th 'exp_shape', but got the {i}th "
|
||||||
|
f"'exp_shape': {exp_shape[i]}, the {i}th 'data_shape': {data_shape[i]}.")
|
||||||
else:
|
else:
|
||||||
for i in range(len(data_shape)):
|
for i in range(len(data_shape)):
|
||||||
if data_shape[i][1:] != exp_shape[i]:
|
if data_shape[i][1:] != exp_shape[i]:
|
||||||
raise ValueError(f'For {self.name}, exp shape must equal to one of buffer shape')
|
raise ValueError(f"For '{self.name}', the {i}th 'exp_shape' must equal to the {i}th 'data_shape'"
|
||||||
|
f"which excepts the first dimension. but got the {i}th 'exp_shape': "
|
||||||
|
f"{exp_shape[i]}, the {i}th 'data_shape': {data_shape[i]}.")
|
||||||
self.add_prim_attr('exp_batch', exp_batch)
|
self.add_prim_attr('exp_batch', exp_batch)
|
||||||
return count_shape
|
return count_shape
|
||||||
|
|
||||||
def infer_dtype(self, data_type, exp_type, count_type, head_type):
|
def infer_dtype(self, data_type, exp_type, count_type, head_type):
|
||||||
for i in range(len(data_type)):
|
for i in range(len(data_type)):
|
||||||
if data_type[i] != exp_type[i]:
|
if data_type[i] != exp_type[i]:
|
||||||
raise TypeError(f'For {self.name}, each tensor in exp must has the same type with buffer')
|
raise TypeError(f"For '{self.name}', each tensor in 'exp' must has the same type with 'data', but got "
|
||||||
|
f"'data_type': {data_type}, 'exp_type': {exp_type}.")
|
||||||
validator.check_type_name("count type", count_type, (mstype.int32), self.name)
|
validator.check_type_name("count type", count_type, (mstype.int32), self.name)
|
||||||
validator.check_type_name("head type", head_type, (mstype.int32), self.name)
|
validator.check_type_name("head type", head_type, (mstype.int32), self.name)
|
||||||
return count_type
|
return count_type
|
||||||
|
|
|
@ -67,22 +67,22 @@ class SparseToDense(PrimitiveWithInfer):
|
||||||
validator.check_tensor_dtype_valid('values', values['dtype'], mstype.number_type + (mstype.bool_,), self.name)
|
validator.check_tensor_dtype_valid('values', values['dtype'], mstype.number_type + (mstype.bool_,), self.name)
|
||||||
indices_shape = indices['shape']
|
indices_shape = indices['shape']
|
||||||
if len(indices_shape) != 2:
|
if len(indices_shape) != 2:
|
||||||
raise ValueError("SparseToDense requires 'indices' must be a 2-D Tensor, "
|
raise ValueError(f"For '{self.name}', the 'indices' must be a 2-D tensor, "
|
||||||
f"but got 'indices' shape: {indices_shape}")
|
f"but got 'indices' shape: {indices_shape}.")
|
||||||
values_shape = values['shape']
|
values_shape = values['shape']
|
||||||
if len(values_shape) != 1 or values_shape[0] != indices_shape[0]:
|
if len(values_shape) != 1 or values_shape[0] != indices_shape[0]:
|
||||||
raise ValueError("SparseToDense requires 'values' must be a 1-D Tensor and "
|
raise ValueError(f"For '{self.name}', the 'values' must be a 1-D tensor and the first dimension length "
|
||||||
"the first dimension length must be equal to the first dimension length of 'indices', "
|
f"must be equal to the first dimension length of 'indices', "
|
||||||
f"but got 'indices' shape: {indices_shape}, 'values' shape: {values_shape}")
|
f"but got 'indices' shape: {indices_shape}, 'values' shape: {values_shape}.")
|
||||||
sparse_shape_v = sparse_shape['value']
|
sparse_shape_v = sparse_shape['value']
|
||||||
for i in sparse_shape_v:
|
for i in sparse_shape_v:
|
||||||
if isinstance(i, bool) or not isinstance(i, int) or i <= 0:
|
if isinstance(i, bool) or not isinstance(i, int) or i <= 0:
|
||||||
raise ValueError("SparseToDense requires all elements in 'sparse_shape' must be "
|
raise ValueError(f"For '{self.name}', all elements in 'sparse_shape' must be "
|
||||||
f"positive int number, but got 'sparse_shape': {sparse_shape_v}")
|
f"positive int number, but got 'sparse_shape': {sparse_shape_v}.")
|
||||||
if len(sparse_shape_v) != indices_shape[1]:
|
if len(sparse_shape_v) != indices_shape[1]:
|
||||||
raise ValueError("SparseToDense requires the 'sparse_shape' length should be equal to the 'indices' "
|
raise ValueError(f"For '{self.name}', the length of 'sparse_shape' should be equal to the second dimension "
|
||||||
"second dimension length, but got the 'indices' second dimension length: "
|
f"length of 'indices', but got the second dimension length of 'indices': "
|
||||||
f"{indices_shape[1]}, 'sparse_shape' length: {len(sparse_shape_v)}")
|
f"{indices_shape[1]}, length of 'sparse_shape': {len(sparse_shape_v)}.")
|
||||||
out = {'shape': sparse_shape['value'],
|
out = {'shape': sparse_shape['value'],
|
||||||
'dtype': values['dtype'],
|
'dtype': values['dtype'],
|
||||||
'value': None}
|
'value': None}
|
||||||
|
@ -157,26 +157,29 @@ class SparseTensorDenseMatmul(PrimitiveWithInfer):
|
||||||
validator.check_tensors_dtypes_same_and_valid(args, valid_types, self.name)
|
validator.check_tensors_dtypes_same_and_valid(args, valid_types, self.name)
|
||||||
indices_shape = indices['shape']
|
indices_shape = indices['shape']
|
||||||
if len(indices_shape) != 2 or indices_shape[1] != 2:
|
if len(indices_shape) != 2 or indices_shape[1] != 2:
|
||||||
raise ValueError("SparseTensorDenseMatmul requires 'indices' must be a 2-D Tensor and "
|
raise ValueError(f"For '{self.name}', the 'indices' must be a 2-D tensor and "
|
||||||
f"the second dimension length must be 2, but got 'indices' shape: {indices_shape}")
|
f"the second dimension length must be 2, but got 'indices' shape: {indices_shape}.")
|
||||||
values_shape = values['shape']
|
values_shape = values['shape']
|
||||||
if len(values_shape) != 1 or values_shape[0] != indices_shape[0]:
|
if len(values_shape) != 1 or values_shape[0] != indices_shape[0]:
|
||||||
raise ValueError("SparseTensorDenseMatmul requires 'value's must be a 1-D Tensor and "
|
raise ValueError(f"For '{self.name}', the 'values' must be a 1-D tensor and "
|
||||||
f"the first dimension length must be equal to the first dimension length of 'indices', "
|
f"the first dimension length must be equal to the first dimension length of 'indices', "
|
||||||
f"but got 'indices' shape: {indices_shape}, 'values' shape: {values_shape}")
|
f"but got 'indices' shape: {indices_shape}, 'values' shape: {values_shape}.")
|
||||||
a_shape = sparse_shape['value'][::-1] if self.adjoint_st else sparse_shape['value']
|
a_shape = sparse_shape['value'][::-1] if self.adjoint_st else sparse_shape['value']
|
||||||
b_shape = dense['shape'][::-1] if self.adjoint_dt else dense['shape']
|
b_shape = dense['shape'][::-1] if self.adjoint_dt else dense['shape']
|
||||||
for i in a_shape:
|
for i in a_shape:
|
||||||
if isinstance(i, bool) or not isinstance(i, int) or i <= 0:
|
if isinstance(i, bool) or not isinstance(i, int) or i <= 0:
|
||||||
raise ValueError("SparseTensorDenseMatmul requires all elements in 'sparse_shape' must be "
|
raise ValueError(f"For '{self.name}', all elements in 'sparse_shape' must be "
|
||||||
f"positive int number, but got sparse shape: {a_shape}")
|
f"positive int number, but got 'sparse_shape': {a_shape}.")
|
||||||
if len(a_shape) != 2 or len(b_shape) != 2:
|
if len(a_shape) != 2 or len(b_shape) != 2:
|
||||||
raise ValueError("SparseTensorDenseMatmul requires both the 'sparse_shape' length and the dense tensor "
|
raise ValueError(f"For '{self.name}', both the length of 'sparse_shape' and the tensor "
|
||||||
f"rank should be equal to 2, but got 'sparse_shape' length: {len(a_shape)}, "
|
f"rank of 'dense' should be equal to 2, but got the length of "
|
||||||
f"dense tensor rank: {len(b_shape)}")
|
f"'sparse_shape': {len(a_shape)}, "
|
||||||
|
f"the tensor rank of 'dense': {len(b_shape)}.")
|
||||||
if a_shape[1] != b_shape[0]:
|
if a_shape[1] != b_shape[0]:
|
||||||
raise ValueError(f"The sparse tensor shape: {a_shape} and the dense tensor shape: {b_shape} "
|
raise ValueError(f"For '{self.name}', the second dimension length of 'sparse_shape' must be equal to the "
|
||||||
f"don't meet the condition for matmul")
|
f"first dimension length of 'dense', but got "
|
||||||
|
f"the tensor shape of 'sparse': {a_shape} and the tensor shape of 'dense': {b_shape}. "
|
||||||
|
f"Don't meet the condition for matmul")
|
||||||
out_shape = [a_shape[0], b_shape[1]]
|
out_shape = [a_shape[0], b_shape[1]]
|
||||||
out = {'shape': tuple(out_shape),
|
out = {'shape': tuple(out_shape),
|
||||||
'dtype': values['dtype'],
|
'dtype': values['dtype'],
|
||||||
|
|
|
@ -2322,7 +2322,6 @@ class LJEnergy(PrimitiveWithInfer):
|
||||||
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
|
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
|
||||||
validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
|
validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
|
||||||
validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
|
validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
|
||||||
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
|
|
||||||
validator.check_int(len(d_lj_b), 1, Rel.EQ, "d_LJ_B_dim", cls_name)
|
validator.check_int(len(d_lj_b), 1, Rel.EQ, "d_LJ_B_dim", cls_name)
|
||||||
|
|
||||||
validator.check_int(uint_crd[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name)
|
validator.check_int(uint_crd[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name)
|
||||||
|
@ -2333,7 +2332,6 @@ class LJEnergy(PrimitiveWithInfer):
|
||||||
validator.check_int(nl_numbers[0], n, Rel.EQ, "nl_numbers_shape", cls_name)
|
validator.check_int(nl_numbers[0], n, Rel.EQ, "nl_numbers_shape", cls_name)
|
||||||
validator.check_int(nl_serial[0], n, Rel.EQ, "nl_serial_shape[0]", cls_name)
|
validator.check_int(nl_serial[0], n, Rel.EQ, "nl_serial_shape[0]", cls_name)
|
||||||
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial_shape[1]", cls_name)
|
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial_shape[1]", cls_name)
|
||||||
validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
|
|
||||||
validator.check_int(d_lj_b[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
|
validator.check_int(d_lj_b[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
|
||||||
return charge
|
return charge
|
||||||
|
|
||||||
|
@ -2414,7 +2412,6 @@ class LJForce(PrimitiveWithInfer):
|
||||||
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
|
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
|
||||||
validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
|
validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
|
||||||
validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
|
validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
|
||||||
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
|
|
||||||
validator.check_int(len(d_lj_b), 1, Rel.EQ, "d_LJ_B_dim", cls_name)
|
validator.check_int(len(d_lj_b), 1, Rel.EQ, "d_LJ_B_dim", cls_name)
|
||||||
|
|
||||||
validator.check_int(uint_crd[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name)
|
validator.check_int(uint_crd[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name)
|
||||||
|
@ -2425,7 +2422,6 @@ class LJForce(PrimitiveWithInfer):
|
||||||
validator.check_int(nl_numbers[0], n, Rel.EQ, "nl_numbers_shape", cls_name)
|
validator.check_int(nl_numbers[0], n, Rel.EQ, "nl_numbers_shape", cls_name)
|
||||||
validator.check_int(nl_serial[0], n, Rel.EQ, "nl_serial_shape[0]", cls_name)
|
validator.check_int(nl_serial[0], n, Rel.EQ, "nl_serial_shape[0]", cls_name)
|
||||||
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial_shape[1]", cls_name)
|
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial_shape[1]", cls_name)
|
||||||
validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
|
|
||||||
validator.check_int(d_lj_b[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
|
validator.check_int(d_lj_b[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
|
||||||
return uint_crd
|
return uint_crd
|
||||||
|
|
||||||
|
@ -2510,7 +2506,6 @@ class LJForceWithPMEDirectForce(PrimitiveWithInfer):
|
||||||
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
|
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
|
||||||
validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
|
validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
|
||||||
validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
|
validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
|
||||||
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
|
|
||||||
validator.check_int(len(d_lj_b), 1, Rel.EQ, "d_LJ_B_dim", cls_name)
|
validator.check_int(len(d_lj_b), 1, Rel.EQ, "d_LJ_B_dim", cls_name)
|
||||||
|
|
||||||
validator.check_int(uint_crd[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name)
|
validator.check_int(uint_crd[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name)
|
||||||
|
@ -2521,7 +2516,6 @@ class LJForceWithPMEDirectForce(PrimitiveWithInfer):
|
||||||
validator.check_int(nl_numbers[0], n, Rel.EQ, "nl_numbers_shape", cls_name)
|
validator.check_int(nl_numbers[0], n, Rel.EQ, "nl_numbers_shape", cls_name)
|
||||||
validator.check_int(nl_serial[0], n, Rel.EQ, "nl_serial_shape[0]", cls_name)
|
validator.check_int(nl_serial[0], n, Rel.EQ, "nl_serial_shape[0]", cls_name)
|
||||||
validator.check_int(nl_serial[1], 800, Rel.EQ, "nl_serial_shape[1]", cls_name)
|
validator.check_int(nl_serial[1], 800, Rel.EQ, "nl_serial_shape[1]", cls_name)
|
||||||
validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
|
|
||||||
validator.check_int(d_lj_b[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
|
validator.check_int(d_lj_b[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
|
||||||
return uint_crd
|
return uint_crd
|
||||||
|
|
||||||
|
|
|
@ -1278,7 +1278,6 @@ class LJForceWithVirialEnergy(PrimitiveWithInfer):
|
||||||
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
|
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
|
||||||
validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
|
validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
|
||||||
validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
|
validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
|
||||||
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
|
|
||||||
validator.check_int(len(d_lj_b), 1, Rel.EQ, "d_LJ_B_dim", cls_name)
|
validator.check_int(len(d_lj_b), 1, Rel.EQ, "d_LJ_B_dim", cls_name)
|
||||||
|
|
||||||
validator.check_int(uint_crd[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name)
|
validator.check_int(uint_crd[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name)
|
||||||
|
@ -1289,7 +1288,6 @@ class LJForceWithVirialEnergy(PrimitiveWithInfer):
|
||||||
validator.check_int(nl_numbers[0], n, Rel.EQ, "nl_numbers_shape", cls_name)
|
validator.check_int(nl_numbers[0], n, Rel.EQ, "nl_numbers_shape", cls_name)
|
||||||
validator.check_int(nl_serial[0], n, Rel.EQ, "nl_serial_shape[0]", cls_name)
|
validator.check_int(nl_serial[0], n, Rel.EQ, "nl_serial_shape[0]", cls_name)
|
||||||
validator.check_int(nl_serial[1], 800, Rel.EQ, "nl_serial_shape[1]", cls_name)
|
validator.check_int(nl_serial[1], 800, Rel.EQ, "nl_serial_shape[1]", cls_name)
|
||||||
validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
|
|
||||||
validator.check_int(d_lj_b[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
|
validator.check_int(d_lj_b[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
|
||||||
return [n, 3], [n,], [n,]
|
return [n, 3], [n,], [n,]
|
||||||
|
|
||||||
|
@ -1381,7 +1379,6 @@ class LJForceWithPMEDirectForceUpdate(PrimitiveWithInfer):
|
||||||
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
|
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
|
||||||
validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
|
validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
|
||||||
validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
|
validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
|
||||||
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
|
|
||||||
validator.check_int(len(d_lj_b), 1, Rel.EQ, "d_LJ_B_dim", cls_name)
|
validator.check_int(len(d_lj_b), 1, Rel.EQ, "d_LJ_B_dim", cls_name)
|
||||||
validator.check_int(len(beta), 1, Rel.EQ, "beta_dim", cls_name)
|
validator.check_int(len(beta), 1, Rel.EQ, "beta_dim", cls_name)
|
||||||
|
|
||||||
|
@ -1393,7 +1390,6 @@ class LJForceWithPMEDirectForceUpdate(PrimitiveWithInfer):
|
||||||
validator.check_int(nl_numbers[0], n, Rel.EQ, "nl_numbers_shape", cls_name)
|
validator.check_int(nl_numbers[0], n, Rel.EQ, "nl_numbers_shape", cls_name)
|
||||||
validator.check_int(nl_serial[0], n, Rel.EQ, "nl_serial_shape[0]", cls_name)
|
validator.check_int(nl_serial[0], n, Rel.EQ, "nl_serial_shape[0]", cls_name)
|
||||||
validator.check_int(nl_serial[1], 800, Rel.EQ, "nl_serial_shape[1]", cls_name)
|
validator.check_int(nl_serial[1], 800, Rel.EQ, "nl_serial_shape[1]", cls_name)
|
||||||
validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
|
|
||||||
validator.check_int(d_lj_b[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
|
validator.check_int(d_lj_b[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
|
||||||
validator.check_int(beta[0], 1, Rel.EQ, "beta_shape", cls_name)
|
validator.check_int(beta[0], 1, Rel.EQ, "beta_shape", cls_name)
|
||||||
return [n, 3]
|
return [n, 3]
|
||||||
|
@ -1691,7 +1687,6 @@ class LJForceWithVirialEnergyUpdate(PrimitiveWithInfer):
|
||||||
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
|
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
|
||||||
validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
|
validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
|
||||||
validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
|
validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
|
||||||
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
|
|
||||||
validator.check_int(len(d_lj_b), 1, Rel.EQ, "d_LJ_B_dim", cls_name)
|
validator.check_int(len(d_lj_b), 1, Rel.EQ, "d_LJ_B_dim", cls_name)
|
||||||
validator.check_int(len(beta), 1, Rel.EQ, "beta_dim", cls_name)
|
validator.check_int(len(beta), 1, Rel.EQ, "beta_dim", cls_name)
|
||||||
validator.check_int(uint_crd[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name)
|
validator.check_int(uint_crd[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name)
|
||||||
|
@ -1702,7 +1697,6 @@ class LJForceWithVirialEnergyUpdate(PrimitiveWithInfer):
|
||||||
validator.check_int(nl_numbers[0], n, Rel.EQ, "nl_numbers_shape", cls_name)
|
validator.check_int(nl_numbers[0], n, Rel.EQ, "nl_numbers_shape", cls_name)
|
||||||
validator.check_int(nl_serial[0], n, Rel.EQ, "nl_serial_shape[0]", cls_name)
|
validator.check_int(nl_serial[0], n, Rel.EQ, "nl_serial_shape[0]", cls_name)
|
||||||
validator.check_int(nl_serial[1], 800, Rel.EQ, "nl_serial_shape[1]", cls_name)
|
validator.check_int(nl_serial[1], 800, Rel.EQ, "nl_serial_shape[1]", cls_name)
|
||||||
validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
|
|
||||||
validator.check_int(d_lj_b[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
|
validator.check_int(d_lj_b[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
|
||||||
validator.check_int(beta[0], 1, Rel.EQ, "beta_shape[0]", cls_name)
|
validator.check_int(beta[0], 1, Rel.EQ, "beta_shape[0]", cls_name)
|
||||||
return [n, 3], [n,], [n,]
|
return [n, 3], [n,], [n,]
|
||||||
|
|
|
@ -90,7 +90,7 @@ def test_resizebilinear_error():
|
||||||
net = Net()
|
net = Net()
|
||||||
with pytest.raises(ValueError) as ex:
|
with pytest.raises(ValueError) as ex:
|
||||||
net()
|
net()
|
||||||
assert "size and scale both none" in str(ex.value)
|
assert "'size' and 'scale' both none" in str(ex.value)
|
||||||
|
|
||||||
|
|
||||||
def test_resizebilinear_error_1():
|
def test_resizebilinear_error_1():
|
||||||
|
@ -106,4 +106,4 @@ def test_resizebilinear_error_1():
|
||||||
net = Net()
|
net = Net()
|
||||||
with pytest.raises(ValueError) as ex:
|
with pytest.raises(ValueError) as ex:
|
||||||
net()
|
net()
|
||||||
assert "size and scale both not none" in str(ex.value)
|
assert "'size' and 'scale' both not none" in str(ex.value)
|
||||||
|
|
Loading…
Reference in New Issue